Skip to content

Commit ff19043

Browse files
greg-kwasniewski1suyoggupta
authored andcommitted
[TRTLLM-6342][feat] Factory TP sharding of quantized models (NVIDIA#8123)
Signed-off-by: greg-kwasniewski1 <[email protected]> Co-authored-by: Suyog Gupta <[email protected]>
1 parent 19156a3 commit ff19043

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def detect_sharding_from_factory_config(
292292
num_simple_shards = 0
293293
num_row_col_shards = 0
294294

295-
for lin_node in filtered_nodes(gm.graph.nodes, is_linear_op):
295+
for lin_node in filtered_nodes(gm.graph.nodes, [is_linear_op, is_fake_quantized_linear_op]):
296296
# use node's weight name to get the module name
297297
module_name = lin_node.args[1].target
298298

@@ -368,7 +368,7 @@ def detect_sharding_from_factory_config(
368368
)
369369
num_row_col_shards += 1
370370
else:
371-
ad_logger.warning("Invalid sharding config. Skipping.")
371+
ad_logger.warning(f"Unsupported sharding action {config}. Skipping.")
372372
else:
373373
# TODO: local refers to hybrid EP+TP parallelism. Not supported yet.
374374
ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.")
@@ -387,7 +387,19 @@ def detect_sharding_from_factory_config(
387387
)
388388
num_simple_shards += 1
389389
else:
390-
ad_logger.warning("Invalid sharding config. Skipping.")
390+
ad_logger.warning(
391+
f"Unsupported sharding action {config}. Fallback to simple shard"
392+
)
393+
sharding_config.tp_transforms.append(
394+
TPShardingInfo.from_node(
395+
lin_node,
396+
split_dim=SplitDimension.COLUMN,
397+
rank=rank,
398+
world_size=world_size,
399+
dist_op="all_gather",
400+
min_local_shape=1,
401+
)
402+
)
391403
# after successful match, break the loop
392404
break
393405

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,12 @@ def filtered_nodes(
239239
for node in nodes:
240240
if target(node):
241241
yield node
242+
elif isinstance(target, Iterable) and all(isinstance(t, Callable) for t in target):
243+
for node in nodes:
244+
for t in target:
245+
if t(node):
246+
yield node
247+
break
242248
else:
243249
# Handle the case where target or ops contains operations
244250
operations = ops if ops is not None else target

0 commit comments

Comments
 (0)