Skip to content

Commit af44eee

Browse files
committed
add swizzle launch
1 parent 11015ed commit af44eee

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

lightllm/common/quantization/triton_quant/fp8/fp8w8a8_scaled_mm_per_token_kernel.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,19 @@ def save_config(cls, N: int, K: int, out_dtype: str, config_json: Dict[int, Dict
6161

6262

6363
@triton.jit
64-
def grouped_launch(pid, m, n, block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):
64+
def grouped_launch(pid, m_block_num, n_block_num, group_m: tl.constexpr):
6565

66-
grid_m = tl.cdiv(m, block_m)
67-
grid_n = tl.cdiv(n, block_n)
66+
num_pid_in_group = group_m * n_block_num
67+
group_id = pid // num_pid_in_group
68+
first_pid_m = group_id * group_m
69+
group_size_m = tl.minimum(m_block_num - first_pid_m, group_m)
70+
in_group_index = pid % num_pid_in_group
6871

69-
width = group_m * grid_n
70-
group_id = pid // width
71-
group_size = tl.minimum(grid_m - group_id * group_m, group_m)
72-
73-
pid_m = group_id * group_m + (pid % group_size)
74-
pid_n = (pid % width) // group_size
72+
# Swizzle pattern: zigzag traversal
73+
back_mark = (in_group_index // group_size_m) % 2
74+
back_mark1 = -1 * (2 * back_mark - 1)
75+
pid_m = first_pid_m + back_mark * (group_size_m - 1) + back_mark1 * (in_group_index % group_size_m)
76+
pid_n = (pid % num_pid_in_group) // group_size_m
7577

7678
return pid_m, pid_n
7779

@@ -89,6 +91,8 @@ def _scaled_mm_per_token(
8991
M,
9092
N,
9193
K,
94+
m_block_num,
95+
n_block_num,
9296
stride_am,
9397
stride_ak,
9498
stride_bk,
@@ -105,7 +109,7 @@ def _scaled_mm_per_token(
105109
GROUP_M: tl.constexpr,
106110
):
107111
pid = tl.program_id(0)
108-
pid_m, pid_n = grouped_launch(pid, M, N, BLOCK_M, BLOCK_N, GROUP_M)
112+
pid_m, pid_n = grouped_launch(pid, m_block_num, n_block_num, GROUP_M)
109113

110114
start_m = pid_m * BLOCK_M
111115
start_n = pid_n * BLOCK_N
@@ -289,6 +293,8 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
289293
M=M,
290294
N=N,
291295
K=K,
296+
m_block_num=triton.cdiv(M, BLOCK_M),
297+
n_block_num=triton.cdiv(N, BLOCK_N),
292298
stride_am=A.stride(0),
293299
stride_ak=A.stride(1),
294300
stride_bk=B.stride(0),

0 commit comments

Comments
 (0)