Skip to content

Commit 5b3bba7

Browse files
authored
Merge pull request #45 from mose-x/t2i_support_sync_call
T2i support sync call...
2 parents 187611e + 4c7579d commit 5b3bba7

File tree

9 files changed

+1145
-108
lines changed

9 files changed

+1145
-108
lines changed

dashscope/aigc/image_synthesis.py

Lines changed: 353 additions & 48 deletions
Large diffs are not rendered by default.

dashscope/aigc/video_synthesis.py

Lines changed: 281 additions & 32 deletions
Large diffs are not rendered by default.

dashscope/api_entities/api_request_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,4 @@ def _build_api_request(model: str,
127127
request_data.add_resources(resources)
128128
request_data.add_parameters(**kwargs)
129129
request.data = request_data
130-
return request
130+
return request

dashscope/api_entities/dashscope_response.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,8 @@ class ImageSynthesisOutput(DictMixin):
445445
results: List[ImageSynthesisResult]
446446

447447
def __init__(self,
448-
task_id: str,
449-
task_status: str,
448+
task_id: str = None,
449+
task_status: str = None,
450450
results: List[ImageSynthesisResult] = [],
451451
**kwargs):
452452
res = []

dashscope/api_entities/http_request.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22

33
import json
4+
import ssl
45
from http import HTTPStatus
56
from typing import Optional
67

78
import aiohttp
9+
import certifi
810
import requests
911

1012
from dashscope.api_entities.base_request import AioBaseRequest
@@ -119,12 +121,18 @@ async def aio_call(self):
119121

120122
async def _handle_aio_request(self):
121123
try:
124+
connector = aiohttp.TCPConnector(
125+
ssl=ssl.create_default_context(
126+
cafile=certifi.where()))
122127
async with aiohttp.ClientSession(
128+
connector=connector,
123129
timeout=aiohttp.ClientTimeout(total=self.timeout),
124130
headers=self.headers) as session:
125131
logger.debug('Starting request: %s' % self.url)
126132
if self.method == HTTPMethod.POST:
127-
is_form, obj = self.data.get_aiohttp_payload()
133+
is_form, obj = False, {}
134+
if hasattr(self, 'data') and self.data is not None:
135+
is_form, obj = self.data.get_aiohttp_payload()
128136
if is_form:
129137
headers = {**self.headers, **obj.headers}
130138
response = await session.post(url=self.url,
@@ -136,8 +144,12 @@ async def _handle_aio_request(self):
136144
json=obj,
137145
headers=self.headers)
138146
elif self.method == HTTPMethod.GET:
147+
# 添加条件判断
148+
params = {}
149+
if hasattr(self, 'data') and self.data is not None:
150+
params = getattr(self.data, 'parameters', {})
139151
response = await session.get(url=self.url,
140-
params=self.data.parameters,
152+
params=params,
141153
headers=self.headers)
142154
else:
143155
raise UnsupportedHTTPMethod('Unsupported http method: %s' %
@@ -211,6 +223,12 @@ async def _handle_aio_response(self, response: aiohttp.ClientResponse):
211223
usage = None
212224
if 'output' in json_content and json_content['output'] is not None:
213225
output = json_content['output']
226+
# Compatible with wan
227+
elif 'data' in json_content and json_content['data'] is not None\
228+
and isinstance(json_content['data'], list)\
229+
and len(json_content['data']) > 0\
230+
and 'task_id' in json_content['data'][0]:
231+
output = json_content
214232
if 'usage' in json_content:
215233
usage = json_content['usage']
216234
if 'request_id' in json_content:

dashscope/client/base_api.py

Lines changed: 295 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2-
2+
import asyncio
3+
import collections
34
import time
45
from http import HTTPStatus
56
from typing import Any, Dict, Iterator, List, Union
@@ -13,16 +14,307 @@
1314
from dashscope.common.constants import (DEFAULT_REQUEST_TIMEOUT_SECONDS,
1415
REPEATABLE_STATUS,
1516
REQUEST_TIMEOUT_KEYWORD,
16-
SSE_CONTENT_TYPE, TaskStatus)
17+
SSE_CONTENT_TYPE, TaskStatus, HTTPMethod)
1718
from dashscope.common.error import InvalidParameter, InvalidTask, ModelRequired
1819
from dashscope.common.logging import logger
1920
from dashscope.common.utils import (_handle_http_failed_response,
2021
_handle_http_response,
2122
_handle_http_stream_response,
2223
default_headers, join_url)
2324

25+
class AsyncAioTaskGetMixin:
26+
@classmethod
27+
async def _get(cls,
28+
task_id: str,
29+
api_key: str = None,
30+
workspace: str = None,
31+
**kwargs) -> DashScopeAPIResponse:
32+
base_url = kwargs.pop('base_address', None)
33+
url = _normalization_url(base_url, 'tasks', task_id)
34+
kwargs = cls._handle_kwargs(api_key, workspace, **kwargs)
35+
kwargs["base_address"] = url
36+
if not api_key:
37+
api_key = get_default_api_key()
38+
request = _build_api_request("", "", "",
39+
"", "", api_key=api_key,
40+
is_service=False, **kwargs)
41+
return await cls._handle_request(request)
42+
43+
@classmethod
44+
def _handle_kwargs(cls, api_key: str = None ,workspace: str = None, **kwargs):
45+
custom_headers = kwargs.pop('headers', None)
46+
headers = {
47+
**_workspace_header(workspace),
48+
**default_headers(api_key),
49+
}
50+
if custom_headers:
51+
headers = {
52+
**custom_headers,
53+
**headers,
54+
}
55+
if workspace is not None:
56+
headers = {
57+
'X-DashScope-WorkSpace': workspace,
58+
**kwargs.pop('headers', {})
59+
}
60+
kwargs['headers'] = headers
61+
kwargs['http_method'] = HTTPMethod.GET
62+
return kwargs
63+
64+
@classmethod
65+
async def _handle_request(cls, request):
66+
# 如果 aio_call 返回的是异步生成器,则需要从中获取响应
67+
response = await request.aio_call()
68+
# 处理异步生成器的情况
69+
if isinstance(response, collections.abc.AsyncGenerator):
70+
result = None
71+
async for item in response:
72+
result = item
73+
return result
74+
else:
75+
return response
76+
77+
class BaseAsyncAioApi(AsyncAioTaskGetMixin):
78+
"""BaseApi, internal use only.
79+
80+
"""
81+
@classmethod
82+
def _validate_params(cls, api_key, model):
83+
if api_key is None:
84+
api_key = get_default_api_key()
85+
if model is None or not model:
86+
raise ModelRequired('Model is required!')
87+
return api_key, model
88+
89+
@classmethod
90+
async def async_call(cls,
91+
model: str,
92+
input: object,
93+
task_group: str,
94+
task: str = None,
95+
function: str = None,
96+
api_key: str = None,
97+
workspace: str = None,
98+
**kwargs) -> DashScopeAPIResponse:
99+
api_key, model = cls._validate_params(api_key, model)
100+
if workspace is not None:
101+
headers = {
102+
'X-DashScope-WorkSpace': workspace,
103+
**kwargs.pop('headers', {})
104+
}
105+
kwargs['headers'] = headers
106+
kwargs['async_request'] = True
107+
request = _build_api_request(model=model,
108+
input=input,
109+
task_group=task_group,
110+
task=task,
111+
function=function,
112+
api_key=api_key,
113+
**kwargs)
114+
# call request service.
115+
return await request.aio_call()
116+
117+
@classmethod
118+
async def call(cls,
119+
model: str,
120+
input: object,
121+
task_group: str,
122+
task: str = None,
123+
function: str = None,
124+
api_key: str = None,
125+
workspace: str = None,
126+
**kwargs) -> DashScopeAPIResponse:
127+
# call request service.
128+
response = await BaseAsyncAioApi.async_call(model, input, task_group, task,
129+
function, api_key, workspace,
130+
**kwargs)
131+
response = await BaseAsyncAioApi.wait(response,
132+
api_key=api_key,
133+
workspace=workspace,
134+
**kwargs)
135+
return response
136+
137+
138+
@classmethod
139+
def _get_task_id(cls, task):
140+
if isinstance(task, str):
141+
task_id = task
142+
elif isinstance(task, DashScopeAPIResponse):
143+
if task.status_code == HTTPStatus.OK:
144+
task_id = task.output['task_id']
145+
else:
146+
raise InvalidTask('Invalid task, task create failed: %s' %
147+
task)
148+
else:
149+
raise InvalidParameter('Task invalid!')
150+
if task_id is None or task_id == '':
151+
raise InvalidParameter('Task id required!')
152+
return task_id
153+
154+
@classmethod
155+
async def wait(cls,
156+
task: Union[str, DashScopeAPIResponse],
157+
api_key: str = None,
158+
workspace: str = None,
159+
**kwargs) -> DashScopeAPIResponse:
160+
"""Wait for async task completion and return task result.
161+
162+
Args:
163+
task (Union[str, DashScopeAPIResponse]): The task_id, or
164+
async_call response.
165+
api_key (str, optional): The api_key. Defaults to None.
166+
167+
Returns:
168+
DashScopeAPIResponse: The async task information.
169+
"""
170+
task_id = cls._get_task_id(task)
171+
wait_seconds = 1
172+
max_wait_seconds = 5
173+
increment_steps = 3
174+
step = 0
175+
while True:
176+
step += 1
177+
# we start by querying once every second, and double
178+
# the query interval after every 3(increment_steps)
179+
# intervals, until we hit the max waiting interval
180+
# of 5(seconds)
181+
# (server side return immediately when ready)
182+
if wait_seconds < max_wait_seconds and step % increment_steps == 0:
183+
wait_seconds = min(wait_seconds * 2, max_wait_seconds)
184+
rsp = await cls._get(task_id, api_key, workspace=workspace, **kwargs)
185+
if rsp.status_code == HTTPStatus.OK:
186+
if rsp.output is None:
187+
return rsp
188+
189+
task_status = rsp.output['task_status']
190+
if task_status in [
191+
TaskStatus.FAILED, TaskStatus.CANCELED,
192+
TaskStatus.SUCCEEDED, TaskStatus.UNKNOWN
193+
]:
194+
return rsp
195+
else:
196+
logger.info('The task %s is %s' % (task_id, task_status))
197+
await asyncio.sleep(wait_seconds) # 异步等待
198+
elif rsp.status_code in REPEATABLE_STATUS:
199+
logger.warn(
200+
('Get task: %s temporary failure, \
201+
status_code: %s, code: %s message: %s, will try again.'
202+
) % (task_id, rsp.status_code, rsp.code, rsp.message))
203+
await asyncio.sleep(wait_seconds) # 异步等待
204+
else:
205+
return rsp
206+
207+
@classmethod
208+
async def cancel(
209+
cls,
210+
task: Union[str, DashScopeAPIResponse],
211+
api_key: str = None,
212+
workspace: str = None,
213+
**kwargs,
214+
) -> DashScopeAPIResponse:
215+
"""Cancel PENDING task.
216+
217+
Args:
218+
task (Union[str, DashScopeAPIResponse]): The task_id, or
219+
async_call response.
220+
api_key (str, optional): The api-key. Defaults to None.
221+
222+
Returns:
223+
DashScopeAPIResponse: The cancel result.
224+
"""
225+
task_id = cls._get_task_id(task)
226+
base_url = kwargs.pop('base_address', None)
227+
url = _normalization_url(base_url, 'tasks', task_id, 'cancel')
228+
kwargs = cls._handle_kwargs(api_key, workspace, **kwargs)
229+
kwargs["base_address"] = url
230+
if not api_key:
231+
api_key = get_default_api_key()
232+
request = _build_api_request("", "", "",
233+
"", "",api_key=api_key,
234+
is_service=False, **kwargs)
235+
return await cls._handle_request(request)
236+
237+
@classmethod
238+
async def list(cls,
239+
start_time: str = None,
240+
end_time: str = None,
241+
model_name: str = None,
242+
api_key_id: str = None,
243+
region: str = None,
244+
status: str = None,
245+
page_no: int = 1,
246+
page_size: int = 10,
247+
api_key: str = None,
248+
workspace: str = None,
249+
**kwargs) -> DashScopeAPIResponse:
250+
"""List async tasks.
251+
252+
Args:
253+
start_time (str, optional): The tasks start time,
254+
for example: 20230420000000. Defaults to None.
255+
end_time (str, optional): The tasks end time,
256+
for example: 20230420000000. Defaults to None.
257+
model_name (str, optional): The tasks model name.
258+
Defaults to None.
259+
api_key_id (str, optional): The tasks api-key-id.
260+
Defaults to None.
261+
region (str, optional): The service region,
262+
for example: cn-beijing. Defaults to None.
263+
status (str, optional): The status of tasks[PENDING,
264+
RUNNING, SUCCEEDED, FAILED, CANCELED]. Defaults to None.
265+
page_no (int, optional): The page number. Defaults to 1.
266+
page_size (int, optional): The page size. Defaults to 10.
267+
api_key (str, optional): The user api-key. Defaults to None.
268+
269+
Returns:
270+
DashScopeAPIResponse: The response data.
271+
"""
272+
base_url = kwargs.pop('base_address', None)
273+
url = _normalization_url(base_url, 'tasks')
274+
params = {'page_no': page_no, 'page_size': page_size}
275+
if start_time is not None:
276+
params['start_time'] = start_time
277+
if end_time is not None:
278+
params['end_time'] = end_time
279+
if model_name is not None:
280+
params['model_name'] = model_name
281+
if api_key_id is not None:
282+
params['api_key_id'] = api_key_id
283+
if region is not None:
284+
params['region'] = region
285+
if status is not None:
286+
params['status'] = status
287+
kwargs = cls._handle_kwargs(api_key, workspace, **kwargs)
288+
kwargs["base_address"] = url
289+
if not api_key:
290+
api_key = get_default_api_key()
291+
request = _build_api_request(model_name, "", "",
292+
"", "", api_key=api_key,
293+
is_service=False, extra_url_parameters=params,
294+
**kwargs)
295+
return await cls._handle_request(request)
296+
297+
@classmethod
298+
async def fetch(cls,
299+
task: Union[str, DashScopeAPIResponse],
300+
api_key: str = None,
301+
workspace: str = None,
302+
**kwargs) -> DashScopeAPIResponse:
303+
"""Query async task status.
304+
305+
Args:
306+
task (Union[str, DashScopeAPIResponse]): The task_id, or
307+
async_call response.
308+
api_key (str, optional): The api_key. Defaults to None.
309+
310+
Returns:
311+
DashScopeAPIResponse: The async task information.
312+
"""
313+
task_id = cls._get_task_id(task)
314+
return await cls._get(task_id, api_key, workspace, **kwargs)
315+
24316

25-
class BaseAioApi():
317+
class BaseAioApi:
26318
"""BaseApi, internal use only.
27319
28320
"""

0 commit comments

Comments
 (0)