Skip to content

Commit 39e7ff2

Browse files
Merge branch 'main' into dependabot/pip/transformers-gte-4.45-and-lt-4.51
Signed-off-by: chichun-charlie-liu <[email protected]>
2 parents cf81bca + f265075 commit 39e7ff2

18 files changed

+124
-53
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ jobs:
4040
strategy:
4141
matrix:
4242
python:
43-
- "3.9"
4443
- "3.10"
4544
- "3.11"
45+
- "3.12"
4646
platform:
4747
- "ubuntu-latest"
4848

README.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ FMS Model Optimizer is a framework for developing reduced precision neural netwo
3636
### Requirements
3737

3838
1. **🐧 Linux system with Nvidia GPU (V100/A100/H100)**
39-
2. Python 3.9 to Python 3.11
40-
41-
📋 Python 3.12 is currently not supported due to PyTorch Dynamo constraint
39+
2. Python 3.10 to Python 3.12
4240
3. CUDA >=12
4341

4442
*Optional packages based on optimization functionality required:*
@@ -47,9 +45,12 @@ FMS Model Optimizer is a framework for developing reduced precision neural netwo
4745
- [auto_gptq](https://pypi.org/project/auto-gptq/) or build from [source](https://github.com/AutoGPTQ/AutoGPTQ)
4846
- If you want to experiment with **INT8** deployment in [QAT](./examples/QAT_INT8/) and [PTQ](./examples/PTQ_INT8/) examples:
4947
- Nvidia GPU with compute capability > 8.0 (A100 family or higher)
50-
- [Ninja](https://ninja-build.org/)
51-
- Clone the [CUTLASS](https://github.com/NVIDIA/cutlass) repository
52-
- `PyTorch 2.3.1` (as newer version will cause issue for the custom CUDA kernel used in these examples)
48+
- Option 1:
49+
- [Ninja](https://ninja-build.org/)
50+
- Clone the [CUTLASS](https://github.com/NVIDIA/cutlass) repository
51+
- `PyTorch 2.3.1` (as newer version will cause issue for the custom CUDA kernel used in these examples)
52+
- Option 2:
53+
- use triton kernel included. But this kernel is currently not faster than FP16.
5354
- **FP8** is a reduced precision format like **INT8**:
5455
- Nvidia A100 family or higher
5556
- [llm-compressor](https://github.com/vllm-project/llm-compressor)

fms_mo/fx/dynamo_utils.py

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,28 @@
3333
logger = logging.getLogger(__name__)
3434

3535

36+
def run_fwd_once(model, sample_inp):
37+
"""Convenient function to run model once using correct input unpack."""
38+
with torch.no_grad():
39+
if isinstance(sample_inp, dict) or all(
40+
hasattr(sample_inp, k) for k in ("keys", "values", "items")
41+
):
42+
out = model(**sample_inp)
43+
elif isinstance(sample_inp, tuple):
44+
out = model(*sample_inp)
45+
elif isinstance(sample_inp, torch.Tensor):
46+
out = model(sample_inp)
47+
else:
48+
try:
49+
# assume user provided input is ready-to-run...
50+
out = model(sample_inp)
51+
except RuntimeError:
52+
logger.info(
53+
f"Unknown data structure for example_input.{type(sample_inp)} Please check."
54+
)
55+
return out
56+
57+
3658
def dfs_gm(
3759
gm,
3860
targetOp=None,
@@ -229,7 +251,9 @@ def _dfs(curr_node, depth):
229251

230252

231253
def find_conv_on_shortcut_gm(
232-
gm: torch.fx.GraphModule, lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None
254+
gm: torch.fx.GraphModule,
255+
lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None,
256+
lut_name_to_mod=None,
233257
):
234258
"""Identify Conv on shortcut using FX GM DFS
235259
It's (almost) specific for ResNet-like CNNs, will return a list of module names (as used in the
@@ -254,6 +278,9 @@ def find_conv_on_shortcut_gm(
254278
5. count levels of each branch, decide which one is the shortcut
255279
"""
256280

281+
if lut_name_to_mod is None:
282+
lut_name_to_mod = {}
283+
257284
# 1. Find "add" nodes, including inplace add as some may use "out+=shortcut"
258285
nodes_add = dfs_gm(gm, ["add"], return_nodes=True)
259286

@@ -337,9 +364,13 @@ def find_conv_on_shortcut_gm(
337364
if n_conv_i.op == "call_module":
338365
conv_mod = gm.get_submodule(n_conv_i.target)
339366
else:
340-
conv_mod = get_org_mod_name_of_fx_node(
367+
# in case aten IR is being used
368+
conv_mod_name = get_org_mod_name_of_fx_node(
341369
n_conv_i, lut_fx2org=lut_fx_mod_name_to_org
342370
)
371+
conv_mod = lut_name_to_mod.get(conv_mod_name, None)
372+
if not isinstance(conv_mod, torch.nn.Conv2d):
373+
continue
343374
if conv_mod.out_channels > conv_mod.in_channels: # see Note 2
344375
qconv_candidate.append(
345376
get_org_mod_name_of_fx_node(
@@ -1003,8 +1034,17 @@ def cus_backend_model_analyzer(
10031034
for _, m in gm_fx.named_modules()
10041035
if isinstance(m, torch.nn.Conv2d) or issubclass(type(m), torch.nn.Conv2d)
10051036
]
1006-
if len(all_conv) > 0:
1007-
skip_candidates += find_conv_on_shortcut_gm(gm_fx, lut_fx_mod_name_to_org)
1037+
# if gm is using aten IR, only ops can be seen, no modules.
1038+
conv_ops = dfs_gm(
1039+
gm_fx,
1040+
targetOp=[torch.nn.Conv2d, torch.nn.functional.conv2d],
1041+
return_nodes=True,
1042+
)
1043+
lut_name_to_mod = {n: m for m, n in qcfg["LUTmodule_name"].items()}
1044+
if len(all_conv) > 0 or len(conv_ops) > 0:
1045+
skip_candidates += find_conv_on_shortcut_gm(
1046+
gm_fx, lut_fx_mod_name_to_org, lut_name_to_mod
1047+
)
10081048

10091049
# Check 2. first/last, see Note 2 and 3, NOTE that transformers are handled differently
10101050
if qcfg["N_backend_called"] > 1:
@@ -1064,6 +1104,7 @@ def cus_backend_model_analyzer(
10641104
from functools import partial
10651105

10661106
# Third Party
1107+
from torchvision.models import VisionTransformer
10671108
from transformers import PreTrainedModel
10681109

10691110
if issubclass(type(model), torch.nn.Module):
@@ -1075,7 +1116,7 @@ def cus_backend_model_analyzer(
10751116
model_to_be_traced = model
10761117
model_param_size = 999
10771118

1078-
is_transformers = issubclass(type(model), PreTrainedModel)
1119+
is_transformers = issubclass(type(model), (PreTrainedModel, VisionTransformer))
10791120
if model_param_size > 1:
10801121
# Standard
10811122
import sys
@@ -1111,35 +1152,25 @@ def call_seq_hook(mod, *_args, **_kwargs):
11111152
h_hooks.append(m.register_forward_hook(call_seq_hook))
11121153

11131154
with torch.no_grad():
1114-
model(**sample_inp)
1155+
run_fwd_once(model, sample_inp)
11151156

11161157
for h in h_hooks:
11171158
h.remove()
11181159

11191160
# only add last layer
11201161
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][-1]]
1162+
# unless it's a ViT, skip first Conv as well
1163+
if issubclass(type(model), VisionTransformer) and isinstance(
1164+
model.get_submodule(qcfg["mod_call_seq"][0]), torch.nn.Conv2d
1165+
):
1166+
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][0]]
11211167

11221168
with torch.no_grad():
11231169
model_opt = torch.compile(
11241170
model_to_be_traced,
11251171
backend=cus_bknd,
11261172
)
1127-
if isinstance(sample_inp, dict) or all(
1128-
hasattr(sample_inp, k) for k in ("keys", "values", "items")
1129-
):
1130-
model_opt(**sample_inp)
1131-
elif isinstance(sample_inp, tuple):
1132-
model_opt(*sample_inp)
1133-
elif isinstance(sample_inp, torch.Tensor):
1134-
model_opt(sample_inp)
1135-
else:
1136-
try:
1137-
# assume user provided input is ready-to-run...
1138-
model_opt(sample_inp)
1139-
except RuntimeError:
1140-
logger.info(
1141-
f"Unknown data structure for example_input.{type(sample_inp)} Please check."
1142-
)
1173+
run_fwd_once(model_opt, sample_inp)
11431174

11441175
del model_opt
11451176

fms_mo/fx/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,8 @@ def get_org_mod_name_of_fx_node(
343343
str: corresponding name on original graph
344344
"""
345345
org_name = f"Unknown:{node.name}"
346+
if lut_fx2org is None:
347+
lut_fx2org = {}
346348
if "nn_module_stack" in node.meta:
347349
n_fx_mod_name = list(node.meta["nn_module_stack"].keys())[-1]
348350
n_fx_org_mod_name = list(node.meta["nn_module_stack"].values())[-1][0]
@@ -360,7 +362,7 @@ def get_org_mod_name_of_fx_node(
360362
org_name = v[: -len(suffix)]
361363
break
362364

363-
if org_name is None:
365+
if org_name.startswith("Unknown:"):
364366
org_name = lname_to_org_name(n_fx_org_mod_name)
365367

366368
return org_name

fms_mo/modules/linear.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@
2727
import torch.nn.functional as F
2828

2929
# Local
30-
from fms_mo.custom_ext_kernels.triton_kernels import (
31-
tl_matmul_chunk_truncate as tl_matmul,
32-
)
3330
from fms_mo.custom_ext_kernels.utils import pack_vectorized
3431
from fms_mo.quant.quantizers import (
3532
HardPrune,
@@ -39,6 +36,13 @@
3936
get_weight_quantizer,
4037
mask_fc_kij,
4138
)
39+
from fms_mo.utils.import_utils import available_packages
40+
41+
if available_packages["triton"]:
42+
# Local
43+
from fms_mo.custom_ext_kernels.triton_kernels import (
44+
tl_matmul_chunk_truncate as tl_matmul,
45+
)
4246

4347
logger = logging.getLogger(__name__)
4448

@@ -879,7 +883,9 @@ def from_torch_iW(cls, nnlin_iW, prec, a_cv, a_cvn, w_cv, zero_shift, **kwargs):
879883
qlinear_iW.nbits_w = 8
880884
qlinear_iW.acc_dtype = kwargs.get("acc_dtype", torch.float)
881885
qlinear_iW.usePTnativeQfunc = kwargs.get("use_PT_native_Qfunc", True)
882-
qlinear_iW.use_int_kernel = kwargs.get("use_int_kernel", "triton")
886+
qlinear_iW.use_int_kernel = kwargs.get(
887+
"use_int_kernel", "triton" if available_packages["triton"] else False
888+
)
883889
qlinear_iW.weight = nn.Parameter(
884890
nnlin_iW.weight.to(torch.int8), requires_grad=False
885891
)
@@ -1119,15 +1125,15 @@ def set_matmul_op(self):
11191125
imatmul_ops_reg,
11201126
)
11211127

1122-
if self.use_int_kernel == "triton":
1128+
if self.use_int_kernel == "triton" and available_packages["triton"]:
11231129
# will use real imatmul written in triton
11241130
imm_func = partial(
11251131
tl_matmul,
11261132
chunk_trun_bits=self.truncate_lsb,
11271133
chunk_size=self.chunk_size,
11281134
)
11291135

1130-
elif self.use_int_kernel == "cutlass":
1136+
elif self.use_int_kernel == "cutlass" and available_packages["cutlass"]:
11311137
# will use real imatmul written in cutlass
11321138
cutlass_ops_load_and_reg()
11331139
# Third Party

fms_mo/run_quant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def quantize(
9292
"auto_gptq module not found. For more instructions on installing the appropriate "
9393
"package, see https://github.com/AutoGPTQ/AutoGPTQ?tab=readme-ov-file#installation"
9494
)
95+
gptq_args.use_triton = gptq_args.use_triton and available_packages["triton"]
9596
run_gptq(model_args, data_args, opt_args, gptq_args)
9697
elif opt_args.quant_method == "fp8":
9798
if not available_packages["llmcompressor"]:

fms_mo/utils/import_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"graphviz",
3030
"pygraphviz",
3131
"fms",
32+
"triton",
3233
]
3334

3435
available_packages = {}

pyproject.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,25 @@ name = "fms-model-optimizer"
77
description = "Quantization Techniques"
88
readme = "README.md"
99
license = {text = "Apache-2.0"}
10-
requires-python = ">=3.9,<3.12"
10+
requires-python = ">3.9,<3.13"
1111
classifiers=[
1212
"Development Status :: 3 - Alpha",
1313
"License :: OSI Approved :: Apache Software License",
1414
"License :: OSI Approved :: MIT License",
1515
"Operating System :: POSIX :: Linux",
1616
"Topic :: Scientific/Engineering :: Artificial Intelligence",
1717
"Programming Language :: Python :: 3",
18-
"Programming Language :: Python :: 3.9",
1918
"Programming Language :: Python :: 3.10",
2019
"Programming Language :: Python :: 3.11",
20+
"Programming Language :: Python :: 3.12",
2121
"Programming Language :: Python :: Implementation :: CPython",
2222
]
2323
dynamic = ["version"]
2424
dependencies = [
2525
"numpy>=1.26.4,<2.3.0",
26-
"accelerate>=0.20.3,!=0.34,<1.4",
26+
"accelerate>=0.20.3,!=0.34,<1.7",
2727
"transformers>=4.45,<4.51",
28-
"torch>=2.2.0,<2.5",
28+
"torch>=2.2.0,<2.6",
2929
"triton>=3.0,<3.2",
3030
"tqdm>=4.66.2,<5.0",
3131
"datasets>=3.0.0,<4.0",
@@ -37,7 +37,8 @@ dependencies = [
3737
"huggingface_hub",
3838
"pandas",
3939
"safetensors",
40-
"ibm-fms>=0.0.8"
40+
"ibm-fms>=0.0.8",
41+
"pkginfo>1.10"
4142
]
4243

4344
[project.optional-dependencies]

tests/aiu_addons/conftest.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,31 @@ def get_gptq_gemm_inputs(request) -> tuple[torch.Tensor, ...]:
7474
"atype": "per_tensor_symm", # per_tensor_asymm, per_token
7575
"smoothquant": False,
7676
},
77-
# {
78-
# "wtype": "per_channel", # per_channel
79-
# "atype": "per_tensor_symm", # per_tensor_asymm, per_token
80-
# "smoothquant": False,
81-
# },
77+
{
78+
"wtype": "per_tensor", # per_channel
79+
"atype": "per_tensor_asymm", # per_tensor_asymm, per_token
80+
"smoothquant": False,
81+
},
82+
{
83+
"wtype": "per_channel", # per_channel
84+
"atype": "per_tensor_symm", # per_tensor_asymm, per_token
85+
"smoothquant": False,
86+
},
87+
{
88+
"wtype": "per_tensor", # per_channel
89+
"atype": "per_token", # per_tensor_asymm, per_token
90+
"smoothquant": False,
91+
},
92+
{
93+
"wtype": "per_channel", # per_channel
94+
"atype": "per_tensor_asymm", # per_tensor_asymm, per_token
95+
"smoothquant": False,
96+
},
97+
{
98+
"wtype": "per_channel", # per_channel
99+
"atype": "per_token", # per_tensor_asymm, per_token
100+
"smoothquant": False,
101+
},
82102
]
83103

84104

Binary file not shown.

0 commit comments

Comments
 (0)