Skip to content

Commit 8e42af9

Browse files
authored
Tianxing/moe int8 w8a8 (#765)
Add support for in8 fused moe
1 parent f669d30 commit 8e42af9

File tree

4 files changed

+133
-32
lines changed

4 files changed

+133
-32
lines changed

python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X,dtype=int8_w8a16.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"kpack": 2
1212
},
1313
"medium_M": {
14-
"BLOCK_SIZE_M": 64,
14+
"BLOCK_SIZE_M": 128,
1515
"BLOCK_SIZE_N": 128,
1616
"BLOCK_SIZE_K": 64,
1717
"GROUP_SIZE_M": 1,
@@ -23,7 +23,7 @@
2323
},
2424
"large_M": {
2525
"BLOCK_SIZE_M": 128,
26-
"BLOCK_SIZE_N": 256,
26+
"BLOCK_SIZE_N": 128,
2727
"BLOCK_SIZE_K": 128,
2828
"GROUP_SIZE_M": 1,
2929
"num_warps": 8,
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{
2+
"small_M": {
3+
"BLOCK_SIZE_M": 64,
4+
"BLOCK_SIZE_N": 64,
5+
"BLOCK_SIZE_K": 128,
6+
"GROUP_SIZE_M": 4,
7+
"num_warps": 8,
8+
"num_stages": 2,
9+
"waves_per_eu": 0,
10+
"matrix_instr_nonkdim": 16,
11+
"kpack": 2
12+
},
13+
"medium_M": {
14+
"BLOCK_SIZE_M": 128,
15+
"BLOCK_SIZE_N": 128,
16+
"BLOCK_SIZE_K": 128,
17+
"GROUP_SIZE_M": 1,
18+
"num_warps": 8,
19+
"num_stages": 2,
20+
"waves_per_eu": 0,
21+
"matrix_instr_nonkdim": 16,
22+
"kpack": 2
23+
},
24+
"large_M": {
25+
"BLOCK_SIZE_M": 128,
26+
"BLOCK_SIZE_N": 128,
27+
"BLOCK_SIZE_K": 128,
28+
"GROUP_SIZE_M": 1,
29+
"num_warps": 8,
30+
"num_stages": 2,
31+
"waves_per_eu": 0,
32+
"matrix_instr_nonkdim": 16,
33+
"kpack": 2
34+
}
35+
}

python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,6 @@
3030
"num_stages": 2,
3131
"waves_per_eu": 0,
3232
"matrix_instr_nonkdim": 16,
33-
"kpack": 2
33+
"kpack": 1
3434
}
3535
}

python/perf-kernels/fused_moe/moe-gemm.py

Lines changed: 95 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434
class MetaData():
3535
use_fp8_w8a8 = False
3636
use_int8_w8a16 = False
37+
use_int8_w8a8 = False
3738

38-
def __init__(self, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config):
39+
def __init__(self, top_k, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config):
40+
self.top_k = top_k
3941
self.topk_weights = topk_weights
4042
self.topk_ids = topk_ids
4143
self.sorted_token_ids = sorted_token_ids
@@ -54,10 +56,15 @@ def set_use_int8_w8a16(self, b_descale):
5456
self.b_descale = b_descale
5557
self.a_descale = None
5658

59+
def set_use_int8_w8a8(self, a_descale, b_descale):
60+
self.use_int8_w8a8 = True
61+
self.a_descale = a_descale
62+
self.b_descale = b_descale
63+
5764
def check_args(self, a, b, o):
5865
assert a.shape[-1] == b.shape[-1] and b.shape[1] == o.shape[-1]
5966

60-
assert not (self.use_fp8_w8a8 and self.use_int8_w8a16)
67+
assert not (self.use_fp8_w8a8 and self.use_int8_w8a16 and self.use_int8_w8a8)
6168
if self.use_fp8_w8a8:
6269
assert self.fp8_type in supported_fp8, f"fp8 type {self.fp8_type} not supported"
6370

