From c86b2604f32def17bd980455de59e9da80a1f73c Mon Sep 17 00:00:00 2001 From: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Date: Thu, 1 Jan 2026 00:18:11 -0800 Subject: [PATCH 1/3] fix: deletion of nested params in sharding Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> --- .../transform/library/fuse_mamba_a_log.py | 24 +++---------------- .../auto_deploy/transform/library/sharding.py | 8 ++++++- tensorrt_llm/_torch/auto_deploy/utils/attr.py | 22 +++++++++++++++++ 3 files changed, 32 insertions(+), 22 deletions(-) create mode 100644 tensorrt_llm/_torch/auto_deploy/utils/attr.py diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py index 5e967e1f35e..8ae3fac2aa9 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py @@ -31,31 +31,13 @@ from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface +from ...utils.attr import del_attr_by_name as _del_attr_by_name +from ...utils.attr import get_attr_by_name as _get_attr_by_name +from ...utils.attr import set_attr_by_name as _set_attr_by_name from ...utils.logger import ad_logger from ...utils.pattern_matcher import ADPatternMatcherPass from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry - -def _get_attr_by_name(obj, name): - for part in name.split("."): - obj = getattr(obj, part) - return obj - - -def _set_attr_by_name(obj, name, value): - parts = name.split(".") - for part in parts[:-1]: - obj = getattr(obj, part) - setattr(obj, parts[-1], value) - - -def _del_attr_by_name(obj, name): - parts = name.split(".") - for part in parts[:-1]: - obj = getattr(obj, part) - delattr(obj, parts[-1]) - - _PATTERN_INPUT_NAME = "a_log_like" diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 3a5d606aebe..54ef9312a4e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -33,6 +33,7 @@ from ...custom_ops.trtllm_dist import is_trtllm_op_available from ...models.factory import ModelFactory, ShardingConfigSource from ...shim.interface import CachedSequenceInterface +from ...utils.attr import del_attr_by_name as _del_attr_by_name from ...utils.logger import ad_logger from ...utils.node_utils import ( LayerSubgraph, @@ -1467,7 +1468,12 @@ def get_partition(lst, world_size, rank): for expert in ( w_up_list_to_remove + w_down_list_to_remove + w_gate_list_to_remove + scales_to_remove ): - delattr(gm, expert.target) + try: + _del_attr_by_name(gm, expert.target) + except AttributeError: + ad_logger.warning( + f"Failed to delete unused parameter {expert.target} from GraphModule." + ) def _slice_expert_dim(gm: GraphModule, tensor_node: Node, lo: int, hi: int) -> Node: diff --git a/tensorrt_llm/_torch/auto_deploy/utils/attr.py b/tensorrt_llm/_torch/auto_deploy/utils/attr.py new file mode 100644 index 00000000000..8edd8c2aa4e --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/utils/attr.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + + +def get_attr_by_name(obj, name): + for part in name.split("."): + obj = getattr(obj, part) + return obj + + +def set_attr_by_name(obj, name, value): + parts = name.split(".") + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], value) + + +def del_attr_by_name(obj, name): + parts = name.split(".") + for part in parts[:-1]: + obj = getattr(obj, part) + delattr(obj, parts[-1]) From 74d11fc977aacf026a96454c04512cce5d0668e6 Mon Sep 17 00:00:00 2001 From: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Date: Thu, 1 Jan 2026 00:43:07 -0800 Subject: [PATCH 2/3] small refactor Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> --- .../transform/library/fuse_mamba_a_log.py | 12 +++++----- .../auto_deploy/transform/library/sharding.py | 4 ++-- .../_torch/auto_deploy/utils/_graph.py | 20 +++++++++++++++++ tensorrt_llm/_torch/auto_deploy/utils/attr.py | 22 ------------------- 4 files changed, 27 insertions(+), 31 deletions(-) delete mode 100644 tensorrt_llm/_torch/auto_deploy/utils/attr.py diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py index 8ae3fac2aa9..977aeeabc5c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py @@ -31,9 +31,7 @@ from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface -from ...utils.attr import del_attr_by_name as _del_attr_by_name -from ...utils.attr import get_attr_by_name as _get_attr_by_name -from ...utils.attr import set_attr_by_name as _set_attr_by_name +from ...utils._graph import del_attr_by_name, get_attr_by_name, set_attr_by_name from ...utils.logger import ad_logger from ...utils.pattern_matcher import ADPatternMatcherPass from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry @@ -64,13 +62,13 @@ def _ensure_a_fused_param(gm: GraphModule, param_name: str) -> Optional[str]: new_param_name = param_name.replace("A_log", "A_fused") try: - _get_attr_by_name(gm, new_param_name) + get_attr_by_name(gm, new_param_name) return new_param_name except AttributeError: pass try: - a_log = _get_attr_by_name(gm, param_name) + a_log = get_attr_by_name(gm, param_name) except AttributeError: ad_logger.warning(f"Could not find attribute {param_name} in gm.") return None @@ -78,7 +76,7 @@ def _ensure_a_fused_param(gm: GraphModule, param_name: str) -> Optional[str]: with torch.no_grad(): a_fused = -torch.exp(a_log.float()) - _set_attr_by_name( + set_attr_by_name( gm, new_param_name, nn.Parameter(a_fused, requires_grad=False), @@ -102,7 +100,7 @@ def _maybe_remove(name: str) -> None: if not name.endswith("A_log") or name in used_a_log_targets: return try: - _del_attr_by_name(gm, name) + del_attr_by_name(gm, name) removed = True except AttributeError: ad_logger.warning(f"Failed to delete unused parameter {name} from GraphModule.") diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 54ef9312a4e..deb82a43e66 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -33,7 +33,7 @@ from ...custom_ops.trtllm_dist import is_trtllm_op_available from ...models.factory import ModelFactory, ShardingConfigSource from ...shim.interface import CachedSequenceInterface -from ...utils.attr import del_attr_by_name as _del_attr_by_name +from ...utils._graph import del_attr_by_name from ...utils.logger import ad_logger from ...utils.node_utils import ( LayerSubgraph, @@ -1469,7 +1469,7 @@ def get_partition(lst, world_size, rank): w_up_list_to_remove + w_down_list_to_remove + w_gate_list_to_remove + scales_to_remove ): try: - _del_attr_by_name(gm, expert.target) + del_attr_by_name(gm, expert.target) except AttributeError: ad_logger.warning( f"Failed to delete unused parameter {expert.target} from GraphModule." diff --git a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py index cd61bd52f1e..73f972ea583 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py @@ -401,3 +401,23 @@ def get_lm_head_weights(model: nn.Module) -> torch.Tensor: gm, output_node = get_output_node(model) lm_head_node = get_lm_head_node(gm, output_node) return get_weight_tensor(gm, lm_head_node) + + +def get_attr_by_name(obj, name): + for part in name.split("."): + obj = getattr(obj, part) + return obj + + +def set_attr_by_name(obj, name, value): + parts = name.split(".") + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], value) + + +def del_attr_by_name(obj, name): + parts = name.split(".") + for part in parts[:-1]: + obj = getattr(obj, part) + delattr(obj, parts[-1]) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/attr.py b/tensorrt_llm/_torch/auto_deploy/utils/attr.py deleted file mode 100644 index 8edd8c2aa4e..00000000000 --- a/tensorrt_llm/_torch/auto_deploy/utils/attr.py +++ /dev/null @@ -1,22 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. -# SPDX-License-Identifier: Apache-2.0 - - -def get_attr_by_name(obj, name): - for part in name.split("."): - obj = getattr(obj, part) - return obj - - -def set_attr_by_name(obj, name, value): - parts = name.split(".") - for part in parts[:-1]: - obj = getattr(obj, part) - setattr(obj, parts[-1], value) - - -def del_attr_by_name(obj, name): - parts = name.split(".") - for part in parts[:-1]: - obj = getattr(obj, part) - delattr(obj, parts[-1]) From 8b1e44d7ba8c80ff5d7a29eeca20842c3a72ea00 Mon Sep 17 00:00:00 2001 From: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Date: Thu, 1 Jan 2026 00:58:45 -0800 Subject: [PATCH 3/3] docstrings Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> --- .../_torch/auto_deploy/utils/_graph.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py index 73f972ea583..04db9bd80d8 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py @@ -404,12 +404,34 @@ def get_lm_head_weights(model: nn.Module) -> torch.Tensor: def get_attr_by_name(obj, name): + """Get an attribute specified by a dot-separated path on an object. + + Args: + obj: The root object from which to resolve the attribute path. + name (str): Dot-separated attribute path (e.g., "a.b.c"). + + Returns: + The value of the resolved attribute. + + Raises: + AttributeError: If any component in the path does not exist. + """ for part in name.split("."): obj = getattr(obj, part) return obj def set_attr_by_name(obj, name, value): + """Set an attribute specified by a dot-separated path on an object. + + Args: + obj: The root object on which to set the attribute. + name (str): Dot-separated attribute path (e.g., "a.b.c"). + value: The value to assign to the target attribute. + + Raises: + AttributeError: If any intermediate component in the path does not exist. + """ parts = name.split(".") for part in parts[:-1]: obj = getattr(obj, part) @@ -417,6 +439,16 @@ def set_attr_by_name(obj, name, value): def del_attr_by_name(obj, name): + """Delete an attribute specified by a dot-separated path from an object. + + Args: + obj: The root object from which to delete the attribute. + name (str): Dot-separated attribute path (e.g., "a.b.c"). + + Raises: + AttributeError: If any intermediate component in the path does not exist + or if the final attribute does not exist. + """ parts = name.split(".") for part in parts[:-1]: obj = getattr(obj, part)