Skip to content

Commit 02008db

Browse files
Merge branch 'main' into dependabot/pip/accelerate-gte-0.20.3-and-neq-0.34-and-lt-1.9
Signed-off-by: chichun-charlie-liu <[email protected]>
2 parents ef850f1 + 06e371a commit 02008db

File tree

17 files changed

+572
-128
lines changed

17 files changed

+572
-128
lines changed

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,26 @@ cd fms-model-optimizer
9898
pip install -e .
9999
```
100100

101+
#### Optional Dependencies
102+
The following optional dependencies are available:
103+
- `fp8`: `llmcompressor` package for fp8 quantization
104+
- `gptq`: `GPTQModel` package for W4A16 quantization
105+
- `mx`: `microxcaling` package for MX quantization
106+
- `opt`: Shortcut for `fp8`, `gptq`, and `mx` installs
107+
- `torchvision`: `torch` package for image recognition training and inference
108+
- `visualize`: Dependencies for visualizing models and performance data
109+
- `test`: Dependencies needed for unit testing
110+
- `dev`: Dependencies needed for development
111+
112+
To install an optional dependency, modify the `pip install` commands above with a list of these names enclosed in brackets. The example below installs `llm-compressor` and `torchvision` with FMS Model Optimizer:
113+
114+
```shell
115+
pip install fms-model-optimizer[fp8,torchvision]
116+
117+
pip install -e .[fp8,torchvision]
118+
```
119+
If you have already installed FMS Model Optimizer, then only the optional packages will be installed.
120+
101121
### Try It Out!
102122

103123
To help you get up and running as quickly as possible with the FMS Model Optimizer framework, check out the following resources which demonstrate how to use the framework with different quantization techniques:

fms_mo/calib.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -574,13 +574,19 @@ def qmodel_calib(
574574
f"Qmodel calibration (clip_val analysis) in progress: {i}/{Nbatch}"
575575
)
576576

577-
if "perCh" not in qcfg["qw_mode"]:
578-
cv_sum_dict = {"layer": [], "value": []}
579-
for k, v in tempmodel.state_dict().items():
580-
if "clip" in k:
581-
cv_sum_dict["layer"].append(k)
582-
cv_sum_dict["value"].append(v.item())
583-
logger.info(f"Observed clipvals: \n{ pd.DataFrame(cv_sum_dict) }")
577+
cv_sum_dict = {"layer": [], "value": []}
578+
for k, v in tempmodel.state_dict().items():
579+
if "clip" not in k:
580+
continue
581+
582+
if v.numel() > 1:
583+
k = k + "*"
584+
v = v.mean()
585+
cv_sum_dict["layer"].append(k)
586+
cv_sum_dict["value"].append(v.item())
587+
logger.info(
588+
f"Observed clipvals: ('*' if it's a vector) \n{ pd.DataFrame(cv_sum_dict) }"
589+
)
584590

585591
# Step 3: extract new clip_vals, params and buffers, then remove handles if needed
586592
temp_new_clipvals = {

fms_mo/dq.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636

3737
# Local
3838
from fms_mo import qconfig_init, qmodel_prep
39+
from fms_mo.custom_ext_kernels.utils import (
40+
lower_qmodel_triton, # pylint: disable=unused-import
41+
)
3942
from fms_mo.fx.utils import model_size_Wb
4043
from fms_mo.quant.ptq import (
4144
calibration_llm_1GPU_v2,
@@ -256,6 +259,15 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
256259
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
257260
tokenizer.save_pretrained(opt_args.output_dir)
258261

262+
if fms_mo_args.aiu_sim_triton:
263+
lower_qmodel_triton(
264+
model,
265+
use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False,
266+
max_acc_bits=qcfg.get("max_acc_bits", 32),
267+
num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0),
268+
chunk_size=qcfg.get("chunk_size", 1024),
269+
)
270+
259271
if fms_mo_args.eval_ppl:
260272
path_test = Path(data_args.test_data_path)
261273
arrow_files = list(path_test.glob("*.arrow"))

fms_mo/fx/dynamo_utils.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
get_target_op_from_mod_or_str,
3030
get_target_op_from_node,
3131
)
32+
from fms_mo.utils.import_utils import available_packages
3233

3334
logger = logging.getLogger(__name__)
3435

@@ -1133,7 +1134,6 @@ def cus_backend_model_analyzer(
11331134
from functools import partial
11341135

11351136
# Third Party
1136-
from torchvision.models import VisionTransformer
11371137
from transformers import PreTrainedModel
11381138

11391139
if issubclass(type(model), torch.nn.Module):
@@ -1145,7 +1145,16 @@ def cus_backend_model_analyzer(
11451145
model_to_be_traced = model
11461146
model_param_size = 999
11471147

1148-
is_transformers = issubclass(type(model), (PreTrainedModel, VisionTransformer))
1148+
transformer_model_classes = (PreTrainedModel,)
1149+
1150+
if available_packages["torchvision"]:
1151+
# Third Party
1152+
# pylint: disable = import-error
1153+
from torchvision.models import VisionTransformer
1154+
1155+
transformer_model_classes += (VisionTransformer,)
1156+
1157+
is_transformers = issubclass(type(model), transformer_model_classes)
11491158
if model_param_size > 1:
11501159
# Standard
11511160
import sys
@@ -1188,11 +1197,13 @@ def call_seq_hook(mod, *_args, **_kwargs):
11881197

11891198
# only add last layer
11901199
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][-1]]
1191-
# unless it's a ViT, skip first Conv as well
1192-
if issubclass(type(model), VisionTransformer) and isinstance(
1193-
model.get_submodule(qcfg["mod_call_seq"][0]), torch.nn.Conv2d
1194-
):
1195-
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][0]]
1200+
1201+
if available_packages["torchvision"]:
1202+
# unless it's a ViT, skip first Conv as well
1203+
if issubclass(type(model), VisionTransformer) and isinstance(
1204+
model.get_submodule(qcfg["mod_call_seq"][0]), torch.nn.Conv2d
1205+
):
1206+
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][0]]
11961207

11971208
with torch.no_grad():
11981209
model_opt = torch.compile(
@@ -1271,21 +1282,23 @@ def qbmm_auto_check(_mod, *_args, **_kwargs):
12711282
# c) identify RPN/FPN
12721283
# TODO this hack only works for torchvision models. will use find_rpn_fpn_gm()
12731284

1274-
# Third Party
1275-
from torchvision.models.detection.rpn import RegionProposalNetwork
1276-
from torchvision.ops import FeaturePyramidNetwork
1277-
1278-
rpnfpn_prefix = []
1279-
rpnfpn_convs = []
1280-
for n, m in model.named_modules():
1281-
if isinstance(m, (FeaturePyramidNetwork, RegionProposalNetwork)):
1282-
rpnfpn_prefix.append(n)
1283-
if isinstance(m, torch.nn.Conv2d) and any(
1284-
n.startswith(p) for p in rpnfpn_prefix
1285-
):
1286-
rpnfpn_convs.append(n)
1287-
if n not in qcfg["qskip_layer_name"]:
1288-
qcfg["qskip_layer_name"].append(n)
1285+
if available_packages["torchvision"]:
1286+
# Third Party
1287+
# pylint: disable = import-error
1288+
from torchvision.models.detection.rpn import RegionProposalNetwork
1289+
from torchvision.ops import FeaturePyramidNetwork
1290+
1291+
rpnfpn_prefix = []
1292+
rpnfpn_convs = []
1293+
for n, m in model.named_modules():
1294+
if isinstance(m, (FeaturePyramidNetwork, RegionProposalNetwork)):
1295+
rpnfpn_prefix.append(n)
1296+
if isinstance(m, torch.nn.Conv2d) and any(
1297+
n.startswith(p) for p in rpnfpn_prefix
1298+
):
1299+
rpnfpn_convs.append(n)
1300+
if n not in qcfg["qskip_layer_name"]:
1301+
qcfg["qskip_layer_name"].append(n)
12891302

12901303
if qcfg["N_backend_called"] > 1:
12911304
logger.warning(

fms_mo/modules/bmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def forward(self, m1, m2):
192192
torch.Tensor: Output tensor after quantized bmm.
193193
"""
194194
# pylint: disable = access-member-before-definition
195-
if self.calib_counter:
195+
if self.calib_counter > 0:
196196
with torch.no_grad():
197197
qm1 = self.quantize_calib_m1(m1)
198198
qm2 = self.quantize_calib_m2(m2)

fms_mo/modules/conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def forward(self, x):
270270
torch.Tensor: Output tensor of shape (batch_size, out_channels, out_height, out_width).
271271
"""
272272
# pylint: disable = access-member-before-definition
273-
if self.calib_counter:
273+
if self.calib_counter > 0:
274274
with torch.no_grad():
275275
qinput = self.quantize_calib_feature(x)
276276
qweight = self.quantize_calib_weight(self.weight)

0 commit comments

Comments
 (0)