26
26
27
27
import torch
28
28
import torch .nn as nn
29
+ from safetensors .torch import save_file
29
30
30
31
from modelopt .torch .quantization import set_quantizer_by_cfg_context
31
32
from modelopt .torch .quantization .nn import SequentialQuantizer , TensorQuantizer
53
54
QUANTIZATION_W4A8_AWQ ,
54
55
QUANTIZATION_W4A8_NVFP4_FP8 ,
55
56
)
56
- from .plugins import rename_and_prune_if_spec_decoding , set_config_if_spec_decoding
57
+ from .plugins import export_spec_ckpt_config , export_spec_ckpt_state_dict , spec_opt_only
57
58
from .quant_utils import (
58
59
fuse_prequant_layernorm ,
59
60
get_activation_scaling_factor ,
@@ -507,18 +508,24 @@ def export_hf_checkpoint(
507
508
"""
508
509
export_dir = Path (export_dir )
509
510
export_dir .mkdir (parents = True , exist_ok = True )
511
+
512
+ # Early exit for speculative decoding models
513
+ # We do this since some spec models get error in convert_hf_quant_config_format
514
+ if spec_opt_only (model ):
515
+ save_file (export_spec_ckpt_state_dict (model ), f"{ export_dir } /model.safetensors" )
516
+ with open (f"{ export_dir } /config.json" , "w" ) as file :
517
+ json .dump (export_spec_ckpt_config (model ), file , indent = 4 )
518
+ return
519
+
510
520
try :
511
521
post_state_dict , hf_quant_config = _export_hf_checkpoint (model , dtype )
512
522
513
- # NOTE: (hg) Should we save hf_quant_config when there's no quantization applied?
514
523
# Save hf_quant_config.json for backward compatibility
515
524
with open (f"{ export_dir } /hf_quant_config.json" , "w" ) as file :
516
525
json .dump (hf_quant_config , file , indent = 4 )
517
526
518
527
hf_quant_config = convert_hf_quant_config_format (hf_quant_config )
519
528
520
- post_state_dict = rename_and_prune_if_spec_decoding (model , post_state_dict )
521
-
522
529
# Save model
523
530
model .save_pretrained (
524
531
export_dir , state_dict = post_state_dict , save_modelopt_state = save_modelopt_state
@@ -532,8 +539,6 @@ def export_hf_checkpoint(
532
539
533
540
config_data ["quantization_config" ] = hf_quant_config
534
541
535
- config_data = set_config_if_spec_decoding (model , config_data )
536
-
537
542
with open (original_config , "w" ) as file :
538
543
json .dump (config_data , file , indent = 4 )
539
544
0 commit comments