Skip to content

Commit d77e518

Browse files
committed
Fix examples
1 parent 4121b74 commit d77e518

File tree

3 files changed

+29
-20
lines changed

3 files changed

+29
-20
lines changed

example.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,10 @@
99
examples = ["auto_fp8 is an easy-to-use model quantization library"]
1010
examples = tokenizer(examples, return_tensors="pt").to("cuda")
1111

12-
ignore_patterns = ["re:.*gate"]
13-
1412
quantize_config = BaseQuantizeConfig(
1513
quant_method="fp8",
1614
activation_scheme="dynamic", # or "static"
17-
ignore_patterns=ignore_patterns,
15+
ignore_patterns=["re:.*lm_head"],
1816
)
1917

2018
model = AutoFP8ForCausalLM.from_pretrained(

example_dataset.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,11 @@
77
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8"
88

99
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
10+
ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(512)
11+
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
12+
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")
1013

11-
DATASET_ID = "mgoin/ultrachat_2k"
12-
DATASET_SPLIT = "train_sft"
13-
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
14-
ds = ds.map(
15-
lambda batch: {
16-
"text": tokenizer.apply_chat_template(batch["messages"], tokenize=False)
17-
}
18-
)
19-
examples = [sample["text"] for sample in ds]
20-
tokenizer.pad_token = tokenizer.eos_token
21-
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to(
22-
"cuda"
23-
)
24-
25-
quantize_config = BaseQuantizeConfig(
26-
quant_method="fp8", activation_scheme="static"
27-
) # or "static"
14+
quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static")
2815

2916
model = AutoFP8ForCausalLM.from_pretrained(
3017
pretrained_model_dir, quantize_config=quantize_config

examples/example_mixtral.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from datasets import load_dataset
2+
from transformers import AutoTokenizer
3+
4+
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
5+
6+
pretrained_model_dir = "mistralai/Mixtral-8x7B-Instruct-v0.1"
7+
quantized_model_dir = "Mixtral-8x7B-Instruct-v0.1-FP8"
8+
9+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
10+
ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(10)
11+
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
12+
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")
13+
14+
quantize_config = BaseQuantizeConfig(
15+
quant_method="fp8",
16+
activation_scheme="static",
17+
ignore_patterns=["re:.*lm_head", "re:.*gate"],
18+
)
19+
20+
model = AutoFP8ForCausalLM.from_pretrained(
21+
pretrained_model_dir, quantize_config=quantize_config
22+
)
23+
model.quantize(examples)
24+
model.save_quantized(quantized_model_dir)

0 commit comments

Comments
 (0)