Skip to content

Commit a0b980a

Browse files
committed
[mlir] Fix padding shape computation in PadTilingInterface
1 parent 13f7786 commit a0b980a

File tree

3 files changed

+104
-24
lines changed

3 files changed

+104
-24
lines changed

mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,24 +114,31 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
114114
/*compressDims=*/true);
115115

116116
// If we are padding to the next multiple of, compose with ceil(sz) * sz.
117+
OpFoldResult paddingDimOfr;
117118
if (options.padToMultipleOf) {
118119
AffineExpr d0, s0;
119120
bindDims(rewriter.getContext(), d0);
120121
bindSymbols(rewriter.getContext(), s0);
121122
AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
122123
AffineMap composedMap = projectedMap.compose(ceilMap);
123-
OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
124+
paddingDimOfr = affine::makeComposedFoldedAffineApply(
124125
rewriter, loc, composedMap,
125126
{indexingSizes[paddingDim], paddingSize},
126127
/*composeAffineMin=*/true);
127-
terms.push_back(paddingDimOfr);
128128
} else {
129129
// Otherwise just set to paddingSize.
130-
OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
130+
paddingDimOfr = affine::makeComposedFoldedAffineApply(
131131
rewriter, loc, projectedMap, paddingSize);
132-
terms.push_back(paddingDimOfr);
133132
}
134133

134+
// Adjust for the maximum accessed index which is (padding_size - 1).
135+
AffineExpr d0;
136+
bindDims(rewriter.getContext(), d0);
137+
AffineMap subtractOneMap = AffineMap::get(1, 0, d0 - 1);
138+
OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
139+
rewriter, loc, subtractOneMap, {paddingDimOfr});
140+
terms.push_back(maxAccessIdx);
141+
135142
LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
136143
}
137144

@@ -148,6 +155,8 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
148155
AffineExpr sumExpr = dims.front();
149156
for (unsigned i = 1; i < dims.size(); ++i)
150157
sumExpr = sumExpr + dims[i];
158+
// Add 1 to the maximum accessed index and get the final padded size.
159+
sumExpr = sumExpr + rewriter.getAffineConstantExpr(1);
151160
OpFoldResult paddedDimOfr =
152161
affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms);
153162
paddedShape[resultIndex] = paddedDimOfr;

mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir

Lines changed: 79 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,22 @@ module {
5252

5353
// CHECK-LABEL: @generic
5454
// CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>,
55-
// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>)
56-
func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
55+
// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>)
56+
func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
5757

5858
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.
5959

6060
// CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[2, 0]
6161
// CHECK: : tensor<7x5xf32> to tensor<9x5xf32>
6262
// CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[2, 4, 2] {
63-
// CHECK: : tensor<7x11x12xf32> to tensor<9x15x14xf32>
63+
// CHECK: : tensor<7x11x11xf32> to tensor<9x15x13xf32>
6464
// CHECK-NEXT: linalg.generic
65-
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<9x15x14xf32> to tensor<7x11x12xf32>
66-
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
65+
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<9x15x13xf32> to tensor<7x11x11xf32>
66+
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
6767
^bb0(%in: f32, %out: f32):
6868
linalg.yield %in : f32
69-
} -> tensor<7x11x12xf32>
70-
return %0 : tensor<7x11x12xf32>
69+
} -> tensor<7x11x11xf32>
70+
return %0 : tensor<7x11x11xf32>
7171
}
7272
module attributes {transform.with_named_sequence} {
7373
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -83,7 +83,7 @@ module {
8383
// -----
8484

8585
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3)>
86-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 5)>
86+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 4)>
8787
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
8888

8989
#map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -272,3 +272,74 @@ module attributes {transform.with_named_sequence} {
272272
}
273273
}
274274

