Skip to content

Commit 4f93387

Browse files
authored
Refactor Web-UI (#4687)
1 parent 5df8180 commit 4f93387

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+3501
-695
lines changed

swift/ui/app.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88
from transformers.utils import strtobool
99

1010
import swift
11-
from swift.llm import DeployArguments, EvalArguments, ExportArguments, RLHFArguments, SwiftPipeline, WebUIArguments
11+
from swift.llm import (DeployArguments, EvalArguments, ExportArguments, RLHFArguments, SamplingArguments, SwiftPipeline,
12+
WebUIArguments)
1213
from swift.ui.llm_eval.llm_eval import LLMEval
1314
from swift.ui.llm_export.llm_export import LLMExport
1415
from swift.ui.llm_grpo.llm_grpo import LLMGRPO
1516
from swift.ui.llm_infer.llm_infer import LLMInfer
17+
from swift.ui.llm_rlhf.llm_rlhf import LLMRLHF
18+
from swift.ui.llm_sample.llm_sample import LLMSample
1619
from swift.ui.llm_train.llm_train import LLMTrain
1720

1821
locale_dict = {
@@ -51,10 +54,12 @@ def run(self):
5154
port_env = os.environ.get('WEBUI_PORT')
5255
port = int(port_env) if port_env else self.args.server_port
5356
LLMTrain.set_lang(lang)
57+
LLMRLHF.set_lang(lang)
5458
LLMGRPO.set_lang(lang)
5559
LLMInfer.set_lang(lang)
5660
LLMExport.set_lang(lang)
5761
LLMEval.set_lang(lang)
62+
LLMSample.set_lang(lang)
5863
with gr.Blocks(title='SWIFT WebUI', theme=gr.themes.Base()) as app:
5964
try:
6065
_version = swift.__version__
@@ -64,10 +69,12 @@ def run(self):
6469
gr.HTML(f"<h3><center>{locale_dict['sub_title'][lang]}</center></h3>")
6570
with gr.Tabs():
6671
LLMTrain.build_ui(LLMTrain)
72+
LLMRLHF.build_ui(LLMRLHF)
6773
LLMGRPO.build_ui(LLMGRPO)
6874
LLMInfer.build_ui(LLMInfer)
6975
LLMExport.build_ui(LLMExport)
7076
LLMEval.build_ui(LLMEval)
77+
LLMSample.build_ui(LLMSample)
7178

7279
concurrent = {}
7380
if version.parse(gr.__version__) < version.parse('4.0.0'):
@@ -76,6 +83,10 @@ def run(self):
7683
partial(LLMTrain.update_input_model, arg_cls=RLHFArguments),
7784
inputs=[LLMTrain.element('model')],
7885
outputs=[LLMTrain.element('train_record')] + list(LLMTrain.valid_elements().values()))
86+
app.load(
87+
partial(LLMRLHF.update_input_model, arg_cls=RLHFArguments),
88+
inputs=[LLMRLHF.element('model')],
89+
outputs=[LLMRLHF.element('train_record')] + list(LLMRLHF.valid_elements().values()))
7990
app.load(
8091
partial(LLMGRPO.update_input_model, arg_cls=RLHFArguments),
8192
inputs=[LLMGRPO.element('model')],
@@ -92,6 +103,10 @@ def run(self):
92103
partial(LLMEval.update_input_model, arg_cls=EvalArguments, has_record=False),
93104
inputs=[LLMEval.element('model')],
94105
outputs=list(LLMEval.valid_elements().values()))
106+
app.load(
107+
partial(LLMSample.update_input_model, arg_cls=SamplingArguments, has_record=False),
108+
inputs=[LLMSample.element('model')],
109+
outputs=list(LLMSample.valid_elements().values()))
95110
app.queue(**concurrent).launch(server_name=server, inbrowser=True, server_port=port, height=800, share=share)
96111

97112

swift/ui/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def get_default_value_from_dataclass(dataclass):
291291
default_dict[f.name] = ' '.join(default_dict[f.name])
292292
except TypeError:
293293
default_dict[f.name] = None
294-
if not default_dict[f.name]:
294+
if not default_dict[f.name] and default_dict[f.name] not in (0, False):
295295
default_dict[f.name] = None
296296
return default_dict
297297

@@ -407,3 +407,12 @@ def update_all_settings(cls, model, train_record, base_tab):
407407
else:
408408
updates.append(gr.update())
409409
return updates
410+
411+
@classmethod
412+
def update_ddp_num(cls, gpu_ids, use_ddp):
413+
if use_ddp:
414+
if 'cpu' in gpu_ids:
415+
return None
416+
else:
417+
return len(gpu_ids)
418+
return 1

swift/ui/llm_eval/llm_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class LLMEval(BaseUI):
3232
'llm_eval': {
3333
'label': {
3434
'zh': 'LLM评测',
35-
'en': 'LLM evaluation',
35+
'en': 'LLM Evaluation',
3636
}
3737
},
3838
'more_params': {
@@ -78,7 +78,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
7878
Model.build_ui(base_tab)
7979
Eval.build_ui(base_tab)
8080
EvalRuntime.build_ui(base_tab)
81-
with gr.Row():
81+
with gr.Row(equal_height=True):
8282
gr.Textbox(elem_id='more_params', lines=4, scale=20)
8383
gr.Button(elem_id='evaluate', scale=2, variant='primary')
8484
gr.Dropdown(

swift/ui/llm_eval/runtime.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ class EvalRuntime(Runtime):
5252
'en': 'Logging content'
5353
},
5454
'info': {
55-
'zh': '如果日志无更新请再次点击"展示日志内容"',
56-
'en': 'Please press "Show log" if the log content is not updating'
55+
'zh': '如果日志无更新请再次点击"展示评测状态"',
56+
'en': 'Please press "Show eval status" if the log content is not updating'
5757
}
5858
},
5959
'running_tasks': {
@@ -84,7 +84,7 @@ class EvalRuntime(Runtime):
8484
def do_build_ui(cls, base_tab: Type['BaseUI']):
8585
with gr.Accordion(elem_id='runtime_tab', open=False, visible=True):
8686
with gr.Blocks():
87-
with gr.Row():
87+
with gr.Row(equal_height=True):
8888
gr.Dropdown(elem_id='running_tasks', scale=10)
8989
gr.Button(elem_id='refresh_tasks', scale=1, variant='primary')
9090
gr.Button(elem_id='show_log', scale=1, variant='primary')

swift/ui/llm_export/export.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,20 @@ class Export(BaseUI):
1414
locale_dict = {
1515
'merge_lora': {
1616
'label': {
17-
'zh': '合并lora',
18-
'en': 'Merge lora'
17+
'zh': '合并LoRA',
18+
'en': 'Merge LoRA'
1919
},
2020
'info': {
2121
'zh':
22-
'lora合并的路径在填入的checkpoint同级目录,请查看运行时log获取更具体的信息',
22+
'LoRA合并的路径在填入的checkpoint同级目录,请查看运行时log获取更具体的信息',
2323
'en':
2424
'The output path is in the sibling directory as the input checkpoint. '
2525
'Please refer to the runtime log for more specific information.'
2626
},
2727
},
2828
'device_map': {
2929
'label': {
30-
'zh': '合并lora使用的device_map',
30+
'zh': '合并LoRA使用的device_map',
3131
'en': 'The device_map when merge-lora'
3232
},
3333
'info': {

swift/ui/llm_export/llm_export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class LLMExport(BaseUI):
3030
'llm_export': {
3131
'label': {
3232
'zh': 'LLM导出',
33-
'en': 'LLM export',
33+
'en': 'LLM Export',
3434
}
3535
},
3636
'more_params': {
@@ -76,7 +76,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
7676
Model.build_ui(base_tab)
7777
Export.build_ui(base_tab)
7878
ExportRuntime.build_ui(base_tab)
79-
with gr.Row():
79+
with gr.Row(equal_height=True):
8080
gr.Textbox(elem_id='more_params', lines=4, scale=20)
8181
gr.Button(elem_id='export', scale=2, variant='primary')
8282
gr.Dropdown(

swift/ui/llm_export/runtime.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ class ExportRuntime(Runtime):
4646
'en': 'Logging content'
4747
},
4848
'info': {
49-
'zh': '如果日志无更新请再次点击"展示日志内容"',
50-
'en': 'Please press "Show log" if the log content is not updating'
49+
'zh': '如果日志无更新请再次点击"展示导出状态"',
50+
'en': 'Please press "Show export status" if the log content is not updating'
5151
}
5252
},
5353
'running_tasks': {

swift/ui/llm_grpo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.

swift/ui/llm_grpo/advanced.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
from swift.ui.llm_train.advanced import Advanced
3+
4+
5+
class GRPOAdvanced(Advanced):
6+
7+
group = 'llm_grpo'

swift/ui/llm_grpo/dataset.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
from swift.ui.llm_train.dataset import Dataset
3+
4+
5+
class GRPODataset(Dataset):
6+
7+
group = 'llm_grpo'

0 commit comments

Comments
 (0)