Skip to content

Commit c7aded5

Browse files
authored
[Add] QwenTranscription (#74)
1 parent 4565d97 commit c7aded5

File tree

2 files changed

+196
-0
lines changed

2 files changed

+196
-0
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
3+
from .qwen_transcription import (QwenTranscription)
4+
5+
__all__ = [
6+
'QwenTranscription',
7+
]
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
3+
import asyncio
4+
import time
5+
from typing import List, Union
6+
7+
import aiohttp
8+
9+
from dashscope.api_entities.dashscope_response import (DashScopeAPIResponse,
10+
TranscriptionResponse)
11+
from dashscope.client.base_api import BaseAsyncApi
12+
from dashscope.common.constants import ApiProtocol, HTTPMethod
13+
from dashscope.common.logging import logger
14+
15+
16+
class QwenTranscription(BaseAsyncApi):
17+
"""API for File Transcription models.
18+
"""
19+
20+
MAX_QUERY_TRY_COUNT = 3
21+
22+
@classmethod
23+
def call(cls,
24+
model: str,
25+
file_url: str,
26+
api_key: str = None,
27+
workspace: str = None,
28+
**kwargs) -> TranscriptionResponse:
29+
"""Transcribe the given files synchronously.
30+
31+
Args:
32+
model (str): The requested model_id.
33+
file_url (str): stored URL.
34+
workspace (str): The dashscope workspace id.
35+
36+
Returns:
37+
TranscriptionResponse: The result of batch transcription.
38+
"""
39+
kwargs = cls._tidy_kwargs(**kwargs)
40+
response = super().call(model,
41+
file_url,
42+
api_key=api_key,
43+
workspace=workspace,
44+
**kwargs)
45+
return TranscriptionResponse.from_api_response(response)
46+
47+
@classmethod
48+
def async_call(cls,
49+
model: str,
50+
file_url: str,
51+
api_key: str = None,
52+
workspace: str = None,
53+
**kwargs) -> TranscriptionResponse:
54+
"""Transcribe the given files asynchronously,
55+
return the status of task submission for querying results subsequently.
56+
57+
Args:
58+
model (str): The requested model, such as paraformer-16k-1
59+
file_url (str): stored URL.
60+
workspace (str): The dashscope workspace id.
61+
62+
Returns:
63+
TranscriptionResponse: The response including task_id.
64+
"""
65+
kwargs = cls._tidy_kwargs(**kwargs)
66+
response = cls._launch_request(model,
67+
file_url,
68+
api_key=api_key,
69+
workspace=workspace,
70+
**kwargs)
71+
return TranscriptionResponse.from_api_response(response)
72+
73+
@classmethod
74+
def fetch(cls,
75+
task: Union[str, TranscriptionResponse],
76+
api_key: str = None,
77+
workspace: str = None,
78+
**kwargs) -> TranscriptionResponse:
79+
"""Fetch the status of task, including results of batch transcription when task_status is SUCCEEDED. # noqa: E501
80+
81+
Args:
82+
task (Union[str, TranscriptionResponse]): The task_id or
83+
response including task_id returned from async_call().
84+
workspace (str): The dashscope workspace id.
85+
86+
Returns:
87+
TranscriptionResponse: The status of task_id,
88+
including results of batch transcription when task_status is SUCCEEDED.
89+
"""
90+
try_count: int = 0
91+
while True:
92+
try:
93+
response = super().fetch(task,
94+
api_key=api_key,
95+
workspace=workspace,
96+
**kwargs)
97+
except (asyncio.TimeoutError, aiohttp.ClientConnectorError) as e:
98+
logger.error(e)
99+
try_count += 1
100+
if try_count <= QwenTranscription.MAX_QUERY_TRY_COUNT:
101+
time.sleep(2)
102+
continue
103+
104+
try_count = 0
105+
break
106+
107+
return TranscriptionResponse.from_api_response(response)
108+
109+
@classmethod
110+
def wait(cls,
111+
task: Union[str, TranscriptionResponse],
112+
api_key: str = None,
113+
workspace: str = None,
114+
**kwargs) -> TranscriptionResponse:
115+
"""Poll task until the final results of transcription is obtained.
116+
117+
Args:
118+
task (Union[str, TranscriptionResponse]): The task_id or
119+
response including task_id returned from async_call().
120+
workspace (str): The dashscope workspace id.
121+
122+
Returns:
123+
TranscriptionResponse: The result of batch transcription.
124+
"""
125+
response = super().wait(task,
126+
api_key=api_key,
127+
workspace=workspace,
128+
**kwargs)
129+
return TranscriptionResponse.from_api_response(response)
130+
131+
@classmethod
132+
def _launch_request(cls,
133+
model: str,
134+
file: str,
135+
api_key: str = None,
136+
workspace: str = None,
137+
**kwargs) -> DashScopeAPIResponse:
138+
"""Submit transcribe request.
139+
140+
Args:
141+
model (str): The requested model, such as paraformer-16k-1
142+
files (List[str]): List of stored URLs.
143+
workspace (str): The dashscope workspace id.
144+
145+
Returns:
146+
DashScopeAPIResponse: The result of task submission.
147+
"""
148+
149+
try_count: int = 0
150+
while True:
151+
try:
152+
response = super().async_call(model=model,
153+
task_group='audio',
154+
task='asr',
155+
function='transcription',
156+
input={'file_url': file},
157+
api_protocol=ApiProtocol.HTTP,
158+
http_method=HTTPMethod.POST,
159+
api_key=api_key,
160+
workspace=workspace,
161+
**kwargs)
162+
except (asyncio.TimeoutError, aiohttp.ClientConnectorError) as e:
163+
logger.error(e)
164+
try_count += 1
165+
if try_count <= QwenTranscription.MAX_QUERY_TRY_COUNT:
166+
time.sleep(2)
167+
continue
168+
break
169+
170+
return response
171+
172+
@classmethod
173+
def _fill_resource_id(cls, phrase_id: str, **kwargs):
174+
resources_list: list = []
175+
if phrase_id is not None and len(phrase_id) > 0:
176+
item = {'resource_id': phrase_id, 'resource_type': 'asr_phrase'}
177+
resources_list.append(item)
178+
179+
if len(resources_list) > 0:
180+
kwargs['resources'] = resources_list
181+
182+
return kwargs
183+
184+
@classmethod
185+
def _tidy_kwargs(cls, **kwargs):
186+
for k in kwargs.copy():
187+
if kwargs[k] is None:
188+
kwargs.pop(k, None)
189+
return kwargs

0 commit comments

Comments
 (0)