Skip to content

Commit 5412d24

Browse files
Explicit Caching (#355)
* *Inital prototype for explicit caching *Add basic CURD support for caching *Remove INPUT_ONLY marked fields from CachedContent dataclass *Rename files 'cached_content*' -> 'caching*' *Update 'Create' method for explicit instantination of 'CachedContent' *Add a factory method to instatinate model with `CachedContent` as its context *blacken *Add tests Change-Id: I694545243efda467d6fd599beded0dc6679b727d * rename get_cached_content to get * Stroke out functional approach for CachedContent CURD ops * blacken * Improve tests * fix tests Change-Id: I39f61012f850a82e09a7afb80b527a0f99ad0ec7 * fix tests Change-Id: I39f61012f850a82e09a7afb80b527a0f99ad0ec7 * Validate name checks for CachedContent creation Change-Id: Ie41602621d99ddff6404c6708c7278e0da790652 * Add tests Change-Id: I249188fa585bd9b7193efa48b1cfca20b8a79821 * mark name as OPTIONAL for CachedContent creation If not provided, the name will be randomly generated Change-Id: Ib95fbafd3dfe098b43164d7ee4d6c2a84b0aae2e * Add type-annotations to __new__ to fix pytype checks Change-Id: I6c69c036e54d56d18ea60368fa0a1dcda2d315fd * Add 'cached_content' to GenerativeModel's repr Change-Id: I06676fad23895e3e1a6393baa938fc1f2df57d80 * blacken Change-Id: I4e073d821d29eea30801bdb7e2a8dc01bb7d6b9a * Fix types Change-Id: Ia4bf6b936fab4c1992798c65cff91c15e51a92c0 * Fix docstrings Change-Id: I6020df4e862a4f1d58462a4cd70876a8448293cf * Fix types Change-Id: Id3e7316562f4029e5b7409ae725bb66e2207f075 * Fix types Change-Id: Id3e7316562f4029e5b7409ae725bb66e2207f075 * Refactor for genai.protos module Change-Id: I2f02d2421d7303f0309ec86f05d33c07332c03c1 * use preview build Change-Id: Ic1cd4fc28f591794dc5fbff0647a00a77ea7f601 --------- Co-authored-by: Mark Daoust <[email protected]>
1 parent 7b9758f commit 5412d24

File tree

7 files changed

+712
-3
lines changed

7 files changed

+712
-3
lines changed

google/generativeai/caching.py

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2024 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import annotations
16+
17+
import dataclasses
18+
import datetime
19+
from typing import Any, Iterable, Optional
20+
21+
from google.generativeai import protos
22+
from google.generativeai.types.model_types import idecode_time
23+
from google.generativeai.types import caching_types
24+
from google.generativeai.types import content_types
25+
from google.generativeai.utils import flatten_update_paths
26+
from google.generativeai.client import get_default_cache_client
27+
28+
from google.protobuf import field_mask_pb2
29+
import google.ai.generativelanguage as glm
30+
31+
32+
@dataclasses.dataclass
33+
class CachedContent:
34+
"""Cached content resource."""
35+
36+
name: str
37+
model: str
38+
create_time: datetime.datetime
39+
update_time: datetime.datetime
40+
expire_time: datetime.datetime
41+
42+
# NOTE: Automatic CachedContent deletion using contextmanager is not P0(P1+).
43+
# Adding basic support for now.
44+
def __enter__(self):
45+
return self
46+
47+
def __exit__(self, exc_type, exc_value, exc_tb):
48+
self.delete()
49+
50+
def _to_dict(self) -> protos.CachedContent:
51+
proto_paths = {
52+
"name": self.name,
53+
"model": self.model,
54+
}
55+
return protos.CachedContent(**proto_paths)
56+
57+
def _apply_update(self, path, value):
58+
parts = path.split(".")
59+
for part in parts[:-1]:
60+
self = getattr(self, part)
61+
if parts[-1] == "ttl":
62+
value = self.expire_time + datetime.timedelta(seconds=value["seconds"])
63+
parts[-1] = "expire_time"
64+
setattr(self, parts[-1], value)
65+
66+
@classmethod
67+
def _decode_cached_content(cls, cached_content: protos.CachedContent) -> CachedContent:
68+
# not supposed to get INPUT_ONLY repeated fields, but local gapic lib build
69+
# is returning these, hence setting including_default_value_fields to False
70+
cached_content = type(cached_content).to_dict(
71+
cached_content, including_default_value_fields=False
72+
)
73+
74+
idecode_time(cached_content, "create_time")
75+
idecode_time(cached_content, "update_time")
76+
# always decode `expire_time` as Timestamp is returned
77+
# regardless of what was sent on input
78+
idecode_time(cached_content, "expire_time")
79+
return cls(**cached_content)
80+
81+
@staticmethod
82+
def _prepare_create_request(
83+
model: str,
84+
name: str | None = None,
85+
system_instruction: Optional[content_types.ContentType] = None,
86+
contents: Optional[content_types.ContentsType] = None,
87+
tools: Optional[content_types.FunctionLibraryType] = None,
88+
tool_config: Optional[content_types.ToolConfigType] = None,
89+
ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1),
90+
) -> protos.CreateCachedContentRequest:
91+
"""Prepares a CreateCachedContentRequest."""
92+
if name is not None:
93+
if not caching_types.valid_cached_content_name(name):
94+
raise ValueError(caching_types.NAME_ERROR_MESSAGE.format(name=name))
95+
96+
name = "cachedContents/" + name
97+
98+
if "/" not in model:
99+
model = "models/" + model
100+
101+
if system_instruction:
102+
system_instruction = content_types.to_content(system_instruction)
103+
104+
tools_lib = content_types.to_function_library(tools)
105+
if tools_lib:
106+
tools_lib = tools_lib.to_proto()
107+
108+
if tool_config:
109+
tool_config = content_types.to_tool_config(tool_config)
110+
111+
if contents:
112+
contents = content_types.to_contents(contents)
113+
114+
if ttl:
115+
ttl = caching_types.to_ttl(ttl)
116+
117+
cached_content = protos.CachedContent(
118+
name=name,
119+
model=model,
120+
system_instruction=system_instruction,
121+
contents=contents,
122+
tools=tools_lib,
123+
tool_config=tool_config,
124+
ttl=ttl,
125+
)
126+
127+
return protos.CreateCachedContentRequest(cached_content=cached_content)
128+
129+
@classmethod
130+
def create(
131+
cls,
132+
model: str,
133+
name: str | None = None,
134+
system_instruction: Optional[content_types.ContentType] = None,
135+
contents: Optional[content_types.ContentsType] = None,
136+
tools: Optional[content_types.FunctionLibraryType] = None,
137+
tool_config: Optional[content_types.ToolConfigType] = None,
138+
ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1),
139+
client: glm.CacheServiceClient | None = None,
140+
) -> CachedContent:
141+
"""Creates `CachedContent` resource.
142+
143+
Args:
144+
model: The name of the `model` to use for cached content creation.
145+
Any `CachedContent` resource can be only used with the
146+
`model` it was created for.
147+
name: The resource name referring to the cached content.
148+
system_instruction: Developer set system instruction.
149+
contents: Contents to cache.
150+
tools: A list of `Tools` the model may use to generate response.
151+
tool_config: Config to apply to all tools.
152+
ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
153+
154+
Returns:
155+
`CachedContent` resource with specified name.
156+
"""
157+
if client is None:
158+
client = get_default_cache_client()
159+
160+
request = cls._prepare_create_request(
161+
model=model,
162+
name=name,
163+
system_instruction=system_instruction,
164+
contents=contents,
165+
tools=tools,
166+
tool_config=tool_config,
167+
ttl=ttl,
168+
)
169+
170+
response = client.create_cached_content(request)
171+
return cls._decode_cached_content(response)
172+
173+
@classmethod
174+
def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedContent:
175+
"""Fetches required `CachedContent` resource.
176+
177+
Args:
178+
name: The resource name referring to the cached content.
179+
180+
Returns:
181+
`CachedContent` resource with specified `name`.
182+
"""
183+
if client is None:
184+
client = get_default_cache_client()
185+
186+
if "cachedContents/" not in name:
187+
name = "cachedContents/" + name
188+
189+
request = protos.GetCachedContentRequest(name=name)
190+
response = client.get_cached_content(request)
191+
return cls._decode_cached_content(response)
192+
193+
@classmethod
194+
def list(
195+
cls, page_size: Optional[int] = 1, client: glm.CacheServiceClient | None = None
196+
) -> Iterable[CachedContent]:
197+
"""Lists `CachedContent` objects associated with the project.
198+
199+
Args:
200+
page_size: The maximum number of permissions to return (per page).
201+
The service may return fewer `CachedContent` objects.
202+
203+
Returns:
204+
A paginated list of `CachedContent` objects.
205+
"""
206+
if client is None:
207+
client = get_default_cache_client()
208+
209+
request = protos.ListCachedContentsRequest(page_size=page_size)
210+
for cached_content in client.list_cached_contents(request):
211+
yield cls._decode_cached_content(cached_content)
212+
213+
def delete(self, client: glm.CachedServiceClient | None = None) -> None:
214+
"""Deletes `CachedContent` resource."""
215+
if client is None:
216+
client = get_default_cache_client()
217+
218+
request = protos.DeleteCachedContentRequest(name=self.name)
219+
client.delete_cached_content(request)
220+
return
221+
222+
def update(
223+
self,
224+
updates: dict[str, Any],
225+
client: glm.CacheServiceClient | None = None,
226+
) -> CachedContent:
227+
"""Updates requested `CachedContent` resource.
228+
229+
Args:
230+
updates: The list of fields to update. Currently only
231+
`ttl/expire_time` is supported as an update path.
232+
233+
Returns:
234+
`CachedContent` object with specified updates.
235+
"""
236+
if client is None:
237+
client = get_default_cache_client()
238+
239+
updates = flatten_update_paths(updates)
240+
for update_path in updates:
241+
if update_path == "ttl":
242+
updates = updates.copy()
243+
update_path_val = updates.get(update_path)
244+
updates[update_path] = caching_types.to_ttl(update_path_val)
245+
else:
246+
raise ValueError(
247+
f"As of now, only `ttl` can be updated for `CachedContent`. Got: `{update_path}` instead."
248+
)
249+
field_mask = field_mask_pb2.FieldMask()
250+
251+
for path in updates.keys():
252+
field_mask.paths.append(path)
253+
for path, value in updates.items():
254+
self._apply_update(path, value)
255+
256+
request = protos.UpdateCachedContentRequest(
257+
cached_content=self._to_dict(), update_mask=field_mask
258+
)
259+
client.update_cached_content(request)
260+
return self

