Skip to content

Commit a46cf1c

Browse files
committed
fix: 修复语音模型传入不正确参数报错的问题
1 parent a0ad4c9 commit a46cf1c

File tree

9 files changed

+76
-50
lines changed

9 files changed

+76
-50
lines changed

apps/application/serializers/application_serializers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1028,7 +1028,11 @@ def play_demo_text(self, form_data, with_valid=True):
10281028
application_id = self.data.get('application_id')
10291029
application = QuerySet(Application).filter(id=application_id).first()
10301030
if application.tts_model_enable:
1031-
model = get_model_instance_by_model_user_id(application.tts_model_id, application.user_id, **form_data)
1031+
tts_model_id = application.tts_model_id
1032+
if 'tts_model_id' in form_data:
1033+
tts_model_id = form_data.get('tts_model_id')
1034+
del form_data['tts_model_id']
1035+
model = get_model_instance_by_model_user_id(tts_model_id, application.user_id, **form_data)
10321036
return model.text_to_speech(text)
10331037

10341038
class ApplicationKeySerializerModel(serializers.ModelSerializer):

apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/tts.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,21 @@
1010
class AliyunBaiLianTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
1111
api_key: str
1212
model: str
13-
voice: str
14-
speech_rate: float
13+
params: dict
1514

1615
def __init__(self, **kwargs):
1716
super().__init__(**kwargs)
1817
self.api_key = kwargs.get('api_key')
1918
self.model = kwargs.get('model')
20-
self.voice = kwargs.get('voice')
21-
self.speech_rate = kwargs.get('speech_rate')
19+
self.params = kwargs.get('params')
2220

