@@ -51,13 +51,79 @@ module {
5151    }
5252    tt.return 
5353  }
54-   // CHECK: tt.func public @test1([[PARAM_0_:%.+]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {  
54+   // CHECK: tt.func public @test1 
5555  // CHECK:   scf.for 
5656  // CHECK:     [[PTR:%.+]] = tt.addptr {{.*}} : tensor<32x32x!tt.ptr<f16>>, tensor<32x32xi32> 
5757  // CHECK:     [[LOAD:%.+]] = tt.load [[PTR]] evictionPolicy = evict_last : tensor<32x32x!tt.ptr<f16>> 
5858  // CHECK:     arith.extf [[LOAD]] : tensor<32x32xf16> to tensor<32x32xf32> 
5959  // CHECK:     [[ORI:%.+]] = arith.ori {{.*}} : tensor<32x32xi1> 
6060  // CHECK:     [[SEL:%.+]] = arith.select [[ORI]], {{.*}}, {{.*}} : tensor<32x32xi1>, tensor<32x32xf32> 
6161  // CHECK:     scf.yield [[SEL]] : tensor<32x32xf32> 
62-   // CHECK: } 
62+ 
63+  tt.func  public  @test2 (%arg0:  !tt.ptr <f32 > {tt.divisibility  = 16  : i32 }, %arg1:  !tt.ptr <f32 > {tt.divisibility  = 16  : i32 }, %arg2:  !tt.ptr <f32 > {tt.divisibility  = 16  : i32 }, %arg3:  !tt.ptr <f32 > {tt.divisibility  = 16  : i32 }, %arg4:  i32  {tt.divisibility  = 16  : i32 }, %arg5:  i32  {tt.divisibility  = 16  : i32 }) {
64+     %cst  = arith.constant  0.000000e+00  : f32 
65+     %cst_0  = arith.constant  dense <1.000000e+00 > : tensor <64 x8 xf32 >
66+     %c8_i32  = arith.constant  8  : i32 
67+     %c128_i32  = arith.constant  128  : i32 
68+     %cst_1  = arith.constant  dense <0.000000e+00 > : tensor <64 x8 xf32 >
69+     %cst_2  = arith.constant  dense <16384 > : tensor <64 x1 xi32 >
70+     %cst_3  = arith.constant  dense <128 > : tensor <1 x8 xi32 >
71+     %c0_i32  = arith.constant  0  : i32 
72+     %cst_4  = arith.constant  dense <128 > : tensor <64 x1 xi32 >
73+     %c64_i32  = arith.constant  64  : i32 
74+     %0  = tt.get_program_id  x  : i32 
75+     %1  = arith.muli  %0 , %c64_i32  : i32 
76+     %2  = tt.make_range  {end  = 64  : i32 , start  = 0  : i32 } : tensor <64 xi32 >
77+     %3  = tt.expand_dims  %2  {axis  = 1  : i32 } : tensor <64 xi32 > -> tensor <64 x1 xi32 >
78+     %4  = tt.splat  %1  : i32  -> tensor <64 x1 xi32 >
79+     %5  = arith.addi  %4 , %3  : tensor <64 x1 xi32 >
80+     %6  = tt.make_range  {end  = 8  : i32 , start  = 0  : i32 } : tensor <8 xi32 >
81+     %7  = tt.expand_dims  %6  {axis  = 0  : i32 } : tensor <8 xi32 > -> tensor <1 x8 xi32 >
82+     %8  = arith.remsi  %5 , %cst_4  : tensor <64 x1 xi32 >
83+     %9  = arith.divsi  %5 , %cst_4  : tensor <64 x1 xi32 >
84+     %10  = tt.broadcast  %8  : tensor <64 x1 xi32 > -> tensor <64 x8 xi32 >
85+     %11  = arith.muli  %9 , %cst_2  : tensor <64 x1 xi32 >
86+     %12  = tt.broadcast  %11  : tensor <64 x1 xi32 > -> tensor <64 x8 xi32 >
87+     %13  = tt.splat  %arg0  : !tt.ptr <f32 > -> tensor <64 x8 x!tt.ptr <f32 >>
88+     %14:3  = scf.for  %arg6  = %c0_i32  to  %c128_i32  step  %c8_i32  iter_args (%arg7  = %cst_1 , %arg8  = %cst_1 , %arg9  = %cst_1 ) -> (tensor <64 x8 xf32 >, tensor <64 x8 xf32 >, tensor <64 x8 xf32 >)  : i32  {
89+       %25  = tt.splat  %arg6  : i32  -> tensor <1 x8 xi32 >
90+       %26  = arith.addi  %25 , %7  : tensor <1 x8 xi32 >
91+       %27  = arith.cmpi  slt , %26 , %cst_3  : tensor <1 x8 xi32 >
92+       %28  = arith.muli  %26 , %cst_3  : tensor <1 x8 xi32 >
93+       %29  = tt.broadcast  %28  : tensor <1 x8 xi32 > -> tensor <64 x8 xi32 >
94+       %30  = arith.addi  %10 , %29  : tensor <64 x8 xi32 >
95+       %31  = arith.addi  %30 , %12  : tensor <64 x8 xi32 >
96+       %32  = tt.addptr  %13 , %31  : tensor <64 x8 x!tt.ptr <f32 >>, tensor <64 x8 xi32 >
97+       %33  = tt.broadcast  %27  : tensor <1 x8 xi1 > -> tensor <64 x8 xi1 >
98+       %34  = tt.load  %32 , %33 , %cst_1  evictionPolicy  = evict_first  : tensor <64 x8 x!tt.ptr <f32 >>
99+       %35  = arith.cmpi  eq , %arg6 , %c0_i32  : i32 
100+       %36:3  = scf.if  %35  -> (tensor <64 x8 xf32 >, tensor <64 x8 xf32 >, tensor <64 x8 xf32 >) {
101+         scf.yield  %cst_1 , %34 , %cst_0  : tensor <64 x8 xf32 >, tensor <64 x8 xf32 >, tensor <64 x8 xf32 >
102+       } else  {
103+         %40  = arith.subf  %34 , %arg7  : tensor <64 x8 xf32 >
104+         %41  = arith.addf  %arg9 , %cst_0  : tensor <64 x8 xf32 >
105+         %42  = arith.divf  %40 , %41  : tensor <64 x8 xf32 >
106+         %43  = arith.addf  %arg7 , %42  : tensor <64 x8 xf32 >
107+         %44  = arith.subf  %34 , %43  : tensor <64 x8 xf32 >
108+         %45  = arith.mulf  %40 , %44  : tensor <64 x8 xf32 >
109+         %46  = arith.addf  %arg8 , %45  : tensor <64 x8 xf32 >
110+         scf.yield  %46 , %43 , %41  : tensor <64 x8 xf32 >, tensor <64 x8 xf32 >, tensor <64 x8 xf32 >
111+       }
112+       %37  = arith.select  %33 , %36#1 , %arg7  : tensor <64 x8 xi1 >, tensor <64 x8 xf32 >
113+       %38  = arith.select  %33 , %36#0 , %arg8  : tensor <64 x8 xi1 >, tensor <64 x8 xf32 >
114+       %39  = arith.select  %33 , %36#2 , %arg9  : tensor <64 x8 xi1 >, tensor <64 x8 xf32 >
115+       scf.yield  %37 , %38 , %39  : tensor <64 x8 xf32 >, tensor <64 x8 xf32 >, tensor <64 x8 xf32 >
116+     }
117+     tt.return 
118+   }
119+   // CHECK: tt.func public @test2 
120+   // CHECK:   scf.for 
121+   // CHECK:     [[PTR:%.+]] = tt.addptr {{.*}} : tensor<64x8x!tt.ptr<f32>>, tensor<64x8xi32> 
122+   // CHECK:     [[LOAD:%.+]] = tt.load [[PTR]] evictionPolicy = evict_first : tensor<64x8x!tt.ptr<f32>> 
123+   // CHECK:     [[IF_RES:%.+]]:3 = scf.if {{.*}} -> (tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32>) 
124+   // CHECK:       scf.yield {{.*}}, [[LOAD]], {{.*}} : tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32> 
125+   // CHECK:     else 
126+   // CHECK-2:     arith.subf [[LOAD]], {{.*}} : tensor<64x8xf32 
127+   // CHECK:     } 
128+   // CHECK:     scf.yield [[IF_RES]]#1, [[IF_RES]]#0, [[IF_RES]]#2 : tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32> 
63129}
0 commit comments