Skip to content

Commit aa63571

Browse files
authored
Applying weight padding to deepseek (ROCm#421)
1 parent 5f8d758 commit aa63571

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,6 @@ def create_weights(
188188
weight_loader = extra_weight_attrs.get("weight_loader")
189189

190190
if self.block_quant:
191-
assert not envs.VLLM_FP8_PADDING, (
192-
"FP8 weight padding is not supported in block quantization.")
193191
tp_size = get_tensor_model_parallel_world_size()
194192
assert self.quant_config.weight_block_size is not None
195193
block_n, block_k = (
@@ -273,6 +271,17 @@ def create_weights(
273271
else:
274272
layer.register_parameter("input_scale", None)
275273

274+
def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor:
275+
# Pad the weight tensor. This is an optimization on ROCm platform, which
276+
# can benefit from tensors located far enough from one another in memory
277+
if (current_platform.is_rocm() and envs.VLLM_FP8_PADDING
278+
and weight.stride(-1) == 1
279+
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
280+
num_pad = 256 // weight.element_size()
281+
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
282+
torch.cuda.empty_cache()
283+
return weight
284+
276285
def process_weights_after_loading(self, layer: Module) -> None:
277286
# TODO(rob): refactor block quant into separate class.
278287
if self.block_quant:
@@ -286,6 +295,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
286295
weight = layer.weight.data
287296
weight_scale_inv = layer.weight_scale_inv.data
288297

298+
weight = self.add_padding_to_weight(weight)
299+
289300
# Torch.compile cannot use Parameter subclasses.
290301
layer.weight = Parameter(weight, requires_grad=False)
291302
layer.weight_scale_inv = Parameter(weight_scale_inv,
@@ -353,14 +364,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
353364
logical_widths=layer.logical_widths,
354365
)
355366

356-
# Pad the weight
357-
if envs.VLLM_FP8_PADDING and weight.stride(-1) == 1 \
358-
and (weight.stride(-2) * weight.element_size()) % 512 == 0:
359-
num_pad = 256 // weight.element_size()
360-
weight = F.pad(weight, (0, num_pad), "constant",
361-
0)[..., :-num_pad]
362-
torch.cuda.empty_cache()
363-
367+
weight = self.add_padding_to_weight(weight)
364368
# Update layer with new values.
365369
layer.weight = Parameter(weight.t(), requires_grad=False)
366370
layer.weight_scale = Parameter(weight_scale, requires_grad=False)

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def w8a8_block_fp8_matmul(
478478
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
479479
M = A.numel() // A.shape[-1]
480480

481-
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
481+
assert B.ndim == 2 and Bs.ndim == 2
482482
N, K = B.shape
483483
assert triton.cdiv(N, block_n) == Bs.shape[0]
484484
assert triton.cdiv(K, block_k) == Bs.shape[1]

0 commit comments

Comments
 (0)