11from __future__ import annotations
22
3- from typing import Union , List , Optional , TYPE_CHECKING
3+ from typing import Union , List , Optional , TYPE_CHECKING , Dict
44
55import httpx
6+ import logging
67from typing_extensions import Literal
78
8- from ...core import BaseAPI
9+ from ...core import BaseAPI , maybe_transform , drop_prefix_image_data
910from ...core import NotGiven , NOT_GIVEN , Headers , Body
1011from ...core import make_request_options
1112from ...types .chat .async_chat_completion import AsyncTaskStatus , AsyncCompletion
13+ from ...types .chat .code_geex import code_geex_params
14+
15+ logger = logging .getLogger (__name__ )
1216
1317if TYPE_CHECKING :
1418 from ..._client import ZhipuAI
@@ -18,7 +22,6 @@ class AsyncCompletions(BaseAPI):
1822 def __init__ (self , client : "ZhipuAI" ) -> None :
1923 super ().__init__ (client )
2024
21-
2225 def create (
2326 self ,
2427 * ,
@@ -34,33 +37,58 @@ def create(
3437 sensitive_word_check : Optional [object ] | NotGiven = NOT_GIVEN ,
3538 tools : Optional [object ] | NotGiven = NOT_GIVEN ,
3639 tool_choice : str | NotGiven = NOT_GIVEN ,
37- meta : Optional [Dict [str ,str ]] | NotGiven = NOT_GIVEN ,
40+ meta : Optional [Dict [str , str ]] | NotGiven = NOT_GIVEN ,
41+ extra : Optional [code_geex_params .CodeGeexExtra ] | NotGiven = NOT_GIVEN ,
3842 extra_headers : Headers | None = None ,
3943 extra_body : Body | None = None ,
40- disable_strict_validation : Optional [bool ] | None = None ,
4144 timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
4245 ) -> AsyncTaskStatus :
4346 _cast_type = AsyncTaskStatus
47+ logger .debug (f"temperature:{ temperature } , top_p:{ top_p } " )
48+ if temperature is not None and temperature != NOT_GIVEN :
49+
50+ if temperature <= 0 :
51+ do_sample = False
52+ temperature = 0.01
53+ # logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperture不生效)")
54+ if temperature >= 1 :
55+ temperature = 0.99
56+ # logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间")
57+ if top_p is not None and top_p != NOT_GIVEN :
4458
45- if disable_strict_validation :
46- _cast_type = object
59+ if top_p >= 1 :
60+ top_p = 0.99
61+ # logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
62+ if top_p <= 0 :
63+ top_p = 0.01
64+ # logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
65+
66+ logger .debug (f"temperature:{ temperature } , top_p:{ top_p } " )
67+ if isinstance (messages , List ):
68+ for item in messages :
69+ if item .get ('content' ):
70+ item ['content' ] = drop_prefix_image_data (item ['content' ])
71+
72+ body = {
73+ "model" : model ,
74+ "request_id" : request_id ,
75+ "temperature" : temperature ,
76+ "top_p" : top_p ,
77+ "do_sample" : do_sample ,
78+ "max_tokens" : max_tokens ,
79+ "seed" : seed ,
80+ "messages" : messages ,
81+ "stop" : stop ,
82+ "sensitive_word_check" : sensitive_word_check ,
83+ "tools" : tools ,
84+ "tool_choice" : tool_choice ,
85+ "meta" : meta ,
86+ "extra" : maybe_transform (extra , code_geex_params .CodeGeexExtra ),
87+ }
4788 return self ._post (
4889 "/async/chat/completions" ,
49- body = {
50- "model" : model ,
51- "request_id" : request_id ,
52- "temperature" : temperature ,
53- "top_p" : top_p ,
54- "do_sample" : do_sample ,
55- "max_tokens" : max_tokens ,
56- "seed" : seed ,
57- "messages" : messages ,
58- "stop" : stop ,
59- "sensitive_word_check" : sensitive_word_check ,
60- "tools" : tools ,
61- "tool_choice" : tool_choice ,
62- "meta" : meta ,
63- },
90+
91+ body = body ,
6492 options = make_request_options (
6593 extra_headers = extra_headers , extra_body = extra_body , timeout = timeout
6694 ),
@@ -69,22 +97,17 @@ def create(
6997 )
7098
7199 def retrieve_completion_result (
72- self ,
73- id : str ,
74- extra_headers : Headers | None = None ,
75- extra_body : Body | None = None ,
76- disable_strict_validation : Optional [bool ] | None = None ,
77- timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
100+ self ,
101+ id : str ,
102+ extra_headers : Headers | None = None ,
103+ extra_body : Body | None = None ,
104+ timeout : float | httpx .Timeout | None | NotGiven = NOT_GIVEN ,
78105 ) -> Union [AsyncCompletion , AsyncTaskStatus ]:
79- _cast_type = Union [AsyncCompletion ,AsyncTaskStatus ]
80- if disable_strict_validation :
81- _cast_type = object
106+ _cast_type = Union [AsyncCompletion , AsyncTaskStatus ]
82107 return self ._get (
83108 path = f"/async-result/{ id } " ,
84109 cast_type = _cast_type ,
85110 options = make_request_options (
86111 extra_headers = extra_headers , extra_body = extra_body , timeout = timeout
87112 ),
88113 )
89-
90-
0 commit comments