Skip to content

Commit daf8f9f

Browse files
[Linalg] Add rank zero operand support to push down extract slice pattern (llvm#157532)
Currently the pattern would crash for rank 0 operand as it decides the padding based on affine results, but for rank 0 there are no affine results in the operand affine map Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent e7e4caf commit daf8f9f

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,6 +1399,10 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
13991399
continue;
14001400
}
14011401
AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
1402+
if (IndexingMap.getNumResults() == 0) {
1403+
paddedInputs.push_back(operand->get());
1404+
continue;
1405+
}
14021406
SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
14031407
getAsIndexOpFoldResult(ctx, 0));
14041408
SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1559,4 +1559,21 @@ func.func @nopush_rankreducingextract(%arg0: tensor<128x128x128xf32>, %arg1: ten
15591559

15601560
// CHECK-LABEL: func.func @nopush_rankreducingextract
15611561
// CHECK: %[[GENERIC:.+]] = linalg.generic
1562-
// CHECK: return %[[GENERIC]]
1562+
// CHECK: return %[[GENERIC]]
1563+
1564+
// -----
1565+
1566+
func.func @push_extract_through_generic_rank0_operand(%arg0: tensor<128x128xf32>, %arg1: tensor<?x?xbf16>, %arg2: index, %arg3 : f32) -> tensor<?x?xbf16> {
1567+
%extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
1568+
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> ()> ,affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice, %arg3 : tensor<?x?xf32>, f32) outs(%arg1 : tensor<?x?xbf16>) {
1569+
^bb0(%in: f32, %in_1 : f32, %out: bf16):
1570+
%1 = arith.truncf %in : f32 to bf16
1571+
linalg.yield %1 : bf16
1572+
} -> tensor<?x?xbf16>
1573+
return %0 : tensor<?x?xbf16>
1574+
}
1575+
1576+
// CHECK-LABEL: func.func @push_extract_through_generic_rank0_operand
1577+
// CHECK: %[[GENERIC:.+]] = linalg.generic
1578+
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]]
1579+
// CHECK: return %[[EXTRACT]]

0 commit comments

Comments
 (0)