Skip to content

Commit 4a96f35

Browse files
Fix bugs (#1311)
1 parent 65ea69d commit 4a96f35

File tree

7 files changed

+92
-16
lines changed

7 files changed

+92
-16
lines changed

swift/llm/utils/media.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import os
22
import shutil
3-
from typing import Any, Dict, List, Literal, Optional, Union
3+
from typing import Any, Dict, Literal, Optional, Union
44

55
import numpy as np
6+
from modelscope.hub.utils.utils import get_cache_dir
67

7-
from swift.hub.utils.utils import get_cache_dir
88
from swift.utils import get_logger
99

1010
logger = get_logger()
@@ -125,10 +125,24 @@ def get_url(media_type):
125125
return f'{MediaCache.URL_PREFIX}{media_type}.{extension}'
126126

127127
@staticmethod
128-
def download(media_type, media_name=None):
129-
from swift.utils import safe_ddp_context
128+
def download(media_type_or_url: str, local_alias: Optional[str] = None):
129+
"""Download and extract a resource from a http link.
130+
131+
Args:
132+
media_type_or_url: `str`, Either belongs to the `media_type_urls` listed in the class field, or a
133+
remote url to download and extract. Be aware that, this media type or url
134+
needs to contain a zip or tar file.
135+
local_alias: `Options[str]`, The local alias name for the `media_type_or_url`. If the first arg is a
136+
media_type listed in this class, local_alias can leave None. else please pass in a name for the url.
137+
The local dir contains the extracted files will be: {cache_dir}/{local_alias}
138+
139+
Returns:
140+
The local dir contains the extracted files.
141+
"""
142+
from swift.utils import safe_ddp_context, FileLockContext
130143
with safe_ddp_context():
131-
return MediaCache._safe_download(media_type=media_type, media_name=media_name)
144+
with FileLockContext(media_type_or_url):
145+
return MediaCache._safe_download(media_type=media_type_or_url, media_name=local_alias)
132146

133147
@staticmethod
134148
def _safe_download(media_type, media_name=None):

swift/llm/utils/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import transformers
1616
from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
1717
GenerationConfig, GPTQConfig, snapshot_download)
18+
from modelscope.hub.utils.utils import get_cache_dir
1819
from packaging import version
1920
from torch import Tensor
2021
from torch import dtype as Dtype
@@ -25,7 +26,6 @@
2526
from transformers.utils.versions import require_version
2627

2728
from swift import get_logger
28-
from swift.hub.utils.utils import get_cache_dir
2929
from swift.utils import get_dist_setting, safe_ddp_context, subprocess_run, use_torchacc
3030
from .template import TemplateType
3131
from .utils import get_max_model_len, is_unsloth_available

swift/llm/utils/preprocess.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ def preprocess(self, d: Dict[str, Any]) -> Dict[str, Any]:
158158
medias = self.parse_medias(d)
159159
self.media_replacer(row, medias)
160160
if self.media_type:
161-
if not isinstance(self.media_key, str):
162-
row[self.media_name] = medias
161+
row[self.media_name] = medias
163162
return row
164163

165164
def __call__(self, dataset: HfDataset) -> HfDataset:
@@ -248,8 +247,7 @@ def preprocess(self, d: Dict[str, Any]) -> Dict[str, Any]:
248247
medias = self.parse_medias(d)
249248
self.media_replacer(row, medias)
250249
if self.media_type:
251-
if not isinstance(self.media_key, str):
252-
row[self.media_name] = medias
250+
row[self.media_name] = medias
253251
return row
254252
except (AssertionError, SyntaxError):
255253
if self.error_strategy == 'raise':
@@ -303,8 +301,7 @@ def preprocess(self, d: Dict[str, Any]) -> Dict[str, Any]:
303301
medias = self.parse_medias(d)
304302
self.media_replacer(row, medias)
305303
if self.media_type:
306-
if not isinstance(self.media_key, str):
307-
row[self.media_name] = medias
304+
row[self.media_name] = medias
308305
except Exception:
309306
if self.error_strategy == 'raise':
310307
raise ValueError(f'conversations: {conversations}')

swift/tuners/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
import numpy as np
1515
import torch
1616
from modelscope import snapshot_download
17+
from modelscope.hub.utils.utils import get_cache_dir
1718
from packaging import version
1819
from peft.utils import CONFIG_NAME
1920
from peft.utils import ModulesToSaveWrapper as _ModulesToSaveWrapper
2021
from peft.utils import _get_submodules
2122

