Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions auto_round/export/export_to_awq/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def wrapper(name):
return model

quantization_config = kwargs["serialization_dict"]
quantization_config.pop("regex_config") #as awq do not support mixed bits config saving
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if its 16 bits, we could convert it to not_convert_module, I forget the name


if output_dir is None:
return compressed_model
Expand All @@ -159,3 +160,4 @@ def wrapper(name):
save_model(compressed_model, output_dir, safe_serialization=safe_serialization, dtype=dtype)

return compressed_model

26 changes: 12 additions & 14 deletions test/test_cpu/test_mix_bits.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
sys.path.insert(0, "../..")
import torch
from transformers import AutoModelForCausalLM, AutoRoundConfig, AutoTokenizer

from auto_round import AutoRound
from auto_round.testing_utils import require_gptqmodel

from auto_round import AutoRound

def _get_folder_size(path: str) -> float:
"""Return folder size in GB."""
Expand All @@ -36,7 +35,7 @@ def __iter__(self):
class TestAutoRound(unittest.TestCase):
@classmethod
def setUpClass(self):
model_name = "facebook/opt-125m"
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
self.save_dir = "./saved"
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
Expand All @@ -46,11 +45,11 @@ def setUpClass(self):
def tearDownClass(self):
shutil.rmtree("./saved", ignore_errors=True)
shutil.rmtree("runs", ignore_errors=True)

@require_gptqmodel
def test_mixed_gptqmodel(self):
bits, sym, group_size = 4, True, 128
model_name = "facebook/opt-125m"
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
layer_config = {
"k_proj": {"bits": 8},
"lm_head": {"bits": 16},
Expand All @@ -69,17 +68,16 @@ def test_mixed_gptqmodel(self):
quantized_model_path = "./saved"
autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq")
from gptqmodel import GPTQModel

model = GPTQModel.load(quantized_model_path)
assert model.model.model.decoder.layers[0].self_attn.k_proj.bits == 8
assert model.model.model.decoder.layers[0].self_attn.q_proj.bits == 4
result = model.generate("Uncovering deep insights begins with")[0] # tokens
assert "!!!" not in model.tokenizer.decode(result) # string output
assert (model.model.model.decoder.layers[0].self_attn.k_proj.bits == 8)
assert (model.model.model.decoder.layers[0].self_attn.q_proj.bits == 4)
result = model.generate("Uncovering deep insights begins with")[0] # tokens
assert("!!!" not in model.tokenizer.decode(result)) # string output
shutil.rmtree(quantized_model_path, ignore_errors=True)

def test_mixed_autoround_format(self):
bits, sym, group_size = 4, True, 128
model_name = "facebook/opt-125m"
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
layer_config = {
"k_proj": {"bits": 8},
"q_proj": {"bits": 3},
Expand All @@ -99,14 +97,14 @@ def test_mixed_autoround_format(self):
quantized_model_path = "./saved"
autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round")
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu")
assert model.model.decoder.layers[0].self_attn.k_proj.bits == 8
assert model.model.decoder.layers[0].self_attn.q_proj.bits == 3
assert (model.model.decoder.layers[0].self_attn.k_proj.bits == 8)
assert (model.model.decoder.layers[0].self_attn.q_proj.bits == 3)
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))
shutil.rmtree(quantized_model_path, ignore_errors=True)


if __name__ == "__main__":
unittest.main()

45 changes: 19 additions & 26 deletions test/test_cuda/test_mix_bits.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
sys.path.insert(0, "../..")
import torch
from transformers import AutoModelForCausalLM, AutoRoundConfig, AutoTokenizer

from auto_round import AutoRound
from auto_round.testing_utils import require_gptqmodel

from auto_round import AutoRound

class LLMDataLoader:
def __init__(self):
Expand All @@ -25,7 +24,7 @@ def __iter__(self):
class TestAutoRound(unittest.TestCase):
@classmethod
def setUpClass(self):
model_name = "facebook/opt-125m"
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
self.save_dir = "./saved"
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
Expand All @@ -35,11 +34,11 @@ def setUpClass(self):
def tearDownClass(self):
shutil.rmtree("./saved", ignore_errors=True)
shutil.rmtree("runs", ignore_errors=True)

