Skip to content

Commit 01f2f3a

Browse files
committed
Add cases for non-unit strides and dilations
1 parent a0b980a commit 01f2f3a

File tree

2 files changed

+89
-3
lines changed

2 files changed

+89
-3
lines changed

mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,28 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
5555
return paddingSizes;
5656
}
5757

58+
/// Extracts the constant multiplier from an affine expression of the form
59+
/// `d * c` or `c * d`, where `d` is an AffineDimExpr and `c` is an
60+
/// AffineConstantExpr. Returns 1 if the expression is not a simple
61+
/// multiplication of a dimension and a constant.
62+
static int64_t extractConstantMultiplier(AffineExpr expr) {
63+
if (auto binOp = dyn_cast<AffineBinaryOpExpr>(expr)) {
64+
if (binOp.getKind() == AffineExprKind::Mul) {
65+
auto lhsD = dyn_cast<AffineDimExpr>(binOp.getLHS());
66+
auto rhsC = dyn_cast<AffineConstantExpr>(binOp.getRHS());
67+
if (lhsD && rhsC) {
68+
return rhsC.getValue();
69+
}
70+
auto lhsC = dyn_cast<AffineConstantExpr>(binOp.getLHS());
71+
auto rhsD = dyn_cast<AffineDimExpr>(binOp.getRHS());
72+
if (lhsC && rhsD) {
73+
return lhsC.getValue();
74+
}
75+
}
76+
}
77+
return 1;
78+
}
79+
5880
/// Compute the padded shape of the given value `v` of `RankedTensorType` given
5981
/// - `indexingSizes` a list of OpFoldResult.
6082
/// - an `indexingMap` that encodes how the shape of varies with increases
@@ -131,12 +153,14 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
131153
rewriter, loc, projectedMap, paddingSize);
132154
}
133155

134-
// Adjust for the maximum accessed index which is (padding_size - 1).
156+
// Adjust for the maximum accessed index, which is (paddingSize - 1) *
157+
// multiplier.
135158
AffineExpr d0;
136159
bindDims(rewriter.getContext(), d0);
137-
AffineMap subtractOneMap = AffineMap::get(1, 0, d0 - 1);
160+
int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0));
161+
AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier);
138162
OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
139-
rewriter, loc, subtractOneMap, {paddingDimOfr});
163+
rewriter, loc, subtractMap, {paddingDimOfr});
140164
terms.push_back(maxAccessIdx);
141165

142166
LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");

mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,65 @@ module attributes {transform.with_named_sequence} {
343343
transform.yield
344344
}
345345
}
346+
347+
// -----
348+
349+
// CHECK-LABEL: pad_conv_strided
350+
func.func @pad_conv_strided(%arg0: tensor<1x42x42x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
351+
352+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 6, 12]
353+
// CHECK: : tensor<1x42x42x4xf32> to tensor<1x42x48x16xf32>
354+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
355+
// CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
356+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
357+
// CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
358+
// CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
359+
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
360+
361+
%0 = linalg.conv_2d_nhwc_fhwc
362+
{dilations = dense<1> : tensor<2xi64>, strides = dense<3> : tensor<2xi64> }
363+
ins(%arg0, %arg1: tensor<1x42x42x4xf32>, tensor<16x3x3x4xf32>)
364+
outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
365+
return %0 : tensor<1x14x14x16xf32>
366+
}
367+
368+
module attributes {transform.with_named_sequence} {
369+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
370+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
371+
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
372+
padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
373+
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
374+
transform.yield
375+
}
376+
}
377+
378+
// -----
379+
380+
// CHECK-LABEL: pad_conv_dilated
381+
func.func @pad_conv_dilated(%arg0: tensor<1x18x18x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
382+
383+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
384+
// CHECK: : tensor<1x18x18x4xf32> to tensor<1x18x20x16xf32>
385+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
386+
// CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
387+
// CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
388+
// CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
389+
// CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
390+
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
391+
392+
%0 = linalg.conv_2d_nhwc_fhwc
393+
{dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
394+
ins(%arg0, %arg1: tensor<1x18x18x4xf32>, tensor<16x3x3x4xf32>)
395+
outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
396+
return %0 : tensor<1x14x14x16xf32>
397+
}
398+
399+
module attributes {transform.with_named_sequence} {
400+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
401+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
402+
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
403+
padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
404+
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
405+
transform.yield
406+
}
407+
}

0 commit comments

Comments
 (0)