From 24d6062913128a85c388273629a64b1b0326049d Mon Sep 17 00:00:00 2001 From: hanhanW Date: Thu, 20 Mar 2025 16:54:56 -0700 Subject: [PATCH 1/2] [mlir][Linalg] Preserve encodings in static shape inference. Previously, the encodings are unconditionally dropped during the shape inference. The revision adds the support for preserving the encodings in the linalg ops. Signed-off-by: hanhanW --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 3 ++- mlir/test/Dialect/Linalg/canonicalize.mlir | 23 ++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 07b19e5cb1a89..275c107cd70f8 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2539,7 +2539,8 @@ static void createNewOperandWithStaticSizes( newShape.push_back(affineExprToSize[dimExpr]); newOperandNeeded = true; } - resultType = RankedTensorType::get(newShape, sourceType.getElementType()); + resultType = RankedTensorType::get(newShape, sourceType.getElementType(), + sourceType.getEncoding()); if (newOperandNeeded) { changeNeeded = true; // Get the new operand value given its size and element type by diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index db4f6181f517c..103ec55dfa441 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -649,6 +649,29 @@ func.func @cast_dest(%arg0: tensor, %arg1: tensor<1x?x?xf32>, %arg2: // ----- +#map = affine_map<(d0, d1) -> (d0, d1)> +#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +// CHECK-DAG: #[[$SPARSE:.+]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +// CHECK-LABEL: func @static_shape_inference_with_encoding( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +func.func @static_shape_inference_with_encoding(%arg0: tensor, %arg1: tensor) -> tensor<3x4xf32> { + %0 = tensor.empty() : tensor<3x4xf32> + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%0 : tensor<3x4xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %2 = arith.addf %in, %in_0 : f32 + linalg.yield %2 : f32 + } -> tensor<3x4xf32> + return %1 : tensor<3x4xf32> + // CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor to tensor<3x4xf32, #[[$SPARSE]]> + // CHECK-NEXT: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor to tensor<3x4xf32> + // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic + // CHECK-SAME: ins(%[[CAST_ARG0]], %[[CAST_ARG1]] : tensor<3x4xf32, #[[$SPARSE]]>, tensor<3x4xf32>) + // CHECK-SAME: outs({{.*}} : tensor<3x4xf32>) +} + +// ----- + // CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 1)> // CHECK-LABEL: func @insert_pad_into_fill // CHECK-SAME: (%[[INPUT:.+]]: tensor, %[[LOW0:.+]]: index, %[[LOW1:.+]]: index, %{{.+}}: index, %{{.+}}: index) From 9acf0324865b15e122f47ad9076efd37adc59f4d Mon Sep 17 00:00:00 2001 From: hanhanW Date: Fri, 21 Mar 2025 10:25:51 -0700 Subject: [PATCH 2/2] format generic op Signed-off-by: hanhanW --- mlir/test/Dialect/Linalg/canonicalize.mlir | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 103ec55dfa441..f99491c25d832 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -657,7 +657,11 @@ func.func @cast_dest(%arg0: tensor, %arg1: tensor<1x?x?xf32>, %arg2: // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] func.func @static_shape_inference_with_encoding(%arg0: tensor, %arg1: tensor) -> tensor<3x4xf32> { %0 = tensor.empty() : tensor<3x4xf32> - %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%0 : tensor<3x4xf32>) { + %1 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel"] + } ins(%arg0, %arg1 : tensor, tensor) + outs(%0 : tensor<3x4xf32>) { ^bb0(%in: f32, %in_0: f32, %out: f32): %2 = arith.addf %in, %in_0 : f32 linalg.yield %2 : f32