Skip to content

Commit e69f389

Browse files
authored
Merge pull request #470 from Xilinx/jose.fix_insert_tile_problem
Fix problem where the shape of the insert shape was calculated incorrectly
2 parents b31574a + fe9d73c commit e69f389

File tree

5 files changed

+112
-14
lines changed

5 files changed

+112
-14
lines changed

mlir/include/mlir/IR/AffineExpr.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ class AffineExpr {
110110
/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
111111
bool isPureAffine() const;
112112

113+
/// Returns true if this expression is monotonicically increasing with respect
114+
/// to the AffineDimExprs, i.e. increasing the value of any AffineDimExpr will
115+
/// never decrease the value of the result.
116+
bool isMonotonicallyIncreasing() const;
117+
113118
/// Returns the greatest known integral divisor of this affine expression. The
114119
/// result is always positive.
115120
int64_t getLargestKnownDivisor() const;

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,15 @@ struct LinalgOpTilingInterface
218218
}));
219219

220220
OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
221+
SmallVector<OpFoldResult> allShapeSizes =
222+
linalgOp.createFlatListOfOperandDims(b, linalgOp.getLoc());
223+
SmallVector<OpFoldResult> sizeBounds =
224+
mlir::affine::makeComposedFoldedMultiResultAffineApply(
225+
b, loc, linalgOp.getShapesToLoopsMap(), allShapeSizes);
221226
SliceParameters sliceParams = computeSliceParameters(
222227
b, loc, outOperand->get(), sizes,
223228
linalgOp.getMatchingIndexingMap(outOperand), offsets,
224-
/*ubs*/ {}, subShapeSizes, true);
229+
/*ubs*/ sizeBounds, subShapeSizes, true);
225230
resultOffsets = sliceParams.offsets;
226231
resultSizes = sliceParams.sizes;
227232
return success();

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,24 @@ namespace {
5656
// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
5757
//
5858
struct TileCheck : public AffineExprVisitor<TileCheck> {
59-
TileCheck(ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> sizeBounds)
60-
: tileSizes(tileSizes), sizeBounds(sizeBounds) {}
59+
TileCheck(ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> sizeBounds,
60+
bool isMonotonicallyIncreasing)
61+
: tileSizes(tileSizes), sizeBounds(sizeBounds),
62+
isMonotonicallyIncreasing(isMonotonicallyIncreasing) {}
6163

6264
void visitDimExpr(AffineDimExpr expr) {
6365
unsigned pos = expr.getPosition();
6466

65-
// This dimension is tiled if the tile size is larger than zero and not
66-
// equal to its domain size (if statically known).
67-
std::optional<int64_t> tileSize = getConstantIntValue(tileSizes[pos]);
68-
if (tileSize && !sizeBounds.empty()) {
69-
std::optional<int64_t> sizeBound = getConstantIntValue(sizeBounds[pos]);
70-
if (sizeBound && *sizeBound == *tileSize) {
71-
return;
67+
// If the expression is non monotonic, this dimension is tiled if the tile
68+
// size is larger than zero and not equal to its domain size (if statically
69+
// known).
70+
if (!isMonotonicallyIncreasing) {
71+
std::optional<int64_t> tileSize = getConstantIntValue(tileSizes[pos]);
72+
if (tileSize && !sizeBounds.empty()) {
73+
std::optional<int64_t> sizeBound = getConstantIntValue(sizeBounds[pos]);
74+
if (sizeBound && *sizeBound == *tileSize) {
75+
return;
76+
}
7277
}
7378
}
7479

@@ -84,6 +89,7 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
8489
bool isTiled = false;
8590
ArrayRef<OpFoldResult> tileSizes;
8691
ArrayRef<OpFoldResult> sizeBounds;
92+
bool isMonotonicallyIncreasing;
8793
};
8894

8995
} // namespace
@@ -92,7 +98,7 @@ static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes,
9298
ArrayRef<OpFoldResult> sizeBounds) {
9399
if (!expr)
94100
return false;
95-
TileCheck t(tileSizes, sizeBounds);
101+
TileCheck t(tileSizes, sizeBounds, expr.isMonotonicallyIncreasing());
96102
t.visit(expr);
97103
return t.isTiled;
98104
}

mlir/lib/IR/AffineExpr.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,42 @@ bool AffineExpr::isPureAffine() const {
239239
llvm_unreachable("Unknown AffineExpr");
240240
}
241241

242+
static bool isNonNegativeConstant(AffineExpr expr) {
243+
auto constant = dyn_cast<AffineConstantExpr>(expr);
244+
return constant && constant.getValue() >= 0;
245+
}
246+
247+
bool AffineExpr::isMonotonicallyIncreasing() const {
248+
switch (getKind()) {
249+
case AffineExprKind::SymbolId:
250+
case AffineExprKind::DimId:
251+
case AffineExprKind::Constant:
252+
return true;
253+
case AffineExprKind::Add: {
254+
auto op = llvm::cast<AffineBinaryOpExpr>(*this);
255+
return op.getLHS().isMonotonicallyIncreasing() &&
256+
op.getRHS().isMonotonicallyIncreasing();
257+
}
258+
case AffineExprKind::Mul: {
259+
// One operand must be a non-negative constant.
260+
auto op = llvm::cast<AffineBinaryOpExpr>(*this);
261+
return op.getLHS().isMonotonicallyIncreasing() &&
262+
op.getRHS().isMonotonicallyIncreasing() &&
263+
(isNonNegativeConstant(op.getLHS()) ||
264+
isNonNegativeConstant(op.getRHS()));
265+
}
266+
case AffineExprKind::FloorDiv:
267+
case AffineExprKind::CeilDiv: {
268+
auto op = llvm::cast<AffineBinaryOpExpr>(*this);
269+
return op.getLHS().isMonotonicallyIncreasing() &&
270+
isNonNegativeConstant(op.getRHS());
271+
}
272+
case AffineExprKind::Mod:
273+
return false;
274+
}
275+
llvm_unreachable("Unknown AffineExpr");
276+
}
277+
242278
// Returns the greatest known integral divisor of this affine expression.
243279
int64_t AffineExpr::getLargestKnownDivisor() const {
244280
AffineBinaryOpExpr binExpr(nullptr);

mlir/test/Dialect/Linalg/tile-tensors.mlir

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,14 @@ func.func @non_monotonic_affine_expr(%arg0 : tensor<7xf32>) -> tensor<7xf32> {
177177
%0 = tensor.dim %arg0, %c0 : tensor<7xf32>
178178
%empty = tensor.empty() : tensor<7xf32>
179179

180-
// CHECK: %[[OUT:.*]] = tensor.empty() : tensor<7xf32>
181-
// CHECK: scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[OUT]]) -> (tensor<7xf32>) {
182-
// CHECK: tensor.extract_slice %[[TC0]][0] [7] [1] : tensor<7xf32> to tensor<7xf32>
180+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
181+
// CHECK-DAG: %[[OUT:.*]] = tensor.empty() : tensor<7xf32>
182+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
183+
// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
184+
// CHECK-DAG: %[[C7_1:.*]] = arith.constant 7 : index
185+
// CHECK: scf.for %[[IV0:.+]] = %[[C0]] to %[[C7]] step %[[C7_1]] iter_args(%[[TC0:.*]] = %[[OUT]]) -> (tensor<7xf32>) {
186+
// CHECK: tensor.extract_slice %[[ARG0]][0] [7] [1] : tensor<7xf32> to tensor<7xf32>
187+
// CHECK: tensor.extract_slice %[[TC0]][%[[IV0]]] [7] [1] : tensor<7xf32> to tensor<7xf32>
183188
%generic = linalg.generic
184189
{indexing_maps = [affine_map<(d0) -> (d0 mod 4)>,
185190
affine_map<(d0) -> (d0)>],
@@ -199,3 +204,44 @@ module attributes {transform.with_named_sequence} {
199204
transform.yield
200205
}
201206
}
207+
208+
// -----
209+
210+
#identity = affine_map<(d0, d1) -> (d0, d1)>
211+
#identity1 = affine_map<(d0, d1) -> (d0 mod 3, d1)>
212+
213+
// CHECK-LABEL: func @tile_monotonic_outer_dim
214+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x10xf32>
215+
func.func @tile_monotonic_outer_dim(%in: tensor<4x10xf32>) -> tensor<4x10xf32> {
216+
%empty = tensor.empty() : tensor<4x10xf32>
217+
%1 = linalg.generic {indexing_maps = [#identity, #identity1], iterator_types = ["parallel", "parallel"]}
218+
ins(%in : tensor<4x10xf32>) outs(%empty : tensor<4x10xf32>) {
219+
^bb1(%a: f32, %b: f32):
220+
linalg.yield %a : f32
221+
} -> tensor<4x10xf32>
222+
223+
// CHECK: %[[C4:.+]] = arith.constant 4 : index
224+
// CHECK: %[[C4_1:.+]] = arith.constant 4 : index
225+
// CHECK: %[[C5:.+]] = arith.constant 5 : index
226+
// CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[C4]] step %[[C4_1]] iter_args(%[[ARG1:.+]] = %[[OUT:.+]]) -> (tensor<4x10xf32>) {
227+
// CHECK: scf.for %[[IV1:.+]] = %{{.+}} to %{{.+}} step %[[C5]] iter_args(%[[ARG2:.+]] = %[[ARG1]]) -> (tensor<4x10xf32>) {
228+
// CHECK: %[[INSLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]] [4, 5] [1, 1] : tensor<4x10xf32> to tensor<4x5xf32>
229+
// CHECK: %[[OUTSLICE:.+]] = tensor.extract_slice %[[ARG2]][0, %[[IV1]]] [4, 5] [1, 1] : tensor<4x10xf32> to tensor<4x5xf32>
230+
// CHECK: %[[RES:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[INSLICE]] : tensor<4x5xf32>) outs(%[[OUTSLICE]] : tensor<4x5xf32>) {
231+
// CHECK: ^bb0(%in: f32, %out: f32):
232+
// CHECK: linalg.yield %in : f32
233+
// CHECK: } -> tensor<4x5xf32>
234+
// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[RES]] into %[[ARG2]][0, %[[IV1]]] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<4x10xf32>
235+
// CHECK: scf.yield %[[INSERT_SLICE]] : tensor<4x10xf32>
236+
// CHECK: }
237+
238+
return %1 : tensor<4x10xf32>
239+
}
240+
241+
module attributes {transform.with_named_sequence} {
242+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
243+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
244+
%1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [4, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
245+
transform.yield
246+
}
247+
}

0 commit comments

Comments
 (0)