|
30 | 30 | from accelerate import Accelerator |
31 | 31 | from example_utils import build_quant_cfg, get_tokenizer |
32 | 32 | from tqdm import tqdm |
33 | | -from transformers import AutoModelForCausalLM, PreTrainedTokenizer, PreTrainedTokenizerFast |
| 33 | +from transformers import ( |
| 34 | + AutoConfig, |
| 35 | + AutoModelForCausalLM, |
| 36 | + AutoProcessor, |
| 37 | + PreTrainedTokenizer, |
| 38 | + PreTrainedTokenizerFast, |
| 39 | +) |
34 | 40 |
|
35 | 41 | import modelopt.torch.opt as mto |
36 | 42 | import modelopt.torch.quantization as mtq |
37 | 43 | from modelopt.torch.export import get_model_type |
38 | 44 | from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format |
| 45 | +from modelopt.torch.export.model_utils import is_multimodal_model |
39 | 46 | from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint |
40 | 47 | from modelopt.torch.quantization.config import need_calibration |
41 | 48 | from modelopt.torch.quantization.utils import patch_fsdp_mp_dtypes |
@@ -243,6 +250,28 @@ def export_model( |
243 | 250 | export_dir = Path(export_path) |
244 | 251 | export_dir.mkdir(parents=True, exist_ok=True) |
245 | 252 |
|
| 253 | + # Check if the model is a multimodal/VLM model |
| 254 | + is_vlm = is_multimodal_model(model) |
| 255 | + |
| 256 | + if is_vlm: |
| 257 | + # Save original model config and the processor config to the export path for VLMs. |
| 258 | + print(f"Saving original model config to {export_path}") |
| 259 | + |
| 260 | + config_kwargs = {"trust_remote_code": args.trust_remote_code} |
| 261 | + if args.attn_implementation is not None: |
| 262 | + config_kwargs["attn_implementation"] = args.attn_implementation |
| 263 | + AutoConfig.from_pretrained(args.pyt_ckpt_path, **config_kwargs).save_pretrained(export_path) |
| 264 | + |
| 265 | + # Try to save processor config if available |
| 266 | + try: |
| 267 | + print(f"Saving processor config to {export_path}") |
| 268 | + AutoProcessor.from_pretrained( |
| 269 | + args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code |
| 270 | + ).save_pretrained(export_path) |
| 271 | + except Exception as e: |
| 272 | + print(f"Warning: Could not save processor config: {e}") |
| 273 | + print("This is normal for some VLM architectures that don't use AutoProcessor") |
| 274 | + |
246 | 275 | post_state_dict, hf_quant_config = _export_hf_checkpoint( |
247 | 276 | model, torch.bfloat16, accelerator=accelerator |
248 | 277 | ) |
|
0 commit comments