-
Notifications
You must be signed in to change notification settings - Fork 74
Open
Description
I'm measuring GEMM performance on this benchmark https://github.com/intel/intel-xpu-backend-for-triton/blob/main/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py
But I replaced tl.constexpr with tl.int64 for strides. Strides are often not tl.constexpr in external source code, example https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html. After tl.constexpr -> tl.int64 performance degrades by ~10x for GPU Max 1100. I keep only one shape for simplicity.
(triton) (base) jovyan@jupyter-ekrivov:~/triton/intel-xpu-backend-for-triton$ TRITON_PRINT_AUTOTUNING=1 python benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py --brief
Stopped warmup after 542 iterations
Stopped warmup after 600 iterations
Stopped warmup after 540 iterations
matmul-tensor-desc-performance:
B M N K Triton-GB/s OneDNN-GB/s CUTLASS-GB/s Triton-TFlops OneDNN-TFlops CUTLASS-TFlops
0 1.0 8192.0 4096.0 4096.0 188.227071 207.767366 172.915159 220.279452 243.147181 202.360141
(triton) (base) jovyan@jupyter-ekrivov:~/triton/intel-xpu-backend-for-triton$ TRITON_PRINT_AUTOTUNING=1 python benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py --brief
Stopped warmup after 59 iterations
Stopped warmup after 638 iterations
Stopped warmup after 512 iterations
matmul-tensor-desc-performance:
B M N K Triton-GB/s OneDNN-GB/s CUTLASS-GB/s Triton-TFlops OneDNN-TFlops CUTLASS-TFlops
0 1.0 8192.0 4096.0 4096.0 17.448587 208.605146 166.736016 20.419832 244.127623 195.128778
ttir diff (slow int64 on the left):
(triton) (base) jovyan@jupyter-ekrivov:~/triton/intel-xpu-backend-for-triton$ diff -y -W 300 ~/.triton/cache/4DKJ3ESMKW2C4H54ULBAFW3HFVM7NE7FA7XQXO6NJLEEU6YACGWQ/matmul_kernel_with_tensor_descriptors_int64.ttir ~/.triton/cache/D5D6E7VHPOP2SEMQZQCHAX2IWSLJWG335O6ATL52V5RT42USZD7Q/matmul_kernel_with_tensor_descriptors.ttir
#loc = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":24:0) | #loc = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":24:0)
#loc25 = loc("a_ptr"(#loc)) #loc25 = loc("a_ptr"(#loc))
#loc26 = loc("b_ptr"(#loc)) #loc26 = loc("b_ptr"(#loc))
#loc27 = loc("c_ptr"(#loc)) #loc27 = loc("c_ptr"(#loc))
#loc28 = loc("stride_am"(#loc)) <
#loc29 = loc("stride_ak"(#loc)) <
#loc30 = loc("stride_bk"(#loc)) <
#loc31 = loc("stride_bn"(#loc)) <
#loc32 = loc("stride_cm"(#loc)) <
#loc33 = loc("stride_cn"(#loc)) <
module { module {
tt.func public @matmul_kernel_with_tensor_descriptors_int64(%a_ptr: !tt.ptr<bf16> {tt.divisibility = 16 : i32} loc("a_ptr"(#loc)), %b_ptr: !tt.ptr | tt.func public @matmul_kernel_with_tensor_descriptors(%a_ptr: !tt.ptr<bf16> {tt.divisibility = 16 : i32} loc("a_ptr"(#loc)), %b_ptr: !tt.ptr<bf16>
%a_desc = arith.constant 4096 : i64 loc(#loc34) | %a_desc = arith.constant 8192 : i64 loc(#loc28)
%a_desc_0 = arith.constant 8192 : i64 loc(#loc34) | %num_pid_in_group = arith.constant 64 : i32 loc(#loc29)
%num_pid_in_group = arith.constant 64 : i32 loc(#loc35) <
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32> loc(#loc3) %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32> loc(#loc3)
%c256_i32 = arith.constant 256 : i32 loc(#loc3) %c256_i32 = arith.constant 256 : i32 loc(#loc3)
%c32_i32 = arith.constant 32 : i32 loc(#loc3) %c32_i32 = arith.constant 32 : i32 loc(#loc3)
%c0_i32 = arith.constant 0 : i32 loc(#loc3) %c0_i32 = arith.constant 0 : i32 loc(#loc3)
> %c1_i64 = arith.constant 1 : i64 loc(#loc3)
> %c4096_i64 = arith.constant 4096 : i64 loc(#loc3)
%c4096_i32 = arith.constant 4096 : i32 loc(#loc3) %c4096_i32 = arith.constant 4096 : i32 loc(#loc3)
%c4_i32 = arith.constant 4 : i32 loc(#loc3) %c4_i32 = arith.constant 4 : i32 loc(#loc3)
%pid = tt.get_program_id x : i32 loc(#loc36) | %pid = tt.get_program_id x : i32 loc(#loc30)
%group_id = arith.divsi %pid, %num_pid_in_group : i32 loc(#loc37) | %group_id = arith.divsi %pid, %num_pid_in_group : i32 loc(#loc31)
%first_pid_m = arith.muli %group_id, %c4_i32 : i32 loc(#loc38) | %first_pid_m = arith.muli %group_id, %c4_i32 : i32 loc(#loc32)
%group_size_m = arith.subi %c32_i32, %first_pid_m : i32 loc(#loc39) | %group_size_m = arith.subi %c32_i32, %first_pid_m : i32 loc(#loc33)
%group_size_m_1 = arith.minsi %group_size_m, %c4_i32 : i32 loc(#loc40) | %group_size_m_0 = arith.minsi %group_size_m, %c4_i32 : i32 loc(#loc34)
%pid_m = arith.remsi %pid, %num_pid_in_group : i32 loc(#loc41) | %pid_m = arith.remsi %pid, %num_pid_in_group : i32 loc(#loc35)
%pid_m_2 = arith.remsi %pid_m, %group_size_m_1 : i32 loc(#loc42) | %pid_m_1 = arith.remsi %pid_m, %group_size_m_0 : i32 loc(#loc36)
%pid_m_3 = arith.addi %first_pid_m, %pid_m_2 : i32 loc(#loc43) | %pid_m_2 = arith.addi %first_pid_m, %pid_m_1 : i32 loc(#loc37)
%pid_n = arith.divsi %pid_m, %group_size_m_1 : i32 loc(#loc44) | %pid_n = arith.divsi %pid_m, %group_size_m_0 : i32 loc(#loc38)
%a_desc_4 = tt.make_tensor_ptr %a_ptr, [%a_desc_0, %a_desc], [%stride_am, %stride_ak], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<2 | %a_desc_3 = tt.make_tensor_ptr %a_ptr, [%a_desc, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256
%b_desc = tt.make_tensor_ptr %b_ptr, [%a_desc, %a_desc], [%stride_bk, %stride_bn], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x25 | %b_desc = tt.make_tensor_ptr %b_ptr, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32
%a = arith.muli %pid_m_3, %c256_i32 : i32 loc(#loc46) | %a = arith.muli %pid_m_2, %c256_i32 : i32 loc(#loc40)
%b = arith.muli %pid_n, %c256_i32 : i32 loc(#loc47) | %b = arith.muli %pid_n, %c256_i32 : i32 loc(#loc41)
%off_k:2 = scf.for %_ = %c0_i32 to %c4096_i32 step %c32_i32 iter_args(%accumulator = %cst, %off_k_5 = %c0_i32) -> (tensor<256x256xf32>, i32) : | %off_k:2 = scf.for %_ = %c0_i32 to %c4096_i32 step %c32_i32 iter_args(%accumulator = %cst, %off_k_4 = %c0_i32) -> (tensor<256x256xf32>, i32) :
%a_6 = tt.advance %a_desc_4, [%a, %off_k_5] : <tensor<256x32xbf16>> loc(#loc49) | %a_5 = tt.advance %a_desc_3, [%a, %off_k_4] : <tensor<256x32xbf16>> loc(#loc43)
%a_7 = tt.load %a_6 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xbf16>> loc(#loc49) | %a_6 = tt.load %a_5 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xbf16>> loc(#loc43)
%b_8 = tt.advance %b_desc, [%off_k_5, %b] : <tensor<32x256xbf16>> loc(#loc50) | %b_7 = tt.advance %b_desc, [%off_k_4, %b] : <tensor<32x256xbf16>> loc(#loc44)
%b_9 = tt.load %b_8 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>> loc(#loc50) | %b_8 = tt.load %b_7 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>> loc(#loc44)
%accumulator_10 = tt.dot %a_7, %b_9, %accumulator, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32> lo | %accumulator_9 = tt.dot %a_6, %b_8, %accumulator, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32> loc
%off_k_11 = arith.addi %off_k_5, %c32_i32 : i32 loc(#loc52) | %off_k_10 = arith.addi %off_k_4, %c32_i32 : i32 loc(#loc46)
scf.yield %accumulator_10, %off_k_11 : tensor<256x256xf32>, i32 loc(#loc21) | scf.yield %accumulator_9, %off_k_10 : tensor<256x256xf32>, i32 loc(#loc21)
} loc(#loc54) | } loc(#loc48)
%c_desc = tt.make_tensor_ptr %c_ptr, [%a_desc_0, %a_desc], [%stride_cm, %stride_cn], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256 | %c_desc = tt.make_tensor_ptr %c_ptr, [%a_desc, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x2
%0 = tt.advance %c_desc, [%a, %b] : <tensor<256x256xf32>> loc(#loc23) %0 = tt.advance %c_desc, [%a, %b] : <tensor<256x256xf32>> loc(#loc23)
tt.store %0, %off_k#0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32>> loc(#loc23) tt.store %0, %off_k#0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32>> loc(#loc23)
tt.return loc(#loc24) tt.return loc(#loc24)
} loc(#loc) } loc(#loc)
} loc(#loc) } loc(#loc)
#loc1 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":46:39) | #loc1 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":46:39)
#loc2 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":38:38) | #loc2 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":38:38)
#loc3 = loc(unknown) #loc3 = loc(unknown)
#loc4 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":35:24) | #loc4 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":35:24)
#loc5 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":39:22) | #loc5 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":39:22)
#loc6 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":40:29) | #loc6 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":40:29)
#loc7 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":41:35) | #loc7 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":41:35)
#loc8 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":41:48) | #loc8 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":41:48)
#loc9 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":42:34) | #loc9 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":42:34)
#loc10 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":42:54) | #loc10 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":42:54)
#loc11 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":42:27) | #loc11 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":42:27)
#loc12 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":43:40) | #loc12 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":43:40)
#loc13 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":48:39) | #loc13 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":48:39)
#loc14 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":53:33) | #loc14 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":53:33)
#loc15 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":54:40) | #loc15 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":54:40)
#loc16 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":52:25) | #loc16 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":52:25)
#loc17 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":53:24) | #loc17 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":53:24)
#loc18 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":54:24) | #loc18 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":54:24)
#loc19 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":55:33) | #loc19 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":55:33)
#loc20 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":56:17) | #loc20 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":56:17)
#loc21 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":56:8) | #loc21 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":56:8)
#loc22 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":60:39) | #loc22 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":60:39)
#loc23 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":61:63) | #loc23 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":61:63)
#loc24 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":61:4) | #loc24 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":61:4)
#loc34 = loc("a_desc"(#loc1)) | #loc28 = loc("a_desc"(#loc1))
#loc35 = loc("num_pid_in_group"(#loc2)) | #loc29 = loc("num_pid_in_group"(#loc2))
#loc36 = loc("pid"(#loc4)) | #loc30 = loc("pid"(#loc4))
#loc37 = loc("group_id"(#loc5)) | #loc31 = loc("group_id"(#loc5))
#loc38 = loc("first_pid_m"(#loc6)) | #loc32 = loc("first_pid_m"(#loc6))
#loc39 = loc("group_size_m"(#loc7)) | #loc33 = loc("group_size_m"(#loc7))
#loc40 = loc("group_size_m"(#loc8)) | #loc34 = loc("group_size_m"(#loc8))
#loc41 = loc("pid_m"(#loc9)) | #loc35 = loc("pid_m"(#loc9))
#loc42 = loc("pid_m"(#loc10)) | #loc36 = loc("pid_m"(#loc10))
#loc43 = loc("pid_m"(#loc11)) | #loc37 = loc("pid_m"(#loc11))
#loc44 = loc("pid_n"(#loc12)) | #loc38 = loc("pid_n"(#loc12))
#loc45 = loc("b_desc"(#loc13)) | #loc39 = loc("b_desc"(#loc13))
#loc46 = loc("a"(#loc14)) | #loc40 = loc("a"(#loc14))
#loc47 = loc("b"(#loc15)) | #loc41 = loc("b"(#loc15))
#loc48 = loc("accumulator"(#loc16)) | #loc42 = loc("accumulator"(#loc16))
#loc49 = loc("a"(#loc17)) | #loc43 = loc("a"(#loc17))
#loc50 = loc("b"(#loc18)) | #loc44 = loc("b"(#loc18))
#loc51 = loc("accumulator"(#loc19)) | #loc45 = loc("accumulator"(#loc19))
#loc52 = loc("off_k"(#loc20)) | #loc46 = loc("off_k"(#loc20))
#loc53 = loc("c_desc"(#loc22)) | #loc47 = loc("c_desc"(#loc22))
#loc54 = loc("off_k"(#loc48)) | #loc48 = loc("off_k"(#loc42))
TTGIR diff:
(triton) (base) jovyan@jupyter-ekrivov:~/triton/intel-xpu-backend-for-triton$ diff -y -W 300 ~/.triton/cache/4DKJ3ESMKW2C4H54ULBAFW3HFVM7NE7FA7XQXO6NJLEEU6YACGWQ/matmul_kernel_with_tensor_descriptors_int64.ttgir ~/.triton/cache/D5D6E7VHPOP2SEMQZQCHAX2IWSLJWG335O6ATL52V5RT42USZD7Q/matmul_kernel_with_tensor_descriptors.ttgir
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> | #loc = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":24:0)
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 16], order = [1, 0]}> <
#loc = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":24:0) <
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [ #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> | #loc23 = loc("a_ptr"(#loc))
#smem = #ttg.shared_memory | #loc24 = loc("b_ptr"(#loc))
#loc24 = loc("a_ptr"(#loc)) | #loc25 = loc("c_ptr"(#loc))
#loc25 = loc("b_ptr"(#loc)) <
#loc26 = loc("c_ptr"(#loc)) <
#loc27 = loc("stride_am"(#loc)) <
#loc28 = loc("stride_ak"(#loc)) <
#loc29 = loc("stride_bk"(#loc)) <
#loc30 = loc("stride_bn"(#loc)) <
#loc31 = loc("stride_cm"(#loc)) <
#loc32 = loc("stride_cn"(#loc)) <
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 1 module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 1
tt.func public @matmul_kernel_with_tensor_descriptors_int64(%a_ptr: !tt.ptr<bf16> {tt.divisibility = 16 : i32} loc("a_ptr"(#loc)), %b_ptr: !tt.ptr | tt.func public @matmul_kernel_with_tensor_descriptors(%a_ptr: !tt.ptr<bf16> {tt.divisibility = 16 : i32} loc("a_ptr"(#loc)), %b_ptr: !tt.ptr<bf16>
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> loc(#loc1) %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> loc(#loc1)
%c64_i32 = arith.constant 64 : i32 loc(#loc1) %c64_i32 = arith.constant 64 : i32 loc(#loc1)
%c8192_i64 = arith.constant 8192 : i64 loc(#loc1) %c8192_i64 = arith.constant 8192 : i64 loc(#loc1)
%c4096_i64 = arith.constant 4096 : i64 loc(#loc1) <
%c256_i32 = arith.constant 256 : i32 loc(#loc1) %c256_i32 = arith.constant 256 : i32 loc(#loc1)
%c32_i32 = arith.constant 32 : i32 loc(#loc1) %c32_i32 = arith.constant 32 : i32 loc(#loc1)
%c0_i32 = arith.constant 0 : i32 loc(#loc1) %c0_i32 = arith.constant 0 : i32 loc(#loc1)
> %c1_i64 = arith.constant 1 : i64 loc(#loc1)
> %c4096_i64 = arith.constant 4096 : i64 loc(#loc1)
%c4096_i32 = arith.constant 4096 : i32 loc(#loc1) %c4096_i32 = arith.constant 4096 : i32 loc(#loc1)
%c4_i32 = arith.constant 4 : i32 loc(#loc1) %c4_i32 = arith.constant 4 : i32 loc(#loc1)
%pid = tt.get_program_id x : i32 loc(#loc33) | %c4032_i32 = arith.constant 4032 : i32 loc(#loc1)
%group_id = arith.divsi %pid, %c64_i32 : i32 loc(#loc34) | %pid = tt.get_program_id x : i32 loc(#loc26)
%first_pid_m = arith.muli %group_id, %c4_i32 : i32 loc(#loc35) | %group_id = arith.divsi %pid, %c64_i32 : i32 loc(#loc27)
%group_size_m = arith.subi %c32_i32, %first_pid_m : i32 loc(#loc36) | %first_pid_m = arith.muli %group_id, %c4_i32 : i32 loc(#loc28)
%group_size_m_0 = arith.minsi %group_size_m, %c4_i32 : i32 loc(#loc37) | %group_size_m = arith.subi %c32_i32, %first_pid_m : i32 loc(#loc29)
%pid_m = arith.remsi %pid, %c64_i32 : i32 loc(#loc38) | %group_size_m_0 = arith.minsi %group_size_m, %c4_i32 : i32 loc(#loc30)
%pid_m_1 = arith.remsi %pid_m, %group_size_m_0 : i32 loc(#loc39) | %pid_m = arith.remsi %pid, %c64_i32 : i32 loc(#loc31)
%pid_m_2 = arith.addi %first_pid_m, %pid_m_1 : i32 loc(#loc40) | %pid_m_1 = arith.remsi %pid_m, %group_size_m_0 : i32 loc(#loc32)
%pid_n = arith.divsi %pid_m, %group_size_m_0 : i32 loc(#loc41) | %pid_m_2 = arith.addi %first_pid_m, %pid_m_1 : i32 loc(#loc33)
%a_desc = tt.make_tensor_ptr %a_ptr, [%c8192_i64, %c4096_i64], [%stride_am, %stride_ak], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor | %pid_n = arith.divsi %pid_m, %group_size_m_0 : i32 loc(#loc34)
%b_desc = tt.make_tensor_ptr %b_ptr, [%c4096_i64, %c4096_i64], [%stride_bk, %stride_bn], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor | %a_desc = tt.make_tensor_ptr %a_ptr, [%c8192_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<25
%a = arith.muli %pid_m_2, %c256_i32 : i32 loc(#loc44) | %b_desc = tt.make_tensor_ptr %b_ptr, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32
%b = arith.muli %pid_n, %c256_i32 : i32 loc(#loc45) | %a = arith.muli %pid_m_2, %c256_i32 : i32 loc(#loc37)
%off_k:2 = scf.for %off_k_3 = %c0_i32 to %c4096_i32 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %c0_i32) -> (tensor<256x256xf32, #mma>, i32) | %b = arith.muli %pid_n, %c256_i32 : i32 loc(#loc38)
%a_4 = tt.advance %a_desc, [%a, %arg11] : <tensor<256x32xbf16, #blocked>> loc(#loc47) | %a_3 = tt.advance %a_desc, [%a, %c0_i32] : <tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>> loc(#loc39)
%a_5 = tt.load %a_4 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xbf16, #blocked>> loc(#loc47) | ttig.prefetch %a_3 {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1,
%a_6 = ttg.local_alloc %a_5 : (tensor<256x32xbf16, #blocked>) -> !ttg.memdesc<256x32xbf16, #shared, #smem> loc(#loc47) | %b_4 = tt.advance %b_desc, [%c0_i32, %b] : <tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>> loc(#loc40)
%b_7 = tt.advance %b_desc, [%arg11, %b] : <tensor<32x256xbf16, #blocked1>> loc(#loc48) | ttig.prefetch %b_4 {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1,
%b_8 = tt.load %b_7 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16, #blocked1>> loc(#loc48) | %a_5 = tt.advance %a_desc, [%a, %c32_i32] : <tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>> loc(#loc39)
%b_9 = ttg.local_alloc %b_8 : (tensor<32x256xbf16, #blocked1>) -> !ttg.memdesc<32x256xbf16, #shared, #smem> loc(#loc48) | ttig.prefetch %a_5 {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1,
%a_10 = ttg.local_load %a_6 : !ttg.memdesc<256x32xbf16, #shared, #smem> -> tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = | %b_6 = tt.advance %b_desc, [%c32_i32, %b] : <tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>> loc(#loc40)
%b_11 = ttg.local_load %b_9 : !ttg.memdesc<32x256xbf16, #shared, #smem> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = | ttig.prefetch %b_6 {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1,
%accumulator = tt.dot %a_10, %b_11, %arg10, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * | %off_k:6 = scf.for %off_k_7 = %c0_i32 to %c4096_i32 step %c32_i32 iter_args(%arg4 = %cst, %off_k_8 = %c32_i32, %a_9 = %a_3, %a_10 = %a_5, %b_11
%off_k_12 = arith.addi %arg11, %c32_i32 : i32 loc(#loc50) | %off_k_13 = arith.cmpi slt, %off_k_7, %c4032_i32 : i32 loc(#loc45)
scf.yield %accumulator, %off_k_12 : tensor<256x256xf32, #mma>, i32 loc(#loc20) | %off_k_14 = arith.addi %off_k_8, %c32_i32 : i32 loc(#loc42)
} loc(#loc52) | %a_15 = tt.advance %a_desc, [%a, %off_k_14] : <tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>> loc(#loc39)
%c_desc = tt.make_tensor_ptr %c_ptr, [%c8192_i64, %c4096_i64], [%stride_cm, %stride_cn], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor | %off_k_16 = tt.splat %off_k_13 : i1 -> tensor<256x32xi1, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc45)
%0 = tt.advance %c_desc, [%a, %b] : <tensor<256x256xf32, #mma>> loc(#loc22) | ttig.prefetch %a_15, %off_k_16 {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes =
tt.store %0, %off_k#0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32, #mma>> loc(#loc22) | %b_17 = tt.advance %b_desc, [%off_k_14, %b] : <tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>> loc(#loc40)
tt.return loc(#loc23) | %off_k_18 = tt.splat %off_k_13 : i1 -> tensor<32x256xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> loc(#loc45)
> ttig.prefetch %b_17, %off_k_18 {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes =
> %a_19 = tt.load %a_9 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, par
> %b_20 = tt.load %b_11 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, pa
> %accumulator = tt.dot %a_19, %b_20, %arg4, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> *
> scf.yield %accumulator, %off_k_14, %a_10, %a_15, %b_12, %b_17 : tensor<256x256xf32, #mma>, i32, !tt.ptr<tensor<256x32xbf16, #ttg.dot_op<{opIdx
> } loc(#loc45)
> %c_desc = tt.make_tensor_ptr %c_ptr, [%c8192_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<25
> %0 = tt.advance %c_desc, [%a, %b] : <tensor<256x256xf32, #mma>> loc(#loc21)
> tt.store %0, %off_k#0 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x256xf32, #mma>> loc(#loc21)
> tt.return loc(#loc22)
} loc(#loc) } loc(#loc)
} loc(#loc) } loc(#loc)
#loc1 = loc(unknown) #loc1 = loc(unknown)
#loc2 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":35:24) | #loc2 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":35:24)
#loc3 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":39:22) | #loc3 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":39:22)
#loc4 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":40:29) | #loc4 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":40:29)
#loc5 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":41:35) | #loc5 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":41:35)
#loc6 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":41:48) | #loc6 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":41:48)
#loc7 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":42:34) | #loc7 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":42:34)
#loc8 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":42:54) | #loc8 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":42:54)
#loc9 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":42:27) | #loc9 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":42:27)
#loc10 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":43:40) | #loc10 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":43:40)
#loc11 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":46:39) | #loc11 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":46:39)
#loc12 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":48:39) | #loc12 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":48:39)
#loc13 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":53:33) | #loc13 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":53:33)
#loc14 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":54:40) | #loc14 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":54:40)
#loc15 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":52:25) | #loc15 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":53:24)
#loc16 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":53:24) | #loc16 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":54:24)
#loc17 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":54:24) | #loc17 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":52:25)
#loc18 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":55:33) | #loc18 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":56:17)
#loc19 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":56:17) | #loc19 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":55:33)
#loc20 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":56:8) | #loc20 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":60:39)
#loc21 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":60:39) | #loc21 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":61:63)
#loc22 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":61:63) | #loc22 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py":61:4)
#loc23 = loc("/home/jovyan/triton/intel-xpu-backend-for-triton/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark_int64.py":61:4) | #loc26 = loc("pid"(#loc2))
#loc33 = loc("pid"(#loc2)) | #loc27 = loc("group_id"(#loc3))
#loc34 = loc("group_id"(#loc3)) | #loc28 = loc("first_pid_m"(#loc4))
#loc35 = loc("first_pid_m"(#loc4)) | #loc29 = loc("group_size_m"(#loc5))
#loc36 = loc("group_size_m"(#loc5)) | #loc30 = loc("group_size_m"(#loc6))
#loc37 = loc("group_size_m"(#loc6)) | #loc31 = loc("pid_m"(#loc7))
#loc38 = loc("pid_m"(#loc7)) | #loc32 = loc("pid_m"(#loc8))
#loc39 = loc("pid_m"(#loc8)) | #loc33 = loc("pid_m"(#loc9))
#loc40 = loc("pid_m"(#loc9)) | #loc34 = loc("pid_n"(#loc10))
#loc41 = loc("pid_n"(#loc10)) | #loc35 = loc("a_desc"(#loc11))
#loc42 = loc("a_desc"(#loc11)) | #loc36 = loc("b_desc"(#loc12))
#loc43 = loc("b_desc"(#loc12)) | #loc37 = loc("a"(#loc13))
#loc44 = loc("a"(#loc13)) | #loc38 = loc("b"(#loc14))
#loc45 = loc("b"(#loc14)) | #loc39 = loc("a"(#loc15))
#loc46 = loc("accumulator"(#loc15)) | #loc40 = loc("b"(#loc16))
#loc47 = loc("a"(#loc16)) | #loc41 = loc("accumulator"(#loc17))
#loc48 = loc("b"(#loc17)) | #loc42 = loc("off_k"(#loc18))
#loc49 = loc("accumulator"(#loc18)) | #loc43 = loc("accumulator"(#loc19))
#loc50 = loc("off_k"(#loc19)) | #loc44 = loc("c_desc"(#loc20))
#loc51 = loc("c_desc"(#loc21)) | #loc45 = loc("off_k"(#loc41))
#loc52 = loc("off_k"(#loc46)) <
Metadata
Metadata
Assignees
Labels
No labels