Skip to content

Commit 5329bcd

Browse files
authored
fix bf16 option in AutoScheme (#1079)
1 parent 727e573 commit 5329bcd

File tree

3 files changed

+22
-8
lines changed

3 files changed

+22
-8
lines changed

auto_round/__main__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ def __init__(self, *args, **kwargs):
7878
help="The batch size for tuning/calibration."
7979
"Larger batch sizes may improve stability but require more memory.",
8080
)
81-
basic.add_argument("--avg_bits", default=None, type=float, help="for auto scheme, number of avg weight bits")
81+
basic.add_argument(
82+
"--avg_bits", "--target_bits", default=None, type=float, help="for auto scheme, number of avg weight bits"
83+
)
8284
basic.add_argument(
8385
"--options", default=None, type=str, help="for auto scheme, options for auto scheme, e.g. 'W4A16,W8A16'"
8486
)

auto_round/compressors/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -427,10 +427,6 @@ def _gen_auto_scheme(
427427

428428
all_dtypes = []
429429
for option in scheme.options:
430-
# Skip pure BF16 option
431-
if option == "BF16":
432-
continue
433-
434430
# Resolve the quantization scheme or data type
435431
dtype = "int"
436432
if isinstance(option, str):

test/test_cpu/test_autoround.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -713,14 +713,22 @@ def test_invalid_layer_config(self):
713713

714714
def test_quant_lm_head(self):
715715
model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B"
716-
ar = AutoRound(model_name, quant_lm_head=True, iters=0, disable_opt_rtn=True)
716+
ar = AutoRound(model_name, quant_lm_head=True, iters=0, seqlen=8, nsamples=1, disable_opt_rtn=True)
717717
ar.quantize_and_save(output_dir=self.save_folder, format="auto_round")
718718
model = AutoModelForCausalLM.from_pretrained(self.save_folder, device_map="cpu")
719719
assert "lm_head" in model.config.quantization_config.extra_config
720720
assert model.config.quantization_config.extra_config["lm_head"]["bits"] == 4
721721

722722
layer_config = {"lm_head": {"bits": 4}}
723-
ar = AutoRound(model_name, quant_lm_head=False, iters=0, disable_opt_rtn=True, layer_config=layer_config)
723+
ar = AutoRound(
724+
model_name,
725+
quant_lm_head=False,
726+
iters=0,
727+
seqlen=8,
728+
nsamples=1,
729+
disable_opt_rtn=True,
730+
layer_config=layer_config,
731+
)
724732
ar.quantize_and_save(output_dir=self.save_folder, format="auto_round")
725733
model = AutoModelForCausalLM.from_pretrained(self.save_folder, device_map="cpu")
726734
assert "lm_head" in model.config.quantization_config.extra_config
@@ -729,7 +737,15 @@ def test_quant_lm_head(self):
729737
def test_quant_lm_head_layer_config(self):
730738
model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B"
731739
layer_config = {"lm_head": {"bits": 4}}
732-
ar = AutoRound(model_name, quant_lm_head=True, iters=0, disable_opt_rtn=True, layer_config=layer_config)
740+
ar = AutoRound(
741+
model_name,
742+
quant_lm_head=True,
743+
iters=0,
744+
seqlen=8,
745+
nsamples=1,
746+
disable_opt_rtn=True,
747+
layer_config=layer_config,
748+
)
733749
ar.quantize_and_save(output_dir=self.save_folder, format="auto_round")
734750
model = AutoModelForCausalLM.from_pretrained(self.save_folder, device_map="cpu")
735751
assert "lm_head" in model.config.quantization_config.extra_config

0 commit comments

Comments
 (0)