77
88import torch
99from accelerate import init_empty_weights
10+ from loguru import logger
1011from tqdm import tqdm
11- from transformers import AutoConfig
12+ from transformers import AutoConfig , AutoProcessor
1213
1314from lmms_engine .mapping_func import create_model_from_pretrained
1415from lmms_engine .merger .base import CheckpointMerger
16+ from lmms_engine .models import *
1517
1618CheckpointType = Literal ["regular" , "ema" ]
1719
@@ -159,12 +161,14 @@ def merge(
159161 ValueError: If checkpoint type directory is not found
160162 """
161163 # Resolve checkpoint path (handles parent directories with checkpoint-* subdirs)
164+ original_checkpoint_path = checkpoint_path
162165 checkpoint_path = self ._resolve_checkpoint_path (checkpoint_path )
163166
164167 if output_path is None :
165- output_path = checkpoint_path
168+ output_path = original_checkpoint_path
166169
167170 shard_path = checkpoint_path / self ._state_dict_dirname
171+ logger .info (f"Selecting Checkpoint: { checkpoint_path } with state dict dirname: { self ._state_dict_dirname } " )
168172 if not shard_path .exists ():
169173 raise ValueError (f"Checkpoint type '{ self .checkpoint_type } ' not found at { shard_path } " )
170174
@@ -182,7 +186,9 @@ def merge(
182186 with init_empty_weights ():
183187 model = model_cls .from_config (config )
184188 model .load_state_dict (full_state_dict , assign = True )
185-
189+ processor = AutoProcessor .from_pretrained (checkpoint_path )
190+ processor .save_pretrained (output_path )
191+ config .save_pretrained (output_path )
186192 # Save merged checkpoint
187193 model .save_pretrained (output_path )
188194
0 commit comments