Skip to content

Commit ed3337d

Browse files
Fix quark fp8 format loading. (ROCm#395)
* fix quark fp8 loading * fix undefined variables --------- Co-authored-by: Bowen Bao <[email protected]>
1 parent d586c39 commit ed3337d

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,26 @@ def process_weights_after_loading(self, layer) -> None:
3535
# tensor scales (thus N scales being passed to the kernel),
3636
# requantize so we can always run per tensor
3737
if self.qscheme == "per_tensor":
38-
max_w_scale, weight = requantize_with_max_scale(
39-
weight=layer.weight,
40-
weight_scale=layer.weight_scale,
41-
logical_widths=layer.logical_widths,
42-
)
43-
4438
if current_platform.is_rocm():
4539
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
46-
weight=weight,
47-
weight_scale=max_w_scale,
40+
weight=layer.weight,
41+
weight_scale=layer.weight_scale,
4842
input_scale=layer.input_scale)
49-
if input_scale is not None:
50-
layer.input_scale = Parameter(input_scale,
51-
requires_grad=False)
43+
else:
44+
max_w_scale = layer.weight_scale
45+
weight = layer.weight
46+
input_scale = layer.input_scape
47+
48+
max_w_scale, weight = requantize_with_max_scale(
49+
weight=weight,
50+
weight_scale=max_w_scale,
51+
logical_widths=layer.logical_widths,
52+
)
5253

5354
layer.weight = Parameter(weight.t(), requires_grad=False)
5455
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
56+
if input_scale is not None:
57+
layer.input_scale = Parameter(input_scale, requires_grad=False)
5558

5659
# If channelwise, scales are already lined up, so just transpose.
5760
elif self.qscheme == "per_channel":

0 commit comments

Comments
 (0)