Skip to content

Commit 15e21ec

Browse files
Fix ui (#335)
1 parent 8a28429 commit 15e21ec

File tree

4 files changed

+19
-8
lines changed

4 files changed

+19
-8
lines changed

swift/ui/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ class BaseUI:
7676
sub_ui: List[Type['BaseUI']] = []
7777
group: str = None
7878
lang: str = all_langs[0]
79+
int_regex = r'^[-+]?[0-9]+$'
80+
float_regex = r'[-+]?(?:\d*\.*\d+)'
7981

8082
@classmethod
8183
def build_ui(cls, base_tab: Type['BaseUI']):

swift/ui/llm_infer/llm_infer.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from gradio import Accordion, Tab
99

10+
from swift import snapshot_download
1011
from swift.llm import (InferArguments, inference_stream, limit_history_length,
1112
prepare_model_template)
1213
from swift.ui.base import BaseUI
@@ -19,9 +20,6 @@ class LLMInfer(BaseUI):
1920

2021
sub_ui = [Model]
2122

22-
int_regex = r'^[-+]?[0-9]+$'
23-
float_regex = r'[-+]?(?:\d*\.*\d+)'
24-
2523
locale_dict = {
2624
'generate_alert': {
2725
'value': {
@@ -137,7 +135,6 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
137135

138136
@classmethod
139137
def reset_load_button(cls):
140-
gr.Info(cls.locale('loaded_alert', cls.lang)['value'])
141138
return gr.update(
142139
value=cls.locale('load_checkpoint', cls.lang)['value'])
143140

@@ -184,7 +181,10 @@ def prepare_checkpoint(cls, *args):
184181

185182
kwargs.update(more_params)
186183
if kwargs['model_type'] == cls.locale('checkpoint', cls.lang)['value']:
187-
kwargs['ckpt_dir'] = kwargs.pop('model_id_or_path')
184+
model_dir = kwargs.pop('model_id_or_path')
185+
if not os.path.exists(model_dir):
186+
model_dir = snapshot_download(model_dir)
187+
kwargs['ckpt_dir'] = model_dir
188188
if 'ckpt_dir' in kwargs or 'model_id_or_path' in kwargs:
189189
kwargs.pop('model_type', None)
190190

@@ -194,8 +194,9 @@ def prepare_checkpoint(cls, *args):
194194
gpus = ','.join(devices)
195195
if gpus != 'cpu':
196196
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
197-
inter_args = InferArguments(**kwargs)
198-
model, template = prepare_model_template(inter_args)
197+
infer_args = InferArguments(**kwargs)
198+
model, template = prepare_model_template(infer_args)
199+
gr.Info(cls.locale('loaded_alert', cls.lang)['value'])
199200
return [model, template]
200201

201202
@classmethod

swift/ui/llm_infer/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def update_input_model(choice):
121121
cls.lang)['value'])
122122

123123
def update_model_id_or_path(model_type, path):
124-
if not path:
124+
if not path or not os.path.exists(path):
125125
return None, None, None
126126
local_path = os.path.join(path, 'sft_args.json')
127127
if not os.path.exists(local_path):

swift/ui/llm_train/llm_train.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import collections
22
import os
3+
import re
34
import sys
45
import time
56
from subprocess import PIPE, STDOUT, Popen
@@ -243,6 +244,13 @@ def train(cls, *args):
243244
compare_value, (list, dict)) else compare_value
244245
compare_value_ui = str(value) if not isinstance(
245246
value, (list, dict)) else value
247+
248+
if isinstance(value, str) and re.fullmatch(cls.int_regex, value):
249+
value = int(value)
250+
elif isinstance(value, str) and re.fullmatch(
251+
cls.float_regex, value):
252+
value = float(value)
253+
246254
if key not in ignore_elements and key in sft_args and compare_value_ui != compare_value_arg and value:
247255
kwargs[key] = value if not isinstance(
248256
value, list) else ' '.join(value)

0 commit comments

Comments
 (0)