Skip to content

Commit b47b08f

Browse files
Merge pull request #144 from BrandonGroth/no_torchvision
build: Move torchvision to an optional dependency
2 parents 9623337 + 271ca4d commit b47b08f

File tree

7 files changed

+239
-65
lines changed

7 files changed

+239
-65
lines changed

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,26 @@ cd fms-model-optimizer
9898
pip install -e .
9999
```
100100

101+
#### Optional Dependencies
102+
The following optional dependencies are available:
103+
- `fp8`: `llmcompressor` package for fp8 quantization
104+
- `gptq`: `GPTQModel` package for W4A16 quantization
105+
- `mx`: `microxcaling` package for MX quantization
106+
- `opt`: Shortcut for `fp8`, `gptq`, and `mx` installs
107+
- `torchvision`: `torch` package for image recognition training and inference
108+
- `visualize`: Dependencies for visualizing models and performance data
109+
- `test`: Dependencies needed for unit testing
110+
- `dev`: Dependencies needed for development
111+
112+
To install an optional dependency, modify the `pip install` commands above with a list of these names enclosed in brackets. The example below installs `llm-compressor` and `torchvision` with FMS Model Optimizer:
113+
114+
```shell
115+
pip install fms-model-optimizer[fp8,torchvision]
116+
117+
pip install -e .[fp8,torchvision]
118+
```
119+
If you have already installed FMS Model Optimizer, then only the optional packages will be installed.
120+
101121
### Try It Out!
102122

103123
To help you get up and running as quickly as possible with the FMS Model Optimizer framework, check out the following resources which demonstrate how to use the framework with different quantization techniques:

fms_mo/fx/dynamo_utils.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
get_target_op_from_mod_or_str,
3030
get_target_op_from_node,
3131
)
32+
from fms_mo.utils.import_utils import available_packages
3233

3334
logger = logging.getLogger(__name__)
3435

@@ -1133,7 +1134,6 @@ def cus_backend_model_analyzer(
11331134
from functools import partial
11341135

11351136
# Third Party
1136-
from torchvision.models import VisionTransformer
11371137
from transformers import PreTrainedModel
11381138

11391139
if issubclass(type(model), torch.nn.Module):
@@ -1145,7 +1145,16 @@ def cus_backend_model_analyzer(
11451145
model_to_be_traced = model
11461146
model_param_size = 999
11471147

1148-
is_transformers = issubclass(type(model), (PreTrainedModel, VisionTransformer))
1148+
transformer_model_classes = (PreTrainedModel,)
1149+
1150+
if available_packages["torchvision"]:
1151+
# Third Party
1152+
# pylint: disable = import-error
1153+
from torchvision.models import VisionTransformer
1154+
1155+
transformer_model_classes += (VisionTransformer,)
1156+
1157+
is_transformers = issubclass(type(model), transformer_model_classes)
11491158
if model_param_size > 1:
11501159
# Standard
11511160
import sys
@@ -1188,11 +1197,13 @@ def call_seq_hook(mod, *_args, **_kwargs):
11881197

11891198
# only add last layer
11901199
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][-1]]
1191-
# unless it's a ViT, skip first Conv as well
1192-
if issubclass(type(model), VisionTransformer) and isinstance(
1193-
model.get_submodule(qcfg["mod_call_seq"][0]), torch.nn.Conv2d
1194-
):
1195-
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][0]]
1200+
1201+
if available_packages["torchvision"]:
1202+
# unless it's a ViT, skip first Conv as well
1203+
if issubclass(type(model), VisionTransformer) and isinstance(
1204+
model.get_submodule(qcfg["mod_call_seq"][0]), torch.nn.Conv2d
1205+
):
1206+
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][0]]
11961207

11971208
with torch.no_grad():
11981209
model_opt = torch.compile(
@@ -1271,21 +1282,23 @@ def qbmm_auto_check(_mod, *_args, **_kwargs):
12711282
# c) identify RPN/FPN
12721283
# TODO this hack only works for torchvision models. will use find_rpn_fpn_gm()
12731284

1274-
# Third Party
1275-
from torchvision.models.detection.rpn import RegionProposalNetwork
1276-
from torchvision.ops import FeaturePyramidNetwork
1277-
1278-
rpnfpn_prefix = []
1279-
rpnfpn_convs = []
1280-
for n, m in model.named_modules():
1281-
if isinstance(m, (FeaturePyramidNetwork, RegionProposalNetwork)):
1282-
rpnfpn_prefix.append(n)
1283-
if isinstance(m, torch.nn.Conv2d) and any(
1284-
n.startswith(p) for p in rpnfpn_prefix
1285-
):
1286-
rpnfpn_convs.append(n)
1287-
if n not in qcfg["qskip_layer_name"]:
1288-
qcfg["qskip_layer_name"].append(n)
1285+
if available_packages["torchvision"]:
1286+
# Third Party
1287+
# pylint: disable = import-error
1288+
from torchvision.models.detection.rpn import RegionProposalNetwork
1289+
from torchvision.ops import FeaturePyramidNetwork
1290+
1291+
rpnfpn_prefix = []
1292+
rpnfpn_convs = []
1293+
for n, m in model.named_modules():
1294+
if isinstance(m, (FeaturePyramidNetwork, RegionProposalNetwork)):
1295+
rpnfpn_prefix.append(n)
1296+
if isinstance(m, torch.nn.Conv2d) and any(
1297+
n.startswith(p) for p in rpnfpn_prefix
1298+
):
1299+
rpnfpn_convs.append(n)
1300+
if n not in qcfg["qskip_layer_name"]:
1301+
qcfg["qskip_layer_name"].append(n)
12891302

12901303
if qcfg["N_backend_called"] > 1:
12911304
logger.warning(

fms_mo/utils/import_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"pygraphviz",
3232
"fms",
3333
"triton",
34+
"torchvision",
3435
]
3536

3637
available_packages = {}

pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ dependencies = [
3232
"ninja>=1.11.1.1,<2.0",
3333
"tensorboard",
3434
"notebook",
35-
"torchvision>=0.17",
3635
"evaluate",
3736
"huggingface_hub",
3837
"pandas",
@@ -42,13 +41,15 @@ dependencies = [
4241
]
4342

4443
[project.optional-dependencies]
45-
dev = ["pre-commit>=3.0.4,<5.0"]
4644
fp8 = ["llmcompressor"]
4745
gptq = ["Cython", "gptqmodel>=1.7.3"]
4846
mx = ["microxcaling>=1.1"]
49-
visualize = ["matplotlib", "graphviz", "pygraphviz"]
47+
opt = ["fms-model-optimizer[fp8, gptq, mx]"]
48+
torchvision = ["torchvision>=0.17"]
5049
flash-attn = ["flash-attn>=2.5.3,<3.0"]
51-
opt = ["fms-model-optimizer[fp8, gptq]"]
50+
visualize = ["matplotlib", "graphviz", "pygraphviz"]
51+
dev = ["pre-commit>=3.0.4,<5.0"]
52+
test = ["pytest", "pillow"]
5253

5354
[project.urls]
5455
homepage = "https://github.com/foundation-model-stack/fms-model-optimizer"

tests/models/conftest.py

Lines changed: 115 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@
2222
import os
2323

2424
# Third Party
25+
from PIL import Image # pylint: disable=import-error
2526
from torch.utils.data import DataLoader, TensorDataset
26-
from torchvision.io import read_image
27-
from torchvision.models import ResNet50_Weights, ViT_B_16_Weights, resnet50, vit_b_16
2827
from transformers import (
28+
AutoImageProcessor,
29+
AutoModelForImageClassification,
2930
BertConfig,
3031
BertModel,
3132
BertTokenizer,
@@ -43,6 +44,7 @@
4344
# fms_mo imports
4445
from fms_mo import qconfig_init
4546
from fms_mo.modules import QLSTM, QBmm, QConv2d, QConvTranspose2d, QLinear
47+
from fms_mo.utils.import_utils import available_packages
4648
from fms_mo.utils.qconfig_utils import get_mx_specs_defaults, set_mx_specs
4749

4850
########################
@@ -1123,75 +1125,155 @@ def required_pair(request):
11231125
# Vision Model Fixtures #
11241126
#########################
11251127

1126-
# Create img
1127-
# downloaded from torchvision github (vision/test/assets/encoder_jpeg/ directory)
1128-
img = read_image(
1128+
1129+
if available_packages["torchvision"]:
1130+
# Third Party
1131+
# pylint: disable = import-error
1132+
from torchvision.io import read_image
1133+
from torchvision.models import (
1134+
ResNet50_Weights,
1135+
ViT_B_16_Weights,
1136+
resnet50,
1137+
vit_b_16,
1138+
)
1139+
1140+
# Create img
1141+
# downloaded from torchvision github (vision/test/assets/encoder_jpeg/ directory)
1142+
img_tv = read_image(
1143+
os.path.realpath(
1144+
os.path.join(os.path.dirname(__file__), "grace_hopper_517x606.jpg")
1145+
)
1146+
)
1147+
1148+
# Create resnet/vitbatch fixtures from weights
1149+
def prepocess_img(image, weights):
1150+
"""
1151+
Preprocess an image w/ a weights.transform()
1152+
1153+
Args:
1154+
img_tv (torch.FloatTensor): Image data
1155+
weights (torchvision.models): Weight object
1156+
1157+
Returns:
1158+
torch.FloatTensor: Preprocessed image
1159+
"""
1160+
preprocess = weights.transforms()
1161+
batch = preprocess(image).unsqueeze(0)
1162+
return batch
1163+
1164+
@pytest.fixture(scope="session")
1165+
def batch_resnet():
1166+
"""
1167+
Preprocess an image w/ Resnet weights.transform()
1168+
1169+
Returns:
1170+
torch.FloatTensor: Preprocessed image
1171+
"""
1172+
return prepocess_img(img_tv, ResNet50_Weights.IMAGENET1K_V2)
1173+
1174+
@pytest.fixture(scope="session")
1175+
def batch_vit():
1176+
"""
1177+
Preprocess an image w/ ViT weights.transform()
1178+
1179+
Returns:
1180+
torch.FloatTensor: Preprocessed image
1181+
"""
1182+
return prepocess_img(img_tv, ViT_B_16_Weights.IMAGENET1K_V1)
1183+
1184+
# Create resnet/vit model fixtures from weights
1185+
@pytest.fixture(scope="function")
1186+
def model_resnet():
1187+
"""
1188+
Create Resnet50 model + weights
1189+
1190+
Returns:
1191+
torchvision.models.resnet.ResNet: Resnet50 model
1192+
"""
1193+
return resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
1194+
1195+
@pytest.fixture(scope="function")
1196+
def model_vit():
1197+
"""
1198+
Create ViT model + weights
1199+
1200+
Returns:
1201+
torchvision.models.vision_transformer.VisionTransformer: ViT model
1202+
"""
1203+
return vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
1204+
1205+
1206+
img = Image.open(
11291207
os.path.realpath(
11301208
os.path.join(os.path.dirname(__file__), "grace_hopper_517x606.jpg")
11311209
)
1132-
)
1210+
).convert("RGB")
11331211

11341212

1135-
# Create resnet/vit batch fixtures from weights
1136-
def prepocess_img(image, weights):
1213+
def process_img(
1214+
pretrained_model: str,
1215+
input_img: Image.Image,
1216+
):
11371217
"""
1138-
Preprocess an image w/ a weights.transform()
1218+
Process an image w/ AutoImageProcessor
11391219
11401220
Args:
1141-
img (torch.FloatTensor): Image data
1142-
weights (torchvision.models): Weight object
1221+
processor (AutoImageProcessor): Processor weights for pretrained model
1222+
pretrained_model (str): Weight object
1223+
input_img (Image.Image): Image data
11431224
11441225
Returns:
1145-
torch.FloatTensor: Preprocessed image
1226+
torch.FloatTensor: Processed image
11461227
"""
1147-
preprocess = weights.transforms()
1148-
batch = preprocess(image).unsqueeze(0)
1149-
return batch
1228+
img_processor = AutoImageProcessor.from_pretrained(pretrained_model, use_fast=True)
1229+
batch_dict = img_processor(images=input_img, return_tensors="pt")
1230+
return batch_dict["pixel_values"]
11501231

11511232

1152-
@pytest.fixture(scope="session")
1153-
def batch_resnet():
1233+
@pytest.fixture(scope="function")
1234+
def batch_resnet18():
11541235
"""
1155-
Preprocess an image w/ Resnet weights.transform()
1236+
Preprocess an image w/ ms resnet18 processor
11561237
11571238
Returns:
11581239
torch.FloatTensor: Preprocessed image
11591240
"""
1160-
return prepocess_img(img, ResNet50_Weights.IMAGENET1K_V2)
1241+
return process_img("microsoft/resnet-18", img)
11611242

11621243

1163-
@pytest.fixture(scope="session")
1164-
def batch_vit():
1244+
@pytest.fixture(scope="function")
1245+
def model_resnet18():
11651246
"""
1166-
Preprocess an image w/ ViT weights.transform()
1247+
Create MS ResNet18 model + weights
11671248
11681249
Returns:
1169-
torch.FloatTensor: Preprocessed image
1250+
AutoModelForImageClassification: Resnet18 model
11701251
"""
1171-
return prepocess_img(img, ViT_B_16_Weights.IMAGENET1K_V1)
1252+
return AutoModelForImageClassification.from_pretrained("microsoft/resnet-18")
11721253

11731254

1174-
# Create resnet/vit model fixtures from weights
11751255
@pytest.fixture(scope="function")
1176-
def model_resnet():
1256+
def batch_vit_base():
11771257
"""
1178-
Create Resnet50 model + weights
1258+
Preprocess an image w/ Google ViT-base processor
11791259
11801260
Returns:
1181-
torchvision.models.resnet.ResNet: Resnet50 model
1261+
torch.FloatTensor: Preprocessed image
11821262
"""
1183-
return resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
1263+
return process_img("google/vit-base-patch16-224", img)
11841264

11851265

11861266
@pytest.fixture(scope="function")
1187-
def model_vit():
1267+
def model_vit_base():
11881268
"""
1189-
Create ViT model + weights
1269+
Create Google ViT-base model + weights
11901270
11911271
Returns:
1192-
torchvision.models.vision_transformer.VisionTransformer: ViT model
1272+
AutoModelForImageClassification: Google ViT-base model
11931273
"""
1194-
return vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
1274+
return AutoModelForImageClassification.from_pretrained(
1275+
"google/vit-base-patch16-224"
1276+
)
11951277

11961278

11971279
#######################

0 commit comments

Comments
 (0)