Skip to content

Commit 2fa9f60

Browse files
sxufacebook-github-bot
authored andcommitted
LLM export pass to swap in custom SDPA (#10355)
Summary: Pull Request resolved: #10355 Differential Revision: D73444078
1 parent 3eac583 commit 2fa9f60

File tree

3 files changed

+110
-1
lines changed

3 files changed

+110
-1
lines changed

extension/llm/export/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@ runtime.python_library(
4141
"//executorch/exir:lib",
4242
"//executorch/exir/backend:backend_details",
4343
"//executorch/extension/export_util:export_util",
44+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
45+
"//executorch/extension/llm/custom_ops:custom_ops_aot_py",
4446
"//pytorch/tokenizers/pytorch_tokenizers:tokenizers",
4547
],
4648
)
49+
50+
runtime.python_test(
51+
name = "export_passes_test",
52+
srcs = [
53+
"test_export_passes.py",
54+
],
55+
preload_deps = [
56+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
57+
],
58+
deps = [
59+
":export_lib",
60+
],
61+
)

extension/llm/export/export_passes.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,60 @@ def call(self, graph_module: torch.fx.GraphModule):
9595
graph_module.recompile()
9696

9797
return PassResult(graph_module, graph_changed)
98+
99+
100+
class ReplaceSDPAWithCustomSDPAPass(ExportPass):
101+
def call_operator(self, op, args, kwargs, meta):
102+
from executorch.extension.llm.custom_ops import custom_ops # noqa
103+
104+
if op != torch.ops.aten.scaled_dot_product_attention.default:
105+
return super().call_operator(op, args, kwargs, meta)
106+
107+
q, k, v, *rest = args
108+
mask = None
109+
dropout = 0.0
110+
is_causal = False
111+
scale = None
112+
if len(rest) > 0:
113+
mask = rest[0]
114+
if len(rest) > 1:
115+
dropout = rest[1]
116+
if len(rest) > 2:
117+
is_causal = rest[2]
118+
if "scale" in kwargs:
119+
scale = kwargs["scale"]
120+
121+
qT = self._transpose(q, meta)
122+
kT = self._transpose(k, meta)
123+
vT = self._transpose(v, meta)
124+
125+
if mask is not None and mask.node.meta["val"].dtype == torch.bool:
126+
mask = super().call_operator(
127+
torch.ops.aten.where.Scalar,
128+
(mask, 0.0, float("-inf")),
129+
{},
130+
meta,
131+
)
132+
133+
custom_sdpa = super().call_operator(
134+
torch.ops.llama.custom_sdpa.default,
135+
(qT, kT, vT, 0, mask, dropout, is_causal, scale),
136+
{},
137+
meta,
138+
)
139+
return self._transpose(custom_sdpa, meta)
140+
141+
def _transpose(self, x, meta):
142+
transpose = super().call_operator(
143+
torch.ops.aten.transpose.int,
144+
(x, 1, 2),
145+
{},
146+
meta,
147+
)
148+
contiguous = super().call_operator(
149+
torch.ops.aten.contiguous.default,
150+
(transpose,),
151+
{},
152+
meta,
153+
)
154+
return contiguous

extension/llm/export/test_export_passes.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
import torch
44

5-
from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes
5+
from executorch.extension.llm.export.export_passes import (
6+
RemoveRedundantTransposes,
7+
ReplaceSDPAWithCustomSDPAPass,
8+
)
69

710
from torch.export import export_for_training
811
from torch.testing import FileCheck
@@ -160,3 +163,37 @@ def forward(self, x):
160163

161164
m = TestModule2()
162165
self._check(m, (x,), key, 3, 2)
166+
167+
168+
class ReplaceSDPAWithCustomSDPAPassTest(unittest.TestCase):
169+
class TestModule(torch.nn.Module):
170+
def forward(self, x, mask, is_causal):
171+
return torch.nn.functional.scaled_dot_product_attention(
172+
x, x, x, attn_mask=mask, is_causal=is_causal
173+
)
174+
175+
def setUp(self):
176+
torch.manual_seed(0)
177+
178+
def _test(self, *args):
179+
m = self.TestModule()
180+
gm = export_for_training(m, args, strict=True).module()
181+
182+
sdpa_key = "torch.ops.aten.scaled_dot_product_attention.default"
183+
custom_sdpa_key = "torch.ops.llama.custom_sdpa.default"
184+
FileCheck().check_count(sdpa_key, 1, exactly=True).run(gm.code)
185+
gm = ReplaceSDPAWithCustomSDPAPass()(gm).graph_module
186+
FileCheck().check_count(sdpa_key, 0, exactly=True).run(gm.code)
187+
FileCheck().check_count(custom_sdpa_key, 1, exactly=True).run(gm.code)
188+
189+
y1 = m(*args)
190+
y2 = gm(*args)
191+
self.assertTrue(torch.allclose(y1, y2))
192+
193+
def test_causal_mask(self):
194+
self._test(torch.rand(1, 4, 32, 64), None, True)
195+
196+
def test_custom_mask(self):
197+
m1 = torch.tril(torch.ones(32, 32, dtype=torch.bool))
198+
m2 = torch.tril(torch.ones(32, 32, dtype=torch.bool), diagonal=-16)
199+
self._test(torch.rand(1, 4, 32, 64), torch.logical_xor(m1, m2), False)

0 commit comments

Comments
 (0)