Skip to content

Commit e1b89d2

Browse files
authored
[High Risk]reduce vram usage for optimized RTN mode (#1043)
1 parent 90871a7 commit e1b89d2

File tree

12 files changed

+711
-356
lines changed

12 files changed

+711
-356
lines changed

README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,14 @@ ar.quantize_and_save(output_dir="./qmodel", format="auto_round")
189189
<summary>Important Hyperparameters</summary>
190190

191191
##### Quantization Scheme & Configuration
192-
- **`scheme` (str|dict|AutoScheme)**: The predefined quantization keys, e.g. `W4A16`, `MXFP4`, `NVFP4`, `GGUF:Q4_K_M`.
192+
- **`scheme` (str|dict|AutoScheme)**: The predefined quantization keys, e.g. `W4A16`, `MXFP4`, `NVFP4`, `GGUF:Q4_K_M`. For MXFP4/NVFP4, we recommend exporting to LLM-Compressor format.
193193
- **`bits` (int)**: Number of bits for quantization (default is `None`). If not None, it will override the scheme setting.
194194
- **`group_size` (int)**: Size of the quantization group (default is `None`). If not None, it will override the scheme setting.
195195
- **`sym` (bool)**: Whether to use symmetric quantization (default is `None`). If not None, it will override the scheme setting.
196-
- **`layer_config` (dict)**: Configuration for weight quantization (default is `None`), mainly for mixed schemes.
196+
- **`layer_config` (dict)**: Configuration for layer_wise scheme (default is `None`), mainly for customized mixed schemes.
197197

198198
##### Algorithm Settings
199-
- **`enable_alg_ext` (bool)**: [Experimental Feature] Enable algorithm variants for specific schemes (e.g., MXFP4/W2A16) that could bring notable improvements. Default is `False`.
199+
- **`enable_alg_ext` (bool)**: [Experimental Feature] Only for `iters>0`. Enable algorithm variants for specific schemes (e.g., MXFP4/W2A16) that could bring notable improvements. Default is `False`.
200200
- **`disable_opt_rtn` (bool)**: Use pure RTN mode for specific schemes (e.g., GGUF and WOQ). Default is `False` (improved RTN enabled).
201201

202202
##### Tuning Process Parameters
@@ -217,7 +217,8 @@ ar.quantize_and_save(output_dir="./qmodel", format="auto_round")
217217

218218
</details>
219219

220-
### AutoScheme Usage
220+
### Adaptive Bits/Dtype Usage
221+
AutoScheme provide automatically algorithm to provide mixed bits/data_type quantization recipes. For some accuracy result, please refer to this [doc](https://github.com/intel/auto-round/blob/main/docs/auto_scheme_acc.md).
221222
Please refer to the [user guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#autoscheme) for more details on AutoScheme.
222223
~~~python
223224
from auto_round import AutoRound, AutoScheme
@@ -299,7 +300,7 @@ for output in outputs:
299300

300301

301302
### SGLang (Intel GPU/CUDA)
302-
Please note that support for the MoE models and visual language models is currently limited.
303+
**Please note that support for the MoE models and visual language models is currently limited.**
303304

304305
```python
305306
import sglang as sgl

auto_round/compressors/base.py

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -713,11 +713,10 @@ def _check_compatibility(self) -> None:
713713
raise ValueError("Gguf format is not compatible with other formats, please choose only one of them")
714714
if has_gguf and self.iters != 0 and self.bits != 3 and not self.enable_alg_ext:
715715
logger.warning(
716-
"`iters=0` is recommended when exporting to GGUF format except for bits 3,"
717-
" as we have optimized the RTN method for this case."
718-
" Or add enable_alg_ext to use the new algorithm,"
719-
" refer to https://github.com/intel/auto-round/tree/main/docs/gguf_alg_ext_acc.md"
720-
" to check the acc."
716+
"`iters=0` is recommended when exporting to current GGUF format"
717+
" or add `enable_alg_ext` for better accuracy with much more tuning cost."
718+
" Please refer to https://github.com/intel/auto-round/tree/main/docs/gguf_alg_ext_acc.md"
719+
" for the accuracy results."
721720
)
722721

723722
if (
@@ -1087,11 +1086,16 @@ def _quantize_embedding_layer(self):
10871086
dtype = f"rtn_{dtype}"
10881087

10891088
quant_func = QUANT_FUNC_WITH_DTYPE[dtype]
1089+
dtype = module.weight.dtype
1090+
# As typically float32 are used in RTN to search scale zp,
1091+
# to avoid cache a bf16 copy we'd better use float32
1092+
if config["super_group_size"] is not None:
1093+
dtype = torch.float32
10901094

10911095
# Attempt quantization on GPU, fall back to CPU if OOM
10921096
try:
10931097
weight, scale, zp = quant_func(
1094-
module.weight.to(self.device),
1098+
module.weight.to(dtype=dtype, device=self.device),
10951099
**{k: config[k] for k in ["bits", "group_size", "super_bits", "super_group_size", "scale_dtype"]},
10961100
)
10971101
except torch.OutOfMemoryError:
@@ -1124,8 +1128,9 @@ def _quantize_embedding_layer(self):
11241128

11251129
# Update config
11261130
self.layer_config.setdefault(name, {}).update(config)
1127-
1128-
# Release memory
1131+
del weight
1132+
del scale
1133+
del zp
11291134
clear_memory(device_list=self.device_list)
11301135

11311136
return is_quantized
@@ -1224,7 +1229,7 @@ def get_imatrix_hook(module, input, output):
12241229
for hook in hooks:
12251230
hook.remove()
12261231

1227-
def _quantize_layer_via_rtn(self, name: str) -> None:
1232+
def _quantize_layer_via_rtn(self, name: str, dtype: torch.dtype = None, to_cpu=True) -> None:
12281233
"""Quantizes a layer using RTN (Round-To-Nearest) if available.
12291234
12301235
This function attempts to quantize a layer by switching its data type to a
@@ -1241,19 +1246,20 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
12411246
RuntimeError: If quantization fails for reasons unrelated to memory.
12421247
"""
12431248
m = get_module(self.model, name)
1249+
if dtype is not None:
1250+
m = m.to(dtype)
12441251

12451252
if is_fp8_linear(m):
12461253
m = convert_fp8_layer_to_linear(m, self.amp_dtype, self.device)
12471254
set_module(self.model, name, m)
1248-
1255+
tuning_device = m.tuning_device if hasattr(m, "tuning_device") else self.device
12491256
# Step 1: Try quantization on GPU first, fall back to CPU if OOM
1250-
# if only export gguf, using gguf-packing instead of rtn
12511257
if self.immediate_packing and self.iters == 0 and "gguf" in self.formats[0] and not self.disable_opt_rtn:
1258+
m = m.to(tuning_device)
12521259
m.scale = None
12531260
m.zp = None
12541261
else:
12551262
try:
1256-
tuning_device = m.tuning_device if hasattr(m, "tuning_device") else self.device
12571263
m = m.to(tuning_device)
12581264
m = WrapperLinear(
12591265
m,
@@ -1265,7 +1271,6 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
12651271
disable_opt_rtn=self.disable_opt_rtn,
12661272
)
12671273
m = m.unwrapper({})
1268-
m.to("cpu")
12691274
except torch.OutOfMemoryError:
12701275
cuda_error_msg = traceback.format_exc()
12711276
m = m.orig_layer if hasattr(m, "orig_layer") else m
@@ -1285,18 +1290,23 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
12851290
raise
12861291

12871292
# Step 2: Optional immediate packing/export
1288-
if self.immediate_packing:
1293+
if self.immediate_packing: # For gguf, packing conducts on block level
12891294
self._immediate_pack(name)
1295+
if to_cpu:
1296+
m = m.to("cpu")
12901297
else:
1298+
if to_cpu:
1299+
m = m.to("cpu")
12911300
set_module(self.model, name, m)
1292-
12931301
if self.immediate_saving:
12941302
all_to_quantized_module_names = [n for n, m in self.model.named_modules() if check_to_quantized(m)]
12951303
last_module = (len(all_to_quantized_module_names) == 0) or (name == all_to_quantized_module_names[-1])
12961304
m = get_module(self.model, name)
12971305
immediate_saving(self, m, name, last_module)
12981306

12991307
def _immediate_pack(self, name: str):
1308+
if not self.immediate_packing:
1309+
return
13001310
m = get_module(self.model, name)
13011311
if not check_to_quantized(m):
13021312
return
@@ -1353,16 +1363,18 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
13531363
for module in tqdm(modules, desc="Update weight global scale for fuse module"):
13541364
update_fused_layer_global_scales(module)
13551365

1356-
has_gguf_k = (
1357-
any("gguf" in fmt and "k" in fmt for fmt in getattr(self, "formats", [])) or self.super_bits is not None
1358-
)
1359-
1360-
self._quantize_embedding_layer()
1366+
if not (any("gguf" in fmt for fmt in getattr(self, "formats", [])) or self.super_bits is not None):
1367+
self._quantize_embedding_layer() # leave to gguf itself to handle
13611368

13621369
self.model.to("cpu")
1370+
# Release memory
1371+
clear_memory(device_list=self.device_list)
13631372

13641373
enable_imatrix = False
13651374
if not self.disable_opt_rtn:
1375+
has_gguf_k = (
1376+
any("gguf" in fmt and "k" in fmt for fmt in getattr(self, "formats", [])) or self.super_bits is not None
1377+
)
13661378
if has_gguf_k:
13671379
enable_imatrix = True
13681380
elif self.data_type == "int" and self.sym:
@@ -1498,39 +1510,44 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
14981510
self.device,
14991511
self.cache_device,
15001512
)
1513+
15011514
if len(self.device_list) > 1:
15021515
accelerate.hooks.remove_hook_from_submodules(block)
15031516

15041517
if is_nv_fp(self.act_data_type) or is_static_wfp8afp8(self):
15051518
# enable moe experts act_max automatic generation for Linear
15061519
set_amax_for_all_moe_layers(block, attr_name="act_max")
15071520
# Normalize imatrix and quantize layers
1521+
if self.low_gpu_mem_usage:
1522+
block.to("cpu")
1523+
clear_memory(device_list=self.device_list)
1524+
15081525
for _, m in block.named_modules():
15091526
# fix issue: Ling-flash-2.0-q2_k_s fail infer on cuda but well on cpu
15101527
# https://huggingface.co/Intel/Ling-flash-2.0-gguf-q2ks-mixed-AutoRound/discussions/1
15111528
if hasattr(m, "imatrix"):
15121529
m.imatrix /= m.imatrix_cnt
15131530
if hasattr(m, "tmp_name") and m.tmp_name in all_to_quantized_module_names:
1514-
self._quantize_layer_via_rtn(m.tmp_name)
1531+
self._quantize_layer_via_rtn(m.tmp_name, to_cpu=False)
15151532
all_to_quantized_module_names.remove(m.tmp_name)
15161533
if not self.immediate_saving:
15171534
mv_module_from_gpu(block)
1535+
if block_name == block_names[-1]:
1536+
clear_memory(input_ids, device_list=self.device_list)
1537+
else:
1538+
clear_memory(device_list=self.device_list)
1539+
15181540
memory_monitor.log_summary()
15191541
pbar.update(1)
15201542

15211543
pbar.close()
1522-
cnt = 1
1523-
block_names_cnt = len(flatten_list(get_block_names(self.model, True)))
1524-
clear_mem_freq = len(all_to_quantized_module_names) // block_names_cnt
1525-
if clear_mem_freq == 0:
1526-
clear_mem_freq = 1
15271544
# Process remaining layers not in blocks
15281545
for name in all_to_quantized_module_names:
1529-
self._quantize_layer_via_rtn(name)
1530-
if cnt % clear_mem_freq == 0:
1531-
clear_memory(device_list=self.device_list)
1532-
cnt = 1
1533-
cnt += 1
1546+
dtype = None
1547+
if self.super_group_size is not None:
1548+
dtype = torch.float32
1549+
self._quantize_layer_via_rtn(name, dtype=dtype)
1550+
# clear_memory(device_list=self.device_list)
15341551

15351552
def _update_inputs(self, inputs: dict, q_inputs: dict) -> tuple[dict, torch.Tensor]:
15361553
keys = inputs.keys()
@@ -1631,6 +1648,7 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
16311648
logger.info("start to cache block inputs")
16321649
all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names)
16331650
is_quantized_embedding = self._quantize_embedding_layer()
1651+
clear_memory(device_list=self.device_list)
16341652
all_q_inputs = None
16351653
if is_quantized_embedding:
16361654
all_inputs = copy.deepcopy(self.inputs)
@@ -2838,7 +2856,7 @@ def _quantize_block(
28382856
if auto_offload:
28392857
mv_module_from_gpu(block)
28402858

2841-
clear_memory(input_ids)
2859+
clear_memory(input_ids, device_list=self.device_list)
28422860
memory_info_summary = memory_monitor.get_summary()
28432861
logger.infoclean(dump_info + "," + memory_info_summary)
28442862

@@ -2848,7 +2866,7 @@ def _quantize_block(
28482866
accelerate.hooks.remove_hook_from_submodules(block)
28492867
if auto_offload:
28502868
mv_module_from_gpu(block)
2851-
clear_memory(input_ids)
2869+
clear_memory(input_ids, device_list=self.device_list)
28522870
memory_info_summary = memory_monitor.get_summary()
28532871
logger.infoclean(dump_info + "," + memory_info_summary)
28542872

0 commit comments

Comments
 (0)