Skip to content

Commit 520c78f

Browse files
fix mxfp exporting (#806)
* enable act quantization for mxfp datatype Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add ut, fixtypo Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixtypo Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Zhang, Weiwei1 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9534d2a commit 520c78f

File tree

8 files changed

+172
-97
lines changed

8 files changed

+172
-97
lines changed

auto_round/autoround.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,11 @@ def _check_configs(self) -> None:
686686
if self.gradient_accumulate_steps <= 0:
687687
raise ValueError("`gradient_accumulate_steps` must be positive")
688688

689-
if self.act_bits <= 8 and (not is_nv_fp(self.act_data_type) or "static_gs" not in self.act_data_type):
689+
if (
690+
self.act_bits <= 8
691+
and (not is_nv_fp(self.act_data_type) or "static_gs" not in self.act_data_type)
692+
and not is_mx_fp(self.act_data_type)
693+
):
690694
logger.warning(
691695
"activation quantization is an experimental feature with limited support and a complex API. "
692696
"And please save the quantized model to fake format as real deployment is not supported currently"
@@ -897,25 +901,22 @@ def _check_supported_format(self, format: str) -> bool:
897901
# format check for fp8
898902
w_fp8 = self.data_type == "fp" and self.bits == 8
899903
act_fp8 = self.act_data_type == "fp" and self.act_bits == 8
900-
if (w_fp8 or act_fp8) and re.search("^auto_round|^llmcompressor", format) is None:
904+
if (w_fp8 or act_fp8) and re.search("^auto_round|^llm_compressor", format) is None:
901905
error_msg = (
902-
f"is only supported to export auto_round or llmcompressor format," f" but got {format}, please check."
906+
f"is only supported to export auto_round or llm_compressor format," f" but got {format}, please check."
903907
)
904908
error_msg = ("act_data_type<fp8> " + error_msg) if act_fp8 else error_msg
905909
error_msg = ("data_type<fp8> " + error_msg) if w_fp8 else error_msg
906910
logger.error(error_msg)
907911
sys.exit(-1)
908912

909-
# Only support to export afp8/nv_fp
913+
# Only support to export afp8/nv_fp/mx_fp
910914
if self.act_bits <= 8:
911915
if not is_standard_fp(self.act_data_type) or self.act_dynamic:
912916
if "llm_compressor" in format:
913-
if is_nv_fp(self.act_data_type) and "static_gs" in self.act_data_type:
914-
logger.warning(
915-
f"AutoRound supports exporting to format '{format}', "
916-
"but loading quantized models in this format is not yet supported. "
917-
"It is currently recommended to export to the 'llm_compressor' format."
918-
)
917+
if (is_nv_fp(self.act_data_type) and "static_gs" in self.act_data_type) or (
918+
is_mx_fp(self.act_data_type)
919+
):
919920
return format
920921
bits, group_size, sym, act_bits = 8, -1, True, 8
921922
assert (
@@ -925,15 +926,23 @@ def _check_supported_format(self, format: str) -> bool:
925926
and self.act_bits == act_bits
926927
and self.act_dynamic
927928
), (
928-
f"Currently only support to export llmcompressor format for sym dynamic quantized"
929+
f"Currently only support to export llm_compressor format for sym dynamic quantized"
929930
f" W{self.bits}A{self.act_bits} model with group_size={group_size},"
930931
f" but got bits={self.bits}, group_size={self.group_size}, sym={self.sym},"
931932
f" act_bits={self.act_bits}"
932933
)
933-
elif format != "fake" and (not is_nv_fp(format) or "static_gs" not in self.act_data_type):
934+
elif "auto_round" in format and (
935+
is_mx_fp(self.act_data_type) or (is_nv_fp(format) and "static_gs" in self.act_data_type)
936+
):
937+
logger.warning(
938+
f"AutoRound supports exporting to format '{format}', "
939+
"but loading quantized models in this format is not yet supported. "
940+
"It is currently recommended to export to the 'llm_compressor' format."
941+
)
942+
elif format != "fake":
934943
logger.warning(
935944
"Currently only support to export auto_round format quantized model"
936-
" with fp8 or nv_fp4 dtype activation for activation quantization."
945+
" with fp8, mx_fp and nv_fp4 dtype activation for activation quantization."
937946
" Change format to fake and save."
938947
)
939948
format = "fake"
@@ -2159,7 +2168,7 @@ def calib(self, nsamples, bs):
21592168
exit(-1)
21602169
elif total_cnt < nsamples:
21612170
logger.warning(
2162-
f"An insufficient number of samples likely reduces the accuracy of the quantized model."
2171+
f"An insufficient number of samples likely reduces the accuracy of the quantized model. "
21632172
f"Target samples count is {nsamples}, while valid samples count is {total_cnt}"
21642173
)
21652174

@@ -2838,7 +2847,11 @@ def _quantize_block(
28382847
with torch.no_grad():
28392848
unwrapper_block(block, best_params)
28402849

2841-
if is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats):
2850+
if (
2851+
is_nv_fp(self.act_data_type)
2852+
and hasattr(self, "formats")
2853+
and any("nv_fp" in format_ for format_ in self.formats)
2854+
):
28422855
# enable moe experts act_max automatic generation for WrapperWALayer
28432856
set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max")
28442857

auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -166,16 +166,21 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs):
166166
for n, m in model.named_modules():
167167
if isinstance(m, WrapperWALayer):
168168
orig_layer = m.orig_layer
169-
if not getattr(orig_layer, "input_global_scale", None):
170-
assert hasattr(orig_layer, "act_max")
171-
from auto_round.data_type.nvfp import calculate_gparam
172-
173-
input_global_scale = calculate_gparam(orig_layer.act_max, orig_layer.group_size, model.device)
174-
setattr(orig_layer, "input_global_scale", input_global_scale)
175-
delattr(orig_layer, "act_max")
176169
set_module(model, n, orig_layer)
177170

178-
# update input_global_scale
171+
if is_nv_fp(act_data_type) and "static_gs" in str(act_data_type).lower():
172+
# generate static input_global_scale
173+
for n, m in model.named_modules():
174+
if isinstance(m, SUPPORTED_LAYER_TYPES):
175+
layer = m
176+
if layer.act_bits < 8 and not getattr(layer, "input_global_scale", None):
177+
assert hasattr(layer, "act_max")
178+
from auto_round.data_type.nvfp import calculate_gparam
179+
180+
input_global_scale = calculate_gparam(layer.act_max, layer.group_size, model.device)
181+
setattr(layer, "input_global_scale", input_global_scale)
182+
delattr(layer, "act_max")
183+
# update fused input_global_scale
179184
from auto_round.data_type.utils import update_fused_layer_global_scales
180185

181186
modules = list(model.modules())

auto_round/export/export_to_autoround/qlinear_fp.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,18 +95,6 @@ def __init__(
9595
weight_name,
9696
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32),
9797
)
98-
if not self.sym:
99-
## TODO Currently only sym quant is supported for mxfp dtype. Is weight_zero_point param necessary?
100-
self.register_buffer(
101-
"weight_zero_point",
102-
torch.zeros(
103-
(
104-
math.ceil(infeatures / self.group_size),
105-
outfeatures // 32 * self.bits,
106-
),
107-
dtype=torch.int32,
108-
),
109-
)
11098
self.register_buffer(
11199
"weight_scale",
112100
torch.zeros(

auto_round/export/export_to_llmcompressor/export_to_fp.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,21 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs):
161161
for n, m in model.named_modules():
162162
if isinstance(m, WrapperWALayer):
163163
orig_layer = m.orig_layer
164-
if not getattr(orig_layer, "input_global_scale", None):
165-
assert hasattr(orig_layer, "act_max")
166-
from auto_round.data_type.nvfp import calculate_gparam
167-
168-
input_global_scale = calculate_gparam(orig_layer.act_max, orig_layer.group_size, model.device)
169-
setattr(orig_layer, "input_global_scale", input_global_scale)
170-
delattr(orig_layer, "act_max")
171164
set_module(model, n, orig_layer)
172165

173-
# update input_global_scale
166+
if is_nv_fp(act_data_type) and "static_gs" in str(act_data_type).lower():
167+
# generate static input_global_scale
168+
for n, m in model.named_modules():
169+
if isinstance(m, SUPPORTED_LAYER_TYPES):
170+
layer = m
171+
if layer.act_bits < 8 and not getattr(layer, "input_global_scale", None):
172+
assert hasattr(layer, "act_max")
173+
from auto_round.data_type.nvfp import calculate_gparam
174+
175+
input_global_scale = calculate_gparam(layer.act_max, layer.group_size, model.device)
176+
setattr(layer, "input_global_scale", input_global_scale)
177+
delattr(layer, "act_max")
178+
# update fused input_global_scale
174179
from auto_round.data_type.utils import update_fused_layer_global_scales
175180

176181
modules = list(model.modules())

auto_round/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2538,14 +2538,17 @@ class BackendDataType(str, Enum):
25382538

25392539

25402540
def is_standard_fp(backend):
2541+
backend = backend.lower()
25412542
return BackendDataType.STANDARD_FP in backend and not is_mx_fp(backend) and not is_nv_fp(backend)
25422543

25432544

25442545
def is_mx_fp(backend):
2546+
backend = backend.lower()
25452547
return BackendDataType.MX_FP in backend
25462548

25472549

25482550
def is_nv_fp(backend):
2551+
backend = backend.lower()
25492552
return BackendDataType.NV_FP in backend
25502553

25512554

test/test_cpu/test_export.py

Lines changed: 80 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import shutil
23
import sys
34
import unittest
@@ -11,6 +12,17 @@
1112
from auto_round import AutoRound
1213

1314

15+
def _get_folder_size(path: str) -> float:
16+
"""Return folder size in GB."""
17+
total_size = 0
18+
for dirpath, _, filenames in os.walk(path):
19+
for f in filenames:
20+
fp = os.path.join(dirpath, f)
21+
if os.path.isfile(fp):
22+
total_size += os.path.getsize(fp)
23+
return total_size / (1024**3) # convert to GB
24+
25+
1426
class LLMDataLoader:
1527
def __init__(self):
1628
self.batch_size = 1
@@ -25,7 +37,7 @@ class TestAutoRound(unittest.TestCase):
2537
def setUpClass(self):
2638
model_name = "facebook/opt-125m"
2739
self.save_dir = "./saved"
28-
self.model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto", trust_remote_code=True)
40+
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
2941
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
3042
self.llm_dataloader = LLMDataLoader()
3143

@@ -268,10 +280,7 @@ def test_mxfp4_llmcompressor_format(self):
268280
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
269281
from transformers import AutoConfig
270282

271-
bits = 4
272-
data_type = "mx_fp"
273-
group_size = 32
274-
sym = True
283+
scheme = "MXFP4"
275284
layer_config = {}
276285
fp_layers_str = "k_proj"
277286
from auto_round.utils import get_fp_layer_names
@@ -282,12 +291,58 @@ def test_mxfp4_llmcompressor_format(self):
282291
autoround = AutoRound(
283292
model,
284293
self.tokenizer,
285-
bits=bits,
286-
group_size=group_size,
287-
sym=sym,
294+
scheme=scheme,
288295
iters=2,
289296
seqlen=2,
290-
data_type=data_type,
297+
layer_config=layer_config,
298+
dataset=self.llm_dataloader,
299+
)
300+
quantized_model_path = self.save_dir
301+
autoround.quantize()
302+
compressed_model = autoround.save_quantized(
303+
output_dir=quantized_model_path, inplace=True, format="llm_compressor"
304+
)
305+
tmp_layer = compressed_model.model.decoder.layers[3].self_attn.q_proj
306+
skip_layer = compressed_model.model.decoder.layers[3].self_attn.k_proj
307+
assert (
308+
hasattr(tmp_layer, "weight_scale")
309+
and hasattr(tmp_layer, "weight_packed")
310+
and tmp_layer.weight_scale.dtype is torch.uint8
311+
and tmp_layer.weight_scale.shape[0] == 768
312+
), "Illegal MXFP4 packing name or data_type or shape"
313+
assert not hasattr(skip_layer, "weight_scale") and not hasattr( ## check skipped layers
314+
skip_layer, "weight_packed"
315+
), "Illegal MXFP4 quantization for fp_layers"
316+
quantization_config = AutoConfig.from_pretrained(
317+
quantized_model_path, trust_remote_code=True
318+
).quantization_config
319+
assert (
320+
quantization_config["format"] == "float-quantized"
321+
and quantization_config["config_groups"]["group_0"]["weights"]["is_mx"] is True
322+
and quantization_config["config_groups"]["group_0"]["weights"]["num_bits"] == 4
323+
), f"Invalid MXFP4 quantization configuration: {quantization_config}"
324+
325+
shutil.rmtree("./saved", ignore_errors=True)
326+
327+
def test_rtn_mxfp4_llmcompressor_format(self):
328+
model_name = "facebook/opt-125m"
329+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
330+
from transformers import AutoConfig
331+
332+
scheme = "MXFP4"
333+
layer_config = {}
334+
fp_layers_str = "k_proj"
335+
from auto_round.utils import get_fp_layer_names
336+
337+
not_quantize_layer_names = get_fp_layer_names(model, fp_layers_str)
338+
for name in not_quantize_layer_names:
339+
layer_config[name] = {"bits": 16, "act_bits": 16, "data_type": "float"}
340+
autoround = AutoRound(
341+
model,
342+
self.tokenizer,
343+
scheme=scheme,
344+
iters=0,
345+
seqlen=2,
291346
layer_config=layer_config,
292347
dataset=self.llm_dataloader,
293348
)
@@ -322,19 +377,13 @@ def test_mxfp8_llmcompressor_format(self):
322377
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
323378
from transformers import AutoConfig
324379

325-
bits = 8
326-
data_type = "mx_fp_rceil"
327-
group_size = 32
328-
sym = True
380+
scheme = "MXFP8"
329381
autoround = AutoRound(
330382
model,
331383
self.tokenizer,
332-
bits=bits,
333-
group_size=group_size,
334-
sym=sym,
384+
scheme=scheme,
335385
iters=2,
336386
seqlen=2,
337-
data_type=data_type,
338387
dataset=self.llm_dataloader,
339388
)
340389
quantized_model_path = self.save_dir
@@ -355,28 +404,23 @@ def test_mxfp8_llmcompressor_format(self):
355404
and quantization_config["config_groups"]["group_0"]["weights"]["is_mx"] is True
356405
and quantization_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8
357406
), f"Invalid MXFP8 quantization configuration: {quantization_config}"
407+
folder_size_gb = _get_folder_size(quantized_model_path)
408+
# Original opt-125m is < 0.5GB -> quantized mxfp8 model should be smaller but not empty
409+
assert (
410+
0.15 < folder_size_gb < 0.2
411+
), f"Quantized model folder size {folder_size_gb:.2f} GB is outside the expected range (0.1~0.2 GB)"
358412
shutil.rmtree("./saved", ignore_errors=True)
359413

