File tree Expand file tree Collapse file tree 4 files changed +20
-31
lines changed
tensorrt_llm/_torch/auto_deploy
unittest/_torch/auto_deploy/unit/multigpu/transformations/library Expand file tree Collapse file tree 4 files changed +20
-31
lines changed Original file line number Diff line number Diff line change 4343 extract_weight_nodes ,
4444 filtered_nodes ,
4545 get_all_layer_subgraphs ,
46+ get_all_weights_in_subgraph ,
4647 get_layer_after_linear_node ,
4748 is_any_attention_op ,
4849 is_any_lin_op ,
@@ -1060,31 +1061,6 @@ def _resolve_tp_cls_from_node(node: Node):
10601061 return WeightShardingInfo
10611062
10621063
1063- def _get_dim0_from_arg (gm : GraphModule , arg : Union [Node , torch .Tensor ]) -> int :
1064- """Helper to get the first dimension size of an argument (Node or Tensor)."""
1065- if isinstance (arg , torch .Tensor ):
1066- return arg .shape [0 ]
1067- if isinstance (arg , Node ):
1068- if arg .op == "get_attr" :
1069- # Traverse attributes to find the tensor
1070- obj = gm
1071- for atom in arg .target .split ("." ):
1072- obj = getattr (obj , atom )
1073- return obj .shape [0 ]
1074- if "val" in arg .meta :
1075- return shape (arg )[0 ]
1076- raise ValueError (f"Cannot determine shape[0] for { arg } " )
1077-
1078-
1079- def get_all_weights_in_subgraph (
1080- sources : list [Node ],
1081- sinks : list [Node ],
1082- ):
1083- """Get all weight nodes (get_attr nodes) in the subgraph between sources and sinks."""
1084- weight_nodes = subgraph (sources , sinks , include = lambda n : n .op == "get_attr" )
1085- return weight_nodes
1086-
1087-
10881064def init_process_grid_from_config (
10891065 config : ShardingTransformConfig ,
10901066) -> Dict [ShardingDim , Dict [str , int ]]:
Original file line number Diff line number Diff line change @@ -143,8 +143,19 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node):
143143 return input_params , weight_params , output_params
144144
145145
146- def extract_weight_name (node : Node ) -> str :
146+ def get_all_weights_in_subgraph (
147+ sources : list [Node ],
148+ sinks : list [Node ],
149+ ):
150+ """Get all weight nodes (get_attr nodes) in the subgraph between sources and sinks."""
151+ weight_nodes = subgraph (sources , sinks , include = is_weight_node )
152+ return weight_nodes
153+
154+
155+ def extract_weight_name (node : Node ) -> Union [str , bool ]:
147156 weight_nodes = extract_weight_nodes (node )
157+ if len (weight_nodes .weights ) == 0 :
158+ return False
148159 return weight_nodes .weights [0 ].node_key
149160
150161
@@ -431,6 +442,10 @@ def is_dist_op(node: Node) -> bool:
431442 return is_op (node , dist_ops )
432443
433444
445+ def is_weight_node (node : Node ) -> bool :
446+ return node .op == "get_attr" and node .target and has_shape (node ) and len (shape (node )) > 0
447+
448+
434449def get_user_if_pattern_match (node , ops , numusers , user_idx : int = 0 ):
435450 """Get a user from a node if the node matches a given op set and num of users."""
436451 if node is None :
Original file line number Diff line number Diff line change @@ -353,10 +353,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mt
353353stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test] SKIP (https://nvbugs/5814203)
354354unittest/_torch/attention/test_trtllm_flashinfer_symbol_collision.py::test_flashinfer_fused_moe_matches_torch_moe SKIP (https://nvbugs/5814215)
355355full:sm89/accuracy/test_llm_api_pytorch_multimodal.py::TestNVILA_8B::test_auto_dtype SKIP (https://nvbugs/5814504)
356- unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-False-False-8] SKIP (https://nvbugs/5819002)
357- unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-False-True-8] SKIP (https://nvbugs/5819002)
358- unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-True-False-8] SKIP (https://nvbugs/5819002)
359- unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding_pattern_detection[NemotronHMamba2Mixer-torch_dist_all_reduce-True-True-8] SKIP (https://nvbugs/5819002)
360356accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] SKIP (https://nvbugs/5819005)
361357unittest/llmapi/test_mpi_session.py::test_llmapi_launch_multiple_tasks SKIP (https://nvbugs/5819014)
362358accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5819019)
Original file line number Diff line number Diff line change @@ -572,7 +572,9 @@ def _run_pattern_detection_job(
572572 fused_weight_dims = None ,
573573 )
574574 )
575- if len (node .args ) > 1 and "norm_weight" in node .args [0 ].name :
575+ if len (node .args ) > 1 and (
576+ "norm_weight" in node .args [0 ].name or "a_log" in node .args [0 ].name
577+ ):
576578 expected_transformations .append (
577579 WeightShardingInfo (
578580 target_node = node .name ,
You can’t perform that action at this time.
0 commit comments