Skip to content

Commit 4e6ed88

Browse files
committed
Big refactor into package
1 parent 424a4c9 commit 4e6ed88

File tree

9 files changed

+907
-0
lines changed

9 files changed

+907
-0
lines changed

.gitignore

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
share/python-wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
MANIFEST
28+
29+
# PyInstaller
30+
# Usually these files are written by a python script from a template
31+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
32+
*.manifest
33+
*.spec
34+
35+
# Installer logs
36+
pip-log.txt
37+
pip-delete-this-directory.txt
38+
39+
# Unit test / coverage reports
40+
htmlcov/
41+
.tox/
42+
.nox/
43+
.coverage
44+
.coverage.*
45+
.cache
46+
nosetests.xml
47+
coverage.xml
48+
*.cover
49+
*.py,cover
50+
.hypothesis/
51+
.pytest_cache/
52+
cover/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
docs/source/getting_started/examples/*.rst
74+
!**/*.template.rst
75+
76+
# PyBuilder
77+
.pybuilder/
78+
target/
79+
80+
# Jupyter Notebook
81+
.ipynb_checkpoints
82+
83+
# IPython
84+
profile_default/
85+
ipython_config.py
86+
87+
# pyenv
88+
# For a library or package, you might want to ignore these files since the code is
89+
# intended to run in multiple environments; otherwise, check them in:
90+
# .python-version
91+
92+
# pipenv
93+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
95+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
96+
# install all needed dependencies.
97+
#Pipfile.lock
98+
99+
# poetry
100+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101+
# This is especially recommended for binary packages to ensure reproducibility, and is more
102+
# commonly ignored for libraries.
103+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104+
#poetry.lock
105+
106+
# pdm
107+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108+
#pdm.lock
109+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110+
# in version control.
111+
# https://pdm.fming.dev/#use-with-ide
112+
.pdm.toml
113+
114+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115+
__pypackages__/
116+
117+
# Celery stuff
118+
celerybeat-schedule
119+
celerybeat.pid
120+
121+
# SageMath parsed files
122+
*.sage.py
123+
124+
# Environments
125+
.env
126+
.venv
127+
env/
128+
venv/
129+
ENV/
130+
env.bak/
131+
venv.bak/
132+
133+
# Spyder project settings
134+
.spyderproject
135+
.spyproject
136+
137+
# Rope project settings
138+
.ropeproject
139+
140+
# mkdocs documentation
141+
/site
142+
143+
# mypy
144+
.mypy_cache/
145+
.dmypy.json
146+
dmypy.json
147+
148+
# Pyre type checker
149+
.pyre/
150+
151+
# pytype static type analyzer
152+
.pytype/
153+
154+
# Cython debug symbols
155+
cython_debug/
156+
157+
# PyCharm
158+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160+
# and can be added to the global gitignore or merged into this file. For a more nuclear
161+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
162+
.idea/
163+
164+
# VSCode
165+
.vscode/
166+
167+
# DS Store
168+
.DS_Store
169+
170+
# Results
171+
*.csv
172+
173+
# Python pickle files
174+
*.pkl
175+
176+
# Sphinx documentation
177+
_build/
178+
179+
# vim swap files
180+
*.swo
181+
*.swp
182+
183+
# hip files generated by PyTorch
184+
*.hip
185+
*_hip*
186+
hip_compat.h
187+
188+
# Benchmark dataset
189+
*.json

auto_fp8/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .modeling import AutoFP8ForCausalLM
2+
from .config import BaseQuantizeConfig

auto_fp8/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class BaseQuantizeConfig:
2+
def __init__(self, quant_method="fp8", activation_scheme="static"):
3+
if quant_method != "fp8":
4+
raise ValueError("Only FP8 quantization is supported.")
5+
if activation_scheme not in ["static", "dynamic"]:
6+
raise ValueError(
7+
"Invalid activation_scheme. Choose either 'static' or 'dynamic'."
8+
)
9+
self.quant_method = quant_method
10+
self.activation_scheme = activation_scheme

