Skip to content

Commit 47b8716

Browse files
Merge pull request #169 from BrandonGroth/quant_refactor_perCh
new functions are isolated from current functions, safe to merge. But we will need a new example to demo how to use the new `_rc` functions.
2 parents 207eb06 + 759dab5 commit 47b8716

34 files changed

+15193
-79
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ qcfg.json
1717
configs
1818
pytest.out
1919

20+
# Log file
21+
fms_mo.log
22+
2023
# IDEs
2124
.vscode/
2225
.idea/

fms_mo/quant/quantizers.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -513,9 +513,12 @@ def forward(
513513

514514
if istraining:
515515
# only recalc clipvals under training mode
516+
num_bits_int = (
517+
num_bits.item() if isinstance(num_bits, torch.Tensor) else num_bits
518+
)
516519
SAWBcode_mapping = {8: 803, 4: 403, 2: 103}
517520
if num_bits in [2, 4, 8]:
518-
sawb_code = SAWBcode_mapping[num_bits]
521+
sawb_code = SAWBcode_mapping[num_bits_int]
519522
clip_val, _ = sawb_params_code(
520523
num_bits, sawb_code, input_tensor, perCh=True
521524
)
@@ -552,9 +555,13 @@ def forward(
552555
clip_val.dtype
553556
) # NOTE return will be a fp32 tensor; function only support float()
554557
else:
555-
output = torch.quantize_per_channel(
556-
input_tensor, scale, zero_point, 0, torch.qint8
557-
).int_repr()
558+
output = (
559+
torch.quantize_per_channel(
560+
input_tensor, scale, zero_point, 0, torch.qint8
561+
)
562+
.int_repr()
563+
.clamp(int_l, int_u)
564+
)
558565
# NOTE return will be a torch.int8 tensor
559566

560567
return output
@@ -2540,7 +2547,7 @@ def asymmetric_linear_quantization_params(
25402547
return scale, zero_point
25412548

25422549

2543-
def clamp(input_tensor: torch.FloatTensor, clamp_min, clamp_max, inplace=False):
2550+
def clamp(input_tensor: torch.Tensor, clamp_min, clamp_max, inplace=False):
25442551
"""
25452552
Returns:
25462553
Clamped Torch Tensor.
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Base QModel Class
17+
"""
18+
19+
# Standard
20+
# pylint: disable=keyword-arg-before-vararg
21+
import logging
22+
23+
# Third Party
24+
import torch
25+
26+
logger = logging.getLogger(__name__)
27+
28+
29+
class Qmodel: # do not inherit nn.Module, or self.model will not show up in __dict__
30+
"""
31+
A wrapper for fms_mo model, mainly for user API simplification purpose.
32+
Everything is the same as the original model, but we can add new member functions.
33+
Make sure the naming will be unique enough so that we won't override any existing functions
34+
in Huggingface models.
35+
"""
36+
37+
def __init__(self, original_model, qcfg=None, *args, **kwargs) -> None:
38+
super().__init__(*args, **kwargs)
39+
self.org_attr = dir(original_model)
40+
self.model = original_model
41+
if qcfg:
42+
self.qcfg = qcfg
43+
44+
def __getattr__(self, name: str):
45+
if name in self.org_attr:
46+
logger.info(f"Trying to access self.{name}, forward to self.model.{name}")
47+
return getattr(self.model, name)
48+
# NOTE: this self.model is in __dict__, so it will not trigger __getattr__
49+
# recursively.!!
50+
51+
def __call__(self, *args, **kwargs):
52+
logger.info(
53+
"Make this object callable, but actually just calling self.model.__call__()"
54+
)
55+
return self.model(*args, **kwargs)
56+
57+
def __repr__(self):
58+
OKCYAN = "\033[96m"
59+
ENDC = "\033[0m"
60+
rep_txt = f"{OKCYAN}FMSMO_Qmodel_wrapper({ENDC}\n{self.model.__repr__()}{OKCYAN}){ENDC}"
61+
return rep_txt
62+
63+
def to(self, tar_dev: torch.device):
64+
"""
65+
Demonstrate that we can override a function in original model
66+
it should not call __getattr__(), i.e. will not see the printout from that func
67+
68+
Args:
69+
tar_dev (torch.device): A new device
70+
71+
Returns:
72+
Qmodel: Moved model to tar_dev
73+
"""
74+
logger.info(
75+
f"Override a function in original model. moving the model to a new device {tar_dev}"
76+
)
77+
return self.model.to(tar_dev)
78+
79+
def save_model_in_pt_fmt(
80+
self, filename: str = "model.pt", exam_inp: torch.Tensor = None
81+
):
82+
"""
83+
Save entire model to file
84+
85+
Args:
86+
filename (str, optional): File path to save model. Defaults to "model.pt".
87+
exam_inp (torch.Tensor, optional): Example input for model. Defaults to None.
88+
"""
89+
# NOTE self.qcfg has a lot of info already, like transformers_version
90+
# NOTE cannot save wrapped self, can only save self.model...
91+
save_dict = {"model": self.model}
92+
if exam_inp:
93+
save_dict["exam_inp"] = exam_inp
94+
torch.save(save_dict, filename)
95+
logger.info(f"{filename} saved successfully.")
96+
97+
def save_statedict_in_pt_fmt(
98+
self,
99+
filename: str = "model.pt",
100+
):
101+
"""
102+
Save the model state dict to file
103+
104+
Args:
105+
filename (str, optional): File path to save model state dict. Defaults to "model.pt".
106+
"""
107+
torch.save(self.model.state_dict(), filename)
108+
logger.info(f"model.state_dict() is saved to {filename} successfully.")
109+
110+
def run_gptq(self):
111+
"""
112+
Check model is supported by AutoGPTQ first
113+
"""
114+
return

0 commit comments

Comments
 (0)