Skip to content

Commit a9e227e

Browse files
committed
Revert "Attention Pattern Matcher (closes NVIDIA#4404) (#88)"
This reverts commit a8b54f9.
1 parent 810df7d commit a9e227e

File tree

6 files changed

+1050
-354
lines changed

6 files changed

+1050
-354
lines changed

tensorrt_llm/_torch/auto_deploy/transformations/_graph.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,24 +96,23 @@ def named_graphmodules(gm: fx.GraphModule) -> Iterator[Tuple[str, fx.GraphModule
9696
yield name, m
9797

9898

99-
def _move_single_gm_to_device(gm: GraphModule, device: torch.device) -> None:
99+
def _move_single_gm_to_device(
100+
gm: GraphModule, device: torch.device, recompile_graph: bool = False
101+
) -> None:
100102
"""Move one GraphModule and its nodes to the specified device in-place.
101103
Partially inspired by https://github.com/pytorch/pytorch/blob/05cb98f91d49df9eadfcb3fc29bbd1b621d88860/torch/export/passes/__init__.py#L11
102104
"""
103105
# move state dict
104106
gm.to(device)
105-
recompile_graph = False
106107

107108
for node in gm.graph.nodes:
108109
# move all the nodes kwargs with burnt-in device
109110
if "device" in node.kwargs:
110-
recompile_graph = True
111111
kwargs = node.kwargs.copy()
112112
kwargs["device"] = device
113113
node.kwargs = kwargs
114114

115115
if is_op(node, torch.ops.aten.to.device):
116-
recompile_graph = True
117116
args = list(node.args)
118117
args[1] = device
119118
node.args = tuple(args)
@@ -136,7 +135,7 @@ def move_to_device(gm: fx.GraphModule, device: DeviceLikeType) -> fx.GraphModule
136135

137136
for _, subgm in reversed(list(named_graphmodules(gm))):
138137
# recompile graph to update self generated codes in subgraph
139-
_move_single_gm_to_device(subgm, device)
138+
_move_single_gm_to_device(subgm, device, subgm is not gm)
140139

141140

142141
def _is_impure_node(node: Node) -> bool:

0 commit comments

Comments
 (0)