Skip to content

Commit fc7f21f

Browse files
committed
tuning searching space
1 parent ba6258c commit fc7f21f

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -343,11 +343,13 @@ def grouped_matmul_kernel(
343343
group_id = pid // num_pid_in_group
344344
first_pid_m = group_id * GROUP_SIZE_M
345345
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
346-
in_group_index = pid % num_pid_in_group
347-
back_mark = (in_group_index // group_size_m) % 2
348-
back_mark1 = -1 * (2 * back_mark - 1)
349-
pid_m = first_pid_m + back_mark * (group_size_m - 1) + back_mark1 * (in_group_index % group_size_m)
346+
pid_m = first_pid_m + pid % num_pid_in_group % group_size_m
350347
pid_n = (pid % num_pid_in_group) // group_size_m
348+
# in_group_index = pid % num_pid_in_group
349+
# back_mark = (in_group_index // group_size_m) % 2
350+
# back_mark1 = -1 * (2 * back_mark - 1)
351+
# pid_m = first_pid_m + back_mark * (group_size_m - 1) + back_mark1 * (in_group_index % group_size_m)
352+
# pid_n = (pid % num_pid_in_group) // group_size_m
351353

352354
expert_id = tl.load(mblocks_to_expert_id + pid_m)
353355

@@ -488,12 +490,12 @@ def _get_grouped_matmul_configs():
488490
"num_stages": ns,
489491
"NEED_TRANS": need_trans,
490492
}
491-
for ns in [1, 2, 3, 4, 5]
492-
for gm in [1, 2, 4, 8]
493-
for nw in [2, 4, 8]
493+
for ns in [2, 3, 4, 5]
494+
for gm in [1, 16, 32, 64]
495+
for nw in [4, 8]
494496
for bm in [16, 32, 64, 128]
495497
for bn in [16, 32, 64, 128]
496-
for bk in [16, 32, 64, 128]
498+
for bk in [32, 64, 128]
497499
for need_trans in [True, False]
498500
]
499501

lightllm/common/triton_utils/autotuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def _try_load_cache(self, static_key):
192192
self.cached_configs[static_key] = orjson.loads(f.read())
193193
return
194194

195-
def _bench(self, *args, n_repeat=3, n_retries=1, **kwargs):
195+
def _bench(self, *args, n_repeat=3, n_retries=5, **kwargs):
196196
from triton.compiler.errors import CompileTimeAssertionFailure
197197
from triton.runtime.errors import OutOfResources, PTXASError
198198

0 commit comments

Comments
 (0)