diff --git a/llmc/__main__.py b/llmc/__main__.py index 98db4ca26..ec60c1492 100755 --- a/llmc/__main__.py +++ b/llmc/__main__.py @@ -20,7 +20,8 @@ from llmc.models import * from llmc.utils import (check_config, deploy_all_modality, get_modality, mkdirs, print_important_package_version, seed_all, - update_autoawq_quant_config, update_vllm_quant_config) + update_autoawq_quant_config, + update_lightx2v_quant_config, update_vllm_quant_config) from llmc.utils.registry_factory import ALGO_REGISTRY, MODEL_REGISTRY @@ -158,6 +159,7 @@ def main(config): elif config.save.get('save_lightx2v', False): deploy_all_modality(blockwise_opts, 'lightx2v_quant') blockwise_opt.save_model(save_quant_path) + update_lightx2v_quant_config(save_quant_path) if 'opencompass' in config: assert config.save.get('save_trans', False) diff --git a/llmc/utils/__init__.py b/llmc/utils/__init__.py old mode 100644 new mode 100755 index cdd10a9e7..7f8a38411 --- a/llmc/utils/__init__.py +++ b/llmc/utils/__init__.py @@ -1,4 +1,5 @@ from .export_autoawq import update_autoawq_quant_config +from .export_lightx2v import update_lightx2v_quant_config from .export_vllm import update_vllm_quant_config from .utils import (check_config, copy_files, deploy_all_modality, get_modality, mkdirs, print_important_package_version, diff --git a/llmc/utils/export_lightx2v.py b/llmc/utils/export_lightx2v.py new file mode 100755 index 000000000..0c753ad89 --- /dev/null +++ b/llmc/utils/export_lightx2v.py @@ -0,0 +1,11 @@ +import json + + +def update_lightx2v_quant_config(save_quant_path): + + config_file = save_quant_path + '/config.json' + with open(config_file, 'r') as file: + config_lightx2v = json.load(file) + config_lightx2v['quant_method'] = 'advanced_ptq' + with open(config_file, 'w') as file: + json.dump(config_lightx2v, file, indent=4)