Skip to content

Commit 23020a8

Browse files
[mlir] Add optimization to bubbleUpPadSlice pattern for no pad case (#135859)
In cases where there is no padding on a dim, we do not need to compute new offsets, lengths and padding, for example the new test case added can just be lowered to ``` %extracted_slice = tensor.extract_slice %arg0[%arg2, 1, 2] [%arg2, 2, 1] [1, 1, 1] : tensor<3x4x5xf32> to tensor<?x2x1xf32> ``` without this PR we will have affine maps like ``` #map = affine_map<()[s0] -> (3, s0)> #map1 = affine_map<()[s0, s1] -> (-s0 + 3, s1)> %0 = affine.min #map()[%arg2] %1 = affine.min #map1()[%0, %arg2] %extracted_slice = tensor.extract_slice %arg0[%0, 1, 2] [%1, 2, 1] [1, 1, 1] : tensor<3x4x5xf32> to tensor<?x2x1xf32> ``` which are unnecessary Signed-off-by: Nirvedh <[email protected]>
1 parent 6182015 commit 23020a8

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
122122
OpFoldResult zero = b.getIndexAttr(0);
123123

124124
// Compute new offsets, lengths, low padding, high padding.
125-
SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
125+
SmallVector<OpFoldResult> newOffsets, newLengths;
126126
SmallVector<OpFoldResult> newLows, newHighs;
127127
// Set to true if the original data source is not read at all.
128128
bool hasZeroLen = false;
@@ -131,13 +131,25 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
131131
Value dynHasZeroLenCond;
132132

133133
int64_t rank = padOp.getSourceType().getRank();
134+
// Only unit stride supported.
135+
SmallVector<OpFoldResult> newStrides(rank, b.getIndexAttr(1));
134136
for (unsigned dim = 0; dim < rank; ++dim) {
135137
auto low = padOp.getMixedLowPad()[dim];
136138
bool hasLowPad = !isConstantIntValue(low, 0);
137139
auto high = padOp.getMixedHighPad()[dim];
138140
bool hasHighPad = !isConstantIntValue(high, 0);
139141
auto offset = offsets[dim];
140142
auto length = sizes[dim];
143+
// If the dim has no padding, we dont need to calculate new values for that
144+
// dim as the exisiting ones are correct even after the pattern.
145+
if (!hasLowPad && !hasHighPad) {
146+
newOffsets.push_back(offset);
147+
newLengths.push_back(length);
148+
newLows.push_back(low);
149+
newHighs.push_back(high);
150+
continue;
151+
}
152+
141153
auto srcSize = tensor::getMixedSize(b, loc, padOp.getSource(), dim);
142154

143155
// The new amount of low padding is `low - offset`. Except for the case
@@ -216,9 +228,6 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
216228
OpFoldResult newHigh =
217229
hasHighPad ? sub(sub(length, newLength), newLow) : zero;
218230
newHighs.push_back(newHigh);
219-
220-
// Only unit stride supported.
221-
newStrides.push_back(b.getIndexAttr(1));
222231
}
223232

224233
// The shape of the result can be obtained from the sizes passed in.

mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@
77
// CHECK: scf.for
88
// CHECK: memref.alloc() : memref<128x16xf32, 3>
99
// CHECK: scf.forall
10-
// CHECK: vector.create_mask
10+
// CHECK: vector.constant_mask [16, 4] : vector<128x4xi1>
1111
// CHECK: vector.transfer_read
1212
// CHECK: vector.transfer_write
1313
// CHECK: memref.alloc() : memref<16x128xf32, 3>
1414
// CHECK: scf.forall
15-
// CHECK: vector.create_mask
15+
// CHECK: vector.constant_mask [16, 4] : vector<128x4xi1>
1616
// CHECK: vector.transfer_read
1717
// CHECK: vector.transfer_write
1818
// CHECK: memref.alloc() : memref<128x128xf32, 3>
1919
// CHECK: scf.forall
20-
// CHECK: vector.create_mask
20+
// CHECK-NOT: mask
2121
// CHECK: vector.transfer_read
2222
// CHECK: vector.transfer_write
2323
// CHECK: linalg.matmul

mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,20 @@ func.func @dynamic_zero_high_padding(%arg0 : tensor<?x?xf32>, %pad : f32,
216216
%1 = tensor.extract_slice %0[%o1, %o2] [%s1, %s2] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
217217
return %1 : tensor<?x?xf32>
218218
}
219+
220+
// -----
221+
// CHECK-LABEL: @nopaddim_with_dynamic_extract(
222+
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4x5xf32>
223+
// CHECK-SAME: %[[ARG1:.*]]: f32
224+
// CHECK-SAME: %[[ARG2:.*]]: index
225+
// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 1, 2] [%[[ARG2]], 2, 1] [1, 1, 1] : tensor<3x4x5xf32> to tensor<?x2x1xf32>
226+
// CHECK: return %[[RESULT]]
227+
func.func @nopaddim_with_dynamic_extract(%arg0 : tensor<3x4x5xf32>, %pad : f32, %index : index)
228+
-> tensor<?x2x1xf32> {
229+
%0 = tensor.pad %arg0 low[0, 0, 0] high[0, 7, 8] {
230+
^bb0(%arg1: index, %arg2: index, %arg3: index):
231+
tensor.yield %pad : f32
232+
} : tensor<3x4x5xf32> to tensor<3x11x13xf32>
233+
%1 = tensor.extract_slice %0[%index, 1, 2] [%index, 2, 1] [1, 1, 1] : tensor<3x11x13xf32> to tensor<?x2x1xf32>
234+
return %1 : tensor<?x2x1xf32>
235+
}

0 commit comments

Comments
 (0)