15
15
from __future__ import annotations
16
16
17
17
import os
18
- from typing import cast , Optional , Union
18
+ import dataclasses
19
+ import types
20
+ from typing import Any , cast
21
+ from collections .abc import Sequence
19
22
20
23
import google .ai .generativelanguage as glm
21
24
26
29
27
30
from google .generativeai import version
28
31
29
-
30
32
USER_AGENT = "genai-py"
31
33
32
- default_client_config = {}
33
- default_discuss_client = None
34
- default_discuss_async_client = None
35
- default_model_client = None
36
- default_text_client = None
37
- default_operations_client = None
34
+
35
+ @dataclasses .dataclass
36
+ class _ClientManager :
37
+ client_config : dict [str , Any ] = dataclasses .field (default_factory = dict )
38
+ default_metadata : Sequence [tuple [str , str ]] = ()
39
+ discuss_client : glm .DiscussServiceClient | None = None
40
+ discuss_async_client : glm .DiscussServiceAsyncClient | None = None
41
+ model_client : glm .ModelServiceClient | None = None
42
+ text_client : glm .TextServiceClient | None = None
43
+ operations_client = None
44
+
45
+ def configure (
46
+ self ,
47
+ * ,
48
+ api_key : str | None = None ,
49
+ credentials : ga_credentials .Credentials | dict | None = None ,
50
+ # The user can pass a string to choose `rest` or `grpc` or 'grpc_asyncio'.
51
+ # See `_transport_registry` in `DiscussServiceClientMeta`.
52
+ # Since the transport classes align with the client classes it wouldn't make
53
+ # sense to accept a `Transport` object here even though the client classes can.
54
+ # We could accept a dict since all the `Transport` classes take the same args,
55
+ # but that seems rare. Users that need it can just switch to the low level API.
56
+ transport : str | None = None ,
57
+ client_options : client_options_lib .ClientOptions | dict | None = None ,
58
+ client_info : gapic_v1 .client_info .ClientInfo | None = None ,
59
+ default_metadata : Sequence [tuple [str , str ]] = (),
60
+ ) -> None :
61
+ """Captures default client configuration.
62
+
63
+ If no API key has been provided (either directly, or on `client_options`) and the
64
+ `GOOGLE_API_KEY` environment variable is set, it will be used as the API key.
65
+
66
+ Note: Not all arguments are detailed below. Refer to the `*ServiceClient` classes in
67
+ `google.ai.generativelanguage` for details on the other arguments.
68
+
69
+ Args:
70
+ transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
71
+ api_key: The API-Key to use when creating the default clients (each service uses
72
+ a separate client). This is a shortcut for `client_options={"api_key": api_key}`.
73
+ If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be
74
+ used.
75
+ default_metadata: Default (key, value) metadata pairs to send with every request.
76
+ when using `transport="rest"` these are sent as HTTP headers.
77
+ """
78
+ if isinstance (client_options , dict ):
79
+ client_options = client_options_lib .from_dict (client_options )
80
+ if client_options is None :
81
+ client_options = client_options_lib .ClientOptions ()
82
+ client_options = cast (client_options_lib .ClientOptions , client_options )
83
+ had_api_key_value = getattr (client_options , "api_key" , None )
84
+
85
+ if had_api_key_value :
86
+ if api_key is not None :
87
+ raise ValueError ("You can't set both `api_key` and `client_options['api_key']`." )
88
+ else :
89
+ if api_key is None :
90
+ # If no key is provided explicitly, attempt to load one from the
91
+ # environment.
92
+ api_key = os .getenv ("GOOGLE_API_KEY" )
93
+
94
+ client_options .api_key = api_key
95
+
96
+ user_agent = f"{ USER_AGENT } /{ version .__version__ } "
97
+ if client_info :
98
+ # Be respectful of any existing agent setting.
99
+ if client_info .user_agent :
100
+ client_info .user_agent += f" { user_agent } "
101
+ else :
102
+ client_info .user_agent = user_agent
103
+ else :
104
+ client_info = gapic_v1 .client_info .ClientInfo (user_agent = user_agent )
105
+
106
+ client_config = {
107
+ "credentials" : credentials ,
108
+ "transport" : transport ,
109
+ "client_options" : client_options ,
110
+ "client_info" : client_info ,
111
+ }
112
+
113
+ client_config = {key : value for key , value in client_config .items () if value is not None }
114
+
115
+ self .client_config = client_config
116
+ self .default_metadata = default_metadata
117
+ self .discuss_client = None
118
+ self .text_client = None
119
+ self .model_client = None
120
+ self .operations_client = None
121
+
122
+ def make_client (self , cls ):
123
+ # Attempt to configure using defaults.
124
+ if not self .client_config :
125
+ configure ()
126
+
127
+ client = cls (** self .client_config )
128
+
129
+ if not self .default_metadata :
130
+ return client
131
+
132
+ def keep (name , f ):
133
+ if name .startswith ("_" ):
134
+ return False
135
+ elif not isinstance (f , types .FunctionType ):
136
+ return False
137
+ elif isinstance (f , classmethod ):
138
+ return False
139
+ elif isinstance (f , staticmethod ):
140
+ return False
141
+ else :
142
+ return True
143
+
144
+ def add_default_metadata_wrapper (f ):
145
+ def call (* args , metadata = (), ** kwargs ):
146
+ metadata = list (metadata ) + list (self .default_metadata )
147
+ return f (* args , ** kwargs , metadata = metadata )
148
+
149
+ return call
150
+
151
+ for name , value in cls .__dict__ .items ():
152
+ if not keep (name , value ):
153
+ continue
154
+ f = getattr (client , name )
155
+ f = add_default_metadata_wrapper (f )
156
+ setattr (client , name , f )
157
+
158
+ return client
159
+
160
+ def get_default_discuss_client (self ) -> glm .DiscussServiceClient :
161
+ if self .discuss_client is None :
162
+ self .discuss_client = self .make_client (glm .DiscussServiceClient )
163
+ return self .discuss_client
164
+
165
+ def get_default_text_client (self ) -> glm .TextServiceClient :
166
+ if self .text_client is None :
167
+ self .text_client = self .make_client (glm .TextServiceClient )
168
+ return self .text_client
169
+
170
+ def get_default_discuss_async_client (self ) -> glm .DiscussServiceAsyncClient :
171
+ if self .discuss_async_client is None :
172
+ self .discuss_async_client = self .make_client (glm .DiscussServiceAsyncClient )
173
+ return self .discuss_async_client
174
+
175
+ def get_default_model_client (self ) -> glm .ModelServiceClient :
176
+ if self .model_client is None :
177
+ self .model_client = self .make_client (glm .ModelServiceClient )
178
+ return self .model_client
179
+
180
+ def get_default_operations_client (self ) -> operations_v1 .OperationsClient :
181
+ if self .operations_client is None :
182
+ self .model_client = get_default_model_client ()
183
+ self .operations_client = self .model_client ._transport .operations_client
184
+
185
+ return self .operations_client
186
+
187
+
188
+ _client_manager = _ClientManager ()
38
189
39
190
40
191
def configure (
@@ -50,119 +201,50 @@ def configure(
50
201
transport : str | None = None ,
51
202
client_options : client_options_lib .ClientOptions | dict | None = None ,
52
203
client_info : gapic_v1 .client_info .ClientInfo | None = None ,
204
+ default_metadata : Sequence [tuple [str , str ]] = (),
53
205
):
54
206
"""Captures default client configuration.
55
207
56
208
If no API key has been provided (either directly, or on `client_options`) and the
57
209
`GOOGLE_API_KEY` environment variable is set, it will be used as the API key.
58
210
211
+ Note: Not all arguments are detailed below. Refer to the `*ServiceClient` classes in
212
+ `google.ai.generativelanguage` for details on the other arguments.
213
+
59
214
Args:
60
- Refer to `glm.DiscussServiceClient `, and `glm.ModelsServiceClient` for details on additional arguments .
215
+ transport: A string, one of: [`rest `, `grpc`, `grpc_asyncio`] .
61
216
api_key: The API-Key to use when creating the default clients (each service uses
62
217
a separate client). This is a shortcut for `client_options={"api_key": api_key}`.
63
218
If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be
64
219
used.
220
+ default_metadata: Default (key, value) metadata pairs to send with every request.
221
+ when using `transport="rest"` these are sent as HTTP headers.
65
222
"""
66
- global default_client_config
67
- global default_discuss_client
68
- global default_model_client
69
- global default_text_client
70
- global default_operations_client
71
-
72
- if isinstance (client_options , dict ):
73
- client_options = client_options_lib .from_dict (client_options )
74
- if client_options is None :
75
- client_options = client_options_lib .ClientOptions ()
76
- client_options = cast (client_options_lib .ClientOptions , client_options )
77
- had_api_key_value = getattr (client_options , "api_key" , None )
78
-
79
- if had_api_key_value :
80
- if api_key is not None :
81
- raise ValueError ("You can't set both `api_key` and `client_options['api_key']`." )
82
- else :
83
- if api_key is None :
84
- # If no key is provided explicitly, attempt to load one from the
85
- # environment.
86
- api_key = os .getenv ("GOOGLE_API_KEY" )
87
-
88
- client_options .api_key = api_key
89
-
90
- user_agent = f"{ USER_AGENT } /{ version .__version__ } "
91
- if client_info :
92
- # Be respectful of any existing agent setting.
93
- if client_info .user_agent :
94
- client_info .user_agent += f" { user_agent } "
95
- else :
96
- client_info .user_agent = user_agent
97
- else :
98
- client_info = gapic_v1 .client_info .ClientInfo (user_agent = user_agent )
99
-
100
- new_default_client_config = {
101
- "credentials" : credentials ,
102
- "transport" : transport ,
103
- "client_options" : client_options ,
104
- "client_info" : client_info ,
105
- }
106
-
107
- new_default_client_config = {
108
- key : value for key , value in new_default_client_config .items () if value is not None
109
- }
110
-
111
- default_client_config = new_default_client_config
112
- default_discuss_client = None
113
- default_text_client = None
114
- default_model_client = None
115
- default_operations_client = None
223
+ return _client_manager .configure (
224
+ api_key = api_key ,
225
+ credentials = credentials ,
226
+ transport = transport ,
227
+ client_options = client_options ,
228
+ client_info = client_info ,
229
+ default_metadata = default_metadata ,
230
+ )
116
231
117
232
118
233
def get_default_discuss_client () -> glm .DiscussServiceClient :
119
- global default_discuss_client
120
- if default_discuss_client is None :
121
- # Attempt to configure using defaults.
122
- if not default_client_config :
123
- configure ()
124
- default_discuss_client = glm .DiscussServiceClient (** default_client_config )
125
-
126
- return default_discuss_client
234
+ return _client_manager .get_default_discuss_client ()
127
235
128
236
129
237
def get_default_text_client () -> glm .TextServiceClient :
130
- global default_text_client
131
- if default_text_client is None :
132
- # Attempt to configure using defaults.
133
- if not default_client_config :
134
- configure ()
135
- default_text_client = glm .TextServiceClient (** default_client_config )
136
-
137
- return default_text_client
238
+ return _client_manager .get_default_text_client ()
138
239
139
240
140
- def get_default_discuss_async_client () -> glm .DiscussServiceAsyncClient :
141
- global default_discuss_async_client
142
- if default_discuss_async_client is None :
143
- # Attempt to configure using defaults.
144
- if not default_client_config :
145
- configure ()
146
- default_discuss_async_client = glm .DiscussServiceAsyncClient (** default_client_config )
147
-
148
- return default_discuss_async_client
149
-
150
-
151
- def get_default_model_client () -> glm .ModelServiceClient :
152
- global default_model_client
153
- if default_model_client is None :
154
- # Attempt to configure using defaults.
155
- if not default_client_config :
156
- configure ()
157
- default_model_client = glm .ModelServiceClient (** default_client_config )
241
+ def get_default_operations_client () -> operations_v1 .OperationsClient :
242
+ return _client_manager .get_default_operations_client ()
158
243
159
- return default_model_client
160
244
245
+ def get_default_discuss_async_client () -> glm .DiscussServiceAsyncClient :
246
+ return _client_manager .get_default_discuss_async_client ()
161
247
162
- def get_default_operations_client () -> operations_v1 .OperationsClient :
163
- global default_operations_client
164
- if default_operations_client is None :
165
- model_client = get_default_model_client ()
166
- default_operations_client = model_client ._transport .operations_client
167
248
168
- return default_operations_client
249
+ def get_default_model_client () -> glm .ModelServiceAsyncClient :
250
+ return _client_manager .get_default_model_client ()
0 commit comments