@@ -307,6 +307,12 @@ def apply(self, module: Linear, input: torch.Tensor,
307307 bias : Optional [torch .Tensor ], * args , ** kwargs ):
308308 raise NotImplementedError
309309
310+ @abstractmethod
311+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
312+ bias : Optional [torch .Tensor ], tp_rank : int ,
313+ tp_group : List [int ], * args , ** kwargs ):
314+ raise NotImplementedError
315+
310316 def load_weights (self ,
311317 module : Linear ,
312318 weights : List [Dict ],
@@ -393,6 +399,11 @@ def apply(self, module: Linear, input: torch.Tensor,
393399 output = F .linear (input , module .weight , bias )
394400 return output
395401
402+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
403+ bias : Optional [torch .Tensor ], tp_rank : int ,
404+ tp_group : List [int ], * args , ** kwargs ):
405+ raise NotImplementedError
406+
396407 def load_weights_vanilla (self ,
397408 module : Linear ,
398409 weights : List [Dict ],
@@ -509,6 +520,11 @@ def apply(self, module: Linear, input: torch.Tensor,
509520 output = output + bias
510521 return output
511522
523+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
524+ bias : Optional [torch .Tensor ], tp_rank : int ,
525+ tp_group : List [int ], * args , ** kwargs ):
526+ raise NotImplementedError
527+
512528 def load_kv_scales (self , weights : List [Dict ]):
513529 k_scale , v_scale = [], []
514530 for w in weights :
@@ -653,6 +669,11 @@ def apply(self, module: Linear, input: torch.Tensor,
653669 output = output + bias
654670 return output
655671
672+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
673+ bias : Optional [torch .Tensor ], tp_rank : int ,
674+ tp_group : List [int ], * args , ** kwargs ):
675+ raise NotImplementedError
676+
656677 def _get_scale_name (self , weights : List [Dict ]):
657678 # `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe.
658679 # Actually they hold identical values of data_amax / 448.
@@ -767,6 +788,11 @@ def apply(self, module: Linear, input: torch.Tensor,
767788 output = output + bias
768789 return output
769790
791+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
792+ bias : Optional [torch .Tensor ], tp_rank : int ,
793+ tp_group : List [int ], * args , ** kwargs ):
794+ raise NotImplementedError
795+
770796 def _get_scale_name (self , weights : List [Dict ]):
771797 # `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe.
772798 # Actually they hold identical values of data_amax / 448.
@@ -953,6 +979,28 @@ def apply(self, module: Linear, input: torch.Tensor,
953979 output = output + bias
954980 return output
955981
982+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
983+ bias : Optional [torch .Tensor ], tp_rank : int ,
984+ tp_group : List [int ], * args , ** kwargs ):
985+ if isinstance (input , Fp4QuantizedTensor ):
986+ act_fp4 , act_sf = input .fp4_tensor , input .scaling_factor
987+ elif isinstance (input , tuple ):
988+ act_fp4 , act_sf = input
989+ else :
990+ act_fp4 , act_sf = torch .ops .trtllm .fp4_quantize (
991+ input , module .input_scale , module .scaling_vector_size , False )
992+
993+ output = torch .ops .trtllm .nvfp4_gemm_allreduce (
994+ act_fp4 , module .weight , act_sf , module .weight_scale , module .alpha ,
995+ module .dtype , tp_rank , tp_group )
996+ # Take the dim of out_features if padded. Make sure the output is contiguous
997+ if output .shape [- 1 ] > module .out_features :
998+ output = output [..., :module .out_features ].contiguous ()
999+
1000+ if bias is not None :
1001+ output = output + bias
1002+ return output
1003+
9561004 def load_kv_scales (self , weights : List [Dict ]):
9571005 k_scale , v_scale = [], []
9581006 for w in weights :
@@ -1229,6 +1277,11 @@ def apply(self, module: Linear, input: torch.Tensor,
12291277 output = output + bias
12301278 return output
12311279
1280+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
1281+ bias : Optional [torch .Tensor ], tp_rank : int ,
1282+ tp_group : List [int ], * args , ** kwargs ):
1283+ raise NotImplementedError
1284+
12321285 def load_weight_scales (
12331286 self ,
12341287 weights : List [Dict ],
@@ -1397,6 +1450,16 @@ def apply(self, module: Linear, input: torch.Tensor,
13971450 output = output + bias
13981451 return output
13991452
1453+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
1454+ bias : Optional [torch .Tensor ], tp_rank : int ,
1455+ tp_group : List [int ], * args , ** kwargs ):
1456+ raise NotImplementedError
1457+
1458+ def apply_linear_allreduce (self , module : Linear , input : torch .Tensor ,
1459+ bias : Optional [torch .Tensor ], tp_rank : int ,
1460+ tp_group : List [int ], * args , ** kwargs ):
1461+ raise NotImplementedError
1462+
14001463 def load_weight_scales (self ,
14011464 weights : List [Dict ],
14021465 tp_size : int = 1 ,
@@ -2055,6 +2118,7 @@ def __init__(
20552118 disable_deep_gemm : bool = False ,
20562119 fused_weight_shard_indices_mapping : Optional [dict ] = None ,
20572120 nvfp4_backend : str = "auto" ,
2121+ use_fused_gemm_allreduce : bool = False ,
20582122 ):
20592123 """
20602124 Args:
@@ -2124,6 +2188,8 @@ def __init__(
21242188 self .reduce_output = reduce_output
21252189 self .use_custom_cublas_mm = use_custom_cublas_mm
21262190 self .lora = lora
2191+ self .use_fused_gemm_allreduce = use_fused_gemm_allreduce and self .quant_config .layer_quant_mode .has_nvfp4 (
2192+ )
21272193
21282194 self .enable_cuda_core = False
21292195 if torch .cuda .is_available ():
@@ -2223,6 +2289,20 @@ def apply_linear(self,
22232289 output = output + lora_result
22242290 return output
22252291
2292+ def apply_linear_allreduce (self ,
2293+ input ,
2294+ bias ,
2295+ lora_params : Optional [dict ] | None = None ,
2296+ layer_idx : Optional [int ] | None = None ):
2297+ output = self .quant_method .apply_linear_allreduce (
2298+ self , input , bias , self .tp_rank , self .mapping .tp_group )
2299+
2300+ if self .lora is not None and bool (lora_params ):
2301+ lora_result = self .lora (input , lora_params , layer_idx )
2302+ if lora_result is not None :
2303+ output = output + lora_result
2304+ return output
2305+
22262306 def _maybe_fuse_bias_into_allreduce (
22272307 self ,
22282308 bias : Optional [torch .Tensor ],
@@ -2249,16 +2329,23 @@ def forward(
22492329 layer_idx : Optional [int ] = None ,
22502330 ) -> torch .Tensor :
22512331 if self .tp_mode == TensorParallelMode .ROW :
2332+ use_fused_gemm_allreduce = self .use_fused_gemm_allreduce and (
2333+ all_reduce_params is None or
2334+ (all_reduce_params .enable_allreduce == True
2335+ and all_reduce_params .fusion_op == AllReduceFusionOp .NONE ))
22522336 bias = None if (self .tp_rank > 0 ) else self .bias
22532337 if self .reduce_output :
2254- fuse_bias = self ._maybe_fuse_bias_into_allreduce (
2255- bias , all_reduce_params )
2256- bias = None if fuse_bias else bias
2257- output = self .apply_linear (input , bias , lora_params , layer_idx )
2258- output = self .all_reduce (
2259- output ,
2260- all_reduce_params = all_reduce_params ,
2261- )
2338+ if use_fused_gemm_allreduce :
2339+ output = self .apply_linear_allreduce (
2340+ input , bias , lora_params , layer_idx )
2341+ else :
2342+ fuse_bias = self ._maybe_fuse_bias_into_allreduce (
2343+ bias , all_reduce_params )
2344+ bias = None if fuse_bias else bias
2345+ output = self .apply_linear (input , bias , lora_params ,
2346+ layer_idx )
2347+ output = self .all_reduce (
2348+ output , all_reduce_params = all_reduce_params )
22622349 else :
22632350 output = self .apply_linear (input , bias , lora_params , layer_idx )
22642351 elif self .tp_mode == TensorParallelMode .COLUMN :
0 commit comments