Skip to content

Commit d567353

Browse files
fix_vllm_quant (ROCm#342)
1 parent 466334a commit d567353

File tree

1 file changed

+20
-0
lines changed
  • vllm/model_executor/layers/quantization

1 file changed

+20
-0
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,26 @@ def process_weights_after_loading(self, layer: Module) -> None:
423423
requires_grad=False)
424424
layer.w2_weight = torch.nn.Parameter(w2_weight,
425425
requires_grad=False)
426+
427+
if envs.VLLM_MOE_SHUFFLE:
428+
layer.w13_weight.data = permute_weight_fp8(layer.w13_weight.data)
429+
layer.w2_weight.data = permute_weight_fp8(layer.w2_weight.data)
430+
431+
if envs.VLLM_MOE_PADDING:
432+
pad_size = 256
433+
layer.w13_weight = torch.nn.Parameter(
434+
F.pad(layer.w13_weight.data, (0, pad_size), "constant",
435+
0)[..., :-pad_size],
436+
requires_grad=False,
437+
)
438+
torch.cuda.empty_cache()
439+
layer.w2_weight = torch.nn.Parameter(
440+
F.pad(layer.w2_weight.data, (0, pad_size), "constant",
441+
0)[..., :-pad_size],
442+
requires_grad=False,
443+
)
444+
torch.cuda.empty_cache()
445+
426446
return
427447

428448
# If checkpoint is fp8, we need to handle that the

0 commit comments

Comments
 (0)