auto_fp8/modeling.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import torch
2+
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
3+
from auto_fp8.quantize import (
4+
quantize_weights,
5+
quantize_activations,
6+
save_quantized_model,
7+
)
8+
from auto_fp8.config import BaseQuantizeConfig
9+
10+
11+
class AutoFP8ForCausalLM:
12+
def __init__(
13+
self,
14+
model: PreTrainedModel,
15+
quantize_config: BaseQuantizeConfig,
16+
):
17+
# super().__init__()
18+
19+
self.model = model
20+
self.model_type = self.model.config.model_type
21+
self.quantize_config = quantize_config
22+
self.config = self.model.config
23+
24+
@classmethod
25+
def from_pretrained(
26+
cls,
27+
pretrained_model_name_or_path: str,
28+
quantize_config: BaseQuantizeConfig,
29+
**model_init_kwargs,
30+
):
31+
"""load un-quantized pretrained model to cpu"""
32+
33+
if not torch.cuda.is_available():
34+
raise EnvironmentError(
35+
"Load pretrained model to do quantization requires CUDA available."
36+
)
37+
38+
def skip(*args, **kwargs):
39+
pass
40+
41+
torch.nn.init.kaiming_uniform_ = skip
42+
torch.nn.init.uniform_ = skip
43+
torch.nn.init.normal_ = skip
44+
45+
# Parameters related to loading from Hugging Face Hub
46+
cache_dir = model_init_kwargs.pop("cache_dir", None)
47+
force_download = model_init_kwargs.pop("force_download", False)
48+
resume_download = model_init_kwargs.pop("resume_download", False)
49+
proxies = model_init_kwargs.pop("proxies", None)
50+
local_files_only = model_init_kwargs.pop("local_files_only", False)
51+
use_auth_token = model_init_kwargs.pop("use_auth_token", None)
52+
revision = model_init_kwargs.pop("revision", None)
53+
subfolder = model_init_kwargs.pop("subfolder", "")
54+
commit_hash = model_init_kwargs.pop("_commit_hash", None)
55+
56+
cached_file_kwargs = {
57+
"cache_dir": cache_dir,
58+
"force_download": force_download,
59+
"proxies": proxies,
60+
"resume_download": resume_download,
61+
"local_files_only": local_files_only,
62+
"use_auth_token": use_auth_token,
63+
"revision": revision,
64+
"subfolder": subfolder,
65+
"_commit_hash": commit_hash,
66+
}
67+
68+
torch.cuda.empty_cache()
69+
70+
# Important defaults
71+
if not hasattr(model_init_kwargs, "torch_dtype"):
72+
model_init_kwargs["torch_dtype"] = "auto"
73+
74+
if not hasattr(model_init_kwargs, "device_map"):
75+
model_init_kwargs["device_map"] = "auto"
76+
77+
merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
78+
print(merged_kwargs)
79+
model = AutoModelForCausalLM.from_pretrained(
80+
pretrained_model_name_or_path, **merged_kwargs
81+
)
82+
83+
model_config = model.config.to_dict()
84+
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
85+
if any(k in model_config for k in seq_len_keys):
86+
for key in seq_len_keys:
87+
if key in model_config:
88+
model.seqlen = model_config[key]
89+
break
90+
else:
91+
print(
92+
"can't get model's sequence length from model config, will set to 2048."
93+
)
94+
model.seqlen = 2048
95+
model.eval()
96+
97+
return cls(model, quantize_config)
98+
99+
def quantize(self, calibration_tokens):
100+
def _prepare_calibration_data(calibration_tokens):
101+
if hasattr(calibration_tokens, "input_ids"):
102+
return calibration_tokens.input_ids
103+
return calibration_tokens
104+
105+
if self.quantize_config.activation_scheme == "dynamic":
106+
quantize_weights(self.model)
107+
else:
108+
quantize_weights(self.model)
109+
quantize_activations(
110+
self.model, _prepare_calibration_data(calibration_tokens)
111+
)
112+
113+
def save_quantized(self, save_dir):
114+
save_quantized_model(
115+
self.model,
116+
activation_scheme=self.quantize_config.activation_scheme,
117+
save_dir=save_dir,
118+
)

0 commit comments

Comments
 (0)