Skip to content

Commit 4569cd1

Browse files
authored
Merge pull request #75 from MetaGLM/feature/cogtts-0106
Feature/cogtts 0106
2 parents 4bec3a4 + 712e278 commit 4569cd1

File tree

8 files changed

+213
-4
lines changed

8 files changed

+213
-4
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from zhipuai import ZhipuAI
2+
import zhipuai
3+
4+
import logging
5+
import logging.config
6+
from pathlib import Path
7+
8+
9+
def test_audio_speech(logging_conf):
10+
logging.config.dictConfig(logging_conf) # type: ignore
11+
client = ZhipuAI() # 填写您自己的APIKey
12+
try:
13+
speech_file_path = Path(__file__).parent / "speech.wav"
14+
response = client.audio.speech(
15+
model="cogtts",
16+
input="你好呀,欢迎来到智谱开放平台",
17+
voice="female",
18+
response_format="wav"
19+
)
20+
response.stream_to_file(speech_file_path)
21+
22+
except zhipuai.core._errors.APIRequestFailedError as err:
23+
print(err)
24+
except zhipuai.core._errors.APIInternalError as err:
25+
print(err)
26+
except zhipuai.core._errors.APIStatusError as err:
27+
print(err)
28+
29+
def test_audio_customization(logging_conf):
30+
logging.config.dictConfig(logging_conf)
31+
client = ZhipuAI() # 填写您自己的APIKey
32+
with open('/Users/jhy/Desktop/tts/test_case_8s.wav', 'rb') as file:
33+
try:
34+
speech_file_path = Path(__file__).parent / "customization.wav"
35+
response = client.audio.customization(
36+
model="cogtts",
37+
input="你好呀,欢迎来到智谱开放平台",
38+
voice_text="这是一条测试用例",
39+
voice_data=file,
40+
response_format="wav"
41+
)
42+
response.stream_to_file(speech_file_path)
43+
44+
except zhipuai.core._errors.APIRequestFailedError as err:
45+
print(err)
46+
except zhipuai.core._errors.APIInternalError as err:
47+
print(err)
48+
except zhipuai.core._errors.APIStatusError as err:
49+
print(err)

tests/integration_tests/test_videos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def test_videos(logging_conf):
1010
client = ZhipuAI() # 填写您自己的APIKey
1111
try:
1212
response = client.videos.generations(
13-
model="cogvideo",
13+
model="cogvideox",
1414
prompt="一个开船的人",
1515

1616
user_id="1212222"

zhipuai/api_resource/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@
4949
Agents
5050
)
5151

