Skip to content

Commit 3b0f98a

Browse files
committed
move customized copy function to example utils
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent 39ccfad commit 3b0f98a

File tree

2 files changed

+60
-53
lines changed

2 files changed

+60
-53
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
# limitations under the License.
1515

1616
import os
17+
import shutil
1718
import sys
1819
import warnings
20+
from pathlib import Path
1921
from typing import Any
2022

2123
import torch
@@ -263,3 +265,53 @@ def apply_kv_cache_quant(quant_cfg: dict[str, Any], kv_cache_quant_cfg: dict[str
263265
quant_cfg["algorithm"] = "max"
264266

265267
return quant_cfg
268+
269+
270+
def copy_custom_model_files(source_path: str, export_path: str, trust_remote_code: bool = False):
271+
"""Copy custom model files (configuration_*.py, modeling_*.py, etc.) from source to export directory.
272+
273+
Args:
274+
source_path: Path to the original model directory
275+
export_path: Path to the exported model directory
276+
trust_remote_code: Whether trust_remote_code was used (only copy files if True)
277+
"""
278+
if not trust_remote_code:
279+
return
280+
281+
source_dir = Path(source_path)
282+
export_dir = Path(export_path)
283+
284+
if not source_dir.exists():
285+
print(f"Warning: Source directory {source_path} does not exist")
286+
return
287+
288+
if not export_dir.exists():
289+
print(f"Warning: Export directory {export_path} does not exist")
290+
return
291+
292+
# Common patterns for custom model files that need to be copied
293+
custom_file_patterns = [
294+
"configuration_*.py",
295+
"modeling_*.py",
296+
"tokenization_*.py",
297+
"processing_*.py",
298+
"image_processing_*.py",
299+
"feature_extraction_*.py",
300+
]
301+
302+
copied_files = []
303+
for pattern in custom_file_patterns:
304+
for file_path in source_dir.glob(pattern):
305+
if file_path.is_file():
306+
dest_path = export_dir / file_path.name
307+
try:
308+
shutil.copy2(file_path, dest_path)
309+
copied_files.append(file_path.name)
310+
print(f"Copied custom model file: {file_path.name}")
311+
except Exception as e:
312+
print(f"Warning: Failed to copy {file_path.name}: {e}")
313+
314+
if copied_files:
315+
print(f"Successfully copied {len(copied_files)} custom model files to {export_path}")
316+
else:
317+
print("No custom model files found to copy")

examples/llm_ptq/hf_ptq.py

Lines changed: 8 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,21 @@
1616
import argparse
1717
import copy
1818
import random
19-
import shutil
2019
import time
2120
import warnings
22-
from pathlib import Path
2321
from typing import Any
2422

2523
import numpy as np
2624
import torch
2725
from accelerate.hooks import remove_hook_from_module
28-
from example_utils import apply_kv_cache_quant, get_model, get_processor, get_tokenizer, is_enc_dec
26+
from example_utils import (
27+
apply_kv_cache_quant,
28+
copy_custom_model_files,
29+
get_model,
30+
get_processor,
31+
get_tokenizer,
32+
is_enc_dec,
33+
)
2934
from transformers import (
3035
AutoConfig,
3136
AutoModelForCausalLM,
@@ -85,56 +90,6 @@
8590
mto.enable_huggingface_checkpointing()
8691

8792

88-
def copy_custom_model_files(source_path: str, export_path: str, trust_remote_code: bool = False):
89-
"""Copy custom model files (configuration_*.py, modeling_*.py, etc.) from source to export directory.
90-
91-
Args:
92-
source_path: Path to the original model directory
93-
export_path: Path to the exported model directory
94-
trust_remote_code: Whether trust_remote_code was used (only copy files if True)
95-
"""
96-
if not trust_remote_code:
97-
return
98-
99-
source_dir = Path(source_path)
100-
export_dir = Path(export_path)
101-
102-
if not source_dir.exists():
103-
print(f"Warning: Source directory {source_path} does not exist")
104-
return
105-
106-
if not export_dir.exists():
107-
print(f"Warning: Export directory {export_path} does not exist")
108-
return
109-
110-
# Common patterns for custom model files that need to be copied
111-
custom_file_patterns = [
112-
"configuration_*.py",
113-
"modeling_*.py",
114-
"tokenization_*.py",
115-
"processing_*.py",
116-
"image_processing_*.py",
117-
"feature_extraction_*.py",
118-
]
119-
120-
copied_files = []
121-
for pattern in custom_file_patterns:
122-
for file_path in source_dir.glob(pattern):
123-
if file_path.is_file():
124-
dest_path = export_dir / file_path.name
125-
try:
126-
shutil.copy2(file_path, dest_path)
127-
copied_files.append(file_path.name)
128-
print(f"Copied custom model file: {file_path.name}")
129-
except Exception as e:
130-
print(f"Warning: Failed to copy {file_path.name}: {e}")
131-
132-
if copied_files:
133-
print(f"Successfully copied {len(copied_files)} custom model files to {export_path}")
134-
else:
135-
print("No custom model files found to copy")
136-
137-
13893
def auto_quantize(
13994
model, qformat, auto_quantize_bits, calib_dataloader, calibrate_loop, batch_size=1
14095
):

0 commit comments

Comments
 (0)