@@ -922,3 +922,189 @@ module attributes {transform.with_named_sequence} {
922922 transform.yield
923923 }
924924}
925+
926+ // -----
927+
928+ // Verify that the basic attention matcher works.
929+
930+ #map_query = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d2 , d4 )>
931+ #map_key = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d5 , d4 )>
932+ #map_value = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d3 , d5 )>
933+ #map_scale = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> ()>
934+ #map_output = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d2 , d3 )>
935+
936+ // CHECK-LABEL: func.func @attention_ops
937+ func.func @attention_ops (
938+ %query: tensor <2 x10 x6 x4 xf16 >,
939+ %key: tensor <2 x10 x4 x4 xf16 >,
940+ %value: tensor <2 x10 x4 x4 xf16 >,
941+ %scale: f16 ,
942+ %output: tensor <2 x10 x6 x4 xf16 >,
943+ %input_mm: tensor <32 x64 xf32 >,
944+ %filter_mm: tensor <64 x32 xf32 >,
945+ %output_mm: tensor <32 x32 xf32 >) -> (tensor <2 x10 x6 x4 xf16 >, tensor <32 x32 xf32 >) {
946+
947+ // CHECK: iree_linalg_ext.attention
948+ // CHECK-SAME: match_status = "matched"
949+ %res1 = iree_linalg_ext.attention {index ing_maps = [#map_query , #map_key , #map_value , #map_scale , #map_output ], match_status = " unmatched" }
950+ ins (%query , %key , %value , %scale : tensor <2 x10 x6 x4 xf16 >, tensor <2 x10 x4 x4 xf16 >, tensor <2 x10 x4 x4 xf16 >, f16 )
951+ outs (%output : tensor <2 x10 x6 x4 xf16 >) {
952+ ^bb0 (%arg: f32 ):
953+ iree_linalg_ext.yield %arg : f32
954+ } -> tensor <2 x10 x6 x4 xf16 >
955+
956+ // CHECK: linalg.matmul
957+ // CHECK-SAME: match_status = "unmatched"
958+ %res2 = linalg.matmul
959+ ins (%input_mm , %filter_mm : tensor <32 x64 xf32 >, tensor <64 x32 xf32 >)
960+ outs (%output_mm : tensor <32 x32 xf32 >) {match_status = " unmatched" } -> tensor <32 x32 xf32 >
961+
962+ return %res1 , %res2 : tensor <2 x10 x6 x4 xf16 >, tensor <32 x32 xf32 >
963+ }
964+
965+ module attributes {transform.with_named_sequence } {
966+ transform.named_sequence @match_attention (%op: !transform.any_op {transform.readonly }) -> !transform.any_op {
967+ %batch , %m , %k1 , %k2 , %n =
968+ transform.iree.match.attention %op ,
969+ query_type = f16 , key_type = f16 , value_type = f16 , output_type = f16 ,
970+ indexing_maps = [#map_query , #map_key , #map_value , #map_scale , #map_output ] :
971+ !transform.any_op -> !transform.param <i64 >
972+
973+ transform.yield %op : !transform.any_op
974+ }
975+
976+ transform.named_sequence @annotate (%op: !transform.any_op {transform.readonly }) {
977+ %0 = transform.param.constant " matched" -> !transform.any_param
978+ transform.annotate %op " match_status" = %0 : !transform.any_op , !transform.any_param
979+ transform.yield
980+ }
981+
982+ transform.named_sequence @__transform_main (%module: !transform.any_op ) {
983+ transform.foreach_match in %module
984+ @match_attention -> @annotate
985+ : (!transform.any_op ) -> (!transform.any_op )
986+ transform.yield
987+ }
988+ }
989+
990+ // -----
991+
992+ // Verify dimension size constraints for the attention op with lowering config.
993+
994+ #map_query = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d2 , d4 )>
995+ #map_key = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d5 , d4 )>
996+ #map_value = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d3 , d5 )>
997+ #map_scale = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> ()>
998+ #map_output = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d2 , d3 )>
999+
1000+ #lowering_config = #iree_gpu.lowering_config <{promote_operands = [0 , 1 ]}>
1001+
1002+ // CHECK-LABEL: func.func @attention_constraints
1003+ func.func @attention_constraints (
1004+ %query: tensor <2 x10 x6 x4 xf16 >,
1005+ %key: tensor <2 x10 x4 x4 xf16 >,
1006+ %value: tensor <2 x10 x4 x4 xf16 >,
1007+ %scale: f16 ,
1008+ %output: tensor <2 x10 x6 x4 xf16 >) -> tensor <2 x10 x6 x4 xf16 > {
1009+
1010+ // CHECK: iree_linalg_ext.attention
1011+ // CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1]}>
1012+ // CHECK-SAME: match_status = "matched"
1013+ %res = iree_linalg_ext.attention {
1014+ indexing_maps = [#map_query , #map_key , #map_value , #map_scale , #map_output ],
1015+ lowering_config = #lowering_config ,
1016+ match_status = " unmatched" }
1017+ ins (%query , %key , %value , %scale : tensor <2 x10 x6 x4 xf16 >, tensor <2 x10 x4 x4 xf16 >, tensor <2 x10 x4 x4 xf16 >, f16 )
1018+ outs (%output : tensor <2 x10 x6 x4 xf16 >) {
1019+ ^bb0 (%arg: f32 ):
1020+ iree_linalg_ext.yield %arg : f32
1021+ } -> tensor <2 x10 x6 x4 xf16 >
1022+
1023+ return %res : tensor <2 x10 x6 x4 xf16 >
1024+ }
1025+
1026+ module attributes {transform.with_named_sequence } {
1027+ transform.named_sequence @match_attention_all_dims (%op: !transform.any_op {transform.readonly }) -> !transform.any_op {
1028+ %batch , %m , %n , %k1 , %k2 =
1029+ transform.iree.match.attention %op ,
1030+ query_type = f16 , key_type = f16 , value_type = f16 , output_type = f16 ,
1031+ indexing_maps = [#map_query , #map_key , #map_value , #map_scale , #map_output ] :
1032+ !transform.any_op -> !transform.param <i64 >
1033+
1034+ transform.iree.match.dims_equal %batch , [2 , 10 ] : !transform.param <i64 >
1035+ transform.iree.match.dims_equal %m , [6 ] : !transform.param <i64 >
1036+ transform.iree.match.dims_equal %k1 , [4 ] : !transform.param <i64 >
1037+ transform.iree.match.dims_equal %k2 , [4 ] : !transform.param <i64 >
1038+ transform.iree.match.dims_equal %n , [4 ] : !transform.param <i64 >
1039+ transform.yield %op : !transform.any_op
1040+ }
1041+
1042+ transform.named_sequence @annotate (%op: !transform.any_op {transform.readonly }) {
1043+ %0 = transform.param.constant " matched" -> !transform.any_param
1044+ transform.annotate %op " match_status" = %0 : !transform.any_op , !transform.any_param
1045+ transform.yield
1046+ }
1047+
1048+ transform.named_sequence @__transform_main (%module: !transform.any_op ) {
1049+ transform.foreach_match in %module
1050+ @match_attention_all_dims -> @annotate
1051+ : (!transform.any_op ) -> (!transform.any_op )
1052+ transform.yield
1053+ }
1054+ }
1055+
1056+ // -----
1057+
1058+ // Verify indexing maps mismatching for the attention op.
1059+
1060+ #map_query = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d2 , d4 )>
1061+ #map_key = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d5 , d4 )>
1062+ #map_value = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d3 , d5 )>
1063+ #map_scale = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> ()>
1064+ #map_output = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d2 , d3 )>
1065+
1066+ #map_wrong_key = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d4 , d5 )>
1067+
1068+ // CHECK-LABEL: func.func @indexing_maps_test
1069+ func.func @indexing_maps_test (
1070+ %query: tensor <2 x10 x6 x4 xf16 >,
1071+ %key: tensor <2 x10 x4 x4 xf16 >,
1072+ %value: tensor <2 x10 x4 x4 xf16 >,
1073+ %scale: f16 ,
1074+ %output: tensor <2 x10 x6 x4 xf16 >) -> tensor <2 x10 x6 x4 xf16 > {
1075+
1076+ // CHECK: iree_linalg_ext.attention
1077+ // CHECK-SAME: maps_match = "unmatched"
1078+ %res = iree_linalg_ext.attention {index ing_maps = [#map_query , #map_key , #map_value , #map_scale , #map_output ], maps_match = " unmatched" }
1079+ ins (%query , %key , %value , %scale : tensor <2 x10 x6 x4 xf16 >, tensor <2 x10 x4 x4 xf16 >, tensor <2 x10 x4 x4 xf16 >, f16 )
1080+ outs (%output : tensor <2 x10 x6 x4 xf16 >) {
1081+ ^bb0 (%arg: f32 ):
1082+ iree_linalg_ext.yield %arg : f32
1083+ } -> tensor <2 x10 x6 x4 xf16 >
1084+
1085+ return %res : tensor <2 x10 x6 x4 xf16 >
1086+ }
1087+
1088+ module attributes {transform.with_named_sequence } {
1089+ transform.named_sequence @match_with_wrong_maps (%op: !transform.any_op {transform.readonly }) -> !transform.any_op {
1090+ %batch , %m , %k1 , %k2 , %n =
1091+ transform.iree.match.attention %op ,
1092+ query_type = f16 , key_type = f16 , value_type = f16 , output_type = f16 ,
1093+ indexing_maps = [#map_query , #map_wrong_key , #map_value , #map_scale , #map_output ] :
1094+ !transform.any_op -> !transform.param <i64 >
1095+ transform.yield %op : !transform.any_op
1096+ }
1097+
1098+ transform.named_sequence @annotate (%op: !transform.any_op {transform.readonly }) {
1099+ %0 = transform.param.constant " matched" -> !transform.any_param
1100+ transform.annotate %op " maps_match" = %0 : !transform.any_op , !transform.any_param
1101+ transform.yield
1102+ }
1103+
1104+ transform.named_sequence @__transform_main (%module: !transform.any_op ) {
1105+ transform.foreach_match in %module
1106+ @match_with_wrong_maps -> @annotate
1107+ : (!transform.any_op ) -> (!transform.any_op )
1108+ transform.yield
1109+ }
1110+ }
0 commit comments