275+
// -----
276+
277+
// CHECK-LABEL: pad_conv
278+
func.func @pad_conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
279+
280+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
281+
// CHECK: : tensor<1x16x16x4xf32> to tensor<1x16x18x16xf32>
282+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
283+
// CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
284+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
285+
// CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
286+
// CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
287+
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
288+
289+
%0 = linalg.conv_2d_nhwc_fhwc
290+
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
291+
ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
292+
outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
293+
return %0 : tensor<1x14x14x16xf32>
294+
}
295+
296+
module attributes {transform.with_named_sequence} {
297+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
298+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
299+
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
300+
padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
301+
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
302+
transform.yield
303+
}
304+
}
305+
306+
// -----
307+
308+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16 + 2)>
309+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16)>
310+
311+
// CHECK-LABEL: pad_conv_dynamic
312+
func.func @pad_conv_dynamic(%arg0: tensor<1x16x?x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> {
313+
314+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
315+
// CHECK: %[[D0_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
316+
// CHECK: %[[D0_1:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x16x?x4xf32>
317+
// CHECK: %[[H0:.*]] = affine.apply #[[$MAP0]]()[%[[D0_0]], %[[D0_1]]]
318+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H0]], 12]
319+
// CHECK: : tensor<1x16x?x4xf32> to tensor<1x16x?x16xf32>
320+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
321+
// CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
322+
// CHECK: %[[D1_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
323+
// CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]]()[%[[D0_0]], %[[D1_0]]]
324+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H1]], 0]
325+
// CHECK: : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
326+
// CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
327+
// CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
328+
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, %[[D2_0]], 16] [1, 1, 1, 1] : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
329+
330+
%0 = linalg.conv_2d_nhwc_fhwc
331+
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
332+
ins(%arg0, %arg1: tensor<1x16x?x4xf32>, tensor<16x3x3x4xf32>)
333+
outs(%arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32>
334+
return %0 : tensor<1x14x?x16xf32>
335+
}
336+
337+
module attributes {transform.with_named_sequence} {
338+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
339+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
340+
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
341+
padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
342+
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
343+
transform.yield
344+
}
345+
}

mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,22 +69,22 @@ module {
6969

7070
// CHECK-LABEL: @generic
7171
// CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>,
72-
// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>)
73-
func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
72+
// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>)
73+
func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
7474

7575
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.
7676

7777
// CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[1, 0]
7878
// CHECK: : tensor<7x5xf32> to tensor<8x5xf32>
7979
// CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[1, 3, 1] {
80-
// CHECK: : tensor<7x11x12xf32> to tensor<8x14x13xf32>
80+
// CHECK: : tensor<7x11x11xf32> to tensor<8x14x12xf32>
8181
// CHECK-NEXT: linalg.generic
82-
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<8x14x13xf32> to tensor<7x11x12xf32>
83-
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
82+
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<8x14x12xf32> to tensor<7x11x11xf32>
83+
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
8484
^bb0(%in: f32, %out: f32):
8585
linalg.yield %in : f32
86-
} -> tensor<7x11x12xf32>
87-
return %0 : tensor<7x11x12xf32>
86+
} -> tensor<7x11x11xf32>
87+
return %0 : tensor<7x11x11xf32>
8888
}
8989
module attributes {transform.with_named_sequence} {
9090
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -102,7 +102,7 @@ module {
102102

103103

104104
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (-s0 + 8)>
105-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 13)>
105+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 12)>
106106
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
107107

108108
#map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -127,13 +127,13 @@ module {
127127
// CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<?x11x?xf32>
128128
// CHECK: %[[H2:.*]] = affine.apply #[[$MAP1]]()[%[[D2_0]]]
129129
// CHECK: tensor.pad %{{.*}} low[0, 0, 0] high[%[[H1]], 3, %[[H2]]] {
130-
// CHECK: : tensor<?x11x?xf32> to tensor<8x14x13xf32>
130+
// CHECK: : tensor<?x11x?xf32> to tensor<8x14x12xf32>
131131
//
132132
// CHECK: %[[D0_2:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x5xf32>
133133
// CHECK: %[[D2_1:.*]] = affine.apply #[[$MAP2]]()[%[[D0_2]]]
134-
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x13xf32>) {
135-
// CHECK: } -> tensor<8x14x13xf32>
136-
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x13xf32> to tensor<?x11x?xf32>
134+
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x12xf32>) {
135+
// CHECK: } -> tensor<8x14x12xf32>
136+
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x12xf32> to tensor<?x11x?xf32>
137137
//
138138
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<?x5xf32>) outs(%arg1 : tensor<?x11x?xf32>) {
139139
^bb0(%in: f32, %out: f32):

0 commit comments

Comments
 (0)