@@ -89,6 +96,7 @@ def moe_gemm_kernel(
8996
MUL_ROUTED_WEIGHT: tl.constexpr,
9097
use_fp8_w8a8: tl.constexpr,
9198
use_int8_w8a16: tl.constexpr,
99+
use_int8_w8a8: tl.constexpr,
92100
BLOCK_SIZE_M: tl.constexpr,
93101
BLOCK_SIZE_N: tl.constexpr,
94102
BLOCK_SIZE_K: tl.constexpr,
@@ -146,7 +154,7 @@ def moe_gemm_kernel(
146154
b_scale_ptrs = B_scale + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
147155
b_scale = tl.load(b_scale_ptrs)
148156

149-
if use_fp8_w8a8:
157+
if use_fp8_w8a8 or use_int8_w8a8:
150158
a_scale = tl.load(A_scale)
151159
b_scale = tl.load(B_scale + off_experts)
152160

@@ -163,7 +171,7 @@ def moe_gemm_kernel(
163171

164172
if use_int8_w8a16:
165173
accumulator = tl.dot(a, b.to(a.dtype), acc=accumulator)
166-
elif use_fp8_w8a8:
174+
elif use_fp8_w8a8 or use_int8_w8a8:
167175
accumulator += tl.dot(a, b)
168176
else:
169177
accumulator = tl.dot(a, b, acc=accumulator)
@@ -177,7 +185,7 @@ def moe_gemm_kernel(
177185

178186
if use_int8_w8a16:
179187
accumulator = (accumulator * b_scale).to(Out.dtype.element_ty)
180-
elif use_fp8_w8a8:
188+
elif use_fp8_w8a8 or use_int8_w8a8:
181189
accumulator = (accumulator * a_scale * b_scale).to(Out.dtype.element_ty)
182190
else:
183191
accumulator = accumulator.to(Out.dtype.element_ty)
@@ -278,11 +286,13 @@ def moe_align_block_size(topk_ids: torch.Tensor, block_size: int,
278286

279287

280288
def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False,
281-
use_fp8_w8a8: Optional[bool] = False):
289+
use_int8_w8a8: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False):
282290
if use_fp8_w8a8:
283291
return "fp8_w8a8"
284292
elif use_int8_w8a16:
285293
return "int8_w8a16"
294+
elif use_int8_w8a8:
295+
return "int8_w8a8"
286296
elif dtype == torch.float:
287297
# avoiding cases where kernel fails when float32 MoE
288298
# use fp16/bfloat16 configs
@@ -360,19 +370,19 @@ def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, metadata: MetaDa
360370
# TODO shard M dim
361371
metadata.check_args(a, b, c)
362372

363-
topk_ids, num_tokens_post_padded, topk_weights, sorted_token_ids, expert_ids, config = metadata.topk_ids, metadata.num_tokens_post_padded, metadata.topk_weights, metadata.sorted_token_ids, metadata.expert_ids, metadata.config
373+
num_tokens_post_padded, topk_weights, sorted_token_ids, expert_ids, config = metadata.num_tokens_post_padded, metadata.topk_weights, metadata.sorted_token_ids, metadata.expert_ids, metadata.config
364374

365-
use_fp8_w8a8, use_int8_w8a16 = metadata.use_fp8_w8a8, metadata.use_int8_w8a16
375+
use_fp8_w8a8, use_int8_w8a16, use_int8_w8a8 = metadata.use_fp8_w8a8, metadata.use_int8_w8a16, metadata.use_int8_w8a8
366376
a_descale, b_descale = None, None
367377
stride_bse = None
368378
stride_bsn = None
369-
if use_fp8_w8a8 or use_int8_w8a16:
379+
if use_fp8_w8a8 or use_int8_w8a16 or use_int8_w8a8:
370380
a_descale, b_descale = metadata.a_descale, metadata.b_descale
371381
if use_int8_w8a16:
372382
stride_bse = b_descale.stride(0)
373383
stride_bsn = b_descale.stride(1)
374384

375-
_, top_k = topk_ids.shape
385+
top_k = metadata.top_k
376386

377387
EM = num_tokens_post_padded.item()
378388
_, N, K = b.shape
@@ -384,7 +394,7 @@ def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, metadata: MetaDa
384394
b_descale, a.stride(0), a.stride(1), b.stride(0), b.stride(1), b.stride(2), c.stride(1),
385395
c.stride(2), stride_bse, stride_bsn, top_k, topk_weights, sorted_token_ids, expert_ids, EM, N,
386396
K, EVEN_K, MUL_ROUTED_WEIGHT=topk_weights is not None, use_fp8_w8a8=use_fp8_w8a8,
387-
use_int8_w8a16=use_int8_w8a16, **config)
397+
use_int8_w8a16=use_int8_w8a16, use_int8_w8a8=use_int8_w8a8, **config)
388398
return c
389399

390400

@@ -410,8 +420,9 @@ def quantize_tensor(tensor: torch.Tensor, dtype, dim=()) -> tuple[torch.Tensor,
410420
return tensor_quantized, scale, 1 / scale
411421

412422

413-
def quantize_input(a, b, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, metatdata: MetaData, fp8_type=None):
414-
assert not (use_fp8_w8a8 and use_int8_w8a16)
423+
def quantize_input(a, b, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, use_int8_w8a8: tl.constexpr,
424+
metatdata: MetaData, fp8_type=None):
425+
assert not (use_fp8_w8a8 and use_int8_w8a16 and use_int8_w8a8)
415426
assert not (use_fp8_w8a8 and fp8_type is None)
416427

417428
if use_fp8_w8a8:
@@ -420,14 +431,20 @@ def quantize_input(a, b, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexp
420431
metatdata.set_use_fp8_w8a8(a_descale, b_descale, fp8_type)
421432
return a_quantized, b_quantized
422433

434+
if use_int8_w8a8:
435+
a_quantized, _, a_descale = quantize_tensor(a, dtype=torch.int8)
436+
b_quantized, _, b_descale = quantize_tensor(b, dim=(0, ), dtype=torch.int8)
437+
metatdata.set_use_int8_w8a8(a_descale, b_descale)
438+
return a_quantized, b_quantized
439+
423440
if use_int8_w8a16:
424441
b_quantized, _, b_descale = quantize_tensor(b, dim=(0, 1), dtype=torch.int8)
425442
metatdata.set_use_int8_w8a16(b_descale)
426443
return a, b_quantized
427444

428445

429446
def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, use_fp8_w8a8: bool,
430-
use_int8_w8a16: bool, fp8_type, dtype):
447+
use_int8_w8a16: bool, use_int8_w8a8: bool, fp8_type, dtype):
431448
a = torch.randn((M, K), dtype=dtype, device='cuda')
432449
b = torch.randn((E, N, K), dtype=dtype, device='cuda')
433450
c = torch.zeros((M, top_k, N), dtype=dtype, device='cuda')
@@ -437,7 +454,8 @@ def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool
437454
softmax_vals = torch.softmax(values, dim=1)
438455
topk_weights, topk_ids = torch.topk(softmax_vals, k=top_k, dim=1)
439456

440-
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, dtype=dtype)
457+
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16,
458+
use_int8_w8a8=use_int8_w8a8, dtype=dtype)
441459
get_config_func = functools.partial(
442460
try_get_optimal_moe_config,
443461
E,
@@ -446,11 +464,11 @@ def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool
446464
config = get_config_func(M)
447465
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], E)
448466

