Skip to content

Commit 5c26cbf

Browse files
authored
增加apikey与token认证逻辑 (#31)
1 parent 769f8ed commit 5c26cbf

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

zhipuai/_client.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
class ZhipuAI(HttpClient):
1818
chat: api_resource.chat.Chat
1919
api_key: str
20+
_disable_token_cache: bool = True
2021

2122
def __init__(
2223
self,
@@ -26,13 +27,15 @@ def __init__(
2627
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
2728
max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
2829
http_client: httpx.Client | None = None,
29-
custom_headers: Mapping[str, str] | None = None
30+
custom_headers: Mapping[str, str] | None = None,
31+
disable_token_cache: bool = True
3032
) -> None:
3133
if api_key is None:
3234
api_key = os.environ.get("ZHIPUAI_API_KEY")
3335
if api_key is None:
3436
raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供")
3537
self.api_key = api_key
38+
self._disable_token_cache = disable_token_cache
3639

3740
if base_url is None:
3841
base_url = os.environ.get("ZHIPUAI_BASE_URL")
@@ -57,8 +60,10 @@ def __init__(
5760
@override
5861
def auth_headers(self) -> dict[str, str]:
5962
api_key = self.api_key
60-
# return {"Authorization": f"{_jwt_token.generate_token(api_key)}"}
61-
return {"Authorization": f"{api_key}"}
63+
if self._disable_token_cache:
64+
return {"Authorization": f"{api_key}"}
65+
else:
66+
return {"Authorization": f"{_jwt_token.generate_token(api_key)}"}
6267

6368
def __del__(self) -> None:
6469
if (not hasattr(self, "_has_custom_http_client")

zhipuai/core/_jwt_token.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import cachetools.func
55
import jwt
66

7-
API_TOKEN_TTL_SECONDS = 3 * 60
7+
# 缓存时间 3分钟
8+
CACHE_TTL_SECONDS = 3 * 60
89

9-
CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30
10+
# token 有效期比缓存时间 多30秒
11+
API_TOKEN_TTL_SECONDS = CACHE_TTL_SECONDS + 30
1012

1113

1214
@cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS)

0 commit comments

Comments
 (0)