Skip to content

Commit 9cff868

Browse files
Fix studio (#946)
1 parent 03b17dc commit 9cff868

File tree

4 files changed

+85
-58
lines changed

4 files changed

+85
-58
lines changed

swift/ui/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ def wrapper(*args, **kwargs):
3333
self.is_list = kwargs.pop('is_list')
3434

3535
if base_builder and base_builder.default(elem_id) is not None:
36-
kwargs['value'] = base_builder.default(elem_id)
36+
if os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio' and kwargs.get('value') is not None:
37+
pass
38+
else:
39+
kwargs['value'] = base_builder.default(elem_id)
3740

3841
if builder is not None:
3942
if elem_id in builder.locales(lang):

swift/ui/llm_train/dataset.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import Type
23

34
import gradio as gr
@@ -10,6 +11,8 @@ class Dataset(BaseUI):
1011

1112
group = 'llm_train'
1213

14+
is_studio = os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio'
15+
1316
locale_dict = {
1417
'dataset': {
1518
'label': {
@@ -67,8 +70,11 @@ class Dataset(BaseUI):
6770
'en': 'The sample size from the train dataset'
6871
},
6972
'info': {
70-
'zh': '从训练集中采样一定行数进行训练',
71-
'en': 'Train with the sample size from the dataset'
73+
'zh':
74+
'从训练集中采样一定行数进行训练' if not is_studio else '为减少训练时间, 采样数量在space/studio条件下不可选',
75+
'en':
76+
'Train with the sample size from the dataset'
77+
if not is_studio else 'Not interactive in space/studio to reduce train time',
7278
}
7379
},
7480
'val_dataset_sample': {
@@ -77,8 +83,11 @@ class Dataset(BaseUI):
7783
'en': 'The sample size from the val dataset'
7884
},
7985
'info': {
80-
'zh': '从验证集中采样一定行数进行训练',
81-
'en': 'Validate with the sample size from the dataset'
86+
'zh':
87+
'从验证集中采样一定行数进行训练' if not is_studio else '为减少训练时间, 采样数量在space/studio条件下不可选',
88+
'en':
89+
'Validate with the sample size from the dataset'
90+
if not is_studio else 'Not interactive in space/studio to reduce train time',
8291
}
8392
},
8493
'truncation_strategy': {
@@ -113,6 +122,10 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
113122
with gr.Row():
114123
gr.Slider(elem_id='dataset_test_ratio', minimum=0.0, maximum=1.0, step=0.05, scale=20)
115124
gr.Slider(elem_id='max_length', minimum=32, maximum=8192, step=32, scale=20)
116-
gr.Textbox(elem_id='train_dataset_sample', scale=20)
117-
gr.Textbox(elem_id='val_dataset_sample', scale=20)
125+
if not cls.is_studio:
126+
gr.Textbox(elem_id='train_dataset_sample', scale=20)
127+
gr.Textbox(elem_id='val_dataset_sample', scale=20)
128+
else:
129+
gr.Textbox(elem_id='train_dataset_sample', value=500, interactive=False, scale=20)
130+
gr.Textbox(elem_id='val_dataset_sample', value=50, interactive=False, scale=20)
118131
gr.Dropdown(elem_id='truncation_strategy', scale=20)

swift/ui/llm_train/llm_train.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class LLMTrain(BaseUI):
3535

3636
group = 'llm_train'
3737

38+
is_studio = os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio'
39+
3840
sub_ui = [
3941
Model,
4042
Dataset,
@@ -224,13 +226,13 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
224226
Quantization.build_ui(base_tab)
225227
SelfCog.build_ui(base_tab)
226228
Advanced.build_ui(base_tab)
227-
if os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio':
229+
if cls.is_studio:
228230
submit.click(
229231
cls.update_runtime, [],
230232
[cls.element('runtime_tab'), cls.element('log')]).then(
231233
cls.train_studio,
232234
[value for value in cls.elements().values() if not isinstance(value, (Tab, Accordion))],
233-
[cls.element('log')],
235+
[cls.element('log')] + Runtime.all_plots,
234236
queue=True)
235237
else:
236238
submit.click(
@@ -242,17 +244,18 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
242244
cls.element('running_tasks'),
243245
],
244246
queue=True)
245-
base_tab.element('running_tasks').change(
246-
partial(Runtime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')],
247-
[value for value in base_tab.elements().values() if not isinstance(value, (Tab, Accordion))]
248-
+ [cls.element('log')] + Runtime.all_plots,
249-
cancels=Runtime.log_event)
250-
Runtime.element('kill_task').click(
251-
Runtime.kill_task,
252-
[Runtime.element('running_tasks')],
253-
[Runtime.element('running_tasks')] + [Runtime.element('log')] + Runtime.all_plots,
254-
cancels=[Runtime.log_event],
255-
).then(Runtime.reset, [], [Runtime.element('logging_dir')] + [Save.element('output_dir')])
247+
if not cls.is_studio:
248+
base_tab.element('running_tasks').change(
249+
partial(Runtime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')],
250+
[value for value in base_tab.elements().values() if not isinstance(value, (Tab, Accordion))]
251+
+ [cls.element('log')] + Runtime.all_plots,
252+
cancels=Runtime.log_event)
253+
Runtime.element('kill_task').click(
254+
Runtime.kill_task,
255+
[Runtime.element('running_tasks')],
256+
[Runtime.element('running_tasks')] + [Runtime.element('log')] + Runtime.all_plots,
257+
cancels=[Runtime.log_event],
258+
).then(Runtime.reset, [], [Runtime.element('logging_dir')] + [Save.element('output_dir')])
256259

257260
@classmethod
258261
def update_runtime(cls):
@@ -329,7 +332,7 @@ def train(cls, *args):
329332
if ddp_param:
330333
ddp_param = f'set {ddp_param} && '
331334
run_command = f'{cuda_param}{ddp_param}start /b swift sft {params} > {log_file} 2>&1'
332-
elif os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio':
335+
elif cls.is_studio:
333336
run_command = f'{cuda_param} {ddp_param} swift sft {params}'
334337
else:
335338
run_command = f'{cuda_param} {ddp_param} nohup swift sft {params} > {log_file} 2>&1 &'
@@ -339,14 +342,14 @@ def train(cls, *args):
339342
@classmethod
340343
def train_studio(cls, *args):
341344
run_command, sft_args, other_kwargs = cls.train(*args)
342-
if os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio':
345+
if cls.is_studio:
343346
lines = collections.deque(maxlen=int(os.environ.get('MAX_LOG_LINES', 50)))
344347
process = Popen(run_command, shell=True, stdout=PIPE, stderr=STDOUT)
345348
with process.stdout:
346349
for line in iter(process.stdout.readline, b''):
347350
line = line.decode('utf-8')
348351
lines.append(line)
349-
yield '\n'.join(lines)
352+
yield ['\n'.join(lines)] + Runtime.plot(run_command)
350353

351354
@classmethod
352355
def train_local(cls, *args):

swift/ui/llm_train/runtime.py

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class Runtime(BaseUI):
3030

3131
log_event = None
3232

33+
is_studio = os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio'
34+
3335
sft_plot = [
3436
{
3537
'name': 'train/loss',
@@ -187,49 +189,52 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
187189
with gr.Blocks():
188190
with gr.Row():
189191
gr.Textbox(elem_id='running_cmd', lines=1, scale=20, interactive=False, max_lines=1)
190-
gr.Textbox(elem_id='logging_dir', lines=1, scale=20, max_lines=1)
191-
gr.Button(elem_id='show_log', scale=2, variant='primary')
192-
gr.Button(elem_id='stop_show_log', scale=2)
193-
gr.Textbox(elem_id='tb_url', lines=1, scale=10, interactive=False, max_lines=1)
194-
gr.Button(elem_id='start_tb', scale=2, variant='primary')
195-
gr.Button(elem_id='close_tb', scale=2)
192+
if not cls.is_studio:
193+
gr.Textbox(elem_id='logging_dir', lines=1, scale=20, max_lines=1)
194+
gr.Button(elem_id='show_log', scale=2, variant='primary')
195+
gr.Button(elem_id='stop_show_log', scale=2)
196+
gr.Textbox(elem_id='tb_url', lines=1, scale=10, interactive=False, max_lines=1)
197+
gr.Button(elem_id='start_tb', scale=2, variant='primary')
198+
gr.Button(elem_id='close_tb', scale=2)
196199
with gr.Row():
197200
gr.Textbox(elem_id='log', lines=6, visible=False)
198-
with gr.Row():
199-
gr.Dropdown(elem_id='running_tasks', scale=10)
200-
gr.Button(elem_id='refresh_tasks', scale=1)
201-
gr.Button(elem_id='kill_task', scale=1)
201+
if not cls.is_studio:
202+
with gr.Row():
203+
gr.Dropdown(elem_id='running_tasks', scale=10)
204+
gr.Button(elem_id='refresh_tasks', scale=1)
205+
gr.Button(elem_id='kill_task', scale=1)
202206

203207
with gr.Row():
204208
cls.all_plots = []
205209
for k in Runtime.sft_plot:
206210
name = k['name']
207211
cls.all_plots.append(gr.Plot(elem_id=name, label=name))
208212

209-
cls.log_event = base_tab.element('show_log').click(
210-
Runtime.update_log, [], [cls.element('log')] + cls.all_plots).then(
211-
Runtime.wait, [base_tab.element('logging_dir'),
212-
base_tab.element('running_tasks')], [cls.element('log')] + cls.all_plots)
213+
if not cls.is_studio:
214+
cls.log_event = base_tab.element('show_log').click(
215+
Runtime.update_log, [], [cls.element('log')] + cls.all_plots).then(
216+
Runtime.wait, [base_tab.element('logging_dir'),
217+
base_tab.element('running_tasks')], [cls.element('log')] + cls.all_plots)
213218

214-
base_tab.element('stop_show_log').click(lambda: None, cancels=cls.log_event)
219+
base_tab.element('stop_show_log').click(lambda: None, cancels=cls.log_event)
215220

216-
base_tab.element('start_tb').click(
217-
Runtime.start_tb,
218-
[base_tab.element('logging_dir')],
219-
[base_tab.element('tb_url')],
220-
)
221+
base_tab.element('start_tb').click(
222+
Runtime.start_tb,
223+
[base_tab.element('logging_dir')],
224+
[base_tab.element('tb_url')],
225+
)
221226

222-
base_tab.element('close_tb').click(
223-
Runtime.close_tb,
224-
[base_tab.element('logging_dir')],
225-
[],
226-
)
227+
base_tab.element('close_tb').click(
228+
Runtime.close_tb,
229+
[base_tab.element('logging_dir')],
230+
[],
231+
)
227232

228-
base_tab.element('refresh_tasks').click(
229-
Runtime.refresh_tasks,
230-
[base_tab.element('running_tasks')],
231-
[base_tab.element('running_tasks')],
232-
)
233+
base_tab.element('refresh_tasks').click(
234+
Runtime.refresh_tasks,
235+
[base_tab.element('running_tasks')],
236+
[base_tab.element('running_tasks')],
237+
)
233238

234239
@classmethod
235240
def update_log(cls):
@@ -362,11 +367,12 @@ def format_time(seconds):
362367
@staticmethod
363368
def parse_info_from_cmdline(task):
364369
pid = None
365-
for i in range(3):
366-
slash = task.find('/')
367-
if i == 0:
368-
pid = task[:slash].split(':')[1]
369-
task = task[slash + 1:]
370+
if '/cmd:' in task:
371+
for i in range(3):
372+
slash = task.find('/')
373+
if i == 0:
374+
pid = task[:slash].split(':')[1]
375+
task = task[slash + 1:]
370376
args = task.split('swift sft')[1]
371377
args = [arg.strip() for arg in args.split('--') if arg.strip()]
372378
all_args = {}
@@ -427,6 +433,8 @@ def plot(task):
427433
return [None] * len(Runtime.sft_plot)
428434
_, all_args = Runtime.parse_info_from_cmdline(task)
429435
tb_dir = all_args['logging_dir']
436+
if not os.path.exists(tb_dir):
437+
return [None] * len(Runtime.sft_plot)
430438
fname = [
431439
fname for fname in os.listdir(tb_dir)
432440
if os.path.isfile(os.path.join(tb_dir, fname)) and fname.startswith('events.out')

0 commit comments

Comments
 (0)