Skip to content

Commit de490d8

Browse files
committed
fix: Fixing merger import error and not saving processor and config
1 parent 5e458a5 commit de490d8

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/lmms_engine/merger/fsdp2.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77

88
import torch
99
from accelerate import init_empty_weights
10+
from loguru import logger
1011
from tqdm import tqdm
11-
from transformers import AutoConfig
12+
from transformers import AutoConfig, AutoProcessor
1213

1314
from lmms_engine.mapping_func import create_model_from_pretrained
1415
from lmms_engine.merger.base import CheckpointMerger
16+
from lmms_engine.models import *
1517

1618
CheckpointType = Literal["regular", "ema"]
1719

@@ -159,12 +161,14 @@ def merge(
159161
ValueError: If checkpoint type directory is not found
160162
"""
161163
# Resolve checkpoint path (handles parent directories with checkpoint-* subdirs)
164+
original_checkpoint_path = checkpoint_path
162165
checkpoint_path = self._resolve_checkpoint_path(checkpoint_path)
163166

164167
if output_path is None:
165-
output_path = checkpoint_path
168+
output_path = original_checkpoint_path
166169

167170
shard_path = checkpoint_path / self._state_dict_dirname
171+
logger.info(f"Selecting Checkpoint: {checkpoint_path} with state dict dirname: {self._state_dict_dirname}")
168172
if not shard_path.exists():
169173
raise ValueError(f"Checkpoint type '{self.checkpoint_type}' not found at {shard_path}")
170174

@@ -182,7 +186,9 @@ def merge(
182186
with init_empty_weights():
183187
model = model_cls.from_config(config)
184188
model.load_state_dict(full_state_dict, assign=True)
185-
189+
processor = AutoProcessor.from_pretrained(checkpoint_path)
190+
processor.save_pretrained(output_path)
191+
config.save_pretrained(output_path)
186192
# Save merged checkpoint
187193
model.save_pretrained(output_path)
188194

0 commit comments

Comments
 (0)