Skip to content

Commit 83ef4a9

Browse files
authored
feat(model/qwen-tts) interface change to multimodal_conversation (#58)
1 parent 556be36 commit 83ef4a9

File tree

3 files changed

+110
-23
lines changed

3 files changed

+110
-23
lines changed

dashscope/aigc/multimodal_conversation.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ class Models:
2424
def call(
2525
cls,
2626
model: str,
27-
messages: List,
27+
messages: List = None,
2828
api_key: str = None,
2929
workspace: str = None,
30+
text: str = None,
3031
**kwargs
3132
) -> Union[MultiModalConversationResponse, Generator[
3233
MultiModalConversationResponse, None, None]]:
@@ -55,6 +56,7 @@ def call(
5556
if None, will retrieve by rule [1].
5657
[1]: https://help.aliyun.com/zh/dashscope/developer-reference/api-key-settings. # noqa E501
5758
workspace (str): The dashscope workspace id.
59+
text (str): The text to generate.
5860
**kwargs:
5961
stream(bool, `optional`): Enable server-sent events
6062
(ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501
@@ -68,8 +70,11 @@ def call(
6870
tokens with top_p probability mass. So 0.1 means only
6971
the tokens comprising the top 10% probability mass are
7072
considered[qwen-turbo,bailian-v1].
73+
voice(string, `optional`): The voice name of qwen tts, include 'Cherry'/'Ethan'/'Sunny'/'Dylan' and so on,
74+
you can get the total voice list : https://help.aliyun.com/zh/model-studio/qwen-tts.
7175
top_k(float, `optional`):
7276
77+
7378
Raises:
7479
InvalidInput: The history and auto_history are mutually exclusive.
7580
@@ -78,18 +83,24 @@ def call(
7883
Generator[MultiModalConversationResponse, None, None]]: If
7984
stream is True, return Generator, otherwise MultiModalConversationResponse.
8085
"""
81-
if (messages is None or not messages):
82-
raise InputRequired('prompt or messages is required!')
8386
if model is None or not model:
8487
raise ModelRequired('Model is required!')
8588
task_group, _ = _get_task_group_and_task(__name__)
86-
msg_copy = copy.deepcopy(messages)
87-
has_upload = cls._preprocess_messages(model, msg_copy, api_key)
88-
if has_upload:
89-
headers = kwargs.pop('headers', {})
90-
headers['X-DashScope-OssResourceResolve'] = 'enable'
91-
kwargs['headers'] = headers
92-
input = {'messages': msg_copy}
89+
input = {}
90+
msg_copy = None
91+
92+
if messages is not None and messages:
93+
msg_copy = copy.deepcopy(messages)
94+
has_upload = cls._preprocess_messages(model, msg_copy, api_key)
95+
if has_upload:
96+
headers = kwargs.pop('headers', {})
97+
headers['X-DashScope-OssResourceResolve'] = 'enable'
98+
kwargs['headers'] = headers
99+
100+
if text is not None and text:
101+
input.update({'text': text})
102+
if msg_copy is not None:
103+
input.update({'messages': msg_copy})
93104
response = super().call(model=model,
94105
task_group=task_group,
95106
task=MultiModalConversation.task,
@@ -145,9 +156,10 @@ class Models:
145156
async def call(
146157
cls,
147158
model: str,
148-
messages: List,
159+
messages: List = None,
149160
api_key: str = None,
150161
workspace: str = None,
162+
text: str = None,
151163
**kwargs
152164
) -> Union[MultiModalConversationResponse, Generator[
153165
MultiModalConversationResponse, None, None]]:
@@ -176,6 +188,7 @@ async def call(
176188
if None, will retrieve by rule [1].
177189
[1]: https://help.aliyun.com/zh/dashscope/developer-reference/api-key-settings. # noqa E501
178190
workspace (str): The dashscope workspace id.
191+
text (str): The text to generate.
179192
**kwargs:
180193
stream(bool, `optional`): Enable server-sent events
181194
(ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501
@@ -189,6 +202,8 @@ async def call(
189202
tokens with top_p probability mass. So 0.1 means only
190203
the tokens comprising the top 10% probability mass are
191204
considered[qwen-turbo,bailian-v1].
205+
voice(string, `optional`): The voice name of qwen tts, include 'Cherry'/'Ethan'/'Sunny'/'Dylan' and so on,
206+
you can get the total voice list : https://help.aliyun.com/zh/model-studio/qwen-tts.
192207
top_k(float, `optional`):
193208
194209
Raises:
@@ -199,18 +214,24 @@ async def call(
199214
Generator[MultiModalConversationResponse, None, None]]: If
200215
stream is True, return Generator, otherwise MultiModalConversationResponse.
201216
"""
202-
if (messages is None or not messages):
203-
raise InputRequired('prompt or messages is required!')
204217
if model is None or not model:
205218
raise ModelRequired('Model is required!')
206219
task_group, _ = _get_task_group_and_task(__name__)
207-
msg_copy = copy.deepcopy(messages)
208-
has_upload = cls._preprocess_messages(model, msg_copy, api_key)
209-
if has_upload:
210-
headers = kwargs.pop('headers', {})
211-
headers['X-DashScope-OssResourceResolve'] = 'enable'
212-
kwargs['headers'] = headers
213-
input = {'messages': msg_copy}
220+
input = {}
221+
msg_copy = None
222+
223+
if messages is not None and messages:
224+
msg_copy = copy.deepcopy(messages)
225+
has_upload = cls._preprocess_messages(model, msg_copy, api_key)
226+
if has_upload:
227+
headers = kwargs.pop('headers', {})
228+
headers['X-DashScope-OssResourceResolve'] = 'enable'
229+
kwargs['headers'] = headers
230+
231+
if text is not None and text:
232+
input.update({'text': text})
233+
if msg_copy is not None:
234+
input.update({'messages': msg_copy})
214235
response = await super().call(model=model,
215236
task_group=task_group,
216237
task=AioMultiModalConversation.task,

dashscope/api_entities/dashscope_response.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,26 @@ def __init__(self,
152152
**kwargs)
153153

154154

155+
@dataclass(init=False)
156+
class Audio(DictMixin):
157+
data: str
158+
url: str
159+
id: str
160+
expires_at: int
161+
162+
def __init__(self,
163+
data: str = None,
164+
url: str = None,
165+
id: str = None,
166+
expires_at: int = None,
167+
**kwargs):
168+
super().__init__(data=data,
169+
url=url,
170+
id=id,
171+
expires_at=expires_at,
172+
**kwargs)
173+
174+
155175
@dataclass(init=False)
156176
class GenerationOutput(DictMixin):
157177
text: str
@@ -217,36 +237,44 @@ def from_api_response(api_response: DashScopeAPIResponse):
217237
@dataclass(init=False)
218238
class MultiModalConversationOutput(DictMixin):
219239
choices: List[Choice]
240+
audio: Audio
220241

221242
def __init__(self,
222243
text: str = None,
223244
finish_reason: str = None,
224245
choices: List[Choice] = None,
246+
audio: Audio = None,
225247
**kwargs):
226248
chs = None
227249
if choices is not None:
228250
chs = []
229251
for choice in choices:
230252
chs.append(Choice(**choice))
253+
if audio is not None:
254+
audio = Audio(**audio)
231255
super().__init__(text=text,
232256
finish_reason=finish_reason,
233257
choices=chs,
258+
audio=audio,
234259
**kwargs)
235260

236261

237262
@dataclass(init=False)
238263
class MultiModalConversationUsage(DictMixin):
239264
input_tokens: int
240265
output_tokens: int
266+
characters: int
241267

242268
# TODO add image usage info.
243269

244270
def __init__(self,
245271
input_tokens: int = 0,
246272
output_tokens: int = 0,
273+
characters: int = 0,
247274
**kwargs):
248275
super().__init__(input_tokens=input_tokens,
249276
output_tokens=output_tokens,
277+
characters=characters,
250278
**kwargs)
251279

252280

@@ -378,7 +406,7 @@ def is_sentence_end(sentence: Dict[str, Any]) -> bool:
378406
"""
379407
result = False
380408
if sentence is not None and 'end_time' in sentence and sentence[
381-
'end_time'] is not None:
409+
'end_time'] is not None:
382410
result = True
383411
return result
384412

@@ -445,8 +473,8 @@ class ImageSynthesisOutput(DictMixin):
445473
results: List[ImageSynthesisResult]
446474

447475
def __init__(self,
448-
task_id: str = None,
449-
task_status: str = None,
476+
task_id: str = None,
477+
task_status: str = None,
450478
results: List[ImageSynthesisResult] = [],
451479
**kwargs):
452480
res = []

samples/test_qwen_tts.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import os
2+
3+
import dashscope
4+
import logging
5+
6+
logger = logging.getLogger('dashscope')
7+
logger.setLevel(logging.DEBUG)
8+
console_handler = logging.StreamHandler()
9+
# create formatter
10+
formatter = logging.Formatter(
11+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
12+
# add formatter to ch
13+
console_handler.setFormatter(formatter)
14+
15+
# add ch to logger
16+
logger.addHandler(console_handler)
17+
18+
# switch stream or non-stream mode
19+
use_stream = True
20+
21+
response = dashscope.MultiModalConversation.call(
22+
api_key=os.getenv('DASHSCOPE_API_KEY'),
23+
model="qwen-tts",
24+
text="Today is a wonderful day to build something people love!",
25+
voice="Cherry",
26+
stream=use_stream
27+
)
28+
if use_stream:
29+
# print the audio data in stream mode
30+
for chunk in response:
31+
audio = chunk.output.audio
32+
print("base64 audio data is: {}", chunk.output.audio.data)
33+
if chunk.output.finish_reason == "stop":
34+
print("finish at: {} ", chunk.output.audio.expires_at)
35+
else:
36+
# print the audio url in non-stream mode
37+
print("synthesized audio url is: {}", response.output.audio.url)
38+
print("finish at: {} ", response.output.audio.expires_at)

0 commit comments

Comments
 (0)