449-
metadata = MetaData(topk_weights if routed_weight else None, topk_ids, sorted_token_ids, expert_ids,
467+
metadata = MetaData(top_k, topk_weights if routed_weight else None, topk_ids, sorted_token_ids, expert_ids,
450468
num_tokens_post_padded, config)
451469

452-
if use_fp8_w8a8 or use_int8_w8a16:
453-
a, b = quantize_input(a, b, use_fp8_w8a8, use_int8_w8a16, metadata, fp8_type)
470+
if use_fp8_w8a8 or use_int8_w8a16 or use_int8_w8a8:
471+
a, b = quantize_input(a, b, use_fp8_w8a8, use_int8_w8a16, use_int8_w8a8, metadata, fp8_type)
454472

455473
return a, b, c, metadata
456474

@@ -471,7 +489,7 @@ def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool
471489
def test_correctness(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, dtype=torch.float16):
472490
torch.manual_seed(20)
473491
a, b, c, metadata = input_helper(M, N, K, top_k, E, routed_weight=routed_weight, use_fp8_w8a8=False,
474-
use_int8_w8a16=False, fp8_type=None, dtype=dtype)
492+
use_int8_w8a16=False, use_int8_w8a8=False, fp8_type=None, dtype=dtype)
475493

476494
tri_out = moe_gemm(a, b, c, metadata)
477495

