Skip to content

Commit 6c963d8

Browse files
Fix FSDP; Add training percentage to jsonl logging; Add a web-ui component (#1381)
1 parent 35dac29 commit 6c963d8

File tree

10 files changed

+76
-56
lines changed

10 files changed

+76
-56
lines changed

swift/llm/sft.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
plot_images, preprocess_logits_for_metrics, seed_everything, show_layers, use_torchacc)
1919
from .accelerator import ta_accelerate
2020
from .tuner import prepare_model
21-
from .utils import (TEMPLATE_MAPPING, LazyLLMDataset, SftArguments, Template, dataset_map, get_dataset,
22-
get_model_tokenizer, get_template, get_time_info, print_example, set_generation_config,
23-
sort_by_max_length, stat_dataset)
21+
from .utils import (LazyLLMDataset, SftArguments, Template, dataset_map, get_dataset, get_model_tokenizer, get_template,
22+
get_time_info, print_example, set_generation_config, sort_by_max_length, stat_dataset)
2423

2524
logger = get_logger()
2625

@@ -42,7 +41,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
4241
torch.cuda.set_per_process_memory_fraction(max(min(args.gpu_memory_fraction, 1.0), 0.01), device=device_id)
4342

4443
# Loading Model and Tokenizer
45-
if is_deepspeed_zero3_enabled():
44+
if is_deepspeed_zero3_enabled() or os.environ.get('ACCELERATE_USE_FSDP', 'False') == 'true':
4645
model_kwargs = {'device_map': None}
4746
elif is_torch_npu_available():
4847
model_kwargs = {'device_map': local_rank if local_rank >= 0 else 0}

swift/trainers/callback.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from transformers.trainer_utils import IntervalStrategy, has_length, speed_metrics
99

1010
from swift.utils import append_to_jsonl, is_pai_training_job, use_torchacc
11+
from ..utils.utils import format_time
1112
from .arguments import TrainingArguments
1213

1314

@@ -17,6 +18,7 @@ def on_train_begin(self, args, state, control, **kwargs):
1718
if state.is_local_process_zero:
1819
self.training_bar = tqdm(desc='Train', total=state.max_steps, dynamic_ncols=True)
1920
self.current_step = 0
21+
self.start_time = time.time()
2022
if use_torchacc():
2123
self.warmup_start_time = 0
2224
self.warmup_metric = None
@@ -33,7 +35,14 @@ def on_prediction_step(self, args, state: TrainerState, control, eval_dataloader
3335
self.prediction_bar.update()
3436

3537
def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs):
36-
logs['global_step'] = state.global_step
38+
logs['steps[global_step/max_steps]'] = f'{state.global_step}/{state.max_steps}'
39+
train_percentage = state.global_step / state.max_steps if state.max_steps else 0.
40+
logs['percentage'] = f'{train_percentage * 100:.2f}%'
41+
elapsed = time.time() - self.start_time
42+
elapsed = max(0., elapsed)
43+
logs['elapsed_time'] = format_time(elapsed)
44+
logs['remaining_time'] = format_time(elapsed / train_percentage - elapsed)
45+
3746
if use_torchacc():
3847
if state.global_step >= self.metric_warmup_step and self.warmup_start_time == 0:
3948
self.warmup_start_time = time.time()

swift/ui/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import os
22
import typing
33
from dataclasses import fields
4-
from functools import partial, wraps
4+
from functools import wraps
55
from typing import Any, Dict, List, OrderedDict, Type
66

7-
from gradio import Accordion, Button, Checkbox, Dropdown, Slider, Tab, TabItem, Textbox
7+
from gradio import Accordion, Audio, Button, Checkbox, Dropdown, File, Image, Slider, Tab, TabItem, Textbox, Video
88

99
from swift.llm.utils.model import MODEL_MAPPING, ModelType
1010

@@ -69,6 +69,10 @@ def wrapper(*args, **kwargs):
6969
TabItem.__init__ = update_data(TabItem.__init__)
7070
Accordion.__init__ = update_data(Accordion.__init__)
7171
Button.__init__ = update_data(Button.__init__)
72+
File.__init__ = update_data(File.__init__)
73+
Image.__init__ = update_data(Image.__init__)
74+
Video.__init__ = update_data(Video.__init__)
75+
Audio.__init__ = update_data(Audio.__init__)
7276

7377

7478
class BaseUI:

swift/ui/llm_eval/llm_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def eval(cls, *args):
126126
elif isinstance(value, str) and re.fullmatch(cls.bool_regex, value):
127127
value = True if value.lower() == 'true' else False
128128
kwargs[key] = value if not isinstance(value, list) else ' '.join(value)
129-
kwargs_is_list[key] = isinstance(value, list)
129+
kwargs_is_list[key] = isinstance(value, list) or getattr(cls.element(key), 'is_list', False)
130130
else:
131131
other_kwargs[key] = value
132132
if key == 'more_params' and value:

