Skip to content

Commit 32651c7

Browse files
committed
Merge branch 'ds_moe' of https://github.com/ModelTC/lightllm into ds_moe
2 parents 10313cb + d8a143f commit 32651c7

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

lightllm/common/fused_moe/moe_silu_and_mul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _get_silu_and_mul_configs():
6868
{"BLOCK_M": bm, "BLOCK_N": bn, "num_warps": nw, "NUM_STAGES": ns}
6969
for ns in [1, 2, 4]
7070
for nw in [1, 4, 8]
71-
for bm in [32, 64, 128, 256]
71+
for bm in [1, 8, 32, 64, 128, 256]
7272
for bn in [32, 64, 128, 256]
7373
]
7474

lightllm/common/quantization/deepgemm_quant.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,14 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
6666

6767
if out is None:
6868
out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
69-
deep_gemm.gemm_fp8_fp8_bf16_nt([qinput_tensor, input_scale], [qweight.t(), weight_scale.t()], out)
69+
_deepgemm_fp8_nt((qinput_tensor, input_scale), (qweight.t(), weight_scale.t()), out)
7070
return out
71+
72+
73+
def _deepgemm_fp8_nt(a_tuple, b_tuple, out):
74+
if HAS_DEEPGEMM:
75+
if hasattr(deep_gemm, "gemm_fp8_fp8_bf16_nt"):
76+
return deep_gemm.gemm_fp8_fp8_bf16_nt([a_tuple[0], a_tuple[1]], [b_tuple[0], b_tuple[1]], out)
77+
if hasattr(deep_gemm, "fp8_gemm_nt"):
78+
return deep_gemm.fp8_gemm_nt((a_tuple[0], a_tuple[1]), (b_tuple[0], b_tuple[1]), out)
79+
raise RuntimeError("deep_gemm does not provide fp8 NT GEMM kernel in this version")

lightllm/common/triton_utils/autotuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def _try_load_cache(self, static_key):
192192
self.cached_configs[static_key] = orjson.loads(f.read())
193193
return
194194

195-
def _bench(self, *args, n_repeat=3, n_retries=5, **kwargs):
195+
def _bench(self, *args, n_repeat=3, n_retries=3, **kwargs):
196196
from triton.compiler.errors import CompileTimeAssertionFailure
197197
from triton.runtime.errors import OutOfResources, PTXASError
198198

0 commit comments

Comments
 (0)