Skip to content

Commit 566a113

Browse files
committed
fix hf_space
1 parent a2ef243 commit 566a113

File tree

3 files changed

+24
-12
lines changed

3 files changed

+24
-12
lines changed

swift/ui/app.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,12 @@ def run_ui():
5353
gr.HTML('<p><center>You have duplicated the space, remember remove the `MODELSCOPE_ENVIRONMENT` '
5454
'variable to start an unlimited version</center></p>')
5555
with gr.Tabs():
56-
LLMTrain.build_ui(LLMTrain)
57-
LLMInfer.build_ui(LLMInfer)
56+
if is_shared_ui:
57+
LLMInfer.build_ui(LLMInfer)
58+
LLMTrain.build_ui(LLMTrain)
59+
else:
60+
LLMTrain.build_ui(LLMTrain)
61+
LLMInfer.build_ui(LLMInfer)
5862

5963
port = os.environ.get('WEBUI_PORT', None)
6064
concurrent = {}

swift/ui/llm_train/dataset.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,6 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
122122
with gr.Row():
123123
gr.Slider(elem_id='dataset_test_ratio', minimum=0.0, maximum=1.0, step=0.05, scale=20)
124124
gr.Slider(elem_id='max_length', minimum=32, maximum=8192, step=32, scale=20)
125-
if not cls.is_studio:
126-
gr.Textbox(elem_id='train_dataset_sample', scale=20)
127-
gr.Textbox(elem_id='val_dataset_sample', scale=20)
128-
else:
129-
gr.Textbox(elem_id='train_dataset_sample', value=500, interactive=False, scale=20)
130-
gr.Textbox(elem_id='val_dataset_sample', value=50, interactive=False, scale=20)
125+
gr.Textbox(elem_id='train_dataset_sample', scale=20)
126+
gr.Textbox(elem_id='val_dataset_sample', scale=20)
131127
gr.Dropdown(elem_id='truncation_strategy', scale=20)

swift/ui/llm_train/llm_train.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030

3131
logger = get_logger()
3232

33+
is_spaces = True if 'SPACE_ID' in os.environ else False
34+
if is_spaces:
35+
is_shared_ui = True if 'modelscope/swift' in os.environ['SPACE_ID'] else False
36+
else:
37+
is_shared_ui = False
38+
3339

3440
class LLMTrain(BaseUI):
3541

@@ -214,7 +220,10 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
214220
value=default_device,
215221
scale=8)
216222
gr.Textbox(elem_id='gpu_memory_fraction', scale=4)
217-
gr.Checkbox(elem_id='dry_run', value=False, scale=4)
223+
if is_shared_ui:
224+
gr.Checkbox(elem_id='dry_run', value=True, interactive=False, scale=4)
225+
else:
226+
gr.Checkbox(elem_id='dry_run', value=False, scale=4)
218227
submit = gr.Button(elem_id='submit', scale=4, variant='primary')
219228

220229
Save.build_ui(base_tab)
@@ -232,7 +241,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
232241
[cls.element('runtime_tab'), cls.element('log')]).then(
233242
cls.train_studio,
234243
[value for value in cls.elements().values() if not isinstance(value, (Tab, Accordion))],
235-
[cls.element('log')] + Runtime.all_plots,
244+
[cls.element('log')] + Runtime.all_plots + [cls.element('running_cmd')],
236245
queue=True)
237246
else:
238247
submit.click(
@@ -342,14 +351,17 @@ def train(cls, *args):
342351
@classmethod
343352
def train_studio(cls, *args):
344353
run_command, sft_args, other_kwargs = cls.train(*args)
345-
if cls.is_studio:
354+
if not other_kwargs['dry_run']:
346355
lines = collections.deque(maxlen=int(os.environ.get('MAX_LOG_LINES', 50)))
347356
process = Popen(run_command, shell=True, stdout=PIPE, stderr=STDOUT)
348357
with process.stdout:
349358
for line in iter(process.stdout.readline, b''):
350359
line = line.decode('utf-8')
351360
lines.append(line)
352-
yield ['\n'.join(lines)] + Runtime.plot(run_command)
361+
yield ['\n'.join(lines)] + Runtime.plot(run_command) + [run_command]
362+
else:
363+
yield ['Current is dryrun mode, you can only view the training cmd, please '
364+
'duplicate this space to do training.'] + [None] * len(Runtime.sft_plot) + [run_command]
353365

354366
@classmethod
355367
def train_local(cls, *args):

0 commit comments

Comments
 (0)