Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion nemo_skills/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
get_evaluator_class,
supports_single_eval,
)
from nemo_skills.inference.litellm_hybrid_cache import StableLiteLLMCache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, find and examine the generate.py file
fd -t f "generate.py" | head -5

Repository: NVIDIA-NeMo/Skills

Length of output: 254


🏁 Script executed:

# Get the actual code around lines 40 and 821-833
cat -n nemo_skills/inference/generate.py | sed -n '35,45p'

Repository: NVIDIA-NeMo/Skills

Length of output: 474


🏁 Script executed:

# Get cleanup code around lines 821-833
cat -n nemo_skills/inference/generate.py | sed -n '815,840p'

Repository: NVIDIA-NeMo/Skills

Length of output: 1273


🏁 Script executed:

# Find the StableLiteLLMCache implementation
fd -t f "litellm_hybrid_cache.py"

Repository: NVIDIA-NeMo/Skills

Length of output: 107


🏁 Script executed:

# Read the StableLiteLLMCache implementation to understand threading behavior
cat -n nemo_skills/inference/litellm_hybrid_cache.py

Repository: NVIDIA-NeMo/Skills

Length of output: 10684


🏁 Script executed:

# Check if HybridCache is used and what shutdown methods it has
rg "class HybridCache|class StableLiteLLMCache|def shutdown|def close|def _shutdown" nemo_skills/inference/litellm_hybrid_cache.py -A 5

Repository: NVIDIA-NeMo/Skills

Length of output: 728


🏁 Script executed:

# Check for background threads or save loops in the file
rg "Thread|thread|save.*loop|while.*save" nemo_skills/inference/litellm_hybrid_cache.py -B 2 -A 2

Repository: NVIDIA-NeMo/Skills

Length of output: 1210


🏁 Script executed:

# Check how cleanup_litellm_cache is called and if there are multiple instances
rg "cleanup_litellm_cache" nemo_skills/

Repository: NVIDIA-NeMo/Skills

Length of output: 279


🏁 Script executed:

# Check if disconnect() would be better - search for any async context in cleanup flow
rg "cleanup_litellm_cache|setup_litellm_cache" nemo_skills/inference/generate.py -B 5 -A 10

Repository: NVIDIA-NeMo/Skills

Length of output: 2472


🏁 Script executed:

# Look for the calling context to see if we could use async disconnect
ast-grep --pattern 'def generate(self) {
  $$$
  cleanup_litellm_cache
  $$$
}' nemo_skills/inference/generate.py

Repository: NVIDIA-NeMo/Skills

Length of output: 44


🏁 Script executed:

# Find where cleanup_litellm_cache is actually invoked
rg "cleanup_litellm_cache\(\)" nemo_skills/ -B 3 -A 1

Repository: NVIDIA-NeMo/Skills

Length of output: 381


🏁 Script executed:

# Check if there are any error handling or try-catch around cleanup_litellm_cache
rg "try:|except|cleanup_litellm_cache" nemo_skills/inference/generate.py -B 2 -A 2

Repository: NVIDIA-NeMo/Skills

Length of output: 1107


🏁 Script executed:

# Verify the async context - is generate() ever awaited or is it always sync?
rg "async def generate|await.*generate\(\)" nemo_skills/inference/generate.py

Repository: NVIDIA-NeMo/Skills

Length of output: 127


Stop the background save thread before deleting the cache directory to prevent race conditions.

The HybridCache instance runs a background thread that periodically saves to disk every 300 seconds. Calling force_save() only saves the current state but leaves the thread running. When rmtree() deletes the directory immediately after, the background thread may still attempt to write to a non-existent path or recreate the directory.

Store the cache instance and call its _shutdown() method before deleting the directory. This stops the background thread and performs a final save in the correct order:

