Skip to content

Commit f417260

Browse files
committed
test: Added non-torchvision resnet + vit tests
Signed-off-by: Brandon Groth <[email protected]>
1 parent 030bae9 commit f417260

File tree

2 files changed

+114
-0
lines changed

2 files changed

+114
-0
lines changed

tests/models/conftest.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,11 @@
2222
import os
2323

2424
# Third Party
25+
from PIL import Image # pylint: disable=import-error
2526
from torch.utils.data import DataLoader, TensorDataset
2627
from transformers import (
28+
AutoImageProcessor,
29+
AutoModelForImageClassification,
2730
BertConfig,
2831
BertModel,
2932
BertTokenizer,
@@ -1200,6 +1203,81 @@ def model_vit():
12001203
return vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
12011204

12021205

1206+
img = Image.open(
1207+
os.path.realpath(
1208+
os.path.join(os.path.dirname(__file__), "grace_hopper_517x606.jpg")
1209+
)
1210+
).convert("RGB")
1211+
1212+
1213+
def process_img(
1214+
pretrained_model: str,
1215+
input_img: Image.Image,
1216+
):
1217+
"""
1218+
Process an image w/ AutoImageProcessor
1219+
1220+
Args:
1221+
processor (AutoImageProcessor): Processor weights for pretrained model
1222+
pretrained_model (str): Weight object
1223+
input_img (Image.Image): Image data
1224+
1225+
Returns:
1226+
torch.FloatTensor: Processed image
1227+
"""
1228+
img_processor = AutoImageProcessor.from_pretrained(pretrained_model)
1229+
batch_dict = img_processor(images=input_img, return_tensor="pt", use_fast=False)
1230+
# Data is {pixel_values: numpy_array[0]=data} w/ tensor.shape [C,W,H]
1231+
# Needs to be [1,C,W,H] -> unsqueeze(0)
1232+
return torch.from_numpy(batch_dict["pixel_values"][0]).unsqueeze(0)
1233+
1234+
1235+
@pytest.fixture(scope="function")
1236+
def batch_resnet18():
1237+
"""
1238+
Preprocess an image w/ ms resnet18 processor
1239+
1240+
Returns:
1241+
torch.FloatTensor: Preprocessed image
1242+
"""
1243+
return process_img("microsoft/resnet-18", img)
1244+
1245+
1246+
@pytest.fixture(scope="function")
1247+
def model_resnet18():
1248+
"""
1249+
Create MS ResNet18 model + weights
1250+
1251+
Returns:
1252+
AutoModelForImageClassification: Resnet18 model
1253+
"""
1254+
return AutoModelForImageClassification.from_pretrained("microsoft/resnet-18")
1255+
1256+
1257+
@pytest.fixture(scope="function")
1258+
def batch_vit_base():
1259+
"""
1260+
Preprocess an image w/ Google ViT-base processor
1261+
1262+
Returns:
1263+
torch.FloatTensor: Preprocessed image
1264+
"""
1265+
return process_img("google/vit-base-patch16-224", img)
1266+
1267+
1268+
@pytest.fixture(scope="function")
1269+
def model_vit_base():
1270+
"""
1271+
Create Google ViT-base model + weights
1272+
1273+
Returns:
1274+
AutoModelForImageClassification: Google ViT-base model
1275+
"""
1276+
return AutoModelForImageClassification.from_pretrained(
1277+
"google/vit-base-patch16-224"
1278+
)
1279+
1280+
12031281
#######################
12041282
# BERT Model Fixtures #
12051283
#######################

tests/models/test_qmodelprep.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,42 @@ def test_vit_dynamo(
272272
qmodule_error(model_vit, 2, 36)
273273

274274

275+
def test_resnet18(
276+
model_resnet18,
277+
batch_resnet18,
278+
config_int8: dict,
279+
):
280+
"""
281+
Perform int8 quantization on ResNet-18 w/ Dynamo tracer
282+
283+
Args:
284+
model_resnet18 (AutoModelForImageClassification): Resnet18 model + weights
285+
batch_resnet18 (torch.FloatTensor): Batch image data for Resnet18
286+
config (dict): Recipe Config w/ int8 settings
287+
"""
288+
# Run qmodel_prep w/ Dynamo tracer
289+
qmodel_prep(model_resnet18, batch_resnet18, config_int8, use_dynamo=True)
290+
qmodule_error(model_resnet18, 4, 17)
291+
292+
293+
def test_vit_base(
294+
model_vit_base,
295+
batch_vit_base,
296+
config_int8: dict,
297+
):
298+
"""
299+
Perform int8 quantization on ViT-base w/ Dynamo tracer
300+
301+
Args:
302+
model_vit_base (AutoModelForImageClassification): Resnet18 model + weights
303+
batch_vit_base (torch.FloatTensor): Batch image data for Resnet18
304+
config (dict): Recipe Config w/ int8 settings
305+
"""
306+
# Run qmodel_prep w/ Dynamo tracer
307+
qmodel_prep(model_vit_base, batch_vit_base, config_int8, use_dynamo=True)
308+
qmodule_error(model_vit_base, 1, 73)
309+
310+
275311
def test_bert_dynamo(
276312
model_bert: transformers.models.bert.modeling_bert.BertModel,
277313
input_bert: torch.FloatTensor,

0 commit comments

Comments
 (0)