swift/ui/llm_export/llm_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def export(cls, *args):
124124
elif isinstance(value, str) and re.fullmatch(cls.bool_regex, value):
125125
value = True if value.lower() == 'true' else False
126126
kwargs[key] = value if not isinstance(value, list) else ' '.join(value)
127-
kwargs_is_list[key] = isinstance(value, list)
127+
kwargs_is_list[key] = isinstance(value, list) or getattr(cls.element(key), 'is_list', False)
128128
else:
129129
other_kwargs[key] = value
130130
if key == 'more_params' and value:

swift/ui/llm_infer/llm_infer.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import re
33
import sys
44
import time
5-
from copy import copy
65
from datetime import datetime
76
from functools import partial
87
from typing import Type
@@ -14,7 +13,7 @@
1413
from modelscope import GenerationConfig, snapshot_download
1514

1615
from swift.llm import (TEMPLATE_MAPPING, DeployArguments, InferArguments, XRequestConfig, inference_client,
17-
inference_stream, limit_history_length, prepare_model_template)
16+
inference_stream, prepare_model_template)
1817
from swift.ui.base import BaseUI
1918
from swift.ui.llm_infer.model import Model
2019
from swift.ui.llm_infer.runtime import Runtime
@@ -69,6 +68,16 @@ class LLMInfer(BaseUI):
6968
'en': 'Chat bot'
7069
},
7170
},
71+
'infer_model_type': {
72+
'label': {
73+
'zh': 'Lora模块',
74+
'en': 'Lora module'
75+
},
76+
'info': {
77+
'zh': '发送给server端哪个LoRA,默认为`default-lora`',
78+
'en': 'Which LoRA to use on server, default value is `default-lora`'
79+
}
80+
},
7281
'prompt': {
7382
'label': {
7483
'zh': '请输入:',
@@ -116,12 +125,14 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
116125
history = gr.State([])
117126
Model.build_ui(base_tab)
118127
Runtime.build_ui(base_tab)
119-
gr.Dropdown(
120-
elem_id='gpu_id',
121-
multiselect=True,
122-
choices=[str(i) for i in range(gpu_count)] + ['cpu'],
123-
value=default_device,
124-
scale=8)
128+
with gr.Row():
129+
gr.Dropdown(
130+
elem_id='gpu_id',
131+
multiselect=True,
132+
choices=[str(i) for i in range(gpu_count)] + ['cpu'],
133+
value=default_device,
134+
scale=8)
135+
infer_model_type = gr.Textbox(elem_id='infer_model_type', scale=4)
125136
chatbot = gr.Chatbot(elem_id='chatbot', elem_classes='control-height')
126137
with gr.Row():
127138
prompt = gr.Textbox(elem_id='prompt', lines=1, interactive=True)
@@ -172,7 +183,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
172183
cls.send_message,
173184
inputs=[
174185
cls.element('running_tasks'), model_and_template,
175-
cls.element('template_type'), prompt, image, history,
186+
cls.element('template_type'), prompt, image, history, infer_model_type,
176187
cls.element('system'),
177188
cls.element('max_new_tokens'),
178189
cls.element('temperature'),
@@ -217,7 +228,7 @@ def deploy(cls, *args):
217228
elif isinstance(value, str) and re.fullmatch(cls.bool_regex, value):
218229
value = True if value.lower() == 'true' else False
219230
kwargs[key] = value if not isinstance(value, list) else ' '.join(value)
220-
kwargs_is_list[key] = isinstance(value, list)
231+
kwargs_is_list[key] = isinstance(value, list) or getattr(cls.element(key), 'is_list', False)
221232
else:
222233
other_kwargs[key] = value
223234
if key == 'more_params' and value:
@@ -374,8 +385,8 @@ def agent_type(cls, response):
374385
return None
375386

376387
@classmethod
377-
def send_message(cls, running_task, model_and_template, template_type, prompt: str, image, history, system,
378-
max_new_tokens, temperature, top_k, top_p, repetition_penalty):
388+
def send_message(cls, running_task, model_and_template, template_type, prompt: str, image, history,
389+
infer_model_type, system, max_new_tokens, temperature, top_k, top_p, repetition_penalty):
379390
if not model_and_template:
380391
gr.Warning(cls.locale('generate_alert', cls.lang)['value'])
381392
return '', None, None, []
@@ -393,7 +404,7 @@ def send_message(cls, running_task, model_and_template, template_type, prompt: s
393404
_, args = Runtime.parse_info_from_cmdline(running_task)
394405
model_type, template, sft_type = model_and_template
395406
if sft_type in ('lora', 'longlora') and not args.get('merge_lora'):
396-
model_type = 'default-lora'
407+
model_type = infer_model_type or 'default-lora'
397408
old_history, history = history or [], []
398409
request_config = XRequestConfig(
399410
temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)

swift/ui/llm_infer/model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,16 @@ class Model(BaseUI):
8080
'en': 'Only available when sft_type=lora'
8181
}
8282
},
83+
'lora_modules': {
84+
'label': {
85+
'zh': '外部lora模块',
86+
'en': 'More lora modules'
87+
},
88+
'info': {
89+
'zh': '空格分割的name=/path1/path2键值对',
90+
'en': 'name=/path1/path2 split by blanks'
91+
}
92+
},
8393
'more_params': {
8494
'label': {
8595
'zh': '更多参数',
@@ -117,6 +127,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
117127
system = gr.Textbox(elem_id='system', lines=4, scale=20)
118128
Generate.build_ui(base_tab)
119129
with gr.Row():
130+
gr.Textbox(elem_id='lora_modules', lines=1, is_list=True, scale=40)
120131
gr.Textbox(elem_id='more_params', lines=1, scale=20)
121132
gr.Button(elem_id='load_checkpoint', scale=2, variant='primary')
122133

swift/ui/llm_infer/runtime.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from swift.ui.base import BaseUI
1515
from swift.utils import get_logger
16+
from swift.utils.utils import format_time
1617

1718
logger = get_logger()
1819

@@ -211,23 +212,6 @@ def construct_running_task(proc):
211212
create_time = proc.create_time()
212213
create_time_formatted = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d, %H:%M')
213214

214-
def format_time(seconds):
215-
days = int(seconds // (24 * 3600))
216-
hours = int((seconds % (24 * 3600)) // 3600)
217-
minutes = int((seconds % 3600) // 60)
218-
seconds = int(seconds % 60)
219-
220-
if days > 0:
221-
time_str = f'{days}d {hours}h {minutes}m {seconds}s'
222-
elif hours > 0:
223-
time_str = f'{hours}h {minutes}m {seconds}s'
224-
elif minutes > 0:
225-
time_str = f'{minutes}m {seconds}s'
226-
else:
227-
time_str = f'{seconds}s'
228-
229-
return time_str
230-
231215
return f'pid:{pid}/create:{create_time_formatted}' \
232216
f'/running:{format_time(ts - create_time)}/cmd:{" ".join(proc.cmdline())}'
233217

swift/ui/llm_train/runtime.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from swift.ui.base import BaseUI
1717
from swift.ui.llm_train.utils import close_loop, run_command_in_subprocess
1818
from swift.utils import TB_COLOR, TB_COLOR_SMOOTH, get_logger, read_tensorboard_file, tensorboard_smoothing
19+
from swift.utils.utils import format_time
1920

2021
logger = get_logger()
2122

@@ -423,23 +424,6 @@ def construct_running_task(proc):
423424
create_time = proc.create_time()
424425
create_time_formatted = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d, %H:%M')
425426

426-
def format_time(seconds):
427-
days = int(seconds // (24 * 3600))
428-
hours = int((seconds % (24 * 3600)) // 3600)
429-
minutes = int((seconds % 3600) // 60)
430-
seconds = int(seconds % 60)
431-
432-
if days > 0:
433-
time_str = f'{days}d {hours}h {minutes}m {seconds}s'
434-
elif hours > 0:
435-
time_str = f'{hours}h {minutes}m {seconds}s'
436-
elif minutes > 0:
437-
time_str = f'{minutes}m {seconds}s'
438-
else:
439-
time_str = f'{seconds}s'
440-
441-
return time_str
442-
443427
return f'pid:{pid}/create:{create_time_formatted}' \
444428
f'/running:{format_time(ts-create_time)}/cmd:{" ".join(proc.cmdline())}'
445429

swift/utils/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,24 @@ def _get_version(work_dir: str) -> int:
6363
return max(v_list) + 1
6464

6565

66+
def format_time(seconds):
67+
days = int(seconds // (24 * 3600))
68+
hours = int((seconds % (24 * 3600)) // 3600)
69+
minutes = int((seconds % 3600) // 60)
70+
seconds = int(seconds % 60)
71+
72+
if days > 0:
73+
time_str = f'{days}d {hours}h {minutes}m {seconds}s'
74+
elif hours > 0:
75+
time_str = f'{hours}h {minutes}m {seconds}s'
76+
elif minutes > 0:
77+
time_str = f'{minutes}m {seconds}s'
78+
else:
79+
time_str = f'{seconds}s'
80+
81+
return time_str
82+
83+
6684
def seed_everything(seed: Optional[int] = None, full_determinism: bool = False, *, verbose: bool = True) -> int:
6785

6886
if seed is None:

0 commit comments

Comments
 (0)