Skip to content

Commit 0347f35

Browse files
authored
Convert model inference test from pytest to unittest (#2644)
* convert ao inference test from pytest to unittest * refactor: `common_utils` for common parameters * incline common params * fix uncorrect library import
1 parent d8bb51f commit 0347f35

File tree

1 file changed

+39
-18
lines changed

1 file changed

+39
-18
lines changed

test/test_ao_models.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,53 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import pytest
6+
import unittest
7+
78
import torch
9+
from torch.testing._internal import common_utils
810

911
from torchao._models.llama.model import Transformer
1012

11-
_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
12-
1313

1414
def init_model(name="stories15M", device="cpu", precision=torch.bfloat16):
15+
"""Initialize and return a Transformer model with specified configuration."""
1516
model = Transformer.from_name(name)
1617
model.to(device=device, dtype=precision)
1718
return model.eval()
1819

1920

20-
@pytest.mark.parametrize("device", _AVAILABLE_DEVICES)
21-
@pytest.mark.parametrize("batch_size", [1, 4])
22-
@pytest.mark.parametrize("is_training", [True, False])
23-
def test_ao_llama_model_inference_mode(device, batch_size, is_training):
24-
random_model = init_model(device=device)
25-
seq_len = 16
26-
input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device)
27-
input_pos = None if is_training else torch.arange(seq_len).to(device)
28-
with torch.device(device):
29-
random_model.setup_caches(
30-
max_batch_size=batch_size, max_seq_length=seq_len, training=is_training
31-
)
32-
for i in range(3):
33-
out = random_model(input_ids, input_pos)
34-
assert out is not None, "model failed to run"
21+
class TorchAOBasicTestCase(unittest.TestCase):
22+
"""Test suite for basic Transformer inference functionality."""
23+
24+
@common_utils.parametrize(
25+
"device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
26+
)
27+
@common_utils.parametrize("batch_size", [1, 4])
28+
@common_utils.parametrize("is_training", [True, False])
29+
def test_ao_inference_mode(self, device, batch_size, is_training):
30+
# Initialize model with specified device
31+
random_model = init_model(device=device)
32+
33+
# Set up test input parameters
34+
seq_len = 16
35+
input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device)
36+
37+
# input_pos is None for training mode, tensor for inference mode
38+
input_pos = None if is_training else torch.arange(seq_len).to(device)
39+
40+
# Setup model caches within the device context
41+
with torch.device(device):
42+
random_model.setup_caches(
43+
max_batch_size=batch_size, max_seq_length=seq_len, training=is_training
44+
)
45+
46+
# Run multiple inference iterations to ensure consistency
47+
for i in range(3):
48+
out = random_model(input_ids, input_pos)
49+
self.assertIsNotNone(out, f"Model failed to run on iteration {i}")
50+
51+
52+
common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
53+
54+
if __name__ == "__main__":
55+
unittest.main()

0 commit comments

Comments
 (0)