@@ -508,7 +526,7 @@ def test_correctness_fp8(M: int, N: int, K: int, top_k: int, E: int, routed_weig
508526
dtype=torch.float16):
509527
torch.manual_seed(20)
510528
a, b, c, metadata = input_helper(M, N, K, top_k, E, routed_weight=routed_weight, use_fp8_w8a8=use_fp8_w8a8,
511-
use_int8_w8a16=False, fp8_type=fp8_type, dtype=dtype)
529+
use_int8_w8a16=False, fp8_type=fp8_type, use_int8_w8a8=False, dtype=dtype)
512530

513531
tri_out = moe_gemm(a, b, c, metadata)
514532

@@ -545,11 +563,11 @@ def test_correctness_fp8(M: int, N: int, K: int, top_k: int, E: int, routed_weig
545563
])
546564
@pytest.mark.parametrize('routed_weight', [True, False])
547565
@pytest.mark.parametrize('use_int8_w8a16', [True])
548-
def test_correctness_int8(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, use_int8_w8a16,
549-
dtype=torch.float16):
566+
def test_correctness_int8_w8a16(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, use_int8_w8a16,
567+
dtype=torch.float16):
550568
torch.manual_seed(20)
551569
a, b, c, metadata = input_helper(M, N, K, top_k, E, routed_weight=routed_weight, use_fp8_w8a8=False,
552-
use_int8_w8a16=use_int8_w8a16, fp8_type=None, dtype=dtype)
570+
use_int8_w8a16=use_int8_w8a16, use_int8_w8a8=False, fp8_type=None, dtype=dtype)
553571

554572
tri_out = moe_gemm(a, b, c, metadata)
555573

@@ -560,7 +578,7 @@ def test_correctness_int8(M: int, N: int, K: int, top_k: int, E: int, routed_wei
560578
a_expanded = a.unsqueeze(1).repeat(1, top_k, 1)
561579
# (M, top_k, N, K)
562580
b_indexed = b[topk_ids]
563-
ref_out = torch.einsum("mek,menk->men", a_expanded.to(torch.float32), b_indexed.to(torch.float32))
581+
ref_out = torch.einsum("mek,menk->men", a_expanded.float(), b_indexed.float())
564582
if routed_weight:
565583
ref_out *= topk_weights.unsqueeze(-1)
566584

@@ -571,6 +589,46 @@ def test_correctness_int8(M: int, N: int, K: int, top_k: int, E: int, routed_wei
571589
torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=1e-2)
572590

573591

592+
@pytest.mark.parametrize("M, N, K, top_k, E", [
593+
(64, 14336, 4096, 2, 8),
594+
(16, 14336, 1, 2, 4),
595+
(1, 14336, 128, 2, 4),
596+
(16, 14336, 128, 1, 4),
597+
(16, 14336, 128, 1, 1),
598+
(64, 7186, 128, 2, 8),
599+
(64, 3584, 128, 2, 8),
600+
(64, 1792, 128, 2, 8),
601+
(64, 64, 128, 2, 8),
602+
])
603+
@pytest.mark.parametrize('routed_weight', [True, False])
604+
@pytest.mark.parametrize('use_int8_w8a8', [True])
605+
def test_correctness_int8_w8a8(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, use_int8_w8a8,
606+
dtype=torch.float16):
607+
torch.manual_seed(20)
608+
a, b, c, metadata = input_helper(M, N, K, top_k, E, routed_weight=routed_weight, use_fp8_w8a8=False,
609+
use_int8_w8a16=False, use_int8_w8a8=use_int8_w8a8, fp8_type=None, dtype=dtype)
610+
611+
tri_out = moe_gemm(a, b, c, metadata)
612+
613+
topk_ids = metadata.topk_ids
614+
topk_weights = metadata.topk_weights
615+
ref_out = torch.empty_like(c)
616+
# Repeat a -> (M, top_k, K)
617+
a_expanded = a.unsqueeze(1).repeat(1, top_k, 1)
618+
# (M, top_k, N, K)
619+
b_indexed = b[topk_ids]
620+
ref_out = torch.einsum("mek,menk->men", a_expanded.float(), b_indexed.float())
621+
if routed_weight:
622+
ref_out *= topk_weights.unsqueeze(-1)
623+
624+
ref_out = ref_out * metadata.b_descale[topk_ids].unsqueeze(-1)
625+
ref_out = ref_out * metadata.a_descale
626+
ref_out = ref_out.to(dtype)
627+
628+
# Validate correctness
629+
torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=1e-2)
630+
631+
574632
def get_configs():
575633
configs = [
576634
{"M": 64, "N": 256, "K": 128, "E": 8, "top_k": 2},
@@ -606,8 +664,10 @@ def model_benchmark_configs(args):
606664

607665
E = 8
608666
top_k = 2
667+
# The first moe layer
609668
moe_configs.append((model_name, M, N1, K1, E, top_k))
610-
moe_configs.append((model_name, M, N2, K2, E, top_k))
669+
# The second moe layer
670+
moe_configs.append((model_name, M * top_k, N2, K2, E, 1))
611671

612672
return moe_configs
613673

@@ -616,6 +676,7 @@ def run_benchmark(custom, args):
616676
routed_weight = args.routed_weight
617677
use_int8_w8a16 = args.int8_w8a16
618678
use_fp8_w8a8 = args.fp8_w8a8
679+
use_int8_w8a8 = args.int8_w8a8
619680
dtype = arg_to_torch_dtype[args.dtype]
620681
fp8_type = arg_to_torch_dtype[args.fp8_type]
621682

@@ -640,14 +701,15 @@ def run_benchmark(custom, args):
640701
styles=[('red', '-'), ('blue', '-'),
641702
('yellow', '-')], ylabel='ms / TFLOPS / GB/s', plot_name='moe-gemm-benchmark', args={
642703
'dtype': dtype, 'routed_weight': routed_weight, 'use_fp8_w8a8': use_fp8_w8a8, 'use_int8_w8a16':
643-
use_int8_w8a16, 'fp8_type': fp8_type
704+
use_int8_w8a16, 'use_int8_w8a8': use_int8_w8a8, 'fp8_type': fp8_type
644705
})
645706

