1+ import collections
12import os
23import sys
34import time
5+ from subprocess import PIPE , STDOUT , Popen
46from typing import Dict , Type
57
68import gradio as gr
@@ -191,16 +193,31 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
191193 Quantization .build_ui (base_tab )
192194 SelfCog .build_ui (base_tab )
193195 Advanced .build_ui (base_tab )
194- submit .click (
195- cls .train , [
196- value for value in cls .elements ().values ()
197- if not isinstance (value , (Tab , Accordion ))
198- ], [
199- cls .element ('running_cmd' ),
200- cls .element ('logging_dir' ),
201- cls .element ('runtime_tab' )
202- ],
203- show_progress = True )
196+ if os .environ .get ('MODELSCOPE_ENVIRONMENT' ) == 'studio' :
197+ submit .click (
198+ cls .update_runtime , [],
199+ [cls .element ('runtime_tab' ),
200+ cls .element ('log' )]).then (
201+ cls .train_studio , [
202+ value for value in cls .elements ().values ()
203+ if not isinstance (value , (Tab , Accordion ))
204+ ], [cls .element ('log' )],
205+ queue = True )
206+ else :
207+ submit .click (
208+ cls .train_local , [
209+ value for value in cls .elements ().values ()
210+ if not isinstance (value , (Tab , Accordion ))
211+ ], [
212+ cls .element ('running_cmd' ),
213+ cls .element ('logging_dir' ),
214+ cls .element ('runtime_tab' ),
215+ ],
216+ queue = True )
217+
218+ @classmethod
219+ def update_runtime (cls ):
220+ return gr .update (visible = True ), gr .update (visible = True )
204221
205222 @classmethod
206223 def train (cls , * args ):
@@ -239,7 +256,8 @@ def train(cls, *args):
239256 params += f'--{ e } { kwargs [e ]} '
240257 else :
241258 params += f'--{ e } "{ kwargs [e ]} " '
242- params += '--add_output_dir_suffix False '
259+ params += f'--add_output_dir_suffix False --output_dir { sft_args .output_dir } ' \
260+ f'--logging_dir { sft_args .logging_dir } '
243261 for key , param in more_params .items ():
244262 params += f'--{ key } "{ param } " '
245263 ddp_param = ''
@@ -260,9 +278,30 @@ def train(cls, *args):
260278 if ddp_param :
261279 ddp_param = f'set { ddp_param } && '
262280 run_command = f'{ cuda_param } { ddp_param } start /b swift sft { params } > { log_file } 2>&1'
281+ elif os .environ .get ('MODELSCOPE_ENVIRONMENT' ) == 'studio' :
282+ run_command = f'{ cuda_param } { ddp_param } swift sft { params } '
263283 else :
264284 run_command = f'{ cuda_param } { ddp_param } nohup swift sft { params } > { log_file } 2>&1 &'
265285 logger .info (f'Run training: { run_command } ' )
286+ return run_command , sft_args , other_kwargs
287+
288+ @classmethod
289+ def train_studio (cls , * args ):
290+ run_command , sft_args , other_kwargs = cls .train (* args )
291+ if os .environ .get ('MODELSCOPE_ENVIRONMENT' ) == 'studio' :
292+ lines = collections .deque (
293+ maxlen = int (os .environ .get ('MAX_LOG_LINES' , 50 )))
294+ process = Popen (
295+ run_command , shell = True , stdout = PIPE , stderr = STDOUT )
296+ with process .stdout :
297+ for line in iter (process .stdout .readline , b'' ):
298+ line = line .decode ('utf-8' )
299+ lines .append (line )
300+ yield '\n ' .join (lines )
301+
302+ @classmethod
303+ def train_local (cls , * args ):
304+ run_command , sft_args , other_kwargs = cls .train (* args )
266305 if not other_kwargs ['dry_run' ]:
267306 os .makedirs (sft_args .logging_dir , exist_ok = True )
268307 os .system (run_command )
0 commit comments