@@ -188,8 +188,6 @@ def create_weights(
188188 weight_loader = extra_weight_attrs .get ("weight_loader" )
189189
190190 if self .block_quant :
191- assert not envs .VLLM_FP8_PADDING , (
192- "FP8 weight padding is not supported in block quantization." )
193191 tp_size = get_tensor_model_parallel_world_size ()
194192 assert self .quant_config .weight_block_size is not None
195193 block_n , block_k = (
@@ -273,6 +271,17 @@ def create_weights(
273271 else :
274272 layer .register_parameter ("input_scale" , None )
275273
274+ def add_padding_to_weight (self , weight : torch .Tensor ) -> torch .Tensor :
275+ # Pad the weight tensor. This is an optimization on ROCm platform, which
276+ # can benefit from tensors located far enough from one another in memory
277+ if (current_platform .is_rocm () and envs .VLLM_FP8_PADDING
278+ and weight .stride (- 1 ) == 1
279+ and (weight .stride (- 2 ) * weight .element_size ()) % 512 == 0 ):
280+ num_pad = 256 // weight .element_size ()
281+ weight = F .pad (weight , (0 , num_pad ), "constant" , 0 )[..., :- num_pad ]
282+ torch .cuda .empty_cache ()
283+ return weight
284+
276285 def process_weights_after_loading (self , layer : Module ) -> None :
277286 # TODO(rob): refactor block quant into separate class.
278287 if self .block_quant :
@@ -286,6 +295,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
286295 weight = layer .weight .data
287296 weight_scale_inv = layer .weight_scale_inv .data
288297
298+ weight = self .add_padding_to_weight (weight )
299+
289300 # Torch.compile cannot use Parameter subclasses.
290301 layer .weight = Parameter (weight , requires_grad = False )
291302 layer .weight_scale_inv = Parameter (weight_scale_inv ,
@@ -353,14 +364,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
353364 logical_widths = layer .logical_widths ,
354365 )
355366
356- # Pad the weight
357- if envs .VLLM_FP8_PADDING and weight .stride (- 1 ) == 1 \
358- and (weight .stride (- 2 ) * weight .element_size ()) % 512 == 0 :
359- num_pad = 256 // weight .element_size ()
360- weight = F .pad (weight , (0 , num_pad ), "constant" ,
361- 0 )[..., :- num_pad ]
362- torch .cuda .empty_cache ()
363-
367+ weight = self .add_padding_to_weight (weight )
364368 # Update layer with new values.
365369 layer .weight = Parameter (weight .t (), requires_grad = False )
366370 layer .weight_scale = Parameter (weight_scale , requires_grad = False )
0 commit comments