Skip to content

Commit 5845951

Browse files
authored
[#10056][fix] AutoDeploy: Handle deletion of nested params in sharding (#10376)
Signed-off-by: Gal Hubara Agam <[email protected]>
1 parent 4868772 commit 5845951

File tree

3 files changed

+64
-26
lines changed

3 files changed

+64
-26
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,11 @@
3131

3232
from ...models.factory import ModelFactory
3333
from ...shim.interface import CachedSequenceInterface
34+
from ...utils._graph import del_attr_by_name, get_attr_by_name, set_attr_by_name
3435
from ...utils.logger import ad_logger
3536
from ...utils.pattern_matcher import ADPatternMatcherPass
3637
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
3738

38-
39-
def _get_attr_by_name(obj, name):
40-
for part in name.split("."):
41-
obj = getattr(obj, part)
42-
return obj
43-
44-
45-
def _set_attr_by_name(obj, name, value):
46-
parts = name.split(".")
47-
for part in parts[:-1]:
48-
obj = getattr(obj, part)
49-
setattr(obj, parts[-1], value)
50-
51-
52-
def _del_attr_by_name(obj, name):
53-
parts = name.split(".")
54-
for part in parts[:-1]:
55-
obj = getattr(obj, part)
56-
delattr(obj, parts[-1])
57-
58-
5939
_PATTERN_INPUT_NAME = "a_log_like"
6040

6141

@@ -82,21 +62,21 @@ def _ensure_a_fused_param(gm: GraphModule, param_name: str) -> Optional[str]:
8262

8363
new_param_name = param_name.replace("A_log", "A_fused")
8464
try:
85-
_get_attr_by_name(gm, new_param_name)
65+
get_attr_by_name(gm, new_param_name)
8666
return new_param_name
8767
except AttributeError:
8868
pass
8969

9070
try:
91-
a_log = _get_attr_by_name(gm, param_name)
71+
a_log = get_attr_by_name(gm, param_name)
9272
except AttributeError:
9373
ad_logger.warning(f"Could not find attribute {param_name} in gm.")
9474
return None
9575

9676
with torch.no_grad():
9777
a_fused = -torch.exp(a_log.float())
9878

99-
_set_attr_by_name(
79+
set_attr_by_name(
10080
gm,
10181
new_param_name,
10282
nn.Parameter(a_fused, requires_grad=False),
@@ -120,7 +100,7 @@ def _maybe_remove(name: str) -> None:
120100
if not name.endswith("A_log") or name in used_a_log_targets:
121101
return
122102
try:
123-
_del_attr_by_name(gm, name)
103+
del_attr_by_name(gm, name)
124104
removed = True
125105
except AttributeError:
126106
ad_logger.warning(f"Failed to delete unused parameter {name} from GraphModule.")

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from ...custom_ops.trtllm_dist import is_trtllm_op_available
3434
from ...models.factory import ModelFactory, ShardingConfigSource
3535
from ...shim.interface import CachedSequenceInterface
36+
from ...utils._graph import del_attr_by_name
3637
from ...utils.logger import ad_logger
3738
from ...utils.node_utils import (
3839
LayerSubgraph,
@@ -1467,7 +1468,12 @@ def get_partition(lst, world_size, rank):
14671468
for expert in (
14681469
w_up_list_to_remove + w_down_list_to_remove + w_gate_list_to_remove + scales_to_remove
14691470
):
1470-
delattr(gm, expert.target)
1471+
try:
1472+
del_attr_by_name(gm, expert.target)
1473+
except AttributeError:
1474+
ad_logger.warning(
1475+
f"Failed to delete unused parameter {expert.target} from GraphModule."
1476+
)
14711477

14721478

14731479
def _slice_expert_dim(gm: GraphModule, tensor_node: Node, lo: int, hi: int) -> Node:

tensorrt_llm/_torch/auto_deploy/utils/_graph.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,3 +401,55 @@ def get_lm_head_weights(model: nn.Module) -> torch.Tensor:
401401
gm, output_node = get_output_node(model)
402402
lm_head_node = get_lm_head_node(gm, output_node)
403403
return get_weight_tensor(gm, lm_head_node)
404+
405+
406+
def get_attr_by_name(obj, name):
407+
"""Get an attribute specified by a dot-separated path on an object.
408+
409+
Args:
410+
obj: The root object from which to resolve the attribute path.
411+
name (str): Dot-separated attribute path (e.g., "a.b.c").
412+
413+
Returns:
414+
The value of the resolved attribute.
415+
416+
Raises:
417+
AttributeError: If any component in the path does not exist.
418+
"""
419+
for part in name.split("."):
420+
obj = getattr(obj, part)
421+
return obj
422+
423+
424+
def set_attr_by_name(obj, name, value):
425+
"""Set an attribute specified by a dot-separated path on an object.
426+
427+
Args:
428+
obj: The root object on which to set the attribute.
429+
name (str): Dot-separated attribute path (e.g., "a.b.c").
430+
value: The value to assign to the target attribute.
431+
432+
Raises:
433+
AttributeError: If any intermediate component in the path does not exist.
434+
"""
435+
parts = name.split(".")
436+
for part in parts[:-1]:
437+
obj = getattr(obj, part)
438+
setattr(obj, parts[-1], value)
439+
440+
441+
def del_attr_by_name(obj, name):
442+
"""Delete an attribute specified by a dot-separated path from an object.
443+
444+
Args:
445+
obj: The root object from which to delete the attribute.
446+
name (str): Dot-separated attribute path (e.g., "a.b.c").
447+
448+
Raises:
449+
AttributeError: If any intermediate component in the path does not exist
450+
or if the final attribute does not exist.
451+
"""
452+
parts = name.split(".")
453+
for part in parts[:-1]:
454+
obj = getattr(obj, part)
455+
delattr(obj, parts[-1])

0 commit comments

Comments
 (0)