@@ -676,6 +676,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
676676 // CHECK-NEXT: [[CST_0:%.*]] = llvm.mlir.constant(0 : i32) : i32
677677 // CHECK-NEXT: [[IE1:%.*]] = llvm.insertelement [[BCAST0]], [[VEC1]][[[CST_0]] : i32] : vector<1xf32>
678678 // CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast [[IE1]] : vector<1xf32> to i32
679+ // CHECK-NEXT: [[TRUE1:%.*]] = llvm.mlir.constant(true) : i1
679680 // CHECK-NEXT: [[AND1:%.*]] = llvm.and {{.*}}, [[ARG2_0]] : i1
680681 // CHECK-NEXT: [[VEC2:%.*]] = llvm.mlir.undef : vector<1xi32>
681682 // CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -1059,17 +1060,23 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10591060module attributes {" ttg.target" = " xpu" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
10601061 // CHECK-LABEL: atomic_cas_f32_scalar_no_store
10611062 tt.func @atomic_cas_f32_scalar_no_store (%ptr : !tt.ptr <f32 >, %cmp : f32 , %val : f32 ) {
1062- // CHECK: [[TRUE:%.*]] = llvm.mlir.constant(true) : i1
1063- // CHECK: [[CMP0:%.*]] = llvm.icmp "eq"
1064- // CHECK: [[MASK0:%.*]] = llvm.and [[TRUE]], [[CMP0]]
1065- // CHECK: [[CMP:%.*]] = llvm.icmp "eq"
1066- // CHECK: [[MASK:%.*]] = llvm.and [[MASK0]], [[CMP]]
1067- // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
1063+ // CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32
1064+ // CHECK: [[TRUE:%.*]] = llvm.mlir.constant(-1 : i32) : i32
1065+ // CHECK: [[MASKLANE:%.*]] = llvm.and
1066+ // CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]]
1067+ // CHECK: [[MASKWARP:%.*]] = llvm.and
1068+ // CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]]
1069+ // CHECK-NEXT: [[MASKWARPANDLANE:%.*]] = llvm.and [[CMPLANE]], [[CMPWARP]]
1070+ // CHECK: llvm.mlir.constant(-1 : i32) : i32
1071+ // CHECK: [[MASKBLOCK:%.*]] = llvm.and
1072+ // CHECK-NEXT: [[CMPBLOCK:%.*]] = llvm.icmp "eq" [[MASKBLOCK]], [[ZERO0]]
1073+ // CHECK-NEXT: [[MASK:%.*]] = llvm.and [[MASKWARPANDLANE]], [[CMPBLOCK]]
1074+ // CHECK: [[ZERO1:%.*]] = llvm.mlir.constant(0 : i32) : i32
10681075 // CHECK: [[WGSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
10691076 // CHECK: [[WGMEMSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
10701077 // CHECK: [[GLOBAL:%.*]] = llvm.mlir.constant(528 : i32) : i32
10711078 // CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[WGSCOPE]], [[WGMEMSCOPE]], [[GLOBAL]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
1072- // CHECK-NEXT: llvm.cond_br [[MASK]], ^bb1, ^bb2([[ZERO ]] : i32)
1079+ // CHECK-NEXT: llvm.cond_br [[MASK]], ^bb1, ^bb2([[ZERO1 ]] : i32)
10731080 // CHECK-NEXT: ^bb1:
10741081 // CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast %arg1 : f32 to i32
10751082 // CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast %arg2 : f32 to i32
@@ -1089,13 +1096,19 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp
10891096 // CHECK: llvm.func spir_funccc @_Z7barrierj(i32) attributes {convergent, no_unwind, will_return}
10901097 // CHECK-LABEL: atomic_cas_f32_scalar
10911098 tt.func @atomic_cas_f32_scalar (%ptr : !tt.ptr <f32 >, %cmp : f32 , %val : f32 ) {
1092- // CHECK: [[TRUE:%.*]] = llvm.mlir.constant(true) : i1
1093- // CHECK: [[CMP0:%.*]] = llvm.icmp "eq"
1094- // CHECK: [[MASK0:%.*]] = llvm.and [[TRUE]], [[CMP0]]
1095- // CHECK: [[CMP:%.*]] = llvm.icmp "eq"
1096- // CHECK: [[MASK:%.*]] = llvm.and [[MASK0]], [[CMP]]
1097- // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
1098- // CHECK-NEXT: llvm.cond_br [[MASK]], ^bb1, ^bb2([[ZERO]] : i32)
1099+ // CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32
1100+ // CHECK: [[TRUE:%.*]] = llvm.mlir.constant(-1 : i32) : i32
1101+ // CHECK: [[MASKLANE:%.*]] = llvm.and
1102+ // CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]]
1103+ // CHECK: [[MASKWARP:%.*]] = llvm.and
1104+ // CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]]
1105+ // CHECK-NEXT: [[MASKWARPANDLANE:%.*]] = llvm.and [[CMPLANE]], [[CMPWARP]]
1106+ // CHECK: llvm.mlir.constant(-1 : i32) : i32
1107+ // CHECK: [[MASKBLOCK:%.*]] = llvm.and
1108+ // CHECK-NEXT: [[CMPBLOCK:%.*]] = llvm.icmp "eq" [[MASKBLOCK]], [[ZERO0]]
1109+ // CHECK-NEXT: [[MASK:%.*]] = llvm.and [[MASKWARPANDLANE]], [[CMPBLOCK]]
1110+ // CHECK: [[ZERO1:%.*]] = llvm.mlir.constant(0 : i32) : i32
1111+ // CHECK-NEXT: llvm.cond_br [[MASK]], ^bb1, ^bb2([[ZERO1]] : i32)
10991112 // CHECK-NEXT: ^bb1:
11001113 // CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast %arg1 : f32 to i32
11011114 // CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast %arg2 : f32 to i32
@@ -1131,14 +1144,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11311144 // CHECK-NEXT: [[EV1_ARG2:%.*]] = llvm.extractvalue %arg2[1] : !llvm.struct<(f32, f32)>
11321145 // CHECK: [[EV0_ARG0:%.*]] = llvm.extractvalue %arg0[0] : !llvm.struct<(ptr<1>, ptr<1>)>
11331146 // CHECK-NEXT: [[EV1_ARG0:%.*]] = llvm.extractvalue %arg0[1] : !llvm.struct<(ptr<1>, ptr<1>)>
1134- // CHECK: llvm.mlir.constant(true) : i1
1135- // CHECK: [[CST_TRUE:%.*]] = llvm.mlir.constant(true) : i1
1136- // CHECK: [[PRED0:%.*]] = llvm.and [[CST_TRUE]], {{.*}} : i1
1137- // CHECK-NEXT: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
1147+ // CHECK: [[EV0_ARG1:%.*]] = llvm.extractvalue %arg1[0] : !llvm.struct<(i1, i1)>
1148+ // CHECK-NEXT: [[EV1_ARG1:%.*]] = llvm.extractvalue %arg1[1] : !llvm.struct<(i1, i1)>
1149+ // CHECK: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
11381150 // CHECK: [[IE1:%.*]] = llvm.insertelement [[EV0_ARG2]], [[UNDEF1]][{{.*}} : i64] : vector<1xf32>
1139- // CHECK-NEXT: [[PRED1:%.*]] = llvm.and [[PRED0]], {{.*}} : i1
11401151 // CHECK-NEXT: [[ZERO1:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
1141- // CHECK: llvm.cond_br [[PRED1 ]], ^bb1, ^bb2([[ZERO1]] : f32)
1152+ // CHECK: llvm.cond_br [[EV0_ARG1 ]], ^bb1, ^bb2([[ZERO1]] : f32)
11421153 // CHECK-NEXT: ^bb1:
11431154 // CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast [[IE1]] : vector<1xf32> to f32
11441155 // CHECK-NEXT: [[RMW_RES1:%.*]] = llvm.atomicrmw fadd [[EV0_ARG0]], [[BCAST2]] monotonic : !llvm.ptr<1>, f32
@@ -1147,13 +1158,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11471158 // CHECK-NEXT: [[RMW_CAST:%.*]] = llvm.bitcast [[RMW_PHI1]] : f32 to f32
11481159 // CHECK-NEXT: [[UNDEF2:%.*]] = llvm.mlir.undef : vector<1xf32>
11491160 // CHECK: [[IE2:%.*]] = llvm.insertelement [[EV1_ARG2]], [[UNDEF2]][{{.*}} : i64] : vector<1xf32>
1150- // CHECK-NEXT: [[PRED2:%.*]] = llvm.and [[PRED0]], {{.*}} : i1
11511161 // CHECK-NEXT: [[ZERO2:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
11521162 // CHECK: [[WGSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
11531163 // CHECK: [[WGMEMSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
11541164 // CHECK: [[GLOBAL:%.*]] = llvm.mlir.constant(528 : i32) : i32
11551165 // CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[WGSCOPE]], [[WGMEMSCOPE]], [[GLOBAL]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
1156- // CHECK-NEXT: llvm.cond_br [[PRED2 ]], ^bb3, ^bb4([[ZERO2]] : f32)
1166+ // CHECK-NEXT: llvm.cond_br [[EV1_ARG1 ]], ^bb3, ^bb4([[ZERO2]] : f32)
11571167 // CHECK-NEXT: ^bb3:
11581168 // CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast [[IE2]] : vector<1xf32> to f32
11591169 // CHECK-NEXT: [[RMW_RES2:%.*]] = llvm.atomicrmw fadd [[EV1_ARG0]], [[BCAST2]] monotonic : !llvm.ptr<1>, f32
@@ -1169,14 +1179,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11691179module attributes {" ttg.target" = " xpu" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
11701180 // CHECK-LABEL: atomic_add_f32_scalar_no_store
11711181 tt.func @atomic_add_f32_scalar_no_store (%arg0 : !tt.ptr <f32 >, %arg1 : i1 , %arg2 : f32 ) {
1172- // CHECK: [[CST_TRUE:%.*]] = llvm.mlir.constant(true) : i1
1173- // CHECK: [[CMP:%.*]] = llvm.icmp "eq"
1174- // CHECK-NEXT: [[AND:%.*]] = llvm.and [[CST_TRUE]], [[CMP]] : i1
1175- // CHECK: [[CMP1:%.*]] = llvm.icmp "eq"
1176- // CHECK-NEXT: [[AND1:%.*]] = llvm.and [[AND]], [[CMP1]] : i1
1177- // CHECK: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
1182+ // CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32
1183+ // CHECK: [[MASKLANE:%.*]] = llvm.and
1184+ // CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]]
1185+ // CHECK: [[MASKWARP:%.*]] = llvm.and
1186+ // CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]]
1187+ // CHECK-NEXT: [[MASKWARPANDLANE:%.*]] = llvm.and [[CMPLANE]], [[CMPWARP]]
1188+ // CHECK: llvm.mlir.constant(-1 : i32) : i32
1189+ // CHECK: [[MASKBLOCK:%.*]] = llvm.and
1190+ // CHECK-NEXT: [[CMPBLOCK:%.*]] = llvm.icmp "eq" [[MASKBLOCK]], [[ZERO0]]
1191+ // CHECK-NEXT: [[MASK:%.*]] = llvm.and [[MASKWARPANDLANE]], [[CMPBLOCK]]
1192+ // CHECK-NEXT: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
11781193 // CHECK: [[IE1:%.*]] = llvm.insertelement %arg2, [[UNDEF1]][{{.*}} : i64] : vector<1xf32>
1179- // CHECK: [[PRED:%.*]] = llvm.and [[AND1]], %arg1 : i1
1194+ // CHECK: [[PRED:%.*]] = llvm.and %arg1, [[MASK]] : i1
11801195 // CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
11811196 // CHECK: [[WGSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
11821197 // CHECK: [[WGMEMSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
@@ -1200,14 +1215,19 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp
12001215 // CHECK: llvm.func spir_funccc @_Z7barrierj(i32) attributes {convergent, no_unwind, will_return}
12011216 // CHECK-LABEL: atomic_add_f32_scalar
12021217 tt.func @atomic_add_f32_scalar (%arg0 : !tt.ptr <f32 >, %arg1 : i1 , %arg2 : f32 ) {
1203- // CHECK: [[CST_TRUE:%.*]] = llvm.mlir.constant(true) : i1
1204- // CHECK: [[CMP:%.*]] = llvm.icmp "eq"
1205- // CHECK-NEXT: [[AND:%.*]] = llvm.and [[CST_TRUE]], [[CMP]] : i1
1206- // CHECK: [[CMP1:%.*]] = llvm.icmp "eq"
1207- // CHECK-NEXT: [[AND1:%.*]] = llvm.and [[AND]], [[CMP1]] : i1
1208- // CHECK: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
1218+ // CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32
1219+ // CHECK: [[MASKLANE:%.*]] = llvm.and
1220+ // CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]]
1221+ // CHECK: [[MASKWARP:%.*]] = llvm.and
1222+ // CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]]
1223+ // CHECK-NEXT: [[MASKWARPANDLANE:%.*]] = llvm.and [[CMPLANE]], [[CMPWARP]]
1224+ // CHECK: llvm.mlir.constant(-1 : i32) : i32
1225+ // CHECK: [[MASKBLOCK:%.*]] = llvm.and
1226+ // CHECK-NEXT: [[CMPBLOCK:%.*]] = llvm.icmp "eq" [[MASKBLOCK]], [[ZERO0]]
1227+ // CHECK-NEXT: [[MASK:%.*]] = llvm.and [[MASKWARPANDLANE]], [[CMPBLOCK]]
1228+ // CHECK-NEXT: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32>
12091229 // CHECK: [[IE1:%.*]] = llvm.insertelement %arg2, [[UNDEF1]][{{.*}} : i64] : vector<1xf32>
1210- // CHECK: [[PRED:%.*]] = llvm.and [[AND1]], %arg1 : i1
1230+ // CHECK: [[PRED:%.*]] = llvm.and %arg1, [[MASK]] : i1
12111231 // CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
12121232 // CHECK-NEXT: llvm.cond_br [[PRED]], ^bb1, ^bb2([[ZERO]] : f32)
12131233 // CHECK-NEXT: ^bb1:
@@ -1295,22 +1315,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
12951315 // CHECK-NEXT: [[ARG0_1:%.*]] = llvm.extractvalue %arg0[1] : !llvm.struct<(ptr<1>, ptr<1>)>
12961316 // CHECK-NEXT: [[ARG1_0:%.*]] = llvm.extractvalue %arg1[0] : !llvm.struct<(f32, f32)>
12971317 // CHECK-NEXT: [[ARG1_1:%.*]] = llvm.extractvalue %arg1[1] : !llvm.struct<(f32, f32)>
1298- // CHECK: llvm.mlir.constant(true) : i1
12991318 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
1300- // CHECK-NEXT: llvm.call spir_funccc @_Z12get_local_idj([[ZERO]]) {{.*}} : (i32) -> i64
1301- // CHECK: [[TRUE1:%.*]] = llvm.mlir.constant(true) : i1
1302- // CHECK: [[TRUE2:%.*]] = llvm.mlir.constant(true) : i1
1303- // CHECK: [[PRED:%.*]] = llvm.and [[TRUE1]], [[TRUE2]] : i1
1319+ // CHECK: [[ZERO1:%.*]] = llvm.mlir.constant(0 : i32) : i32
1320+ // CHECK-NEXT: llvm.call spir_funccc @_Z12get_local_idj([[ZERO1]]) {{.*}} : (i32) -> i64
1321+ // CHECK: [[PRED:%.*]] = llvm.mlir.constant(true) : i1
13041322 // CHECK: llvm.cond_br [[PRED]], ^bb1, ^bb2
13051323 // CHECK-NEXT: ^bb1:
13061324 // CHECK-NEXT: [[BCAST:%.*]] = llvm.bitcast [[ARG0_0]] : !llvm.ptr<1> to !llvm.ptr<1>
13071325 // CHECK-NEXT: llvm.store {{.*}}, [[BCAST]] {alignment = 4 : i64} : vector<1xi32>, !llvm.ptr<1>
13081326 // CHECK-NEXT: llvm.br ^bb2
13091327 // CHECK-NEXT: ^bb2:
1328+ // CHECK: llvm.mlir.undef : vector<1xf32>
1329+ // CHECK: [[PRED2:%.*]] = llvm.mlir.constant(true) : i1
13101330 // CHECK: [[VEC:%.*]] = llvm.mlir.undef : vector<1xi32>
13111331 // CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
13121332 // CHECK-NEXT: [[IE1:%.*]] = llvm.insertelement {{.*}}, [[VEC]][[[ZERO]] : i32] : vector<1xi32>
1313- // CHECK: llvm.cond_br [[PRED ]], ^bb3, ^bb4
1333+ // CHECK: llvm.cond_br [[PRED2 ]], ^bb3, ^bb4
13141334 // CHECK-NEXT: ^bb3:
13151335 // CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast [[ARG0_1]] : !llvm.ptr<1> to !llvm.ptr<1>
13161336 // CHECK-NEXT: llvm.store [[IE1]], [[BCAST1]] {alignment = 4 : i64} : vector<1xi32>, !llvm.ptr<1>
0 commit comments