22-
from swift.hub.utils.utils import get_cache_dir
2323
from swift.tuners.module_mapping import ModelKeys
2424
from swift.utils.constants import BIN_EXTENSIONS
2525
from swift.utils.logger import get_logger

swift/ui/llm_train/runtime.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,14 @@ def update_log(cls, task):
302302
ret.append(gr.update(visible=True, label=p['name']))
303303
return ret
304304

305+
@classmethod
306+
def get_initial(cls, line):
307+
tqdm_starts = ['Train:', 'Map:', 'Val:', 'Filter:']
308+
for start in tqdm_starts:
309+
if line.startswith(start):
310+
return start
311+
return None
312+
305313
@classmethod
306314
def wait(cls, logging_dir, task):
307315
if not logging_dir:
@@ -334,6 +342,15 @@ def wait(cls, logging_dir, task):
334342
else:
335343
latest_data = ''
336344
lines.extend(latest_lines)
345+
start = cls.get_initial(lines[-1])
346+
if start:
347+
i = len(lines) - 2
348+
while i >= 0:
349+
if lines[i].startswith(start):
350+
del lines[i]
351+
i -= 1
352+
else:
353+
break
337354
yield ['\n'.join(lines)] + Runtime.plot(task)
338355
except IOError:
339356
pass

swift/utils/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010
from .torch_utils import (activate_model_parameters, broadcast_string, freeze_model_parameters, get_dist_setting,
1111
get_model_info, is_ddp_plus_mp, is_dist, is_local_master, is_master, is_mp, is_on_same_device,
1212
show_layers, time_synchronize, torchacc_trim_graph, use_torchacc)
13-
from .utils import (add_version_to_work_dir, check_json_format, get_pai_tensorboard_dir, is_pai_training_job,
14-
lower_bound, parse_args, read_multi_line, safe_ddp_context, seed_everything, subprocess_run,
15-
test_time, upper_bound)
13+
from .utils import (FileLockContext, add_version_to_work_dir, check_json_format, get_pai_tensorboard_dir,
14+
is_pai_training_job, lower_bound, parse_args, read_multi_line, safe_ddp_context, seed_everything,
15+
subprocess_run, test_time, upper_bound)

swift/utils/utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
import datetime as dt
3+
import fcntl
4+
import hashlib
35
import os
46
import random
57
import re
@@ -11,6 +13,7 @@
1113

1214
import numpy as np
1315
import torch.distributed as dist
16+
from modelscope.hub.utils.utils import get_cache_dir
1417
from transformers import HfArgumentParser, enable_full_determinism, set_seed
1518

1619
from .logger import get_logger
@@ -20,6 +23,51 @@
2023
logger = get_logger()
2124

2225

26+
class FileLockContext:
27+
28+
cache_dir = os.path.join(get_cache_dir(), 'lockers')
29+
30+
def __init__(self, origin_symbol: str, timeout: int = 60 * 30):
31+
self.origin_symbol = origin_symbol
32+
self.file_path = hashlib.md5(origin_symbol.encode('utf-8')).hexdigest() + '.lock'
33+
self.file_path = os.path.join(FileLockContext.cache_dir, self.file_path)
34+
self.file_handle = None
35+
self.timeout = timeout
36+
37+
def acquire(self):
38+
"""Acquire the lock, optionally waiting until it is available."""
39+
start_time = time.time()
40+
while True:
41+
try:
42+
os.makedirs(FileLockContext.cache_dir, exist_ok=True)
43+
open(self.file_path, 'a').close()
44+
self.file_handle = open(self.file_path, 'w')
45+
fcntl.flock(self.file_handle, fcntl.LOCK_EX)
46+
return True
47+
except IOError as e:
48+
if self.file_handle:
49+
self.file_handle.close()
50+
self.file_handle = None
51+
if self.timeout and (time.time() - start_time) >= self.timeout:
52+
raise IOError(f'Cannot acquire the file lock from {self.origin_symbol} '
53+
f'as the timeout reaches: {self.timeout} seconds') from e
54+
time.sleep(1)
55+
56+
def release(self):
57+
"""Release the lock."""
58+
if self.file_handle:
59+
fcntl.flock(self.file_handle, fcntl.LOCK_UN)
60+
self.file_handle.close()
61+
self.file_handle = None
62+
63+
def __enter__(self):
64+
self.acquire()
65+
return self
66+
67+
def __exit__(self, exc_type, exc_value, traceback):
68+
self.release()
69+
70+
2371
@contextmanager
2472
def safe_ddp_context():
2573
if is_dist() and not is_local_master():

0 commit comments

Comments
 (0)