@require_gptqmodel
def test_mixed_gptqmodel(self):
bits, sym, group_size = 4, True, 128
model_name = "facebook/opt-125m"
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
layer_config = {
"k_proj": {"bits": 8},
"lm_head": {"bits": 16},
Expand All @@ -58,17 +57,16 @@ def test_mixed_gptqmodel(self):
quantized_model_path = "./saved"
autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq")
from gptqmodel import GPTQModel

model = GPTQModel.load(quantized_model_path)
assert model.model.model.decoder.layers[0].self_attn.k_proj.bits == 8
assert model.model.model.decoder.layers[0].self_attn.q_proj.bits == 4
result = model.generate("Uncovering deep insights begins with")[0] # tokens
assert "!!!" not in model.tokenizer.decode(result) # string output
assert (model.model.model.decoder.layers[0].self_attn.k_proj.bits == 8)
assert (model.model.model.decoder.layers[0].self_attn.q_proj.bits == 4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test the regex , full name, part name and make sure the model could inference correctly.

result = model.generate("Uncovering deep insights begins with")[0] # tokens
assert("!!!" not in model.tokenizer.decode(result)) # string output
shutil.rmtree(quantized_model_path, ignore_errors=True)

def test_mixed_autoround_format(self):
bits, sym, group_size = 4, True, 128
model_name = "facebook/opt-125m"
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
layer_config = {
"k_proj": {"bits": 8},
"q_proj": {"bits": 3},
Expand All @@ -88,8 +86,8 @@ def test_mixed_autoround_format(self):
quantized_model_path = "./saved"
autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round")
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu")
assert model.model.decoder.layers[0].self_attn.k_proj.bits == 8
assert model.model.decoder.layers[0].self_attn.q_proj.bits == 3
assert (model.model.decoder.layers[0].self_attn.k_proj.bits == 8)
assert (model.model.decoder.layers[0].self_attn.q_proj.bits == 3)
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
Expand All @@ -115,7 +113,6 @@ def test_mixed_autoround_format_vllm(self):
autoround.save_quantized(output_dir=quantized_model_path, inplace=False, format="auto_round")

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
"The capital of France is",
Expand All @@ -124,7 +121,7 @@ def test_mixed_autoround_format_vllm(self):
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
QUANTIZATION = "auto-round" # quantized_model_path
QUANTIZATION = "auto-round" #quantized_model_path
llm = LLM(model=quantized_model_path, quantization=QUANTIZATION, trust_remote_code=True, tensor_parallel_size=1)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
Expand All @@ -136,16 +133,13 @@ def test_mixed_autoround_format_vllm(self):
print(f"{prompt}: {generated_text}")
shutil.rmtree(quantized_model_path, ignore_errors=True)


def test_mixed_llmcompressor_format_vllm(self):
model_name = "facebook/opt-125m"
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
layer_config = {
"self_attn": {"bits": 16, "act_bits": 16, "data_type": "float"},
"lm_head": {"bits": 16, "act_bits": 16, "data_type": "float"},
"fc1": {
"bits": 16,
"act_bits": 16,
"data_type": "float",
},
"fc1": {"bits": 16, "act_bits": 16, "data_type": "float", },
}
autoround = AutoRound(
model_name,
Expand All @@ -156,11 +150,8 @@ def test_mixed_llmcompressor_format_vllm(self):
layer_config=layer_config,
)
quantized_model_path = self.save_dir
compressed, _ = autoround.quantize_and_save(
output_dir=quantized_model_path, inplace=False, format="llm_compressor"
)
compressed,_ = autoround.quantize_and_save(output_dir=quantized_model_path, inplace=False, format="llm_compressor")
from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
"The capital of France is",
Expand All @@ -169,7 +160,7 @@ def test_mixed_llmcompressor_format_vllm(self):
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
QUANTIZATION = "auto-round" # quantized_model_path
QUANTIZATION = "auto-round" #quantized_model_path
llm = LLM(model=quantized_model_path, trust_remote_code=True, tensor_parallel_size=1)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
Expand All @@ -181,5 +172,7 @@ def test_mixed_llmcompressor_format_vllm(self):
shutil.rmtree(quantized_model_path, ignore_errors=True)



if __name__ == "__main__":
unittest.main()
Copy link
Contributor

@wenhuach21 wenhuach21 Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For AutoRound format, please make sure inference is ready first then support it on exporting side


Loading