Skip to content

Commit d266e75

Browse files
sxufacebook-github-bot
authored andcommitted
LLM export pass to swap in custom SDPA
Differential Revision: D73444078
1 parent 9e64882 commit d266e75

File tree

3 files changed

+108
-1
lines changed

3 files changed

+108
-1
lines changed

extension/llm/export/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,18 @@ runtime.python_library(
4141
"//executorch/exir:lib",
4242
"//executorch/exir/backend:backend_details",
4343
"//executorch/extension/export_util:export_util",
44+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
45+
"//executorch/extension/llm/custom_ops:custom_ops_aot_py",
4446
"//pytorch/tokenizers/pytorch_tokenizers:tokenizers",
4547
],
4648
)
49+
50+
runtime.python_test(
51+
name = "export_passes_test",
52+
srcs = [
53+
"test_export_passes.py",
54+
],
55+
deps = [
56+
":export_lib",
57+
],
58+
)

extension/llm/export/export_passes.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from torch._subclasses import FakeTensor
55
from torch.fx.passes.infra.pass_base import PassResult
66

7+
torch.ops.load_library("//executorch/extension/llm/custom_ops:custom_ops_aot_lib")
8+
from executorch.extension.llm.custom_ops import custom_ops # noqa
9+
710

811
def _normalize_dims(tensor: FakeTensor, dim_0: int, dim_1: int):
912
"""
@@ -95,3 +98,58 @@ def call(self, graph_module: torch.fx.GraphModule):
9598
graph_module.recompile()
9699

97100
return PassResult(graph_module, graph_changed)
101+
102+
103+
class ReplaceSDPAWithCustomSDPAPass(ExportPass):
104+
def call_operator(self, op, args, kwargs, meta): # pyre-ignore
105+
if op != torch.ops.aten.scaled_dot_product_attention.default:
106+
return super().call_operator(op, args, kwargs, meta)
107+
108+
q, k, v, *rest = args
109+
mask = None
110+
dropout = 0.0
111+
is_causal = False
112+
scale = None
113+
if len(rest) > 0:
114+
mask = rest[0]
115+
if len(rest) > 1:
116+
dropout = rest[1]
117+
if len(rest) > 2:
118+
is_causal = rest[2]
119+
if "scale" in kwargs:
120+
scale = kwargs["scale"]
121+
122+
qT = self._transpose(q, meta)
123+
kT = self._transpose(k, meta)
124+
vT = self._transpose(v, meta)
125+
126+
if mask is not None and mask.node.meta["val"].dtype == torch.bool:
127+
mask = super().call_operator(
128+
torch.ops.aten.where.Scalar,
129+
(mask, 0.0, float("-inf")),
130+
{},
131+
meta,
132+
)
133+
134+
custom_sdpa = super().call_operator(
135+
torch.ops.llama.custom_sdpa.default,
136+
(qT, kT, vT, 0, mask, dropout, is_causal, scale),
137+
{},
138+
meta,
139+
)
140+
return self._transpose(custom_sdpa, meta)
141+
142+
def _transpose(self, x, meta): # pyre-ignore
143+
transpose = super().call_operator(
144+
torch.ops.aten.transpose.int,
145+
(x, 1, 2),
146+
{},
147+
meta,
148+
)
149+
contiguous = super().call_operator(
150+
torch.ops.aten.contiguous.default,
151+
(transpose,),
152+
{},
153+
meta,
154+
)
155+
return contiguous

extension/llm/export/test_export_passes.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
import torch
44

5-
from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes
5+
from executorch.extension.llm.export.export_passes import (
6+
RemoveRedundantTransposes,
7+
ReplaceSDPAWithCustomSDPAPass,
8+
)
69

710
from torch.export import export_for_training
811
from torch.testing import FileCheck
@@ -160,3 +163,37 @@ def forward(self, x):
160163

161164
m = TestModule2()
162165
self._check(m, (x,), key, 3, 2)
166+
167+
168+
class ReplaceSDPAWithCustomSDPAPassTest(unittest.TestCase):
169+
class TestModule(torch.nn.Module):
170+
def forward(self, x, mask, is_causal):
171+
return torch.nn.functional.scaled_dot_product_attention(
172+
x, x, x, attn_mask=mask, is_causal=is_causal
173+
)
174+
175+
def setUp(self):
176+
torch.manual_seed(0)
177+
178+
def _test(self, *args):
179+
m = self.TestModule()
180+
gm = export_for_training(m, args, strict=True).module()
181+
182+
sdpa_key = "torch.ops.aten.scaled_dot_product_attention.default"
183+
custom_sdpa_key = "torch.ops.llama.custom_sdpa.default"
184+
FileCheck().check_count(sdpa_key, 1, exactly=True).run(gm.code)
185+
gm = ReplaceSDPAWithCustomSDPAPass()(gm).graph_module
186+
FileCheck().check_count(sdpa_key, 0, exactly=True).run(gm.code)
187+
FileCheck().check_count(custom_sdpa_key, 1, exactly=True).run(gm.code)
188+
189+
y1 = m(*args)
190+
y2 = gm(*args)
191+
self.assertTrue(torch.allclose(y1, y2))
192+
193+
def test_causal_mask(self):
194+
self._test(torch.rand(1, 4, 32, 64), None, True)
195+
196+
def test_custom_mask(self):
197+
m1 = torch.tril(torch.ones(32, 32, dtype=torch.bool))
198+
m2 = torch.tril(torch.ones(32, 32, dtype=torch.bool), diagonal=-16)
199+
self._test(torch.rand(1, 4, 32, 64), torch.logical_xor(m1, m2), False)

0 commit comments

Comments
 (0)