@@ -309,6 +309,12 @@ def apply(self, module: Linear, input: torch.Tensor,
309309 bias : Optional [torch .Tensor ], * args , ** kwargs ):
310310 raise NotImplementedError
311311
312+ @abstractmethod
313+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
314+ bias : Optional [torch .Tensor ], tp_rank : int ,
315+ tp_group : List [int ], * args , ** kwargs ):
316+ raise NotImplementedError
317+
312318 def load_weights (self ,
313319 module : Linear ,
314320 weights : List [Dict ],
@@ -395,6 +401,11 @@ def apply(self, module: Linear, input: torch.Tensor,
395401 output = F .linear (input , module .weight , bias )
396402 return output
397403
404+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
405+ bias : Optional [torch .Tensor ], tp_rank : int ,
406+ tp_group : List [int ], * args , ** kwargs ):
407+ raise NotImplementedError
408+
398409 def load_weights_vanilla (self ,
399410 module : Linear ,
400411 weights : List [Dict ],
@@ -511,6 +522,11 @@ def apply(self, module: Linear, input: torch.Tensor,
511522 output = output + bias
512523 return output
513524
525+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
526+ bias : Optional [torch .Tensor ], tp_rank : int ,
527+ tp_group : List [int ], * args , ** kwargs ):
528+ raise NotImplementedError
529+
514530 def load_kv_scales (self , weights : List [Dict ]):
515531 k_scale , v_scale = [], []
516532 for w in weights :
@@ -655,6 +671,11 @@ def apply(self, module: Linear, input: torch.Tensor,
655671 output = output + bias
656672 return output
657673
674+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
675+ bias : Optional [torch .Tensor ], tp_rank : int ,
676+ tp_group : List [int ], * args , ** kwargs ):
677+ raise NotImplementedError
678+
658679 def _get_scale_name (self , weights : List [Dict ]):
659680 # `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe.
660681 # Actually they hold identical values of data_amax / 448.
@@ -769,6 +790,11 @@ def apply(self, module: Linear, input: torch.Tensor,
769790 output = output + bias
770791 return output
771792
793+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
794+ bias : Optional [torch .Tensor ], tp_rank : int ,
795+ tp_group : List [int ], * args , ** kwargs ):
796+ raise NotImplementedError
797+
772798 def _get_scale_name (self , weights : List [Dict ]):
773799 # `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe.
774800 # Actually they hold identical values of data_amax / 448.
@@ -950,6 +976,28 @@ def apply(self, module: Linear, input: torch.Tensor,
950976 output = output + bias
951977 return output
952978
979+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
980+ bias : Optional [torch .Tensor ], tp_rank : int ,
981+ tp_group : List [int ], * args , ** kwargs ):
982+ if isinstance (input , Fp4QuantizedTensor ):
983+ act_fp4 , act_sf = input .fp4_tensor , input .scaling_factor
984+ elif isinstance (input , tuple ):
985+ act_fp4 , act_sf = input
986+ else :
987+ act_fp4 , act_sf = torch .ops .trtllm .fp4_quantize (
988+ input , module .input_scale , module .scaling_vector_size , False )
989+
990+ output = torch .ops .trtllm .nvfp4_gemm_allreduce (
991+ act_fp4 , module .weight , act_sf , module .weight_scale , module .alpha ,
992+ module .dtype , tp_rank , tp_group )
993+ # Take the dim of out_features if padded. Make sure the output is contiguous
994+ if output .shape [- 1 ] > module .out_features :
995+ output = output [..., :module .out_features ].contiguous ()
996+
997+ if bias is not None :
998+ output = output + bias
999+ return output
1000+
9531001 def load_kv_scales (self , weights : List [Dict ]):
9541002 k_scale , v_scale = [], []
9551003 for w in weights :
@@ -1189,6 +1237,11 @@ def apply(self, module: Linear, input: torch.Tensor,
11891237 output = output + bias
11901238 return output
11911239
1240+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
1241+ bias : Optional [torch .Tensor ], tp_rank : int ,
1242+ tp_group : List [int ], * args , ** kwargs ):
1243+ raise NotImplementedError
1244+
11921245 def load_weight_scales (
11931246 self ,
11941247 weights : List [Dict ],
@@ -1357,6 +1410,16 @@ def apply(self, module: Linear, input: torch.Tensor,
13571410 output = output + bias
13581411 return output
13591412
1413+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
1414+ bias : Optional [torch .Tensor ], tp_rank : int ,
1415+ tp_group : List [int ], * args , ** kwargs ):
1416+ raise NotImplementedError
1417+
1418+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
1419+ bias : Optional [torch .Tensor ], tp_rank : int ,
1420+ tp_group : List [int ], * args , ** kwargs ):
1421+ raise NotImplementedError
1422+
13601423 def load_weight_scales (self ,
13611424 weights : List [Dict ],
13621425 tp_size : int = 1 ,
@@ -2016,6 +2079,7 @@ def __init__(
20162079 use_cublaslt_nvfp4_blockscaling_mm : bool = False ,
20172080 disable_deep_gemm : bool = False ,
20182081 fused_weight_shard_indices_mapping : Optional [dict ] = None ,
2082+ use_fused_gemm_allreduce : bool = False ,
20192083 ):
20202084 from ..distributed import AllReduce
20212085
@@ -2065,6 +2129,8 @@ def __init__(
20652129 self .reduce_output = reduce_output
20662130 self .use_custom_cublas_mm = use_custom_cublas_mm
20672131 self .lora = lora
2132+ self .use_fused_gemm_allreduce = use_fused_gemm_allreduce and self .quant_config .layer_quant_mode .has_nvfp4 (
2133+ )
20682134
20692135 self .enable_cuda_core = False
20702136 if torch .cuda .is_available ():
@@ -2164,6 +2230,20 @@ def apply_linear(self,
21642230 output = output + lora_result
21652231 return output
21662232
2233+ def apply_linear_allreduce (self ,
2234+ input ,
2235+ bias ,
2236+ lora_params : Optional [dict ] | None = None ,
2237+ layer_idx : Optional [int ] | None = None ):
2238+ output = self .quant_method .apply_linear_allreduce (
2239+ self , input , bias , self .tp_rank , self .mapping .tp_group )
2240+
2241+ if self .lora is not None and bool (lora_params ):
2242+ lora_result = self .lora (input , lora_params , layer_idx )
2243+ if lora_result is not None :
2244+ output = output + lora_result
2245+ return output
2246+
21672247 def _maybe_fuse_bias_into_allreduce (
21682248 self ,
21692249 bias : Optional [torch .Tensor ],
@@ -2190,16 +2270,23 @@ def forward(
21902270 layer_idx : Optional [int ] = None ,
21912271 ) -> torch .Tensor :
21922272 if self .tp_mode == TensorParallelMode .ROW :
2273+ use_fused_gemm_allreduce = self .use_fused_gemm_allreduce and (
2274+ all_reduce_params is None or
2275+ (all_reduce_params .enable_allreduce == True
2276+ and all_reduce_params .fusion_op == AllReduceFusionOp .NONE ))
21932277 bias = None if (self .tp_rank > 0 ) else self .bias
21942278 if self .reduce_output :
2195- fuse_bias = self ._maybe_fuse_bias_into_allreduce (
2196- bias , all_reduce_params )
2197- bias = None if fuse_bias else bias
2198- output = self .apply_linear (input , bias , lora_params , layer_idx )
2199- output = self .all_reduce (
2200- output ,
2201- all_reduce_params = all_reduce_params ,
2202- )
2279+ if use_fused_gemm_allreduce :
2280+ output = self .apply_linear_allreduce (
2281+ input , bias , lora_params , layer_idx )
2282+ else :
2283+ fuse_bias = self ._maybe_fuse_bias_into_allreduce (
2284+ bias , all_reduce_params )
2285+ bias = None if fuse_bias else bias
2286+ output = self .apply_linear (input , bias , lora_params ,
2287+ layer_idx )
2288+ output = self .all_reduce (
2289+ output , all_reduce_params = all_reduce_params )
22032290 else :
22042291 output = self .apply_linear (input , bias , lora_params , layer_idx )
22052292 elif self .tp_mode == TensorParallelMode .COLUMN :
0 commit comments