Skip to content

Commit 6816a39

Browse files
authored
[LinalgExt] Added TilingInterface support for ExpReductionOp (#22316)
This is part 2 of #21761. In this PR, we: - Added TilingInterface to ExpReductionOp. - Added LinalgStructuredInterface to ExpReductionOp. - Added invalid and tiling tests - ExpReductionOp now does not accept use of linalg.index. - All indexing maps used by ExpReductionOp must now be projected permutations. - Refactored logic shared with other components: - AttentionOp's `getPermutedRange` utility function is shared - CustomOp's `createFlatListOfOperandDims` utility function is shared - The `StaticizeLinalgExtOp` rewrite now calls a utility function `allIndexingsAreProjectedPermutation` to check for all indexing maps being permutations. Co-authored-by: Kunwar Grover [[email protected]](mailto:[email protected]) --------- Signed-off-by: Ivan Ho <[email protected]>
1 parent c65dc6d commit 6816a39

File tree

5 files changed

+457
-38
lines changed

5 files changed

+457
-38
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ static bool isInvalid(ArrayRef<int64_t> dimsPos, int64_t rank) {
9898
dimsPos, [rank](int64_t dimPos) { return dimPos < 0 || dimPos >= rank; });
9999
}
100100

101+
static bool allIndexingsAreProjectedPermutation(IndexingMapOpInterface op) {
102+
return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) {
103+
return m.isProjectedPermutation(/*allowZeroInResults=*/true);
104+
});
105+
}
106+
101107
/// Emit an error and return failure when `seq` is invalid. It is only valid
102108
/// when it is a permutation of the sequence 0...length(seq) - 1.
103109
static LogicalResult
@@ -339,9 +345,7 @@ struct StaticizeLinalgExtOp : public OpRewritePattern<OpTy> {
339345
return failure();
340346
}
341347

342-
if (llvm::any_of(op.getIndexingMapsArray(), [](AffineMap map) {
343-
return !map.isProjectedPermutation();
344-
})) {
348+
if (!allIndexingsAreProjectedPermutation(op)) {
345349
return failure();
346350
}
347351

@@ -2304,6 +2308,14 @@ LogicalResult ExpReductionOp::verify() {
23042308
}
23052309
}
23062310

2311+
if (!allIndexingsAreProjectedPermutation(*this)) {
2312+
return op->emitOpError("all indexing maps must be projected permutations");
2313+
}
2314+
2315+
if (!getBody()->getOps<linalg::IndexOp>().empty()) {
2316+
return op->emitOpError("linalg.index is not supported in body");
2317+
}
2318+
23072319
return success();
23082320
}
23092321

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,8 +1016,18 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention",
10161016
// ExpReduction
10171017
//===----------------------------------------------------------------------===//
10181018

