Skip to content

Commit 275144b

Browse files
committed
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code sharing"
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress [ghstack-poisoned]
2 parents 5eb4c6f + e105c4c commit 275144b

File tree

4 files changed

+259
-0
lines changed

4 files changed

+259
-0
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,8 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
657657
# export_to_edge
658658
builder_exported = _prepare_for_llama_export(args).export()
659659

660+
builder_exported.run_canonical_optimizations()
661+
660662
if args.export_only:
661663
exit()
662664

extension/llm/export/builder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
from torch.export import export_for_training
3838
from torch.nn.attention import SDPBackend
3939

40+
from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes
41+
4042
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
4143
logging.basicConfig(level=logging.INFO, format=FORMAT)
4244

@@ -108,6 +110,7 @@ def __init__(
108110
self.calibration_seq_length = calibration_seq_length
109111
self.calibration_data = calibration_data
110112
self.tokenizer_path = tokenizer_path
113+
self.canonical_passes = [RemoveRedundantTransposes()]
111114

112115
def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
113116
"""
@@ -212,6 +215,13 @@ def export(self) -> "LLMEdgeManager":
212215

213216
return self
214217

218+
def run_canonical_optimizations(self):
219+
for pass_instance in self.canonical_passes:
220+
logging.info(f"Running canonical pass: {pass_instance.__class__.__name__}")
221+
res = pass_instance(self.pre_autograd_graph_module)
222+
assert res.graph_module is not None, "Pass returned None"
223+
self.pre_autograd_graph_module = res.graph_module
224+
215225
def pt2e_calibrate(
216226
self,
217227
prepared_module,
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import torch
2+
from torch._subclasses import FakeTensor
3+
4+
from executorch.exir.pass_base import ExportPass
5+
from torch.fx.passes.infra.pass_base import PassResult
6+
7+
def _normalize_dims(tensor: FakeTensor, dim_0: int, dim_1: int):
8+
"""
9+
Normalize the dimensions of a tensor.
10+
"""
11+
assert tensor is not None, "Tensor is None"
12+
ndim = tensor.ndim
13+
if dim_0 < 0:
14+
dim_0 = ndim + dim_0
15+
if dim_1 < 0:
16+
dim_1 = ndim + dim_1
17+
assert dim_0 < ndim and dim_1 < ndim, f"Invalid dimensions: {dim_0}, {dim_1}"
18+
return dim_0, dim_1
19+
20+
class RemoveRedundantTransposes(ExportPass):
21+
"""
22+
This pass removes redundant transpose nodes in the graph.
23+
It checks if the next node is also a transpose node and if the two transpose nodes undo each other.
24+
For example, if the graph has the following nodes:
25+
26+
node1 = torch.ops.aten.transpose.int(x, 0, 1)
27+
node2 = torch.ops.aten.transpose.int(node1, 0, 1)
28+
29+
Then node2's use can be replaced by x
30+
31+
It will also check for permute nodes
32+
node1 = torch.ops.aten.permute(x, [0, 2, 1])
33+
node2 = torch.ops.aten.permute(node1, [0, 2, 1])
34+
35+
Then also node2's use can be replaced by x
36+
37+
NB: Does not work for inplace ops or functionalized _copy suffix ops
38+
"""
39+
def call(self, graph_module: torch.fx.GraphModule):
40+
graph_changed = False
41+
for node in graph_module.graph.nodes:
42+
if node.op == 'call_function' and node.target == torch.ops.aten.transpose.int:
43+
# Check if the next node is also a transpose node
44+
tranpose_users = list(node.users.keys())
45+
dim_0 = node.args[1]
46+
dim_1 = node.args[2]
47+
dim_0, dim_1 = _normalize_dims(node.args[0].meta["val"], dim_0, dim_1)
48+
49+
for user in tranpose_users:
50+
if user.op == 'call_function' and user.target == torch.ops.aten.transpose.int:
51+
# Get the arguments of the current and next transpose nodes
52+
user_dim_0 = user.args[1]
53+
user_dim_1 = user.args[2]
54+
user_dim_0, user_dim_1 = _normalize_dims(user.args[0].meta["val"], user_dim_0, user_dim_1)
55+
56+
# Check if the two transpose nodes undo each other
57+
if dim_0 == user_dim_0 and dim_1 == user_dim_1:
58+
graph_changed = True
59+
user.replace_all_uses_with(node.args[0])
60+
61+
for node in graph_module.graph.nodes:
62+
if node.op == 'call_function' and node.target == torch.ops.aten.permute.default:
63+
# Check if the next node is also a transpose node
64+
permute_users = list(node.users.keys())
65+
dim_list = node.args[1]
66+
67+
for user in permute_users:
68+
if user.op == 'call_function' and user.target == torch.ops.aten.permute.default:
69+
# Get the arguments of the current and next transpose nodes
70+
user_dim_list = user.args[1]
71+
72+
# Check if the two permutes undo each other
73+
if dim_list == user_dim_list:
74+
graph_changed = True
75+
user.replace_all_uses_with(node.args[0])
76+
77+
graph_module.graph.eliminate_dead_code()
78+
graph_module.recompile()
79+
80+
return PassResult(graph_module, graph_changed)
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import unittest
2+
import os
3+
4+
import torch
5+
from torch.testing import FileCheck
6+
7+
from torch.export import export_for_training
8+
9+
from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes
10+
11+
class RemoveRedundantTransposesPassTest(unittest.TestCase):
12+
def _export(self, model, example_inputs):
13+
exported_module = export_for_training(
14+
model,
15+
example_inputs,
16+
)
17+
return exported_module.module()
18+
19+
def _check(self, model, example_inputs, key, before_count, after_count):
20+
gm = self._export(model, example_inputs)
21+
FileCheck().check_count(key, before_count, exactly=True).run(
22+
gm.code
23+
)
24+
pass_res = RemoveRedundantTransposes()(gm)
25+
FileCheck().check_count(key, after_count, exactly=True).run(
26+
pass_res.graph_module.code
27+
)
28+
29+
def test_transpose_removal(self):
30+
class TestModule1(torch.nn.Module):
31+
def __init__(self):
32+
super().__init__()
33+
34+
def forward(self, x):
35+
x = torch.transpose(x, 1, 2)
36+
x = torch.transpose(x, 1, 2)
37+
return x + 1
38+
39+
class TestModule2(torch.nn.Module):
40+
def __init__(self):
41+
super().__init__()
42+
43+
def forward(self, x):
44+
x = torch.transpose(x, 1, 2)
45+
x = torch.transpose(x, 1, 2)
46+
x = x + 1
47+
48+
x = torch.transpose(x, 2, 3)
49+
x = torch.transpose(x, 2, 3)
50+
51+
return x + 2
52+
53+
x = torch.rand((1, 2, 3, 4))
54+
key = "torch.ops.aten.transpose.int"
55+
m = TestModule1()
56+
self._check(m, (x,), key, 2, 0)
57+
58+
m = TestModule2()
59+
self._check(m, (x,), key, 4, 0)
60+
61+
def test_transpose_no_removal(self):
62+
class TestModule1(torch.nn.Module):
63+
def __init__(self):
64+
super().__init__()
65+
66+
def forward(self, x):
67+
x = torch.transpose(x, 1, 2)
68+
x = torch.transpose(x, 1, 2)
69+
x = x + 1
70+
71+
x = torch.transpose(x, 2, 3)
72+
x = torch.transpose(x, 1, 2)
73+
74+
return x + 2
75+
76+
x = torch.rand((1, 2, 3, 4))
77+
key = "torch.ops.aten.transpose.int"
78+
79+
m = TestModule1()
80+
self._check(m, (x,), key, 4, 2)
81+
82+
class TestModule2(torch.nn.Module):
83+
def __init__(self):
84+
super().__init__()
85+
86+
def forward(self, x):
87+
x_1 = torch.transpose(x, 1, 2)
88+
x_2 = torch.transpose(x_1, 1, 2)
89+
x_2 = x_2 + 1
90+
91+
x = x_1 + 2
92+
x = torch.transpose(x, 1, 2)
93+
94+
return x + x_2
95+
96+
m = TestModule2()
97+
self._check(m, (x,), key, 3, 2)
98+
99+
def test_permute_removal(self):
100+
class TestModule1(torch.nn.Module):
101+
def __init__(self):
102+
super().__init__()
103+
104+
def forward(self, x):
105+
x = torch.permute(x, [0, 2, 1, 3])
106+
x = torch.permute(x, [0, 2, 1, 3])
107+
return x + 1
108+
109+
class TestModule2(torch.nn.Module):
110+
def __init__(self):
111+
super().__init__()
112+
113+
def forward(self, x):
114+
x = torch.permute(x, [0, 2, 1, 3])
115+
x = torch.permute(x, [0, 2, 1, 3])
116+
x = x + 1
117+
118+
x = torch.permute(x, [0, 1, 3, 2])
119+
x = torch.permute(x, [0, 1, 3, 2])
120+
121+
return x + 2
122+
123+
x = torch.rand((1, 2, 3, 4))
124+
key = "torch.ops.aten.permute.default"
125+
m = TestModule1()
126+
self._check(m, (x,), key, 2, 0)
127+
128+
m = TestModule2()
129+
self._check(m, (x,), key, 4, 0)
130+
131+
def test_permute_no_removal(self):
132+
class TestModule1(torch.nn.Module):
133+
def __init__(self):
134+
super().__init__()
135+
136+
def forward(self, x):
137+
x = torch.permute(x, [0, 2, 1, 3])
138+
x = torch.permute(x, [0, 2, 1, 3])
139+
x = x + 1
140+
141+
x = torch.permute(x, [0, 1, 3, 2])
142+
x = torch.permute(x, [0, 2, 1, 3])
143+
144+
return x + 2
145+
146+
x = torch.rand((1, 2, 3, 4))
147+
key = "torch.ops.aten.permute.default"
148+
149+
m = TestModule1()
150+
self._check(m, (x,), key, 4, 2)
151+
152+
class TestModule2(torch.nn.Module):
153+
def __init__(self):
154+
super().__init__()
155+
156+
def forward(self, x):
157+
x_1 = torch.permute(x, [0, 2, 1, 3])
158+
x_2 = torch.permute(x_1, [0, 2, 1, 3])
159+
x_2 = x_2 + 1
160+
161+
x = x_1 + 2
162+
x = torch.permute(x, [0, 2, 1, 3])
163+
164+
return x + x_2
165+
166+
m = TestModule2()
167+
self._check(m, (x,), key, 3, 2)

0 commit comments

Comments
 (0)