@@ -453,9 +453,9 @@ tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr
453453 // CHECK-NOT: ttg.convert_layout
454454 %cst = arith.constant dense <30000 > : tensor <1 x1 xi32 , #blocked2 >
455455 %cst_0 = arith.constant dense <30000 > : tensor <1 x512 xi32 , #blocked2 >
456- %c512 = arith.constant 512 : index
457- %c30000 = arith.constant 30000 : index
458- %c0 = arith.constant 0 : index
456+ %c512 = arith.constant 512 : i32
457+ %c30000 = arith.constant 30000 : i32
458+ %c0 = arith.constant 0 : i32
459459 %cst_1 = arith.constant dense <2048 > : tensor <1 x1 xi32 , #blocked2 >
460460 %cst_2 = arith.constant dense <0.000000e+00 > : tensor <1 x512 xf64 , #blocked2 >
461461 %0 = tt.get_program_id x : i32
@@ -473,9 +473,8 @@ tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr
473473 %12 = tt.broadcast %11 : tensor <1 x1 xi32 , #blocked2 > -> tensor <1 x512 xi32 , #blocked2 >
474474 %13 = tt.splat %arg0 : !tt.ptr <f64 > -> tensor <1 x512 x!tt.ptr <f64 >, #blocked2 >
475475 %14 = tt.broadcast %7 : tensor <1 x1 xi1 , #blocked2 > -> tensor <1 x512 xi1 , #blocked2 >
476- %15 = scf.for %arg3 = %c0 to %c30000 step %c512 iter_args (%arg4 = %cst_2 ) -> (tensor <1 x512 xf64 , #blocked2 >) {
477- %16 = arith.index_cast %arg3 : index to i32
478- %17 = tt.splat %16 : i32 -> tensor <1 x512 xi32 , #blocked2 >
476+ %15 = scf.for %arg3 = %c0 to %c30000 step %c512 iter_args (%arg4 = %cst_2 ) -> (tensor <1 x512 xf64 , #blocked2 >) : i32 {
477+ %17 = tt.splat %arg3 : i32 -> tensor <1 x512 xi32 , #blocked2 >
479478 %18 = arith.addi %17 , %10 : tensor <1 x512 xi32 , #blocked2 >
480479 %19 = arith.cmpi " slt" , %18 , %cst_0 : tensor <1 x512 xi32 , #blocked2 >
481480 %20 = arith.addi %18 , %12 : tensor <1 x512 xi32 , #blocked2 >
@@ -999,9 +998,9 @@ tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !
999998// CHECK-LABEL: cmp
1000999module attributes {" ttg.num-warps" = 8 : i32 , " ttg.num-ctas" = 1 : i32 } {
10011000tt.func public @cmp (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg2: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg3: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg4: i32 {tt.divisibility = 16 : i32 }, %arg5: i32 {tt.divisibility = 16 : i32 }) {
1002- %c64 = arith.constant 64 : index
1003- %c2048 = arith.constant 2048 : index
1004- %c0 = arith.constant 0 : index
1001+ %c64 = arith.constant 64 : i32
1002+ %c2048 = arith.constant 2048 : i32
1003+ %c0 = arith.constant 0 : i32
10051004 %c64_i32 = arith.constant 64 : i32
10061005 %cst = arith.constant dense <-3.40282347E+38 > : tensor <64 x64 xf32 , #blocked2 >
10071006 %cst_0 = arith.constant dense <4194304 > : tensor <64 x1 xi32 , #blocked2 >
@@ -1036,9 +1035,8 @@ tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt
10361035 %22 = arith.muli %21 , %cst_0 : tensor <64 x1 xi32 , #blocked2 >
10371036 %23 = tt.broadcast %22 : tensor <64 x1 xi32 , #blocked2 > -> tensor <64 x64 xi32 , #blocked2 >
10381037 %24 = tt.splat %arg1 : !tt.ptr <f32 > -> tensor <64 x64 x!tt.ptr <f32 >, #blocked2 >
1039- %25 = scf.for %arg6 = %c0 to %c2048 step %c64 iter_args (%arg7 = %14 ) -> (tensor <64 x64 xf32 , #blocked2 >) {
1040- %44 = arith.index_cast %arg6 : index to i32
1041- %45 = tt.splat %44 : i32 -> tensor <1 x64 xi32 , #blocked3 >
1038+ %25 = scf.for %arg6 = %c0 to %c2048 step %c64 iter_args (%arg7 = %14 ) -> (tensor <64 x64 xf32 , #blocked2 >) : i32 {
1039+ %45 = tt.splat %arg6 : i32 -> tensor <1 x64 xi32 , #blocked3 >
10421040 %46 = arith.addi %45 , %10 : tensor <1 x64 xi32 , #blocked3 >
10431041 %47 = arith.cmpi " slt" , %46 , %cst_2 : tensor <1 x64 xi32 , #blocked3 >
10441042 %48 = tt.broadcast %46 : tensor <1 x64 xi32 , #blocked3 > -> tensor <64 x64 xi32 , #blocked3 >
@@ -1092,9 +1090,8 @@ tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt
10921090 %41 = tt.broadcast %30 : tensor <64 x1 xf32 , #blocked2 > -> tensor <64 x64 xf32 , #blocked2 >
10931091 %42 = tt.splat %arg2 : !tt.ptr <f32 > -> tensor <64 x64 x!tt.ptr <f32 >, #blocked2 >
10941092 %43 = tt.splat %arg3 : !tt.ptr <f16 > -> tensor <64 x64 x!tt.ptr <f16 >, #blocked2 >
1095- scf.for %arg6 = %c0 to %c2048 step %c64 {
1096- %44 = arith.index_cast %arg6 : index to i32
1097- %45 = tt.splat %44 : i32 -> tensor <1 x64 xi32 , #blocked3 >
1093+ scf.for %arg6 = %c0 to %c2048 step %c64 : i32 {
1094+ %45 = tt.splat %arg6 : i32 -> tensor <1 x64 xi32 , #blocked3 >
10981095 %46 = arith.addi %45 , %10 : tensor <1 x64 xi32 , #blocked3 >
10991096 %47 = arith.cmpi " slt" , %46 , %cst_2 : tensor <1 x64 xi32 , #blocked3 >
11001097 %48 = tt.broadcast %46 : tensor <1 x64 xi32 , #blocked3 > -> tensor <64 x64 xi32 , #blocked3 >
@@ -1226,9 +1223,9 @@ module attributes {"ttg.num-warps" = 2 : i32, "ttg.num-ctas" = 1 : i32} {
12261223module attributes {" ttg.num-warps" = 4 : i32 , " ttg.num-ctas" = 1 : i32 } {
12271224 tt.func public @reduce_cvt2 (%arg0: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg2: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg3: i32 {tt.divisibility = 16 : i32 }, %arg4: i32 {tt.divisibility = 16 : i32 }) {
12281225 %cst = arith.constant dense <0.000000e+00 > : tensor <1 x256 xf32 , #blocked >
1229- %c3136_i32 = arith.constant 3136 : index
1230- %c256_i32 = arith.constant 256 : index
1231- %c0_i32 = arith.constant 0 : index
1226+ %c3136_i32 = arith.constant 3136 : i32
1227+ %c256_i32 = arith.constant 256 : i32
1228+ %c0_i32 = arith.constant 0 : i32
12321229 %cst_0 = arith.constant dense <3.136000e+03 > : tensor <1 x1 xf32 , #blocked >
12331230 %cst_1 = arith.constant dense <50176 > : tensor <1 x256 xi32 , #blocked >
12341231 %cst_2 = arith.constant dense <196 > : tensor <1 x1 xi32 , #blocked >
@@ -1250,9 +1247,8 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} {
12501247 %12 = tt.broadcast %11 : tensor <1 x1 xi32 , #blocked > -> tensor <1 x256 xi32 , #blocked >
12511248 %13 = tt.splat %arg1 : !tt.ptr <f32 > -> tensor <1 x256 x!tt.ptr <f32 >, #blocked >
12521249 %14 = tt.broadcast %7 : tensor <1 x1 xi1 , #blocked > -> tensor <1 x256 xi1 , #blocked >
1253- %15 = scf.for %arg5 = %c0_i32 to %c3136_i32 step %c256_i32 iter_args (%arg6 = %cst ) -> (tensor <1 x256 xf32 , #blocked >) {
1254- %42 = arith.index_cast %arg5 : index to i32
1255- %43 = tt.splat %42 : i32 -> tensor <1 x256 xi32 , #blocked >
1250+ %15 = scf.for %arg5 = %c0_i32 to %c3136_i32 step %c256_i32 iter_args (%arg6 = %cst ) -> (tensor <1 x256 xf32 , #blocked >) : i32 {
1251+ %43 = tt.splat %arg5 : i32 -> tensor <1 x256 xi32 , #blocked >
12561252 %44 = arith.addi %43 , %10 : tensor <1 x256 xi32 , #blocked >
12571253 %45 = arith.cmpi " slt" , %44 , %cst_4 : tensor <1 x256 xi32 , #blocked >
12581254 %46 = arith.remsi %44 , %cst_3 : tensor <1 x256 xi32 , #blocked >
0 commit comments