Skip to content

Commit 030bae9

Browse files
committed
build: Guarded torchvision calls with available_packages
Signed-off-by: Brandon Groth <[email protected]>
1 parent bade4cb commit 030bae9

File tree

3 files changed

+125
-86
lines changed

3 files changed

+125
-86
lines changed

fms_mo/fx/dynamo_utils.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
get_target_op_from_mod_or_str,
3030
get_target_op_from_node,
3131
)
32+
from fms_mo.utils.import_utils import available_packages
3233

3334
logger = logging.getLogger(__name__)
3435

@@ -1133,7 +1134,6 @@ def cus_backend_model_analyzer(
11331134
from functools import partial
11341135

11351136
# Third Party
1136-
from torchvision.models import VisionTransformer
11371137
from transformers import PreTrainedModel
11381138

11391139
if issubclass(type(model), torch.nn.Module):
@@ -1145,7 +1145,16 @@ def cus_backend_model_analyzer(
11451145
model_to_be_traced = model
11461146
model_param_size = 999
11471147

1148-
is_transformers = issubclass(type(model), (PreTrainedModel, VisionTransformer))
1148+
transformer_model_classes = (PreTrainedModel,)
1149+
1150+
if available_packages["torchvision"]:
1151+
# Third Party
1152+
# pylint: disable = import-error
1153+
from torchvision.models import VisionTransformer
1154+
1155+
transformer_model_classes += (VisionTransformer,)
1156+
1157+
is_transformers = issubclass(type(model), transformer_model_classes)
11491158
if model_param_size > 1:
11501159
# Standard
11511160
import sys
@@ -1188,11 +1197,13 @@ def call_seq_hook(mod, *_args, **_kwargs):
11881197

11891198
# only add last layer
11901199
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][-1]]
1191-
# unless it's a ViT, skip first Conv as well
1192-
if issubclass(type(model), VisionTransformer) and isinstance(
1193-
model.get_submodule(qcfg["mod_call_seq"][0]), torch.nn.Conv2d
1194-
):
1195-
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][0]]
1200+
1201+
if available_packages["torchvision"]:
1202+
# unless it's a ViT, skip first Conv as well
1203+
if issubclass(type(model), VisionTransformer) and isinstance(
1204+
model.get_submodule(qcfg["mod_call_seq"][0]), torch.nn.Conv2d
1205+
):
1206+
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][0]]
11961207

