|
| 1 | +# Copyright 2025 The AI Edge Torch Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | + |
| 16 | +from ai_edge_torch import odml_torch |
| 17 | +from ai_edge_torch.generative.layers import scaled_dot_product_attention |
| 18 | +import torch |
| 19 | + |
| 20 | +from absl.testing import absltest as googletest |
| 21 | + |
| 22 | + |
| 23 | +class ScaledDotProductAttentionTest(googletest.TestCase): |
| 24 | + |
| 25 | + def test_scaled_dot_product_attention(self): |
| 26 | + query = torch.randn(1, 16, 16, 128, dtype=torch.float32) |
| 27 | + key = torch.randn(1, 16, 16, 128, dtype=torch.float32) |
| 28 | + value = torch.randn(1, 16, 16, 128, dtype=torch.float32) |
| 29 | + mask = torch.ones((1, 1, 1, 16), dtype=torch.float32) |
| 30 | + output = scaled_dot_product_attention.scaled_dot_product_attention( |
| 31 | + query, key, value, head_size=128, mask=mask, scale=1.0, softcap=10.0 |
| 32 | + ) |
| 33 | + self.assertEqual(output.shape, (1, 16, 16, 128)) |
| 34 | + |
| 35 | + def test_scaled_dot_product_attention_transposed(self): |
| 36 | + query = torch.randn(1, 16, 16, 128, dtype=torch.float32) |
| 37 | + key = torch.randn(1, 16, 16, 128, dtype=torch.float32) |
| 38 | + value = torch.randn(1, 16, 128, 16, dtype=torch.float32) |
| 39 | + mask = torch.ones((1, 1, 1, 16), dtype=torch.float32) |
| 40 | + output = ( |
| 41 | + scaled_dot_product_attention.scaled_dot_product_attention_transposed( |
| 42 | + query, key, value, head_size=128, mask=mask, scale=1.0, softcap=10.0 |
| 43 | + ) |
| 44 | + ) |
| 45 | + self.assertEqual(output.shape, (1, 16, 16, 128)) |
| 46 | + |
| 47 | + def test_scaled_dot_product_attention_with_hlfb(self): |
| 48 | + query = torch.randn(1, 16, 16, 128, dtype=torch.float32) |
| 49 | + key = torch.randn(1, 16, 16, 128, dtype=torch.float32) |
| 50 | + value = torch.randn(1, 16, 16, 128, dtype=torch.float32) |
| 51 | + mask = torch.ones((1, 1, 1, 16), dtype=torch.float32) |
| 52 | + output = ( |
| 53 | + scaled_dot_product_attention.scaled_dot_product_attention_with_hlfb( |
| 54 | + query, key, value, head_size=128, mask=mask, scale=1.0, softcap=10.0 |
| 55 | + ) |
| 56 | + ) |
| 57 | + self.assertEqual(output.shape, (1, 16, 16, 128)) |
| 58 | + |
| 59 | + def model_to_mlir(model, args): |
| 60 | + ep = torch.export.export(model, args) |
| 61 | + mlir = odml_torch.export.exported_program_to_mlir(ep) |
| 62 | + return mlir.get_text() |
| 63 | + |
| 64 | + class SDPAModule(torch.nn.Module): |
| 65 | + |
| 66 | + def __init__(self): |
| 67 | + super().__init__() |
| 68 | + |
| 69 | + def forward(self, query, key, value, mask): |
| 70 | + return ( |
| 71 | + scaled_dot_product_attention.scaled_dot_product_attention_with_hlfb( |
| 72 | + query, |
| 73 | + key, |
| 74 | + value, |
| 75 | + head_size=128, |
| 76 | + mask=mask, |
| 77 | + scale=1.0, |
| 78 | + softcap=10.0, |
| 79 | + ) |
| 80 | + ) |
| 81 | + |
| 82 | + ir_text = model_to_mlir(SDPAModule().eval(), (query, key, value, mask)) |
| 83 | + self.assertEqual(ir_text.count("stablehlo.custom_call @mark_tensor"), 5) |
| 84 | + |
| 85 | + |
| 86 | +if __name__ == "__main__": |
| 87 | + googletest.main() |
0 commit comments