Skip to content

Commit 03b44fc

Browse files
committed
Fix tests
1 parent 416c91b commit 03b44fc

File tree

2 files changed

+85
-40
lines changed

2 files changed

+85
-40
lines changed

extension/llm/modules/mha.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ def forward(
354354
q = q.transpose(1, 2)
355355
k = k.transpose(1, 2)
356356
v = v.transpose(1, 2)
357+
357358
output = self._attention_fn(
358359
q,
359360
k,

extension/llm/modules/test/test_mha.py

Lines changed: 84 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
import unittest
88

99
import torch
10+
from executorch.exir import EdgeCompileConfig, to_edge
1011

1112
from executorch.extension.llm.modules.mha import (
1213
MultiHeadAttention as ETMultiHeadAttention,
1314
)
15+
from executorch.runtime import Runtime
16+
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
1417
from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention
15-
from torchtune.modules.kv_cache import KVCache
1618

1719

1820
torch.manual_seed(0)
@@ -21,76 +23,118 @@
2123
class AttentionTest(unittest.TestCase):
2224
def setUp(self):
2325
super().setUp()
24-
self.embed_dim=2048
25-
self.num_heads=32
26-
self.num_kv_heads=8
27-
self.head_dim=64
26+
27+
# Constants
28+
self.embed_dim = 2048
29+
self.num_heads = 32
30+
self.num_kv_heads = 8
31+
self.head_dim = 64
2832
self.max_seq_len = 128
33+
self.rope_base = 500_000
34+
self.scale_factor = 32
35+
36+
# Module dependency injections.
37+
self.q_proj = torch.nn.Linear(
38+
self.embed_dim, self.num_heads * self.head_dim, bias=False
39+
)
40+
self.k_proj = torch.nn.Linear(
41+
self.embed_dim, self.num_kv_heads * self.head_dim, bias=False
42+
)
43+
self.v_proj = torch.nn.Linear(
44+
self.embed_dim, self.num_kv_heads * self.head_dim, bias=False
45+
)
46+
self.output_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False)
47+
self.pos_embeddings = Llama3ScaledRoPE(dim=self.head_dim, max_seq_len=self.max_seq_len, base=self.rope_base, scale_factor=self.scale_factor)
48+
49+
# Original TorchTune reference module to test accuracy against.
2950
self.tt_mha = TTMultiHeadAttention(
3051
embed_dim=self.embed_dim,
3152
num_heads=self.num_heads,
3253
num_kv_heads=self.num_kv_heads,
3354
head_dim=self.head_dim,
34-
q_proj=torch.nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False),
35-
k_proj=torch.nn.Linear(self.embed_dim, self.num_kv_heads * self.head_dim, bias=False),
36-
v_proj=torch.nn.Linear(self.embed_dim, self.num_kv_heads * self.head_dim, bias=False),
37-
output_proj=torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False),
38-
# pos_embeddings=rope,
55+
q_proj=self.q_proj,
56+
k_proj=self.k_proj,
57+
v_proj=self.v_proj,
58+
output_proj=self.output_proj,
59+
pos_embeddings=self.pos_embeddings,
3960
max_seq_len=self.max_seq_len,
40-
# attn_dropout=attn_dropout,
4161
)
62+
63+
# Source transformed module that we are testing.
4264
self.et_mha = ETMultiHeadAttention(
4365
embed_dim=self.embed_dim,
4466
num_heads=self.num_heads,
4567
num_kv_heads=self.num_kv_heads,
4668
head_dim=self.head_dim,
47-
q_proj=torch.nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False),
48-
k_proj=torch.nn.Linear(self.embed_dim, self.num_kv_heads * self.head_dim, bias=False),
49-
v_proj=torch.nn.Linear(self.embed_dim, self.num_kv_heads * self.head_dim, bias=False),
50-
output_proj=torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False),
51-
# pos_embeddings=rope,
69+
q_proj=self.q_proj,
70+
k_proj=self.k_proj,
71+
v_proj=self.v_proj,
72+
output_proj=self.output_proj,
73+
pos_embeddings=self.pos_embeddings,
5274
max_seq_len=self.max_seq_len,
53-
# attn_dropout=attn_dropout,
5475
)
5576

56-
def test_self_attention_eager(self):
77+
# Common inputs.
5778
seq_len = 10
58-
x = torch.randn(1, seq_len, self.embed_dim)
59-
et_res = self.et_mha(x, x) # Self attention.
60-
tt_res = self.tt_mha(x, x) # Self attention.
61-
79+
self.x = torch.randn(1, seq_len, self.embed_dim)
80+
seq_len_dim = torch.export.Dim("seq_len", min=1, max=100)
81+
self.dynamic_shapes = (
82+
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
83+
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
84+
)
85+
86+
def test_attention_eager(self):
87+
et_res = self.et_mha(self.x, self.x) # Self attention.
88+
tt_res = self.tt_mha(self.x, self.x) # Self attention.
89+
6290
self.assertTrue(torch.allclose(et_res, tt_res))
6391

6492
# TODO: KV cache.
6593
# self.et_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20)
6694
# self.tt_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20)
67-
68-
# et_res = self.et_mha(x, x) # Self attention.
69-
# tt_res = self.tt_mha(x, x) # Self attention.
7095

71-
# self.assertTrue(torch.allclose(et_res, tt_res))
96+
# et_res = self.et_mha(self.x, self.x) # Self attention.
97+
# tt_res = self.tt_mha(self.x, self.x) # Self attention.
7298

73-
def test_self_attention_export(self):
74-
seq_len = 10
75-
x = torch.randn(1, seq_len, self.embed_dim)
76-
seq_len_dim = torch.export.Dim("seq_len", min=1, max=100)
77-
dynamic_shapes = (
78-
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
79-
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
80-
)
99+
# self.assertTrue(torch.allclose(et_res, tt_res))
81100

101+
def test_attention_export(self):
82102
# Self attention.
83103
et_mha_ep = torch.export.export(
84104
self.et_mha,
85-
(x, x),
105+
(self.x, self.x),
86106
kwargs=None,
87-
dynamic_shapes=dynamic_shapes,
107+
dynamic_shapes=self.dynamic_shapes,
88108
)
89-
et_res = et_mha_ep.module()(x, x)
90-
tt_res = self.tt_mha(x, x)
109+
et_res = et_mha_ep.module()(self.x, self.x)
110+
tt_res = self.tt_mha(self.x, self.x)
91111
self.assertTrue(torch.allclose(et_res, tt_res))
92-
112+
93113
# TODO: KV cache.
94114

95-
def test_cross_attention_export(self):
115+
def test_attention_aoti(self):
116+
# TODO.
96117
pass
118+
119+
def test_attention_executorch(self):
120+
# Self attention.
121+
et_mha_ep = torch.export.export(
122+
self.et_mha,
123+
(self.x, self.x),
124+
kwargs=None,
125+
dynamic_shapes=self.dynamic_shapes,
126+
)
127+
et_program = to_edge(
128+
et_mha_ep,
129+
compile_config=EdgeCompileConfig(),
130+
).to_executorch()
131+
runtime = Runtime.get()
132+
program = runtime.load_program(et_program.buffer)
133+
method = program.load_method("forward")
134+
et_res = method.execute((self.x, self.x))
135+
tt_res = self.tt_mha(self.x, self.x)
136+
137+
self.assertTrue(torch.allclose(et_res[0], tt_res, atol=1e-06))
138+
139+
# TODO: KV cache.
140+

0 commit comments

Comments
 (0)