Skip to content

Commit f4aca38

Browse files
fix webui (#296)
1 parent b142266 commit f4aca38

File tree

3 files changed

+28
-22
lines changed

3 files changed

+28
-22
lines changed

swift/ui/base.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import os
2+
import typing
3+
from dataclasses import fields
24
from functools import partial, wraps
35
from typing import Any, Dict, List, OrderedDict, Type
46

@@ -146,3 +148,23 @@ def set_lang(cls, lang):
146148
cls.lang = lang
147149
for sub_ui in cls.sub_ui:
148150
sub_ui.lang = lang
151+
152+
@staticmethod
153+
def get_choices_from_dataclass(dataclass):
154+
choice_dict = {}
155+
for f in fields(dataclass):
156+
if 'choices' in f.metadata:
157+
choice_dict[f.name] = f.metadata['choices']
158+
if 'Literal' in type(f.type).__name__ and typing.get_args(f.type):
159+
choice_dict[f.name] = typing.get_args(f.type)
160+
return choice_dict
161+
162+
@staticmethod
163+
def get_default_value_from_dataclass(dataclass):
164+
default_dict = {}
165+
for f in fields(dataclass):
166+
if hasattr(dataclass, f.name):
167+
default_dict[f.name] = getattr(dataclass, f.name)
168+
else:
169+
default_dict[f.name] = None
170+
return default_dict

swift/ui/llm_infer/llm_infer.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,7 @@ class LLMInfer(BaseUI):
7878
},
7979
}
8080

81-
choice_dict = {}
82-
for f in fields(InferArguments):
83-
if 'choices' in f.metadata:
84-
choice_dict[f.name] = f.metadata['choices']
81+
choice_dict = BaseUI.get_choices_from_dataclass(InferArguments)
8582

8683
@classmethod
8784
def do_build_ui(cls, base_tab: Type['BaseUI']):
@@ -143,11 +140,7 @@ def reset_memory(cls):
143140
def prepare_checkpoint(cls, *args):
144141
global model, tokenizer, template
145142
torch.cuda.empty_cache()
146-
infer_args = fields(InferArguments)
147-
infer_args = {
148-
arg.name: getattr(InferArguments, arg.name)
149-
for arg in infer_args
150-
}
143+
infer_args = cls.get_default_value_from_dataclass(InferArguments)
151144
kwargs = {}
152145
kwargs_is_list = {}
153146
other_kwargs = {}

swift/ui/llm_train/llm_train.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import sys
33
import time
4-
from dataclasses import fields
54
from typing import Dict, Type
65

76
import gradio as gr
@@ -147,12 +146,8 @@ class LLMTrain(BaseUI):
147146
}
148147
}
149148

150-
choice_dict = {}
151-
default_dict = {}
152-
for f in fields(SftArguments):
153-
if 'choices' in f.metadata:
154-
choice_dict[f.name] = f.metadata['choices']
155-
default_dict[f.name] = getattr(SftArguments, f.name)
149+
choice_dict = BaseUI.get_choices_from_dataclass(SftArguments)
150+
default_dict = BaseUI.get_default_value_from_dataclass(SftArguments)
156151

157152
@classmethod
158153
def do_build_ui(cls, base_tab: Type['BaseUI']):
@@ -210,11 +205,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
210205
@classmethod
211206
def train(cls, *args):
212207
ignore_elements = ('model_type', 'logging_dir', 'more_params')
213-
sft_args = fields(SftArguments)
214-
sft_args = {
215-
arg.name: getattr(SftArguments, arg.name)
216-
for arg in sft_args
217-
}
208+
sft_args = cls.get_default_value_from_dataclass(SftArguments)
218209
kwargs = {}
219210
kwargs_is_list = {}
220211
other_kwargs = {}
@@ -244,7 +235,7 @@ def train(cls, *args):
244235
params = ''
245236

246237
for e in kwargs:
247-
if kwargs_is_list[e]:
238+
if e in kwargs_is_list and kwargs_is_list[e]:
248239
params += f'--{e} {kwargs[e]} '
249240
else:
250241
params += f'--{e} "{kwargs[e]}" '

0 commit comments

Comments
 (0)