1010#dot0 = #triton_gpu.dot_op <{opIdx = 0 , parent = #dpas , kWidth =2 }>
1111#dot1 = #triton_gpu.dot_op <{opIdx = 1 , parent = #dpas , kWidth =2 }>
1212module attributes {" triton_gpu.num-warps" = 64 : i32 , " triton_gpu.threads-per-warp" = 16 : i32 , " triton_intel_gpu.support_sg_2d_block" } {
13- tt.func public @matmul_kernel_with_block_pointers (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg2: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg3: i32 {tt.divisibility = 16 : i32 }, %arg4: i32 {tt.divisibility = 16 : i32 }, %arg5: i32 {tt.divisibility = 16 : i32 }, %arg6: i32 {tt.divisibility = 16 : i32 }, %arg7: i32 {tt.divisibility = 16 : i32 }, %arg8: i32 {tt.divisibility = 16 : i32 }) {
14- // CHECK: @matmul_kernel_with_block_pointers
13+ tt.func public @matmul_kernel_with_block_pointers (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg2: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg3: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg4: i32 {tt.divisibility = 16 : i32 }, %arg5: i32 {tt.divisibility = 16 : i32 }, %arg6: i32 {tt.divisibility = 16 : i32 }, %arg7: i32 {tt.divisibility = 16 : i32 }, %arg8: i32 {tt.divisibility = 16 : i32 }, %arg9: i32 {tt.divisibility = 16 : i32 }) {
1514 %c4_i32 = arith.constant 4 : i32
1615 %c256_i32 = arith.constant 256 : i32
1716 %c1_i64 = arith.constant 1 : i64
@@ -20,9 +19,9 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa
2019 %c255_i32 = arith.constant 255 : i32
2120 %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #dpas >
2221 %0 = tt.get_program_id x : i32
23- %1 = arith.addi %arg3 , %c255_i32 : i32
22+ %1 = arith.addi %arg4 , %c255_i32 : i32
2423 %2 = arith.divsi %1 , %c256_i32 : i32
25- %3 = arith.addi %arg4 , %c255_i32 : i32
24+ %3 = arith.addi %arg5 , %c255_i32 : i32
2625 %4 = arith.divsi %3 , %c256_i32 : i32
2726 %5 = arith.muli %4 , %c4_i32 : i32
2827 %6 = arith.divsi %0 , %5 : i32
@@ -34,35 +33,41 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa
3433 %12 = arith.remsi %0 , %5 : i32
3534 %13 = arith.divsi %12 , %9 : i32
3635 %14 = arith.muli %11 , %c256_i32 : i32
37- %15 = arith.extsi %arg3 : i32 to i64
38- %16 = arith.extsi %arg5 : i32 to i64
39- %17 = arith.extsi %arg6 : i32 to i64
36+ %15 = arith.extsi %arg4 : i32 to i64
37+ %16 = arith.extsi %arg6 : i32 to i64
38+ %17 = arith.extsi %arg7 : i32 to i64
4039 // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
4140 %18 = tt.make_tensor_ptr %arg0 , [%15 , %16 ], [%17 , %c1_i64 ], [%14 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x32 xf16 , #dot0 >>
4241 %19 = arith.muli %13 , %c256_i32 : i32
43- %20 = arith.extsi %arg4 : i32 to i64
44- %21 = arith.extsi %arg7 : i32 to i64
42+ %20 = arith.extsi %arg5 : i32 to i64
43+ %21 = arith.extsi %arg8 : i32 to i64
4544 // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
4645 %22 = tt.make_tensor_ptr %arg1 , [%16 , %20 ], [%21 , %c1_i64 ], [%c0_i32 , %19 ] {order = array<i32 : 1 , 0 >} : <tensor <32 x256 xf16 , #dot1 >>
47- %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args (%arg10 = %cst , %arg11 = %18 , %arg12 = %22 ) -> (tensor <256 x256 xf32 , #dpas >, !tt.ptr <tensor <256 x32 xf16 , #dot0 >>, !tt.ptr <tensor <32 x256 xf16 , #dot1 >>) : i32 {
46+ %23:3 = scf.for %arg10 = %c0_i32 to %arg6 step %c32_i32 iter_args (%arg11 = %cst , %arg12 = %18 , %arg13 = %22 ) -> (tensor <256 x256 xf32 , #dpas >, !tt.ptr <tensor <256 x32 xf16 , #dot0 >>, !tt.ptr <tensor <32 x256 xf16 , #dot1 >>) : i32 {
4847 // CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
4948 // CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
50- %28 = tt.load %arg11 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <256 x32 xf16 , #dot0 >>
51- %29 = tt.load %arg12 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <32 x256 xf16 , #dot1 >>
49+ %28 = tt.load %arg12 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <256 x32 xf16 , #dot0 >>
50+ %29 = tt.load %arg13 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <32 x256 xf16 , #dot1 >>
5251 // CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[DPAS]]>
5352 // CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
5453 // CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
55- %30 = tt.dot %28 , %29 , %arg10 , inputPrecision = tf32 : tensor <256 x32 xf16 , #dot0 > * tensor <32 x256 xf16 , #dot1 > -> tensor <256 x256 xf32 , #dpas >
56- %31 = tt.advance %arg11 , [%c0_i32 , %c32_i32 ] : <tensor <256 x32 xf16 , #dot0 >>
57- %32 = tt.advance %arg12 , [%c32_i32 , %c0_i32 ] : <tensor <32 x256 xf16 , #dot1 >>
54+ %30 = tt.dot %28 , %29 , %arg11 , inputPrecision = tf32 : tensor <256 x32 xf16 , #dot0 > * tensor <32 x256 xf16 , #dot1 > -> tensor <256 x256 xf32 , #dpas >
55+ %31 = tt.advance %arg12 , [%c0_i32 , %c32_i32 ] : <tensor <256 x32 xf16 , #dot0 >>
56+ %32 = tt.advance %arg13 , [%c32_i32 , %c0_i32 ] : <tensor <32 x256 xf16 , #dot1 >>
5857 scf.yield %30 , %31 , %32 : tensor <256 x256 xf32 , #dpas >, !tt.ptr <tensor <256 x32 xf16 , #dot0 >>, !tt.ptr <tensor <32 x256 xf16 , #dot1 >>
5958 }
60- %24 = arith.truncf %23#0 : tensor <256 x256 xf32 , #dpas > to tensor <256 x256 xf16 , #dpas >
61- %26 = arith.extsi %arg8 : i32 to i64
59+ %25 = arith.extsi %arg9 : i32 to i64
60+ // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #[[DPAS]]>>
61+ %26 = tt.make_tensor_ptr %arg3 , [%15 , %20 ], [%25 , %c1_i64 ], [%14 , %19 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x256 xf32 , #dpas >>
62+ // CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32, #[[DPAS]]>>
63+ %27 = tt.load %26 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <256 x256 xf32 , #dpas >>
64+ %28 = arith.addf %23#0 , %27 : tensor <256 x256 xf32 , #dpas >
65+ %29 = arith.truncf %28 : tensor <256 x256 xf32 , #dpas > to tensor <256 x256 xf16 , #dpas >
66+
6267 // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #[[DPAS]]>>
63- %27 = tt.make_tensor_ptr %arg2 , [%15 , %20 ], [%26 , %c1_i64 ], [%14 , %19 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x256 xf16 , #dpas >>
68+ %30 = tt.make_tensor_ptr %arg2 , [%15 , %20 ], [%25 , %c1_i64 ], [%14 , %19 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x256 xf16 , #dpas >>
6469 // CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf16, #[[DPAS]]>>
65- tt.store %27 , %24 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <256 x256 xf16 , #dpas >>
70+ tt.store %30 , %29 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <256 x256 xf16 , #dpas >>
6671 tt.return
6772 }
6873}
0 commit comments