|
25 | 25 | from accelerate.hooks import remove_hook_from_module
|
26 | 26 | from example_utils import apply_kv_cache_quant, get_model, get_processor, get_tokenizer, is_enc_dec
|
27 | 27 | from transformers import (
|
| 28 | + AutoConfig, |
28 | 29 | AutoModelForCausalLM,
|
| 30 | + AutoProcessor, |
29 | 31 | PreTrainedTokenizer,
|
30 | 32 | PreTrainedTokenizerFast,
|
31 | 33 | WhisperProcessor,
|
|
39 | 41 | export_tensorrt_llm_checkpoint,
|
40 | 42 | get_model_type,
|
41 | 43 | )
|
| 44 | +from modelopt.torch.export.model_utils import is_multimodal_model |
42 | 45 | from modelopt.torch.quantization.config import need_calibration
|
43 | 46 | from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights
|
44 | 47 | from modelopt.torch.quantization.utils import is_quantized
|
@@ -567,19 +570,26 @@ def output_decode(generated_ids, input_shape):
|
567 | 570 |
|
568 | 571 | export_path = args.export_path
|
569 | 572 |
|
570 |
| - if hasattr(full_model, "language_model"): |
571 |
| - # Save original model config and the preprocessor config to the export path for VLMs. |
572 |
| - from transformers import AutoConfig, AutoProcessor |
| 573 | + # Check if the model is a multimodal/VLM model |
| 574 | + is_vlm = is_multimodal_model(full_model) |
573 | 575 |
|
574 |
| - print(f"Saving original model and processor configs to {export_path}") |
| 576 | + if is_vlm: |
| 577 | + # Save original model config and the processor config to the export path for VLMs. |
| 578 | + print(f"Saving original model config to {export_path}") |
575 | 579 |
|
576 | 580 | AutoConfig.from_pretrained(
|
577 | 581 | args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
|
578 | 582 | ).save_pretrained(export_path)
|
579 | 583 |
|
580 |
| - AutoProcessor.from_pretrained( |
581 |
| - args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code |
582 |
| - ).save_pretrained(export_path) |
| 584 | + # Try to save processor config if available |
| 585 | + try: |
| 586 | + print(f"Saving processor config to {export_path}") |
| 587 | + AutoProcessor.from_pretrained( |
| 588 | + args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code |
| 589 | + ).save_pretrained(export_path) |
| 590 | + except Exception as e: |
| 591 | + print(f"Warning: Could not save processor config: {e}") |
| 592 | + print("This is normal for some VLM architectures that don't use AutoProcessor") |
583 | 593 |
|
584 | 594 | if model_type == "mllama":
|
585 | 595 | full_model_config = model.config
|
|
0 commit comments