Skip to content

Commit 0ae6633

Browse files
authored
Code geex (#38)
* codegeex代码示例 * codegeex增加异步参数
1 parent d165385 commit 0ae6633

File tree

7 files changed

+278
-68
lines changed

7 files changed

+278
-68
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import os.path
2+
3+
from zhipuai import ZhipuAI
4+
import zhipuai
5+
import time
6+
7+
import logging
8+
import logging.config
9+
10+
11+
def test_code_geex(logging_conf):
12+
logging.config.dictConfig(logging_conf) # type: ignore
13+
client = ZhipuAI() # 填写您自己的APIKey
14+
try:
15+
# 生成request_id
16+
request_id = time.time()
17+
print(f"request_id:{request_id}")
18+
response = client.chat.completions.create(
19+
request_id=request_id,
20+
model="codegeex-4",
21+
messages=[
22+
{
23+
"role": "system",
24+
"content": """你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。
25+
任务:请为输入代码提供格式规范的注释,包含多行注释和单行注释,请注意不要改动原始代码,只需要添加注释。
26+
请用中文回答。"""
27+
},
28+
{
29+
"role": "user",
30+
"content": """写一个快速排序函数"""
31+
}
32+
],
33+
top_p=0.7,
34+
temperature=0.9,
35+
max_tokens=2000,
36+
stop=["<|endoftext|>", "<|user|>", "<|assistant|>", "<|observation|>"],
37+
extra={
38+
"target": {
39+
"path": "11111",
40+
"language": "Python",
41+
"code_prefix": "EventSource.Factory factory = EventSources.createFactory(OkHttpUtils.getInstance());",
42+
"code_suffix": "TaskMonitorLocal taskMonitorLocal = getTaskMonitorLocal(algoMqReq);"
43+
},
44+
"contexts": [
45+
{
46+
"path": "/1/2",
47+
"code": "if(!sensitiveUser){ZpTraceUtils.addAsyncAttribute(algoMqReq.getTaskOrderNo(), ApiTraceProperty.request_params.getCode(), modelSendMap);"
48+
}
49+
50+
]
51+
}
52+
)
53+
print(response)
54+
55+
except zhipuai.core._errors.APIRequestFailedError as err:
56+
print(err)
57+
except zhipuai.core._errors.APIInternalError as err:
58+
print(err)
59+
except zhipuai.core._errors.APIStatusError as err:
60+
print(err)
61+
62+
63+
def test_code_geex_async(logging_conf):
64+
logging.config.dictConfig(logging_conf) # type: ignore
65+
client = ZhipuAI() # 填写您自己的APIKey
66+
try:
67+
# 生成request_id
68+
request_id = time.time()
69+
print(f"request_id:{request_id}")
70+
response = client.chat.asyncCompletions.create(
71+
request_id=request_id,
72+
model="codegeex-4",
73+
messages=[
74+
{
75+
"role": "system",
76+
"content": """你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。
77+
任务:请为输入代码提供格式规范的注释,包含多行注释和单行注释,请注意不要改动原始代码,只需要添加注释。
78+
请用中文回答。"""
79+
},
80+
{
81+
"role": "user",
82+
"content": """写一个快速排序函数"""
83+
}
84+
],
85+
top_p=0.7,
86+
temperature=0.9,
87+
max_tokens=2000,
88+
stop=["<|endoftext|>", "<|user|>", "<|assistant|>", "<|observation|>"],
89+
extra={
90+
"target": {
91+
"path": "11111",
92+
"language": "Python",
93+
"code_prefix": "EventSource.Factory factory = EventSources.createFactory(OkHttpUtils.getInstance());",
94+
"code_suffix": "TaskMonitorLocal taskMonitorLocal = getTaskMonitorLocal(algoMqReq);"
95+
},
96+
"contexts": [
97+
{
98+
"path": "/1/2",
99+
"code": "if(!sensitiveUser){ZpTraceUtils.addAsyncAttribute(algoMqReq.getTaskOrderNo(), ApiTraceProperty.request_params.getCode(), modelSendMap);"
100+
}
101+
102+
]
103+
}
104+
)
105+
print(response)
106+
107+
except zhipuai.core._errors.APIRequestFailedError as err:
108+
print(err)
109+
except zhipuai.core._errors.APIInternalError as err:
110+
print(err)
111+
except zhipuai.core._errors.APIStatusError as err:
112+
print(err)
113+
114+
115+
def test_geex_result(logging_conf):
116+
logging.config.dictConfig(logging_conf) # type: ignore
117+
client = ZhipuAI() # 请填写您自己的APIKey
118+
try:
119+
response = client.chat.asyncCompletions.retrieve_completion_result(id="1014908807577524653187108")
120+
print(response)
121+
122+
123+
except zhipuai.core._errors.APIRequestFailedError as err:
124+
print(err)
125+
except zhipuai.core._errors.APIInternalError as err:
126+
print(err)
127+
except zhipuai.core._errors.APIStatusError as err:
128+
print(err)

zhipuai/api_resource/chat/async_completions.py

Lines changed: 56 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
from __future__ import annotations
22

3-
from typing import Union, List, Optional, TYPE_CHECKING
3+
from typing import Union, List, Optional, TYPE_CHECKING, Dict
44

55
import httpx
6+
import logging
67
from typing_extensions import Literal
78

8-
from ...core import BaseAPI
9+
from ...core import BaseAPI, maybe_transform, drop_prefix_image_data
910
from ...core import NotGiven, NOT_GIVEN, Headers, Body
1011
from ...core import make_request_options
1112
from ...types.chat.async_chat_completion import AsyncTaskStatus, AsyncCompletion
13+
from ...types.chat.code_geex import code_geex_params
14+
15+
logger = logging.getLogger(__name__)
1216

1317
if TYPE_CHECKING:
1418
from ..._client import ZhipuAI
@@ -18,7 +22,6 @@ class AsyncCompletions(BaseAPI):
1822
def __init__(self, client: "ZhipuAI") -> None:
1923
super().__init__(client)
2024

21-
2225
def create(
2326
self,
2427
*,
@@ -34,33 +37,58 @@ def create(
3437
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
3538
tools: Optional[object] | NotGiven = NOT_GIVEN,
3639
tool_choice: str | NotGiven = NOT_GIVEN,
37-
meta: Optional[Dict[str,str]] | NotGiven = NOT_GIVEN,
40+
meta: Optional[Dict[str, str]] | NotGiven = NOT_GIVEN,
41+
extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN,
3842
extra_headers: Headers | None = None,
3943
extra_body: Body | None = None,
40-
disable_strict_validation: Optional[bool] | None = None,
4144
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
4245
) -> AsyncTaskStatus:
4346
_cast_type = AsyncTaskStatus
47+
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
48+
if temperature is not None and temperature != NOT_GIVEN:
49+
50+
if temperature <= 0:
51+
do_sample = False
52+
temperature = 0.01
53+
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperture不生效)")
54+
if temperature >= 1:
55+
temperature = 0.99
56+
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间")
57+
if top_p is not None and top_p != NOT_GIVEN:
4458

45-
if disable_strict_validation:
46-
_cast_type = object
59+
if top_p >= 1:
60+
top_p = 0.99
61+
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
62+
if top_p <= 0:
63+
top_p = 0.01
64+
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
65+
66+
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
67+
if isinstance(messages, List):
68+
for item in messages:
69+
if item.get('content'):
70+
item['content'] = drop_prefix_image_data(item['content'])
71+
72+
body = {
73+
"model": model,
74+
"request_id": request_id,
75+
"temperature": temperature,
76+
"top_p": top_p,
77+
"do_sample": do_sample,
78+
"max_tokens": max_tokens,
79+
"seed": seed,
80+
"messages": messages,
81+
"stop": stop,
82+
"sensitive_word_check": sensitive_word_check,
83+
"tools": tools,
84+
"tool_choice": tool_choice,
85+
"meta": meta,
86+
"extra": maybe_transform(extra, code_geex_params.CodeGeexExtra),
87+
}
4788
return self._post(
4889
"/async/chat/completions",
49-
body={
50-
"model": model,
51-
"request_id": request_id,
52-
"temperature": temperature,
53-
"top_p": top_p,
54-
"do_sample": do_sample,
55-
"max_tokens": max_tokens,
56-
"seed": seed,
57-
"messages": messages,
58-
"stop": stop,
59-
"sensitive_word_check": sensitive_word_check,
60-
"tools": tools,
61-
"tool_choice": tool_choice,
62-
"meta": meta,
63-
},
90+
91+
body=body,
6492
options=make_request_options(
6593
extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
6694
),
@@ -69,22 +97,17 @@ def create(
6997
)
7098

7199
def retrieve_completion_result(
72-
self,
73-
id: str,
74-
extra_headers: Headers | None = None,
75-
extra_body: Body | None = None,
76-
disable_strict_validation: Optional[bool] | None = None,
77-
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
100+
self,
101+
id: str,
102+
extra_headers: Headers | None = None,
103+
extra_body: Body | None = None,
104+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
78105
) -> Union[AsyncCompletion, AsyncTaskStatus]:
79-
_cast_type = Union[AsyncCompletion,AsyncTaskStatus]
80-
if disable_strict_validation:
81-
_cast_type = object
106+
_cast_type = Union[AsyncCompletion, AsyncTaskStatus]
82107
return self._get(
83108
path=f"/async-result/{id}",
84109
cast_type=_cast_type,
85110
options=make_request_options(
86111
extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
87112
),
88113
)
89-
90-

zhipuai/api_resource/chat/completions.py

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
from __future__ import annotations
22

3-
from typing import Union, List, Optional, TYPE_CHECKING
3+
from typing import Union, List, Optional, TYPE_CHECKING, Dict
44

55
import httpx
66
import logging
77
from typing_extensions import Literal
88

9-
from ...core import BaseAPI
9+
from ...core import BaseAPI, deepcopy_minimal, maybe_transform, drop_prefix_image_data
1010
from ...core import NotGiven, NOT_GIVEN, Headers, Query, Body
1111
from ...core import make_request_options
1212
from ...core import StreamResponse
1313
from ...types.chat.chat_completion import Completion
1414
from ...types.chat.chat_completion_chunk import ChatCompletionChunk
15+
from ...types.chat.code_geex import code_geex_params
1516

1617
logger = logging.getLogger(__name__)
1718

@@ -40,6 +41,7 @@ def create(
4041
tools: Optional[object] | NotGiven = NOT_GIVEN,
4142
tool_choice: str | NotGiven = NOT_GIVEN,
4243
meta: Optional[Dict[str,str]] | NotGiven = NOT_GIVEN,
44+
extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN,
4345
extra_headers: Headers | None = None,
4446
extra_body: Body | None = None,
4547
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
@@ -67,26 +69,28 @@ def create(
6769
if isinstance(messages, List):
6870
for item in messages:
6971
if item.get('content'):
70-
item['content'] = self._drop_prefix_image_data(item['content'])
72+
item['content'] = drop_prefix_image_data(item['content'])
7173

74+
body = deepcopy_minimal({
75+
"model": model,
76+
"request_id": request_id,
77+
"temperature": temperature,
78+
"top_p": top_p,
79+
"do_sample": do_sample,
80+
"max_tokens": max_tokens,
81+
"seed": seed,
82+
"messages": messages,
83+
"stop": stop,
84+
"sensitive_word_check": sensitive_word_check,
85+
"stream": stream,
86+
"tools": tools,
87+
"tool_choice": tool_choice,
88+
"meta": meta,
89+
"extra": maybe_transform(extra, code_geex_params.CodeGeexExtra),
90+
})
7291
return self._post(
7392
"/chat/completions",
74-
body={
75-
"model": model,
76-
"request_id": request_id,
77-
"temperature": temperature,
78-
"top_p": top_p,
79-
"do_sample": do_sample,
80-
"max_tokens": max_tokens,
81-
"seed": seed,
82-
"messages": messages,
83-
"stop": stop,
84-
"sensitive_word_check": sensitive_word_check,
85-
"stream": stream,
86-
"tools": tools,
87-
"tool_choice": tool_choice,
88-
"meta": meta,
89-
},
93+
body=body,
9094
options=make_request_options(
9195
extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
9296
),
@@ -95,19 +99,4 @@ def create(
9599
stream_cls=StreamResponse[ChatCompletionChunk],
96100
)
97101

98-
def _drop_prefix_image_data(self, content: Union[str,List[dict]]) -> Union[str,List[dict]]:
99-
"""
100-
删除 ;base64, 前缀
101-
:param image_data:
102-
:return:
103-
"""
104-
if isinstance(content, List):
105-
for data in content:
106-
if data.get('type') == 'image_url':
107-
image_data = data.get("image_url").get("url")
108-
if image_data.startswith("data:image/"):
109-
image_data = image_data.split("base64,")[-1]
110-
data["image_url"]["url"] = image_data
111-
112-
return content
113102

zhipuai/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
maybe_transform,
5959
deepcopy_minimal,
6060
extract_files,
61+
drop_prefix_image_data,
6162
)
6263

6364
from ._sse_client import StreamResponse

zhipuai/core/_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
get_required_header as get_required_header,
3030
maybe_coerce_boolean as maybe_coerce_boolean,
3131
maybe_coerce_integer as maybe_coerce_integer,
32+
drop_prefix_image_data as drop_prefix_image_data,
3233
)
3334

3435

0 commit comments

Comments
 (0)