Skip to content

Commit 46db06f

Browse files
authored
Add metadata handling. (#74)
* Add metadata handling. * Add and fix tests. * Resolve comments
1 parent e19b8b0 commit 46db06f

File tree

5 files changed

+236
-110
lines changed

5 files changed

+236
-110
lines changed

google/generativeai/client.py

Lines changed: 182 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
from __future__ import annotations
1616

1717
import os
18-
from typing import cast, Optional, Union
18+
import dataclasses
19+
import types
20+
from typing import Any, cast
21+
from collections.abc import Sequence
1922

2023
import google.ai.generativelanguage as glm
2124

@@ -26,15 +29,163 @@
2629

2730
from google.generativeai import version
2831

29-
3032
USER_AGENT = "genai-py"
3133

32-
default_client_config = {}
33-
default_discuss_client = None
34-
default_discuss_async_client = None
35-
default_model_client = None
36-
default_text_client = None
37-
default_operations_client = None
34+
35+
@dataclasses.dataclass
36+
class _ClientManager:
37+
client_config: dict[str, Any] = dataclasses.field(default_factory=dict)
38+
default_metadata: Sequence[tuple[str, str]] = ()
39+
discuss_client: glm.DiscussServiceClient | None = None
40+
discuss_async_client: glm.DiscussServiceAsyncClient | None = None
41+
model_client: glm.ModelServiceClient | None = None
42+
text_client: glm.TextServiceClient | None = None
43+
operations_client = None
44+
45+
def configure(
46+
self,
47+
*,
48+
api_key: str | None = None,
49+
credentials: ga_credentials.Credentials | dict | None = None,
50+
# The user can pass a string to choose `rest` or `grpc` or 'grpc_asyncio'.
51+
# See `_transport_registry` in `DiscussServiceClientMeta`.
52+
# Since the transport classes align with the client classes it wouldn't make
53+
# sense to accept a `Transport` object here even though the client classes can.
54+
# We could accept a dict since all the `Transport` classes take the same args,
55+
# but that seems rare. Users that need it can just switch to the low level API.
56+
transport: str | None = None,
57+
client_options: client_options_lib.ClientOptions | dict | None = None,
58+
client_info: gapic_v1.client_info.ClientInfo | None = None,
59+
default_metadata: Sequence[tuple[str, str]] = (),
60+
) -> None:
61+
"""Captures default client configuration.
62+
63+
If no API key has been provided (either directly, or on `client_options`) and the
64+
`GOOGLE_API_KEY` environment variable is set, it will be used as the API key.
65+
66+
Note: Not all arguments are detailed below. Refer to the `*ServiceClient` classes in
67+
`google.ai.generativelanguage` for details on the other arguments.
68+
69+
Args:
70+
transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
71+
api_key: The API-Key to use when creating the default clients (each service uses
72+
a separate client). This is a shortcut for `client_options={"api_key": api_key}`.
73+
If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be
74+
used.
75+
default_metadata: Default (key, value) metadata pairs to send with every request.
76+
when using `transport="rest"` these are sent as HTTP headers.
77+
"""
78+
if isinstance(client_options, dict):
79+
client_options = client_options_lib.from_dict(client_options)
80+
if client_options is None:
81+
client_options = client_options_lib.ClientOptions()
82+
client_options = cast(client_options_lib.ClientOptions, client_options)
83+
had_api_key_value = getattr(client_options, "api_key", None)
84+
85+
if had_api_key_value:
86+
if api_key is not None:
87+
raise ValueError("You can't set both `api_key` and `client_options['api_key']`.")
88+
else:
89+
if api_key is None:
90+
# If no key is provided explicitly, attempt to load one from the
91+
# environment.
92+
api_key = os.getenv("GOOGLE_API_KEY")
93+
94+
client_options.api_key = api_key
95+
96+
user_agent = f"{USER_AGENT}/{version.__version__}"
97+
if client_info:
98+
# Be respectful of any existing agent setting.
99+
if client_info.user_agent:
100+
client_info.user_agent += f" {user_agent}"
101+
else:
102+
client_info.user_agent = user_agent
103+
else:
104+
client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent)
105+
106+
client_config = {
107+
"credentials": credentials,
108+
"transport": transport,
109+
"client_options": client_options,
110+
"client_info": client_info,
111+
}
112+
113+
client_config = {key: value for key, value in client_config.items() if value is not None}
114+
115+
self.client_config = client_config
116+
self.default_metadata = default_metadata
117+
self.discuss_client = None
118+
self.text_client = None
119+
self.model_client = None
120+
self.operations_client = None
121+
122+
def make_client(self, cls):
123+
# Attempt to configure using defaults.
124+
if not self.client_config:
125+
configure()
126+
127+
client = cls(**self.client_config)
128+
129+
if not self.default_metadata:
130+
return client
131+
132+
def keep(name, f):
133+
if name.startswith("_"):
134+
return False
135+
elif not isinstance(f, types.FunctionType):
136+
return False
137+
elif isinstance(f, classmethod):
138+
return False
139+
elif isinstance(f, staticmethod):
140+
return False
141+
else:
142+
return True
143+
144+
def add_default_metadata_wrapper(f):
145+
def call(*args, metadata=(), **kwargs):
146+
metadata = list(metadata) + list(self.default_metadata)
147+
return f(*args, **kwargs, metadata=metadata)
148+
149+
return call
150+
151+
for name, value in cls.__dict__.items():
152+
if not keep(name, value):
153+
continue
154+
f = getattr(client, name)
155+
f = add_default_metadata_wrapper(f)
156+
setattr(client, name, f)
157+
158+
return client
159+
160+
def get_default_discuss_client(self) -> glm.DiscussServiceClient:
161+
if self.discuss_client is None:
162+
self.discuss_client = self.make_client(glm.DiscussServiceClient)
163+
return self.discuss_client
164+
165+
def get_default_text_client(self) -> glm.TextServiceClient:
166+
if self.text_client is None:
167+
self.text_client = self.make_client(glm.TextServiceClient)
168+
return self.text_client
169+
170+
def get_default_discuss_async_client(self) -> glm.DiscussServiceAsyncClient:
171+
if self.discuss_async_client is None:
172+
self.discuss_async_client = self.make_client(glm.DiscussServiceAsyncClient)
173+
return self.discuss_async_client
174+
175+
def get_default_model_client(self) -> glm.ModelServiceClient:
176+
if self.model_client is None:
177+
self.model_client = self.make_client(glm.ModelServiceClient)
178+
return self.model_client
179+
180+
def get_default_operations_client(self) -> operations_v1.OperationsClient:
181+
if self.operations_client is None:
182+
self.model_client = get_default_model_client()
183+
self.operations_client = self.model_client._transport.operations_client
184+
185+
return self.operations_client
186+
187+
188+
_client_manager = _ClientManager()
38189

