Skip to content

Commit cb163d0

Browse files
committed
add fp4 gemm + allreduce
Signed-off-by: benzh <[email protected]>
1 parent e06c582 commit cb163d0

File tree

4 files changed

+345
-10
lines changed

4 files changed

+345
-10
lines changed

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ add_library(
104104
loraOp.cpp
105105
finegrained_mixed_dtype_gemm_thop.cpp
106106
tinygemm2.cpp
107-
dsv3RopeOp.cpp)
107+
dsv3RopeOp.cpp
108+
fusedGemmAllreduceOp.cpp)
108109
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
109110
target_link_libraries(
110111
th_common PRIVATE ${TORCH_LIBRARIES} th_utils ${Python3_LIBRARIES}

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,3 +1614,96 @@ def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
16141614
stream = get_stream(stream_id)
16151615
assert stream is not None
16161616
tensor.record_stream(stream)
1617+
1618+
1619+
class Fp4GemmAllreduceRunner(TunableRunner):
1620+
runner_dict = dict()
1621+
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
1622+
0, 0, get_last_power_of_2_num_tokens_buckets,
1623+
last_positive_power_of_2), ),
1624+
constraint_specs=(ConstraintSpec(
1625+
2, 0, fp4_scale_infer_shape), ))
1626+
1627+
def __init__(
1628+
self,
1629+
output_dtype: torch.dtype,
1630+
tp_rank: int,
1631+
tp_group: List[int],
1632+
):
1633+
self.output_dtype = output_dtype
1634+
self.tp_rank = tp_rank
1635+
self.tp_group_str = '-'.join(str(g) for g in tp_group)
1636+
instance_key = (output_dtype, self.tp_group_str)
1637+
if instance_key not in Fp4GemmAllreduceRunner.runner_dict:
1638+
Fp4GemmAllreduceRunner.runner_dict[
1639+
instance_key] = torch.classes.trtllm.Fp4GemmAllreduceRunner(
1640+
output_dtype, tp_rank, tp_group)
1641+
self.fp4_gemm_all_reduce_runner = Fp4GemmAllreduceRunner.runner_dict[
1642+
instance_key]
1643+
1644+
def unique_id(self):
1645+
return (self.output_dtype, self.tp_group_str)
1646+
1647+
def get_valid_tactics(self, inputs: List[torch.Tensor],
1648+
profile: OptimizationProfile, **kwargs) -> List[int]:
1649+
return list(range(self.fp4_gemm_all_reduce_runner.get_num_configs()))
1650+
1651+
def forward(
1652+
self,
1653+
inputs: List[torch.Tensor],
1654+
tactic: int = 0,
1655+
) -> torch.Tensor:
1656+
mat1, mat2, mat1_scale, mat2_scale, global_scale = inputs
1657+
return self.fp4_gemm_all_reduce_runner.run_gemm(
1658+
mat1,
1659+
mat2,
1660+
mat1_scale,
1661+
mat2_scale,
1662+
global_scale,
1663+
tactic,
1664+
)
1665+
1666+
1667+
@torch.library.custom_op("trtllm::nvfp4_gemm_allreduce", mutates_args=())
1668+
def nvfp4_gemm_allreduce(
1669+
act_fp4: torch.Tensor,
1670+
weight: torch.Tensor,
1671+
act_sf: torch.Tensor,
1672+
weight_scale: torch.Tensor,
1673+
alpha: torch.Tensor,
1674+
output_dtype: torch.dtype,
1675+
tp_rank: int,
1676+
tp_group: List[int],
1677+
) -> torch.Tensor:
1678+
tuner = AutoTuner.get()
1679+
1680+
# Use Cutlass runner with predefined configs
1681+
nvfp4_gemm_allreduce_runner = Fp4GemmAllreduceRunner(
1682+
output_dtype, tp_rank, tp_group)
1683+
1684+
runner_type = type(nvfp4_gemm_allreduce_runner).__name__
1685+
_, best_tactic = tuner.choose_one(
1686+
f"trtllm::nvfp4_gemm_allreduce::{runner_type}",
1687+
[nvfp4_gemm_allreduce_runner],
1688+
nvfp4_gemm_allreduce_runner.tuning_config,
1689+
[act_fp4, weight, act_sf, weight_scale, alpha],
1690+
)
1691+
1692+
return nvfp4_gemm_allreduce_runner(
1693+
inputs=[act_fp4, weight, act_sf, weight_scale, alpha],
1694+
tactic=best_tactic)
1695+
1696+
1697+
@nvfp4_gemm_allreduce.register_fake
1698+
def _(
1699+
act_fp4: torch.Tensor,
1700+
weight: torch.Tensor,
1701+
act_sf: torch.Tensor,
1702+
weight_scale: torch.Tensor,
1703+
alpha: torch.Tensor,
1704+
output_dtype: torch.dtype,
1705+
tp_rank: int,
1706+
tp_group: List[int],
1707+
) -> torch.Tensor:
1708+
return act_fp4.new_empty((act_fp4.size(0), weight.size(0)),
1709+
dtype=output_dtype)

