Skip to content

Conversation

@hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Mar 20, 2025

Previously, the encodings are unconditionally dropped during the shape inference. The revision adds the support for preserving the encodings in the linalg ops.

Previously, the encodings are unconditionally dropped during the shape
inference. The revision adds the support for preserving the encodings in
the linalg ops.

Signed-off-by: hanhanW <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Mar 20, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Han-Chung Wang (hanhanW)

Changes

Previously, the encodings are unconditionally dropped during the shape inference. The revision adds the support for preserving the encodings in the linalg ops.


Full diff: https://github.com/llvm/llvm-project/pull/132311.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+2-1)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+23)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..275c107cd70f8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2539,7 +2539,8 @@ static void createNewOperandWithStaticSizes(
     newShape.push_back(affineExprToSize[dimExpr]);
     newOperandNeeded = true;
   }
-  resultType = RankedTensorType::get(newShape, sourceType.getElementType());
+  resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
+                                     sourceType.getEncoding());
   if (newOperandNeeded) {
     changeNeeded = true;
     // Get the new operand value given its size and element type by
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index db4f6181f517c..103ec55dfa441 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -649,6 +649,29 @@ func.func @cast_dest(%arg0: tensor<?x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2:
 
 // -----
 
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+// CHECK-DAG:   #[[$SPARSE:.+]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+// CHECK-LABEL: func @static_shape_inference_with_encoding(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+func.func @static_shape_inference_with_encoding(%arg0: tensor<?x?xf32, #sparse>, %arg1: tensor<?x?xf32>) -> tensor<3x4xf32> {
+  %0 = tensor.empty() : tensor<3x4xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32, #sparse>, tensor<?x?xf32>) outs(%0 : tensor<3x4xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %2 = arith.addf %in, %in_0 : f32
+    linalg.yield %2 : f32
+  } -> tensor<3x4xf32>
+  return %1 : tensor<3x4xf32>
+    //  CHECK:      %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?x?xf32, #[[$SPARSE]]> to tensor<3x4xf32, #[[$SPARSE]]>
+    //  CHECK-NEXT: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?xf32> to tensor<3x4xf32>
+    //  CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
+    //  CHECK-SAME: ins(%[[CAST_ARG0]], %[[CAST_ARG1]] : tensor<3x4xf32, #[[$SPARSE]]>, tensor<3x4xf32>)
+    //  CHECK-SAME: outs({{.*}} : tensor<3x4xf32>)
+}
+
+// -----
+
 //       CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 1)>
 // CHECK-LABEL: func @insert_pad_into_fill
 //  CHECK-SAME: (%[[INPUT:.+]]: tensor<?x?x?xf32>, %[[LOW0:.+]]: index, %[[LOW1:.+]]: index, %{{.+}}: index, %{{.+}}: index)

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't find an encoding for testing, so I followed the other example that uses sparse encoding.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does feel like we should be able to create more tests for this, but I'm struggling to find a specific suggestion 🤷🏻

[nit] Your linalg.generic is quite wide. Would you mind splitting it across multiple lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking if I should use linalg.elemwise_unary/binary in the test. I followed the other tests in the file, but I think they were added a long time ago.

  %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
  %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use linalg.generic above as a source of inspiration:

  %2 = linalg.generic {
    indexing_maps = [#map, #map, #map],
    iterator_types = ["parallel", "parallel", "parallel"]
  } ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<1x?x?xf32>)
    outs(%0 : tensor<?x?x?xf32>) {
  ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
    %3 = arith.subf %arg5, %arg6 : f32
    linalg.yield %3 : f32
  } -> tensor<?x?x?xf32>

IMHO, we should strive to make the official docs suggest some good reference point and then use that in tests. For example, for linalg.elementwise, you get this:

And, for linalg.generic:

But I am bike-shedding a bit 😅 My main motivation was to follow the example immediately above yours.

Whatever you decide it will be great, I will be happy and MLIR will be in a better place :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I like consistency better, so I'll format the generic op a bit!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, LGTM!

Signed-off-by: hanhanW <[email protected]>
@hanhanW hanhanW merged commit 900be71 into llvm:main Mar 21, 2025
11 checks passed
@hanhanW hanhanW deleted the linalg-infer-static-shape-do-not-drop-encoding branch March 21, 2025 20:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants