Skip to content

Commit 3078d3f

Browse files
authored
[LinalgExt] add TilingInterface support for ArgmaxOp (#21077)
This PR adds tiling support for the `iree_linalg_ext.argmax` operation by implementing the tiling interface. The following methods are introduced: - `getLoopIteratorTypes` - `getResultTilePosition` - `getTiledImplementation` - `generateResultTileValue` Additionally, corresponding test cases are provided to verify the tiling behavior for both tensor and memref cases. --------- Signed-off-by: Bangtian Liu <[email protected]>
1 parent ac5be26 commit 3078d3f

File tree

3 files changed

+331
-0
lines changed

3 files changed

+331
-0
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,10 @@ def IREELinalgExt_ArgCompareOp : IREELinalgExt_Op<"arg_compare", [
679679
DeclareOpInterfaceMethods<LinalgExtInterface>,
680680
DeclareOpInterfaceMethods<TilingInterface,
681681
["generateScalarImplementation",
682+
"getLoopIteratorTypes",
683+
"getResultTilePosition",
684+
"getTiledImplementation",
685+
"generateResultTileValue",
682686
"getIterationDomain"]
683687
>
684688
]> {

compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,6 +1357,95 @@ SmallVector<Range> ArgCompareOp::getIterationDomain(OpBuilder &builder) {
13571357
return ranges;
13581358
}
13591359

1360+
SmallVector<utils::IteratorType> ArgCompareOp::getLoopIteratorTypes() {
1361+
SmallVector<utils::IteratorType> iteratorTypes(getInputRank(),
1362+
utils::IteratorType::parallel);
1363+
iteratorTypes[getDimension()] = utils::IteratorType::reduction;
1364+
return iteratorTypes;
1365+
}
1366+
1367+
FailureOr<TilingResult>
1368+
ArgCompareOp::getTiledImplementation(OpBuilder &builder,
1369+
ArrayRef<OpFoldResult> offsets,
1370+
ArrayRef<OpFoldResult> sizes) {
1371+
Location loc = getLoc();
1372+
int64_t rank = getInputRank();
1373+
assert(offsets.size() == static_cast<size_t>(rank) &&
1374+
"Unexpected offsets size");
1375+
assert(sizes.size() == static_cast<size_t>(rank) && "Unexpected sizes size");
1376+
1377+
SmallVector<Operation *> slices;
1378+
SmallVector<Value> tiledOperands;
1379+
1380+
SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
1381+
Operation *inputSlice =
1382+
getSlice(builder, loc, getInputValue(), offsets, sizes, strides);
1383+
tiledOperands.push_back(inputSlice->getResult(0));
1384+
slices.push_back(inputSlice);
1385+
1386+
SmallVector<OpFoldResult> outputOffsets, outputSizes;
1387+
if (failed(getResultTilePosition(builder, 0, offsets, sizes, outputOffsets,
1388+
outputSizes))) {
1389+
return emitOpError("failed to compute output tile position");
1390+
}
1391+
1392+
SmallVector<OpFoldResult> outputStrides(outputOffsets.size(),
1393+
builder.getIndexAttr(1));
1394+
Operation *outputValSlice = getSlice(
1395+
builder, loc, outputValue(), outputOffsets, outputSizes, outputStrides);
1396+
tiledOperands.push_back(outputValSlice->getResult(0));
1397+
slices.push_back(outputValSlice);
1398+
1399+
Operation *outputIdxSlice = getSlice(
1400+
builder, loc, outputIndex(), outputOffsets, outputSizes, outputStrides);
1401+
tiledOperands.push_back(outputIdxSlice->getResult(0));
1402+
slices.push_back(outputIdxSlice);
1403+
1404+
if (getIndexBase()) {
1405+
tiledOperands.push_back(getIndexBase());
1406+
}
1407+
1408+
SmallVector<Type> resultTypes;
1409+
if (hasPureTensorSemantics()) {
1410+
resultTypes.push_back(outputValSlice->getResult(0).getType());
1411+
resultTypes.push_back(outputIdxSlice->getResult(0).getType());
1412+
}
1413+
1414+
Operation *tiledArgmaxOp =
1415+
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
1416+
1417+
return TilingResult{
1418+
{tiledArgmaxOp}, SmallVector<Value>(tiledArgmaxOp->getResults()), slices};
1419+
}
1420+
1421+
LogicalResult ArgCompareOp::getResultTilePosition(
1422+
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
1423+
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
1424+
SmallVector<OpFoldResult> &resultSizes) {
1425+
int64_t dim = getDimension();
1426+
int64_t inputRank = getInputRank();
1427+
1428+
resultOffsets.clear();
1429+
resultSizes.clear();
1430+
1431+
for (int64_t i = 0; i < inputRank; ++i) {
1432+
if (i == dim) {
1433+
continue;
1434+
}
1435+
resultOffsets.push_back(offsets[i]);
1436+
resultSizes.push_back(sizes[i]);
1437+
}
1438+
1439+
return success();
1440+
}
1441+
1442+
FailureOr<TilingResult>
1443+
ArgCompareOp::generateResultTileValue(OpBuilder &builder, unsigned resultNumber,
1444+
ArrayRef<OpFoldResult> offsets,
1445+
ArrayRef<OpFoldResult> sizes) {
1446+
return getTiledImplementation(builder, offsets, sizes);
1447+
}
1448+
13601449
LogicalResult ArgCompareOp::generateScalarImplementation(OpBuilder &b,
13611450
Location loc,
13621451
ValueRange ivs) {

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,244 @@ module attributes { transform.with_named_sequence } {
788788

789789
// -----
790790

791+
func.func @arg_compare_tile_tensor(
792+
%input: tensor<?x?xf32>,
793+
%outv: tensor<?xf32>,
794+
%outi: tensor<?xi32>
795+
) -> (tensor<?xf32>, tensor<?xi32>) {
796+
%0:2 = iree_linalg_ext.arg_compare
797+
dimension(1)
798+
ins(%input : tensor<?x?xf32>)
799+
outs(%outv, %outi : tensor<?xf32>, tensor<?xi32>) {
800+
^bb0(%a: f32, %b: f32):
801+
%cmp = arith.cmpf ogt, %a, %b : f32
802+
iree_linalg_ext.yield %cmp : i1
803+
} -> tensor<?xf32>, tensor<?xi32>
804+
return %0#0, %0#1 : tensor<?xf32>, tensor<?xi32>
805+
}
806+
807+
module attributes { transform.with_named_sequence } {
808+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
809+
%0 = transform.structured.match ops{["iree_linalg_ext.arg_compare"]} in %module_op
810+
: (!transform.any_op) -> !transform.any_op
811+
%1, %loops = transform.structured.tile_using_for %0 tile_sizes [10, 0]
812+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
813+
transform.yield
814+
}
815+
}
816+
817+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
818+
// CHECK: func.func @arg_compare_tile_tensor
819+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
820+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
821+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]
822+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
823+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
824+
// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
825+
// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
826+
// CHECK: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
827+
// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[C10]] iter_args(%[[V0:.+]] = %[[ARG1]], %[[V1:.+]] = %[[ARG2]])
828+
// CHECK: %[[MIN:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D0]]]
829+
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[MIN]], %[[D1]]] [1, 1]
830+
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[V0]][%[[IV]]] [%[[MIN]]] [1]
831+
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[V1]][%[[IV]]] [%[[MIN]]] [1]
832+
// CHECK: %[[CMP:.+]]:2 = iree_linalg_ext.arg_compare
833+
// CHECK-SAME: ins(%[[SLICE0]]
834+
// CHECK-SAME: outs(%[[SLICE1]], %[[SLICE2]]
835+
// CHECK: %[[INS0:.+]] = tensor.insert_slice %[[CMP]]#0 into %[[V0]][%[[IV]]] [%[[MIN]]] [1]
836+
// CHECK: %[[INS1:.+]] = tensor.insert_slice %[[CMP]]#1 into %[[V1]][%[[IV]]] [%[[MIN]]] [1]
837+
// CHECK: scf.yield %[[INS0]], %[[INS1]]
838+
// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
839+
840+
// -----
841+
842+
func.func @arg_compare_tile_memref(
843+
%input: memref<?x?xf32>,
844+
%outv: memref<?xf32>,
845+
%outi: memref<?xi32>
846+
) {
847+
iree_linalg_ext.arg_compare
848+
dimension(1)
849+
ins(%input : memref<?x?xf32>)
850+
outs(%outv, %outi : memref<?xf32>, memref<?xi32>) {
851+
^bb0(%a: f32, %b: f32):
852+
%cmp = arith.cmpf ogt, %a, %b : f32
853+
iree_linalg_ext.yield %cmp : i1
854+
}
855+
return
856+
}
857+
858+
module attributes { transform.with_named_sequence } {
859+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
860+
%0 = transform.structured.match ops{["iree_linalg_ext.arg_compare"]} in %module_op
861+
: (!transform.any_op) -> !transform.any_op
862+
%1, %loops = transform.structured.tile_using_for %0 tile_sizes [10, 0]
863+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
864+
transform.yield
865+
}
866+
}
867+
868+
// CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
869+
// CHECK: func.func @arg_compare_tile_memref
870+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
871+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
872+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
873+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
874+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
875+
// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
876+
// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]]
877+
// CHECK: %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]]
878+
// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[C10]]
879+
// CHECK: %[[MIN:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D0]]]
880+
// CHECK: %[[SV0:.+]] = memref.subview %[[ARG0]][%[[IV]], 0] [%[[MIN]], %[[D1]]] [1, 1]
881+
// CHECK: %[[SV1:.+]] = memref.subview %[[ARG1]][%[[IV]]] [%[[MIN]]] [1]
882+
// CHECK: %[[SV2:.+]] = memref.subview %[[ARG2]][%[[IV]]] [%[[MIN]]] [1]
883+
// CHECK: iree_linalg_ext.arg_compare
884+
// CHECK-SAME: dimension(1)
885+
// CHECK-SAME: ins(%[[SV0]]
886+
// CHECK-SAME: outs(%[[SV1]], %[[SV2]]
887+
// CHECK: return
888+
889+
// -----
890+
891+
func.func @arg_compare_1d(%input: tensor<128xf32>) -> tensor<i32> {
892+
%outv = tensor.empty() : tensor<f32>
893+
%outi = tensor.empty() : tensor<i32>
894+
%result:2 = iree_linalg_ext.arg_compare
895+
dimension(0)
896+
ins(%input : tensor<128xf32>)
897+
outs(%outv, %outi : tensor<f32>, tensor<i32>) {
898+
^bb0(%a: f32, %b: f32):
899+
%cmp = arith.cmpf ogt, %a, %b : f32
900+
iree_linalg_ext.yield %cmp : i1
901+
} -> tensor<f32>, tensor<i32>
902+
return %result#1 : tensor<i32>
903+
}
904+
905+
module attributes { transform.with_named_sequence } {
906+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
907+
%0 = transform.structured.match ops{["iree_linalg_ext.arg_compare"]} in %module_op
908+
: (!transform.any_op) -> !transform.any_op
909+
%1 = transform.structured.tile_using_for %0 tile_sizes [0]
910+
: (!transform.any_op) -> (!transform.any_op)
911+
transform.yield
912+
}
913+
}
914+
915+
// CHECK: func.func @arg_compare_1d(
916+
// CHECK-SAME: %[[OPERAND:.+]]: tensor<128xf32>
917+
// CHECK: %[[ACCV:.+]] = tensor.empty() : tensor<f32>
918+
// CHECK: %[[ACCI:.+]] = tensor.empty() : tensor<i32>
919+
// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.arg_compare
920+
// CHECK-SAME: ins(%[[OPERAND]] :
921+
// CHECK-SAME: outs(%[[ACCV]], %[[ACCI]] :
922+
// CHECK: return %[[RESULT]]#1
923+
924+
// -----
925+
926+
func.func @arg_compare_2d_dim0(%input: tensor<16x32xf32>) -> tensor<32xi32> {
927+
%outv = tensor.empty() : tensor<32xf32>
928+
%outi = tensor.empty() : tensor<32xi32>
929+
%result:2 = iree_linalg_ext.arg_compare
930+
dimension(0)
931+
ins(%input : tensor<16x32xf32>)
932+
outs(%outv, %outi : tensor<32xf32>, tensor<32xi32>) {
933+
^bb0(%a: f32, %b: f32):
934+
%cmp = arith.cmpf ogt, %a, %b : f32
935+
iree_linalg_ext.yield %cmp : i1
936+
} -> tensor<32xf32>, tensor<32xi32>
937+
return %result#1 : tensor<32xi32>
938+
}
939+
940+
module attributes { transform.with_named_sequence } {
941+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
942+
%0 = transform.structured.match ops{["iree_linalg_ext.arg_compare"]} in %module_op
943+
: (!transform.any_op) -> !transform.any_op
944+
// Only tile the non-reduction dimension: columns.
945+
%1, %loops = transform.structured.tile_using_for %0 tile_sizes [0, 20]
946+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
947+
transform.yield
948+
}
949+
}
950+
951+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (-d0 + 32, 20)>
952+
// CHECK: func.func @arg_compare_2d_dim0(
953+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
954+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
955+
// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
956+
// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
957+
// CHECK-DAG: %[[ACCV:.+]] = tensor.empty() : tensor<32xf32>
958+
// CHECK-DAG: %[[ACCI:.+]] = tensor.empty() : tensor<32xi32>
959+
// CHECK: %[[RESULT:.+]]:2 = scf.for %[[I:.+]] = %[[C0]] to %[[C32]] step %[[C20]]
960+
// CHECK-SAME: iter_args(%[[ARG2:.+]] = %[[ACCV]], %[[ARG3:.+]] = %[[ACCI]])
961+
// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])
962+
// CHECK: %[[UPDATE_SLICE_IN:.+]] = tensor.extract_slice %[[ARG0]][0, %[[I]]] [16, %[[SIZE]]] [1, 1]
963+
// CHECK: %[[UPDATE_SLICE_OUTV:.+]] = tensor.extract_slice %[[ARG2]][%[[I]]] [%[[SIZE]]] [1]
964+
// CHECK: %[[UPDATE_SLICE_OUTI:.+]] = tensor.extract_slice %[[ARG3]][%[[I]]] [%[[SIZE]]] [1]
965+
// CHECK: %[[ARGCMP_TILE:.+]]:2 = iree_linalg_ext.arg_compare
966+
// CHECK-SAME: dimension(0)
967+
// CHECK-SAME: ins(%[[UPDATE_SLICE_IN]]
968+
// CHECK-SAME: outs(%[[UPDATE_SLICE_OUTV]], %[[UPDATE_SLICE_OUTI]]
969+
// CHECK: %[[ACCV_YIELD:.+]] = tensor.insert_slice %[[ARGCMP_TILE]]#0 into %[[ARG2]][%[[I]]] [%[[SIZE]]] [1]
970+
// CHECK: %[[ACCI_YIELD:.+]] = tensor.insert_slice %[[ARGCMP_TILE]]#1 into %[[ARG3]][%[[I]]] [%[[SIZE]]] [1]
971+
// CHECK: scf.yield %[[ACCV_YIELD]], %[[ACCI_YIELD]] : tensor<32xf32>, tensor<32xi32>
972+
// CHECK: return %[[RESULT]]#1
973+
974+
// -----
975+
976+
func.func @arg_compare_with_base(
977+
%input : tensor<2x6xf32>,
978+
%outv : tensor<2xf32>,
979+
%outi : tensor<2xindex>,
980+
%base : index
981+
) -> (tensor<2xf32>, tensor<2xindex>) {
982+
%0:2 = iree_linalg_ext.arg_compare
983+
dimension(1)
984+
ins(%input : tensor<2x6xf32>)
985+
outs(%outv, %outi : tensor<2xf32>, tensor<2xindex>)
986+
index_base(%base : index) {
987+
^bb0(%a: f32, %b: f32):
988+
%cmp = arith.cmpf ogt, %a, %b : f32
989+
iree_linalg_ext.yield %cmp : i1
990+
} -> tensor<2xf32>, tensor<2xindex>
991+
return %0#0, %0#1 : tensor<2xf32>, tensor<2xindex>
992+
}
993+
994+
module attributes { transform.with_named_sequence } {
995+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
996+
%0 = transform.structured.match ops{["iree_linalg_ext.arg_compare"]} in %module_op
997+
: (!transform.any_op) -> !transform.any_op
998+
%1, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0]
999+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
1000+
transform.yield
1001+
}
1002+
}
1003+
1004+
// CHECK-LABEL: func.func @arg_compare_with_base(
1005+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<2x6xf32>
1006+
// CHECK-SAME: %[[OUTV:[a-zA-Z0-9_]+]]: tensor<2xf32>
1007+
// CHECK-SAME: %[[OUTI:[a-zA-Z0-9_]+]]: tensor<2xindex>
1008+
// CHECK-SAME: %[[BASE:[a-zA-Z0-9_]+]]: index
1009+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1010+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
1011+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1012+
// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[C2]] step %[[C1]]
1013+
// CHECK-SAME: iter_args(%[[VARG:.+]] = %[[OUTV]], %[[IARG:.+]] = %[[OUTI]])
1014+
// CHECK: %[[SLICE_IN:.+]] = tensor.extract_slice %[[INPUT]][%[[IV]], 0] [1, 6] [1, 1]
1015+
// CHECK: %[[SLICE_OUTV:.+]] = tensor.extract_slice %[[VARG]][%[[IV]]] [1] [1]
1016+
// CHECK: %[[SLICE_OUTI:.+]] = tensor.extract_slice %[[IARG]][%[[IV]]] [1] [1]
1017+
// CHECK: %[[ARGCMP:.+]]:2 = iree_linalg_ext.arg_compare
1018+
// CHECK-SAME: dimension(1)
1019+
// CHECK-SAME: ins(%[[SLICE_IN]]
1020+
// CHECK-SAME: outs(%[[SLICE_OUTV]], %[[SLICE_OUTI]]
1021+
// CHECK-SAME: index_base(%[[BASE]]
1022+
// CHECK: %[[INS_OUTV:.+]] = tensor.insert_slice %[[ARGCMP]]#0 into %[[VARG]][%[[IV]]] [1] [1]
1023+
// CHECK: %[[INS_OUTI:.+]] = tensor.insert_slice %[[ARGCMP]]#1 into %[[IARG]][%[[IV]]] [1] [1]
1024+
// CHECK: scf.yield %[[INS_OUTV]], %[[INS_OUTI]]
1025+
// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
1026+
1027+
// -----
1028+
7911029
func.func @im2col(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> {
7921030
%0 = tensor.empty() : tensor<2x1024x5760xf32>
7931031
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]

0 commit comments

Comments
 (0)