Skip to content

Commit 48e3cf5

Browse files
Add lit tests.
Signed-off-by: MaheshRavishankar <[email protected]>
1 parent e901ea3 commit 48e3cf5

File tree

1 file changed

+170
-0
lines changed

1 file changed

+170
-0
lines changed

compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,3 +1041,173 @@ util.func @move_captured_from_above_ops(%arg0 : tensor<1x1x2x4xi32>,
10411041
// CHECK: linalg.yield
10421042
// CHECK: flow.return %[[GENERIC]]
10431043
// CHECK: util.return %[[DISPATCH]]
1044+
1045+
// -----
1046+
1047+
util.func @horizontal_fusion1(%lhs : tensor<2x4096x640xf16>,
1048+
%rhs0 : tensor<10x64x640xf16>, %rhs1 : tensor<10x64x640xf16>,
1049+
%rhs2 : tensor<10x64x640xf16>) ->
1050+
(tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>,
1051+
tensor<2x10x4096x64xf16>) {
1052+
%4 = tensor.empty() : tensor<2x10x4096x64xf32>
1053+
%cst = arith.constant 0.0 : f32
1054+
%5 = linalg.fill ins(%cst : f32)
1055+
outs(%4 : tensor<2x10x4096x64xf32>) -> tensor<2x10x4096x64xf32>
1056+
%6:3 = linalg.generic {
1057+
indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>,
1058+
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
1059+
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
1060+
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
1061+
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>,
1062+
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>,
1063+
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>],
1064+
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]}
1065+
ins(%lhs, %rhs0, %rhs1, %rhs2
1066+
: tensor<2x4096x640xf16>, tensor<10x64x640xf16>, tensor<10x64x640xf16>,
1067+
tensor<10x64x640xf16>)
1068+
outs(%5, %5, %5
1069+
: tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>) {
1070+
^bb0(%in: f16, %in_0: f16, %in_1: f16, %in_2: f16, %out: f32, %out_3: f32, %out_4: f32):
1071+
%14 = arith.extf %in : f16 to f32
1072+
%15 = arith.extf %in_0 : f16 to f32
1073+
%16 = arith.mulf %14, %15 : f32
1074+
%17 = arith.addf %out, %16 : f32
1075+
%18 = arith.extf %in_1 : f16 to f32
1076+
%19 = arith.mulf %14, %18 : f32
1077+
%20 = arith.addf %out_3, %19 : f32
1078+
%21 = arith.extf %in_2 : f16 to f32
1079+
%22 = arith.mulf %14, %21 : f32
1080+
%23 = arith.addf %out_4, %22 : f32
1081+
linalg.yield %17, %20, %23 : f32, f32, f32
1082+
} -> (tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>)
1083+
%7 = tensor.empty() : tensor<2x10x4096x64xf16>
1084+
%8 = linalg.generic {
1085+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
1086+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
1087+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1088+
ins(%6#0 : tensor<2x10x4096x64xf32>) outs(%7 : tensor<2x10x4096x64xf16>) {
1089+
^bb0(%in: f32, %out: f16):
1090+
%14 = arith.truncf %in : f32 to f16
1091+
linalg.yield %14 : f16
1092+
} -> tensor<2x10x4096x64xf16>
1093+
%9 = linalg.generic {
1094+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
1095+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
1096+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1097+
ins(%6#1 : tensor<2x10x4096x64xf32>) outs(%7 : tensor<2x10x4096x64xf16>) {
1098+
^bb0(%in: f32, %out: f16):
1099+
%14 = arith.truncf %in : f32 to f16
1100+
linalg.yield %14 : f16
1101+
} -> tensor<2x10x4096x64xf16>
1102+
%10 = linalg.generic {
1103+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
1104+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
1105+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1106+
ins(%6#2 : tensor<2x10x4096x64xf32>) outs(%7 : tensor<2x10x4096x64xf16>) {
1107+
^bb0(%in: f32, %out: f16):
1108+
%14 = arith.truncf %in : f32 to f16
1109+
linalg.yield %14 : f16
1110+
} -> tensor<2x10x4096x64xf16>
1111+
util.return %8, %9, %10 : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>
1112+
}
1113+
// CHECK-LABEL: func public @horizontal_fusion1
1114+
// CHECK: %[[DISPATCH:.+]]:3 = flow.dispatch.region
1115+
// CHECK: %[[GENERIC:.+]]:3 = linalg.generic
1116+
// CHECK: %[[TRUNC0:.+]] = linalg.generic
1117+
// CHECK-SAME: ins(%[[GENERIC]]#0 :
1118+
// CHECK: %[[TRUNC1:.+]] = linalg.generic
1119+
// CHECK-SAME: ins(%[[GENERIC]]#1 :
1120+
// CHECK: %[[TRUNC2:.+]] = linalg.generic
1121+
// CHECK-SAME: ins(%[[GENERIC]]#2 :
1122+
// CHECK: flow.return %[[TRUNC0]], %[[TRUNC1]], %[[TRUNC2]]
1123+
// CHECK: util.return %[[DISPATCH]]#0, %[[DISPATCH]]#1, %[[DISPATCH]]#2
1124+
1125+
// -----
1126+
1127+
util.func @horizontal_fusion2(%lhs : tensor<2x4096x640xi8>,
1128+
%rhs0 : tensor<2x640x640xi8>, %rhs1 : tensor<2x640x640xi8>)
1129+
-> tensor<2x4096x640xf16> {
1130+
%c0_i32 = arith.constant 32 : i32
1131+
%0 = tensor.empty() : tensor<2x4096x640xf16>
1132+
%1 = tensor.empty() : tensor<2x4096x640xi32>
1133+
%2 = linalg.fill ins(%c0_i32 : i32)
1134+
outs(%1 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
1135+
%3:2 = linalg.generic {
1136+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
1137+
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
1138+
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
1139+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
1140+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
1141+
iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
1142+
ins(%lhs, %rhs0, %rhs1
1143+
: tensor<2x4096x640xi8>, tensor<2x640x640xi8>, tensor<2x640x640xi8>)
1144+
outs(%2, %2 : tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) {
1145+
^bb0(%in: i8, %in_0: i8, %in_1: i8, %out: i32, %out_2: i32):
1146+
%4 = arith.extsi %in : i8 to i32
1147+
%5 = arith.extsi %in_0 : i8 to i32
1148+
%6 = arith.muli %4, %5 : i32
1149+
%7 = arith.addi %out, %6 : i32
1150+
%8 = arith.extsi %in_1 : i8 to i32
1151+
%9 = arith.muli %7, %8 : i32
1152+
%10 = arith.addi %out_2, %9 : i32
1153+
linalg.yield %7, %10 : i32, i32
1154+
} -> (tensor<2x4096x640xi32>, tensor<2x4096x640xi32>)
1155+
%4 = linalg.generic {
1156+
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1157+
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1158+
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
1159+
iterator_types = ["parallel", "parallel", "parallel"]}
1160+
ins(%3#1, %3#0 : tensor<2x4096x640xi32>, tensor<2x4096x640xi32>)
1161+
outs(%0 : tensor<2x4096x640xf16>) {
1162+
^bb0(%in: i32, %in_0: i32, %out: f16):
1163+
%5 = arith.sitofp %in : i32 to f32
1164+
%6 = arith.truncf %5 : f32 to f16
1165+
%7 = arith.sitofp %in_0 : i32 to f32
1166+
%8 = arith.truncf %7 : f32 to f16
1167+
%9 = arith.addf %6, %8 : f16
1168+
linalg.yield %9 : f16
1169+
} -> tensor<2x4096x640xf16>
1170+
util.return %4 : tensor<2x4096x640xf16>
1171+
}
1172+
// CHECK-LABEL: func public @horizontal_fusion2
1173+
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
1174+
// CHECK: %[[GENERIC:.+]]:2 = linalg.generic
1175+
// CHECK: %[[TRUNC:.+]] = linalg.generic
1176+
// CHECK-SAME: ins(%[[GENERIC]]#1, %[[GENERIC]]#0 :
1177+
// CHECK: flow.return %[[TRUNC]]
1178+
// CHECK: util.return %[[DISPATCH]]
1179+
1180+
// -----
1181+
1182+
util.func @avoid_use_def_violation_on_consumer_fusion(%arg0 : tensor<?xf32>,
1183+
%arg1 : tensor<f32>) -> tensor<f32> {
1184+
%0 = linalg.generic {
1185+
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>],
1186+
iterator_types = ["reduction"]}
1187+
ins(%arg0 : tensor<?xf32>) outs(%arg1 : tensor<f32>) {
1188+
^bb0(%b0 : f32, %b1 : f32):
1189+
%1 = arith.addf %b0, %b1 : f32
1190+
linalg.yield %1 : f32
1191+
} -> tensor<f32>
1192+
%1 = util.optimization_barrier %0 : tensor<f32>
1193+
%2 = tensor.empty() : tensor<f32>
1194+
%3 = linalg.generic {
1195+
indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>],
1196+
iterator_types = []}
1197+
ins(%0, %1 : tensor<f32>, tensor<f32>) outs(%2 : tensor<f32>) {
1198+
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
1199+
%4 = arith.mulf %b0, %b1 : f32
1200+
linalg.yield %4 : f32
1201+
} -> tensor<f32>
1202+
util.return %3 : tensor<f32>
1203+
}
1204+
// CHECK-LABEL: func public @avoid_use_def_violation_on_consumer_fusion
1205+
// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region
1206+
// CHECK: %[[GENERIC1:.+]] = linalg.generic
1207+
// CHECK: flow.return %[[GENERIC1]]
1208+
// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[DISPATCH1]]
1209+
// CHECK: %[[DISPATCH2:.+]] = flow.dispatch.region
1210+
// CHECK: %[[GENERIC2:.+]] = linalg.generic
1211+
// CHECK-SAME: ins(%[[DISPATCH1]], %[[BARRIER]] :
1212+
// CHECK: flow.return %[[GENERIC2]]
1213+
// CHECK: util.return %[[DISPATCH2]]

0 commit comments

Comments
 (0)