360414
def test_nvfp4_llmcompressor_format(self):
361415
model_name = "facebook/opt-125m"
362416
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
363417
from transformers import AutoConfig
364418

365-
bits = 4
366-
act_bits = 4
367-
data_type = "nv_fp"
368-
act_data_type = "nv_fp4_with_static_gs"
369-
group_size = 16
370-
sym = True
419+
scheme = "NVFP4"
371420
autoround = AutoRound(
372421
model,
373422
self.tokenizer,
374-
bits=bits,
375-
act_bits=act_bits,
376-
data_type=data_type,
377-
act_data_type=act_data_type,
378-
group_size=group_size,
379-
sym=sym,
423+
scheme=scheme,
380424
iters=2,
381425
seqlen=2,
382426
dataset=self.llm_dataloader,
@@ -399,28 +443,23 @@ def test_nvfp4_llmcompressor_format(self):
399443
quantization_config["format"] == "nvfp4-pack-quantized"
400444
and quantization_config["config_groups"]["group_0"]["input_activations"]["num_bits"] == 4
401445
), f"Invalid NVFP4 quantization configuration: {quantization_config}"
446+
folder_size_gb = _get_folder_size(quantized_model_path)
447+
# Original opt-125m is < 0.5GB -> quantized nvfp4 model should be smaller but not empty
448+
assert (
449+
0.1 < folder_size_gb < 0.15
450+
), f"Quantized model folder size {folder_size_gb:.2f} GB is outside the expected range (0.1~0.15 GB)"
402451
shutil.rmtree("./saved", ignore_errors=True)
403452

404453
def test_nvfp4_autoround_format(self):
405454
model_name = "facebook/opt-125m"
406455
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
407456
from transformers import AutoConfig
408457

409-
bits = 4
410-
act_bits = 4
411-
data_type = "nv_fp"
412-
act_data_type = "nv_fp4_with_static_gs"
413-
group_size = 16
414-
sym = True
458+
scheme = "NVFP4"
415459
autoround = AutoRound(
416460
model,
417461
self.tokenizer,
418-
bits=bits,
419-
act_bits=act_bits,
420-
data_type=data_type,
421-
act_data_type=act_data_type,
422-
group_size=group_size,
423-
sym=sym,
462+
scheme="NVFP4",
424463
iters=2,
425464
seqlen=2,
426465
dataset=self.llm_dataloader,

0 commit comments

Comments
 (0)