Skip to content

Commit 7cf4e62

Browse files
author
wangzaijun
committed
fix fp8 pingpong gemm
1 parent f552e4c commit 7cf4e62

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

lightllm/common/quantization/vllm_quant.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,12 @@ def apply_scaled_mm_fp8(self, input_tensor, weights, bias=None, out=None, worksp
125125
def apply_pingpong_fp8(self, input_tensor, weights, bias=None, out=None, workspace=None):
126126
x_q, x_scale = ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=False)
127127
assert bias is None
128-
assert out is None
128+
m = input_tensor.shape[0]
129+
n = weights[0].shape[1]
130+
if out is None:
131+
out = g_cache_manager.alloc_tensor(
132+
(m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False
133+
)
129134
from fp8_pingpong_gemm import cutlass_scaled_mm
130135

131-
return cutlass_scaled_mm(x_q, weights[0], x_scale, weights[1])
136+
return cutlass_scaled_mm(x_q, weights[0], x_scale, weights[1], out)

0 commit comments

Comments
 (0)