52+
from .audio import (
53+
Audio
54+
)
55+
5256
__all__ = [
5357
'Videos',
5458
'AsyncCompletions',
Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,112 @@
11
from __future__ import annotations
22

3-
from ...core import BaseAPI, cached_property
3+
from typing import TYPE_CHECKING, List, Mapping, cast, Optional, Dict
44
from .transcriptions import Transcriptions
55

6+
from zhipuai.core._utils import extract_files
7+
8+
from zhipuai.types.sensitive_word_check import SensitiveWordCheckRequest
9+
from zhipuai.types.audio import AudioSpeechParams
10+
from ...types.audio import audio_customization_param
611

12+
from zhipuai.core import BaseAPI, maybe_transform
13+
from zhipuai.core import NOT_GIVEN, Body, Headers, NotGiven, FileTypes
14+
from zhipuai.core import _legacy_response
15+
16+
import httpx
17+
from ...core import BaseAPI, cached_property
18+
19+
from zhipuai.core import (
20+
make_request_options,
21+
)
22+
from zhipuai.core import deepcopy_minimal
23+
24+
if TYPE_CHECKING:
25+
from zhipuai._client import ZhipuAI
726

827
__all__ = ["Audio"]
28+
29+
930
class Audio(BaseAPI):
31+
1032
@cached_property
1133
def transcriptions(self) -> Transcriptions:
1234
return Transcriptions(self._client)
35+
36+
def __init__(self, client: "ZhipuAI") -> None:
37+
super().__init__(client)
38+
39+
def speech(
40+
self,
41+
*,
42+
model: str,
43+
input: str = None,
44+
voice: str = None,
45+
response_format: str = None,
46+
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
47+
request_id: str = None,
48+
user_id: str = None,
49+
extra_headers: Headers | None = None,
50+
extra_body: Body | None = None,
51+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
52+
) -> _legacy_response.HttpxBinaryResponseContent:
53+
body = deepcopy_minimal(
54+
{
55+
"model": model,
56+
"input": input,
57+
"voice": voice,
58+
"response_format": response_format,
59+
"sensitive_word_check": sensitive_word_check,
60+
"request_id": request_id,
61+
"user_id": user_id
62+
}
63+
)
64+
return self._post(
65+
"/audio/speech",
66+
body=maybe_transform(body, AudioSpeechParams),
67+
options=make_request_options(
68+
extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
69+
),
70+
cast_type=_legacy_response.HttpxBinaryResponseContent
71+
)
72+
73+
def customization(
74+
self,
75+
*,
76+
model: str,
77+
input: str = None,
78+
voice_text: str = None,
79+
voice_data: FileTypes = None,
80+
response_format: str = None,
81+
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
82+
request_id: str = None,
83+
user_id: str = None,
84+
extra_headers: Headers | None = None,
85+
extra_body: Body | None = None,
86+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
87+
) -> _legacy_response.HttpxBinaryResponseContent:
88+
body = deepcopy_minimal(
89+
{
90+
"model": model,
91+
"input": input,
92+
"voice_text": voice_text,
93+
"voice_data": voice_data,
94+
"response_format": response_format,
95+
"sensitive_word_check": sensitive_word_check,
96+
"request_id": request_id,
97+
"user_id": user_id
98+
}
99+
)
100+
files = extract_files(cast(Mapping[str, object], body), paths=[["voice_data"]])
101+
102+
if files:
103+
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
104+
return self._post(
105+
"/audio/customization",
106+
body=maybe_transform(body, audio_customization_param.AudioCustomizationParam),
107+
files=files,
108+
options=make_request_options(
109+
extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
110+
),
111+
cast_type=_legacy_response.HttpxBinaryResponseContent
112+
)

zhipuai/core/_sse_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __stream__(self) -> Iterator[ResponseT]:
5454
data = sse.json_data()
5555
if isinstance(data, Mapping) and data.get("agent_id"):
5656
yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response)
57-
break
57+
continue
5858
if isinstance(data, Mapping) and data.get("error"):
5959
raise APIResponseError(
6060
message="An error occurred during streaming",

zhipuai/types/audio/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1+
from .audio_speech_params import(
2+
AudioSpeechParams
3+
)
14

5+
from .audio_customization_param import(
6+
AudioCustomizationParam
7+
)
28
from .transcriptions_create_param import(
39
TranscriptionsParam
410
)
511

6-
__all__ = ["TranscriptionsParam"]
12+
__all__ = ["AudioSpeechParams","AudioCustomizationParam","TranscriptionsParam"]
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from __future__ import annotations
2+
3+
from typing import List, Optional
4+
5+
from typing_extensions import Literal, Required, TypedDict
6+
__all__ = ["AudioCustomizationParam"]
7+
8+
from ..sensitive_word_check import SensitiveWordCheckRequest
9+
10+
class AudioCustomizationParam(TypedDict, total=False):
11+
model: str
12+
"""模型编码"""
13+
input: str
14+
"""需要生成语音的文本"""
15+
voice_text: str
16+
"""需要生成语音的音色"""
17+
response_format: str
18+
"""需要生成语音文件的格式"""
19+
sensitive_word_check: Optional[SensitiveWordCheckRequest]
20+
request_id: str
21+
"""由用户端传参,需保证唯一性;用于区分每次请求的唯一标识,用户端不传时平台会默认生成。"""
22+
user_id: str
23+
"""用户端。"""
24+
25+
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from __future__ import annotations
2+
3+
from typing import List, Optional
4+
5+
from typing_extensions import Literal, Required, TypedDict
6+
7+
__all__ = ["AudioSpeechParams"]
8+
9+
from ..sensitive_word_check import SensitiveWordCheckRequest
10+
11+
12+
class AudioSpeechParams(TypedDict, total=False):
13+
model: str
14+
"""模型编码"""
15+
input: str
16+
"""需要生成语音的文本"""
17+
voice: str
18+
"""需要生成语音的音色"""
19+
response_format: str
20+
"""需要生成语音文件的格式"""
21+
sensitive_word_check: Optional[SensitiveWordCheckRequest]
22+
request_id: str
23+
"""由用户端传参,需保证唯一性;用于区分每次请求的唯一标识,用户端不传时平台会默认生成。"""
24+
user_id: str
25+
"""用户端。"""

0 commit comments

Comments
 (0)