646707
@triton.testing.perf_report([benchmark])
647-
def bench_moe_gemm(M, N, K, E, top_k, dtype, routed_weight, metric, use_fp8_w8a8, use_int8_w8a16, fp8_type,
648-
model=None):
708+
def bench_moe_gemm(M, N, K, E, top_k, dtype, routed_weight, metric, use_fp8_w8a8, use_int8_w8a16, use_int8_w8a8,
709+
fp8_type, model=None):
649710
a, b, c, metadata = input_helper(M, N, K, top_k, E, routed_weight=routed_weight, use_fp8_w8a8=use_fp8_w8a8,
650-
use_int8_w8a16=use_int8_w8a16, fp8_type=fp8_type, dtype=dtype)
711+
use_int8_w8a16=use_int8_w8a16, use_int8_w8a8=use_int8_w8a8, fp8_type=fp8_type,
712+
dtype=dtype)
651713

652714
# (M, K) * (top_k, N, K) -> (M, top_k, N). 2 for multiplication and accumulation
653715
flops = 2.0 * M * top_k * K * N
@@ -658,6 +720,9 @@ def bench_moe_gemm(M, N, K, E, top_k, dtype, routed_weight, metric, use_fp8_w8a8
658720
if use_fp8_w8a8:
659721
a_bytes = b_bytes = torch.tensor([], dtype=fp8_type).element_size()
660722
c_bytes = torch.tensor([], dtype=dtype).element_size()
723+
if use_int8_w8a8:
724+
a_bytes = b_bytes = torch.tensor([], dtype=torch.int8).element_size()
725+
c_bytes = torch.tensor([], dtype=torch.int8).element_size()
661726
elif use_int8_w8a16:
662727
b_bytes = torch.tensor([], dtype=torch.int8).element_size()
663728
a_bytes = c_bytes = torch.tensor([], dtype=dtype).element_size()
@@ -705,6 +770,7 @@ def parse_args():
705770
parser.add_argument("-top_k", type=int, default=0, help="top_k experts per token")
706771
parser.add_argument("-routed_weight", action='store_true', default=False)
707772
parser.add_argument("-int8_w8a16", action='store_true', default=False)
773+
parser.add_argument("-int8_w8a8", action='store_true', default=False)
708774
parser.add_argument("-fp8_w8a8", action='store_true', default=False)
709775
parser.add_argument("-dtype", default='fp16')
710776
parser.add_argument("-fp8_type", default='e5m2fnuz')

0 commit comments

Comments
 (0)