Skip to content

Commit ffaee7d

Browse files
committed
import vllm cutlass fp8w8a8-block gemm
1 parent 7c05dec commit ffaee7d

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

lightllm/common/quantization/vllm_quant.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,40 @@ def apply_pingpong_fp8(
165165
from fp8_pingpong_gemm import cutlass_scaled_mm
166166

167167
return cutlass_scaled_mm(x_q, weights[0], x_scale, weights[1], out)
168+
169+
170+
@QUANTMETHODS.register(["vllm-fp8w8a8-b128"])
171+
class vLLMFP8w8a8B128QuantizationMethod(vLLMBaseQuantizationMethod):
172+
def __init__(self):
173+
super().__init__()
174+
self.blcok_size = 128
175+
176+
def quantize(self, weight: torch.Tensor):
177+
if self.is_moe:
178+
return self.quantize_moe(weight)
179+
qweight, weight_scale = ops.scaled_fp8_quant(
180+
weight.contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True
181+
)
182+
return qweight.transpose(0, 1), weight_scale
183+
184+
def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True):
185+
qweight, weight_scale, input_scale = weights
186+
m, k = input_tensor.shape
187+
n = weights[0].shape[1]
188+
if input_scale is None:
189+
input_scale = self.cache_manager.alloc_tensor(
190+
(m, k // self.block_size), torch.float32, device=input_tensor.device, is_graph_out=False
191+
)
192+
qinput_tensor = self.cache_manager.alloc_tensor(
193+
(m, k), qweight.dtype, device=qweight.device, is_graph_out=False
194+
)
195+
per_token_group_quant_fp8(input_tensor, self.block_size, qinput_tensor, input_scale)
196+
if out is None:
197+
if use_custom_tensor_mananger:
198+
out = self.cache_manager.alloc_tensor(
199+
(m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False
200+
)
201+
else:
202+
out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
203+
torch.ops._C.cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias)
204+
return out

0 commit comments

Comments
 (0)