-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR][Python] Add structured.fuseop to generator. #120601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Hugo Trachino (nujaa) ChangesImplements a python interface for structured fuseOp allowing more freedom on inputs. Full diff: https://github.com/llvm/llvm-project/pull/120601.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
index 90021ffa7c380bc..efbe5c56a219b37 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
@@ -162,7 +162,7 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
}
/// Projects out the options for `createConvertVectorToLLVMPass`.
- ConvertVectorToLLVMPassOptions lowerVectorToLLVMOptions() const {
+ ConvertVectorToLLVMPassOptions convertVectorToLLVMOptions() const {
ConvertVectorToLLVMPassOptions opts{};
opts.reassociateFPReductions = reassociateFPReductions;
opts.force32BitVectorIndices = force32BitVectorIndices;
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 55143d5939ba257..842d239cf6a512c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -35,8 +35,8 @@ using namespace mlir;
using namespace mlir::vector;
namespace {
-struct LowerVectorToLLVMPass
- : public impl::ConvertVectorToLLVMPassBase<LowerVectorToLLVMPass> {
+struct ConvertVectorToLLVMPass
+ : public impl::ConvertVectorToLLVMPassBase<ConvertVectorToLLVMPass> {
using Base::Base;
@@ -58,7 +58,7 @@ struct LowerVectorToLLVMPass
};
} // namespace
-void LowerVectorToLLVMPass::runOnOperation() {
+void ConvertVectorToLLVMPass::runOnOperation() {
// Perform progressive lowering of operations on slices and
// all contraction operations. Also applies folding and DCE.
{
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index c5eb965884396ae..5e49252c0e57d98 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -76,16 +76,16 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
pm.addNestedPass<func::FuncOp>(createConvertSCFToCFPass());
pm.addPass(memref::createExpandStridedMetadataPass());
pm.addPass(createLowerAffinePass());
- pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions()));
+ pm.addPass(createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
pm.addNestedPass<func::FuncOp>(createConvertComplexToStandardPass());
pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass());
pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass());
pm.addPass(createConvertMathToLibmPass());
pm.addPass(createConvertComplexToLibmPass());
- pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions()));
+ pm.addPass(createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
pm.addPass(createConvertComplexToLLVMPass());
- pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions()));
+ pm.addPass(createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
pm.addPass(createConvertFuncToLLVMPass());
// Finalize GPU code generation.
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 41051c0d5b2ffb6..b97a1aa8a82bfcf 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -139,6 +139,58 @@ def __init__(
ip=ip,
)
+@_ods_cext.register_operation(_Dialect, replace=True)
+class FuseOp(FuseOp):
+ """Specialization for FuseOp class."""
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, Value, OpView],
+ *,
+ sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ loop_types_or_target: Union[Type, List[Type], Operation, Value],
+ target_or_none: Optional[Union[Operation, Value, OpView]] = None,
+ *,
+ sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ sizes = sizes if sizes else []
+ num_loops = sum(v if v == 0 else 1 for v in sizes)
+
+ if isinstance(loop_types_or_target, (Operation, Value, OpView)):
+ loop_types = [transform.AnyOpType.get()] * num_loops
+ target = loop_types_or_target
+ assert (
+ target_or_none is None
+ ), "Cannot construct FuseOp with two targets."
+ else:
+ loop_types = (
+ ([loop_types_or_target] * num_loops)
+ if isinstance(loop_types_or_target, Type)
+ else loop_types_or_target
+ )
+ target = target_or_none
+ super().__init__(
+ target.type,
+ loop_types,
+ target,
+ tile_sizes=sizes,
+ tile_interchange=interchange,
+ loc=loc,
+ ip=ip,
+ )
+
@_ods_cext.register_operation(_Dialect, replace=True)
class GeneralizeOp(GeneralizeOp):
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 3ea73e8beea3688..551c2fa1e48acd3 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -101,6 +101,27 @@ def testFuseIntoContainingOpCompact(target):
# CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+@run
+@create_sequence
+def testFuseOpCompact(target):
+ structured.FuseOp(target, sizes=[4, 8], interchange=[0, 1])
+ # CHECK-LABEL: TEST: testFuseOpCompact
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
+ # CHECK-SAME: interchange [0, 1]
+ # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+
+@run
+@create_sequence
+def testFuseOpNoArg(target):
+ structured.FuseOp(target)
+ # CHECK-LABEL: TEST: testFuseOpNoArg
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}} = transform.structured.fuse %{{.*}} :
+ # CHECK-SAME: (!transform.any_op) -> !transform.any_op
+
+
@run
@create_sequence
def testGeneralize(target):
|
8a4bf10 to
dd7406b
Compare
|
✅ With the latest revision this PR passed the Python code formatter. |
makslevental
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is probably good but I'm not a transform user so we'll have to wait for @ftynse to comment on the API
ftynse
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with some nits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| num_loops = sum(v if v == 0 else 1 for v in sizes) | |
| num_loops = sum(0 if v == 0 else 1 for v in sizes) |
This is slightly easier to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| loop_types_or_target: Union[Type, List[Type], Operation, Value], | |
| loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you need to explicitly list both overloads. And use exactly the same names for arguments, otherwise the linter complains.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, sorry for the long silence, What do you mean by you need to explicitly list both overloads ? Do you mean adding a @overload under ... ? I followed what was done in other examples. The last overload one never has @overload.
- Renamed arguments to
tile_sizesandtile_interchange
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIU, the code needs to list overloads separately from the implementation, the code currently lists only one overload. https://mypy.readthedocs.io/en/stable/more_types.html#function-overloading
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added an overload setting output types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: can we add a test where sizes are not constant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added support for arrayAttr.
e75878a to
9379d82
Compare
Implements a python interface for structured fuseOp allowing more freedom on inputs.