Skip to content

Commit da089e9

Browse files
authored
enable model python files saving (#802)
1 parent ab55a97 commit da089e9

File tree

10 files changed

+97
-6
lines changed

10 files changed

+97
-6
lines changed

auto_round/autoround.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
convert_dtype_str2torch,
5959
convert_fp8_layer_to_linear,
6060
convert_fp8_model_to_16b_model,
61+
copy_python_files_from_model_cache,
6162
detect_device,
6263
estimate_tuning_block_mem,
6364
find_matching_blocks,
@@ -850,7 +851,8 @@ def remove_duplicates(lst):
850851
elif format == "llm_compressor":
851852
from auto_round.export.export_to_llmcompressor import check_compressed_tensors_supported
852853

853-
if check_compressed_tensors_supported() and (is_nv_fp(self.data_type) or is_mx_fp(self.data_type)):
854+
if is_nv_fp(self.data_type) or is_mx_fp(self.data_type):
855+
check_compressed_tensors_supported()
854856
format = format.replace("llm_compressor", f"llm_compressor:{self.data_type}")
855857
formats[index] = format
856858
elif not is_wfp8afp8(self):
@@ -3036,6 +3038,10 @@ def save_quantized(
30363038
processor = kwargs.get("processor", None)
30373039
if processor is not None:
30383040
processor.save_pretrained(output_dir)
3041+
try:
3042+
copy_python_files_from_model_cache(self.model, output_dir)
3043+
except Exception as e:
3044+
logger.warning("Skipping source model Python file copy due to error: %s", e)
30393045
return
30403046
if self.act_bits <= 8 and format == "qdq":
30413047
logger.warning(

auto_round/export/export_to_autogptq/export.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from auto_round.utils import (
5151
SUPPORTED_LAYER_TYPES,
5252
check_to_quantized,
53+
copy_python_files_from_model_cache,
5354
filter_quantization_config,
5455
get_autogptq_packing_qlinear,
5556
get_block_names,
@@ -259,3 +260,8 @@ def save(
259260
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
260261
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
261262
json.dump(model.config.quantization_config, f, indent=2)
263+
264+
try:
265+
copy_python_files_from_model_cache(model, save_dir)
266+
except Exception as e:
267+
logger.warning("Skipping source model Python file copy due to error: %s", e)

auto_round/export/export_to_autoround/export.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
SUPPORTED_LAYER_TYPES,
3232
check_start_with_block_name,
3333
check_to_quantized,
34+
copy_python_files_from_model_cache,
3435
filter_quantization_config,
3536
get_autogptq_packing_qlinear,
3637
get_module,
@@ -399,3 +400,8 @@ def save(model: nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_seri
399400
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
400401
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
401402
json.dump(model.config.quantization_config, f, indent=2)
403+
404+
try:
405+
copy_python_files_from_model_cache(model, save_dir)
406+
except Exception as e:
407+
logger.warning("Skipping source model Python file copy due to error: %s", e)

auto_round/export/export_to_autoround/export_to_fp8.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
_get_packing_device,
3030
check_start_with_block_name,
3131
check_to_quantized,
32+
copy_python_files_from_model_cache,
3233
filter_quantization_config,
3334
get_module,
3435
logger,
@@ -270,3 +271,8 @@ def save(
270271
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
271272
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
272273
json.dump(model.config.quantization_config, f, indent=2)
274+
275+
try:
276+
copy_python_files_from_model_cache(model, save_dir)
277+
except Exception as e:
278+
logger.warning("Skipping source model Python file copy due to error: %s", e)

auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_get_packing_device,
3131
check_start_with_block_name,
3232
check_to_quantized,
33+
copy_python_files_from_model_cache,
3334
filter_quantization_config,
3435
get_module,
3536
is_mx_fp,
@@ -282,3 +283,8 @@ def save(model: nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_seri
282283
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
283284
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
284285
json.dump(model.config.quantization_config, f, indent=2)
286+
287+
try:
288+
copy_python_files_from_model_cache(model, save_dir)
289+
except Exception as e:
290+
logger.warning("Skipping source model Python file copy due to error: %s", e)

auto_round/export/export_to_awq/export.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from auto_round.utils import (
3535
SUPPORTED_LAYER_TYPES,
3636
check_to_quantized,
37+
copy_python_files_from_model_cache,
3738
extract_block_names_to_str,
3839
filter_quantization_config,
3940
get_block_names,
@@ -197,3 +198,8 @@ def save(model: nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_seri
197198
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
198199
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
199200
json.dump(model.config.quantization_config, f, indent=2)
201+
202+
try:
203+
copy_python_files_from_model_cache(model, save_dir)
204+
except Exception as e:
205+
logger.warning("Skipping source model Python file copy due to error: %s", e)

auto_round/export/export_to_llmcompressor/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ def check_compressed_tensors_supported(): # pragma: no cover
7676

7777
return True
7878
except ImportError:
79-
logger.warning(
79+
logger.error(
8080
"Please install compressed-tensors via 'pip install compressed-tensors'" " to save as llm-compressor format"
8181
)
82-
return False
82+
exit(-1)
8383

8484

8585
if check_compressed_tensors_supported():

auto_round/export/export_to_llmcompressor/export.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,16 @@
1717
import torch
1818

1919
from auto_round.export.export_to_llmcompressor.config import quantization_config
20-
from auto_round.utils import detect_device, get_module, is_mx_fp, is_nv_fp, is_standard_fp, logger, set_module
20+
from auto_round.utils import (
21+
copy_python_files_from_model_cache,
22+
detect_device,
23+
get_module,
24+
is_mx_fp,
25+
is_nv_fp,
26+
is_standard_fp,
27+
logger,
28+
set_module,
29+
)
2130
from auto_round.wrapper import WrapperWALayer
2231

2332
from .export_to_fp import save_quantized_as_fp
@@ -111,3 +120,8 @@ def save_quantized_as_llmcompressor(output_dir, **kwargs):
111120
if hasattr(model, "generation_config"):
112121
setattr(model.generation_config, "do_sample", True)
113122
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
123+
124+
try:
125+
copy_python_files_from_model_cache(model, output_dir)
126+
except Exception as e:
127+
logger.warning("Skipping source model Python file copy due to error: %s", e)

auto_round/export/export_to_llmcompressor/export_to_fp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
SUPPORTED_LAYER_TYPES,
3030
check_start_with_block_name,
3131
check_to_quantized,
32+
copy_python_files_from_model_cache,
3233
filter_quantization_config,
3334
get_block_names,
3435
get_module,
@@ -274,3 +275,8 @@ def save(model: nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_seri
274275
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
275276
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
276277
json.dump(model.config.quantization_config, f, indent=2)
278+
279+
try:
280+
copy_python_files_from_model_cache(model, save_dir)
281+
except Exception as e:
282+
logger.warning("Skipping source model Python file copy due to error: %s", e)

auto_round/utils.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2550,8 +2550,11 @@ def is_nv_fp(backend):
25502550

25512551

25522552
def is_wfp8afp8(ar):
2553-
if ("fp8" in ar.act_data_type or ("fp" in ar.act_data_type and ar.act_bits == 8)) and (
2554-
"fp8" in ar.data_type or ("fp" in ar.data_type and ar.bits == 8)
2553+
if (
2554+
("fp8" in ar.act_data_type or ("fp" in ar.act_data_type and ar.act_bits == 8))
2555+
and ("fp8" in ar.data_type or ("fp" in ar.data_type and ar.bits == 8))
2556+
and is_standard_fp(ar.act_data_type)
2557+
and is_standard_fp(ar.data_type)
25552558
):
25562559
return True
25572560
else:
@@ -2677,3 +2680,35 @@ def _get_packing_device(device: str | torch.device | None = "auto") -> torch.dev
26772680
raise ValueError(f"Invalid device string: {device}") from e
26782681

26792682
raise TypeError(f"Unsupported device type: {type(device)} ({device})")
2683+
2684+
2685+
# Adapted from https://github.com/vllm-project/llm-compressor/blob/
2686+
# 5b3ddff74cae9651f24bef15d3255c4ee053fc60/src/llmcompressor/pytorch/model_load/helpers.py#L144
2687+
def copy_python_files_from_model_cache(model, save_path: str):
2688+
config = model.config
2689+
cache_path = None
2690+
if hasattr(config, "_name_or_path"):
2691+
import os
2692+
import shutil
2693+
2694+
from huggingface_hub import hf_hub_download
2695+
from transformers import TRANSFORMERS_CACHE
2696+
from transformers.utils import http_user_agent
2697+
2698+
cache_path = config._name_or_path
2699+
if not os.path.exists(cache_path):
2700+
user_agent = http_user_agent()
2701+
config_file_path = hf_hub_download(
2702+
repo_id=cache_path,
2703+
filename="config.json",
2704+
cache_dir=TRANSFORMERS_CACHE,
2705+
force_download=False,
2706+
user_agent=user_agent,
2707+
)
2708+
cache_path = os.path.sep.join(config_file_path.split(os.path.sep)[:-1])
2709+
2710+
for file in os.listdir(cache_path):
2711+
full_file_name = os.path.join(cache_path, file)
2712+
if file.endswith(".py") and os.path.isfile(full_file_name):
2713+
logger.debug(f"Transferring {full_file_name} to {save_path}")
2714+
shutil.copy(full_file_name, save_path)

0 commit comments

Comments
 (0)