Skip to content

Commit f155a52

Browse files
authored
[Codegen] add transform op for matching attention op (iree-org#22199)
The key difference is that the indexing map is required for attention operations, so `MatchAttentionOp `does not treat it as optional. Signed-off-by: Bangtian Liu <[email protected]>
1 parent 670c33c commit f155a52

File tree

5 files changed

+373
-0
lines changed

5 files changed

+373
-0
lines changed

compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,3 +922,189 @@ module attributes {transform.with_named_sequence} {
922922
transform.yield
923923
}
924924
}
925+
926+
// -----
927+
928+
// Verify that the basic attention matcher works.
929+
930+
#map_query = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
931+
#map_key = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
932+
#map_value = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d5)>
933+
#map_scale = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
934+
#map_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
935+
936+
// CHECK-LABEL: func.func @attention_ops
937+
func.func @attention_ops(
938+
%query: tensor<2x10x6x4xf16>,
939+
%key: tensor<2x10x4x4xf16>,
940+
%value: tensor<2x10x4x4xf16>,
941+
%scale: f16,
942+
%output: tensor<2x10x6x4xf16>,
943+
%input_mm: tensor<32x64xf32>,
944+
%filter_mm: tensor<64x32xf32>,
945+
%output_mm: tensor<32x32xf32>) -> (tensor<2x10x6x4xf16>, tensor<32x32xf32>) {
946+
947+
// CHECK: iree_linalg_ext.attention
948+
// CHECK-SAME: match_status = "matched"
949+
%res1 = iree_linalg_ext.attention {indexing_maps = [#map_query, #map_key, #map_value, #map_scale, #map_output], match_status = "unmatched"}
950+
ins(%query, %key, %value, %scale : tensor<2x10x6x4xf16>, tensor<2x10x4x4xf16>, tensor<2x10x4x4xf16>, f16)
951+
outs(%output : tensor<2x10x6x4xf16>) {
952+
^bb0(%arg: f32):
953+
iree_linalg_ext.yield %arg : f32
954+
} -> tensor<2x10x6x4xf16>
955+
956+
// CHECK: linalg.matmul
957+
// CHECK-SAME: match_status = "unmatched"
958+
%res2 = linalg.matmul
959+
ins(%input_mm, %filter_mm : tensor<32x64xf32>, tensor<64x32xf32>)
960+
outs(%output_mm : tensor<32x32xf32>) {match_status = "unmatched"} -> tensor<32x32xf32>
961+
962+
return %res1, %res2 : tensor<2x10x6x4xf16>, tensor<32x32xf32>
963+
}
964+
965+
module attributes {transform.with_named_sequence} {
966+
transform.named_sequence @match_attention(%op: !transform.any_op {transform.readonly}) -> !transform.any_op {
967+
%batch, %m, %k1, %k2, %n =
968+
transform.iree.match.attention %op,
969+
query_type = f16, key_type = f16, value_type = f16, output_type = f16,
970+
indexing_maps = [#map_query, #map_key, #map_value, #map_scale, #map_output] :
971+
!transform.any_op -> !transform.param<i64>
972+
973+
transform.yield %op : !transform.any_op
974+
}
975+
976+
transform.named_sequence @annotate(%op: !transform.any_op {transform.readonly}) {
977+
%0 = transform.param.constant "matched" -> !transform.any_param
978+
transform.annotate %op "match_status" = %0 : !transform.any_op, !transform.any_param
979+
transform.yield
980+
}
981+
982+
transform.named_sequence @__transform_main(%module: !transform.any_op) {
983+
transform.foreach_match in %module
984+
@match_attention -> @annotate
985+
: (!transform.any_op) -> (!transform.any_op)
986+
transform.yield
987+
}
988+
}
989+
990+
// -----
991+
992+
// Verify dimension size constraints for the attention op with lowering config.
993+
994+
#map_query = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
995+
#map_key = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
996+
#map_value = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d5)>
997+
#map_scale = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
998+
#map_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
999+
1000+
#lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1]}>
1001+
1002+
// CHECK-LABEL: func.func @attention_constraints
1003+
func.func @attention_constraints(
1004+
%query: tensor<2x10x6x4xf16>,
1005+
%key: tensor<2x10x4x4xf16>,
1006+
%value: tensor<2x10x4x4xf16>,
1007+
%scale: f16,
1008+
%output: tensor<2x10x6x4xf16>) -> tensor<2x10x6x4xf16> {
1009+
1010+
// CHECK: iree_linalg_ext.attention
1011+
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1]}>
1012+
// CHECK-SAME: match_status = "matched"
1013+
%res = iree_linalg_ext.attention {
1014+
indexing_maps = [#map_query, #map_key, #map_value, #map_scale, #map_output],
1015+
lowering_config = #lowering_config,
1016+
match_status = "unmatched"}
1017+
ins(%query, %key, %value, %scale : tensor<2x10x6x4xf16>, tensor<2x10x4x4xf16>, tensor<2x10x4x4xf16>, f16)
1018+
outs(%output : tensor<2x10x6x4xf16>) {
1019+
^bb0(%arg: f32):
1020+
iree_linalg_ext.yield %arg : f32
1021+
} -> tensor<2x10x6x4xf16>
1022+
1023+
return %res : tensor<2x10x6x4xf16>
1024+
}
1025+
1026+
module attributes {transform.with_named_sequence} {
1027+
transform.named_sequence @match_attention_all_dims(%op: !transform.any_op {transform.readonly}) -> !transform.any_op {
1028+
%batch, %m, %n, %k1, %k2 =
1029+
transform.iree.match.attention %op,
1030+
query_type = f16, key_type = f16, value_type = f16, output_type = f16,
1031+
indexing_maps = [#map_query, #map_key, #map_value, #map_scale, #map_output] :
1032+
!transform.any_op -> !transform.param<i64>
1033+
1034+
transform.iree.match.dims_equal %batch, [2, 10] : !transform.param<i64>
1035+
transform.iree.match.dims_equal %m, [6] : !transform.param<i64>
1036+
transform.iree.match.dims_equal %k1, [4] : !transform.param<i64>
1037+
transform.iree.match.dims_equal %k2, [4] : !transform.param<i64>
1038+
transform.iree.match.dims_equal %n, [4] : !transform.param<i64>
1039+
transform.yield %op : !transform.any_op
1040+
}
1041+
1042+
transform.named_sequence @annotate(%op: !transform.any_op {transform.readonly}) {
1043+
%0 = transform.param.constant "matched" -> !transform.any_param
1044+
transform.annotate %op "match_status" = %0 : !transform.any_op, !transform.any_param
1045+
transform.yield
1046+
}
1047+
1048+
transform.named_sequence @__transform_main(%module: !transform.any_op) {
1049+
transform.foreach_match in %module
1050+
@match_attention_all_dims -> @annotate
1051+
: (!transform.any_op) -> (!transform.any_op)
1052+
transform.yield
1053+
}
1054+
}
1055+
1056+
// -----
1057+
1058+
// Verify indexing maps mismatching for the attention op.
1059+
1060+
#map_query = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
1061+
#map_key = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
1062+
#map_value = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d5)>
1063+
#map_scale = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
1064+
#map_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
1065+
1066+
#map_wrong_key = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>
1067+
1068+
// CHECK-LABEL: func.func @indexing_maps_test
1069+
func.func @indexing_maps_test(
1070+
%query: tensor<2x10x6x4xf16>,
1071+
%key: tensor<2x10x4x4xf16>,
1072+
%value: tensor<2x10x4x4xf16>,
1073+
%scale: f16,
1074+
%output: tensor<2x10x6x4xf16>) -> tensor<2x10x6x4xf16> {
1075+
1076+
// CHECK: iree_linalg_ext.attention
1077+
// CHECK-SAME: maps_match = "unmatched"
1078+
%res = iree_linalg_ext.attention {indexing_maps = [#map_query, #map_key, #map_value, #map_scale, #map_output], maps_match = "unmatched"}
1079+
ins(%query, %key, %value, %scale : tensor<2x10x6x4xf16>, tensor<2x10x4x4xf16>, tensor<2x10x4x4xf16>, f16)
1080+
outs(%output : tensor<2x10x6x4xf16>) {
1081+
^bb0(%arg: f32):
1082+
iree_linalg_ext.yield %arg : f32
1083+
} -> tensor<2x10x6x4xf16>
1084+
1085+
return %res : tensor<2x10x6x4xf16>
1086+
}
1087+
1088+
module attributes {transform.with_named_sequence} {
1089+
transform.named_sequence @match_with_wrong_maps(%op: !transform.any_op {transform.readonly}) -> !transform.any_op {
1090+
%batch, %m, %k1, %k2, %n =
1091+
transform.iree.match.attention %op,
1092+
query_type = f16, key_type = f16, value_type = f16, output_type = f16,
1093+
indexing_maps = [#map_query, #map_wrong_key, #map_value, #map_scale, #map_output] :
1094+
!transform.any_op -> !transform.param<i64>
1095+
transform.yield %op : !transform.any_op
1096+
}
1097+
1098+
transform.named_sequence @annotate(%op: !transform.any_op {transform.readonly}) {
1099+
%0 = transform.param.constant "matched" -> !transform.any_param
1100+
transform.annotate %op "maps_match" = %0 : !transform.any_op, !transform.any_param
1101+
transform.yield
1102+
}
1103+
1104+
transform.named_sequence @__transform_main(%module: !transform.any_op) {
1105+
transform.foreach_match in %module
1106+
@match_with_wrong_maps -> @annotate
1107+
: (!transform.any_op) -> (!transform.any_op)
1108+
transform.yield
1109+
}
1110+
}

compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ iree_compiler_cc_library(
5656
],
5757
deps = [
5858
":PreprocessingExtensionsOpGen",
59+
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
60+
"//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
5961
"//compiler/src/iree/compiler/Utils",
6062
"//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
6163
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",

compiler/src/iree/compiler/Preprocessing/TransformExtensions/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ iree_cc_library(
4242
MLIRTransformDialectInterfaces
4343
MLIRTransformUtils
4444
MLIRValueBoundsOpInterface
45+
iree::compiler::Dialect::LinalgExt::IR
46+
iree::compiler::Dialect::LinalgExt::Utils
4547
iree::compiler::Utils
4648
PUBLIC
4749
)

compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
#include "iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.h"
88

9+
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
10+
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
911
#include "iree/compiler/Utils/EquivalenceUtils.h"
1012
#include "iree/compiler/Utils/ShapeUtils.h"
1113
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
@@ -302,6 +304,90 @@ IREE::transform_dialect::MatchContractionOp::matchOperation(
302304
return DiagnosedSilenceableFailure::success();
303305
}
304306

307+
DiagnosedSilenceableFailure
308+
IREE::transform_dialect::MatchAttentionOp::matchOperation(
309+
Operation *current, transform::TransformResults &results,
310+
transform::TransformState &state) {
311+
Location loc = current->getLoc();
312+
auto attentionOp = dyn_cast<IREE::LinalgExt::AttentionOp>(current);
313+
if (!attentionOp) {
314+
return emitSilenceableFailure(loc)
315+
<< "Operation is not an attention operation.";
316+
}
317+
318+
Type targetQueryType = getQueryType();
319+
Type currentQueryType =
320+
getElementTypeOrSelf(attentionOp.getQuery().getType());
321+
if (currentQueryType != targetQueryType) {
322+
return emitSilenceableFailure(loc)
323+
<< "Query type doesn't match: expected " << targetQueryType
324+
<< ", got " << currentQueryType;
325+
}
326+
327+
Type targetKeyType = getKeyType();
328+
Type currentKeyType = getElementTypeOrSelf(attentionOp.getKey().getType());
329+
if (currentKeyType != targetKeyType) {
330+
return emitSilenceableFailure(loc)
331+
<< "Key type doesn't match: expected " << targetKeyType << ", got "
332+
<< currentKeyType;
333+
}
334+
335+
Type targetValueType = getValueType();
336+
Type currentValueType =
337+
getElementTypeOrSelf(attentionOp.getValue().getType());
338+
if (currentValueType != targetValueType) {
339+
return emitSilenceableFailure(loc)
340+
<< "Value type doesn't match: expected " << targetValueType
341+
<< ", got " << currentValueType;
342+
}
343+
344+
Type targetOutputType = getOutputType();
345+
Type currentOutputType =
346+
getElementTypeOrSelf(attentionOp.getOutput().getType());
347+
if (currentOutputType != targetOutputType) {
348+
return emitSilenceableFailure(loc)
349+
<< "Output type doesn't match: expected " << targetOutputType
350+
<< ", got " << currentOutputType;
351+
}
352+
353+
ArrayAttr currentIndexingMaps = attentionOp.getIndexingMaps();
354+
ArrayAttr targetIndexingMaps = getIndexingMaps();
355+
if (currentIndexingMaps != targetIndexingMaps) {
356+
return emitSilenceableFailure(loc) << "indexing maps don't match";
357+
}
358+
359+
FailureOr<IREE::LinalgExt::AttentionOpDetail> maybeOpInfo =
360+
IREE::LinalgExt::AttentionOpDetail::get(
361+
attentionOp.getQueryMap(), attentionOp.getKeyMap(),
362+
attentionOp.getValueMap(), attentionOp.getOutputMap());
363+
if (failed(maybeOpInfo)) {
364+
return emitSilenceableFailure(loc)
365+
<< "Failed to infer attention dimensions";
366+
}
367+
IREE::LinalgExt::AttentionOpDetail opInfo = *maybeOpInfo;
368+
SmallVector<int64_t> iterationDomain = attentionOp.getStaticLoopRanges();
369+
370+
Builder builder(getContext());
371+
auto iterationSizes = [&](ArrayRef<int64_t> dimIndices) {
372+
return llvm::map_to_vector(dimIndices, [&](int64_t dimIdx) -> Attribute {
373+
return builder.getI64IntegerAttr(iterationDomain[dimIdx]);
374+
});
375+
};
376+
377+
results.setParams(cast<OpResult>(getBatchDims()),
378+
iterationSizes(opInfo.getBatchDims()));
379+
results.setParams(cast<OpResult>(getMDims()),
380+
iterationSizes(opInfo.getMDims()));
381+
results.setParams(cast<OpResult>(getNDims()),
382+
iterationSizes(opInfo.getNDims()));
383+
results.setParams(cast<OpResult>(getK1Dims()),
384+
iterationSizes(opInfo.getK1Dims()));
385+
results.setParams(cast<OpResult>(getK2Dims()),
386+
iterationSizes(opInfo.getK2Dims()));
387+
388+
return DiagnosedSilenceableFailure::success();
389+
}
390+
305391
//===----------------------------------------------------------------------===//
306392
// MatchConvolutionOp
307393
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)