Skip to content

Commit f7c5f15

Browse files
committed
Merge branch 'develop'
2 parents d294731 + aafb74d commit f7c5f15

File tree

8 files changed

+2254
-213
lines changed

8 files changed

+2254
-213
lines changed

dashscope/aigc/generation.py

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import copy
44
import json
5-
from typing import Any, Dict, Generator, List, Union
5+
from typing import Any, Dict, Generator, List, Union, AsyncGenerator
66

77
from dashscope.api_entities.dashscope_response import (GenerationResponse,
88
Message, Role)
@@ -13,6 +13,8 @@
1313
from dashscope.common.error import InputRequired, ModelRequired
1414
from dashscope.common.logging import logger
1515
from dashscope.common.utils import _get_task_group_and_task
16+
from dashscope.utils.param_utils import ParamUtil
17+
from dashscope.utils.message_utils import merge_single_response
1618

1719

1820
class Generation(BaseApi):
@@ -137,6 +139,16 @@ def call(
137139
kwargs['headers'] = headers
138140
input, parameters = cls._build_input_parameters(
139141
model, prompt, history, messages, **kwargs)
142+
143+
is_stream = parameters.get('stream', False)
144+
# Check if we need to merge incremental output
145+
is_incremental_output = kwargs.get('incremental_output', None)
146+
to_merge_incremental_output = False
147+
if (ParamUtil.should_modify_incremental_output(model) and
148+
is_stream and is_incremental_output is False):
149+
to_merge_incremental_output = True
150+
parameters['incremental_output'] = True
151+
140152
response = super().call(model=model,
141153
task_group=task_group,
142154
task=Generation.task,
@@ -145,10 +157,14 @@ def call(
145157
input=input,
146158
workspace=workspace,
147159
**parameters)
148-
is_stream = kwargs.get('stream', False)
149160
if is_stream:
150-
return (GenerationResponse.from_api_response(rsp)
151-
for rsp in response)
161+
if to_merge_incremental_output:
162+
# Extract n parameter for merge logic
163+
n = parameters.get('n', 1)
164+
return cls._merge_generation_response(response, n)
165+
else:
166+
return (GenerationResponse.from_api_response(rsp)
167+
for rsp in response)
152168
else:
153169
return GenerationResponse.from_api_response(response)
154170

@@ -191,6 +207,20 @@ def _build_input_parameters(cls, model, prompt, history, messages,
191207

192208
return input, {**parameters, **kwargs}
193209

210+
@classmethod
211+
def _merge_generation_response(cls, response, n=1) -> Generator[GenerationResponse, None, None]:
212+
"""Merge incremental response chunks to simulate non-incremental output."""
213+
accumulated_data = {}
214+
for rsp in response:
215+
parsed_response = GenerationResponse.from_api_response(rsp)
216+
result = merge_single_response(parsed_response, accumulated_data, n)
217+
if result is True:
218+
yield parsed_response
219+
elif isinstance(result, list):
220+
# Multiple responses to yield (for n>1 non-stop cases)
221+
for resp in result:
222+
yield resp
223+
194224

195225
class AioGeneration(BaseAioApi):
196226
task = 'text-generation'
@@ -220,7 +250,7 @@ async def call(
220250
plugins: Union[str, Dict[str, Any]] = None,
221251
workspace: str = None,
222252
**kwargs
223-
) -> Union[GenerationResponse, Generator[GenerationResponse, None, None]]:
253+
) -> Union[GenerationResponse, AsyncGenerator[GenerationResponse, None]]:
224254
"""Call generation model service.
225255
226256
Args:
@@ -296,8 +326,8 @@ async def call(
296326
297327
Returns:
298328
Union[GenerationResponse,
299-
Generator[GenerationResponse, None, None]]: If
300-
stream is True, return Generator, otherwise GenerationResponse.
329+
AsyncGenerator[GenerationResponse, None]]: If
330+
stream is True, return AsyncGenerator, otherwise GenerationResponse.
301331
"""
302332
if (prompt is None or not prompt) and (messages is None
303333
or not messages):
@@ -314,6 +344,16 @@ async def call(
314344
kwargs['headers'] = headers
315345
input, parameters = Generation._build_input_parameters(
316346
model, prompt, history, messages, **kwargs)
347+
348+
is_stream = parameters.get('stream', False)
349+
# Check if we need to merge incremental output
350+
is_incremental_output = kwargs.get('incremental_output', None)
351+
to_merge_incremental_output = False
352+
if (ParamUtil.should_modify_incremental_output(model) and
353+
is_stream and is_incremental_output is False):
354+
to_merge_incremental_output = True
355+
parameters['incremental_output'] = True
356+
317357
response = await super().call(model=model,
318358
task_group=task_group,
319359
task=Generation.task,
@@ -322,9 +362,34 @@ async def call(
322362
input=input,
323363
workspace=workspace,
324364
**parameters)
325-
is_stream = kwargs.get('stream', False)
326365
if is_stream:
327-
return (GenerationResponse.from_api_response(rsp)
328-
async for rsp in response)
366+
if to_merge_incremental_output:
367+
# Extract n parameter for merge logic
368+
n = parameters.get('n', 1)
369+
return cls._merge_generation_response(response, n)
370+
else:
371+
return cls._stream_responses(response)
329372
else:
330373
return GenerationResponse.from_api_response(response)
374+
375+
@classmethod
376+
async def _stream_responses(cls, response) -> AsyncGenerator[GenerationResponse, None]:
377+
"""Convert async response stream to GenerationResponse stream."""
378+
# Type hint: when stream=True, response is actually an AsyncIterable
379+
async for rsp in response: # type: ignore
380+
yield GenerationResponse.from_api_response(rsp)
381+
382+
@classmethod
383+
async def _merge_generation_response(cls, response, n=1) -> AsyncGenerator[GenerationResponse, None]:
384+
"""Async version of merge incremental response chunks."""
385+
accumulated_data = {}
386+
387+
async for rsp in response: # type: ignore
388+
parsed_response = GenerationResponse.from_api_response(rsp)
389+
result = merge_single_response(parsed_response, accumulated_data, n)
390+
if result is True:
391+
yield parsed_response
392+
elif isinstance(result, list):
393+
# Multiple responses to yield (for n>1 non-stop cases)
394+
for resp in result:
395+
yield resp

dashscope/aigc/multimodal_conversation.py

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22

33
import copy
4-
from typing import Generator, List, Union
4+
from typing import AsyncGenerator, Generator, List, Union
55

66
from dashscope.api_entities.dashscope_response import \
77
MultiModalConversationResponse
88
from dashscope.client.base_api import BaseAioApi, BaseApi
99
from dashscope.common.error import InputRequired, ModelRequired
1010
from dashscope.common.utils import _get_task_group_and_task
1111
from dashscope.utils.oss_utils import preprocess_message_element
12+
from dashscope.utils.param_utils import ParamUtil
13+
from dashscope.utils.message_utils import merge_multimodal_single_response
1214

1315

1416
class MultiModalConversation(BaseApi):
@@ -108,6 +110,16 @@ def call(
108110
input.update({'language_type': language_type})
109111
if msg_copy is not None:
110112
input.update({'messages': msg_copy})
113+
114+
# Check if we need to merge incremental output
115+
is_incremental_output = kwargs.get('incremental_output', None)
116+
to_merge_incremental_output = False
117+
is_stream = kwargs.get('stream', False)
118+
if (ParamUtil.should_modify_incremental_output(model) and
119+
is_stream and is_incremental_output is not None and is_incremental_output is False):
120+
to_merge_incremental_output = True
121+
kwargs['incremental_output'] = True
122+
111123
response = super().call(model=model,
112124
task_group=task_group,
113125
task=MultiModalConversation.task,
@@ -116,10 +128,14 @@ def call(
116128
input=input,
117129
workspace=workspace,
118130
**kwargs)
119-
is_stream = kwargs.get('stream', False)
120131
if is_stream:
121-
return (MultiModalConversationResponse.from_api_response(rsp)
122-
for rsp in response)
132+
if to_merge_incremental_output:
133+
# Extract n parameter for merge logic
134+
n = kwargs.get('n', 1)
135+
return cls._merge_multimodal_response(response, n)
136+
else:
137+
return (MultiModalConversationResponse.from_api_response(rsp)
138+
for rsp in response)
123139
else:
124140
return MultiModalConversationResponse.from_api_response(response)
125141

@@ -149,6 +165,21 @@ def _preprocess_messages(cls, model: str, messages: List[dict],
149165
has_upload = True
150166
return has_upload
151167

168+
@classmethod
169+
def _merge_multimodal_response(cls, response, n=1) -> Generator[MultiModalConversationResponse, None, None]:
170+
"""Merge incremental response chunks to simulate non-incremental output."""
171+
accumulated_data = {}
172+
173+
for rsp in response:
174+
parsed_response = MultiModalConversationResponse.from_api_response(rsp)
175+
result = merge_multimodal_single_response(parsed_response, accumulated_data, n)
176+
if result is True:
177+
yield parsed_response
178+
elif isinstance(result, list):
179+
# Multiple responses to yield (for n>1 non-stop cases)
180+
for resp in result:
181+
yield resp
182+
152183

153184
class AioMultiModalConversation(BaseAioApi):
154185
"""Async MultiModal conversational robot interface.
@@ -170,8 +201,8 @@ async def call(
170201
voice: str = None,
171202
language_type: str = None,
172203
**kwargs
173-
) -> Union[MultiModalConversationResponse, Generator[
174-
MultiModalConversationResponse, None, None]]:
204+
) -> Union[MultiModalConversationResponse, AsyncGenerator[
205+
MultiModalConversationResponse, None]]:
175206
"""Call the conversation model service asynchronously.
176207
177208
Args:
@@ -221,8 +252,8 @@ async def call(
221252
222253
Returns:
223254
Union[MultiModalConversationResponse,
224-
Generator[MultiModalConversationResponse, None, None]]: If
225-
stream is True, return Generator, otherwise MultiModalConversationResponse.
255+
AsyncGenerator[MultiModalConversationResponse, None]]: If
256+
stream is True, return AsyncGenerator, otherwise MultiModalConversationResponse.
226257
"""
227258
if model is None or not model:
228259
raise ModelRequired('Model is required!')
@@ -246,6 +277,16 @@ async def call(
246277
input.update({'language_type': language_type})
247278
if msg_copy is not None:
248279
input.update({'messages': msg_copy})
280+
281+
# Check if we need to merge incremental output
282+
is_incremental_output = kwargs.get('incremental_output', None)
283+
to_merge_incremental_output = False
284+
is_stream = kwargs.get('stream', False)
285+
if (ParamUtil.should_modify_incremental_output(model) and
286+
is_stream and is_incremental_output is not None and is_incremental_output is False):
287+
to_merge_incremental_output = True
288+
kwargs['incremental_output'] = True
289+
249290
response = await super().call(model=model,
250291
task_group=task_group,
251292
task=AioMultiModalConversation.task,
@@ -254,10 +295,13 @@ async def call(
254295
input=input,
255296
workspace=workspace,
256297
**kwargs)
257-
is_stream = kwargs.get('stream', False)
258298
if is_stream:
259-
return (MultiModalConversationResponse.from_api_response(rsp)
260-
async for rsp in response)
299+
if to_merge_incremental_output:
300+
# Extract n parameter for merge logic
301+
n = kwargs.get('n', 1)
302+
return cls._merge_multimodal_response(response, n)
303+
else:
304+
return cls._stream_responses(response)
261305
else:
262306
return MultiModalConversationResponse.from_api_response(response)
263307

@@ -286,3 +330,27 @@ def _preprocess_messages(cls, model: str, messages: List[dict],
286330
if is_upload and not has_upload:
287331
has_upload = True
288332
return has_upload
333+
334+
@classmethod
335+
async def _stream_responses(cls, response) -> AsyncGenerator[MultiModalConversationResponse, None]:
336+
"""Convert async response stream to MultiModalConversationResponse stream."""
337+
# Type hint: when stream=True, response is actually an AsyncIterable
338+
async for rsp in response: # type: ignore
339+
yield MultiModalConversationResponse.from_api_response(rsp)
340+
341+
@classmethod
342+
async def _merge_multimodal_response(cls, response, n=1) -> AsyncGenerator[MultiModalConversationResponse, None]:
343+
"""Async version of merge incremental response chunks."""
344+
accumulated_data = {}
345+
346+
async for rsp in response:
347+
parsed_response = MultiModalConversationResponse.from_api_response(rsp)
348+
result = merge_multimodal_single_response(parsed_response, accumulated_data, n)
349+
if result is True:
350+
yield parsed_response
351+
elif isinstance(result, list):
352+
# Multiple responses to yield (for n>1 non-stop cases)
353+
for resp in result:
354+
yield resp
355+
356+

0 commit comments

Comments
 (0)