@@ -262,25 +262,26 @@ def kernel_basic(src, N, BLOCK_SIZE: tl.constexpr):
262262 # CHECK: #loc = loc("{{.*}}":261:0)
263263 # CHECK-LABEL: tt.func public @kernel_basic(
264264 # CHECK-SAME: %src: !tt.ptr<f32> loc("src"(#loc)), %N: i32 loc("N"(#loc)))
265- # CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<16xf32> loc(#loc1 )
266- # CHECK: %c16_i32 = arith.constant 16 : i32 loc(#loc1 )
267- # CHECK: %pid = tt.get_program_id x : i32 loc(#loc14 )
268- # CHECK: %offset = arith.muli %pid, %c16_i32 : i32 loc(#loc15 )
269- # CHECK: %offsets = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc16 )
270- # CHECK: %offsets_0 = tt.splat %offset : i32 -> tensor<16xi32> loc(#loc17 )
271- # CHECK: %offsets_1 = arith.addi %offsets_0, %offsets : tensor<16xi32> loc(#loc17 )
272- # CHECK: %load_src_store_dst = tt.splat %src : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>> loc(#loc18 )
273- # CHECK: %load_src_store_dst_2 = tt.addptr %load_src_store_dst, %offsets_1 : tensor<16x!tt.ptr<f32>>, tensor<16xi32> loc(#loc18 )
274- # CHECK: %mask = tt.splat %N : i32 -> tensor<16xi32> loc(#loc19 )
275- # CHECK: %mask_3 = arith.cmpi slt, %offsets_1, %mask : tensor<16xi32> loc(#loc19 )
276- # CHECK: %x_plus_1 = tt.load %load_src_store_dst_2, %mask_3 : tensor<16x!tt.ptr<f32>> loc(#loc20 )
277- # CHECK: %x_plus_1_4 = arith.addf %x_plus_1 , %cst : tensor<16xf32> loc(#loc21 )
278- # CHECK: tt.store %load_src_store_dst_2, %x_plus_1_4 , %mask_3 : tensor<16x!tt.ptr<f32>> loc(#loc10)
265+ # CHECK: %x_plus_1 = arith.constant dense<1.000000e+00> : tensor<16xf32> loc(#loc14 )
266+ # CHECK: %c16_i32 = arith.constant 16 : i32 loc(#loc2 )
267+ # CHECK: %pid = tt.get_program_id x : i32 loc(#loc15 )
268+ # CHECK: %offset = arith.muli %pid, %c16_i32 : i32 loc(#loc16 )
269+ # CHECK: %offsets = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc17 )
270+ # CHECK: %offsets_0 = tt.splat %offset : i32 -> tensor<16xi32> loc(#loc18 )
271+ # CHECK: %offsets_1 = arith.addi %offsets_0, %offsets : tensor<16xi32> loc(#loc18 )
272+ # CHECK: %load_src_store_dst = tt.splat %src : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>> loc(#loc19 )
273+ # CHECK: %load_src_store_dst_2 = tt.addptr %load_src_store_dst, %offsets_1 : tensor<16x!tt.ptr<f32>>, tensor<16xi32> loc(#loc19 )
274+ # CHECK: %mask = tt.splat %N : i32 -> tensor<16xi32> loc(#loc20 )
275+ # CHECK: %mask_3 = arith.cmpi slt, %offsets_1, %mask : tensor<16xi32> loc(#loc20 )
276+ # CHECK: %x_plus_1_4 = tt.load %load_src_store_dst_2, %mask_3 : tensor<16x!tt.ptr<f32>> loc(#loc21 )
277+ # CHECK: %x_plus_1_5 = arith.addf %x_plus_1_4 , %x_plus_1 : tensor<16xf32> loc(#loc14 )
278+ # CHECK: tt.store %load_src_store_dst_2, %x_plus_1_5 , %mask_3 : tensor<16x!tt.ptr<f32>> loc(#loc10)
279279 # CHECK: tt.return loc(#loc11)
280- # CHECK: } loc(#loc)
280+ # CHECK: } loc(#loc)
281+ # CHECK: } loc(#loc)
281282
282- # CHECK: #loc1 = loc(unknown )
283- # CHECK: #loc2 = loc({{.*}} )
283+ # CHECK: #loc1 = loc({{.*}} )
284+ # CHECK: #loc2 = loc(unknown )
284285 # CHECK: #loc3 = loc({{.*}})
285286 # CHECK: #loc4 = loc({{.*}})
286287 # CHECK: #loc5 = loc({{.*}})
@@ -290,13 +291,13 @@ def kernel_basic(src, N, BLOCK_SIZE: tl.constexpr):
290291 # CHECK: #loc9 = loc({{.*}})
291292 # CHECK: #loc10 = loc({{.*}})
292293 # CHECK: #loc11 = loc({{.*}})
293- # CHECK: #loc14 = loc("pid "(#loc2 ))
294- # CHECK: #loc15 = loc("offset "(#loc3))
295- # CHECK: #loc16 = loc("offsets "(#loc4))
294+ # CHECK: #loc14 = loc("x_plus_1 "(#loc1 ))
295+ # CHECK: #loc15 = loc("pid "(#loc3))
296+ # CHECK: #loc16 = loc("offset "(#loc4))
296297 # CHECK: #loc17 = loc("offsets"(#loc5))
297- # CHECK: #loc18 = loc("load_src_store_dst "(#loc6))
298- # CHECK: #loc19 = loc("mask "(#loc7))
299- # CHECK: #loc20 = loc("x_plus_1 "(#loc8))
298+ # CHECK: #loc18 = loc("offsets "(#loc6))
299+ # CHECK: #loc19 = loc("load_src_store_dst "(#loc7))
300+ # CHECK: #loc20 = loc("mask "(#loc8))
300301 # CHECK: #loc21 = loc("x_plus_1"(#loc9))
301302
302303 pid = tl .program_id (0 )
@@ -404,20 +405,20 @@ def kernel_basic_while(N):
404405 # CHECK: %arange = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32>
405406 arange = tl .arange (0 , 16 )
406407 ivar = 0
407- # CHECK: %ivar: 2 = scf.while (%arange_0 = %arange, %ivar_1 = %c0_i32 ) : (tensor<16xi32>, i32) -> (tensor<16xi32>, i32)
408- # CHECK: %[[COND:.*]] = arith.cmpi slt, %ivar_1 , %N : i32
409- # CHECK: scf.condition(%[[COND]]) %arange_0 , %ivar_1 : tensor<16xi32>, i32
408+ # CHECK: %ivar_[[IV0:.+]]: 2 = scf.while (%arange_[[AR0:.+]] = %arange, %ivar_[[IV1:.+]] = %ivar ) : (tensor<16xi32>, i32) -> (tensor<16xi32>, i32)
409+ # CHECK: %[[COND:.*]] = arith.cmpi slt, %ivar_[[IV1]] , %N : i32
410+ # CHECK: scf.condition(%[[COND]]) %arange_[[AR0]] , %ivar_[[IV1]] : tensor<16xi32>, i32
410411 while ivar < N :
411- # CHECK: ^bb0(%arange_0 : tensor<16xi32> loc("arange"), %ivar_1 : i32
412+ # CHECK: ^bb0(%arange_[[AR0]] : tensor<16xi32> loc("arange"), %ivar_[[IV1]] : i32
412413
413- # CHECK: %ivar_2 = arith.addi %ivar_1 , %c1_i32 : i32
414+ # CHECK: %ivar_[[IV2:.+]] = arith.addi %ivar_[[IV1]] , %c1_i32 : i32
414415 ivar += 1
415- # CHECK: %arange_3 = tt.splat %ivar_2 : i32 -> tensor<16xi32>
416- # CHECK: %arange_4 = arith.muli %arange_0 , %arange_3 : tensor<16xi32>
417- # CHECK: scf.yield %arange_4 , %ivar_2 : tensor<16xi32>, i32
416+ # CHECK: %arange_[[AR1:.+]] = tt.splat %ivar_[[IV2]] : i32 -> tensor<16xi32>
417+ # CHECK: %arange_[[AR2:.+]] = arith.muli %arange_[[AR0]] , %arange_[[AR1]] : tensor<16xi32>
418+ # CHECK: scf.yield %arange_[[AR2]] , %ivar_[[IV2]] : tensor<16xi32>, i32
418419 arange *= ivar
419420
420- # CHECK: tt.print ": " {hex = false, isSigned = array<i32: 1>} : %ivar #0 : tensor<16xi32>
421+ # CHECK: tt.print ": " {hex = false, isSigned = array<i32: 1>} : %ivar_[[IV0]] #0 : tensor<16xi32>
421422 tl .device_print ("" , arange )
422423
423424 h = triton .compile (triton .compiler .ASTSource (fn = kernel_basic_while , signature = {"N" : "i32" }, constexprs = {}))
0 commit comments