39190

40191
def configure(
@@ -50,119 +201,50 @@ def configure(
50201
transport: str | None = None,
51202
client_options: client_options_lib.ClientOptions | dict | None = None,
52203
client_info: gapic_v1.client_info.ClientInfo | None = None,
204+
default_metadata: Sequence[tuple[str, str]] = (),
53205
):
54206
"""Captures default client configuration.
55207
56208
If no API key has been provided (either directly, or on `client_options`) and the
57209
`GOOGLE_API_KEY` environment variable is set, it will be used as the API key.
58210
211+
Note: Not all arguments are detailed below. Refer to the `*ServiceClient` classes in
212+
`google.ai.generativelanguage` for details on the other arguments.
213+
59214
Args:
60-
Refer to `glm.DiscussServiceClient`, and `glm.ModelsServiceClient` for details on additional arguments.
215+
transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
61216
api_key: The API-Key to use when creating the default clients (each service uses
62217
a separate client). This is a shortcut for `client_options={"api_key": api_key}`.
63218
If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be
64219
used.
220+
default_metadata: Default (key, value) metadata pairs to send with every request.
221+
when using `transport="rest"` these are sent as HTTP headers.
65222
"""
66-
global default_client_config
67-
global default_discuss_client
68-
global default_model_client
69-
global default_text_client
70-
global default_operations_client
71-
72-
if isinstance(client_options, dict):
73-
client_options = client_options_lib.from_dict(client_options)
74-
if client_options is None:
75-
client_options = client_options_lib.ClientOptions()
76-
client_options = cast(client_options_lib.ClientOptions, client_options)
77-
had_api_key_value = getattr(client_options, "api_key", None)
78-
79-
if had_api_key_value:
80-
if api_key is not None:
81-
raise ValueError("You can't set both `api_key` and `client_options['api_key']`.")
82-
else:
83-
if api_key is None:
84-
# If no key is provided explicitly, attempt to load one from the
85-
# environment.
86-
api_key = os.getenv("GOOGLE_API_KEY")
87-
88-
client_options.api_key = api_key
89-
90-
user_agent = f"{USER_AGENT}/{version.__version__}"
91-
if client_info:
92-
# Be respectful of any existing agent setting.
93-
if client_info.user_agent:
94-
client_info.user_agent += f" {user_agent}"
95-
else:
96-
client_info.user_agent = user_agent
97-
else:
98-
client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent)
99-
100-
new_default_client_config = {
101-
"credentials": credentials,
102-
"transport": transport,
103-
"client_options": client_options,
104-
"client_info": client_info,
105-
}
106-
107-
new_default_client_config = {
108-
key: value for key, value in new_default_client_config.items() if value is not None
109-
}
110-
111-
default_client_config = new_default_client_config
112-
default_discuss_client = None
113-
default_text_client = None
114-
default_model_client = None
115-
default_operations_client = None
223+
return _client_manager.configure(
224+
api_key=api_key,
225+
credentials=credentials,
226+
transport=transport,
227+
client_options=client_options,
228+
client_info=client_info,
229+
default_metadata=default_metadata,
230+
)
116231

