33from typing import Dict , Type
44
55import gradio as gr
6+ from packaging import version
67
78from swift .llm .argument .base_args .base_args import get_supported_tuners
89from swift .ui .base import BaseUI
1415from swift .ui .llm_grpo .model import GRPOModel
1516from swift .ui .llm_grpo .optimizer import GRPOOptimizer
1617from swift .ui .llm_grpo .quantization import GRPOQuantization
17- from swift .ui .llm_grpo .ref_model import RefModel
1818from swift .ui .llm_grpo .report_to import GRPOReportTo
1919from swift .ui .llm_grpo .reward import Reward
2020from swift .ui .llm_grpo .rollout import Rollout
2121from swift .ui .llm_grpo .runtime import GRPORuntime
2222from swift .ui .llm_grpo .save import GRPOSave
2323from swift .ui .llm_grpo .tuner import GRPOTuner
2424from swift .ui .llm_train .llm_train import LLMTrain
25+ from swift .ui .llm_train .runtime import Runtime
2526from swift .utils import get_device_count , get_logger
2627
2728logger = get_logger ()
@@ -32,7 +33,7 @@ class LLMGRPO(LLMTrain):
3233
3334 sub_ui = [
3435 GRPOModel , GRPODataset , Reward , GRPORuntime , Rollout , GRPOSave , GRPOTuner , GRPOOptimizer , GRPOHyper ,
35- GRPOQuantization , GRPOAdvanced , RefModel , GrpoAdvanced , GRPOReportTo , LLMRollout
36+ GRPOQuantization , GRPOAdvanced , GrpoAdvanced , GRPOReportTo , LLMRollout
3637 ]
3738
3839 locale_dict : Dict [str , Dict ] = {
@@ -146,16 +147,6 @@ class LLMGRPO(LLMTrain):
146147 'en' : 'The data parallel size of DDP'
147148 }
148149 },
149- 'tuner_backend' : {
150- 'label' : {
151- 'zh' : 'Tuner backend' ,
152- 'en' : 'Tuner backend'
153- },
154- 'info' : {
155- 'zh' : 'Tuner实现框架' ,
156- 'en' : 'The tuner backend'
157- }
158- },
159150 'use_liger_kernel' : {
160151 'label' : {
161152 'zh' : '使用Liger kernel' ,
@@ -239,11 +230,17 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
239230 with gr .Accordion (elem_id = 'train_param' , open = True ):
240231 with gr .Row ():
241232 gr .Dropdown (elem_id = 'train_type' , scale = 4 , choices = list (get_supported_tuners ()))
242- gr .Dropdown (elem_id = 'tuner_backend' , scale = 4 )
243233 gr .Textbox (elem_id = 'seed' , scale = 4 )
244234 gr .Dropdown (elem_id = 'torch_dtype' , scale = 4 )
245- with gr .Row ():
246235 gr .Checkbox (elem_id = 'use_liger_kernel' , scale = 4 )
236+ gr .Textbox (elem_id = 'sequence_parallel_size' , lines = 1 , scale = 4 )
237+ with gr .Row ():
238+ gr .Dropdown (
239+ elem_id = 'gpu_id' ,
240+ multiselect = True ,
241+ choices = [str (i ) for i in range (device_count )] + ['cpu' ],
242+ value = default_device ,
243+ scale = 8 )
247244 gr .Checkbox (elem_id = 'use_ddp' , value = False , scale = 4 )
248245 gr .Textbox (elem_id = 'ddp_num' , value = '1' , scale = 4 )
249246 gr .Dropdown (
@@ -252,25 +249,17 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
252249 allow_custom_value = True ,
253250 value = None ,
254251 choices = ['zero0' , 'zero1' , 'zero2' , 'zero3' , 'zero2_offload' , 'zero3_offload' ])
255- gr .Textbox (elem_id = 'sequence_parallel_size' , lines = 1 , scale = 4 )
256252 GRPOHyper .build_ui (base_tab )
257253 GRPORuntime .build_ui (base_tab )
258254 with gr .Row (equal_height = True ):
259- gr .Dropdown (
260- elem_id = 'gpu_id' ,
261- multiselect = True ,
262- choices = [str (i ) for i in range (device_count )] + ['cpu' ],
263- value = default_device ,
264- scale = 8 )
265- gr .Textbox (elem_id = 'envs' , scale = 8 )
255+ gr .Textbox (elem_id = 'envs' , scale = 12 )
266256 gr .Checkbox (elem_id = 'dry_run' , value = False , scale = 4 )
267257 submit = gr .Button (elem_id = 'submit' , scale = 4 , variant = 'primary' )
268258
269259 Rollout .build_ui (base_tab )
270260 LLMRollout .set_lang (cls .lang )
271261 LLMRollout .build_ui (LLMRollout )
272262 GRPOTuner .build_ui (base_tab )
273- RefModel .build_ui (base_tab )
274263 with gr .Accordion (elem_id = 'extra_params' , open = True ):
275264 with gr .Tabs ():
276265 GrpoAdvanced .build_ui (base_tab )
@@ -286,13 +275,6 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
286275 inputs = [base_tab .element ('train_type' )],
287276 outputs = [cls .element ('learning_rate' )])
288277
289- base_tab .element ('gpu_id' ).change (
290- cls .update_ddp_num ,
291- [base_tab .element ('gpu_id' ), base_tab .element ('use_ddp' )], base_tab .element ('ddp_num' ))
292- base_tab .element ('use_ddp' ).change (
293- cls .update_ddp_num ,
294- [base_tab .element ('gpu_id' ), base_tab .element ('use_ddp' )], base_tab .element ('ddp_num' ))
295-
296278 submit .click (
297279 cls .train_local ,
298280 list (cls .valid_elements ().values ()), [
@@ -312,15 +294,35 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
312294 cls .element ('template' )],
313295 [LLMRollout .element ('rollout_runtime_tab' ),
314296 LLMRollout .element ('rollout_running_tasks' )])
315- base_tab .element ('running_tasks' ).change (
316- partial (GRPORuntime .task_changed , base_tab = base_tab ), [base_tab .element ('running_tasks' )],
317- list (base_tab .valid_elements ().values ()) + [cls .element ('log' )] + GRPORuntime .all_plots )
297+
318298 GRPORuntime .element ('kill_task' ).click (
319299 GRPORuntime .kill_task ,
320300 [GRPORuntime .element ('running_tasks' )],
321301 [GRPORuntime .element ('running_tasks' )] + [GRPORuntime .element ('log' )] + GRPORuntime .all_plots ,
322302 ).then (GRPORuntime .reset , [], [GRPORuntime .element ('logging_dir' )] + [GRPOHyper .element ('output_dir' )])
323303
304+ base_tab .element ('gpu_id' ).change (
305+ cls .update_ddp_num ,
306+ [base_tab .element ('gpu_id' ), base_tab .element ('use_ddp' )], base_tab .element ('ddp_num' ))
307+ base_tab .element ('use_ddp' ).change (
308+ cls .update_ddp_num ,
309+ [base_tab .element ('gpu_id' ), base_tab .element ('use_ddp' )], base_tab .element ('ddp_num' ))
310+ base_tab .element ('ddp_num' ).change (Rollout .update_num_gen , [
311+ GRPOHyper .element ('per_device_train_batch_size' ),
312+ GRPOHyper .element ('gradient_accumulation_steps' ),
313+ cls .element ('ddp_num' )
314+ ], [Rollout .element ('num_generations' )])
315+ GRPOHyper .element ('gradient_accumulation_steps' ).change (Rollout .update_num_gen , [
316+ GRPOHyper .element ('per_device_train_batch_size' ),
317+ GRPOHyper .element ('gradient_accumulation_steps' ),
318+ cls .element ('ddp_num' )
319+ ], [Rollout .element ('num_generations' )])
320+ GRPOHyper .element ('per_device_train_batch_size' ).change (Rollout .update_num_gen , [
321+ GRPOHyper .element ('per_device_train_batch_size' ),
322+ GRPOHyper .element ('gradient_accumulation_steps' ),
323+ cls .element ('ddp_num' )
324+ ], [Rollout .element ('num_generations' )])
325+
324326 @classmethod
325327 def prepare_sub_to_filter (cls ):
326328 tabs_relation_dict = {
0 commit comments