Skip to content

Commit 344bf2c

Browse files
authored
Improve GEMM performance of shape 4096x8x128x16384 (#2601)
Locally: ~6.16 to 6.44(Default path) 6.79 (Advanced path) ``` Before: default: matmul-performance: B M K N Triton-GB/s XeTLA-GB/s Triton-GB/s-min XeTLA-GB/s-min Triton-GB/s-max XeTLA-GB/s-max Triton-TFlops XeTLA-TFlops Triton-TFlops-min XeTLA-TFlops-min Triton-TFlops-max XeTLA-TFlops-max Triton-CV XeTLA-CV 0 4096.0 8.0 128.0 16384.0 866.781315 1013.591803 860.94579 1009.020643 869.550433 1016.113302 6.161104 7.204637 6.119625 7.172145 6.180787 7.22256 0.00239 0.001517 Advanced: matmul-performance: B M K N Triton-GB/s XeTLA-GB/s Triton-GB/s-min XeTLA-GB/s-min Triton-GB/s-max XeTLA-GB/s-max Triton-TFlops XeTLA-TFlops Triton-TFlops-min XeTLA-TFlops-min Triton-TFlops-max XeTLA-TFlops-max Triton-CV XeTLA-CV 0 4096.0 8.0 128.0 16384.0 870.200789 1013.800005 862.488109 1010.581648 875.605071 1018.837387 6.18541 7.206117 6.130588 7.183241 6.223824 7.241923 0.004391 0.002175 After: default: matmul-performance: B M K N Triton-GB/s XeTLA-GB/s Triton-GB/s-min XeTLA-GB/s-min Triton-GB/s-max XeTLA-GB/s-max Triton-TFlops XeTLA-TFlops Triton-TFlops-min XeTLA-TFlops-min Triton-TFlops-max XeTLA-TFlops-max Triton-CV XeTLA-CV 0 4096.0 8.0 128.0 16384.0 906.811524 1012.40191 900.05344 1006.608438 910.558331 1019.53374 6.44564 7.196179 6.397603 7.154999 6.472272 7.246872 0.002137 0.001734 Advanced: matmul-performance: B M K N Triton-GB/s XeTLA-GB/s Triton-GB/s-min XeTLA-GB/s-min Triton-GB/s-max XeTLA-GB/s-max Triton-TFlops XeTLA-TFlops Triton-TFlops-min XeTLA-TFlops-min Triton-TFlops-max XeTLA-TFlops-max Triton-CV XeTLA-CV 0 4096.0 8.0 128.0 16384.0 954.660979 1012.176091 952.807487 1009.501134 959.487262 1015.609403 6.785755 7.194574 6.77258 7.17556 6.82006 7.218978 0.002269 0.001743 ```
1 parent 1dbef57 commit 344bf2c

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ def matmul_kernel_with_block_pointers(
9898
triton.Config(
9999
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
100100
num_stages=s, num_warps=32) for s in [2]
101+
] + [
102+
triton.Config(
103+
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 1024, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
104+
num_stages=s, num_warps=32) for s in [2, 3]
101105
] + [
102106
triton.Config(
103107
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},

0 commit comments

Comments
 (0)