|
22 | 22 | import os |
23 | 23 |
|
24 | 24 | # Third Party |
| 25 | +from PIL import Image # pylint: disable=import-error |
25 | 26 | from torch.utils.data import DataLoader, TensorDataset |
26 | | -from torchvision.io import read_image |
27 | | -from torchvision.models import ResNet50_Weights, ViT_B_16_Weights, resnet50, vit_b_16 |
28 | 27 | from transformers import ( |
| 28 | + AutoImageProcessor, |
| 29 | + AutoModelForImageClassification, |
29 | 30 | BertConfig, |
30 | 31 | BertModel, |
31 | 32 | BertTokenizer, |
|
43 | 44 | # fms_mo imports |
44 | 45 | from fms_mo import qconfig_init |
45 | 46 | from fms_mo.modules import QLSTM, QBmm, QConv2d, QConvTranspose2d, QLinear |
| 47 | +from fms_mo.utils.import_utils import available_packages |
46 | 48 | from fms_mo.utils.qconfig_utils import get_mx_specs_defaults, set_mx_specs |
47 | 49 |
|
48 | 50 | ######################## |
@@ -1123,75 +1125,155 @@ def required_pair(request): |
1123 | 1125 | # Vision Model Fixtures # |
1124 | 1126 | ######################### |
1125 | 1127 |
|
1126 | | -# Create img |
1127 | | -# downloaded from torchvision github (vision/test/assets/encoder_jpeg/ directory) |
1128 | | -img = read_image( |
| 1128 | + |
| 1129 | +if available_packages["torchvision"]: |
| 1130 | + # Third Party |
| 1131 | + # pylint: disable = import-error |
| 1132 | + from torchvision.io import read_image |
| 1133 | + from torchvision.models import ( |
| 1134 | + ResNet50_Weights, |
| 1135 | + ViT_B_16_Weights, |
| 1136 | + resnet50, |
| 1137 | + vit_b_16, |
| 1138 | + ) |
| 1139 | + |
| 1140 | + # Create img |
| 1141 | + # downloaded from torchvision github (vision/test/assets/encoder_jpeg/ directory) |
| 1142 | + img_tv = read_image( |
| 1143 | + os.path.realpath( |
| 1144 | + os.path.join(os.path.dirname(__file__), "grace_hopper_517x606.jpg") |
| 1145 | + ) |
| 1146 | + ) |
| 1147 | + |
| 1148 | + # Create resnet/vitbatch fixtures from weights |
| 1149 | + def prepocess_img(image, weights): |
| 1150 | + """ |
| 1151 | + Preprocess an image w/ a weights.transform() |
| 1152 | +
|
| 1153 | + Args: |
| 1154 | + img_tv (torch.FloatTensor): Image data |
| 1155 | + weights (torchvision.models): Weight object |
| 1156 | +
|
| 1157 | + Returns: |
| 1158 | + torch.FloatTensor: Preprocessed image |
| 1159 | + """ |
| 1160 | + preprocess = weights.transforms() |
| 1161 | + batch = preprocess(image).unsqueeze(0) |
| 1162 | + return batch |
| 1163 | + |
| 1164 | + @pytest.fixture(scope="session") |
| 1165 | + def batch_resnet(): |
| 1166 | + """ |
| 1167 | + Preprocess an image w/ Resnet weights.transform() |
| 1168 | +
|
| 1169 | + Returns: |
| 1170 | + torch.FloatTensor: Preprocessed image |
| 1171 | + """ |
| 1172 | + return prepocess_img(img_tv, ResNet50_Weights.IMAGENET1K_V2) |
| 1173 | + |
| 1174 | + @pytest.fixture(scope="session") |
| 1175 | + def batch_vit(): |
| 1176 | + """ |
| 1177 | + Preprocess an image w/ ViT weights.transform() |
| 1178 | +
|
| 1179 | + Returns: |
| 1180 | + torch.FloatTensor: Preprocessed image |
| 1181 | + """ |
| 1182 | + return prepocess_img(img_tv, ViT_B_16_Weights.IMAGENET1K_V1) |
| 1183 | + |
| 1184 | + # Create resnet/vit model fixtures from weights |
| 1185 | + @pytest.fixture(scope="function") |
| 1186 | + def model_resnet(): |
| 1187 | + """ |
| 1188 | + Create Resnet50 model + weights |
| 1189 | +
|
| 1190 | + Returns: |
| 1191 | + torchvision.models.resnet.ResNet: Resnet50 model |
| 1192 | + """ |
| 1193 | + return resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) |
| 1194 | + |
| 1195 | + @pytest.fixture(scope="function") |
| 1196 | + def model_vit(): |
| 1197 | + """ |
| 1198 | + Create ViT model + weights |
| 1199 | +
|
| 1200 | + Returns: |
| 1201 | + torchvision.models.vision_transformer.VisionTransformer: ViT model |
| 1202 | + """ |
| 1203 | + return vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1) |
| 1204 | + |
| 1205 | + |
| 1206 | +img = Image.open( |
1129 | 1207 | os.path.realpath( |
1130 | 1208 | os.path.join(os.path.dirname(__file__), "grace_hopper_517x606.jpg") |
1131 | 1209 | ) |
1132 | | -) |
| 1210 | +).convert("RGB") |
1133 | 1211 |
|
1134 | 1212 |
|
1135 | | -# Create resnet/vit batch fixtures from weights |
1136 | | -def prepocess_img(image, weights): |
| 1213 | +def process_img( |
| 1214 | + pretrained_model: str, |
| 1215 | + input_img: Image.Image, |
| 1216 | +): |
1137 | 1217 | """ |
1138 | | - Preprocess an image w/ a weights.transform() |
| 1218 | + Process an image w/ AutoImageProcessor |
1139 | 1219 |
|
1140 | 1220 | Args: |
1141 | | - img (torch.FloatTensor): Image data |
1142 | | - weights (torchvision.models): Weight object |
| 1221 | + processor (AutoImageProcessor): Processor weights for pretrained model |
| 1222 | + pretrained_model (str): Weight object |
| 1223 | + input_img (Image.Image): Image data |
1143 | 1224 |
|
1144 | 1225 | Returns: |
1145 | | - torch.FloatTensor: Preprocessed image |
| 1226 | + torch.FloatTensor: Processed image |
1146 | 1227 | """ |
1147 | | - preprocess = weights.transforms() |
1148 | | - batch = preprocess(image).unsqueeze(0) |
1149 | | - return batch |
| 1228 | + img_processor = AutoImageProcessor.from_pretrained(pretrained_model, use_fast=True) |
| 1229 | + batch_dict = img_processor(images=input_img, return_tensors="pt") |
| 1230 | + return batch_dict["pixel_values"] |
1150 | 1231 |
|
1151 | 1232 |
|
1152 | | -@pytest.fixture(scope="session") |
1153 | | -def batch_resnet(): |
| 1233 | +@pytest.fixture(scope="function") |
| 1234 | +def batch_resnet18(): |
1154 | 1235 | """ |
1155 | | - Preprocess an image w/ Resnet weights.transform() |
| 1236 | + Preprocess an image w/ ms resnet18 processor |
1156 | 1237 |
|
1157 | 1238 | Returns: |
1158 | 1239 | torch.FloatTensor: Preprocessed image |
1159 | 1240 | """ |
1160 | | - return prepocess_img(img, ResNet50_Weights.IMAGENET1K_V2) |
| 1241 | + return process_img("microsoft/resnet-18", img) |
1161 | 1242 |
|
1162 | 1243 |
|
1163 | | -@pytest.fixture(scope="session") |
1164 | | -def batch_vit(): |
| 1244 | +@pytest.fixture(scope="function") |
| 1245 | +def model_resnet18(): |
1165 | 1246 | """ |
1166 | | - Preprocess an image w/ ViT weights.transform() |
| 1247 | + Create MS ResNet18 model + weights |
1167 | 1248 |
|
1168 | 1249 | Returns: |
1169 | | - torch.FloatTensor: Preprocessed image |
| 1250 | + AutoModelForImageClassification: Resnet18 model |
1170 | 1251 | """ |
1171 | | - return prepocess_img(img, ViT_B_16_Weights.IMAGENET1K_V1) |
| 1252 | + return AutoModelForImageClassification.from_pretrained("microsoft/resnet-18") |
1172 | 1253 |
|
1173 | 1254 |
|
1174 | | -# Create resnet/vit model fixtures from weights |
1175 | 1255 | @pytest.fixture(scope="function") |
1176 | | -def model_resnet(): |
| 1256 | +def batch_vit_base(): |
1177 | 1257 | """ |
1178 | | - Create Resnet50 model + weights |
| 1258 | + Preprocess an image w/ Google ViT-base processor |
1179 | 1259 |
|
1180 | 1260 | Returns: |
1181 | | - torchvision.models.resnet.ResNet: Resnet50 model |
| 1261 | + torch.FloatTensor: Preprocessed image |
1182 | 1262 | """ |
1183 | | - return resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) |
| 1263 | + return process_img("google/vit-base-patch16-224", img) |
1184 | 1264 |
|
1185 | 1265 |
|
1186 | 1266 | @pytest.fixture(scope="function") |
1187 | | -def model_vit(): |
| 1267 | +def model_vit_base(): |
1188 | 1268 | """ |
1189 | | - Create ViT model + weights |
| 1269 | + Create Google ViT-base model + weights |
1190 | 1270 |
|
1191 | 1271 | Returns: |
1192 | | - torchvision.models.vision_transformer.VisionTransformer: ViT model |
| 1272 | + AutoModelForImageClassification: Google ViT-base model |
1193 | 1273 | """ |
1194 | | - return vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1) |
| 1274 | + return AutoModelForImageClassification.from_pretrained( |
| 1275 | + "google/vit-base-patch16-224" |
| 1276 | + ) |
1195 | 1277 |
|
1196 | 1278 |
|
1197 | 1279 | ####################### |
|
0 commit comments