11971208
with torch.no_grad():
11981209
model_opt = torch.compile(
@@ -1271,21 +1282,23 @@ def qbmm_auto_check(_mod, *_args, **_kwargs):
12711282
# c) identify RPN/FPN
12721283
# TODO this hack only works for torchvision models. will use find_rpn_fpn_gm()
12731284

1274-
# Third Party
1275-
from torchvision.models.detection.rpn import RegionProposalNetwork
1276-
from torchvision.ops import FeaturePyramidNetwork
1277-
1278-
rpnfpn_prefix = []
1279-
rpnfpn_convs = []
1280-
for n, m in model.named_modules():
1281-
if isinstance(m, (FeaturePyramidNetwork, RegionProposalNetwork)):
1282-
rpnfpn_prefix.append(n)
1283-
if isinstance(m, torch.nn.Conv2d) and any(
1284-
n.startswith(p) for p in rpnfpn_prefix
1285-
):
1286-
rpnfpn_convs.append(n)
1287-
if n not in qcfg["qskip_layer_name"]:
1288-
qcfg["qskip_layer_name"].append(n)
1285+
if available_packages["torchvision"]:
1286+
# Third Party
1287+
# pylint: disable = import-error
1288+
from torchvision.models.detection.rpn import RegionProposalNetwork
1289+
from torchvision.ops import FeaturePyramidNetwork
1290+
1291+
rpnfpn_prefix = []
1292+
rpnfpn_convs = []
1293+
for n, m in model.named_modules():
1294+
if isinstance(m, (FeaturePyramidNetwork, RegionProposalNetwork)):
1295+
rpnfpn_prefix.append(n)
1296+
if isinstance(m, torch.nn.Conv2d) and any(
1297+
n.startswith(p) for p in rpnfpn_prefix
1298+
):
1299+
rpnfpn_convs.append(n)
1300+
if n not in qcfg["qskip_layer_name"]:
1301+
qcfg["qskip_layer_name"].append(n)
12891302

12901303
if qcfg["N_backend_called"] > 1:
12911304
logger.warning(

tests/models/conftest.py

Lines changed: 64 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323

2424
# Third Party
2525
from torch.utils.data import DataLoader, TensorDataset
26-
from torchvision.io import read_image
27-
from torchvision.models import ResNet50_Weights, ViT_B_16_Weights, resnet50, vit_b_16
2826
from transformers import (
2927
BertConfig,
3028
BertModel,
@@ -43,6 +41,7 @@
4341
# fms_mo imports
4442
from fms_mo import qconfig_init
4543
from fms_mo.modules import QLSTM, QBmm, QConv2d, QConvTranspose2d, QLinear
44+
from fms_mo.utils.import_utils import available_packages
4645
from fms_mo.utils.qconfig_utils import get_mx_specs_defaults, set_mx_specs
4746

4847
########################
@@ -1123,75 +1122,82 @@ def required_pair(request):
11231122
# Vision Model Fixtures #
11241123
#########################
11251124

1126-
# Create img
1127-
# downloaded from torchvision github (vision/test/assets/encoder_jpeg/ directory)
1128-
img = read_image(
1129-
os.path.realpath(
1130-
os.path.join(os.path.dirname(__file__), "grace_hopper_517x606.jpg")
1131-
)
1132-
)
1133-
11341125

1135-
# Create resnet/vit batch fixtures from weights
1136-
def prepocess_img(image, weights):
1137-
"""
1138-
Preprocess an image w/ a weights.transform()
1139-
1140-
Args:
1141-
img (torch.FloatTensor): Image data
1142-
weights (torchvision.models): Weight object
1143-
1144-
Returns:
1145-
torch.FloatTensor: Preprocessed image
1146-
"""
1147-
preprocess = weights.transforms()
1148-
batch = preprocess(image).unsqueeze(0)
1149-
return batch
1126+
if available_packages["torchvision"]:
1127+
# Third Party
1128+
# pylint: disable = import-error
1129+
from torchvision.io import read_image
1130+
from torchvision.models import (
1131+
ResNet50_Weights,
1132+
ViT_B_16_Weights,
1133+
resnet50,
1134+
vit_b_16,
1135+
)
11501136

1137+
# Create img
1138+
# downloaded from torchvision github (vision/test/assets/encoder_jpeg/ directory)
1139+
img_tv = read_image(
1140+
os.path.realpath(
1141+
os.path.join(os.path.dirname(__file__), "grace_hopper_517x606.jpg")
1142+
)
1143+
)
11511144

1152-
@pytest.fixture(scope="session")
1153-
def batch_resnet():
1154-
"""
1155-
Preprocess an image w/ Resnet weights.transform()
1145+
# Create resnet/vitbatch fixtures from weights
1146+
def prepocess_img(image, weights):
1147+
"""
1148+
Preprocess an image w/ a weights.transform()
11561149
1157-
Returns:
1158-
torch.FloatTensor: Preprocessed image
1159-
"""
1160-
return prepocess_img(img, ResNet50_Weights.IMAGENET1K_V2)
1150+
Args:
1151+
img_tv (torch.FloatTensor): Image data
1152+
weights (torchvision.models): Weight object
11611153
1154+
Returns:
1155+
torch.FloatTensor: Preprocessed image
1156+
"""
1157+
preprocess = weights.transforms()
1158+
batch = preprocess(image).unsqueeze(0)
1159+
return batch
11621160

1163-
@pytest.fixture(scope="session")
1164-
def batch_vit():
1165-
"""
1166-
Preprocess an image w/ ViT weights.transform()
1161+
@pytest.fixture(scope="session")
1162+
def batch_resnet():
1163+
"""
1164+
Preprocess an image w/ Resnet weights.transform()
11671165
1168-
Returns:
1169-
torch.FloatTensor: Preprocessed image
1170-
"""
1171-
return prepocess_img(img, ViT_B_16_Weights.IMAGENET1K_V1)
1166+
Returns:
1167+
torch.FloatTensor: Preprocessed image
1168+
"""
1169+
return prepocess_img(img_tv, ResNet50_Weights.IMAGENET1K_V2)
11721170

1171+
@pytest.fixture(scope="session")
1172+
def batch_vit():
1173+
"""
1174+
Preprocess an image w/ ViT weights.transform()
11731175
1174-
# Create resnet/vit model fixtures from weights
1175-
@pytest.fixture(scope="function")
1176-
def model_resnet():
1177-
"""
1178-
Create Resnet50 model + weights
1176+
Returns:
1177+
torch.FloatTensor: Preprocessed image
1178+
"""
1179+
return prepocess_img(img_tv, ViT_B_16_Weights.IMAGENET1K_V1)
11791180

1180-
Returns:
1181-
torchvision.models.resnet.ResNet: Resnet50 model
1182-
"""
1183-
return resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
1181+
# Create resnet/vit model fixtures from weights
1182+
@pytest.fixture(scope="function")
1183+
def model_resnet():
1184+
"""
1185+
Create Resnet50 model + weights
11841186
1187+
Returns:
1188+
torchvision.models.resnet.ResNet: Resnet50 model
1189+
"""
1190+
return resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
11851191

1186-
@pytest.fixture(scope="function")
1187-
def model_vit():
1188-
"""
1189-
Create ViT model + weights
1192+
@pytest.fixture(scope="function")
1193+
def model_vit():
1194+
"""
1195+
Create ViT model + weights
11901196
1191-
Returns:
1192-
torchvision.models.vision_transformer.VisionTransformer: ViT model
1193-
"""
1194-
return vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
1197+
Returns:
1198+
torchvision.models.vision_transformer.VisionTransformer: ViT model
1199+
"""
1200+
return vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
11951201

11961202

11971203
#######################

tests/models/test_qmodelprep.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
# Third Party
2020
import pytest
2121
import torch
22-
import torchvision
2322
import transformers
2423

2524
# Local
2625
# fms_mo imports
2726
from fms_mo import qconfig_init, qmodel_prep
2827
from fms_mo.prep import has_quantized_module
28+
from fms_mo.utils.import_utils import available_packages
2929
from fms_mo.utils.utils import patch_torch_bmm
3030
from tests.models.test_model_utils import count_qmodules, delete_file, qmodule_error
3131

@@ -159,8 +159,12 @@ def test_config_fp32_qmodes(
159159
###########################
160160

161161

162+
@pytest.mark.skipif(
163+
not available_packages["torchvision"],
164+
reason="Requires torchvision",
165+
)
162166
def test_resnet50_torchscript(
163-
model_resnet: torchvision.models.resnet.ResNet,
167+
model_resnet,
164168
batch_resnet: torch.FloatTensor,
165169
config_int8: dict,
166170
):
@@ -177,8 +181,12 @@ def test_resnet50_torchscript(
177181
qmodule_error(model_resnet, 6, 48)
178182

179183

184+
@pytest.mark.skipif(
185+
not available_packages["torchvision"],
186+
reason="Requires torchvision",
187+
)
180188
def test_resnet50_dynamo(
181-
model_resnet: torchvision.models.resnet.ResNet,
189+
model_resnet,
182190
batch_resnet: torch.FloatTensor,
183191
config_int8: dict,
184192
):
@@ -195,8 +203,12 @@ def test_resnet50_dynamo(
195203
qmodule_error(model_resnet, 6, 48)
196204

197205

206+
@pytest.mark.skipif(
207+
not available_packages["torchvision"],
208+
reason="Requires torchvision",
209+
)
198210
def test_resnet50_dynamo_layers(
199-
model_resnet: torchvision.models.resnet.ResNet,
211+
model_resnet,
200212
batch_resnet: torch.FloatTensor,
201213
config_int8: dict,
202214
):
@@ -216,8 +228,12 @@ def test_resnet50_dynamo_layers(
216228

217229

218230
# Vision Transformer tests
231+
@pytest.mark.skipif(
232+
not available_packages["torchvision"],
233+
reason="Requires torchvision",
234+
)
219235
def test_vit_torchscript(
220-
model_vit: torchvision.models.vision_transformer.VisionTransformer,
236+
model_vit,
221237
batch_vit: torch.FloatTensor,
222238
config_int8: dict,
223239
):
@@ -234,8 +250,12 @@ def test_vit_torchscript(
234250
qmodule_error(model_vit, 2, 36)
235251

236252

253+
@pytest.mark.skipif(
254+
not available_packages["torchvision"],
255+
reason="Requires torchvision",
256+
)
237257
def test_vit_dynamo(
238-
model_vit: torchvision.models.vision_transformer.VisionTransformer,
258+
model_vit,
239259
batch_vit: torch.FloatTensor,
240260
config_int8: dict,
241261
):

0 commit comments

Comments
 (0)