Skip to content

Commit 900be71

Browse files
authored
[mlir][Linalg] Preserve encodings in static shape inference. (#132311)
Previously, the encodings are unconditionally dropped during the shape inference. The revision adds the support for preserving the encodings in the linalg ops. --------- Signed-off-by: hanhanW <[email protected]>
1 parent 5b09079 commit 900be71

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2539,7 +2539,8 @@ static void createNewOperandWithStaticSizes(
25392539
newShape.push_back(affineExprToSize[dimExpr]);
25402540
newOperandNeeded = true;
25412541
}
2542-
resultType = RankedTensorType::get(newShape, sourceType.getElementType());
2542+
resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2543+
sourceType.getEncoding());
25432544
if (newOperandNeeded) {
25442545
changeNeeded = true;
25452546
// Get the new operand value given its size and element type by

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,33 @@ func.func @cast_dest(%arg0: tensor<?x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2:
649649

650650
// -----
651651

652+
#map = affine_map<(d0, d1) -> (d0, d1)>
653+
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
654+
// CHECK-DAG: #[[$SPARSE:.+]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
655+
// CHECK-LABEL: func @static_shape_inference_with_encoding(
656+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
657+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
658+
func.func @static_shape_inference_with_encoding(%arg0: tensor<?x?xf32, #sparse>, %arg1: tensor<?x?xf32>) -> tensor<3x4xf32> {
659+
%0 = tensor.empty() : tensor<3x4xf32>
660+
%1 = linalg.generic {
661+
indexing_maps = [#map, #map, #map],
662+
iterator_types = ["parallel", "parallel"]
663+
} ins(%arg0, %arg1 : tensor<?x?xf32, #sparse>, tensor<?x?xf32>)
664+
outs(%0 : tensor<3x4xf32>) {
665+
^bb0(%in: f32, %in_0: f32, %out: f32):
666+
%2 = arith.addf %in, %in_0 : f32
667+
linalg.yield %2 : f32
668+
} -> tensor<3x4xf32>
669+
return %1 : tensor<3x4xf32>
670+
// CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?x?xf32, #[[$SPARSE]]> to tensor<3x4xf32, #[[$SPARSE]]>
671+
// CHECK-NEXT: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?xf32> to tensor<3x4xf32>
672+
// CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
673+
// CHECK-SAME: ins(%[[CAST_ARG0]], %[[CAST_ARG1]] : tensor<3x4xf32, #[[$SPARSE]]>, tensor<3x4xf32>)
674+
// CHECK-SAME: outs({{.*}} : tensor<3x4xf32>)
675+
}
676+
677+
// -----
678+
652679
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 1)>
653680
// CHECK-LABEL: func @insert_pad_into_fill
654681
// CHECK-SAME: (%[[INPUT:.+]]: tensor<?x?x?xf32>, %[[LOW0:.+]]: index, %[[LOW1:.+]]: index, %{{.+}}: index, %{{.+}}: index)

0 commit comments

Comments
 (0)