Skip to content

Commit d8e6e22

Browse files
[https://nvbugs/5819002][fix] fix sharding tests (NVIDIA#10775)
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent d43be7b commit d8e6e22

File tree

4 files changed

+20
-31
lines changed

4 files changed

+20
-31
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
extract_weight_nodes,
4444
filtered_nodes,
4545
get_all_layer_subgraphs,
46+
get_all_weights_in_subgraph,
4647
get_layer_after_linear_node,
4748
is_any_attention_op,
4849
is_any_lin_op,
@@ -1060,31 +1061,6 @@ def _resolve_tp_cls_from_node(node: Node):
10601061
return WeightShardingInfo
10611062

10621063

1063-
def _get_dim0_from_arg(gm: GraphModule, arg: Union[Node, torch.Tensor]) -> int:
1064-
"""Helper to get the first dimension size of an argument (Node or Tensor)."""
1065-
if isinstance(arg, torch.Tensor):
1066-
return arg.shape[0]
1067-
if isinstance(arg, Node):
1068-
if arg.op == "get_attr":
1069-
# Traverse attributes to find the tensor
1070-
obj = gm
1071-
for atom in arg.target.split("."):
1072-
obj = getattr(obj, atom)
1073-
return obj.shape[0]
1074-
if "val" in arg.meta:
1075-
return shape(arg)[0]
1076-
raise ValueError(f"Cannot determine shape[0] for {arg}")
1077-
1078-
1079-
def get_all_weights_in_subgraph(
1080-
sources: list[Node],
1081-
sinks: list[Node],
1082-
):
1083-
"""Get all weight nodes (get_attr nodes) in the subgraph between sources and sinks."""
1084-
weight_nodes = subgraph(sources, sinks, include=lambda n: n.op == "get_attr")
1085-
return weight_nodes
1086-
1087-
10881064
def init_process_grid_from_config(
10891065
config: ShardingTransformConfig,
10901066
) -> Dict[ShardingDim, Dict[str, int]]:

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,19 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node):
143143
return input_params, weight_params, output_params
144144

145145

146-
def extract_weight_name(node: Node) -> str:
146+
def get_all_weights_in_subgraph(
147+
sources: list[Node],
148+
sinks: list[Node],
149+
):
150+
"""Get all weight nodes (get_attr nodes) in the subgraph between sources and sinks."""
151+
weight_nodes = subgraph(sources, sinks, include=is_weight_node)
152+
return weight_nodes
153+
154+
155+
def extract_weight_name(node: Node) -> Union[str, bool]:
147156
weight_nodes = extract_weight_nodes(node)
157+
if len(weight_nodes.weights) == 0:
158+
return False
148159
return weight_nodes.weights[0].node_key
149160

150161

@@ -431,6 +442,10 @@ def is_dist_op(node: Node) -> bool:
431442
return is_op(node, dist_ops)
432443

433444

445+
def is_weight_node(node: Node) -> bool:
446+
return node.op == "get_attr" and node.target and has_shape(node) and len(shape(node)) > 0
447+
448+
434449
def get_user_if_pattern_match(node, ops, numusers, user_idx: int = 0):
435450
"""Get a user from a node if the node matches a given op set and num of users."""
436451
if node is None:

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -353,10 +353,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mt
353353
stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test] SKIP (https://nvbugs/5814203)
354354
unittest/_torch/attention/test_trtllm_flashinfer_symbol_collision.py::test_flashinfer_fused_moe_matches_torch_moe SKIP (https://nvbugs/5814215)
355355
full:sm89/accuracy/test_llm_api_pytorch_multimodal.py::TestNVILA_8B::test_auto_dtype SKIP (https://nvbugs/5814504)
356-
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-False-False-8] SKIP (https://nvbugs/5819002)
357-
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-False-True-8] SKIP (https://nvbugs/5819002)
358-
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-True-False-8] SKIP (https://nvbugs/5819002)
359-
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-True-True-8] SKIP (https://nvbugs/5819002)
360356
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] SKIP (https://nvbugs/5819005)
361357
unittest/llmapi/test_mpi_session.py::test_llmapi_launch_multiple_tasks SKIP (https://nvbugs/5819014)
362358
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5819019)

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,9 @@ def _run_pattern_detection_job(
572572
fused_weight_dims=None,
573573
)
574574
)
575-
if len(node.args) > 1 and "norm_weight" in node.args[0].name:
575+
if len(node.args) > 1 and (
576+
"norm_weight" in node.args[0].name or "a_log" in node.args[0].name
577+
):
576578
expected_transformations.append(
577579
WeightShardingInfo(
578580
target_node=node.name,

0 commit comments

Comments
 (0)