Skip to content

Commit 18b29b2

Browse files
committed
参数为0的逻辑失效问题
1 parent 024986f commit 18b29b2

File tree

3 files changed

+121
-3
lines changed

3 files changed

+121
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "zhipuai"
3-
version = "2.0.1.20240426"
3+
version = "2.0.1.20240427"
44
description = "A SDK library for accessing big model apis from ZhipuAI"
55
authors = ["Zhipu AI"]
66
readme = "README.md"

tests/integration_tests/test_chat.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,112 @@
44
import zhipuai
55

66

7+
def test_completions_temp0():
8+
client = ZhipuAI() # 填写您自己的APIKey
9+
try:
10+
response = client.chat.completions.create(
11+
model="glm-4",
12+
messages=[
13+
{
14+
"role": "user",
15+
"content": "tell me a joke"
16+
}
17+
],
18+
top_p=0.7,
19+
temperature=0,
20+
max_tokens=2000,
21+
)
22+
print(response)
23+
24+
except zhipuai.core._errors.APIRequestFailedError as err:
25+
print(err)
26+
except zhipuai.core._errors.APIInternalError as err:
27+
print(err)
28+
except zhipuai.core._errors.APIStatusError as err:
29+
print(err)
30+
31+
32+
def test_completions_temp1():
33+
client = ZhipuAI() # 填写您自己的APIKey
34+
try:
35+
response = client.chat.completions.create(
36+
model="glm-4",
37+
messages=[
38+
{
39+
"role": "user",
40+
"content": "tell me a joke"
41+
}
42+
],
43+
top_p=0.7,
44+
temperature=1,
45+
max_tokens=2000,
46+
)
47+
print(response)
48+
49+
50+
51+
except zhipuai.core._errors.APIRequestFailedError as err:
52+
print(err)
53+
except zhipuai.core._errors.APIInternalError as err:
54+
print(err)
55+
except zhipuai.core._errors.APIStatusError as err:
56+
print(err)
57+
58+
59+
def test_completions_top0():
60+
client = ZhipuAI() # 填写您自己的APIKey
61+
try:
62+
response = client.chat.completions.create(
63+
model="glm-4",
64+
messages=[
65+
{
66+
"role": "user",
67+
"content": "tell me a joke"
68+
}
69+
],
70+
top_p=0,
71+
temperature=0.9,
72+
max_tokens=2000,
73+
)
74+
print(response)
75+
76+
77+
78+
except zhipuai.core._errors.APIRequestFailedError as err:
79+
print(err)
80+
except zhipuai.core._errors.APIInternalError as err:
81+
print(err)
82+
except zhipuai.core._errors.APIStatusError as err:
83+
print(err)
84+
85+
86+
def test_completions_top1():
87+
client = ZhipuAI() # 填写您自己的APIKey
88+
try:
89+
response = client.chat.completions.create(
90+
model="glm-4",
91+
messages=[
92+
{
93+
"role": "user",
94+
"content": "tell me a joke"
95+
}
96+
],
97+
top_p=1,
98+
temperature=0.9,
99+
max_tokens=2000,
100+
)
101+
print(response)
102+
103+
104+
105+
except zhipuai.core._errors.APIRequestFailedError as err:
106+
print(err)
107+
except zhipuai.core._errors.APIInternalError as err:
108+
print(err)
109+
except zhipuai.core._errors.APIStatusError as err:
110+
print(err)
111+
112+
7113
def test_completions():
8114
client = ZhipuAI() # 填写您自己的APIKey
9115
try:
@@ -38,6 +144,7 @@ def test_completions():
38144
except zhipuai.core._errors.APIStatusError as err:
39145
print(err)
40146

147+
41148
def test_completions_stream():
42149
client = ZhipuAI() # 填写您自己的APIKey
43150
try:
@@ -64,6 +171,7 @@ def test_completions_stream():
64171
except zhipuai.core._errors.APIStatusError as err:
65172
print(err)
66173

174+
67175
# Function to encode the image
68176
def encode_image(image_path):
69177
import base64
@@ -149,6 +257,7 @@ def test_completions_vis_base64(test_file_path):
149257
except zhipuai.core._errors.APIStatusError as err:
150258
print(err)
151259

260+
152261
def test_async_completions():
153262
client = ZhipuAI() # 请填写您自己的APIKey
154263
try:
@@ -183,6 +292,7 @@ def test_async_completions():
183292
except zhipuai.core._errors.APIStatusError as err:
184293
print(err)
185294

295+
186296
def test_retrieve_completion_result():
187297
client = ZhipuAI() # 请填写您自己的APIKey
188298
try:
@@ -196,3 +306,7 @@ def test_retrieve_completion_result():
196306
print(err)
197307
except zhipuai.core._errors.APIStatusError as err:
198308
print(err)
309+
310+
311+
if __name__ == '__main__':
312+
test_completions_top0()

zhipuai/api_resource/chat/completions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def create(
5050
_cast_type = object
5151
_stream_cls = StreamResponse[object]
5252

53-
if temperature:
53+
print(f"temperature:{temperature}")
54+
print(f"top_p:{top_p}")
55+
if temperature is not None and temperature != NOT_GIVEN:
5456

5557
if temperature <= 0:
5658
do_sample = False
@@ -60,7 +62,7 @@ def create(
6062
do_sample = False
6163
temperature = 0.99
6264
logger.warning("取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperture不生效)")
63-
if top_p:
65+
if top_p is not None and top_p != NOT_GIVEN:
6466

6567
if top_p >= 1:
6668
top_p = 0.99
@@ -69,6 +71,8 @@ def create(
6971
top_p = 0.01
7072
logger.warning("取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
7173

74+
print(f"temperature:{temperature}")
75+
print(f"top_p:{top_p}")
7276
if isinstance(messages, List):
7377
for item in messages:
7478
if item.get('content'):

0 commit comments

Comments
 (0)