Suggested change
     def setup_litellm_cache(self):
         if self.cfg.enable_litellm_cache:
             # One cache per (output_file_name, chunk_id) pair
             output_file_name = Path(self.cfg.output_file).name
             self.litellm_cache_dir = (
                 Path(self.cfg.output_file).parent / "litellm_cache" / f"{output_file_name}_{self.cfg.chunk_id or 0}"
             )
-            litellm.cache = StableLiteLLMCache(cache_file_path=str(self.litellm_cache_dir / "cache.pkl"))
+            self.litellm_cache = StableLiteLLMCache(
+                cache_file_path=str(self.litellm_cache_dir / "cache.pkl")
+            )
+            litellm.cache = self.litellm_cache
 
     def cleanup_litellm_cache(self):
         if self.cfg.enable_litellm_cache:
-            litellm.cache.cache.force_save()
+            self.litellm_cache.cache._shutdown()
             shutil.rmtree(self.litellm_cache_dir)
🤖 Prompt for AI Agents
In `@nemo_skills/inference/generate.py` at line 40, The background save thread in
the HybridCache (created via StableLiteLLMCache / HybridCache instance) must be
stopped before deleting its storage directory: instead of only calling
force_save() then shutil.rmtree(), retain the cache instance and call its
_shutdown() method to stop the background thread and perform a final save, then
call shutil.rmtree() to remove the directory; update any code that currently
only calls force_save() to call cache._shutdown() (or cache.shutdown wrapper if
available) prior to deleting the directory.

