1+ import collections
12import os
23import re
4+ import sys
5+ from subprocess import PIPE , STDOUT , Popen
36from typing import Type
47
58import gradio as gr
811from gradio import Accordion , Tab
912
1013from swift import snapshot_download
11- from swift .llm import (InferArguments , inference_stream , limit_history_length ,
12- prepare_model_template )
14+ from swift .llm import (DeployArguments , InferArguments , XRequestConfig ,
15+ inference_client )
1316from swift .ui .base import BaseUI
1417from swift .ui .llm_infer .model import Model
18+ from swift .ui .llm_infer .runtime import Runtime
19+ from swift .utils import get_logger
1520
21+ logger = get_logger ()
1622
17- class LLMInfer (BaseUI ):
1823
24+ class LLMInfer (BaseUI ):
1925 group = 'llm_infer'
2026
21- sub_ui = [Model ]
27+ sub_ui = [Model , Runtime ]
2228
2329 locale_dict = {
2430 'generate_alert' : {
@@ -92,8 +98,9 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
9298 gpu_count = torch .cuda .device_count ()
9399 default_device = '0'
94100 with gr .Blocks ():
95- model_and_template = gr .State ([])
101+ model_and_template_type = gr .State ([])
96102 Model .build_ui (base_tab )
103+ Runtime .build_ui (base_tab )
97104 gr .Dropdown (
98105 elem_id = 'gpu_id' ,
99106 multiselect = True ,
@@ -112,7 +119,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
112119 submit .click (
113120 cls .generate_chat ,
114121 inputs = [
115- model_and_template ,
122+ model_and_template_type ,
116123 cls .element ('template_type' ), prompt , chatbot ,
117124 cls .element ('max_new_tokens' ),
118125 cls .element ('system' )
@@ -121,18 +128,40 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
121128 queue = True )
122129 clear_history .click (
123130 fn = cls .clear_session , inputs = [], outputs = [prompt , chatbot ])
124- cls .element ('load_checkpoint' ).click (
125- cls .reset_memory , [], [model_and_template ])\
126- .then (cls .reset_loading_button , [], [cls .element ('load_checkpoint' )]).then (
127- cls .prepare_checkpoint , [
128- value for value in cls .elements ().values ()
129- if not isinstance (value , (Tab , Accordion ))
130- ], [model_and_template ]).then (cls .change_interactive , [],
131- [prompt ]).then ( # noqa
132- cls .clear_session ,
133- inputs = [],
134- outputs = [prompt , chatbot ],
135- queue = True ).then (cls .reset_load_button , [], [cls .element ('load_checkpoint' )])
131+
132+ if os .environ .get ('MODELSCOPE_ENVIRONMENT' ) == 'studio' :
133+ cls .element ('load_checkpoint' ).click (
134+ cls .update_runtime , [],
135+ [cls .element ('runtime_tab' ),
136+ cls .element ('log' )]).then (
137+ cls .deploy_studio , [
138+ value for value in cls .elements ().values ()
139+ if not isinstance (value , (Tab , Accordion ))
140+ ], [cls .element ('log' )],
141+ queue = True )
142+ else :
143+ cls .element ('load_checkpoint' ).click (
144+ cls .reset_memory , [], [model_and_template_type ]).then (
145+ cls .reset_loading_button , [],
146+ [cls .element ('load_checkpoint' )
147+ ]).then (cls .get_model_template_type , [
148+ value for value in cls .elements ().values ()
149+ if not isinstance (value , (Tab , Accordion ))
150+ ], [model_and_template_type ]).then (
151+ cls .deploy_local , [
152+ value
153+ for value in cls .elements ().values ()
154+ if not isinstance (value , (Tab , Accordion ))
155+ ], []).then (
156+ cls .change_interactive , [],
157+ [prompt ]).then ( # noqa
158+ cls .clear_session ,
159+ inputs = [],
160+ outputs = [prompt ,
161+ chatbot ],
162+ queue = True ).then (
163+ cls .reset_load_button , [],
164+ [cls .element ('load_checkpoint' )])
136165
137166 @classmethod
138167 def reset_load_button (cls ):
@@ -148,9 +177,46 @@ def reset_memory(cls):
148177 return []
149178
150179 @classmethod
151- def prepare_checkpoint (cls , * args ):
152- torch .cuda .empty_cache ()
153- infer_args = cls .get_default_value_from_dataclass (InferArguments )
180+ def clear_session (cls ):
181+ return '' , None
182+
183+ @classmethod
184+ def change_interactive (cls ):
185+ return gr .update (interactive = True )
186+
187+ @classmethod
188+ def generate_chat (cls ,
189+ model_and_template_type ,
190+ template_type ,
191+ prompt : str ,
192+ history ,
193+ max_new_tokens ,
194+ system ,
195+ seed = 42 ):
196+ model_type = model_and_template_type [0 ]
197+ old_history , history = history , []
198+ request_config = XRequestConfig (seed = seed )
199+ request_config .stream = True
200+ stream_resp_with_history = ''
201+ if not template_type .endswith ('generation' ):
202+ stream_resp = inference_client (
203+ model_type ,
204+ prompt ,
205+ old_history ,
206+ system = system ,
207+ request_config = request_config )
208+ else :
209+ stream_resp = inference_client (
210+ model_type , prompt , request_config = request_config )
211+ for chunk in stream_resp :
212+ stream_resp_with_history += chunk .choices [0 ].delta .content
213+ qr_pair = [prompt , stream_resp_with_history ]
214+ total_history = old_history + [qr_pair ]
215+ yield '' , total_history
216+
217+ @classmethod
218+ def deploy (cls , * args ):
219+ deploy_args = cls .get_default_value_from_dataclass (DeployArguments )
154220 kwargs = {}
155221 kwargs_is_list = {}
156222 other_kwargs = {}
@@ -160,12 +226,12 @@ def prepare_checkpoint(cls, *args):
160226 if not isinstance (value , (Tab , Accordion ))
161227 ]
162228 for key , value in zip (keys , args ):
163- compare_value = infer_args .get (key )
229+ compare_value = deploy_args .get (key )
164230 compare_value_arg = str (compare_value ) if not isinstance (
165231 compare_value , (list , dict )) else compare_value
166232 compare_value_ui = str (value ) if not isinstance (
167233 value , (list , dict )) else value
168- if key in infer_args and compare_value_ui != compare_value_arg and value :
234+ if key in deploy_args and compare_value_ui != compare_value_arg and value :
169235 if isinstance (value , str ) and re .fullmatch (
170236 cls .int_regex , value ):
171237 value = int (value )
@@ -190,50 +256,66 @@ def prepare_checkpoint(cls, *args):
190256 'model_id_or_path' in kwargs
191257 and not os .path .exists (kwargs ['model_id_or_path' ])):
192258 kwargs .pop ('model_type' , None )
193-
259+ deploy_args = DeployArguments (
260+ ** {
261+ key : value .split (' ' )
262+ if key in kwargs_is_list and kwargs_is_list [key ] else value
263+ for key , value in kwargs .items ()
264+ })
265+ params = ''
266+ for e in kwargs :
267+ if e in kwargs_is_list and kwargs_is_list [e ]:
268+ params += f'--{ e } { kwargs [e ]} '
269+ else :
270+ params += f'--{ e } "{ kwargs [e ]} " '
194271 devices = other_kwargs ['gpu_id' ]
195272 devices = [d for d in devices if d ]
196273 assert (len (devices ) == 1 or 'cpu' not in devices )
197274 gpus = ',' .join (devices )
275+ cuda_param = ''
198276 if gpus != 'cpu' :
199- os .environ ['CUDA_VISIBLE_DEVICES' ] = gpus
200- infer_args = InferArguments (** kwargs )
201- model , template = prepare_model_template (infer_args )
202- gr .Info (cls .locale ('loaded_alert' , cls .lang )['value' ])
203- return [model , template ]
277+ cuda_param = f'CUDA_VISIBLE_DEVICES={ gpus } '
278+
279+ log_file = os .path .join (os .getcwd (), 'run_deploy.log' )
280+ if sys .platform == 'win32' :
281+ if cuda_param :
282+ cuda_param = f'set { cuda_param } && '
283+ run_command = f'{ cuda_param } start /b swift deploy { params } > { log_file } 2>&1'
284+ elif os .environ .get ('MODELSCOPE_ENVIRONMENT' ) == 'studio' :
285+ run_command = f'{ cuda_param } swift deploy { params } '
286+ else :
287+ run_command = f'{ cuda_param } nohup swift deploy { params } > { log_file } 2>&1 &'
288+ return run_command , deploy_args
204289
205290 @classmethod
206- def clear_session (cls ):
207- return '' , None
291+ def deploy_studio (cls , * args ):
292+ run_command , deploy_args = cls .deploy (* args )
293+ if os .environ .get ('MODELSCOPE_ENVIRONMENT' ) == 'studio' :
294+ lines = collections .deque (
295+ maxlen = int (os .environ .get ('MAX_LOG_LINES' , 50 )))
296+ logger .info (f'Run deploying: { run_command } ' )
297+ process = Popen (
298+ run_command , shell = True , stdout = PIPE , stderr = STDOUT )
299+ with process .stdout :
300+ for line in iter (process .stdout .readline , b'' ):
301+ line = line .decode ('utf-8' )
302+ lines .append (line )
303+ yield '\n ' .join (lines )
208304
209305 @classmethod
210- def change_interactive (cls ):
211- return gr .update (interactive = True )
306+ def deploy_local (cls , * args ):
307+ run_command , deploy_args = cls .deploy (* args )
308+ lines = collections .deque (
309+ maxlen = int (os .environ .get ('MAX_LOG_LINES' , 50 )))
310+ logger .info (f'Run deploying: { run_command } ' )
311+ process = Popen (run_command , shell = True , stdout = PIPE , stderr = STDOUT )
312+ with process .stdout :
313+ for line in iter (process .stdout .readline , b'' ):
314+ line = line .decode ('utf-8' )
315+ lines .append (line )
316+ yield '\n ' .join (lines )
212317
213318 @classmethod
214- def generate_chat (cls , model_and_template , template_type , prompt : str ,
215- history , max_new_tokens , system ):
216- if not model_and_template :
217- gr .Warning (cls .locale ('generate_alert' , cls .lang )['value' ])
218- return '' , None
219- model , template = model_and_template
220- if os .environ .get ('MODELSCOPE_ENVIRONMENT' ) == 'studio' :
221- model .cuda ()
222- if not template_type .endswith ('generation' ):
223- old_history , history = limit_history_length (
224- template , prompt , history , int (max_new_tokens ))
225- else :
226- old_history = []
227- history = []
228- gen = inference_stream (
229- model ,
230- template ,
231- prompt ,
232- history ,
233- system = system ,
234- stop_words = ['Observation:' ])
235- for _ , history in gen :
236- total_history = old_history + history
237- yield '' , total_history
238- if os .environ .get ('MODELSCOPE_ENVIRONMENT' ) == 'studio' :
239- model .cpu ()
319+ def get_model_template_type (cls , * args ):
320+ run_command , deploy_args = cls .deploy (* args )
321+ return [deploy_args .model_type , deploy_args .template_type ]
0 commit comments