@@ -577,3 +577,119 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
577577 tt.return %11 : tensor <1024 xf32 , #blocked >
578578 }
579579}
580+
581+ // -----
582+
583+ #blocked = #triton_gpu.blocked <{sizePerThread = [4 ], threadsPerWarp = [64 ], warpsPerCTA = [4 ], order = [0 ]}>
584+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 4 : i32 , triton_gpu.target = " hip:gfx942" , " triton_gpu.threads-per-warp" = 64 : i32 } {
585+ // CHECK-LABEL: scalar_pointers
586+ tt.func public @scalar_pointers (%arg0: !tt.ptr <i64 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <i64 > {tt.divisibility = 16 : i32 }, %arg2: !tt.ptr <i32 > {tt.divisibility = 16 : i32 }) attributes {noinline = false } {
587+ %0 = tt.get_program_id x : i32
588+ %c1_i32 = arith.constant 1 : i32
589+ %c0_i64 = arith.constant 0 : i64
590+ %c10_i64 = arith.constant 10 : i64
591+ %c100_i32 = arith.constant 100 : i32
592+ %5 = tt.addptr %arg0 , %c1_i32 : !tt.ptr <i64 >, i32
593+ // CHECK: arith.constant 0 : i64
594+ // CHECK: arith.constant 0 : i64
595+ // CHECK: %[[offset0:.*]] = arith.constant 0 : i64
596+ // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %c1_i32 : !tt.ptr<i64>, i32
597+ // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[ptr1:.*]] = %[[ptr0]], %[[offset1:.*]] = %[[offset0]])
598+ %10:1 = scf.for %arg3 = %c1_i32 to %c100_i32 step %c1_i32 iter_args (%arg4 = %5 ) -> (!tt.ptr <i64 >) : i32 {
599+ // CHECK: tt.store %[[ptr1]]
600+ tt.store %arg4 , %c0_i64 : !tt.ptr <i64 >
601+ // CHECK: tt.addptr %[[ptr1]]
602+ %11 = tt.addptr %arg4 , %c1_i32 : !tt.ptr <i64 >, i32
603+ scf.yield %11 : !tt.ptr <i64 >
604+ }
605+ tt.return
606+ }
607+ }
608+
609+ // -----
610+
611+ #blocked = #triton_gpu.blocked <{sizePerThread = [4 ], threadsPerWarp = [64 ], warpsPerCTA = [4 ], order = [0 ]}>
612+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 4 : i32 , triton_gpu.target = " hip:gfx942" , " triton_gpu.threads-per-warp" = 64 : i32 } {
613+ // CHECK-LABEL: @scalar_if
614+ tt.func @scalar_if (%arg0: !tt.ptr <f32 >, %init : tensor <1024 xf32 , #blocked >, %cond : i1 )->f32 {
615+ %0 = tt.get_program_id x : i32
616+ %c1_i32 = arith.constant 1 : i32
617+ %c0_i64 = arith.constant 0 : i64
618+ %c10_i64 = arith.constant 10 : i64
619+ %c100_i32 = arith.constant 100 : i32
620+ %5 = tt.addptr %arg0 , %c1_i32 : !tt.ptr <f32 >, i32
621+ // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %{{.*}}
622+ // CHECK: scf.if {{.*}} -> ({{.*}}, !tt.ptr<f32>, i64)
623+ %6 = scf.if %cond -> (!tt.ptr <f32 >){
624+ %true = tt.addptr %5 , %c1_i32 : !tt.ptr <f32 >, i32
625+ // CHECK: %[[ptr1:.*]] = tt.addptr %[[ptr0]]
626+ // CHECK: scf.yield {{.*}}, %[[ptr1]]
627+ scf.yield %true : !tt.ptr <f32 >
628+ } else {
629+ %false = tt.addptr %5 , %c100_i32 : !tt.ptr <f32 >, i32
630+ // CHECK: %[[ptr2:.*]] = tt.addptr %[[ptr0]]
631+ // CHECK: scf.yield {{.*}}, %[[ptr2]]
632+ scf.yield %false : !tt.ptr <f32 >
633+ }
634+ %11 = tt.load %6 : !tt.ptr <f32 >
635+ tt.return %11 : f32
636+ }
637+ }
638+
639+ // -----
640+
641+ #blocked = #triton_gpu.blocked <{sizePerThread = [4 ], threadsPerWarp = [64 ], warpsPerCTA = [4 ], order = [0 ]}>
642+ module attributes {" triton_gpu.num-warps" = 4 : i32 , " triton_gpu.threads-per-warp" = 64 : i32 } {
643+ // CHECK-LABEL: tt.func @scalar_while
644+ tt.func @scalar_while (%arg0: !tt.ptr <f32 >, %init : f32 , %cond : i1 )->f32 {
645+ %c1024_i32 = arith.constant 1024 : i32
646+ %c0 = arith.constant 0 : index
647+ %c128 = arith.constant 128 : index
648+ %c1 = arith.constant 1 : index
649+ %0 = tt.get_program_id x : i32
650+ %1 = arith.muli %0 , %c1024_i32 : i32
651+ // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %{{.*}}
652+ // CHECK: scf.while ({{.*}}, {{.*}} = %arg2, %[[ptr1:.*]] = %[[ptr0]], {{.*}})
653+ %2 = tt.addptr %arg0 , %0 : !tt.ptr <f32 >, i32
654+ %6 = scf.while (%arg1 = %2 , %arg2 = %cond ) : (!tt.ptr <f32 >, i1 ) -> (!tt.ptr <f32 >) {
655+ // CHECK: scf.condition({{.*}}) {{.*}}, %[[ptr1]]
656+ scf.condition (%arg2 ) %arg1 : !tt.ptr <f32 >
657+ } do {
658+ // CHECK: ^bb0({{.*}}: !tt.ptr<f32>, %[[ptr2:.*]]: !tt.ptr<f32>, {{.*}})
659+ // CHECK: scf.yield %{{.*}}, {{.*}} %[[ptr2]], {{.*}}, {{.*}}
660+ ^bb0 (%arg1: !tt.ptr <f32 >):
661+ scf.yield %arg1 , %cond : !tt.ptr <f32 >, i1
662+ }
663+ %11 = tt.load %6 : !tt.ptr <f32 >
664+ tt.return %11 : f32
665+ }
666+ }
667+
668+ // -----
669+
670+ #blocked = #triton_gpu.blocked <{sizePerThread = [4 ], threadsPerWarp = [64 ], warpsPerCTA = [4 ], order = [0 ]}>
671+ module attributes {" triton_gpu.num-warps" = 4 : i32 , " triton_gpu.threads-per-warp" = 64 : i32 } {
672+ // CHECK-LABEL: tt.func @scalar_cond_branch
673+ tt.func @scalar_cond_branch (%arg0 : !tt.ptr <f32 >, %i1 : i1 ) -> f32 {
674+ %c1024_i32 = arith.constant 1024 : i32
675+ %c0 = arith.constant 0 : index
676+ %c128 = arith.constant 128 : index
677+ %c1 = arith.constant 1 : index
678+ %0 = tt.get_program_id x : i32
679+ %1 = arith.muli %0 , %c1024_i32 : i32
680+ %6 = tt.addptr %arg0 , %0 : !tt.ptr <f32 >, i32
681+ // CHECK: %[[ptr0:.*]] = tt.addptr %arg0
682+ // CHECK: cf.cond_br %arg1, ^bb1(%{{.*}}, %[[ptr0]], {{.*}}), ^bb2(%{{.*}}, %arg0, {{.*}})
683+ cf.cond_br %i1 , ^bb1 (%6 : !tt.ptr <f32 >), ^bb2 (%arg0 : !tt.ptr <f32 >)
684+ // CHECK: ^bb1({{.*}}, %[[ptr1:.*]]: !tt.ptr<f32>, {{.*}}):
685+ ^bb1 (%arg1 : !tt.ptr <f32 >):
686+ // CHECK: tt.load %[[ptr1]]
687+ %out1 = tt.load %arg1 : !tt.ptr <f32 >
688+ tt.return %out1 : f32
689+ // CHECK: ^bb2({{.*}}, %[[ptr2:.*]]: !tt.ptr<f32>, {{.*}}):
690+ ^bb2 (%arg2 : !tt.ptr <f32 >): // 2 preds: ^bb0, ^bb1
691+ // CHECK: tt.load %[[ptr2]]
692+ %out2 = tt.load %arg2 : !tt.ptr <f32 >
693+ tt.return %out2 : f32
694+ }
695+ }
0 commit comments