1- // RUN: triton-opt %s -tritongpu-optimize-accumulator-init | FileCheck %s
1+ // RUN: triton-opt %s -split-input-file - tritongpu-optimize-accumulator-init | FileCheck %s
22
33#blocked = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [1 , 4 ], order = [0 , 1 ]}>
44#blocked1 = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [4 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
@@ -292,42 +292,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
292292 tt.return %17 : tensor <128 x16 xf32 , #mma1 >
293293 }
294294
295- // Check that we bail out in unsupported cases
296-
297- // CHECK-LABEL: @non_zero_init
298- // CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc
299- tt.func @non_zero_init (%A: !ttg.memdesc <128 x64 xf16 , #shared , #smem >, %B: !ttg.memdesc <64 x16 xf16 , #shared1 , #smem >, %arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %ext: i32 , %inc: tensor <64 x16 xi32 , #blocked > {tt.divisibility = 16 : i32 }) -> tensor <128 x16 xf32 , #mma1 > {
300- %c0_i32 = arith.constant 0 : i32
301- %cst_2 = arith.constant dense <1.000000e+00 > : tensor <128 x16 xf32 , #mma1 >
302- %c1_i32 = arith.constant 1 : i32
303- %c8_i32 = arith.constant 8 : i32
304- %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args (%arg4 = %cst_2 ) -> (tensor <128 x16 xf32 , #mma1 >) : i32 {
305- %cnd = arith.cmpi slt , %arg3 , %ext : i32
306- %acc = ttng.warp_group_dot %A , %B , %arg4 : !ttg.memdesc <128 x64 xf16 , #shared , #smem > * !ttg.memdesc <64 x16 xf16 , #shared1 , #smem > -> tensor <128 x16 xf32 , #mma1 >
307- %acc_ = arith.select %cnd , %cst_2 , %acc : tensor <128 x16 xf32 , #mma1 >
308- scf.yield %acc_: tensor <128 x16 xf32 , #mma1 >
309- }
310- tt.return %17 : tensor <128 x16 xf32 , #mma1 >
311- }
312-
313- // CHECK-LABEL: @zero_init_dist_2
314- // CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc
295+ // CHECK-LABEL: @zero_init_dist_2
315296 tt.func @zero_init_dist_2 (%A: !ttg.memdesc <128 x64 xf16 , #shared , #smem >, %B: !ttg.memdesc <64 x16 xf16 , #shared1 , #smem >, %arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %ext: i32 , %inc: tensor <64 x16 xi32 , #blocked > {tt.divisibility = 16 : i32 }) -> tensor <128 x16 xf32 , #mma1 > {
316297 %c0_i32 = arith.constant 0 : i32
298+ // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
317299 %cst_2 = arith.constant dense <0.000000e+00 > : tensor <128 x16 xf32 , #mma1 >
318300 %c1_i32 = arith.constant 1 : i32
319301 %c8_i32 = arith.constant 8 : i32
302+ // CHECK: scf.for {{.*}} = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg{{[1-9]+}} = %{{.*}}, %[[ACC:.*]] = %[[CST]], %[[INIT_FLAG:.*]] = %false)
320303 %17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args (%arg4 = %cst_2 , %arg5 = %cst_2 ) -> (tensor <128 x16 xf32 , #mma1 >, tensor <128 x16 xf32 , #mma1 >) : i32 {
321304 %cnd = arith.cmpi slt , %arg3 , %ext : i32
305+ // CHECK: %2 = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[INIT_FLAG]]
322306 %acc = ttng.warp_group_dot %A , %B , %arg5 : !ttg.memdesc <128 x64 xf16 , #shared , #smem > * !ttg.memdesc <64 x16 xf16 , #shared1 , #smem > -> tensor <128 x16 xf32 , #mma1 >
323307 %acc_ = arith.select %cnd , %cst_2 , %acc : tensor <128 x16 xf32 , #mma1 >
308+ // CHECK: scf.yield {{.*}}, {{.*}}, %true
324309 scf.yield %acc_ , %arg4: tensor <128 x16 xf32 , #mma1 >, tensor <128 x16 xf32 , #mma1 >
325310 }
326311 tt.return %17 : tensor <128 x16 xf32 , #mma1 >
327312 }
328313
329314// CHECK-LABEL: @if_defines_alternative
330- // CHECK-NOT : %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc
315+ // CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %arg {{.*}} : !ttg.memdesc
331316 tt.func @if_defines_alternative (%A: !ttg.memdesc <128 x64 xf16 , #shared , #smem >, %B: !ttg.memdesc <64 x16 xf16 , #shared1 , #smem >, %arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %ext: i32 , %inc: tensor <64 x16 xi32 , #blocked > {tt.divisibility = 16 : i32 }) -> tensor <128 x16 xf32 , #mma1 > {
332317 %c0_i32 = arith.constant 0 : i32
333318 %cst_2 = arith.constant dense <0.000000e+00 > : tensor <128 x16 xf32 , #mma1 >
@@ -343,13 +328,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
343328 %acc_alt = arith.addf %acc , %cst_3 : tensor <128 x16 xf32 , #mma1 >
344329 scf.yield %acc_alt : tensor <128 x16 xf32 , #mma1 >
345330 }
331+ // CHECK: scf.yield {{.*}}, %true
346332 scf.yield %acc_: tensor <128 x16 xf32 , #mma1 >
347333 }
348334 tt.return %17 : tensor <128 x16 xf32 , #mma1 >
349335 }
350336
351337// CHECK-LABEL: @non_cond_override
352- // CHECK-NOT : %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc
338+ // CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %arg {{.*}} : !ttg.memdesc
353339 tt.func @non_cond_override (%A: !ttg.memdesc <128 x64 xf16 , #shared , #smem >, %B: !ttg.memdesc <64 x16 xf16 , #shared1 , #smem >, %arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %inc: tensor <64 x16 xi32 , #blocked > {tt.divisibility = 16 : i32 }) -> tensor <128 x16 xf32 , #mma1 > {
354340 %c0_i32 = arith.constant 0 : i32
355341 %cst_2 = arith.constant dense <0.000000e+00 > : tensor <128 x16 xf32 , #mma1 >
@@ -359,6 +345,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
359345 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args (%arg4 = %cst_2 ) -> (tensor <128 x16 xf32 , #mma1 >) : i32 {
360346 %acc = ttng.warp_group_dot %A , %B , %arg4 : !ttg.memdesc <128 x64 xf16 , #shared , #smem > * !ttg.memdesc <64 x16 xf16 , #shared1 , #smem > -> tensor <128 x16 xf32 , #mma1 >
361347 %acc_ = arith.addf %acc , %cst_3 : tensor <128 x16 xf32 , #mma1 >
348+ // CHECK: scf.yield {{.*}}, %true
349+ scf.yield %acc_: tensor <128 x16 xf32 , #mma1 >
350+ }
351+ tt.return %17 : tensor <128 x16 xf32 , #mma1 >
352+ }
353+
354+
355+ // Check that we bail out in unsupported cases
356+
357+ // CHECK-LABEL: @non_zero_init
358+ // CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc
359+ tt.func @non_zero_init (%A: !ttg.memdesc <128 x64 xf16 , #shared , #smem >, %B: !ttg.memdesc <64 x16 xf16 , #shared1 , #smem >, %arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %ext: i32 , %inc: tensor <64 x16 xi32 , #blocked > {tt.divisibility = 16 : i32 }) -> tensor <128 x16 xf32 , #mma1 > {
360+ %c0_i32 = arith.constant 0 : i32
361+ %cst_2 = arith.constant dense <1.000000e+00 > : tensor <128 x16 xf32 , #mma1 >
362+ %c1_i32 = arith.constant 1 : i32
363+ %c8_i32 = arith.constant 8 : i32
364+ %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args (%arg4 = %cst_2 ) -> (tensor <128 x16 xf32 , #mma1 >) : i32 {
365+ %cnd = arith.cmpi slt , %arg3 , %ext : i32
366+ %acc = ttng.warp_group_dot %A , %B , %arg4 : !ttg.memdesc <128 x64 xf16 , #shared , #smem > * !ttg.memdesc <64 x16 xf16 , #shared1 , #smem > -> tensor <128 x16 xf32 , #mma1 >
367+ %acc_ = arith.select %cnd , %cst_2 , %acc : tensor <128 x16 xf32 , #mma1 >
362368 scf.yield %acc_: tensor <128 x16 xf32 , #mma1 >
363369 }
364370 tt.return %17 : tensor <128 x16 xf32 , #mma1 >
0 commit comments