Skip to content

Commit 44a8897

Browse files
author
Tobias Gysi
committed
[mlir][linalg] Fold ExtractSliceOps during tiling.
Add the makeComposedExtractSliceOp method that creates an ExtractSliceOp and folds chains of ExtractSliceOps by computing the sum of their offsets and by multiplying their strides. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D109601
1 parent 125e8ef commit 44a8897

File tree

5 files changed

+135
-9
lines changed

5 files changed

+135
-9
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,25 @@ SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b);
5656
/// Otherwise return nullptr.
5757
IntegerAttr getSmallestBoundingIndex(Value size);
5858

59+
/// Create an ExtractSliceOp and, if `source` is defined by an ExtractSliceOp,
60+
/// fold it by adding the offsets.
61+
///
62+
/// Example:
63+
/// ```
64+
/// %0 = tensor.extract_slice %arg0[3, 4][3, 32][1, 1] : tensor<64x64xf32> to
65+
/// tensor<3x32xf32>
66+
/// %1 = tensor.extract_slice %0[0, 5][3, 4][1, 1] : tensor<3x32xf32> to
67+
/// tensor<3x4xf32>
68+
/// ```
69+
/// folds into:
70+
/// ```
71+
/// %1 = tensor.extract_slice %arg0[3, 9][3, 4][1, 1] : tensor<64x64xf32> to
72+
/// tensor<3x4xf32>
73+
/// ```
74+
tensor::ExtractSliceOp makeComposedExtractSliceOp(
75+
OpBuilder &b, Location loc, Value source, ArrayRef<OpFoldResult> offsets,
76+
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides);
77+
5978
//===----------------------------------------------------------------------===//
6079
// Fusion utilities
6180
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
7272
}
7373
};
7474

75+
/// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
76+
/// a Value or creates a ConstantIndexOp if it casts to an IntegerAttribute.
77+
/// Other attribute types are not supported.
78+
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
79+
OpFoldResult ofr);
80+
7581
/// Helper struct to build simple arithmetic quantities with minimal type
7682
/// inference support.
7783
struct ArithBuilder {

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "mlir/IR/OpImplementation.h"
2929
#include "mlir/Pass/Pass.h"
3030
#include "mlir/Transforms/LoopUtils.h"
31+
#include "llvm/ADT/TypeSwitch.h"
3132
#include "llvm/Support/Debug.h"
3233

3334
#define DEBUG_TYPE "linalg-utils"
@@ -194,6 +195,48 @@ IntegerAttr getSmallestBoundingIndex(Value size) {
194195
return nullptr;
195196
}
196197

198+
tensor::ExtractSliceOp makeComposedExtractSliceOp(
199+
OpBuilder &b, Location loc, Value source, ArrayRef<OpFoldResult> offsets,
200+
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
201+
assert(source && "expect source to be nonzero");
202+
203+
// Do not fold if the producer is not an ExtractSliceOp.
204+
auto producerOp = source.getDefiningOp<tensor::ExtractSliceOp>();
205+
if (!producerOp)
206+
return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
207+
strides);
208+
209+
// Do not fold if the producer is rank reducing or if there are any non-unit
210+
// strides. Supporting non-unit strides complicates the offset computation
211+
// since the consumer offsets need to be multiplied by the producer strides.
212+
// TODO: support non-unit strides once there are use cases.
213+
SmallVector<OpFoldResult> allStrides = producerOp.getMixedStrides();
214+
allStrides.append(strides.begin(), strides.end());
215+
bool hasNonUnitStride = any_of(allStrides, [](OpFoldResult ofr) {
216+
return getConstantIntValue(ofr) != static_cast<int64_t>(1);
217+
});
218+
if (hasNonUnitStride ||
219+
producerOp.getSourceType().getRank() !=
220+
producerOp.getResult().getType().cast<ShapedType>().getRank())
221+
return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
222+
strides);
223+
224+
// Fold the producer by adding the offests and extracting the slice directly
225+
// from the producer source tensor.
226+
SmallVector<OpFoldResult> foldedOffsets(offsets.begin(), offsets.end());
227+
AffineExpr dim1, dim2;
228+
bindDims(b.getContext(), dim1, dim2);
229+
for (auto en : enumerate(producerOp.getMixedOffsets())) {
230+
SmallVector<Value> offsetValues = {
231+
getValueOrCreateConstantIndexOp(b, loc, foldedOffsets[en.index()]),
232+
getValueOrCreateConstantIndexOp(b, loc, en.value())};
233+
foldedOffsets[en.index()] =
234+
makeComposedAffineApply(b, loc, dim1 + dim2, offsetValues).getResult();
235+
}
236+
return b.create<tensor::ExtractSliceOp>(loc, producerOp.source(),
237+
foldedOffsets, sizes, strides);
238+
}
239+
197240
/// Specialization to build an scf "for" nest.
198241
template <>
199242
void GenerateLoopNest<scf::ForOp>::doit(
@@ -603,15 +646,18 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
603646
strides.push_back(builder.getIndexAttr(1));
604647
}
605648

606-
Operation *sliceOp = shapedType.isa<MemRefType>()
607-
? builder
608-
.create<memref::SubViewOp>(
609-
loc, valueToTile, offsets, sizes, strides)
610-
.getOperation()
611-
: builder
612-
.create<tensor::ExtractSliceOp>(
613-
loc, valueToTile, offsets, sizes, strides)
614-
.getOperation();
649+
auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
650+
.Case([&](MemRefType) {
651+
return builder.create<memref::SubViewOp>(
652+
loc, valueToTile, offsets, sizes, strides);
653+
})
654+
.Case([&](RankedTensorType) {
655+
return makeComposedExtractSliceOp(
656+
builder, loc, valueToTile, offsets, sizes, strides);
657+
})
658+
.Default([](ShapedType) -> Operation * {
659+
llvm_unreachable("Unexpected shaped type");
660+
});
615661
return sliceOp->getResult(0);
616662
}
617663

mlir/lib/Dialect/StandardOps/Utils/Utils.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ void mlir::getPositionsOfShapeOne(
4949
}
5050
}
5151

52+
Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
53+
OpFoldResult ofr) {
54+
if (auto value = ofr.dyn_cast<Value>())
55+
return value;
56+
auto attr = ofr.dyn_cast<Attribute>().dyn_cast<IntegerAttr>();
57+
assert(attr && "expect the op fold result casts to an integer attribute");
58+
return b.create<ConstantIndexOp>(loc, attr.getValue().getSExtValue());
59+
}
60+
5261
Value ArithBuilder::_and(Value lhs, Value rhs) {
5362
return b.create<AndOp>(loc, lhs, rhs);
5463
}

mlir/test/Dialect/Linalg/tile-tensors.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,49 @@ func @generic_op_tensors(
130130
// TLOOP-SAME: ins (%{{.*}} = %[[ARG_0]]: [[TY]], %{{.*}} = %[[ARG_1]]: [[TY]])
131131
// TLOOP-SAME: outs (%{{.*}} = %[[INIT]]: [[TY]])
132132
// TLOOP-SAME: distribution["block_x", "block_y", "none"] {
133+
134+
// -----
135+
136+
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
137+
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 + 3)>
138+
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 + 4)>
139+
140+
// CHECK: fold_extract_slice
141+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<?x128xf32>
142+
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<?x42xf32>
143+
func @fold_extract_slice(
144+
%arg0 : tensor<?x128xf32>, %arg1 : tensor<?x42xf32>, %arg2 : tensor<?x42x?xf32>) -> tensor<?x42xf32> {
145+
146+
// CHECK: %[[C0:.*]] = constant 0
147+
%c0 = constant 0 : index
148+
149+
// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG1]], %[[C0]]
150+
%0 = tensor.dim %arg1, %c0 : tensor<?x42xf32>
151+
%1 = tensor.extract_slice %arg0[3, 4] [%0, 42] [1, 1] : tensor<?x128xf32> to tensor<?x42xf32>
152+
153+
// CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] =
154+
// CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] =
155+
156+
// Fold the existing extract slice op into the one created by the tiling.
157+
// CHECK: %[[SIZE0:.*]] = affine.min #[[MAP0]](%[[IV0]])[%[[DIM]]
158+
// CHECK: %[[OFF0:.*]] = affine.apply #[[MAP1]](%[[IV0]]
159+
// CHECK: %[[OFF1:.*]] = affine.apply #[[MAP2]](%[[IV1]]
160+
// CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
161+
// CHECK-SAME: %[[OFF0]], %[[OFF1]]
162+
// CHECK-SAME: %[[SIZE0]], 3
163+
// CHECK-SAME: 1, 1
164+
// CHECK: {{.*}} = linalg.generic {{.*}} ins(%[[T0]]
165+
%2 = linalg.generic
166+
{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>,
167+
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
168+
affine_map<(d0, d1, d2) -> (d0, d1)>],
169+
iterator_types = ["parallel", "parallel", "parallel"]}
170+
ins(%1, %arg2 : tensor<?x42xf32>, tensor<?x42x?xf32>)
171+
outs(%arg1 : tensor<?x42xf32>) {
172+
^bb0(%arg3 : f32, %arg4: f32, %arg5: f32):
173+
%5 = addf %arg3, %arg5 : f32
174+
linalg.yield %5 : f32
175+
} -> tensor<?x42xf32>
176+
return %2 : tensor<?x42xf32>
177+
}
178+

0 commit comments

Comments
 (0)