-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR] [python] A few improvements to the Python bindings #131686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -361,37 +361,45 @@ class PyRegionIterator { | |
|
|
||
| /// Regions of an op are fixed length and indexed numerically so are represented | ||
| /// with a sequence-like container. | ||
| class PyRegionList { | ||
| class PyRegionList : public Sliceable<PyRegionList, PyRegion> { | ||
| public: | ||
| PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} | ||
| static constexpr const char *pyClassName = "RegionSequence"; | ||
|
|
||
| PyRegionList(PyOperationRef operation, intptr_t startIndex = 0, | ||
| intptr_t length = -1, intptr_t step = 1) | ||
| : Sliceable(startIndex, | ||
| length == -1 ? mlirOperationGetNumRegions(operation->get()) | ||
| : length, | ||
| step), | ||
| operation(std::move(operation)) {} | ||
|
|
||
| PyRegionIterator dunderIter() { | ||
| operation->checkValid(); | ||
| return PyRegionIterator(operation); | ||
| } | ||
|
|
||
| intptr_t dunderLen() { | ||
| static void bindDerived(ClassTy &c) { | ||
| c.def("__iter__", &PyRegionList::dunderIter); | ||
| } | ||
|
|
||
| private: | ||
| /// Give the parent CRTP class access to hook implementations below. | ||
| friend class Sliceable<PyRegionList, PyRegion>; | ||
|
|
||
| intptr_t getRawNumElements() { | ||
| operation->checkValid(); | ||
| return mlirOperationGetNumRegions(operation->get()); | ||
| } | ||
|
|
||
| PyRegion dunderGetItem(intptr_t index) { | ||
| // dunderLen checks validity. | ||
| if (index < 0 || index >= dunderLen()) { | ||
| throw nb::index_error("attempt to access out of bounds region"); | ||
| } | ||
| MlirRegion region = mlirOperationGetRegion(operation->get(), index); | ||
| return PyRegion(operation, region); | ||
| PyRegion getRawElement(intptr_t pos) { | ||
| operation->checkValid(); | ||
| return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos)); | ||
| } | ||
|
|
||
| static void bind(nb::module_ &m) { | ||
| nb::class_<PyRegionList>(m, "RegionSequence") | ||
| .def("__len__", &PyRegionList::dunderLen) | ||
| .def("__iter__", &PyRegionList::dunderIter) | ||
| .def("__getitem__", &PyRegionList::dunderGetItem); | ||
| PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) { | ||
| return PyRegionList(operation, startIndex, length, step); | ||
| } | ||
|
|
||
| private: | ||
| PyOperationRef operation; | ||
| }; | ||
|
|
||
|
|
@@ -450,6 +458,9 @@ class PyBlockList { | |
|
|
||
| PyBlock dunderGetItem(intptr_t index) { | ||
| operation->checkValid(); | ||
| if (index < 0) { | ||
| index += dunderLen(); | ||
| } | ||
| if (index < 0) { | ||
| throw nb::index_error("attempt to access out of bounds block"); | ||
| } | ||
|
Comment on lines
464
to
466
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DCE?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can still be negative, right?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh whoops duh. Ok ignore me then!
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have this already in the test suite? Out-of-scope for your PR here, but could these exceptions be a bit more expressive by including the original index, the adjusted one, and the bounds in the error message?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test suite doesn't check OOB errors at the moment, I think. |
||
|
|
@@ -546,6 +557,9 @@ class PyOperationList { | |
|
|
||
| nb::object dunderGetItem(intptr_t index) { | ||
| parentOperation->checkValid(); | ||
| if (index < 0) { | ||
| index += dunderLen(); | ||
| } | ||
superbobry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if (index < 0) { | ||
| throw nb::index_error("attempt to access out of bounds operation"); | ||
| } | ||
makslevental marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
@@ -2629,6 +2643,9 @@ class PyOpAttributeMap { | |
| } | ||
|
|
||
| PyNamedAttribute dunderGetItemIndexed(intptr_t index) { | ||
| if (index < 0) { | ||
| index += dunderLen(); | ||
| } | ||
superbobry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if (index < 0 || index >= dunderLen()) { | ||
| throw nb::index_error("attempt to access out of bounds attribute"); | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.