Skip to content

Commit 74d11fc

Browse files
committed
small refactor
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
1 parent c86b260 commit 74d11fc

File tree

4 files changed

+27
-31
lines changed

4 files changed

+27
-31
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@
3131

3232
from ...models.factory import ModelFactory
3333
from ...shim.interface import CachedSequenceInterface
34-
from ...utils.attr import del_attr_by_name as _del_attr_by_name
35-
from ...utils.attr import get_attr_by_name as _get_attr_by_name
36-
from ...utils.attr import set_attr_by_name as _set_attr_by_name
34+
from ...utils._graph import del_attr_by_name, get_attr_by_name, set_attr_by_name
3735
from ...utils.logger import ad_logger
3836
from ...utils.pattern_matcher import ADPatternMatcherPass
3937
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
@@ -64,21 +62,21 @@ def _ensure_a_fused_param(gm: GraphModule, param_name: str) -> Optional[str]:
6462

6563
new_param_name = param_name.replace("A_log", "A_fused")
6664
try:
67-
_get_attr_by_name(gm, new_param_name)
65+
get_attr_by_name(gm, new_param_name)
6866
return new_param_name
6967
except AttributeError:
7068
pass
7169

7270
try:
73-
a_log = _get_attr_by_name(gm, param_name)
71+
a_log = get_attr_by_name(gm, param_name)
7472
except AttributeError:
7573
ad_logger.warning(f"Could not find attribute {param_name} in gm.")
7674
return None
7775

7876
with torch.no_grad():
7977
a_fused = -torch.exp(a_log.float())
8078

81-
_set_attr_by_name(
79+
set_attr_by_name(
8280
gm,
8381
new_param_name,
8482
nn.Parameter(a_fused, requires_grad=False),
@@ -102,7 +100,7 @@ def _maybe_remove(name: str) -> None:
102100
if not name.endswith("A_log") or name in used_a_log_targets:
103101
return
104102
try:
105-
_del_attr_by_name(gm, name)
103+
del_attr_by_name(gm, name)
106104
removed = True
107105
except AttributeError:
108106
ad_logger.warning(f"Failed to delete unused parameter {name} from GraphModule.")

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from ...custom_ops.trtllm_dist import is_trtllm_op_available
3434
from ...models.factory import ModelFactory, ShardingConfigSource
3535
from ...shim.interface import CachedSequenceInterface
36-
from ...utils.attr import del_attr_by_name as _del_attr_by_name
36+
from ...utils._graph import del_attr_by_name
3737
from ...utils.logger import ad_logger
3838
from ...utils.node_utils import (
3939
LayerSubgraph,
@@ -1469,7 +1469,7 @@ def get_partition(lst, world_size, rank):
14691469
w_up_list_to_remove + w_down_list_to_remove + w_gate_list_to_remove + scales_to_remove
14701470
):
14711471
try:
1472-
_del_attr_by_name(gm, expert.target)
1472+
del_attr_by_name(gm, expert.target)
14731473
except AttributeError:
14741474
ad_logger.warning(
14751475
f"Failed to delete unused parameter {expert.target} from GraphModule."

tensorrt_llm/_torch/auto_deploy/utils/_graph.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,3 +401,23 @@ def get_lm_head_weights(model: nn.Module) -> torch.Tensor:
401401
gm, output_node = get_output_node(model)
402402
lm_head_node = get_lm_head_node(gm, output_node)
403403
return get_weight_tensor(gm, lm_head_node)
404+
405+
406+
def get_attr_by_name(obj, name):
407+
for part in name.split("."):
408+
obj = getattr(obj, part)
409+
return obj
410+
411+
412+
def set_attr_by_name(obj, name, value):
413+
parts = name.split(".")
414+
for part in parts[:-1]:
415+
obj = getattr(obj, part)
416+
setattr(obj, parts[-1], value)
417+
418+
419+
def del_attr_by_name(obj, name):
420+
parts = name.split(".")
421+
for part in parts[:-1]:
422+
obj = getattr(obj, part)
423+
delattr(obj, parts[-1])

tensorrt_llm/_torch/auto_deploy/utils/attr.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

0 commit comments

Comments
 (0)