Skip to content

Commit c062b1f

Browse files
committed
Last fixes
1 parent 3b16e08 commit c062b1f

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

auto_fp8/modeling.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def skip(*args, **kwargs):
7575
model_init_kwargs["device_map"] = "auto"
7676

7777
merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
78-
print(merged_kwargs)
78+
print("Loading model with the following kwargs:", merged_kwargs)
7979
model = AutoModelForCausalLM.from_pretrained(
8080
pretrained_model_name_or_path, **merged_kwargs
8181
)
@@ -102,10 +102,10 @@ def _prepare_calibration_data(calibration_tokens):
102102
return calibration_tokens.input_ids
103103
return calibration_tokens
104104

105-
if self.quantize_config.activation_scheme == "dynamic":
106-
quantize_weights(self.model)
107-
else:
108-
quantize_weights(self.model)
105+
# Always quantize the weights as they do not require calibration data
106+
quantize_weights(self.model)
107+
108+
if self.quantize_config.activation_scheme == "static":
109109
quantize_activations(
110110
self.model, _prepare_calibration_data(calibration_tokens)
111111
)

example.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from transformers import AutoTokenizer
22
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
33

4-
pretrained_model_dir = "facebook/opt-125m"
5-
quantized_model_dir = "opt-125m-fp8"
4+
pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
5+
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8"
66

77
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
8-
examples = ["auto-fp8 is an easy-to-use model quantization library"]
8+
examples = ["auto_fp8 is an easy-to-use model quantization library"]
99
examples = tokenizer(examples, return_tensors="pt").to("cuda")
1010

11-
quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static")
11+
quantize_config = BaseQuantizeConfig(
12+
quant_method="fp8", activation_scheme="dynamic"
13+
) # or "static"
1214

1315
model = AutoFP8ForCausalLM.from_pretrained(
1416
pretrained_model_dir, quantize_config=quantize_config

0 commit comments

Comments
 (0)