1
- # -*- coding: utf-8 -*-
2
- # Copyright 2023 Google LLC
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
1
from __future__ import annotations
16
2
17
3
import os
27
13
from google .api_core import gapic_v1
28
14
from google .api_core import operations_v1
29
15
30
- from google .generativeai import version
16
+ try :
17
+ from google .generativeai import version
18
+
19
+ __version__ = version .__version__
20
+ except ImportError :
21
+ __version__ = "0.0.0"
31
22
32
23
USER_AGENT = "genai-py"
33
24
36
27
class _ClientManager :
37
28
client_config : dict [str , Any ] = dataclasses .field (default_factory = dict )
38
29
default_metadata : Sequence [tuple [str , str ]] = ()
30
+
39
31
discuss_client : glm .DiscussServiceClient | None = None
40
32
discuss_async_client : glm .DiscussServiceAsyncClient | None = None
41
- model_client : glm .ModelServiceClient | None = None
42
- text_client : glm .TextServiceClient | None = None
43
- operations_client = None
33
+ clients : dict [str , Any ] = dataclasses .field (default_factory = dict )
44
34
45
35
def configure (
46
36
self ,
@@ -54,7 +44,7 @@ def configure(
54
44
# We could accept a dict since all the `Transport` classes take the same args,
55
45
# but that seems rare. Users that need it can just switch to the low level API.
56
46
transport : str | None = None ,
57
- client_options : client_options_lib .ClientOptions | dict | None = None ,
47
+ client_options : client_options_lib .ClientOptions | dict [ str , Any ] | None = None ,
58
48
client_info : gapic_v1 .client_info .ClientInfo | None = None ,
59
49
default_metadata : Sequence [tuple [str , str ]] = (),
60
50
) -> None :
@@ -93,7 +83,7 @@ def configure(
93
83
94
84
client_options .api_key = api_key
95
85
96
- user_agent = f"{ USER_AGENT } /{ version . __version__ } "
86
+ user_agent = f"{ USER_AGENT } /{ __version__ } "
97
87
if client_info :
98
88
# Be respectful of any existing agent setting.
99
89
if client_info .user_agent :
@@ -114,12 +104,16 @@ def configure(
114
104
115
105
self .client_config = client_config
116
106
self .default_metadata = default_metadata
117
- self .discuss_client = None
118
- self .text_client = None
119
- self .model_client = None
120
- self .operations_client = None
121
107
122
- def make_client (self , cls ):
108
+ self .clients = {}
109
+
110
+ def make_client (self , name ):
111
+ if name .endswith ("_async" ):
112
+ name = name .split ("_" )[0 ]
113
+ cls = getattr (glm , name .title () + "ServiceAsyncClient" )
114
+ else :
115
+ cls = getattr (glm , name .title () + "ServiceClient" )
116
+
123
117
# Attempt to configure using defaults.
124
118
if not self .client_config :
125
119
configure ()
@@ -157,35 +151,25 @@ def call(*args, metadata=(), **kwargs):
157
151
158
152
return client
159
153
160
- def get_default_discuss_client (self ) -> glm .DiscussServiceClient :
161
- if self .discuss_client is None :
162
- self .discuss_client = self .make_client (glm .DiscussServiceClient )
163
- return self .discuss_client
164
-
165
- def get_default_text_client (self ) -> glm .TextServiceClient :
166
- if self .text_client is None :
167
- self .text_client = self .make_client (glm .TextServiceClient )
168
- return self .text_client
169
-
170
- def get_default_discuss_async_client (self ) -> glm .DiscussServiceAsyncClient :
171
- if self .discuss_async_client is None :
172
- self .discuss_async_client = self .make_client (glm .DiscussServiceAsyncClient )
173
- return self .discuss_async_client
154
+ def get_default_client (self , name ):
155
+ name = name .lower ()
156
+ if name == "operations" :
157
+ return self .get_default_operations_client ()
174
158
175
- def get_default_model_client (self ) -> glm .ModelServiceClient :
176
- if self .model_client is None :
177
- self .model_client = self .make_client (glm .ModelServiceClient )
178
- return self .model_client
159
+ client = self .clients .get (name )
160
+ if client is None :
161
+ client = self .make_client (name )
162
+ self .clients [name ] = client
163
+ return client
179
164
180
165
def get_default_operations_client (self ) -> operations_v1 .OperationsClient :
181
- if self .operations_client is None :
182
- self .model_client = get_default_model_client ()
183
- self .operations_client = self .model_client ._transport .operations_client
184
-
185
- return self .operations_client
186
-
166
+ client = self .clients .get ("operations" , None )
167
+ if client is None :
168
+ model_client = self .get_default_client ("Model" )
169
+ client = model_client ._transport .operations_client
170
+ self .clients ["operations" ] = client
187
171
188
- _client_manager = _ClientManager ()
172
+ return client
189
173
190
174
191
175
def configure (
@@ -230,21 +214,33 @@ def configure(
230
214
)
231
215
232
216
217
+ _client_manager = _ClientManager ()
218
+ _client_manager .configure ()
219
+
220
+
233
221
def get_default_discuss_client () -> glm .DiscussServiceClient :
234
- return _client_manager .get_default_discuss_client ( )
222
+ return _client_manager .get_default_client ( "discuss" )
235
223
236
224
237
- def get_default_text_client () -> glm .TextServiceClient :
238
- return _client_manager .get_default_text_client ( )
225
+ def get_default_discuss_async_client () -> glm .DiscussServiceAsyncClient :
226
+ return _client_manager .get_default_client ( "discuss_async" )
239
227
240
228
241
- def get_default_operations_client () -> operations_v1 . OperationsClient :
242
- return _client_manager .get_default_operations_client ( )
229
+ def get_default_generative_client () -> glm . GenerativeServiceClient :
230
+ return _client_manager .get_default_client ( "generative" )
243
231
244
232
245
- def get_default_discuss_async_client () -> glm .DiscussServiceAsyncClient :
246
- return _client_manager .get_default_discuss_async_client ()
233
+ def get_default_generative_async_client () -> glm .GenerativeServiceAsyncClient :
234
+ return _client_manager .get_default_client ("generative_async" )
235
+
236
+
237
+ def get_default_text_client () -> glm .TextServiceClient :
238
+ return _client_manager .get_default_client ("text" )
239
+
240
+
241
+ def get_default_operations_client () -> operations_v1 .OperationsClient :
242
+ return _client_manager .get_default_client ("operations" )
247
243
248
244
249
245
def get_default_model_client () -> glm .ModelServiceAsyncClient :
250
- return _client_manager .get_default_model_client ( )
246
+ return _client_manager .get_default_client ( "model" )
0 commit comments