Skip to content

Conversation

@nujaa
Copy link
Contributor

@nujaa nujaa commented Dec 19, 2024

Implements a python interface for structured fuseOp allowing more freedom on inputs.

@llvmbot
Copy link
Member

llvmbot commented Dec 19, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Hugo Trachino (nujaa)

Changes

Implements a python interface for structured fuseOp allowing more freedom on inputs.


Full diff: https://github.com/llvm/llvm-project/pull/120601.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h (+1-1)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+3-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp (+3-3)
  • (modified) mlir/python/mlir/dialects/transform/structured.py (+52)
  • (modified) mlir/test/python/dialects/transform_structured_ext.py (+21)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
index 90021ffa7c380bc..efbe5c56a219b37 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
@@ -162,7 +162,7 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
   }
 
   /// Projects out the options for `createConvertVectorToLLVMPass`.
-  ConvertVectorToLLVMPassOptions lowerVectorToLLVMOptions() const {
+  ConvertVectorToLLVMPassOptions convertVectorToLLVMOptions() const {
     ConvertVectorToLLVMPassOptions opts{};
     opts.reassociateFPReductions = reassociateFPReductions;
     opts.force32BitVectorIndices = force32BitVectorIndices;
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 55143d5939ba257..842d239cf6a512c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -35,8 +35,8 @@ using namespace mlir;
 using namespace mlir::vector;
 
 namespace {
-struct LowerVectorToLLVMPass
-    : public impl::ConvertVectorToLLVMPassBase<LowerVectorToLLVMPass> {
+struct ConvertVectorToLLVMPass
+    : public impl::ConvertVectorToLLVMPassBase<ConvertVectorToLLVMPass> {
 
   using Base::Base;
 
@@ -58,7 +58,7 @@ struct LowerVectorToLLVMPass
 };
 } // namespace
 
-void LowerVectorToLLVMPass::runOnOperation() {
+void ConvertVectorToLLVMPass::runOnOperation() {
   // Perform progressive lowering of operations on slices and
   // all contraction operations. Also applies folding and DCE.
   {
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index c5eb965884396ae..5e49252c0e57d98 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -76,16 +76,16 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
   pm.addNestedPass<func::FuncOp>(createConvertSCFToCFPass());
   pm.addPass(memref::createExpandStridedMetadataPass());
   pm.addPass(createLowerAffinePass());
-  pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions()));
+  pm.addPass(createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
   pm.addPass(createFinalizeMemRefToLLVMConversionPass());
   pm.addNestedPass<func::FuncOp>(createConvertComplexToStandardPass());
   pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass());
   pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass());
   pm.addPass(createConvertMathToLibmPass());
   pm.addPass(createConvertComplexToLibmPass());
-  pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions()));
+  pm.addPass(createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
   pm.addPass(createConvertComplexToLLVMPass());
-  pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions()));
+  pm.addPass(createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
   pm.addPass(createConvertFuncToLLVMPass());
 
   // Finalize GPU code generation.
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 41051c0d5b2ffb6..b97a1aa8a82bfcf 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -139,6 +139,58 @@ def __init__(
             ip=ip,
         )
 
+@_ods_cext.register_operation(_Dialect, replace=True)
+class FuseOp(FuseOp):
+    """Specialization for FuseOp class."""
+
+    @overload
+    def __init__(
+        self,
+        target: Union[Operation, Value, OpView],
+        *,
+        sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+        interchange: OptionalIntList = None,
+        loc=None,
+        ip=None,
+    ):
+        ...
+
+    def __init__(
+        self,
+        loop_types_or_target: Union[Type, List[Type], Operation, Value],
+        target_or_none: Optional[Union[Operation, Value, OpView]] = None,
+        *,
+        sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+        interchange: OptionalIntList = None,
+        loc=None,
+        ip=None,
+    ):
+        sizes = sizes if sizes else []
+        num_loops = sum(v if v == 0 else 1 for v in sizes)
+
+        if isinstance(loop_types_or_target, (Operation, Value, OpView)):
+            loop_types = [transform.AnyOpType.get()] * num_loops
+            target = loop_types_or_target
+            assert (
+                target_or_none is None
+            ), "Cannot construct FuseOp with two targets."
+        else:
+            loop_types = (
+                ([loop_types_or_target] * num_loops)
+                if isinstance(loop_types_or_target, Type)
+                else loop_types_or_target
+            )
+            target = target_or_none
+        super().__init__(
+            target.type,
+            loop_types,
+            target,
+            tile_sizes=sizes,
+            tile_interchange=interchange,
+            loc=loc,
+            ip=ip,
+        )
+
 
 @_ods_cext.register_operation(_Dialect, replace=True)
 class GeneralizeOp(GeneralizeOp):
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 3ea73e8beea3688..551c2fa1e48acd3 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -101,6 +101,27 @@ def testFuseIntoContainingOpCompact(target):
     # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
 
 
+@run
+@create_sequence
+def testFuseOpCompact(target):
+    structured.FuseOp(target, sizes=[4, 8], interchange=[0, 1])
+    # CHECK-LABEL: TEST: testFuseOpCompact
+    # CHECK: transform.sequence
+    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
+    # CHECK-SAME: interchange [0, 1]
+    # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+
+@run
+@create_sequence
+def testFuseOpNoArg(target):
+    structured.FuseOp(target)
+    # CHECK-LABEL: TEST: testFuseOpNoArg
+    # CHECK: transform.sequence
+    # CHECK: %{{.+}} = transform.structured.fuse %{{.*}} :
+    # CHECK-SAME: (!transform.any_op) -> !transform.any_op
+
+
 @run
 @create_sequence
 def testGeneralize(target):

@nujaa nujaa force-pushed the hugo.fuseopPython branch from 8a4bf10 to dd7406b Compare December 19, 2024 16:28
@github-actions
Copy link

github-actions bot commented Dec 19, 2024

✅ With the latest revision this PR passed the Python code formatter.

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is probably good but I'm not a transform user so we'll have to wait for @ftynse to comment on the API

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with some nits.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num_loops = sum(v if v == 0 else 1 for v in sizes)
num_loops = sum(0 if v == 0 else 1 for v in sizes)

This is slightly easier to read.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
loop_types_or_target: Union[Type, List[Type], Operation, Value],
loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value],

Comment on lines 147 to 157
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to explicitly list both overloads. And use exactly the same names for arguments, otherwise the linter complains.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, sorry for the long silence, What do you mean by you need to explicitly list both overloads ? Do you mean adding a @overload under ... ? I followed what was done in other examples. The last overload one never has @overload.

  • Renamed arguments to tile_sizes and tile_interchange

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIU, the code needs to list overloads separately from the implementation, the code currently lists only one overload. https://mypy.readthedocs.io/en/stable/more_types.html#function-overloading

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an overload setting output types.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: can we add a test where sizes are not constant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added support for arrayAttr.

@nujaa nujaa force-pushed the hugo.fuseopPython branch from e75878a to 9379d82 Compare January 2, 2025 11:57
@nujaa nujaa merged commit 579ced4 into llvm:main Jan 3, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:python MLIR Python bindings mlir:sparse Sparse compiler in MLIR mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants