Skip to content

Commit 00bc355

Browse files
yzhang93krishna2803
authored andcommitted
[mlir][linalg] Fix padding shape computation in PadTilingInterface for convs (llvm#149576)
This PR fixes the computation of padded shapes for convolution-style affine maps (e.g., d0 + d1) in `PadTilingInterface`. Previously, the codes used the direct sum of loop upper bounds, leading to over-padding. For example, the following `conv_2d_nhwc_fhwc` op, if only padding the c dimensions to multiples of 16, it also incorrectly pads the convolved dimensions and generates the wrong input shape as: ``` %padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 1, 1, 12] { ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): tensor.yield %cst : f32 } : tensor<1x16x16x4xf32> to tensor<1x17x17x16xf32> %padded_0 = tensor.pad %arg1 low[0, 0, 0, 0] high[0, 0, 0, 12] { ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): tensor.yield %cst : f32 } : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%padded, %padded_0 : tensor<1x17x17x16xf32>, tensor<16x3x3x16xf32>) outs(%arg2 : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> return %0 : tensor<1x14x14x16xf32> ``` The new implementation uses the maximum accessed index as the input for affine map and then adds 1 after aggregating all the terms to get the final padded size. This fixed llvm#148679.
1 parent d2225b5 commit 00bc355

File tree

5 files changed

+206
-26
lines changed

5 files changed

+206
-26
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,7 @@ def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interfac
11911191
iteration domain induces a padding of the operands that is consistent
11921192
across the op semantics and, unlike for simple elementwise ops, may not be
11931193
trivially deducible or specifiable on operands only (e.g. convolutions).
1194+
Currently, only a limited set of projected permutation maps are supported.
11941195

11951196
The specification of `padding_sizes` follows that of `tile_sizes` during
11961197
tiling: the value "0" on a particular iterator encode "no padding". Like in

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,13 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
611611
/// affine.apply operations.
612612
/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and
613613
/// provides a gentle portability path for Linalg-like ops with affine maps.
614+
/// The padded shape is computed by evaluating the maximum accessed index per
615+
/// dimension, which may involve multiplying by constant factors derived from
616+
/// the affine indexing expressions. Currently, only a limited set of projected
617+
/// permuation indexing maps are supported, such as
618+
/// - affine_map<(d0, d1, d2) -> (d0, d1)>
619+
/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)>
620+
/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
614621
/// In the future, more general interfaces can be devised to encode similar
615622
/// shape evolutions and map between an op and its operands.
616623
SmallVector<OpFoldResult>

mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,28 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
5555
return paddingSizes;
5656
}
5757

58+
/// Extracts the constant multiplier from an affine expression of the form
59+
/// `d * c` or `c * d`, where `d` is an AffineDimExpr and `c` is an
60+
/// AffineConstantExpr. Returns 1 if the expression is not a simple
61+
/// multiplication of a dimension and a constant.
62+
static int64_t extractConstantMultiplier(AffineExpr expr) {
63+
if (auto binOp = dyn_cast<AffineBinaryOpExpr>(expr)) {
64+
if (binOp.getKind() == AffineExprKind::Mul) {
65+
auto lhsD = dyn_cast<AffineDimExpr>(binOp.getLHS());
66+
auto rhsC = dyn_cast<AffineConstantExpr>(binOp.getRHS());
67+
if (lhsD && rhsC) {
68+
return rhsC.getValue();
69+
}
70+
auto lhsC = dyn_cast<AffineConstantExpr>(binOp.getLHS());
71+
auto rhsD = dyn_cast<AffineDimExpr>(binOp.getRHS());
72+
if (lhsC && rhsD) {
73+
return lhsC.getValue();
74+
}
75+
}
76+
}
77+
return 1;
78+
}
79+
5880
/// Compute the padded shape of the given value `v` of `RankedTensorType` given
5981
/// - `indexingSizes` a list of OpFoldResult.
6082
/// - an `indexingMap` that encodes how the shape of varies with increases
@@ -63,6 +85,13 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
6385
/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps.
6486
/// The implementaiton below iteratively combines increases from contributing
6587
/// dimensions using affine.apply operations.
88+
/// The padded shape is computed by evaluating the maximum accessed index per
89+
/// dimension, which may involve multiplying by constant factors derived from
90+
/// the affine indexing expressions. Currently, only a limited set of projected
91+
/// permutation indexing maps are supported, such as
92+
/// - affine_map<(d0, d1, d2) -> (d0, d1)>
93+
/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)>
94+
/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
6695
/// In the future, more general interfaces can be devised to encode similar
6796
/// shape evolutions and map between an op and its operands.
6897
SmallVector<OpFoldResult> linalg::computePaddedShape(
@@ -114,24 +143,33 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
114143
/*compressDims=*/true);
115144

116145
// If we are padding to the next multiple of, compose with ceil(sz) * sz.
146+
OpFoldResult paddingDimOfr;
117147
if (options.padToMultipleOf) {
118148
AffineExpr d0, s0;
119149
bindDims(rewriter.getContext(), d0);
120150
bindSymbols(rewriter.getContext(), s0);
121151
AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
122152
AffineMap composedMap = projectedMap.compose(ceilMap);
123-
OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
153+
paddingDimOfr = affine::makeComposedFoldedAffineApply(
124154
rewriter, loc, composedMap,
125155
{indexingSizes[paddingDim], paddingSize},
126156
/*composeAffineMin=*/true);
127-
terms.push_back(paddingDimOfr);
128157
} else {
129158
// Otherwise just set to paddingSize.
130-
OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
159+
paddingDimOfr = affine::makeComposedFoldedAffineApply(
131160
rewriter, loc, projectedMap, paddingSize);
132-
terms.push_back(paddingDimOfr);
133161
}
134162

163+
// Adjust for the maximum accessed index, which is (paddingSize - 1) *
164+
// multiplier.
165+
AffineExpr d0;
166+
bindDims(rewriter.getContext(), d0);
167+
int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0));
168+
AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier);
169+
OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
170+
rewriter, loc, subtractMap, {paddingDimOfr});
171+
terms.push_back(maxAccessIdx);
172+
135173
LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
136174
}
137175

@@ -148,8 +186,9 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
148186
AffineExpr sumExpr = dims.front();
149187
for (unsigned i = 1; i < dims.size(); ++i)
150188
sumExpr = sumExpr + dims[i];
151-
OpFoldResult paddedDimOfr =
152-
affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms);
189+
// Add 1 to the maximum accessed index and get the final padded size.
190+
OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply(
191+
rewriter, loc, sumExpr + 1, terms);
153192
paddedShape[resultIndex] = paddedDimOfr;
154193
}
155194

mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir

Lines changed: 141 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,22 @@ module {
5252

5353
// CHECK-LABEL: @generic
5454
// CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>,
55-
// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>)
56-
func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
55+
// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>)
56+
func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
5757

5858
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.
5959

6060
// CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[2, 0]
6161
// CHECK: : tensor<7x5xf32> to tensor<9x5xf32>
6262
// CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[2, 4, 2] {
63-
// CHECK: : tensor<7x11x12xf32> to tensor<9x15x14xf32>
63+
// CHECK: : tensor<7x11x11xf32> to tensor<9x15x13xf32>
6464
// CHECK-NEXT: linalg.generic
65-
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<9x15x14xf32> to tensor<7x11x12xf32>
66-
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
65+
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<9x15x13xf32> to tensor<7x11x11xf32>
66+
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
6767
^bb0(%in: f32, %out: f32):
6868
linalg.yield %in : f32
69-
} -> tensor<7x11x12xf32>
70-
return %0 : tensor<7x11x12xf32>
69+
} -> tensor<7x11x11xf32>
70+
return %0 : tensor<7x11x11xf32>
7171
}
7272
module attributes {transform.with_named_sequence} {
7373
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -83,7 +83,7 @@ module {
8383
// -----
8484

8585
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3)>
86-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 5)>
86+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 4)>
8787
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
8888

8989
#map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -272,3 +272,136 @@ module attributes {transform.with_named_sequence} {
272272
}
273273
}
274274

275+
// -----
276+
277+
// CHECK-LABEL: pad_conv
278+
func.func @pad_conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
279+
280+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
281+
// CHECK: : tensor<1x16x16x4xf32> to tensor<1x16x18x16xf32>
282+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
283+
// CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
284+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
285+
// CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
286+
// CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
287+
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
288+
289+
%0 = linalg.conv_2d_nhwc_fhwc
290+
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
291+
ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
292+
outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
293+
return %0 : tensor<1x14x14x16xf32>
294+
}
295+
296+
module attributes {transform.with_named_sequence} {
297+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
298+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
299+
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
300+
padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
301+
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
302+
transform.yield
303+
}
304+
}
305+
306+
// -----
307+
308+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16 + 2)>
309+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16)>
310+
311+
// CHECK-LABEL: pad_conv_dynamic
312+
func.func @pad_conv_dynamic(%arg0: tensor<1x16x?x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> {
313+
314+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
315+
// CHECK: %[[D0_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
316+
// CHECK: %[[D0_1:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x16x?x4xf32>
317+
// CHECK: %[[H0:.*]] = affine.apply #[[$MAP0]]()[%[[D0_0]], %[[D0_1]]]
318+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H0]], 12]
319+
// CHECK: : tensor<1x16x?x4xf32> to tensor<1x16x?x16xf32>
320+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
321+
// CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
322+
// CHECK: %[[D1_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
323+
// CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]]()[%[[D0_0]], %[[D1_0]]]
324+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H1]], 0]
325+
// CHECK: : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
326+
// CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
327+
// CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
328+
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, %[[D2_0]], 16] [1, 1, 1, 1] : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
329+
330+
%0 = linalg.conv_2d_nhwc_fhwc
331+
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
332+
ins(%arg0, %arg1: tensor<1x16x?x4xf32>, tensor<16x3x3x4xf32>)
333+
outs(%arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32>
334+
return %0 : tensor<1x14x?x16xf32>
335+
}
336+
337+
module attributes {transform.with_named_sequence} {
338+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
339+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
340+
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
341+
padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
342+
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
343+
transform.yield
344+
}
345+
}
346+
347+
// -----
348+
349+
// CHECK-LABEL: pad_conv_strided
350+
func.func @pad_conv_strided(%arg0: tensor<1x42x42x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
351+
352+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 6, 12]
353+
// CHECK: : tensor<1x42x42x4xf32> to tensor<1x42x48x16xf32>
354+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
355+
// CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
356+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
357+
// CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
358+
// CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
359+
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
360+
361+
%0 = linalg.conv_2d_nhwc_fhwc
362+
{dilations = dense<1> : tensor<2xi64>, strides = dense<3> : tensor<2xi64> }
363+
ins(%arg0, %arg1: tensor<1x42x42x4xf32>, tensor<16x3x3x4xf32>)
364+
outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
365+
return %0 : tensor<1x14x14x16xf32>
366+
}
367+
368+
module attributes {transform.with_named_sequence} {
369+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
370+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
371+
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
372+
padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
373+
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
374+
transform.yield
375+
}
376+
}
377+
378+
// -----
379+
380+
// CHECK-LABEL: pad_conv_dilated
381+
func.func @pad_conv_dilated(%arg0: tensor<1x18x18x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
382+
383+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
384+
// CHECK: : tensor<1x18x18x4xf32> to tensor<1x18x20x16xf32>
385+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
386+
// CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
387+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
388+
// CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
389+
// CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
390+
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
391+
392+
%0 = linalg.conv_2d_nhwc_fhwc
393+
{dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
394+
ins(%arg0, %arg1: tensor<1x18x18x4xf32>, tensor<16x3x3x4xf32>)
395+
outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
396+
return %0 : tensor<1x14x14x16xf32>
397+
}
398+
399+
module attributes {transform.with_named_sequence} {
400+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
401+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
402+
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
403+
padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
404+
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
405+
transform.yield
406+
}
407+
}

mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,22 +69,22 @@ module {
6969

7070
// CHECK-LABEL: @generic
7171
// CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>,
72-
// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>)
73-
func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
72+
// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>)
73+
func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
7474

7575
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.
7676

7777
// CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[1, 0]
7878
// CHECK: : tensor<7x5xf32> to tensor<8x5xf32>
7979
// CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[1, 3, 1] {
80-
// CHECK: : tensor<7x11x12xf32> to tensor<8x14x13xf32>
80+
// CHECK: : tensor<7x11x11xf32> to tensor<8x14x12xf32>
8181
// CHECK-NEXT: linalg.generic
82-
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<8x14x13xf32> to tensor<7x11x12xf32>
83-
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
82+
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<8x14x12xf32> to tensor<7x11x11xf32>
83+
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
8484
^bb0(%in: f32, %out: f32):
8585
linalg.yield %in : f32
86-
} -> tensor<7x11x12xf32>
87-
return %0 : tensor<7x11x12xf32>
86+
} -> tensor<7x11x11xf32>
87+
return %0 : tensor<7x11x11xf32>
8888
}
8989
module attributes {transform.with_named_sequence} {
9090
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -102,7 +102,7 @@ module {
102102

103103

104104
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (-s0 + 8)>
105-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 13)>
105+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 12)>
106106
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
107107

108108
#map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -127,13 +127,13 @@ module {
127127
// CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<?x11x?xf32>
128128
// CHECK: %[[H2:.*]] = affine.apply #[[$MAP1]]()[%[[D2_0]]]
129129
// CHECK: tensor.pad %{{.*}} low[0, 0, 0] high[%[[H1]], 3, %[[H2]]] {
130-
// CHECK: : tensor<?x11x?xf32> to tensor<8x14x13xf32>
130+
// CHECK: : tensor<?x11x?xf32> to tensor<8x14x12xf32>
131131
//
132132
// CHECK: %[[D0_2:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x5xf32>
133133
// CHECK: %[[D2_1:.*]] = affine.apply #[[$MAP2]]()[%[[D0_2]]]
134-
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x13xf32>) {
135-
// CHECK: } -> tensor<8x14x13xf32>
136-
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x13xf32> to tensor<?x11x?xf32>
134+
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x12xf32>) {
135+
// CHECK: } -> tensor<8x14x12xf32>
136+
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x12xf32> to tensor<?x11x?xf32>
137137
//
138138
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<?x5xf32>) outs(%arg1 : tensor<?x11x?xf32>) {
139139
^bb0(%in: f32, %out: f32):

0 commit comments

Comments
 (0)