Skip to content

Commit 0bd3585

Browse files
committed
add shape check for triton
1 parent c7dd020 commit 0bd3585

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525

2626
def aiter_triton_gemm_check(m, n, k):
2727
if m <= 64:
28-
return ((n == 8192 and k == 8192) or (n == 10240 and k == 8192)
29-
or (n == 57344 and k == 8192) or (n == 8192 and k == 28672))
28+
return (
29+
(n == 10240 and k == 8192) or (n == 8192 and k == 8192) or (n == 57344 and k == 8192) or (n == 8192 and k == 28672) or
30+
(n == 1280 and k == 8192) or (n == 8192 and k == 1024) or (n == 7168 and k == 8192) or (n == 8192 and k == 3584)
31+
)
3032
return False
3133

3234
def gemm_with_dynamic_quant(
@@ -74,16 +76,26 @@ def gemm_with_dynamic_quant(
7476
# use hip quant kernel for performance
7577
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
7678
else:
77-
x_q = x
78-
x_s = x_scales
79+
x_q = x.view(torch.float4_e2m1fn_x2)
80+
x_s = x_scales.view(torch.float8_e8m0fnu)
7981

8082
# 32 alignment is enough for dim0 padding of output for
8183
# gemm_a4w4 kernel
8284
y = torch.empty((M + 31) // 32 * 32,
8385
weight.shape[0],
8486
device=x_q.device,
8587
dtype=out_dtype)
86-
88+
89+
# weight = weight.view(x_q.dtype)
90+
# weight_scale = weight_scale.view(x_s.dtype)
91+
# print("fp4dtype", x_q.dtype, weight.dtype, x_s.dtype, weight_scale.dtype)
92+
93+
# gemm_a4w4(x_q,
94+
# weight,
95+
# x_s,
96+
# weight_scale,
97+
# y,
98+
# bpreshuffle=True)
8799
gemm_a4w4(x_q,
88100
weight.view(x_q.dtype),
89101
x_s,

0 commit comments

Comments
 (0)