from nemo_skills.inference.model import (
ParallelThinkingConfig,
get_code_execution_model,
Expand Down Expand Up @@ -824,10 +825,11 @@ def setup_litellm_cache(self):
self.litellm_cache_dir = (
Path(self.cfg.output_file).parent / "litellm_cache" / f"{output_file_name}_{self.cfg.chunk_id or 0}"
)
litellm.cache = litellm.Cache(type="disk", disk_cache_dir=self.litellm_cache_dir)
litellm.cache = StableLiteLLMCache(cache_file_path=str(self.litellm_cache_dir / "cache.pkl"))

def cleanup_litellm_cache(self):
if self.cfg.enable_litellm_cache:
litellm.cache.cache.force_save()
shutil.rmtree(self.litellm_cache_dir)

def generate(self):
Expand Down
240 changes: 240 additions & 0 deletions nemo_skills/inference/litellm_hybrid_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
Hybrid Cache implementation that keeps data in memory and periodically saves to a single file on disk.

This avoids the issue with disk cache creating too many files while still persisting cache across runs.

Also includes StableLiteLLMCache which fixes litellm's cache key generation to be order-independent.
"""

import atexit
import json
import os
import pickle
import threading
from pathlib import Path
from typing import List, Optional

import litellm
from litellm.caching.caching import Cache as LiteLLMCache


class HybridCache:
def __init__(
self,
cache_file_path: str,
save_interval_seconds: float = 300.0, # 5 minutes
):
self.cache_file_path = cache_file_path
self.save_interval_seconds = save_interval_seconds

self.cache_dict: dict = {}
self._lock = threading.RLock()
self._dirty = False # Track if cache has been modified since last save
self._stop_event = threading.Event()
self._save_thread: Optional[threading.Thread] = None

self._load_from_disk()
self._start_background_save_thread()

atexit.register(self._shutdown)

def _check_no_ttl(self, **kwargs):
"""Raise error if TTL is provided since TTL is not supported."""
if kwargs.get("ttl") is not None:
raise ValueError("TTL is not supported by HybridCache")

def _load_from_disk(self):
"""Load cache from disk if the file exists."""
if os.path.exists(self.cache_file_path):
with open(self.cache_file_path, "rb") as f:
data = pickle.load(f)
self.cache_dict = data["cache_dict"]

def _save_to_disk(self):
"""Save cache to disk."""
with self._lock:
if not self._dirty:
return
data = {
"cache_dict": self.cache_dict.copy(),
}
self._dirty = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_dirty flag is cleared before the disk write completes. If the write fails (lines 78-82), the cache will incorrectly believe it's already saved and won't retry.

Suggested change
self._dirty = False
self._dirty = False # Move this after successful write


temp_path = self.cache_file_path + ".tmp"
Path(self.cache_file_path).parent.mkdir(parents=True, exist_ok=True)
with open(temp_path, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
os.replace(temp_path, self.cache_file_path)

def _start_background_save_thread(self):
"""Start a background thread that periodically saves to disk."""

def save_loop():
while not self._stop_event.wait(timeout=self.save_interval_seconds):
self._save_to_disk()

self._save_thread = threading.Thread(target=save_loop, daemon=True)
self._save_thread.start()

def _shutdown(self):
"""Shutdown the background thread and save cache."""
self._stop_event.set()
if self._save_thread is not None:
self._save_thread.join(timeout=5.0)
self._save_to_disk()

def set_cache(self, key, value, **kwargs):
"""Set a value in the cache."""
self._check_no_ttl(**kwargs)
with self._lock:
self.cache_dict[key] = value
self._dirty = True

async def async_set_cache(self, key, value, **kwargs):
"""Async set - delegates to sync implementation since we're using in-memory."""
self.set_cache(key=key, value=value, **kwargs)

async def async_set_cache_pipeline(self, cache_list, **kwargs):
"""Set multiple cache entries."""
for cache_key, cache_value in cache_list:
self.set_cache(key=cache_key, value=cache_value, **kwargs)

def get_cache(self, key, **kwargs):
"""Get a value from the cache."""
with self._lock:
if key not in self.cache_dict:
return None
cached_response = self.cache_dict[key]
if isinstance(cached_response, str):
try:
cached_response = json.loads(cached_response)
except json.JSONDecodeError:
pass
return cached_response

async def async_get_cache(self, key, **kwargs):
"""Async get - delegates to sync implementation."""
return self.get_cache(key=key, **kwargs)

def batch_get_cache(self, keys: list, **kwargs):
"""Get multiple values from cache."""
return [self.get_cache(key=k, **kwargs) for k in keys]

async def async_batch_get_cache(self, keys: list, **kwargs):
"""Async batch get."""
return self.batch_get_cache(keys=keys, **kwargs)

def increment_cache(self, key, value: int, **kwargs) -> int:
"""Increment a cache value."""
with self._lock:
init_value = self.get_cache(key=key) or 0
new_value = init_value + value
self.set_cache(key, new_value, **kwargs)
return new_value

async def async_increment(self, key, value: float, **kwargs) -> float:
"""Async increment."""
return self.increment_cache(key, int(value), **kwargs)

Comment on lines +101 to +153
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

find . -name "litellm_hybrid_cache.py" -type f

Repository: NVIDIA-NeMo/Skills

Length of output: 109


🏁 Script executed:

# Get file size and read the full implementation
wc -l nemo_skills/inference/litellm_hybrid_cache.py

Repository: NVIDIA-NeMo/Skills

Length of output: 111


🏁 Script executed:

# Read the file to understand _check_no_ttl and the full context
cat -n nemo_skills/inference/litellm_hybrid_cache.py

Repository: NVIDIA-NeMo/Skills

Length of output: 10684


🌐 Web query:

LiteLLM cache interface documentation - what kwargs are supported by get_cache set_cache increment_cache methods

💡 Result:

LiteLLM’s cache backends expose a small “core” interface (get_cache, set_cache, increment_cache) and then accept cache-control knobs via kwargs (passed through from the cache={...} dict on a request).

Cache-control kwargs (passed via cache={...} on a request)

These are the per-request kwargs LiteLLM documents and routes into cache read/write behavior (i.e., they affect get_cache() / cache lookup and cache write):

  • ttl (int, seconds) – per-entry TTL override (used on write) [1][2]
  • s-maxage (int, seconds) – freshness requirement for accepting an entry (validated after a hit / on read) [1][2]
  • no-cache (bool) – bypass cache read (still allows write unless no-store) [1][2]
  • no-store (bool) – prevent cache write [1][2]
  • namespace (str) – per-request key prefix / cache segmentation [2]

Method-level kwargs / parameters you can rely on

Because the concrete backends differ, LiteLLM’s docs are most consistent about these method parameters:

  • get_cache(key, ...) (reads; influenced by no-cache, s-maxage, namespace) [2]
  • set_cache(key, value, ttl=None, ...) (writes; TTL commonly supported as an argument) [3]
  • increment_cache(key, ...) is used for counters (rate limits, etc.); docs don’t enumerate a separate, richer kwarg set beyond the same TTL/namespace concepts used throughout caching [2].

If you tell me which cache backend you’re implementing (e.g., InMemoryCache, RedisCache, DualCache, proxy cache), I can map these to the exact Python method signatures used in that backend.

Sources: [1] https://docs.litellm.ai/docs/caching/all_caches • [2] https://deepwiki.com/BerriAI/litellm/5.3-dynamic-cache-controls • [3] https://readmex.com/en-US/BerriAI/litellm/page-6d9ab5a93-b77f-4311-b15f-ac9ae2a3ec34


🏁 Script executed:

rg "\.get_cache\(|\.batch_get_cache\(|\.increment_cache\(" --type py -A 2 -B 2

Repository: NVIDIA-NeMo/Skills

Length of output: 2494


🏁 Script executed:

# Also check calls to async variants
rg "\.async_get_cache\(|\.async_batch_get_cache\(|\.async_increment\(" --type py -A 2 -B 2

Repository: NVIDIA-NeMo/Skills

Length of output: 44


🏁 Script executed:

# Search for external calls to HybridCache methods
rg "HybridCache|StableLiteLLMCache" --type py -A 3 | head -100

Repository: NVIDIA-NeMo/Skills

Length of output: 2256


🏁 Script executed:

# Check how LiteLLM uses cache interface - look for cache.get_cache or cache.set_cache patterns
rg "cache\.(get_cache|set_cache|batch_get_cache|increment_cache|async_get_cache)" --type py

Repository: NVIDIA-NeMo/Skills

Length of output: 44


🏁 Script executed:

# Search for where cache methods are actually invoked from LiteLLM
# Check the StableLiteLLMCache parent class inheritance and method overrides
rg "def (get_cache|set_cache|batch_get_cache|increment_cache)" nemo_skills/inference/litellm_hybrid_cache.py

Repository: NVIDIA-NeMo/Skills

Length of output: 312


🏁 Script executed:

# Check if HybridCache methods are actually called by LiteLLM by searching generate.py
cat -n nemo_skills/inference/generate.py | grep -A 30 "StableLiteLLMCache\|litellm.cache"

Repository: NVIDIA-NeMo/Skills

Length of output: 7738


Validate kwargs across all cache methods to fail fast on unsupported parameters.

Currently, get_cache(), batch_get_cache(), and async variants silently drop **kwargs, while only set_cache() validates via _check_no_ttl(). This inconsistency violates the guideline to avoid silently ignoring user-passed parameters. Since HybridCache is in-memory persistent storage without support for cache-control features (ttl, namespace, s-maxage, no-cache, no-store), all methods should validate and reject any unsupported kwargs.

Replace _check_no_ttl() with a comprehensive validation method that fails on any kwargs:

def _validate_kwargs(self, **kwargs):
    if kwargs:
        raise ValueError(f"Unsupported cache kwargs: {', '.join(kwargs)}")

Apply this to all cache methods: get_cache(), batch_get_cache(), increment_cache(), and their async variants—not just set_cache().

🧰 Tools
🪛 Ruff (0.14.14)

[warning] 117-117: Unused method argument: kwargs

(ARG002)

🤖 Prompt for AI Agents
In `@nemo_skills/inference/litellm_hybrid_cache.py` around lines 101 - 153, The
cache methods currently only validate kwargs in set_cache via _check_no_ttl,
letting unsupported kwargs silently pass in get_cache, batch_get_cache,
increment_cache and async variants; replace _check_no_ttl with a new
_validate_kwargs(self, **kwargs) that raises ValueError when kwargs is non-empty
and call _validate_kwargs at the start of all cache entry points (set_cache,
async_set_cache, async_set_cache_pipeline, get_cache, async_get_cache,
batch_get_cache, async_batch_get_cache, increment_cache, async_increment) so
every method fails fast on unsupported parameters.

def flush_cache(self):
"""Clear all cache entries."""
with self._lock:
self.cache_dict.clear()
self._dirty = True

def delete_cache(self, key):
"""Delete a specific key from cache."""
with self._lock:
self.cache_dict.pop(key, None)
self._dirty = True

async def disconnect(self):
"""Disconnect and save cache to disk."""
self._shutdown()

async def async_set_cache_sadd(self, key, value: List):
"""Add values to a set."""
with self._lock:
init_value = self.get_cache(key=key) or set()
for val in value:
init_value.add(val)
self.set_cache(key, init_value)
return value

def force_save(self):
"""Force an immediate save to disk."""
self._dirty = True
self._save_to_disk()


class StableLiteLLMCache(LiteLLMCache):
"""
A litellm Cache subclass that generates order-independent cache keys.

The default litellm cache key generation iterates through kwargs in order,
which means the same request with different parameter ordering produces
different cache keys. This class fixes that by sorting kwargs before iteration.
"""

def __init__(self, cache_file_path: str, save_interval_seconds: float = 300.0, **kwargs):
super().__init__(type="local", **kwargs)
self.cache = HybridCache(
cache_file_path=cache_file_path,
save_interval_seconds=save_interval_seconds,
)

def _stable_str(self, value) -> str:
"""Convert value to string deterministically (handles nested dicts/lists)."""
if isinstance(value, (dict, list)):
return json.dumps(value, sort_keys=True, default=str)
return str(value)

def get_cache_key(self, **kwargs) -> str:
"""
Get the cache key for the given arguments.
Same as parent but with sorted iteration for deterministic keys.
"""
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
from litellm.types.utils import all_litellm_params

cache_key = ""

preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
if preset_cache_key is not None:
return preset_cache_key

combined_kwargs = ModelParamHelper._get_all_llm_api_params()
litellm_param_kwargs = all_litellm_params

# FIX: Sort kwargs for deterministic cache key generation
for param in sorted(kwargs.keys()):
if param in combined_kwargs:
param_value = self._get_param_value(param, kwargs)
if param_value is not None:
cache_key += f"{str(param)}: {self._stable_str(param_value)}"
elif param not in litellm_param_kwargs:
if litellm.enable_caching_on_provider_specific_optional_params is True:
if kwargs[param] is None:
continue
param_value = kwargs[param]
cache_key += f"{str(param)}: {self._stable_str(param_value)}"

hashed_cache_key = self._get_hashed_cache_key(cache_key)
hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs)
self._set_preset_cache_key_in_kwargs(preset_cache_key=hashed_cache_key, **kwargs)
return hashed_cache_key
8 changes: 5 additions & 3 deletions recipes/proof-gen-verification/scripts/script_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ async def process_single_datapoint(self, data_point, all_data):
kwargs["llm_kwargs"] = self.llm_kwargs
if "random_seed" in sig.parameters:
kwargs["random_seed"] = self.random_seed
result = await self._script_module.process_single(**kwargs)
if "generation" not in result:
result["generation"] = "dummy generation key" # To avoid error in dumping
async with self.semaphore:
result = await self._script_module.process_single(**kwargs)
if "generation" not in result:
result["generation"] = "dummy generation key" # To avoid error in dumping

return result


Expand Down