google/generativeai/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,10 @@ def configure(
315315
_client_manager.configure()
316316

317317

318+
def get_default_cache_client() -> glm.CacheServiceClient:
319+
return _client_manager.get_default_client("cache")
320+
321+
318322
def get_default_discuss_client() -> glm.DiscussServiceClient:
319323
return _client_manager.get_default_client("discuss")
320324

google/generativeai/generative_models.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from collections.abc import Iterable
66
import textwrap
7-
from typing import Any
7+
from typing import Any, Union, overload
88
import reprlib
99

1010
# pylint: disable=bad-continuation, line-too-long
@@ -13,6 +13,8 @@
1313
import google.api_core.exceptions
1414
from google.generativeai import protos
1515
from google.generativeai import client
16+
17+
from google.generativeai import caching
1618
from google.generativeai.types import content_types
1719
from google.generativeai.types import generation_types
1820
from google.generativeai.types import helper_types
@@ -94,6 +96,15 @@ def __init__(
9496
self._client = None
9597
self._async_client = None
9698

99+
def __new__(cls, *args, **kwargs) -> GenerativeModel:
100+
self = super().__new__(cls)
101+
102+
if cached_instance := kwargs.pop("cached_content", None):
103+
setattr(self, "_cached_content", cached_instance.name)
104+
setattr(cls, "cached_content", property(fget=lambda self: self._cached_content))
105+
106+
return self
107+
97108
@property
98109
def model_name(self):
99110
return self._model_name
@@ -112,6 +123,7 @@ def maybe_text(content):
112123
safety_settings={self._safety_settings},
113124
tools={self._tools},
114125
system_instruction={maybe_text(self._system_instruction)},
126+
cached_content={getattr(self, "cached_content", None)}
115127
)"""
116128
)
117129

@@ -127,6 +139,13 @@ def _prepare_request(
127139
tool_config: content_types.ToolConfigType | None,
128140
) -> protos.GenerateContentRequest:
129141
"""Creates a `protos.GenerateContentRequest` from raw inputs."""
142+
if hasattr(self, "cached_content") and any([self._system_instruction, tools, tool_config]):
143+
raise ValueError(
144+
"`tools`, `tool_config`, `system_instruction` cannot be set on a model instantinated with `cached_content` as its context."
145+
)
146+
147+
cached_content = getattr(self, "cached_content", None)
148+
130149
tools_lib = self._get_tools_lib(tools)
131150
if tools_lib is not None:
132151
tools_lib = tools_lib.to_proto()
@@ -155,6 +174,7 @@ def _prepare_request(
155174
tools=tools_lib,
156175
tool_config=tool_config,
157176
system_instruction=self._system_instruction,
177+
cached_content=cached_content,
158178
)
159179

160180
def _get_tools_lib(
@@ -165,6 +185,55 @@ def _get_tools_lib(
165185
else:
166186
return content_types.to_function_library(tools)
167187

188+
@overload
189+
@classmethod
190+
def from_cached_content(
191+
cls,
192+
cached_content: str,
193+
generation_config: generation_types.GenerationConfigType | None = None,
194+
safety_settings: safety_types.SafetySettingOptions | None = None,
195+
) -> GenerativeModel: ...
196+
197+
@overload
198+
@classmethod
199+
def from_cached_content(
200+
cls,
201+
cached_content: caching.CachedContent,
202+
generation_config: generation_types.GenerationConfigType | None = None,
203+
safety_settings: safety_types.SafetySettingOptions | None = None,
204+
) -> GenerativeModel: ...
205+
206+
@classmethod
207+
def from_cached_content(
208+
cls,
209+
cached_content: str | caching.CachedContent,
210+
generation_config: generation_types.GenerationConfigType | None = None,
211+
safety_settings: safety_types.SafetySettingOptions | None = None,
212+
) -> GenerativeModel:
213+
"""Creates a model with `cached_content` as model's context.
214+
215+
Args:
216+
cached_content: context for the model.
217+
218+
Returns:
219+
`GenerativeModel` object with `cached_content` as its context.
220+
"""
221+
if isinstance(cached_content, str):
222+
cached_content = caching.CachedContent.get(name=cached_content)
223+
224+
# call __new__ with the cached_content to set the model's context. This is done to avoid
225+
# the exposing `cached_content` as a public attribute.
226+
self = cls.__new__(cls, cached_content=cached_content)
227+
228+
# call __init__ to set the model's `generation_config`, `safety_settings`.
229+
# `model_name` will be the name of the model for which the `cached_content` was created.
230+
self.__init__(
231+
model_name=cached_content.model,
232+
generation_config=generation_config,
233+
safety_settings=safety_settings,
234+
)
235+
return self
236+
168237
def generate_content(
169238
self,
170239
contents: content_types.ContentsType,

0 commit comments

Comments
 (0)