@@ -317,25 +317,26 @@ def kernel_basic(src, N, BLOCK_SIZE: tl.constexpr):
317317 # CHECK: #loc = loc("{{.*}}":316:0)
318318 # CHECK-LABEL: tt.func public @kernel_basic(
319319 # CHECK-SAME: %src: !tt.ptr<f32> loc("src"(#loc)), %N: i32 loc("N"(#loc)))
320- # CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<16xf32> loc(#loc1 )
321- # CHECK: %c16_i32 = arith.constant 16 : i32 loc(#loc1 )
322- # CHECK: %pid = tt.get_program_id x : i32 loc(#loc14 )
323- # CHECK: %offset = arith.muli %pid, %c16_i32 : i32 loc(#loc15 )
324- # CHECK: %offsets = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc16 )
325- # CHECK: %offsets_0 = tt.splat %offset : i32 -> tensor<16xi32> loc(#loc17 )
326- # CHECK: %offsets_1 = arith.addi %offsets_0, %offsets : tensor<16xi32> loc(#loc17 )
327- # CHECK: %load_src_store_dst = tt.splat %src : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>> loc(#loc18 )
328- # CHECK: %load_src_store_dst_2 = tt.addptr %load_src_store_dst, %offsets_1 : tensor<16x!tt.ptr<f32>>, tensor<16xi32> loc(#loc18 )
329- # CHECK: %mask = tt.splat %N : i32 -> tensor<16xi32> loc(#loc19 )
330- # CHECK: %mask_3 = arith.cmpi slt, %offsets_1, %mask : tensor<16xi32> loc(#loc19 )
331- # CHECK: %x_plus_1 = tt.load %load_src_store_dst_2, %mask_3 : tensor<16x!tt.ptr<f32>> loc(#loc20 )
332- # CHECK: %x_plus_1_4 = arith.addf %x_plus_1 , %cst : tensor<16xf32> loc(#loc21 )
333- # CHECK: tt.store %load_src_store_dst_2, %x_plus_1_4 , %mask_3 : tensor<16x!tt.ptr<f32>> loc(#loc10)
320+ # CHECK: %x_plus_1 = arith.constant dense<1.000000e+00> : tensor<16xf32> loc(#loc14 )
321+ # CHECK: %c16_i32 = arith.constant 16 : i32 loc(#loc2 )
322+ # CHECK: %pid = tt.get_program_id x : i32 loc(#loc15 )
323+ # CHECK: %offset = arith.muli %pid, %c16_i32 : i32 loc(#loc16 )
324+ # CHECK: %offsets = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc17 )
325+ # CHECK: %offsets_0 = tt.splat %offset : i32 -> tensor<16xi32> loc(#loc18 )
326+ # CHECK: %offsets_1 = arith.addi %offsets_0, %offsets : tensor<16xi32> loc(#loc18 )
327+ # CHECK: %load_src_store_dst = tt.splat %src : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>> loc(#loc19 )
328+ # CHECK: %load_src_store_dst_2 = tt.addptr %load_src_store_dst, %offsets_1 : tensor<16x!tt.ptr<f32>>, tensor<16xi32> loc(#loc19 )
329+ # CHECK: %mask = tt.splat %N : i32 -> tensor<16xi32> loc(#loc20 )
330+ # CHECK: %mask_3 = arith.cmpi slt, %offsets_1, %mask : tensor<16xi32> loc(#loc20 )
331+ # CHECK: %x_plus_1_4 = tt.load %load_src_store_dst_2, %mask_3 : tensor<16x!tt.ptr<f32>> loc(#loc21 )
332+ # CHECK: %x_plus_1_5 = arith.addf %x_plus_1_4 , %x_plus_1 : tensor<16xf32> loc(#loc14 )
333+ # CHECK: tt.store %load_src_store_dst_2, %x_plus_1_5 , %mask_3 : tensor<16x!tt.ptr<f32>> loc(#loc10)
334334 # CHECK: tt.return loc(#loc11)
335- # CHECK: } loc(#loc)
335+ # CHECK: } loc(#loc)
336+ # CHECK: } loc(#loc)
336337
337- # CHECK: #loc1 = loc(unknown )
338- # CHECK: #loc2 = loc({{.*}} )
338+ # CHECK: #loc1 = loc({{.*}} )
339+ # CHECK: #loc2 = loc(unknown )
339340 # CHECK: #loc3 = loc({{.*}})
340341 # CHECK: #loc4 = loc({{.*}})
341342 # CHECK: #loc5 = loc({{.*}})
@@ -345,13 +346,13 @@ def kernel_basic(src, N, BLOCK_SIZE: tl.constexpr):
345346 # CHECK: #loc9 = loc({{.*}})
346347 # CHECK: #loc10 = loc({{.*}})
347348 # CHECK: #loc11 = loc({{.*}})
348- # CHECK: #loc14 = loc("pid "(#loc2 ))
349- # CHECK: #loc15 = loc("offset "(#loc3))
350- # CHECK: #loc16 = loc("offsets "(#loc4))
349+ # CHECK: #loc14 = loc("x_plus_1 "(#loc1 ))
350+ # CHECK: #loc15 = loc("pid "(#loc3))
351+ # CHECK: #loc16 = loc("offset "(#loc4))
351352 # CHECK: #loc17 = loc("offsets"(#loc5))
352- # CHECK: #loc18 = loc("load_src_store_dst "(#loc6))
353- # CHECK: #loc19 = loc("mask "(#loc7))
354- # CHECK: #loc20 = loc("x_plus_1 "(#loc8))
353+ # CHECK: #loc18 = loc("offsets "(#loc6))
354+ # CHECK: #loc19 = loc("load_src_store_dst "(#loc7))
355+ # CHECK: #loc20 = loc("mask "(#loc8))
355356 # CHECK: #loc21 = loc("x_plus_1"(#loc9))
356357
357358 pid = tl .program_id (0 )
@@ -459,20 +460,20 @@ def kernel_basic_while(N):
459460 # CHECK: %arange = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32>
460461 arange = tl .arange (0 , 16 )
461462 ivar = 0
462- # CHECK: %ivar: 2 = scf.while (%arange_0 = %arange, %ivar_1 = %c0_i32 ) : (tensor<16xi32>, i32) -> (tensor<16xi32>, i32)
463- # CHECK: %[[COND:.*]] = arith.cmpi slt, %ivar_1 , %N : i32
464- # CHECK: scf.condition(%[[COND]]) %arange_0 , %ivar_1 : tensor<16xi32>, i32
463+ # CHECK: %ivar_[[IV0:.+]]: 2 = scf.while (%arange_[[AR0:.+]] = %arange, %ivar_[[IV1:.+]] = %ivar ) : (tensor<16xi32>, i32) -> (tensor<16xi32>, i32)
464+ # CHECK: %[[COND:.*]] = arith.cmpi slt, %ivar_[[IV1]] , %N : i32
465+ # CHECK: scf.condition(%[[COND]]) %arange_[[AR0]] , %ivar_[[IV1]] : tensor<16xi32>, i32
465466 while ivar < N :
466- # CHECK: ^bb0(%arange_0 : tensor<16xi32> loc("arange"), %ivar_1 : i32
467+ # CHECK: ^bb0(%arange_[[AR0]] : tensor<16xi32> loc("arange"), %ivar_[[IV1]] : i32
467468
468- # CHECK: %ivar_2 = arith.addi %ivar_1 , %c1_i32 : i32
469+ # CHECK: %ivar_[[IV2:.+]] = arith.addi %ivar_[[IV1]] , %c1_i32 : i32
469470 ivar += 1
470- # CHECK: %arange_3 = tt.splat %ivar_2 : i32 -> tensor<16xi32>
471- # CHECK: %arange_4 = arith.muli %arange_0 , %arange_3 : tensor<16xi32>
472- # CHECK: scf.yield %arange_4 , %ivar_2 : tensor<16xi32>, i32
471+ # CHECK: %arange_[[AR1:.+]] = tt.splat %ivar_[[IV2]] : i32 -> tensor<16xi32>
472+ # CHECK: %arange_[[AR2:.+]] = arith.muli %arange_[[AR0]] , %arange_[[AR1]] : tensor<16xi32>
473+ # CHECK: scf.yield %arange_[[AR2]] , %ivar_[[IV2]] : tensor<16xi32>, i32
473474 arange *= ivar
474475
475- # CHECK: tt.print ": " {hex = false, isSigned = array<i32: 1>} : %ivar #0 : tensor<16xi32>
476+ # CHECK: tt.print ": " {hex = false, isSigned = array<i32: 1>} : %ivar_[[IV0]] #0 : tensor<16xi32>
476477 tl .device_print ("" , arange )
477478
478479 h = triton .compile (triton .compiler .ASTSource (fn = kernel_basic_while , signature = {"N" : "i32" }, constexprs = {}))
0 commit comments