Skip to content

Commit 80485c7

Browse files
committed
add fp4 gemm + allreduce
Signed-off-by: benzh <[email protected]>
1 parent bc25fff commit 80485c7

File tree

4 files changed

+345
-10
lines changed

4 files changed

+345
-10
lines changed

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ add_library(
104104
loraOp.cpp
105105
finegrained_mixed_dtype_gemm_thop.cpp
106106
tinygemm2.cpp
107-
dsv3RopeOp.cpp)
107+
dsv3RopeOp.cpp
108+
fusedGemmAllreduceOp.cpp)
108109
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
109110
target_link_libraries(
110111
th_common PRIVATE ${TORCH_LIBRARIES} th_utils ${Python3_LIBRARIES}

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,3 +1247,96 @@ def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
12471247
stream = get_stream(stream_id)
12481248
assert stream is not None
12491249
tensor.record_stream(stream)
1250+
1251+
1252+
class Fp4GemmAllreduceRunner(TunableRunner):
1253+
runner_dict = dict()
1254+
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
1255+
0, 0, get_last_power_of_2_num_tokens_buckets,
1256+
last_positive_power_of_2), ),
1257+
constraint_specs=(ConstraintSpec(
1258+
2, 0, fp4_scale_infer_shape), ))
1259+
1260+
def __init__(
1261+
self,
1262+
output_dtype: torch.dtype,
1263+
tp_rank: int,
1264+
tp_group: List[int],
1265+
):
1266+
self.output_dtype = output_dtype
1267+
self.tp_rank = tp_rank
1268+
self.tp_group_str = '-'.join(str(g) for g in tp_group)
1269+
instance_key = (output_dtype, self.tp_group_str)
1270+
if instance_key not in Fp4GemmAllreduceRunner.runner_dict:
1271+
Fp4GemmAllreduceRunner.runner_dict[
1272+
instance_key] = torch.classes.trtllm.Fp4GemmAllreduceRunner(
1273+
output_dtype, tp_rank, tp_group)
1274+
self.fp4_gemm_all_reduce_runner = Fp4GemmAllreduceRunner.runner_dict[
1275+
instance_key]
1276+
1277+
def unique_id(self):
1278+
return (self.output_dtype, self.tp_group_str)
1279+
1280+
def get_valid_tactics(self, inputs: List[torch.Tensor],
1281+
profile: OptimizationProfile, **kwargs) -> List[int]:
1282+
return list(range(self.fp4_gemm_all_reduce_runner.get_num_configs()))
1283+
1284+
def forward(
1285+
self,
1286+
inputs: List[torch.Tensor],
1287+
tactic: int = 0,
1288+
) -> torch.Tensor:
1289+
mat1, mat2, mat1_scale, mat2_scale, global_scale = inputs
1290+
return self.fp4_gemm_all_reduce_runner.run_gemm(
1291+
mat1,
1292+
mat2,
1293+
mat1_scale,
1294+
mat2_scale,
1295+
global_scale,
1296+
tactic,
1297+
)
1298+
1299+
1300+
@torch.library.custom_op("trtllm::nvfp4_gemm_allreduce", mutates_args=())
1301+
def nvfp4_gemm_allreduce(
1302+
act_fp4: torch.Tensor,
1303+
weight: torch.Tensor,
1304+
act_sf: torch.Tensor,
1305+
weight_scale: torch.Tensor,
1306+
alpha: torch.Tensor,
1307+
output_dtype: torch.dtype,
1308+
tp_rank: int,
1309+
tp_group: List[int],
1310+
) -> torch.Tensor:
1311+
tuner = AutoTuner.get()
1312+
1313+
# Use Cutlass runner with predefined configs
1314+
nvfp4_gemm_allreduce_runner = Fp4GemmAllreduceRunner(
1315+
output_dtype, tp_rank, tp_group)
1316+
1317+
runner_type = type(nvfp4_gemm_allreduce_runner).__name__
1318+
_, best_tactic = tuner.choose_one(
1319+
f"trtllm::nvfp4_gemm_allreduce::{runner_type}",
1320+
[nvfp4_gemm_allreduce_runner],
1321+
nvfp4_gemm_allreduce_runner.tuning_config,
1322+
[act_fp4, weight, act_sf, weight_scale, alpha],
1323+
)
1324+
1325+
return nvfp4_gemm_allreduce_runner(
1326+
inputs=[act_fp4, weight, act_sf, weight_scale, alpha],
1327+
tactic=best_tactic)
1328+
1329+
1330+
@nvfp4_gemm_allreduce.register_fake
1331+
def _(
1332+
act_fp4: torch.Tensor,
1333+
weight: torch.Tensor,
1334+
act_sf: torch.Tensor,
1335+
weight_scale: torch.Tensor,
1336+
alpha: torch.Tensor,
1337+
output_dtype: torch.dtype,
1338+
tp_rank: int,
1339+
tp_group: List[int],
1340+
) -> torch.Tensor:
1341+
return act_fp4.new_empty((act_fp4.size(0), weight.size(0)),
1342+
dtype=output_dtype)

