-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir:python] Change PyOperation::create to actually return a PyOperation. #114542
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
PyOperation.
In the tablegen-generated Python bindings, we typically see a pattern
like:
```
class ConstantOp(_ods_ir.OpView):
...
def __init__(self, value, *, loc=None, ip=None):
...
super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))
```
i.e., the generated code calls `OpView.__init__()` with the output of
`build_generic`. The purpose of `OpView` is to wrap another operation
object, and `OpView.__init__` can accept any `PyOperationBase`
subclass, and presumably the intention is that `build_generic` returns a
`PyOperation`, so the user ends up with a `PyOpView` wrapping a
`PyOperation`.
However, `PyOpView::buildGeneric` calls `PyOperation::create`, which
does not just build a PyOperation, but it also calls `createOpView` to wrap
that operation in a subclass of `PyOpView` and returns that view. But that's rather pointless:
we called this code from the constructor of an `OpView` subclass, so we
already have a view object ready to go; we don't need to build another
one!
If we change `PyOperation::create` to return the underlying
`PyOperation`, rather than a view wrapper, we can save allocating a
useless `PyOpView` object for each ODS-generated Python object.
This saves approximately 1.5s of Python time in a JAX LLM benchmark that
generates a mixture of upstream dialects and StableHLO.
|
@llvm/pr-subscribers-mlir Author: Peter Hawkins (hawkinsp) ChangesIn the tablegen-generated Python bindings, we typically see a pattern like: i.e., the generated code calls However, If we change This saves approximately 1.5s of Python time in a JAX LLM benchmark that generates a mixture of upstream dialects and StableHLO. Flame graph for calls to and after: Full diff: https://github.com/llvm/llvm-project/pull/114542.diff 1 Files Affected:
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index c12f75e7d224a8..3562ff38201dc3 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1534,7 +1534,7 @@ py::object PyOperation::create(const std::string &name,
PyOperation::createDetached(location->getContext(), operation);
maybeInsertOperation(created, maybeIp);
- return created->createOpView();
+ return created.getObject();
}
py::object PyOperation::clone(const py::object &maybeIp) {
|
…tion. (llvm#114542) In the tablegen-generated Python bindings, we typically see a pattern like: ``` class ConstantOp(_ods_ir.OpView): ... def __init__(self, value, *, loc=None, ip=None): ... super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)) ``` i.e., the generated code calls `OpView.__init__()` with the output of `build_generic`. The purpose of `OpView` is to wrap another operation object, and `OpView.__init__` can accept any `PyOperationBase` subclass, and presumably the intention is that `build_generic` returns a `PyOperation`, so the user ends up with a `PyOpView` wrapping a `PyOperation`. However, `PyOpView::buildGeneric` calls `PyOperation::create`, which does not just build a PyOperation, but it also calls `createOpView` to wrap that operation in a subclass of `PyOpView` and returns that view. But that's rather pointless: we called this code from the constructor of an `OpView` subclass, so we already have a view object ready to go; we don't need to build another one! If we change `PyOperation::create` to return the underlying `PyOperation`, rather than a view wrapper, we can save allocating a useless `PyOpView` object for each ODS-generated Python object. This saves approximately 1.5s of Python time in a JAX LLM benchmark that generates a mixture of upstream dialects and StableHLO.
- Check `isinstance(results, (Operation, OpView))` when checking for ops with no results in `Graph._add_op`. This check is required after an upstream change where now the Python bindings for `mo.OutputOp` are no longer a subclass of `mlir.Operation. Potentially LLVM PR [114542](llvm/llvm-project#114542) is the related upstream change. * Disabled `Fill.getScalarZeroes` test, which is failing due to a new assertion in LLVM. Filed KERN-1196 to track fixing the test. MAX_GRAPH_API_ORIG_REV_ID: f570b71be0d6fb4a71d38b0b180c9493ac148758
- Check `isinstance(results, (Operation, OpView))` when checking for ops with no results in `Graph._add_op`. This check is required after an upstream change where now the Python bindings for `mo.OutputOp` are no longer a subclass of `mlir.Operation. Potentially LLVM PR [114542](llvm/llvm-project#114542) is the related upstream change. * Disabled `Fill.getScalarZeroes` test, which is failing due to a new assertion in LLVM. Filed KERN-1196 to track fixing the test. MAX_GRAPH_API_ORIG_REV_ID: f570b71be0d6fb4a71d38b0b180c9493ac148758
In the tablegen-generated Python bindings, we typically see a pattern like:
i.e., the generated code calls
OpView.__init__()with the output ofbuild_generic. The purpose ofOpViewis to wrap another operation object, andOpView.__init__can accept anyPyOperationBasesubclass, and presumably the intention is thatbuild_genericreturns aPyOperation, so the user ends up with aPyOpViewwrapping aPyOperation.However,
PyOpView::buildGenericcallsPyOperation::create, which does not just build a PyOperation, but it also callscreateOpViewto wrap that operation in a subclass ofPyOpViewand returns that view. But that's rather pointless: we called this code from the constructor of anOpViewsubclass, so we already have a view object ready to go; we don't need to build another one!If we change
PyOperation::createto return the underlyingPyOperation, rather than a view wrapper, we can save allocating a uselessPyOpViewobject for each ODS-generated Python object.This saves approximately 1.5s of Python time in a JAX LLM benchmark that generates a mixture of upstream dialects and StableHLO.
Flame graph for calls to

arith_ops_gen.ConstantOpin that benchmark before:and after:
