Skip to content

Commit a15f90f

Browse files
authored
Tianxing/moe gemm (#685)
Implemented moe gemm, test and benchmarking. The gemm support weights so as the benchmark. The kernel uses the predefined block sizes. The bench mark prints time, flops and memory throughput
1 parent 9679203 commit a15f90f

File tree

6 files changed

+550
-0
lines changed

6 files changed

+550
-0
lines changed

.github/workflows/amd_perf_kernel_Integration_tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ jobs:
130130
pytest -vvvv ./python/perf-kernels/softmax.py
131131
pytest -vvv ./python/perf-kernels/rmsnorm.py
132132
pytest -vvv ./python/perf-kernels/layernorm.py
133+
pytest -vvv ./python/perf-kernels/fused_moe/moe-gemm.py
133134
sh ./python/perf-kernels/streamk/utils/unittest.sh
134135
pytest -vvv ./python/perf-kernels/multreduce_matmul_kernel.py
135136
- name: Run Perf Kernels Benchmark

python/perf-kernels/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,6 @@ Kernel that implements RMS Norm over a row of tensor.
9999

100100
## `layernorm.py`
101101
Kernel that implements Layer Normalization over a row on tensor
102+
103+
## `fused_moe/moe-gemm.py`
104+
Kernel that implements moe gemm.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{
2+
"small_M": {
3+
"BLOCK_SIZE_M": 64,
4+
"BLOCK_SIZE_N": 64,
5+
"BLOCK_SIZE_K": 64,
6+
"GROUP_SIZE_M": 4,
7+
"num_warps": 8,
8+
"num_stages": 2,
9+
"waves_per_eu": 0,
10+
"matrix_instr_nonkdim": 16,
11+
"kpack": 2
12+
},
13+
"medium_M": {
14+
"BLOCK_SIZE_M": 128,
15+
"BLOCK_SIZE_N": 128,
16+
"BLOCK_SIZE_K": 128,
17+
"GROUP_SIZE_M": 1,
18+
"num_warps": 8,
19+
"num_stages": 2,
20+
"waves_per_eu": 0,
21+
"matrix_instr_nonkdim": 16,
22+
"kpack": 2
23+
},
24+
"large_M": {
25+
"BLOCK_SIZE_M": 256,
26+
"BLOCK_SIZE_N": 256,
27+
"BLOCK_SIZE_K": 64,
28+
"GROUP_SIZE_M": 1,
29+
"num_warps": 8,
30+
"num_stages": 2,
31+
"waves_per_eu": 0,
32+
"matrix_instr_nonkdim": 16,
33+
"kpack": 2
34+
}
35+
}

0 commit comments

Comments
 (0)