tensorrt_llm/_torch/modules/linear.py

Lines changed: 95 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,12 @@ def apply(self, module: Linear, input: torch.Tensor,
309309
bias: Optional[torch.Tensor], *args, **kwargs):
310310
raise NotImplementedError
311311

312+
@abstractmethod
313+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
314+
bias: Optional[torch.Tensor], tp_rank: int,
315+
tp_group: List[int], *args, **kwargs):
316+
raise NotImplementedError
317+
312318
def load_weights(self,
313319
module: Linear,
314320
weights: List[Dict],
@@ -395,6 +401,11 @@ def apply(self, module: Linear, input: torch.Tensor,
395401
output = F.linear(input, module.weight, bias)
396402
return output
397403

404+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
405+
bias: Optional[torch.Tensor], tp_rank: int,
406+
tp_group: List[int], *args, **kwargs):
407+
raise NotImplementedError
408+
398409
def load_weights_vanilla(self,
399410
module: Linear,
400411
weights: List[Dict],
@@ -511,6 +522,11 @@ def apply(self, module: Linear, input: torch.Tensor,
511522
output = output + bias
512523
return output
513524

525+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
526+
bias: Optional[torch.Tensor], tp_rank: int,
527+
tp_group: List[int], *args, **kwargs):
528+
raise NotImplementedError
529+
514530
def load_kv_scales(self, weights: List[Dict]):
515531
k_scale, v_scale = [], []
516532
for w in weights:
@@ -655,6 +671,11 @@ def apply(self, module: Linear, input: torch.Tensor,
655671
output = output + bias
656672
return output
657673

674+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
675+
bias: Optional[torch.Tensor], tp_rank: int,
676+
tp_group: List[int], *args, **kwargs):
677+
raise NotImplementedError
678+
658679
def _get_scale_name(self, weights: List[Dict]):
659680
# `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe.
660681
# Actually they hold identical values of data_amax / 448.
@@ -769,6 +790,11 @@ def apply(self, module: Linear, input: torch.Tensor,
769790
output = output + bias
770791
return output
771792

793+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
794+
bias: Optional[torch.Tensor], tp_rank: int,
795+
tp_group: List[int], *args, **kwargs):
796+
raise NotImplementedError
797+
772798
def _get_scale_name(self, weights: List[Dict]):
773799
# `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe.
774800
# Actually they hold identical values of data_amax / 448.
@@ -950,6 +976,28 @@ def apply(self, module: Linear, input: torch.Tensor,
950976
output = output + bias
951977
return output
952978

979+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
980+
bias: Optional[torch.Tensor], tp_rank: int,
981+
tp_group: List[int], *args, **kwargs):
982+
if isinstance(input, Fp4QuantizedTensor):
983+
act_fp4, act_sf = input.fp4_tensor, input.scaling_factor
984+
elif isinstance(input, tuple):
985+
act_fp4, act_sf = input
986+
else:
987+
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(
988+
input, module.input_scale, module.scaling_vector_size, False)
989+
990+
output = torch.ops.trtllm.nvfp4_gemm_allreduce(
991+
act_fp4, module.weight, act_sf, module.weight_scale, module.alpha,
992+
module.dtype, tp_rank, tp_group)
993+
# Take the dim of out_features if padded. Make sure the output is contiguous
994+
if output.shape[-1] > module.out_features:
995+
output = output[..., :module.out_features].contiguous()
996+
997+
if bias is not None:
998+
output = output + bias
999+
return output
1000+
9531001
def load_kv_scales(self, weights: List[Dict]):
9541002
k_scale, v_scale = [], []
9551003
for w in weights:
@@ -1189,6 +1237,11 @@ def apply(self, module: Linear, input: torch.Tensor,
11891237
output = output + bias
11901238
return output
11911239

1240+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
1241+
bias: Optional[torch.Tensor], tp_rank: int,
1242+
tp_group: List[int], *args, **kwargs):
1243+
raise NotImplementedError
1244+
11921245
def load_weight_scales(
11931246
self,
11941247
weights: List[Dict],
@@ -1357,6 +1410,16 @@ def apply(self, module: Linear, input: torch.Tensor,
13571410
output = output + bias
13581411
return output
13591412

1413+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
1414+
bias: Optional[torch.Tensor], tp_rank: int,
1415+
tp_group: List[int], *args, **kwargs):
1416+
raise NotImplementedError
1417+
1418+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
1419+
bias: Optional[torch.Tensor], tp_rank: int,
1420+
tp_group: List[int], *args, **kwargs):
1421+
raise NotImplementedError
1422+
13601423
def load_weight_scales(self,
13611424
weights: List[Dict],
13621425
tp_size: int = 1,
@@ -2016,6 +2079,7 @@ def __init__(
20162079
use_cublaslt_nvfp4_blockscaling_mm: bool = False,
20172080
disable_deep_gemm: bool = False,
20182081
fused_weight_shard_indices_mapping: Optional[dict] = None,
2082+
use_fused_gemm_allreduce: bool = False,
20192083
):
20202084
from ..distributed import AllReduce
20212085

@@ -2065,6 +2129,8 @@ def __init__(
20652129
self.reduce_output = reduce_output
20662130
self.use_custom_cublas_mm = use_custom_cublas_mm
20672131
self.lora = lora
2132+
self.use_fused_gemm_allreduce = use_fused_gemm_allreduce and self.quant_config.layer_quant_mode.has_nvfp4(
2133+
)
20682134

20692135
self.enable_cuda_core = False
20702136
if torch.cuda.is_available():
@@ -2164,6 +2230,20 @@ def apply_linear(self,
21642230
output = output + lora_result
21652231
return output
21662232

2233+
def apply_linear_allreduce(self,
2234+
input,
2235+
bias,
2236+
lora_params: Optional[dict] | None = None,
2237+
layer_idx: Optional[int] | None = None):
2238+
output = self.quant_method.apply_linear_allreduce(
2239+
self, input, bias, self.tp_rank, self.mapping.tp_group)
2240+
2241+
if self.lora is not None and bool(lora_params):
2242+
lora_result = self.lora(input, lora_params, layer_idx)
2243+
if lora_result is not None:
2244+
output = output + lora_result
2245+
return output
2246+
21672247
def _maybe_fuse_bias_into_allreduce(
21682248
self,
21692249
bias: Optional[torch.Tensor],
@@ -2190,16 +2270,23 @@ def forward(
21902270
layer_idx: Optional[int] = None,
21912271
) -> torch.Tensor:
21922272
if self.tp_mode == TensorParallelMode.ROW:
2273+
use_fused_gemm_allreduce = self.use_fused_gemm_allreduce and (
2274+
all_reduce_params is None or
2275+
(all_reduce_params.enable_allreduce == True
2276+
and all_reduce_params.fusion_op == AllReduceFusionOp.NONE))
21932277
bias = None if (self.tp_rank > 0) else self.bias
21942278
if self.reduce_output:
2195-
fuse_bias = self._maybe_fuse_bias_into_allreduce(
2196-
bias, all_reduce_params)
2197-
bias = None if fuse_bias else bias
2198-
output = self.apply_linear(input, bias, lora_params, layer_idx)
2199-
output = self.all_reduce(
2200-
output,
2201-
all_reduce_params=all_reduce_params,
2202-
)
2279+
if use_fused_gemm_allreduce:
2280+
output = self.apply_linear_allreduce(
2281+
input, bias, lora_params, layer_idx)
2282+
else:
2283+
fuse_bias = self._maybe_fuse_bias_into_allreduce(
2284+
bias, all_reduce_params)
2285+
bias = None if fuse_bias else bias
2286+
output = self.apply_linear(input, bias, lora_params,
2287+
layer_idx)
2288+
output = self.all_reduce(
2289+
output, all_reduce_params=all_reduce_params)
22032290
else:
22042291
output = self.apply_linear(input, bias, lora_params, layer_idx)
22052292
elif self.tp_mode == TensorParallelMode.COLUMN:

0 commit comments

Comments
 (0)