diff --git a/lisette/core.py b/lisette/core.py index e1f8d77..7fc4bee 100644 --- a/lisette/core.py +++ b/lisette/core.py @@ -244,6 +244,7 @@ def __init__( cache=False, # Anthropic prompt caching cache_idxs:list=[-1], # Anthropic cache breakpoint idxs, use `0` for sys prompt if provided ttl=None, # Anthropic prompt caching ttl + cached_content = None # Gemini prompt caching ): "LiteLLM chat client." self.model = model @@ -251,11 +252,13 @@ def __init__( if ns is None and tools: ns = mk_ns(tools) elif ns is None: ns = globals() self.tool_schemas = [lite_mk_func(t) for t in tools] if tools else None + self.cache_name = cached_content store_attr() def _prep_msg(self, msg=None, prefill=None): "Prepare the messages list for the API call" - sp = [{"role": "system", "content": self.sp}] if self.sp else [] + # Don't include sp if using cache (it's already in the cache) + sp = [{"role": "system", "content": self.sp}] if self.sp and not getattr(self, '_sp_in_cache', False) else [] if sp: if 0 in self.cache_idxs: sp[0] = _add_cache_control(sp[0]) cache_idxs = L(self.cache_idxs).filter().map(lambda o: o-1 if o>0 else o) @@ -271,6 +274,7 @@ def _call(self, msg=None, prefill=None, temp=None, think=None, search=None, stre if not get_model_info(self.model).get("supports_assistant_prefill"): prefill=None if _has_search(self.model) and (s:=ifnone(search,self.search)): kwargs['web_search_options'] = {"search_context_size": effort[s]} else: _=kwargs.pop('web_search_options',None) + if self.cache_name: kwargs['cached_content'] = self.cache_name res = completion(model=self.model, messages=self._prep_msg(msg, prefill), stream=stream, tools=self.tool_schemas, reasoning_effort = effort.get(think), tool_choice=tool_choice, # temperature is not supported when reasoning @@ -310,6 +314,96 @@ def __call__(self, if stream: return result_gen # streaming elif return_all: return list(result_gen) # toolloop behavior else: return last(result_gen) # normal chat behavior + + def create_cache(self, system_instruction=None, contents=None, tools=None, ttl="3600s"): + from google import genai + from google.genai import types + client = genai.Client() + + # if model is "gemini/gemini-2.0-flash", extract "gemini-2.0-flash" + if "/" in self.model: + model_name = self.model.split("/")[1] + else: + model_name = self.model + + #check if model has `-001` suffix + if "-001" not in model_name: + model_name += "-001" + + # Check if cache already exists + if self.cache_name: + raise ValueError("Cache already exists. Delete it first with delete_cache()") + + # Use defaults from Chat if not provided + system_instruction = system_instruction or self.sp + tools = tools or self.tool_schemas + + # Create cache using google.genai client + if contents: + cache = client.caches.create( + model=model_name, + config=types.CreateCachedContentConfig( + system_instruction=system_instruction, + contents=contents, + tools=tools, + ttl=ttl + ) + ) + else: + cache = client.caches.create( + model=model_name, + config=types.CreateCachedContentConfig( + system_instruction=system_instruction, + tools=tools, + ttl=ttl + ) + ) + # Store cache.name in self.cache_name + self.cache_name = cache.name + + # Set flag if system prompt is in cache + self._sp_in_cache = bool(system_instruction) + + # Return cache object + return cache + + + def delete_cache(self): + from google import genai + + if not self.cache_name: + raise ValueError("No cache exists to delete.") + + client = genai.Client() + client.caches.delete(name=self.cache_name) + self.cache_name = None + + def get_cache(self): + from google import genai + + if not self.cache_name: + raise ValueError("No cache exists") + + client = genai.Client() + return client.caches.get(name=self.cache_name) + + def update_cache(self,ttl='300s'): + ## ttl needs to be in seconds in string format i.e., '300s' + from google import genai + from google.genai import types + + if not self.cache_name: + raise ValueError("No cache exists to update") + + client = genai.Client() + client.caches.update( + name = self.cache_name, + config = types.UpdateCachedContentConfig( + ttl=ttl + ) +) + + # %% ../nbs/00_core.ipynb @patch @@ -467,3 +561,5 @@ async def adisplay_stream(rs): md+=o display(Markdown(md),clear=True) return fmt + +