1019-
def IREELinalgExt_ExpReductionOp : IREELinalgExt_Op<"exp_reduction",
1020-
[AttrSizedOperandSegments]> {
1019+
def IREELinalgExt_ExpReductionOp : IREELinalgExt_Op<"exp_reduction", [
1020+
AttrSizedOperandSegments,
1021+
DeclareOpInterfaceMethods<IndexingMapOpInterface>,
1022+
DeclareOpInterfaceMethods<TilingInterface,
1023+
[
1024+
"getLoopIteratorTypes",
1025+
"getIterationDomain",
1026+
"getTiledImplementation",
1027+
"getResultTilePosition",
1028+
"generateResultTileValue"
1029+
]>
1030+
]> {
10211031
let summary = [{
10221032
A linalg.generic extension with support for exponential reduction.
10231033
}];

compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp

Lines changed: 132 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,43 @@ getStaticOrReifiedInputDims(OpBuilder &builder, Location loc, Value input,
6767
return success();
6868
}
6969

70+
/// Returns the rank of the Value's type, and 0 if it is not a ShapedType.
71+
static int64_t getRank(Value v) {
72+
auto type = dyn_cast<ShapedType>(v.getType());
73+
if (type) {
74+
return type.getRank();
75+
}
76+
return 0;
77+
}
78+
79+
/// Method similar to `LinalgOp`s that concatenates shapes of all operands.
80+
static SmallVector<OpFoldResult>
81+
createFlatListOfOperandDims(OpBuilder &b, Location loc, Operation *op) {
82+
SmallVector<OpFoldResult> res;
83+
for (OpOperand &opOperand : op->getOpOperands()) {
84+
for (auto dim : llvm::seq(getRank(opOperand.get()))) {
85+
res.push_back(linalg::createFoldedDimOp(b, loc, opOperand.get(), dim));
86+
}
87+
}
88+
return res;
89+
}
90+
91+
/// Permutes the offset and size arrays by the result indexes of the provided
92+
/// affine map.
93+
static SmallVector<Range> getPermutedRange(AffineMap permutation,
94+
ArrayRef<OpFoldResult> offsets,
95+
ArrayRef<OpFoldResult> sizes) {
96+
auto one = IntegerAttr::get(IndexType::get(permutation.getContext()), 1);
97+
assert(permutation.isProjectedPermutation() &&
98+
"Affine map should be a projected permutation");
99+
SmallVector<Range> output;
100+
for (AffineExpr dimExpr : permutation.getResults()) {
101+
int dim = cast<AffineDimExpr>(dimExpr).getPosition();
102+
output.push_back(Range{offsets[dim], sizes[dim], one});
103+
}
104+
return output;
105+
}
106+
70107
//===----------------------------------------------------------------------===//
71108
// ScatterOp
72109
//===----------------------------------------------------------------------===//
@@ -1891,6 +1928,101 @@ SmallVector<Range> UnPackOp::getIterationDomain(OpBuilder &builder) {
18911928
return LinalgExt::getIterationDomain(*this, builder);
18921929
}
18931930

1931+
//===----------------------------------------------------------------------===//
1932+
// ExpReductionOp
1933+
//===----------------------------------------------------------------------===//
1934+
1935+
SmallVector<utils::IteratorType> ExpReductionOp::getLoopIteratorTypes() {
1936+
return llvm::to_vector(getIteratorTypes()
1937+
.getAsValueRange<IREE::LinalgExt::IteratorTypeAttr,
1938+
utils::IteratorType>());
1939+
}
1940+
1941+
SmallVector<Range> ExpReductionOp::getIterationDomain(OpBuilder &b) {
1942+
Location loc = getLoc();
1943+
OpFoldResult zero = b.getIndexAttr(0);
1944+
OpFoldResult one = b.getIndexAttr(1);
1945+
1946+
SmallVector<OpFoldResult> allShapesSizes =
1947+
createFlatListOfOperandDims(b, loc, getOperation());
1948+
AffineMap map = getShapesToLoopsMap();
1949+
return llvm::map_to_vector(map.getResults(), [&](AffineExpr loopExpr) {
1950+
OpFoldResult ofr =
1951+
affine::makeComposedFoldedAffineApply(b, loc, loopExpr, allShapesSizes);
1952+
return Range{zero, ofr, one};
1953+
});
1954+
}
1955+
1956+
FailureOr<TilingResult>
1957+
ExpReductionOp::getTiledImplementation(OpBuilder &b,
1958+
ArrayRef<OpFoldResult> offsets,
1959+
ArrayRef<OpFoldResult> sizes) {
1960+
Location loc = getLoc();
1961+
auto indexingMapOp = cast<IndexingMapOpInterface>(getOperation());
1962+
SmallVector<Value> tiledOperands;
1963+
SmallVector<Operation *> generatedSlices;
1964+
for (OpOperand &opOperand : getOperation()->getOpOperands()) {
1965+
AffineMap map = indexingMapOp.getMatchingIndexingMap(&opOperand);
1966+
SmallVector<Range> slice = getPermutedRange(map, offsets, sizes);
1967+
Operation *sliceOp = getSlice(b, loc, opOperand.get(), slice);
1968+
tiledOperands.emplace_back(sliceOp->getResult(0));
1969+
generatedSlices.push_back(sliceOp);
1970+
}
1971+
1972+
SmallVector<Type, 4> resultTensorTypes;
1973+
if (getNumResults()) {
1974+
resultTensorTypes = llvm::map_to_vector<4>(
1975+
getDpsInitsMutable(), [&generatedSlices](OpOperand &opOperand) {
1976+
return generatedSlices[opOperand.getOperandNumber()]
1977+
->getResultTypes()[0];
1978+
});
1979+
}
1980+
1981+
Operation *tiledOp = mlir::clone(b, *this, resultTensorTypes, tiledOperands);
1982+
return TilingResult{
1983+
{tiledOp}, SmallVector<Value>(tiledOp->getResults()), generatedSlices};
1984+
}
1985+
1986+
LogicalResult ExpReductionOp::getResultTilePosition(
1987+
OpBuilder &b, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
1988+
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
1989+
SmallVector<OpFoldResult> &resultSizes) {
1990+
auto indexingMapOp = cast<IndexingMapOpInterface>(getOperation());
1991+
OpOperand *outOperand = getDpsInitOperand(resultNumber);
1992+
AffineMap indexingMap = indexingMapOp.getMatchingIndexingMap(outOperand);
1993+
SmallVector<Range> range = getPermutedRange(indexingMap, offsets, sizes);
1994+
resultOffsets.resize(range.size());
1995+
resultSizes.resize(range.size());
1996+
for (auto [index, r] : llvm::enumerate(range)) {
1997+
resultOffsets[index] = r.offset;
1998+
resultSizes[index] = r.size;
1999+
}
2000+
return success();
2001+
}
2002+
2003+
FailureOr<TilingResult>
2004+
ExpReductionOp::generateResultTileValue(OpBuilder &b, unsigned resultNumber,
2005+
ArrayRef<OpFoldResult> offsets,
2006+
ArrayRef<OpFoldResult> sizes) {
2007+
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
2008+
if (failed(getIterationDomainTileFromResultTile(
2009+
b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
2010+
return failure();
2011+
}
2012+
FailureOr<TilingResult> tilingResult =
2013+
getTiledImplementation(b, mappedOffsets, mappedSizes);
2014+
if (failed(tilingResult)) {
2015+
return failure();
2016+
}
2017+
if (tilingResult->tiledOps.size() != 1) {
2018+
return emitOpError("failed to generate tiled implementation");
2019+
}
2020+
return TilingResult{
2021+
tilingResult->tiledOps,
2022+
SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
2023+
tilingResult->generatedSlices};
2024+
}
2025+
18942026
//===----------------------------------------------------------------------===//
18952027
// Im2colOp
18962028
//===----------------------------------------------------------------------===//
@@ -2449,24 +2581,6 @@ getAttentionIteratorTypes(int64_t domainRank, AffineMap qMap, AffineMap kMap,
24492581
return iteratorTypes;
24502582
}
24512583

2452-
static SmallVector<Range> getPermutedRange(AffineMap permutation,
2453-
ArrayRef<OpFoldResult> offsets,
2454-
ArrayRef<OpFoldResult> sizes) {
2455-
auto one = IntegerAttr::get(IndexType::get(permutation.getContext()), 1);
2456-
assert(permutation.isProjectedPermutation() &&
2457-
"Indexing map should be a projected permutation");
2458-
SmallVector<Range> output;
2459-
for (AffineExpr dimExpr : permutation.getResults()) {
2460-
int dim = cast<AffineDimExpr>(dimExpr).getPosition();
2461-
Range dimRange;
2462-
dimRange.offset = offsets[dim];
2463-
dimRange.size = sizes[dim];
2464-
dimRange.stride = one;
2465-
output.push_back(dimRange);
2466-
}
2467-
return output;
2468-
}
2469-
24702584
static Operation *getPermutedSlice(OpBuilder &b, Location loc, Value val,
24712585
AffineMap permutation,
24722586
ArrayRef<OpFoldResult> offsets,
@@ -3088,19 +3202,6 @@ SmallVector<utils::IteratorType> CustomOp::getLoopIteratorTypes() {
30883202
});
30893203
}
30903204

3091-
/// Method similar to `LinalgOp`s that concatenates shapes of all operands.
3092-
static SmallVector<OpFoldResult>
3093-
createFlatListOfOperandDims(OpBuilder &builder, Location loc,
3094-
CustomOp customOp) {
3095-
SmallVector<OpFoldResult> result;
3096-
for (Value operand : customOp->getOperands()) {
3097-
for (auto dim : llvm::seq<unsigned>(customOp.getRank(operand))) {
3098-
result.push_back(getDim(builder, loc, operand, dim));
3099-
}
3100-
}
3101-
return result;
3102-
}
3103-
31043205
SmallVector<Range> CustomOp::getIterationDomainForDimensions(
31053206
OpBuilder &builder, ArrayRef<unsigned> dims, ArrayRef<unsigned> symbols) {
31063207
CustomOp customOp = *this;

compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,7 +1029,7 @@ func.func @unpack_mismatch_inner_tile_size_and_output_shape(
10291029

10301030
// -----
10311031

1032-
func.func @exp_reduction(%S: tensor<2x3xf32>) -> tensor<2xf32> {
1032+
func.func @exp_reduction_non_zero(%S: tensor<2x3xf32>) -> tensor<2xf32> {
10331033
%M = tensor.empty() : tensor<2xf32>
10341034
%out = tensor.empty() : tensor<2xf32>
10351035

@@ -1057,7 +1057,7 @@ func.func @exp_reduction(%S: tensor<2x3xf32>) -> tensor<2xf32> {
10571057

10581058
// -----
10591059

1060-
func.func @exp_reduction(%S: tensor<2x3xf32>) -> tensor<2xf32> {
1060+
func.func @exp_reduction_index_check(%S: tensor<2x3xf32>) -> tensor<2xf32> {
10611061
%M = tensor.empty() : tensor<2xf32>
10621062
%out = tensor.empty() : tensor<2xf32>
10631063

@@ -1085,6 +1085,117 @@ func.func @exp_reduction(%S: tensor<2x3xf32>) -> tensor<2xf32> {
10851085

10861086
// -----
10871087

1088+
func.func @exp_reduction_shaped_input(%S: f32) -> tensor<2xf32> {
1089+
%M = tensor.empty() : tensor<2xf32>
1090+
%out = tensor.empty() : tensor<2xf32>
1091+
1092+
// expected-error@+1 {{operand #0 must be variadic of ranked tensor of any type values, but got 'f32'}}
1093+
%max, %sum = iree_linalg_ext.exp_reduction {
1094+
indexing_maps = [
1095+
affine_map<(M,N)->()>,
1096+
affine_map<(M,N)->(M)>,
1097+
affine_map<(M,N)->(M)>
1098+
],
1099+
iterator_types = [
1100+
#iree_linalg_ext.iterator_type<parallel>,
1101+
#iree_linalg_ext.iterator_type<reduction>
1102+
],
1103+
exp_reduced_operands = [1]
1104+
} ins(%S: f32)
1105+
outs(%M, %out: tensor<2xf32>, tensor<2xf32>)
1106+
{
1107+
^bb0(%s: f32, %m: f32, %o: f32):
1108+
%add = arith.addf %s, %o: f32
1109+
iree_linalg_ext.yield %m, %add: f32, f32
1110+
} -> tensor<2xf32>, tensor<2xf32>
1111+
return %sum : tensor<2xf32>
1112+
}
1113+
1114+
// -----
1115+
1116+
func.func @exp_reduction_shaped_init(%S: tensor<2x3xf32>, %M : f32) -> tensor<2xf32> {
1117+
%out = tensor.empty() : tensor<2xf32>
1118+
1119+
// expected-error@+1 {{operand #1 must be variadic of ranked tensor of any type values, but got 'f32'}}
1120+
%max, %sum = iree_linalg_ext.exp_reduction {
1121+
indexing_maps = [
1122+
affine_map<(M,N)->(M,N)>,
1123+
affine_map<(M,N)->()>,
1124+
affine_map<(M,N)->(M)>
1125+
],
1126+
iterator_types = [
1127+
#iree_linalg_ext.iterator_type<parallel>,
1128+
#iree_linalg_ext.iterator_type<reduction>
1129+
],
1130+
exp_reduced_operands = [1]
1131+
} ins(%S: tensor<2x3xf32>)
1132+
outs(%M, %out: f32, tensor<2xf32>)
1133+
{
1134+
^bb0(%s: f32, %m: f32, %o: f32):
1135+
%add = arith.addf %s, %o: f32
1136+
iree_linalg_ext.yield %m, %add: f32, f32
1137+
} -> f32, tensor<2xf32>
1138+
return %sum : tensor<2xf32>
1139+
}
1140+
1141+
// -----
1142+
1143+
func.func @exp_reduction_projected(%S: tensor<2x3xf32>) -> tensor<2xf32> {
1144+
%M = tensor.empty() : tensor<2xf32>
1145+
%out = tensor.empty() : tensor<2xf32>
1146+
1147+
// expected-error@+1 {{all indexing maps must be projected permutations}}
1148+
%max, %sum = iree_linalg_ext.exp_reduction {
1149+
indexing_maps = [
1150+
affine_map<(M,N)[s0]->(M,N)>,
1151+
affine_map<(M,N)->(M)>,
1152+
affine_map<(M,N)->(M)>
1153+
],
1154+
iterator_types = [
1155+
#iree_linalg_ext.iterator_type<parallel>,
1156+
#iree_linalg_ext.iterator_type<reduction>
1157+
],
1158+
exp_reduced_operands = [1]
1159+
} ins(%S: tensor<2x3xf32>)
1160+
outs(%M, %out: tensor<2xf32>, tensor<2xf32>)
1161+
{
1162+
^bb0(%s: f32, %m: f32, %o: f32):
1163+
%add = arith.addf %s, %o: f32
1164+
iree_linalg_ext.yield %m, %add: f32, f32
1165+
} -> tensor<2xf32>, tensor<2xf32>
1166+
return %sum : tensor<2xf32>
1167+
}
1168+
1169+
// -----
1170+
1171+
func.func @exp_reduction_index(%S: tensor<2x3xf32>, %M : tensor<2xf32>) -> tensor<2xf32> {
1172+
%out = tensor.empty() : tensor<2xf32>
1173+
1174+
// expected-error@+1 {{linalg.index is not supported in body}}
1175+
%max, %sum = iree_linalg_ext.exp_reduction {
1176+
indexing_maps = [
1177+
affine_map<(M,N)->(M,N)>,
1178+
affine_map<(M,N)->(M)>,
1179+
affine_map<(M,N)->(M)>
1180+
],
1181+
iterator_types = [
1182+
#iree_linalg_ext.iterator_type<parallel>,
1183+
#iree_linalg_ext.iterator_type<reduction>
1184+
],
1185+
exp_reduced_operands = [1]
1186+
} ins(%S: tensor<2x3xf32>)
1187+
outs(%M, %out: tensor<2xf32>, tensor<2xf32>)
1188+
{
1189+
^bb0(%s: f32, %m: f32, %o: f32):
1190+
%add = arith.addf %s, %o: f32
1191+
%v = linalg.index 0: index
1192+
iree_linalg_ext.yield %m, %add: f32, f32
1193+
} -> tensor<2xf32>, tensor<2xf32>
1194+
return %sum : tensor<2xf32>
1195+
}
1196+
1197+
// -----
1198+
10881199
func.func @illegal_im2col_strides(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> {
10891200
%0 = tensor.empty() : tensor<2x1024x5760xf32>
10901201
// expected-error @+1 {{expected strides rank to be equal to the kernel rank}}

0 commit comments

Comments
 (0)