tensorrt_llm/_torch/modules/linear.py

Lines changed: 95 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,12 @@ def apply(self, module: Linear, input: torch.Tensor,
307307
bias: Optional[torch.Tensor], *args, **kwargs):
308308
raise NotImplementedError
309309

310+
@abstractmethod
311+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
312+
bias: Optional[torch.Tensor], tp_rank: int,
313+
tp_group: List[int], *args, **kwargs):
314+
raise NotImplementedError
315+
310316
def load_weights(self,
311317
module: Linear,
312318
weights: List[Dict],
@@ -393,6 +399,11 @@ def apply(self, module: Linear, input: torch.Tensor,
393399
output = F.linear(input, module.weight, bias)
394400
return output
395401

402+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
403+
bias: Optional[torch.Tensor], tp_rank: int,
404+
tp_group: List[int], *args, **kwargs):
405+
raise NotImplementedError
406+
396407
def load_weights_vanilla(self,
397408
module: Linear,
398409
weights: List[Dict],
@@ -509,6 +520,11 @@ def apply(self, module: Linear, input: torch.Tensor,
509520
output = output + bias
510521
return output
511522

523+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
524+
bias: Optional[torch.Tensor], tp_rank: int,
525+
tp_group: List[int], *args, **kwargs):
526+
raise NotImplementedError
527+
512528
def load_kv_scales(self, weights: List[Dict]):
513529
k_scale, v_scale = [], []
514530
for w in weights:
@@ -653,6 +669,11 @@ def apply(self, module: Linear, input: torch.Tensor,
653669
output = output + bias
654670
return output
655671

672+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
673+
bias: Optional[torch.Tensor], tp_rank: int,
674+
tp_group: List[int], *args, **kwargs):
675+
raise NotImplementedError
676+
656677
def _get_scale_name(self, weights: List[Dict]):
657678
# `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe.
658679
# Actually they hold identical values of data_amax / 448.
@@ -767,6 +788,11 @@ def apply(self, module: Linear, input: torch.Tensor,
767788
output = output + bias
768789
return output
769790

791+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
792+
bias: Optional[torch.Tensor], tp_rank: int,
793+
tp_group: List[int], *args, **kwargs):
794+
raise NotImplementedError
795+
770796
def _get_scale_name(self, weights: List[Dict]):
771797
# `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe.
772798
# Actually they hold identical values of data_amax / 448.
@@ -953,6 +979,28 @@ def apply(self, module: Linear, input: torch.Tensor,
953979
output = output + bias
954980
return output
955981

982+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
983+
bias: Optional[torch.Tensor], tp_rank: int,
984+
tp_group: List[int], *args, **kwargs):
985+
if isinstance(input, Fp4QuantizedTensor):
986+
act_fp4, act_sf = input.fp4_tensor, input.scaling_factor
987+
elif isinstance(input, tuple):
988+
act_fp4, act_sf = input
989+
else:
990+
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(
991+
input, module.input_scale, module.scaling_vector_size, False)
992+
993+
output = torch.ops.trtllm.nvfp4_gemm_allreduce(
994+
act_fp4, module.weight, act_sf, module.weight_scale, module.alpha,
995+
module.dtype, tp_rank, tp_group)
996+
# Take the dim of out_features if padded. Make sure the output is contiguous
997+
if output.shape[-1] > module.out_features:
998+
output = output[..., :module.out_features].contiguous()
999+
1000+
if bias is not None:
1001+
output = output + bias
1002+
return output
1003+
9561004
def load_kv_scales(self, weights: List[Dict]):
9571005
k_scale, v_scale = [], []
9581006
for w in weights:
@@ -1229,6 +1277,11 @@ def apply(self, module: Linear, input: torch.Tensor,
12291277
output = output + bias
12301278
return output
12311279

1280+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
1281+
bias: Optional[torch.Tensor], tp_rank: int,
1282+
tp_group: List[int], *args, **kwargs):
1283+
raise NotImplementedError
1284+
12321285
def load_weight_scales(
12331286
self,
12341287
weights: List[Dict],
@@ -1397,6 +1450,16 @@ def apply(self, module: Linear, input: torch.Tensor,
13971450
output = output + bias
13981451
return output
13991452

1453+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
1454+
bias: Optional[torch.Tensor], tp_rank: int,
1455+
tp_group: List[int], *args, **kwargs):
1456+
raise NotImplementedError
1457+
1458+
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
1459+
bias: Optional[torch.Tensor], tp_rank: int,
1460+
tp_group: List[int], *args, **kwargs):
1461+
raise NotImplementedError
1462+
14001463
def load_weight_scales(self,
14011464
weights: List[Dict],
14021465
tp_size: int = 1,
@@ -2055,6 +2118,7 @@ def __init__(
20552118
disable_deep_gemm: bool = False,
20562119
fused_weight_shard_indices_mapping: Optional[dict] = None,
20572120
nvfp4_backend: str = "auto",
2121+
use_fused_gemm_allreduce: bool = False,
20582122
):
20592123
"""
20602124
Args:
@@ -2124,6 +2188,8 @@ def __init__(
21242188
self.reduce_output = reduce_output
21252189
self.use_custom_cublas_mm = use_custom_cublas_mm
21262190
self.lora = lora
2191+
self.use_fused_gemm_allreduce = use_fused_gemm_allreduce and self.quant_config.layer_quant_mode.has_nvfp4(
2192+
)
21272193

21282194
self.enable_cuda_core = False
21292195
if torch.cuda.is_available():
@@ -2223,6 +2289,20 @@ def apply_linear(self,
22232289
output = output + lora_result
22242290
return output
22252291

2292+
def apply_linear_allreduce(self,
2293+
input,
2294+
bias,
2295+
lora_params: Optional[dict] | None = None,
2296+
layer_idx: Optional[int] | None = None):
2297+
output = self.quant_method.apply_linear_allreduce(
2298+
self, input, bias, self.tp_rank, self.mapping.tp_group)
2299+
2300+
if self.lora is not None and bool(lora_params):
2301+
lora_result = self.lora(input, lora_params, layer_idx)
2302+
if lora_result is not None:
2303+
output = output + lora_result
2304+
return output
2305+
22262306
def _maybe_fuse_bias_into_allreduce(
22272307
self,
22282308
bias: Optional[torch.Tensor],
@@ -2249,16 +2329,23 @@ def forward(
22492329
layer_idx: Optional[int] = None,
22502330
) -> torch.Tensor:
22512331
if self.tp_mode == TensorParallelMode.ROW:
2332+
use_fused_gemm_allreduce = self.use_fused_gemm_allreduce and (
2333+
all_reduce_params is None or
2334+
(all_reduce_params.enable_allreduce == True
2335+
and all_reduce_params.fusion_op == AllReduceFusionOp.NONE))
22522336
bias = None if (self.tp_rank > 0) else self.bias
22532337
if self.reduce_output:
2254-
fuse_bias = self._maybe_fuse_bias_into_allreduce(
2255-
bias, all_reduce_params)
2256-
bias = None if fuse_bias else bias
2257-
output = self.apply_linear(input, bias, lora_params, layer_idx)
2258-
output = self.all_reduce(
2259-
output,
2260-
all_reduce_params=all_reduce_params,
2261-
)
2338+
if use_fused_gemm_allreduce:
2339+
output = self.apply_linear_allreduce(
2340+
input, bias, lora_params, layer_idx)
2341+
else:
2342+
fuse_bias = self._maybe_fuse_bias_into_allreduce(
2343+
bias, all_reduce_params)
2344+
bias = None if fuse_bias else bias
2345+
output = self.apply_linear(input, bias, lora_params,
2346+
layer_idx)
2347+
output = self.all_reduce(
2348+
output, all_reduce_params=all_reduce_params)
22622349
else:
22632350
output = self.apply_linear(input, bias, lora_params, layer_idx)
22642351
elif self.tp_mode == TensorParallelMode.COLUMN:

0 commit comments

Comments
 (0)