11# Copyright (c) Alibaba, Inc. and its affiliates.
22import os
3+ from copy import copy
34from dataclasses import fields
45from functools import partial
56from typing import List , Union
@@ -72,6 +73,7 @@ def run(self):
7273 for f in fields (self .args ):
7374 if getattr (self .args , f .name ):
7475 LLMInfer .default_dict [f .name ] = getattr (self .args , f .name )
76+
7577 LLMInfer .is_gradio_app = True
7678 LLMInfer .is_multimodal = self .args .model_meta .is_multimodal
7779 LLMInfer .build_ui (LLMInfer )
@@ -93,10 +95,20 @@ def run(self):
9395 value = getattr (self .args , f .name )
9496 if isinstance (value , list ):
9597 value = ' ' .join ([v or '' for v in value ])
96- LLMInfer .elements ()[f .name ].value = value
97- app .load (LLMInfer .deploy_model , list (LLMInfer .valid_elements ().values ()),
98- [LLMInfer .element ('runtime_tab' ),
99- LLMInfer .element ('running_tasks' )])
98+ LLMInfer .elements ()[f .name ].value = str (value )
99+
100+ args = copy (self .args )
101+ args .port = find_free_port ()
102+
103+ values = []
104+ for key in LLMInfer .valid_elements ():
105+ if key in args .__dict__ :
106+ value = getattr (args , key )
107+ else :
108+ value = LLMInfer .element (key ).value
109+ values .append (value )
110+ _ , running_task = LLMInfer .deploy_model (* values )
111+ LLMInfer .element ('running_tasks' ).value = running_task ['value' ]
100112 else :
101113 app .load (
102114 partial (LLMTrain .update_input_model , arg_cls = RLHFArguments ),
0 commit comments