Skip to content

Commit ddbddad

Browse files
committed
fix
1 parent 3be751e commit ddbddad

File tree

1 file changed

+10
-30
lines changed

1 file changed

+10
-30
lines changed

lightllm/models/deepseek2/triton_kernel/rotary_emb.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import triton
44
import triton.language as tl
5-
5+
import itertools
66
from lightllm.common.triton_utils.autotuner import autotune, nearest_power_of_2
77

88
@triton.jit
@@ -95,35 +95,15 @@ def _rotary_kernel(
9595

9696
def get_test_configs():
9797
configs = []
98-
for num_stages in [
99-
1,
100-
2,
101-
3,
102-
4,
103-
5,
104-
]:
105-
for GROUP_SIZE_M in [
106-
1,
107-
2,
108-
4,
109-
]:
110-
for num_warps in [
111-
2,
112-
4,
113-
8,
114-
]:
115-
for BLOCK_SIZE_M in [16, 32, 64, 128]:
116-
for BLOCK_SIZE_N in [32, 64, 128]:
117-
for BLOCK_SIZE_K in [32, 64, 128]:
118-
t_config = {
119-
"BLOCK_SIZE_M": BLOCK_SIZE_M,
120-
"BLOCK_SIZE_N": BLOCK_SIZE_N,
121-
"BLOCK_SIZE_K": BLOCK_SIZE_K,
122-
"GROUP_SIZE_M": GROUP_SIZE_M,
123-
"num_warps": num_warps,
124-
"num_stages": num_stages,
125-
}
126-
configs.append(t_config)
98+
result = itertools.product([1, 2, 4, 8, 16, 32], [1, 2, 4, 8], [1, 2, 3, 4, 5], [1, 2, 4, 8, 16])
99+
for BLOCK_SEQ, num_warps, num_stages, HEAD_PARALLEL_NUM in result:
100+
t_config = {
101+
"BLOCK_SEQ": BLOCK_SEQ,
102+
"HEAD_PARALLEL_NUM": HEAD_PARALLEL_NUM,
103+
"num_warps": num_warps,
104+
"num_stages": num_stages,
105+
}
106+
configs.append(t_config)
127107
return configs
128108

129109
def get_static_key(q, k):

0 commit comments

Comments
 (0)