|
22 | 22 | import os |
23 | 23 |
|
24 | 24 | # Third Party |
| 25 | +from PIL import Image # pylint: disable=import-error |
25 | 26 | from torch.utils.data import DataLoader, TensorDataset |
26 | 27 | from transformers import ( |
| 28 | + AutoImageProcessor, |
| 29 | + AutoModelForImageClassification, |
27 | 30 | BertConfig, |
28 | 31 | BertModel, |
29 | 32 | BertTokenizer, |
@@ -1200,6 +1203,81 @@ def model_vit(): |
1200 | 1203 | return vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1) |
1201 | 1204 |
|
1202 | 1205 |
|
| 1206 | +img = Image.open( |
| 1207 | + os.path.realpath( |
| 1208 | + os.path.join(os.path.dirname(__file__), "grace_hopper_517x606.jpg") |
| 1209 | + ) |
| 1210 | +).convert("RGB") |
| 1211 | + |
| 1212 | + |
| 1213 | +def process_img( |
| 1214 | + pretrained_model: str, |
| 1215 | + input_img: Image.Image, |
| 1216 | +): |
| 1217 | + """ |
| 1218 | + Process an image w/ AutoImageProcessor |
| 1219 | +
|
| 1220 | + Args: |
| 1221 | + processor (AutoImageProcessor): Processor weights for pretrained model |
| 1222 | + pretrained_model (str): Weight object |
| 1223 | + input_img (Image.Image): Image data |
| 1224 | +
|
| 1225 | + Returns: |
| 1226 | + torch.FloatTensor: Processed image |
| 1227 | + """ |
| 1228 | + img_processor = AutoImageProcessor.from_pretrained(pretrained_model) |
| 1229 | + batch_dict = img_processor(images=input_img, return_tensor="pt", use_fast=False) |
| 1230 | + # Data is {pixel_values: numpy_array[0]=data} w/ tensor.shape [C,W,H] |
| 1231 | + # Needs to be [1,C,W,H] -> unsqueeze(0) |
| 1232 | + return torch.from_numpy(batch_dict["pixel_values"][0]).unsqueeze(0) |
| 1233 | + |
| 1234 | + |
| 1235 | +@pytest.fixture(scope="function") |
| 1236 | +def batch_resnet18(): |
| 1237 | + """ |
| 1238 | + Preprocess an image w/ ms resnet18 processor |
| 1239 | +
|
| 1240 | + Returns: |
| 1241 | + torch.FloatTensor: Preprocessed image |
| 1242 | + """ |
| 1243 | + return process_img("microsoft/resnet-18", img) |
| 1244 | + |
| 1245 | + |
| 1246 | +@pytest.fixture(scope="function") |
| 1247 | +def model_resnet18(): |
| 1248 | + """ |
| 1249 | + Create MS ResNet18 model + weights |
| 1250 | +
|
| 1251 | + Returns: |
| 1252 | + AutoModelForImageClassification: Resnet18 model |
| 1253 | + """ |
| 1254 | + return AutoModelForImageClassification.from_pretrained("microsoft/resnet-18") |
| 1255 | + |
| 1256 | + |
| 1257 | +@pytest.fixture(scope="function") |
| 1258 | +def batch_vit_base(): |
| 1259 | + """ |
| 1260 | + Preprocess an image w/ Google ViT-base processor |
| 1261 | +
|
| 1262 | + Returns: |
| 1263 | + torch.FloatTensor: Preprocessed image |
| 1264 | + """ |
| 1265 | + return process_img("google/vit-base-patch16-224", img) |
| 1266 | + |
| 1267 | + |
| 1268 | +@pytest.fixture(scope="function") |
| 1269 | +def model_vit_base(): |
| 1270 | + """ |
| 1271 | + Create Google ViT-base model + weights |
| 1272 | +
|
| 1273 | + Returns: |
| 1274 | + AutoModelForImageClassification: Google ViT-base model |
| 1275 | + """ |
| 1276 | + return AutoModelForImageClassification.from_pretrained( |
| 1277 | + "google/vit-base-patch16-224" |
| 1278 | + ) |
| 1279 | + |
| 1280 | + |
1203 | 1281 | ####################### |
1204 | 1282 | # BERT Model Fixtures # |
1205 | 1283 | ####################### |
|
0 commit comments