Skip to content

Commit b200614

Browse files
haozha111copybara-github
authored andcommitted
Add unit test for scaled_dot_product_attention.py module.
-- Only testing shape output for now. -- For HLFB, we are testing if the corresponding SHLO composite appears in the converted MLIR. PiperOrigin-RevId: 754046281
1 parent 42a5bb2 commit b200614

File tree

2 files changed

+88
-1
lines changed

2 files changed

+88
-1
lines changed

ai_edge_torch/generative/layers/scaled_dot_product_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def scaled_dot_product_attention_transposed(
160160
Args:
161161
query: Query tensor, with shape [B, T, N, H].
162162
key: Key tensor, with shape [B, T, KV_LEN, H].
163-
value: Value tensor, with shape [B, T, KV_LEN, H].
163+
value: Value tensor, with shape [B, T, H, KV_LEN].
164164
head_size (int): head dimension.
165165
mask (torch.Tensor): the optional mask tensor.
166166
scale (float): the optional scale factor.
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from ai_edge_torch import odml_torch
17+
from ai_edge_torch.generative.layers import scaled_dot_product_attention
18+
import torch
19+
20+
from absl.testing import absltest as googletest
21+
22+
23+
class ScaledDotProductAttentionTest(googletest.TestCase):
24+
25+
def test_scaled_dot_product_attention(self):
26+
query = torch.randn(1, 16, 16, 128, dtype=torch.float32)
27+
key = torch.randn(1, 16, 16, 128, dtype=torch.float32)
28+
value = torch.randn(1, 16, 16, 128, dtype=torch.float32)
29+
mask = torch.ones((1, 1, 1, 16), dtype=torch.float32)
30+
output = scaled_dot_product_attention.scaled_dot_product_attention(
31+
query, key, value, head_size=128, mask=mask, scale=1.0, softcap=10.0
32+
)
33+
self.assertEqual(output.shape, (1, 16, 16, 128))
34+
35+
def test_scaled_dot_product_attention_transposed(self):
36+
query = torch.randn(1, 16, 16, 128, dtype=torch.float32)
37+
key = torch.randn(1, 16, 16, 128, dtype=torch.float32)
38+
value = torch.randn(1, 16, 128, 16, dtype=torch.float32)
39+
mask = torch.ones((1, 1, 1, 16), dtype=torch.float32)
40+
output = (
41+
scaled_dot_product_attention.scaled_dot_product_attention_transposed(
42+
query, key, value, head_size=128, mask=mask, scale=1.0, softcap=10.0
43+
)
44+
)
45+
self.assertEqual(output.shape, (1, 16, 16, 128))
46+
47+
def test_scaled_dot_product_attention_with_hlfb(self):
48+
query = torch.randn(1, 16, 16, 128, dtype=torch.float32)
49+
key = torch.randn(1, 16, 16, 128, dtype=torch.float32)
50+
value = torch.randn(1, 16, 16, 128, dtype=torch.float32)
51+
mask = torch.ones((1, 1, 1, 16), dtype=torch.float32)
52+
output = (
53+
scaled_dot_product_attention.scaled_dot_product_attention_with_hlfb(
54+
query, key, value, head_size=128, mask=mask, scale=1.0, softcap=10.0
55+
)
56+
)
57+
self.assertEqual(output.shape, (1, 16, 16, 128))
58+
59+
def model_to_mlir(model, args):
60+
ep = torch.export.export(model, args)
61+
mlir = odml_torch.export.exported_program_to_mlir(ep)
62+
return mlir.get_text()
63+
64+
class SDPAModule(torch.nn.Module):
65+
66+
def __init__(self):
67+
super().__init__()
68+
69+
def forward(self, query, key, value, mask):
70+
return (
71+
scaled_dot_product_attention.scaled_dot_product_attention_with_hlfb(
72+
query,
73+
key,
74+
value,
75+
head_size=128,
76+
mask=mask,
77+
scale=1.0,
78+
softcap=10.0,
79+
)
80+
)
81+
82+
ir_text = model_to_mlir(SDPAModule().eval(), (query, key, value, mask))
83+
self.assertEqual(ir_text.count("stablehlo.custom_call @mark_tensor"), 5)
84+
85+
86+
if __name__ == "__main__":
87+
googletest.main()

0 commit comments

Comments
 (0)