From b4919c2c7bf969898bee99c9e416007a6307352a Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Mon, 8 Sep 2025 11:47:55 -0700 Subject: [PATCH] [Linalg] Add rank zero operand support to push down extract slice pattern Currently the pattern would crash for rank 0 operand as it decides the padding based on affine results, but for rank 0 there are no affine results in the operand affine map Signed-off-by: Nirvedh Meshram --- .../Transforms/DataLayoutPropagation.cpp | 4 ++++ .../Linalg/data-layout-propagation.mlir | 19 ++++++++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 40085a2368009..ed2efd6fea5f7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -1399,6 +1399,10 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter, continue; } AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand); + if (IndexingMap.getNumResults() == 0) { + paddedInputs.push_back(operand->get()); + continue; + } SmallVector operandLowPads(IndexingMap.getNumResults(), getAsIndexOpFoldResult(ctx, 0)); SmallVector operandHighPads(IndexingMap.getNumResults(), diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index 0e42027644797..fb16e1e7dcda4 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -1559,4 +1559,21 @@ func.func @nopush_rankreducingextract(%arg0: tensor<128x128x128xf32>, %arg1: ten // CHECK-LABEL: func.func @nopush_rankreducingextract // CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK: return %[[GENERIC]] +// CHECK: return %[[GENERIC]] + +// ----- + +func.func @push_extract_through_generic_rank0_operand(%arg0: tensor<128x128xf32>, %arg1: tensor, %arg2: index, %arg3 : f32) -> tensor { + %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> ()> ,affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice, %arg3 : tensor, f32) outs(%arg1 : tensor) { + ^bb0(%in: f32, %in_1 : f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + linalg.yield %1 : bf16 + } -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @push_extract_through_generic_rank0_operand +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]] +// CHECK: return %[[EXTRACT]]