2321
@staticmethod
2422
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
25-
optional_params = {'voice': 'longxiaochun', 'speech_rate': 1.0}
26-
if 'voice' in model_kwargs and model_kwargs['voice'] is not None:
27-
optional_params['voice'] = model_kwargs['voice']
28-
if 'speech_rate' in model_kwargs and model_kwargs['speech_rate'] is not None:
29-
optional_params['speech_rate'] = model_kwargs['speech_rate']
23+
optional_params = {'params': {'voice': 'longxiaochun', 'speech_rate': 1.0}}
24+
for key, value in model_kwargs.items():
25+
if key not in ['model_id', 'use_local', 'streaming']:
26+
optional_params['params'][key] = value
27+
3028
return AliyunBaiLianTextToSpeech(
3129
model=model_name,
3230
api_key=model_credential.get('api_key'),
@@ -38,7 +36,7 @@ def check_auth(self):
3836

3937
def text_to_speech(self, text):
4038
dashscope.api_key = self.api_key
41-
synthesizer = SpeechSynthesizer(model=self.model, voice=self.voice, speech_rate=self.speech_rate)
39+
synthesizer = SpeechSynthesizer(model=self.model, **self.params)
4240
audio = synthesizer.call(text)
4341
if type(audio) == str:
4442
print(audio)

apps/setting/models_provider/impl/openai_model_provider/model/tts.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,21 @@ class OpenAITextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
1616
api_base: str
1717
api_key: str
1818
model: str
19-
voice: str
19+
params: dict
2020

2121
def __init__(self, **kwargs):
2222
super().__init__(**kwargs)
2323
self.api_key = kwargs.get('api_key')
2424
self.api_base = kwargs.get('api_base')
2525
self.model = kwargs.get('model')
26-
self.voice = kwargs.get('voice', 'alloy')
26+
self.params = kwargs.get('params')
2727

2828
@staticmethod
2929
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
30-
optional_params = {'voice': 'alloy'}
31-
if 'voice' in model_kwargs and model_kwargs['voice'] is not None:
32-
optional_params['voice'] = model_kwargs['voice']
30+
optional_params = {'params': {'voice': 'alloy'}}
31+
for key, value in model_kwargs.items():
32+
if key not in ['model_id', 'use_local', 'streaming']:
33+
optional_params['params'][key] = value
3334
return OpenAITextToSpeech(
3435
model=model_name,
3536
api_base=model_credential.get('api_base'),
@@ -52,10 +53,10 @@ def text_to_speech(self, text):
5253
)
5354
with client.audio.speech.with_streaming_response.create(
5455
model=self.model,
55-
voice=self.voice,
5656
input=text,
57+
**self.params
5758
) as response:
5859
return response.read()
5960

6061
def is_cache_model(self):
61-
return False
62+
return False

apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,25 +45,22 @@ class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
4545
volcanic_cluster: str
4646
volcanic_api_url: str
4747
volcanic_token: str
48-
speed_ratio: float
49-
voice_type: str
48+
params: dict
5049

5150
def __init__(self, **kwargs):
5251
super().__init__(**kwargs)
5352
self.volcanic_api_url = kwargs.get('volcanic_api_url')
5453
self.volcanic_token = kwargs.get('volcanic_token')
5554
self.volcanic_app_id = kwargs.get('volcanic_app_id')
5655
self.volcanic_cluster = kwargs.get('volcanic_cluster')
57-
self.voice_type = kwargs.get('voice_type')
58-
self.speed_ratio = kwargs.get('speed_ratio')
56+
self.params = kwargs.get('params')
5957

6058
@staticmethod
6159
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
62-
optional_params = {'voice_type': 'BV002_streaming', 'speed_ratio': 1.0}
63-
if 'voice_type' in model_kwargs and model_kwargs['voice_type'] is not None:
64-
optional_params['voice_type'] = model_kwargs['voice_type']
65-
if 'speed_ratio' in model_kwargs and model_kwargs['speed_ratio'] is not None:
66-
optional_params['speed_ratio'] = model_kwargs['speed_ratio']
60+
optional_params = {'params': {'voice_type': 'BV002_streaming', 'speed_ratio': 1.0}}
61+
for key, value in model_kwargs.items():
62+
if key not in ['model_id', 'use_local', 'streaming']:
63+
optional_params['params'][key] = value
6764
return VolcanicEngineTextToSpeech(
6865
volcanic_api_url=model_credential.get('volcanic_api_url'),
6966
volcanic_token=model_credential.get('volcanic_token'),
@@ -86,12 +83,10 @@ def text_to_speech(self, text):
8683
"uid": "uid"
8784
},
8885
"audio": {
89-
"voice_type": self.voice_type,
9086
"encoding": "mp3",
91-
"speed_ratio": self.speed_ratio,
9287
"volume_ratio": 1.0,
9388
"pitch_ratio": 1.0,
94-
},
89+
} | self.params,
9590
"request": {
9691
"reqid": str(uuid.uuid4()),
9792
"text": '',

apps/setting/models_provider/impl/xf_model_provider/model/tts.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,22 @@ class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
3737
spark_api_key: str
3838
spark_api_secret: str
3939
spark_api_url: str
40-
speed: int
41-
vcn: str
40+
params: dict
4241

4342
def __init__(self, **kwargs):
4443
super().__init__(**kwargs)
4544
self.spark_api_url = kwargs.get('spark_api_url')
4645
self.spark_app_id = kwargs.get('spark_app_id')
4746
self.spark_api_key = kwargs.get('spark_api_key')
4847
self.spark_api_secret = kwargs.get('spark_api_secret')
49-
self.vcn = kwargs.get('vcn')
50-
self.speed = kwargs.get('speed')
48+
self.params = kwargs.get('params')
5149

5250
@staticmethod
5351
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
54-
optional_params = {'vcn': 'xiaoyan', 'speed': 50}
55-
if 'vcn' in model_kwargs and model_kwargs['vcn'] is not None:
56-
optional_params['vcn'] = model_kwargs['vcn']
57-
if 'speed' in model_kwargs and model_kwargs['speed'] is not None:
58-
optional_params['speed'] = model_kwargs['speed']
52+
optional_params = {'params': {'vcn': 'xiaoyan', 'speed': 50}}
53+
for key, value in model_kwargs.items():
54+
if key not in ['model_id', 'use_local', 'streaming']:
55+
optional_params['params'][key] = value
5956
return XFSparkTextToSpeech(
6057
spark_app_id=model_credential.get('spark_app_id'),
6158
spark_api_key=model_credential.get('spark_api_key'),
@@ -139,9 +136,10 @@ async def handle_message(ws):
139136
return audio_bytes
140137

141138
async def send(self, ws, text):
139+
business = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "tte": "utf8"}
142140
d = {
143141
"common": {"app_id": self.spark_app_id},
144-
"business": {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.vcn, "speed": self.speed, "tte": "utf8"},
142+
"business": business | self.params,
145143
"data": {"status": 2, "text": str(base64.b64encode(text.encode('utf-8')), "UTF8")},
146144
}
147145
d = json.dumps(d)

apps/setting/models_provider/impl/xinference_model_provider/model/tts.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,20 @@ def custom_get_token_ids(text: str):
1515
class XInferenceTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
1616
api_base: str
1717
api_key: str
18-
model: str
19-
voice: str
18+
params: dict
2019

2120
def __init__(self, **kwargs):
2221
super().__init__(**kwargs)
2322
self.api_key = kwargs.get('api_key')
2423
self.api_base = kwargs.get('api_base')
25-
self.model = kwargs.get('model')
26-
self.voice = kwargs.get('voice', '中文女')
24+
self.params = kwargs.get('params')
2725

2826
@staticmethod
2927
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
30-
optional_params = {'voice': '中文女'}
31-
if 'voice' in model_kwargs and model_kwargs['voice'] is not None:
32-
optional_params['voice'] = model_kwargs['voice']
28+
optional_params = {'params': {'voice': '中文女'}}
29+
for key, value in model_kwargs.items():
30+
if key not in ['model_id', 'use_local', 'streaming']:
31+
optional_params['params'][key] = value
3332
return XInferenceTextToSpeech(
3433
model=model_name,
3534
api_base=model_credential.get('api_base'),
@@ -54,8 +53,8 @@ def text_to_speech(self, text):
5453

5554
with client.audio.speech.with_streaming_response.create(
5655
model=self.model,
57-
voice=self.voice,
5856
input=text,
57+
**self.params
5958
) as response:
6059
return response.read()
6160

ui/src/views/application/ApplicationSetting.vue

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@
422422
v-model="applicationForm.tts_model_id"
423423
class="w-full"
424424
popper-class="select-model"
425+
@change="ttsModelChange()"
425426
placeholder="请选择语音合成模型"
426427
>
427428
<el-option-group
@@ -807,6 +808,14 @@ function getTTSModel() {
807808
})
808809
}
809810
811+
function ttsModelChange() {
812+
if (applicationForm.value.tts_model_id) {
813+
TTSModeParamSettingDialogRef.value?.reset_default(applicationForm.value.tts_model_id, id)
814+
} else {
815+
refreshTTSForm({})
816+
}
817+
}
818+
810819
function getProvider() {
811820
loading.value = true
812821
model

ui/src/views/application/component/TTSModeParamSettingDialog.vue

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,13 @@ import applicationApi from '@/api/application'
5050
import DynamicsForm from '@/components/dynamics-form/index.vue'
5151
import { keys } from 'lodash'
5252
import { app } from '@/main'
53+
import { MsgError } from '@/utils/message'
5354
5455
const {
5556
params: { id }
5657
} = app.config.globalProperties.$route as any
5758
59+
const tts_model_id = ref('')
5860
const model_form_field = ref<Array<FormField>>([])
5961
const emit = defineEmits(['refresh'])
6062
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
@@ -69,6 +71,7 @@ const getApi = (model_id: string, application_id?: string) => {
6971
}
7072
const open = (model_id: string, application_id?: string, model_setting_data?: any) => {
7173
form_data.value = {}
74+
tts_model_id.value = model_id
7275
const api = getApi(model_id, application_id)
7376
api.then((ok) => {
7477
model_form_field.value = ok.data
@@ -104,9 +107,18 @@ const submit = async () => {
104107
105108
const audioPlayer = ref<HTMLAudioElement | null>(null)
106109
const testPlay = () => {
110+
const data = {
111+
...form_data.value,
112+
tts_model_id: tts_model_id.value
113+
}
107114
applicationApi
108-
.playDemoText(id as string, form_data.value, playLoading)
109-
.then((res: any) => {
115+
.playDemoText(id as string, data, playLoading)
116+
.then(async (res: any) => {
117+
if (res.type === 'application/json') {
118+
const text = await res.text();
119+
MsgError(text)
120+
return
121+
}
110122
// 创建 Blob 对象
111123
const blob = new Blob([res], { type: 'audio/mp3' })
112124

ui/src/workflow/nodes/base-node/index.vue

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@
153153
v-model="form_data.tts_model_id"
154154
class="w-full"
155155
popper-class="select-model"
156+
@change="ttsModelChange()"
156157
placeholder="请选择语音合成模型"
157158
>
158159
<el-option-group
@@ -312,6 +313,15 @@ function getTTSModel() {
312313
})
313314
}
314315
316+
function ttsModelChange() {
317+
if (form_data.value.tts_model_id) {
318+
TTSModeParamSettingDialogRef.value?.reset_default(form_data.value.tts_model_id, id)
319+
} else {
320+
refreshTTSForm({})
321+
}
322+
}
323+
324+
315325
const openTTSParamSettingDialog = () => {
316326
const model_id = form_data.value.tts_model_id
317327
if (!model_id) {

0 commit comments

Comments
 (0)