|
3 | 3 | #
|
4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 |
| -import pytest |
| 6 | +import unittest |
| 7 | + |
7 | 8 | import torch
|
| 9 | +from torch.testing._internal import common_utils |
8 | 10 |
|
9 | 11 | from torchao._models.llama.model import Transformer
|
10 | 12 |
|
11 |
| -_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) |
12 |
| - |
13 | 13 |
|
14 | 14 | def init_model(name="stories15M", device="cpu", precision=torch.bfloat16):
|
| 15 | + """Initialize and return a Transformer model with specified configuration.""" |
15 | 16 | model = Transformer.from_name(name)
|
16 | 17 | model.to(device=device, dtype=precision)
|
17 | 18 | return model.eval()
|
18 | 19 |
|
19 | 20 |
|
20 |
| -@pytest.mark.parametrize("device", _AVAILABLE_DEVICES) |
21 |
| -@pytest.mark.parametrize("batch_size", [1, 4]) |
22 |
| -@pytest.mark.parametrize("is_training", [True, False]) |
23 |
| -def test_ao_llama_model_inference_mode(device, batch_size, is_training): |
24 |
| - random_model = init_model(device=device) |
25 |
| - seq_len = 16 |
26 |
| - input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device) |
27 |
| - input_pos = None if is_training else torch.arange(seq_len).to(device) |
28 |
| - with torch.device(device): |
29 |
| - random_model.setup_caches( |
30 |
| - max_batch_size=batch_size, max_seq_length=seq_len, training=is_training |
31 |
| - ) |
32 |
| - for i in range(3): |
33 |
| - out = random_model(input_ids, input_pos) |
34 |
| - assert out is not None, "model failed to run" |
| 21 | +class TorchAOBasicTestCase(unittest.TestCase): |
| 22 | + """Test suite for basic Transformer inference functionality.""" |
| 23 | + |
| 24 | + @common_utils.parametrize( |
| 25 | + "device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] |
| 26 | + ) |
| 27 | + @common_utils.parametrize("batch_size", [1, 4]) |
| 28 | + @common_utils.parametrize("is_training", [True, False]) |
| 29 | + def test_ao_inference_mode(self, device, batch_size, is_training): |
| 30 | + # Initialize model with specified device |
| 31 | + random_model = init_model(device=device) |
| 32 | + |
| 33 | + # Set up test input parameters |
| 34 | + seq_len = 16 |
| 35 | + input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device) |
| 36 | + |
| 37 | + # input_pos is None for training mode, tensor for inference mode |
| 38 | + input_pos = None if is_training else torch.arange(seq_len).to(device) |
| 39 | + |
| 40 | + # Setup model caches within the device context |
| 41 | + with torch.device(device): |
| 42 | + random_model.setup_caches( |
| 43 | + max_batch_size=batch_size, max_seq_length=seq_len, training=is_training |
| 44 | + ) |
| 45 | + |
| 46 | + # Run multiple inference iterations to ensure consistency |
| 47 | + for i in range(3): |
| 48 | + out = random_model(input_ids, input_pos) |
| 49 | + self.assertIsNotNone(out, f"Model failed to run on iteration {i}") |
| 50 | + |
| 51 | + |
| 52 | +common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) |
| 53 | + |
| 54 | +if __name__ == "__main__": |
| 55 | + unittest.main() |
0 commit comments