Skip to content

Commit c4a1479

Browse files
Support mxfp nvfp lmhead quant (#1051)
* fp8 exporting bugfix Signed-off-by: Zhang, Weiwei1 <[email protected]> * refine exllama backend cuda UT Signed-off-by: Zhang, Weiwei1 <[email protected]> * add lm_head layer act_max hook, enable mxfp/nvfp lm_head export Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixtypo Signed-off-by: Zhang, Weiwei1 <[email protected]> * fixtypo Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ut typo Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine logs, fix pack_layer for awq&gptq Signed-off-by: Zhang, Weiwei1 <[email protected]> * refine log, fix pack_layer for awq&gptq Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add awq&gptq lm_head UT Signed-off-by: Zhang, Weiwei1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix local path Signed-off-by: Zhang, Weiwei1 <[email protected]> --------- Signed-off-by: Zhang, Weiwei1 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7907c32 commit c4a1479

File tree

8 files changed

+298
-90
lines changed

8 files changed

+298
-90
lines changed

auto_round/compressors/base.py

Lines changed: 91 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def _set_device(self, device_map: Union[str, torch.device, int, dict]) -> None:
509509
if len(tmp_devices) > 1:
510510
logger.warning(
511511
f"there are multiple device types in the device_map, "
512-
f"please make sure they are correct,use the first device {tmp_devices[0]} as the core device "
512+
f"please make sure they are correct,use the first device {tmp_devices[0]} as the core device."
513513
)
514514

515515
self.device = tmp_devices[0]
@@ -526,7 +526,7 @@ def _parse_and_set(scheme, kwargs):
526526
if "bits" not in kwargs:
527527
data_type = kwargs["data_type"]
528528
raise KeyError(
529-
f"please set bits when setting data_type={data_type}, or using scheme as an alternative.."
529+
f"please set bits when setting data_type={data_type}, or using scheme as an alternative."
530530
)
531531
bits = kwargs["bits"]
532532
scheme = f"gguf:q{bits}_k" if bits == 6 else f"gguf:q{bits}_k_s"
@@ -1469,8 +1469,8 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
14691469
raise ValueError("Could not find any blocks. Check the model or quant_block_list.")
14701470

14711471
all_first_block_names = [block[0] for block in all_blocks]
1472-
if self.act_bits < 16 and not self.act_dynamic:
1473-
layer_names = self._get_quantized_layer_names_outside_blocks()
1472+
layer_names = self._get_quantized_layer_names_outside_blocks()
1473+
if self.act_bits < 16 and (not self.act_dynamic or len(layer_names) > 0):
14741474
if len(layer_names) > 0:
14751475
logger.warning(
14761476
"quantize layers outside blocks for static activation quantizaiton"
@@ -1783,6 +1783,21 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
17831783

17841784
for layer_name in copy.deepcopy(layer_names):
17851785
if layer_name not in layer_inputs:
1786+
if self.act_bits < 16 and not self.act_dynamic:
1787+
# Activation quantization requires collected inputs
1788+
msg_prefix = (
1789+
f"Activation max hook for layer '{layer_name}' is unavailable due to "
1790+
f"insufficient collected inputs. "
1791+
)
1792+
if "fp8_e5m2" in self.act_data_type:
1793+
logger.warning(msg_prefix + "Please notes that unit scale is used for this layer.")
1794+
else:
1795+
logger.warning(
1796+
msg_prefix + "Static activation quantization is not supported or ineffective, "
1797+
"Skipping quantization for this layer."
1798+
)
1799+
layer_names.remove(layer_name)
1800+
continue
17861801
logger.info(f"using rtn to quantize {layer_name}")
17871802
from auto_round.data_type import QUANT_FUNC_WITH_DTYPE
17881803

@@ -1813,6 +1828,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
18131828
q_layer_inputs = None
18141829
enable_quanted_input = self.enable_quanted_input
18151830
has_gguf = False
1831+
18161832
if hasattr(self, "formats"):
18171833
has_gguf = any("gguf" in format_ for format_ in self.formats)
18181834
if has_gguf and self.immediate_packing:
@@ -2334,6 +2350,64 @@ def _replace_forward(self):
23342350
hook_handle = m.register_forward_hook(hook_func)
23352351
self.hook_handles.append(hook_handle)
23362352

2353+
def _register_act_max_hook(self, model):
2354+
def get_act_max_hook(module, input, output):
2355+
if isinstance(input, (tuple, list)):
2356+
input = input[0]
2357+
if input.numel() == 0:
2358+
return # as no needs for act_max update
2359+
input, _, _ = reshape_pad_tensor_by_group_size(input, self.act_group_size)
2360+
act_max = torch.max(torch.abs(input), dim=-1).values
2361+
if not hasattr(module, "act_max") or module.act_max.numel() == 0:
2362+
module.act_max = act_max
2363+
else:
2364+
act_max = act_max.to(module.act_max.device)
2365+
if is_nv_fp(self.act_data_type): ## for nvfp per-tensor input_global_scale calculation usage
2366+
module.act_max = torch.max(
2367+
torch.tensor([act_max.max(), module.act_max.max()], device=act_max.device)
2368+
)
2369+
else:
2370+
module.act_max = torch.max(act_max, module.act_max)
2371+
2372+
hook_handles = []
2373+
# for single layers out of blocks, like lm_head
2374+
if isinstance(model, SUPPORTED_LAYER_TYPES):
2375+
m = model
2376+
if (
2377+
hasattr(m, "act_dynamic")
2378+
and check_need_act_calibration(m.act_dynamic, m.act_data_type, m.act_bits)
2379+
and check_to_quantized(m)
2380+
):
2381+
hook = m.register_forward_hook(get_act_max_hook)
2382+
hook_handles.append(hook)
2383+
return hook_handles
2384+
2385+
for n, m in model.named_modules():
2386+
if (
2387+
hasattr(m, "act_dynamic")
2388+
and check_need_act_calibration(m.act_dynamic, m.act_data_type, m.act_bits)
2389+
and check_to_quantized(m)
2390+
):
2391+
hook = m.register_forward_hook(get_act_max_hook)
2392+
hook_handles.append(hook)
2393+
continue
2394+
2395+
# for whole model, RTN
2396+
if n in self.layer_config:
2397+
config = self.layer_config[n]
2398+
act_dynamic = config.get("act_dynamic", True)
2399+
act_data_type = config.get("act_data_type", None)
2400+
act_bits = config.get("act_bits", 16)
2401+
if (
2402+
config["bits"] <= 8
2403+
and check_need_act_calibration(act_dynamic, act_data_type, act_bits)
2404+
and check_to_quantized(config)
2405+
):
2406+
hook = m.register_forward_hook(get_act_max_hook)
2407+
hook_handles.append(hook)
2408+
continue
2409+
return hook_handles
2410+
23372411
def _quantize_layer(
23382412
self, layer_name: str, inputs: torch.Tensor, q_inputs: torch.Tensor = None, device: str = "cpu"
23392413
):
@@ -2359,6 +2433,19 @@ def _quantize_layer(
23592433
if q_inputs is not None:
23602434
q_inputs[i] = q_inputs[i].to(layer.weight.dtype)
23612435

2436+
if q_inputs is None:
2437+
hook_handles = self._register_act_max_hook(layer)
2438+
with torch.no_grad():
2439+
layer(torch.cat(inputs, dim=0))
2440+
for handle in hook_handles:
2441+
handle.remove()
2442+
else:
2443+
hook_handles = self._register_act_max_hook(layer)
2444+
if hook_handles:
2445+
layer(torch.cat(q_inputs, dim=0))
2446+
for handle in hook_handles:
2447+
handle.remove()
2448+
23622449
wrapper_linear = WrapperLinear(
23632450
layer,
23642451
enable_minmax_tuning=self.enable_minmax_tuning,
@@ -2495,54 +2582,6 @@ def _quantize_layer(
24952582
dump_info = f"quantized {layer_name}, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}"
24962583
logger.info(dump_info)
24972584

2498-
def _register_act_max_hook(self, model):
2499-
2500-
def get_act_max_hook(module, input, output):
2501-
if isinstance(input, (tuple, list)):
2502-
input = input[0]
2503-
if input.numel() == 0:
2504-
return # as no needs for act_max update
2505-
input, _, _ = reshape_pad_tensor_by_group_size(input, self.act_group_size)
2506-
act_max = torch.max(torch.abs(input), dim=-1).values
2507-
if not hasattr(module, "act_max") or module.act_max.numel() == 0:
2508-
module.act_max = act_max
2509-
else:
2510-
act_max = act_max.to(module.act_max.device)
2511-
if is_nv_fp(self.act_data_type): ## for nvfp per-tensor input_global_scale calculation usage
2512-
module.act_max = torch.max(
2513-
torch.tensor([act_max.max(), module.act_max.max()], device=act_max.device)
2514-
)
2515-
else:
2516-
module.act_max = torch.max(act_max, module.act_max)
2517-
2518-
hook_handles = []
2519-
2520-
for n, m in model.named_modules():
2521-
if (
2522-
hasattr(m, "act_dynamic")
2523-
and check_need_act_calibration(m.act_dynamic, m.act_data_type, m.act_bits)
2524-
and check_to_quantized(m)
2525-
):
2526-
hook = m.register_forward_hook(get_act_max_hook)
2527-
hook_handles.append(hook)
2528-
continue
2529-
2530-
# for whole model, RTN
2531-
if n in self.layer_config:
2532-
config = self.layer_config[n]
2533-
act_dynamic = config.get("act_dynamic", True)
2534-
act_data_type = config.get("act_data_type", None)
2535-
act_bits = config.get("act_bits", 16)
2536-
if (
2537-
config["bits"] <= 8
2538-
and check_need_act_calibration(act_dynamic, act_data_type, act_bits)
2539-
and check_to_quantized(config)
2540-
):
2541-
hook = m.register_forward_hook(get_act_max_hook)
2542-
hook_handles.append(hook)
2543-
continue
2544-
return hook_handles
2545-
25462585
def _get_current_output(self, output: list[torch.Tensor], indices: list[int]) -> torch.Tensor:
25472586
current_output = [output[x] for x in indices]
25482587
current_output = torch.cat(current_output, dim=self.batch_dim)

auto_round/export/export_to_autogptq/export.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,6 @@ def convert_from_autogptq_dynamic(dynamic_config: dict) -> dict:
130130

131131

132132
def pack_layer(name, model, backend, device=None):
133-
if name == "lm_head": ##dese not support lm-head
134-
return
135133
layer = get_module(model, name)
136134

137135
if type(layer) not in SUPPORTED_LAYER_TYPES: # already packed

auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@
5252

5353

5454
def pack_layer(name, model, backend, device=None):
55-
if name == "lm_head": # TODO: Check vLLM inference status to determine whether to enable this feature
56-
return
5755
layer = get_module(model, name)
5856
if type(layer) not in SUPPORTED_LAYER_TYPES and not isinstance(layer, WrapperWALayer): ##already packed
5957
return
@@ -82,8 +80,6 @@ def pack_layer(name, model, backend, device=None):
8280
setattr(layer, "input_global_scale", input_global_scale)
8381
delattr(layer, "act_max")
8482

85-
# QuantLinear = get_fp_qlinear(backend, bits, group_size, sym)
86-
8783
if type(layer) == nn.Linear:
8884
in_features = layer.in_features
8985
out_features = layer.out_features

auto_round/export/export_to_awq/export.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@
4646

4747

4848
def pack_layer(name, model, backend, device=None):
49-
if name == "lm_head": ##dese not support lm-head
50-
return
5149
layer = get_module(model, name)
5250

5351
if type(layer) not in SUPPORTED_LAYER_TYPES: ##already packed

auto_round/export/export_to_llmcompressor/export_to_fp.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@
5151

5252

5353
def pack_layer(name, model, backend, device=None):
54-
if name == "lm_head": # TODO: Check vLLM inference status to determine whether to enable this feature
55-
return
5654
layer = get_module(model, name)
5755
if type(layer) not in SUPPORTED_LAYER_TYPES and not isinstance(layer, WrapperWALayer): ##already packed
5856
return

test/test_cpu/test_export.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ def __iter__(self):
3535
class TestAutoRound(unittest.TestCase):
3636
@classmethod
3737
def setUpClass(self):
38-
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
38+
self.model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
3939
self.save_dir = "./saved"
40-
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
41-
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
40+
self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", trust_remote_code=True)
41+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
4242
self.llm_dataloader = LLMDataLoader()
4343

4444
@classmethod
@@ -49,7 +49,7 @@ def tearDownClass(self):
4949
def test_autogptq_format(self):
5050
for group_size in [-1, 32, 128]:
5151
bits, sym = 4, False
52-
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
52+
model_name = self.model_name
5353
autoround = AutoRound(
5454
model=model_name,
5555
bits=bits,
@@ -79,7 +79,7 @@ def test_autogptq_format(self):
7979
def test_autoround_format(self):
8080
for group_size in [-1, 32, 128]:
8181
bits, sym = 4, True
82-
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
82+
model_name = self.model_name
8383
autoround = AutoRound(
8484
model=model_name,
8585
bits=bits,
@@ -105,7 +105,7 @@ def test_autoround_format(self):
105105
def test_autoround_awq_format(self):
106106
for group_size in [-1, 32, 128]:
107107
bits, sym = 4, False
108-
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
108+
model_name = self.model_name
109109
autoround = AutoRound(
110110
model=model_name,
111111
bits=bits,
@@ -217,7 +217,7 @@ def test_static_afp8_export(self, static_kv_dtype):
217217

218218
from safetensors import safe_open
219219

220-
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
220+
model_name = self.model_name
221221
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
222222
autoround = AutoRound(
223223
model,
@@ -307,7 +307,7 @@ def test_static_fp8_attn(self):
307307

308308
from safetensors import safe_open
309309

310-
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
310+
model_name = self.model_name
311311
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
312312
autoround = AutoRound(
313313
model,
@@ -334,6 +334,69 @@ def test_static_fp8_attn(self):
334334

335335
shutil.rmtree(quantized_model_path, ignore_errors=True)
336336

337+
def test_awq_lmhead_export(self):
338+
bits, sym, group_size = 4, False, 128
339+
model_name = "/tf_dataset/auto_round/models/microsoft/phi-2"
340+
layer_config = {
341+
"lm_head": {"bits": 4}, # set lm_head quant
342+
}
343+
autoround = AutoRound(
344+
model=model_name,
345+
bits=bits,
346+
group_size=group_size,
347+
sym=sym,
348+
iters=2,
349+
seqlen=2,
350+
layer_config=layer_config,
351+
dataset=self.llm_dataloader,
352+
)
353+
quantized_model_path = "./saved"
354+
compressed_model, _ = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_awq")
355+
lm_head = compressed_model.lm_head
356+
from auto_round.export.export_to_awq.utils import WQLinear_GEMM
357+
358+
assert isinstance(lm_head, WQLinear_GEMM), "Illegal GPTQ quantization for lm_head layer"
359+
quantization_config = AutoRoundConfig()
360+
model = AutoModelForCausalLM.from_pretrained(
361+
quantized_model_path, device_map="cpu", quantization_config=quantization_config
362+
)
363+
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path)
364+
text = "There is a girl who likes adventure,"
365+
inputs = tokenizer(text, return_tensors="pt").to(model.device)
366+
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))
367+
shutil.rmtree(quantized_model_path, ignore_errors=True)
368+
369+
def test_gptq_lmhead_export(self):
370+
bits, sym, group_size = 4, True, 128
371+
model_name = "/tf_dataset/auto_round/models/microsoft/phi-2"
372+
layer_config = {
373+
"lm_head": {"bits": 4}, # set lm_head quant
374+
}
375+
autoround = AutoRound(
376+
model=model_name,
377+
bits=bits,
378+
group_size=group_size,
379+
sym=sym,
380+
iters=2,
381+
seqlen=2,
382+
layer_config=layer_config,
383+
dataset=self.llm_dataloader,
384+
)
385+
quantized_model_path = "./saved"
386+
compressed_model, _ = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_gptq")
387+
lm_head = compressed_model.lm_head
388+
assert hasattr(lm_head, "bits") and lm_head.bits == 4, "Illegal GPTQ quantization for lm_head layer"
389+
quantization_config = AutoRoundConfig()
390+
model = AutoModelForCausalLM.from_pretrained(
391+
quantized_model_path, device_map="cpu", quantization_config=quantization_config
392+
)
393+
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path)
394+
text = "There is a girl who likes adventure,"
395+
inputs = tokenizer(text, return_tensors="pt").to(model.device)
396+
res = tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0])
397+
print(res)
398+
shutil.rmtree(quantized_model_path, ignore_errors=True)
399+
337400

338401
if __name__ == "__main__":
339402
unittest.main()

0 commit comments

Comments
 (0)