-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][Linalg] Preserve encodings in static shape inference. #132311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Linalg] Preserve encodings in static shape inference. #132311
Conversation
Previously, the encodings are unconditionally dropped during the shape inference. The revision adds the support for preserving the encodings in the linalg ops. Signed-off-by: hanhanW <[email protected]>
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Han-Chung Wang (hanhanW) ChangesPreviously, the encodings are unconditionally dropped during the shape inference. The revision adds the support for preserving the encodings in the linalg ops. Full diff: https://github.com/llvm/llvm-project/pull/132311.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..275c107cd70f8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2539,7 +2539,8 @@ static void createNewOperandWithStaticSizes(
newShape.push_back(affineExprToSize[dimExpr]);
newOperandNeeded = true;
}
- resultType = RankedTensorType::get(newShape, sourceType.getElementType());
+ resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
+ sourceType.getEncoding());
if (newOperandNeeded) {
changeNeeded = true;
// Get the new operand value given its size and element type by
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index db4f6181f517c..103ec55dfa441 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -649,6 +649,29 @@ func.func @cast_dest(%arg0: tensor<?x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2:
// -----
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+// CHECK-DAG: #[[$SPARSE:.+]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+// CHECK-LABEL: func @static_shape_inference_with_encoding(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+func.func @static_shape_inference_with_encoding(%arg0: tensor<?x?xf32, #sparse>, %arg1: tensor<?x?xf32>) -> tensor<3x4xf32> {
+ %0 = tensor.empty() : tensor<3x4xf32>
+ %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32, #sparse>, tensor<?x?xf32>) outs(%0 : tensor<3x4xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %2 = arith.addf %in, %in_0 : f32
+ linalg.yield %2 : f32
+ } -> tensor<3x4xf32>
+ return %1 : tensor<3x4xf32>
+ // CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?x?xf32, #[[$SPARSE]]> to tensor<3x4xf32, #[[$SPARSE]]>
+ // CHECK-NEXT: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?xf32> to tensor<3x4xf32>
+ // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
+ // CHECK-SAME: ins(%[[CAST_ARG0]], %[[CAST_ARG1]] : tensor<3x4xf32, #[[$SPARSE]]>, tensor<3x4xf32>)
+ // CHECK-SAME: outs({{.*}} : tensor<3x4xf32>)
+}
+
+// -----
+
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 1)>
// CHECK-LABEL: func @insert_pad_into_fill
// CHECK-SAME: (%[[INPUT:.+]]: tensor<?x?x?xf32>, %[[LOW0:.+]]: index, %[[LOW1:.+]]: index, %{{.+}}: index, %{{.+}}: index)
|
| // ----- | ||
|
|
||
| #map = affine_map<(d0, d1) -> (d0, d1)> | ||
| #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't find an encoding for testing, so I followed the other example that uses sparse encoding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does feel like we should be able to create more tests for this, but I'm struggling to find a specific suggestion 🤷🏻
[nit] Your linalg.generic is quite wide. Would you mind splitting it across multiple lines?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking if I should use linalg.elemwise_unary/binary in the test. I followed the other tests in the file, but I think they were added a long time ago.
%0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could use linalg.generic above as a source of inspiration:
%2 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel", "parallel", "parallel"]
} ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<1x?x?xf32>)
outs(%0 : tensor<?x?x?xf32>) {
^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
%3 = arith.subf %arg5, %arg6 : f32
linalg.yield %3 : f32
} -> tensor<?x?x?xf32>IMHO, we should strive to make the official docs suggest some good reference point and then use that in tests. For example, for linalg.elementwise, you get this:
And, for linalg.generic:
But I am bike-shedding a bit 😅 My main motivation was to follow the example immediately above yours.
Whatever you decide it will be great, I will be happy and MLIR will be in a better place :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I like consistency better, so I'll format the generic op a bit!
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, LGTM!
Signed-off-by: hanhanW <[email protected]>
Previously, the encodings are unconditionally dropped during the shape inference. The revision adds the support for preserving the encodings in the linalg ops.