Skip to content

Commit 9279844

Browse files
committed
[megatron] fix multimodal pp (#5857)
1 parent 3c01327 commit 9279844

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

swift/llm/model/utils.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -288,23 +288,23 @@ def safe_snapshot_download(model_id_or_path: str,
288288
hub = get_hub(use_hf)
289289
if model_id_or_path.startswith('~'):
290290
model_id_or_path = os.path.abspath(os.path.expanduser(model_id_or_path))
291-
with safe_ddp_context(hash_id=model_id_or_path):
292-
model_path_to_check = '/'.join(model_id_or_path.split(':', 1))
293-
if os.path.exists(model_id_or_path):
294-
model_dir = model_id_or_path
295-
sub_folder = None
296-
elif os.path.exists(model_path_to_check):
297-
model_dir = model_path_to_check
298-
sub_folder = None
299-
else:
300-
if model_id_or_path.startswith('/'): # startswith
301-
raise ValueError(f"path: '{model_id_or_path}' not found")
302-
model_id_or_path = model_id_or_path.split(':', 1) # get sub_folder
303-
if len(model_id_or_path) == 1:
304-
model_id_or_path = [model_id_or_path[0], None]
305-
model_id_or_path, sub_folder = model_id_or_path
306-
if sub_folder is not None:
307-
kwargs['allow_patterns'] = [f"{sub_folder.rstrip('/')}/*"]
291+
model_path_to_check = '/'.join(model_id_or_path.split(':', 1))
292+
if os.path.exists(model_id_or_path):
293+
model_dir = model_id_or_path
294+
sub_folder = None
295+
elif os.path.exists(model_path_to_check):
296+
model_dir = model_path_to_check
297+
sub_folder = None
298+
else:
299+
if model_id_or_path.startswith('/'): # startswith
300+
raise ValueError(f"path: '{model_id_or_path}' not found")
301+
model_id_or_path = model_id_or_path.split(':', 1) # get sub_folder
302+
if len(model_id_or_path) == 1:
303+
model_id_or_path = [model_id_or_path[0], None]
304+
model_id_or_path, sub_folder = model_id_or_path
305+
if sub_folder is not None:
306+
kwargs['allow_patterns'] = [f"{sub_folder.rstrip('/')}/*"]
307+
with safe_ddp_context(hash_id=model_id_or_path):
308308
model_dir = hub.download_model(model_id_or_path, revision, ignore_patterns, token=hub_token, **kwargs)
309309

310310
logger.info(f'Loading the model using model_dir: {model_dir}')

0 commit comments

Comments
 (0)