diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py index 5e967e1f35e..977aeeabc5c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py @@ -31,31 +31,11 @@ from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface +from ...utils._graph import del_attr_by_name, get_attr_by_name, set_attr_by_name from ...utils.logger import ad_logger from ...utils.pattern_matcher import ADPatternMatcherPass from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry - -def _get_attr_by_name(obj, name): - for part in name.split("."): - obj = getattr(obj, part) - return obj - - -def _set_attr_by_name(obj, name, value): - parts = name.split(".") - for part in parts[:-1]: - obj = getattr(obj, part) - setattr(obj, parts[-1], value) - - -def _del_attr_by_name(obj, name): - parts = name.split(".") - for part in parts[:-1]: - obj = getattr(obj, part) - delattr(obj, parts[-1]) - - _PATTERN_INPUT_NAME = "a_log_like" @@ -82,13 +62,13 @@ def _ensure_a_fused_param(gm: GraphModule, param_name: str) -> Optional[str]: new_param_name = param_name.replace("A_log", "A_fused") try: - _get_attr_by_name(gm, new_param_name) + get_attr_by_name(gm, new_param_name) return new_param_name except AttributeError: pass try: - a_log = _get_attr_by_name(gm, param_name) + a_log = get_attr_by_name(gm, param_name) except AttributeError: ad_logger.warning(f"Could not find attribute {param_name} in gm.") return None @@ -96,7 +76,7 @@ def _ensure_a_fused_param(gm: GraphModule, param_name: str) -> Optional[str]: with torch.no_grad(): a_fused = -torch.exp(a_log.float()) - _set_attr_by_name( + set_attr_by_name( gm, new_param_name, nn.Parameter(a_fused, requires_grad=False), @@ -120,7 +100,7 @@ def _maybe_remove(name: str) -> None: if not name.endswith("A_log") or name in used_a_log_targets: return try: - _del_attr_by_name(gm, name) + del_attr_by_name(gm, name) removed = True except AttributeError: ad_logger.warning(f"Failed to delete unused parameter {name} from GraphModule.") diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 3a5d606aebe..deb82a43e66 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -33,6 +33,7 @@ from ...custom_ops.trtllm_dist import is_trtllm_op_available from ...models.factory import ModelFactory, ShardingConfigSource from ...shim.interface import CachedSequenceInterface +from ...utils._graph import del_attr_by_name from ...utils.logger import ad_logger from ...utils.node_utils import ( LayerSubgraph, @@ -1467,7 +1468,12 @@ def get_partition(lst, world_size, rank): for expert in ( w_up_list_to_remove + w_down_list_to_remove + w_gate_list_to_remove + scales_to_remove ): - delattr(gm, expert.target) + try: + del_attr_by_name(gm, expert.target) + except AttributeError: + ad_logger.warning( + f"Failed to delete unused parameter {expert.target} from GraphModule." + ) def _slice_expert_dim(gm: GraphModule, tensor_node: Node, lo: int, hi: int) -> Node: diff --git a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py index cd61bd52f1e..04db9bd80d8 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py @@ -401,3 +401,55 @@ def get_lm_head_weights(model: nn.Module) -> torch.Tensor: gm, output_node = get_output_node(model) lm_head_node = get_lm_head_node(gm, output_node) return get_weight_tensor(gm, lm_head_node) + + +def get_attr_by_name(obj, name): + """Get an attribute specified by a dot-separated path on an object. + + Args: + obj: The root object from which to resolve the attribute path. + name (str): Dot-separated attribute path (e.g., "a.b.c"). + + Returns: + The value of the resolved attribute. + + Raises: + AttributeError: If any component in the path does not exist. + """ + for part in name.split("."): + obj = getattr(obj, part) + return obj + + +def set_attr_by_name(obj, name, value): + """Set an attribute specified by a dot-separated path on an object. + + Args: + obj: The root object on which to set the attribute. + name (str): Dot-separated attribute path (e.g., "a.b.c"). + value: The value to assign to the target attribute. + + Raises: + AttributeError: If any intermediate component in the path does not exist. + """ + parts = name.split(".") + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], value) + + +def del_attr_by_name(obj, name): + """Delete an attribute specified by a dot-separated path from an object. + + Args: + obj: The root object from which to delete the attribute. + name (str): Dot-separated attribute path (e.g., "a.b.c"). + + Raises: + AttributeError: If any intermediate component in the path does not exist + or if the final attribute does not exist. + """ + parts = name.split(".") + for part in parts[:-1]: + obj = getattr(obj, part) + delattr(obj, parts[-1])