117232

118233
def get_default_discuss_client() -> glm.DiscussServiceClient:
119-
global default_discuss_client
120-
if default_discuss_client is None:
121-
# Attempt to configure using defaults.
122-
if not default_client_config:
123-
configure()
124-
default_discuss_client = glm.DiscussServiceClient(**default_client_config)
125-
126-
return default_discuss_client
234+
return _client_manager.get_default_discuss_client()
127235

128236

129237
def get_default_text_client() -> glm.TextServiceClient:
130-
global default_text_client
131-
if default_text_client is None:
132-
# Attempt to configure using defaults.
133-
if not default_client_config:
134-
configure()
135-
default_text_client = glm.TextServiceClient(**default_client_config)
136-
137-
return default_text_client
238+
return _client_manager.get_default_text_client()
138239

139240

140-
def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient:
141-
global default_discuss_async_client
142-
if default_discuss_async_client is None:
143-
# Attempt to configure using defaults.
144-
if not default_client_config:
145-
configure()
146-
default_discuss_async_client = glm.DiscussServiceAsyncClient(**default_client_config)
147-
148-
return default_discuss_async_client
149-
150-
151-
def get_default_model_client() -> glm.ModelServiceClient:
152-
global default_model_client
153-
if default_model_client is None:
154-
# Attempt to configure using defaults.
155-
if not default_client_config:
156-
configure()
157-
default_model_client = glm.ModelServiceClient(**default_client_config)
241+
def get_default_operations_client() -> operations_v1.OperationsClient:
242+
return _client_manager.get_default_operations_client()
158243

159-
return default_model_client
160244

245+
def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient:
246+
return _client_manager.get_default_discuss_async_client()
161247

162-
def get_default_operations_client() -> operations_v1.OperationsClient:
163-
global default_operations_client
164-
if default_operations_client is None:
165-
model_client = get_default_model_client()
166-
default_operations_client = model_client._transport.operations_client
167248

168-
return default_operations_client
249+
def get_default_model_client() -> glm.ModelServiceAsyncClient:
250+
return _client_manager.get_default_model_client()

0 commit comments

Comments
 (0)