Skip to content

Commit a8b54f9

Browse files
Fridah-nvlucaslie
authored andcommitted
Attention Pattern Matcher (closes NVIDIA#4404) (#88)
* WIP for attention matching: repeat_kv, eager_attention_matching Signed-off-by: Frida Hou <[email protected]> * works e2e with llama2 and llama3.1, eager and sdpa Signed-off-by: Frida Hou <[email protected]> * update for unit test test_attention_matcher Signed-off-by: Frida Hou <[email protected]> * minor Signed-off-by: Frida Hou <[email protected]> * minor Signed-off-by: Frida Hou <[email protected]> * unify into one transformation, update unit tests Signed-off-by: Frida Hou <[email protected]> * update hf_test to verify transformed output, update move_to_devide to recompile graph Signed-off-by: Frida Hou <[email protected]> * update after rebase Signed-off-by: Frida Hou <[email protected]> * minor Signed-off-by: Frida Hou <[email protected]> * update docstring Signed-off-by: Frida Hou <[email protected]> * minor Signed-off-by: Frida Hou <[email protected]> --------- Signed-off-by: Frida Hou <[email protected]>
1 parent d84ce33 commit a8b54f9

File tree

6 files changed

+354
-1050
lines changed

6 files changed

+354
-1050
lines changed

tensorrt_llm/_torch/auto_deploy/transformations/_graph.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,23 +96,24 @@ def named_graphmodules(gm: fx.GraphModule) -> Iterator[Tuple[str, fx.GraphModule
9696
yield name, m
9797

9898

99-
def _move_single_gm_to_device(
100-
gm: GraphModule, device: torch.device, recompile_graph: bool = False
101-
) -> None:
99+
def _move_single_gm_to_device(gm: GraphModule, device: torch.device) -> None:
102100
"""Move one GraphModule and its nodes to the specified device in-place.
103101
Partially inspired by https://github.com/pytorch/pytorch/blob/05cb98f91d49df9eadfcb3fc29bbd1b621d88860/torch/export/passes/__init__.py#L11
104102
"""
105103
# move state dict
106104
gm.to(device)
105+
recompile_graph = False
107106

108107
for node in gm.graph.nodes:
109108
# move all the nodes kwargs with burnt-in device
110109
if "device" in node.kwargs:
110+
recompile_graph = True
111111
kwargs = node.kwargs.copy()
112112
kwargs["device"] = device
113113
node.kwargs = kwargs
114114

115115
if is_op(node, torch.ops.aten.to.device):
116+
recompile_graph = True
116117
args = list(node.args)
117118
args[1] = device
118119
node.args = tuple(args)
@@ -135,7 +136,7 @@ def move_to_device(gm: fx.GraphModule, device: DeviceLikeType) -> fx.GraphModule
135136

136137
for _, subgm in reversed(list(named_graphmodules(gm))):
137138
# recompile graph to update self generated codes in subgraph
138-
_move_single_gm_to_device(subgm, device, subgm is not gm)
139+
_move_single_gm_to_device(subgm, device)
139140

140141

141142
def _is_impure_node(node: Node) -> bool:

0 commit comments

Comments
 (0)