diff --git a/docs/caching.md b/docs/caching.md index 3ea5ac45..599bbed3 100644 --- a/docs/caching.md +++ b/docs/caching.md @@ -5,8 +5,15 @@ LiSSA implements a sophisticated caching system to improve performance and ensure reproducibility of results. The caching system consists of the following components: 1. **Cache Interface** (`cache` package) - - [`Cache`](../src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/Cache.java): Core interface defining cache operations - - [`CacheKey`](../src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/CacheKey.java): Represents a unique key for cached items, including model name, seed, mode (EMBEDDING/CHAT), and content + - [`Cache`](../src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/Cache.java): Core generic interface defining cache operations, parameterized by cache key type + - [`CacheKey`](../src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/CacheKey.java): Base interface for cache keys with JSON serialization support and local key generation + - [`CacheParameter`](../src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/CacheParameter.java): Interface defining cache configuration and key creation logic + - **Specialized Cache Keys**: + - [`ClassifierCacheKey`](../src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/classifier/ClassifierCacheKey.java): Cache key for classifier operations (model name, seed, temperature, mode, content) + - [`EmbeddingCacheKey`](../src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/embedding/EmbeddingCacheKey.java): Cache key for embedding operations (model name, content) + - **Cache Parameters**: + - [`ClassifierCacheParameter`](../src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/classifier/ClassifierCacheParameter.java): Configuration for classifier caches (model name, seed, temperature) + - [`EmbeddingCacheParameter`](../src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/embedding/EmbeddingCacheParameter.java): Configuration for embedding caches (model name) 2. **Cache Implementations** - [`LocalCache`](../src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/LocalCache.java): File-based cache implementation that stores data in JSON format - Implements dirty tracking to optimize writes @@ -20,20 +27,61 @@ LiSSA implements a sophisticated caching system to improve performance and ensur - [`CacheManager`](../src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/CacheManager.java): Central manager for cache instances - Manages cache directory configuration - Provides singleton access to cache instances - - Handles cache creation and retrieval + - Handles cache creation and retrieval based on origin and cache parameters + - Ensures cache uniqueness by validating parameters 4. **Caching Usage** The caching system is used in several key components: - **Embedding Creators**: Caches vector embeddings to avoid recalculating them + - Uses `EmbeddingCacheParameter` to identify unique embedding configurations + - Cache keys are automatically generated based on content using the model name - **Classifiers**: Caches LLM responses for classification tasks + - Uses `ClassifierCacheParameter` to identify unique classifier configurations + - Cache keys include model name, seed, temperature, and content - **Preprocessors**: Caches preprocessing results for text summarization and other operations -5. **Configuration** + - Uses `ClassifierCacheParameter` for LLM-based preprocessing + +## Key Concepts + +### Cache Keys + +Cache keys uniquely identify cached items and consist of two parts: +- **JSON Key**: Serialized representation including all cache parameters (model, seed, temperature, content, mode) +- **Local Key**: Generated UUID-based key for in-memory identification and logging + +### Cache Parameters + +Cache parameters define the configuration that makes a cache unique: +- **ClassifierCacheParameter**: Model name, seed, and temperature for reproducible LLM results +- **EmbeddingCacheParameter**: Model name only (embeddings are deterministic) + +Parameters are used to: +1. Generate unique cache file names (via `parameters()` method) +2. Create cache keys from content (via `createCacheKey()` method) +3. Validate cache consistency when retrieving existing caches + +### Cache API + +The `Cache` interface provides two API levels: +1. **String-based API** (preferred): Pass content as string, cache handles key generation internally +- `get(String key, Class clazz)` +- `put(String key, T value)` +- `containsKey(String key)` + +2. **Internal Key API** (DO NOT USE): Direct cache key manipulation for special cases + - `getViaInternalKey(K key, Class clazz)` + - `putViaInternalKey(K key, T value)` + - Only use for backward compatibility or special handling scenarios + +## Usage Instructions + +1. **Configuration** ```json { "cache_dir": "./cache/path" // Directory for cache storage } ``` -6. **Redis Setup** +2. **Redis Setup** To use Redis for caching, you need to set up a Redis server. Here's a recommended Docker Compose configuration: ```yaml @@ -54,9 +102,9 @@ LiSSA implements a sophisticated caching system to improve performance and ensur To use Redis with LiSSA: 1. Start the Redis server using Docker Compose 2. The system will automatically use Redis if available - 3. If Redis is unavailable, it will fall back to local file-based caching + 3. If Redis is unavailable, it will fall back to local file-based caching (useful for replication packages) -7. **Best Practices** +3. **Best Practices** - Use the cache directory specified in the configuration - Clear the cache directory if you encounter issues diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/Cache.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/Cache.java index 08438f4e..831b7f34 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/Cache.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/Cache.java @@ -1,4 +1,4 @@ -/* Licensed under MIT 2025. */ +/* Licensed under MIT 2025-2026. */ package edu.kit.kastel.sdq.lissa.ratlr.cache; import org.jspecify.annotations.Nullable; @@ -8,7 +8,7 @@ * This interface defines the contract for caching mechanisms that store and retrieve * values associated with cache keys. */ -public interface Cache { +public interface Cache { /** * Retrieves a value from the cache and deserializes it to the specified type. * @@ -17,7 +17,20 @@ public interface Cache { * @param clazz The class of the type to deserialize to * @return The deserialized value, or null if not found */ - @Nullable T get(CacheKey key, Class clazz); + @Nullable T get(String key, Class clazz); + + /** + * Retrieves a value from the cache and deserializes it to the specified type. + * DO NOT USE UNLESS YOU KNOW WHAT YOU ARE DOING. + * + * @param The type to deserialize the cached value to + * @param key The cache key to look up + * @param clazz The class of the type to deserialize to + * @return The deserialized value, or null if not found + * @deprecated This method exposes internal cache key handling and should not be used in general code. + */ + @Deprecated(forRemoval = false) + @Nullable T getViaInternalKey(K key, Class clazz); /** * Stores a string value in the cache. @@ -25,7 +38,18 @@ public interface Cache { * @param key The cache key to store the value under * @param value The string value to store */ - void put(CacheKey key, String value); + void put(String key, String value); + + /** + * Stores a string value in the cache. + * + * @param The type of the value to store + * @param key The cache key to store the value under + * @param value The value to store + * @deprecated This method exposes internal cache key handling and should not be used in general code. + */ + @Deprecated(forRemoval = false) + void putViaInternalKey(K key, T value); /** * Stores an object value in the cache. @@ -35,7 +59,7 @@ public interface Cache { * @param key The cache key to store the value under * @param value The object value to store */ - void put(CacheKey key, T value); + void put(String key, T value); /** * Flushes any pending changes to the cache storage. @@ -51,5 +75,7 @@ public interface Cache { * @param key The cache key to check for existence * @return true if this map contains a mapping for the specified key */ - boolean containsKey(CacheKey key); + boolean containsKey(String key); + + CacheParameter getCacheParameter(); } diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/CacheKey.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/CacheKey.java index c99804e7..a68b17f3 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/CacheKey.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/CacheKey.java @@ -1,4 +1,4 @@ -/* Licensed under MIT 2025. */ +/* Licensed under MIT 2025-2026. */ package edu.kit.kastel.sdq.lissa.ratlr.cache; import com.fasterxml.jackson.core.JsonProcessingException; @@ -8,11 +8,6 @@ /** * Represents a key for caching operations in the LiSSA framework. * - * The current types of cache keys are: - *
    - *
  • {@link edu.kit.kastel.sdq.lissa.ratlr.cache.EmbeddingCacheKey EmbeddingCacheKey} for caching embedding generation operations.
  • - *
  • {@link edu.kit.kastel.sdq.lissa.ratlr.cache.ClassifierCacheKey ClassifierCacheKey} for caching classification operations.
  • - *
*/ public interface CacheKey { /** diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/CacheManager.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/CacheManager.java index 4656630e..720ae42b 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/CacheManager.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/CacheManager.java @@ -1,4 +1,4 @@ -/* Licensed under MIT 2025. */ +/* Licensed under MIT 2025-2026. */ package edu.kit.kastel.sdq.lissa.ratlr.cache; import java.io.IOException; @@ -29,7 +29,7 @@ public final class CacheManager { private static @Nullable CacheManager defaultInstanceManager; private final Path directoryOfCaches; - private final Map caches = new HashMap<>(); + private final Map> caches = new HashMap<>(); /** * Sets the cache directory for the default cache manager instance. @@ -79,61 +79,46 @@ public static CacheManager getDefaultInstance() { * @param parameters a list of parameters that define what makes a cache unique. E.g., the model name, temperature, and seed. * @return A cache instance for the specified name */ - public Cache getCache(Object origin, String[] parameters) { + public Cache getCache(Object origin, CacheParameter parameters) { if (origin == null || parameters == null) { throw new IllegalArgumentException("Origin and parameters must not be null"); } - for (String param : parameters) { - if (param == null) { - throw new IllegalArgumentException("Parameters must not contain null values"); - } - } - String name = origin.getClass().getSimpleName() + "_" + String.join("_", parameters); - return getCache(name, true); + String name = origin.getClass().getSimpleName() + "_" + parameters.parameters(); + return getCache(name, parameters); } /** - * Gets a cache instance for the specified name, optionally appending a file extension. + * Gets a cache instance for the specified name and parameters. * * @param name The name of the cache - * @param appendEnding Whether to append the .json extension to the cache name + * @param parameters The parameters that define the cache configuration * @return A cache instance for the specified name */ - private Cache getCache(String name, boolean appendEnding) { + private Cache getCache(String name, CacheParameter parameters) { name = name.replace(":", "__"); if (caches.containsKey(name)) { - return caches.get(name); + @SuppressWarnings("unchecked") + Cache cached = (Cache) caches.get(name); + if (!cached.getCacheParameter().equals(parameters)) { + throw new IllegalArgumentException( + "Cache with name " + name + " already exists with different parameters"); + } + return cached; } - LocalCache localCache = new LocalCache(directoryOfCaches + "/" + name + (appendEnding ? ".json" : "")); - RedisCache cache = new RedisCache(localCache, DEFAULT_REPLACE_LOCAL_CACHE_ON_CONFLICT); + LocalCache localCache = new LocalCache<>(directoryOfCaches + "/" + name + ".json", parameters); + RedisCache cache = new RedisCache<>(parameters, localCache, DEFAULT_REPLACE_LOCAL_CACHE_ON_CONFLICT); caches.put(name, cache); return cache; } - /** - * Gets a cache instance for an existing cache file. - * - * @param path The path to the existing cache file - * @param create Whether to create the cache file if it doesn't exist - * @return A cache instance for the specified file - * @throws IllegalArgumentException If the file doesn't exist (and create is false) or is a directory - */ - public Cache getCache(Path path, boolean create) { - path = directoryOfCaches.resolve(path.getFileName()); - if ((!create && Files.notExists(path)) || Files.isDirectory(path)) { - throw new IllegalArgumentException("file does not exist or is a directory: " + path); - } - return getCache(path.getFileName().toString(), false); - } - /** * Flushes all caches managed by this cache manager. * This ensures that all pending changes are written to disk. */ public void flush() { - for (Cache cache : caches.values()) { + for (Cache cache : caches.values()) { cache.flush(); } } diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/CacheParameter.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/CacheParameter.java new file mode 100644 index 00000000..806962e9 --- /dev/null +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/CacheParameter.java @@ -0,0 +1,28 @@ +/* Licensed under MIT 2025-2026. */ +package edu.kit.kastel.sdq.lissa.ratlr.cache; + +/** + * Interface for cache parameter implementations that define how cache keys are created and configured. + * Implementations specify the parameters that make a cache unique (e.g., model name, seed, temperature) + * and provide factory methods for creating cache keys. + * + * @param The type of cache key this parameter creates + */ +public interface CacheParameter { + /** + * Provides a unique string based on the actual cache parameters. + * This string is used for the file name of LocalCache and must uniquely identify the cache configuration. + * + * @return A unique string based on the cache parameters + */ + String parameters(); + + /** + * Creates a cache key based on the content and the cache parameters. + * The created key combines the cache configuration with the content to be cached. + * + * @param content The content to create the cache key for + * @return The created cache key + */ + K createCacheKey(String content); +} diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/ClassifierCacheKey.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/ClassifierCacheKey.java deleted file mode 100644 index 54875e17..00000000 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/ClassifierCacheKey.java +++ /dev/null @@ -1,41 +0,0 @@ -/* Licensed under MIT 2025. */ -package edu.kit.kastel.sdq.lissa.ratlr.cache; - -import com.fasterxml.jackson.annotation.JsonAutoDetect; -import com.fasterxml.jackson.annotation.JsonIgnore; -import com.fasterxml.jackson.annotation.JsonInclude; - -import edu.kit.kastel.sdq.lissa.ratlr.utils.KeyGenerator; - -/** - * Represents a key for classification caching operations in the LiSSA framework. - * This record is used to uniquely identify cached values based on various parameters - * such as the model used, seed value, operation mode, and content. - *

- * The key can be serialized to JSON for storage and retrieval from the cache. - *

- * Please always use the {@link #of(String, int, double, String)} method to create a new instance. - * - * @param model The identifier of the model used for the cached operation. - * @param seed The seed value used for randomization in the cached operation. - * @param temperature The temperature setting used in the cached operation. - * @param mode The mode of operation that was cached (classification for backward compatibility). - * @param content The content that was processed in the cached operation. - * @param localKey A local key for additional identification, not included in JSON serialization. - */ -@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY) -@JsonInclude(JsonInclude.Include.NON_NULL) -public record ClassifierCacheKey( - String model, - int seed, - double temperature, - LargeLanguageModelCacheMode mode, - String content, - @JsonIgnore String localKey) - implements CacheKey { - - public static ClassifierCacheKey of(String model, int seed, double temperature, String content) { - return new ClassifierCacheKey( - model, seed, temperature, LargeLanguageModelCacheMode.CHAT, content, KeyGenerator.generateKey(content)); - } -} diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/EmbeddingCacheKey.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/EmbeddingCacheKey.java deleted file mode 100644 index 4f869d3a..00000000 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/EmbeddingCacheKey.java +++ /dev/null @@ -1,51 +0,0 @@ -/* Licensed under MIT 2025. */ -package edu.kit.kastel.sdq.lissa.ratlr.cache; - -import com.fasterxml.jackson.annotation.JsonAutoDetect; -import com.fasterxml.jackson.annotation.JsonIgnore; -import com.fasterxml.jackson.annotation.JsonInclude; - -import edu.kit.kastel.sdq.lissa.ratlr.utils.KeyGenerator; - -/** - * Represents a key for embedding caching operations in the LiSSA framework. - * This record is used to uniquely identify cached values based on various parameters - * such as the model used, seed value, operation mode, and content. - *

- * The key can be serialized to JSON for storage and retrieval from the cache. - *

- * Please always use the {@link #of(String, String)} method to create a new instance. - * - * @param model The identifier of the model used for the cached operation. - * @param seed The seed value used for randomization in the cached operation (-1 for backward compatibility). - * @param temperature The temperature setting used in the cached operation (-1 for backward compatibility). - * @param mode The mode of operation that was cached (embedding generation for backward compatibility). - * @param content The content that was processed in the cached operation. - * @param localKey A local key for additional identification, not included in JSON serialization. - */ -@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY) -@JsonInclude(JsonInclude.Include.NON_NULL) -public record EmbeddingCacheKey( - String model, - int seed, - double temperature, - LargeLanguageModelCacheMode mode, - String content, - @JsonIgnore String localKey) - implements CacheKey { - - public static EmbeddingCacheKey of(String model, String content) { - return new EmbeddingCacheKey( - model, -1, -1, LargeLanguageModelCacheMode.EMBEDDING, content, KeyGenerator.generateKey(content)); - } - - /** - * Only use this method if you want to use a custom local key. You mostly do not want to do this. Only for special handling of embeddings. - * You should always prefer the {@link #of(String, String)} method. - * @deprecated please use {@link #of(String, String)} instead. - */ - @Deprecated(forRemoval = false) - public static EmbeddingCacheKey ofRaw(String model, String content, String localKey) { - return new EmbeddingCacheKey(model, -1, -1, LargeLanguageModelCacheMode.EMBEDDING, content, localKey); - } -} diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/LocalCache.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/LocalCache.java index f43d104f..5c8b3091 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/LocalCache.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/LocalCache.java @@ -1,4 +1,4 @@ -/* Licensed under MIT 2025. */ +/* Licensed under MIT 2025-2026. */ package edu.kit.kastel.sdq.lissa.ratlr.cache; import java.io.File; @@ -19,7 +19,7 @@ * to a JSON file. It includes automatic flushing of changes when a certain threshold * of modifications is reached. */ -class LocalCache { +class LocalCache { private final ObjectMapper mapper; /** @@ -27,6 +27,8 @@ class LocalCache { */ private static final int MAX_DIRTY = 50; + private final CacheParameter cacheParameter; + /** * Counter for unflushed modifications. */ @@ -46,7 +48,8 @@ class LocalCache { * * @param cacheFile The path to the cache file */ - LocalCache(String cacheFile) { + LocalCache(String cacheFile, CacheParameter cacheParameter) { + this.cacheParameter = cacheParameter; this.cacheFile = new File(cacheFile); mapper = new ObjectMapper(); createLocalStore(); @@ -117,7 +120,20 @@ public synchronized void write() { * @param key The cache key to look up * @return The cached value, or null if not found */ - public synchronized @Nullable String get(CacheKey key) { + public synchronized @Nullable String get(String key) { + K cacheKey = cacheParameter.createCacheKey(key); + return cache.get(cacheKey.localKey()); + } + + /** + * Retrieves a value from the cache. + * + * @param key The cache key to look up + * @return The cached value, or null if not found + * @deprecated This method exposes internal cache key handling and should not be used in general code. + */ + @Deprecated(forRemoval = false) + public synchronized @Nullable String getViaInternalKey(K key) { return cache.get(key.localKey()); } @@ -129,8 +145,30 @@ public synchronized void write() { * @param key The cache key to store the value under * @param value The value to store */ - public synchronized void put(CacheKey key, String value) { - String old = cache.put(key.localKey(), value); + public synchronized void put(String key, String value) { + K cacheKey = cacheParameter.createCacheKey(key); + String old = cache.put(cacheKey.localKey(), value); + if (old == null || !old.equals(value)) { + dirty++; + } + + if (dirty > MAX_DIRTY) { + write(); + } + } + + /** + * Stores a value in the cache. + * If the value is different from the existing value (if any), the dirty counter is incremented. + * If the dirty counter exceeds the maximum threshold, the cache is automatically flushed to disk. + * + * @param cacheKey The cache key to store the value under + * @param value The value to store + * @deprecated This method exposes internal cache key handling and should not be used in general code. + */ + @Deprecated(forRemoval = false) + public synchronized void putViaInternalKey(K cacheKey, String value) { + String old = cache.put(cacheKey.localKey(), value); if (old == null || !old.equals(value)) { dirty++; } @@ -146,7 +184,12 @@ public synchronized void put(CacheKey key, String value) { * @param key The cache key to look up * @return true if this map contains a mapping for the specified key */ - public boolean containsKey(CacheKey key) { - return cache.containsKey(key.localKey()); + public boolean containsKey(String key) { + K cacheKey = cacheParameter.createCacheKey(key); + return cache.containsKey(cacheKey.localKey()); + } + + public CacheParameter getCacheParameter() { + return this.cacheParameter; } } diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/RedisCache.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/RedisCache.java index 4b5f0e1b..be550e88 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/RedisCache.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/RedisCache.java @@ -1,4 +1,4 @@ -/* Licensed under MIT 2025. */ +/* Licensed under MIT 2025-2026. */ package edu.kit.kastel.sdq.lissa.ratlr.cache; import java.time.Instant; @@ -26,14 +26,17 @@ * 2. Local-only: When Redis is unavailable and local cache is configured * 3. Hybrid: When both Redis and local cache are available (default) */ -class RedisCache implements Cache { +class RedisCache implements Cache { private static final Logger logger = LoggerFactory.getLogger(RedisCache.class); - private final ObjectMapper mapper; + + private final CacheParameter cacheParameter; /** * Local file-based cache used as a backup. */ - private final @Nullable LocalCache localCache; + private final @Nullable LocalCache localCache; + + private final ObjectMapper mapper; /** * Redis client instance. @@ -51,8 +54,14 @@ class RedisCache implements Cache { * @param localCache The local cache to use as backup, or null if no backup is needed * @throws IllegalArgumentException If neither Redis nor local cache can be initialized */ - RedisCache(@Nullable LocalCache localCache, boolean replaceLocalCacheOnConflict) { + RedisCache( + CacheParameter cacheParameter, @Nullable LocalCache localCache, boolean replaceLocalCacheOnConflict) { + this.cacheParameter = Objects.requireNonNull(cacheParameter); this.localCache = localCache == null || !localCache.isReady() ? null : localCache; + if (this.localCache != null && !this.getCacheParameter().equals(this.localCache.getCacheParameter())) { + throw new IllegalArgumentException("Cache parameter of local cache does not match the one of Redis cache"); + } + mapper = new ObjectMapper(); createRedisConnection(); if (jedis == null && this.localCache == null) { @@ -69,8 +78,9 @@ public void flush() { } @Override - public boolean containsKey(CacheKey key) { - if (jedis != null && jedis.exists(key.toJsonKey())) { + public boolean containsKey(String key) { + K cacheKey = cacheParameter.createCacheKey(key); + if (jedis != null && jedis.exists(cacheKey.toJsonKey())) { return true; } return localCache != null && localCache.containsKey(key); @@ -112,8 +122,9 @@ private void createRedisConnection() { * @return The deserialized value, or null if not found */ @Override - public synchronized T get(CacheKey key, Class clazz) { - String jsonData = jedis == null ? null : jedis.hget(key.toJsonKey(), "data"); + public synchronized T get(String key, Class clazz) { + K cacheKey = cacheParameter.createCacheKey(key); + String jsonData = jedis == null ? null : jedis.hget(cacheKey.toJsonKey(), "data"); if (localCache == null) { return convert(jsonData, clazz); } @@ -124,7 +135,7 @@ public synchronized T get(CacheKey key, Class clazz) { } // Value is in local cache but not in redis cache if (localData != null && jsonData == null && jedis != null) { - jedis.hset(key.toJsonKey(), "data", localData); + jedis.hset(cacheKey.toJsonKey(), "data", localData); } // Value is in both caches, but they differ if (replaceLocalCacheOnConflict && jsonData != null && localData != null && !jsonData.equals(localData)) { @@ -136,6 +147,32 @@ public synchronized T get(CacheKey key, Class clazz) { return convert(valueToReturn, clazz); } + @Override + @SuppressWarnings("deprecation") + public synchronized @Nullable T getViaInternalKey(K cacheKey, Class clazz) { + String jsonData = jedis == null ? null : jedis.hget(cacheKey.toJsonKey(), "data"); + if (localCache == null) { + return convert(jsonData, clazz); + } + String localData = localCache.getViaInternalKey(cacheKey); + // Value is in redis cache but not in local cache + if (localData == null && jsonData != null) { + localCache.putViaInternalKey(cacheKey, jsonData); + } + // Value is in local cache but not in redis cache + if (localData != null && jsonData == null && jedis != null) { + jedis.hset(cacheKey.toJsonKey(), "data", localData); + } + // Value is in both caches, but they differ + if (replaceLocalCacheOnConflict && jsonData != null && localData != null && !jsonData.equals(localData)) { + logger.info("Cache inconsistency detected for key {}, using Redis value and replacing local one", cacheKey); + localCache.putViaInternalKey(cacheKey, jsonData); + } + + String valueToReturn = jsonData != null ? jsonData : localData; + return convert(valueToReturn, clazz); + } + /** * Converts a JSON string to an object of the specified type. * If the target type is String, the JSON string is returned as is. @@ -171,9 +208,10 @@ public synchronized T get(CacheKey key, Class clazz) { * @param value The string value to store */ @Override - public synchronized void put(CacheKey key, String value) { + public synchronized void put(String key, String value) { + K cacheKey = cacheParameter.createCacheKey(key); if (jedis != null) { - String jsonKey = key.toJsonKey(); + String jsonKey = cacheKey.toJsonKey(); jedis.hset(jsonKey, "data", value); jedis.hset(jsonKey, "timestamp", String.valueOf(Instant.now().getEpochSecond())); } @@ -193,11 +231,35 @@ public synchronized void put(CacheKey key, String value) { * @throws NullPointerException If value is null */ @Override - public synchronized void put(CacheKey key, T value) { + public synchronized void put(String key, T value) { try { put(key, mapper.writeValueAsString(Objects.requireNonNull(value))); } catch (JsonProcessingException e) { throw new IllegalArgumentException("Could not serialize object", e); } } + + @Override + @SuppressWarnings("deprecation") + public synchronized void putViaInternalKey(K key, T value) { + String data; + try { + data = mapper.writeValueAsString(Objects.requireNonNull(value)); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("Could not serialize object", e); + } + if (jedis != null) { + String jsonKey = key.toJsonKey(); + jedis.hset(jsonKey, "data", data); + jedis.hset(jsonKey, "timestamp", String.valueOf(Instant.now().getEpochSecond())); + } + if (localCache != null) { + localCache.putViaInternalKey(key, data); + } + } + + @Override + public CacheParameter getCacheParameter() { + return this.cacheParameter; + } } diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/classifier/ClassifierCacheKey.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/classifier/ClassifierCacheKey.java new file mode 100644 index 00000000..01f7fcd9 --- /dev/null +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/classifier/ClassifierCacheKey.java @@ -0,0 +1,141 @@ +/* Licensed under MIT 2025-2026. */ +package edu.kit.kastel.sdq.lissa.ratlr.cache.classifier; + +import java.util.Objects; + +import com.fasterxml.jackson.annotation.JsonAutoDetect; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; + +import edu.kit.kastel.sdq.lissa.ratlr.cache.CacheKey; +import edu.kit.kastel.sdq.lissa.ratlr.cache.LargeLanguageModelCacheMode; +import edu.kit.kastel.sdq.lissa.ratlr.utils.KeyGenerator; + +/** + * Represents a key for classification caching operations in the LiSSA framework. + * This class is used to uniquely identify cached values based on various parameters + * such as the model used, seed value, operation mode, and content. + *

+ * The key can be serialized to JSON for storage and retrieval from the cache. + */ +@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY) +@JsonInclude(JsonInclude.Include.NON_NULL) +public final class ClassifierCacheKey implements CacheKey { + private final String model; + private final int seed; + private final double temperature; + private final LargeLanguageModelCacheMode mode; + private final String content; + + @JsonIgnore + private final String localKey; + + /** + * Creates a new classifier cache key with the specified parameters. + * + * @param model The identifier of the model used for the cached operation + * @param seed The seed value used for randomization in the cached operation + * @param temperature The temperature setting used in the cached operation + * @param mode The mode of operation that was cached (classification for backward compatibility) + * @param content The content that was processed in the cached operation + * @param localKey A local key for additional identification, not included in JSON serialization + */ + private ClassifierCacheKey( + String model, + int seed, + double temperature, + LargeLanguageModelCacheMode mode, + String content, + String localKey) { + this.model = model; + this.seed = seed; + this.temperature = temperature; + this.mode = mode; + this.content = content; + this.localKey = localKey; + } + + /** + * Creates a classifier cache key from the given cache parameter and content. + * This is the preferred way to create cache keys. + * + * @param cacheParameter The cache parameter containing model configuration + * @param content The content to be cached + * @return A new classifier cache key + */ + static ClassifierCacheKey of(ClassifierCacheParameter cacheParameter, String content) { + return new ClassifierCacheKey( + cacheParameter.modelName(), + cacheParameter.seed(), + cacheParameter.temperature(), + LargeLanguageModelCacheMode.CHAT, + content, + KeyGenerator.generateKey(content)); + } + + /** + * Gets the identifier of the model used for the cached operation. + * + * @return The model identifier + */ + public String model() { + return model; + } + + /** + * Gets the seed value used for randomization in the cached operation. + * + * @return The seed value + */ + public int seed() { + return seed; + } + + /** + * Gets the temperature setting used in the cached operation. + * + * @return The temperature value + */ + public double temperature() { + return temperature; + } + + /** + * Gets the content that was processed in the cached operation. + * + * @return The content + */ + public String content() { + return content; + } + + @Override + @JsonIgnore + public String localKey() { + return localKey; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) return true; + if (obj == null || obj.getClass() != this.getClass()) return false; + var that = (ClassifierCacheKey) obj; + return Objects.equals(this.model, that.model) + && this.seed == that.seed + && Double.doubleToLongBits(this.temperature) == Double.doubleToLongBits(that.temperature) + && Objects.equals(this.mode, that.mode) + && Objects.equals(this.content, that.content) + && Objects.equals(this.localKey, that.localKey); + } + + @Override + public int hashCode() { + return Objects.hash(model, seed, temperature, mode, content, localKey); + } + + @Override + public String toString() { + return "ClassifierCacheKey[" + "model=" + model + ", " + "seed=" + seed + ", " + "temperature=" + temperature + + ", " + "mode=" + mode + ", " + "content=" + content + ", " + "localKey=" + localKey + ']'; + } +} diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/classifier/ClassifierCacheParameter.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/classifier/ClassifierCacheParameter.java new file mode 100644 index 00000000..63fb2f28 --- /dev/null +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/classifier/ClassifierCacheParameter.java @@ -0,0 +1,31 @@ +/* Licensed under MIT 2025-2026. */ +package edu.kit.kastel.sdq.lissa.ratlr.cache.classifier; + +import edu.kit.kastel.sdq.lissa.ratlr.cache.CacheParameter; + +/** + * Cache parameters for classifier operations. + * This record encapsulates the configuration parameters that define a unique classifier cache, + * including the model name, random seed, and temperature setting. + * + * @param modelName The name of the language model used for classification + * @param seed The random seed for reproducible results + * @param temperature The temperature parameter for controlling randomness in model outputs + */ +public record ClassifierCacheParameter(String modelName, int seed, double temperature) + implements CacheParameter { + @Override + public String parameters() { + // For backward compatibility, omit temperature if it is 0.0 + if (temperature == 0.0) { + return String.join("_", modelName, String.valueOf(seed)); + } else { + return String.join("_", modelName, String.valueOf(seed), String.valueOf(temperature)); + } + } + + @Override + public ClassifierCacheKey createCacheKey(String content) { + return ClassifierCacheKey.of(this, content); + } +} diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/embedding/EmbeddingCacheKey.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/embedding/EmbeddingCacheKey.java new file mode 100644 index 00000000..c7fae674 --- /dev/null +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/embedding/EmbeddingCacheKey.java @@ -0,0 +1,139 @@ +/* Licensed under MIT 2025-2026. */ +package edu.kit.kastel.sdq.lissa.ratlr.cache.embedding; + +import java.util.Objects; + +import com.fasterxml.jackson.annotation.JsonAutoDetect; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; + +import edu.kit.kastel.sdq.lissa.ratlr.cache.CacheKey; +import edu.kit.kastel.sdq.lissa.ratlr.cache.LargeLanguageModelCacheMode; +import edu.kit.kastel.sdq.lissa.ratlr.utils.KeyGenerator; + +/** + * Represents a key for embedding caching operations in the LiSSA framework. + * This class is used to uniquely identify cached values based on various parameters + * such as the model used, seed value, operation mode, and content. + *

+ * The key can be serialized to JSON for storage and retrieval from the cache. + */ +@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY) +@JsonInclude(JsonInclude.Include.NON_NULL) +public final class EmbeddingCacheKey implements CacheKey { + private final String model; + private final int seed; + private final double temperature; + private final LargeLanguageModelCacheMode mode; + private final String content; + + @JsonIgnore + private final String localKey; + + /** + * Creates a new embedding cache key with the specified parameters. + * + * @param model The identifier of the model used for the cached operation + * @param seed The seed value used for randomization in the cached operation (-1 for backward compatibility) + * @param temperature The temperature setting used in the cached operation (-1 for backward compatibility) + * @param mode The mode of operation that was cached (embedding generation for backward compatibility) + * @param content The content that was processed in the cached operation + * @param localKey A local key for additional identification, not included in JSON serialization + */ + private EmbeddingCacheKey( + String model, + int seed, + double temperature, + LargeLanguageModelCacheMode mode, + String content, + String localKey) { + this.model = model; + this.seed = seed; + this.temperature = temperature; + this.mode = mode; + this.content = content; + this.localKey = localKey; + } + + /** + * Creates an embedding cache key from the given cache parameter and content. + * This is the preferred way to create cache keys. + * + * @param cacheParameter The cache parameter containing model configuration + * @param content The content to be cached + * @return A new embedding cache key + */ + static EmbeddingCacheKey of(EmbeddingCacheParameter cacheParameter, String content) { + return new EmbeddingCacheKey( + cacheParameter.modelName(), + -1, + -1, + LargeLanguageModelCacheMode.EMBEDDING, + content, + KeyGenerator.generateKey(content)); + } + + /** + * Creates an embedding cache key with a custom local key. + * Only use this method if you want to use a custom local key. You mostly do not want to do this. + * Only for special handling of embeddings. You should always prefer the {@link #of(EmbeddingCacheParameter, String)} method. + * + * @param model The identifier of the model + * @param content The content to be cached + * @param localKey The custom local key + * @return A new embedding cache key with the specified local key + * @deprecated Please use {@link #of(EmbeddingCacheParameter, String)} instead + */ + @Deprecated(forRemoval = false) + public static EmbeddingCacheKey ofRaw(String model, String content, String localKey) { + return new EmbeddingCacheKey(model, -1, -1, LargeLanguageModelCacheMode.EMBEDDING, content, localKey); + } + + /** + * Gets the identifier of the model used for the cached operation. + * + * @return The model identifier + */ + public String model() { + return model; + } + + /** + * Gets the content that was processed in the cached operation. + * + * @return The content + */ + public String content() { + return content; + } + + @Override + @JsonIgnore + public String localKey() { + return localKey; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) return true; + if (obj == null || obj.getClass() != this.getClass()) return false; + var that = (EmbeddingCacheKey) obj; + return Objects.equals(this.model, that.model) + && this.seed == that.seed + && Double.doubleToLongBits(this.temperature) == Double.doubleToLongBits(that.temperature) + && Objects.equals(this.mode, that.mode) + && Objects.equals(this.content, that.content) + && Objects.equals(this.localKey, that.localKey); + } + + @Override + public int hashCode() { + return Objects.hash(model, seed, temperature, mode, content, localKey); + } + + @Override + public String toString() { + return "EmbeddingCacheKey[" + "model=" + model + ", " + "seed=" + seed + ", " + "temperature=" + temperature + + ", " + "mode=" + mode + ", " + "content=" + content + ", " + "localKey=" + localKey + ']'; + } +} diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/embedding/EmbeddingCacheParameter.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/embedding/EmbeddingCacheParameter.java new file mode 100644 index 00000000..9afa7258 --- /dev/null +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/embedding/EmbeddingCacheParameter.java @@ -0,0 +1,23 @@ +/* Licensed under MIT 2025-2026. */ +package edu.kit.kastel.sdq.lissa.ratlr.cache.embedding; + +import edu.kit.kastel.sdq.lissa.ratlr.cache.CacheParameter; + +/** + * Cache parameters for embedding operations. + * This record encapsulates the configuration parameters that define a unique embedding cache. + * For embeddings, only the model name is required as embeddings are deterministic. + * + * @param modelName The name of the embedding model used for generating embeddings + */ +public record EmbeddingCacheParameter(String modelName) implements CacheParameter { + @Override + public String parameters() { + return modelName; + } + + @Override + public EmbeddingCacheKey createCacheKey(String content) { + return EmbeddingCacheKey.of(this, content); + } +} diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ChatLanguageModelProvider.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ChatLanguageModelProvider.java index 0e1e9ded..487f94c1 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ChatLanguageModelProvider.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ChatLanguageModelProvider.java @@ -1,4 +1,4 @@ -/* Licensed under MIT 2025. */ +/* Licensed under MIT 2025-2026. */ package edu.kit.kastel.sdq.lissa.ratlr.classifier; import java.nio.charset.StandardCharsets; @@ -6,6 +6,8 @@ import java.util.Base64; import java.util.Map; +import edu.kit.kastel.sdq.lissa.ratlr.cache.CacheParameter; +import edu.kit.kastel.sdq.lissa.ratlr.cache.classifier.ClassifierCacheParameter; import edu.kit.kastel.sdq.lissa.ratlr.configuration.ModuleConfiguration; import edu.kit.kastel.sdq.lissa.ratlr.utils.Environment; @@ -302,14 +304,9 @@ private static ChatModel createOpenWebUIChatModel(String model, int seed, double * This method is used to identify the cache uniquely. * * @return An array of strings representing the cache parameters - * @see edu.kit.kastel.sdq.lissa.ratlr.cache.CacheManager#getCache(Object, String[]) + * @see edu.kit.kastel.sdq.lissa.ratlr.cache.CacheManager#getCache(Object, CacheParameter) */ - public String[] getCacheParameters() { - if (temperature == 0.0) { - // Backwards compatibility with the old mode that did not have temperature - return new String[] {modelName(), String.valueOf(seed())}; - } else { - return new String[] {modelName(), String.valueOf(seed()), String.valueOf(temperature())}; - } + public ClassifierCacheParameter cacheParameters() { + return new ClassifierCacheParameter(modelName, seed, temperature); } } diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningClassifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningClassifier.java index 5e300ca7..7b4637e3 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningClassifier.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningClassifier.java @@ -1,4 +1,4 @@ -/* Licensed under MIT 2025. */ +/* Licensed under MIT 2025-2026. */ package edu.kit.kastel.sdq.lissa.ratlr.classifier; import static dev.langchain4j.internal.Utils.quoted; @@ -11,7 +11,7 @@ import edu.kit.kastel.sdq.lissa.ratlr.cache.Cache; import edu.kit.kastel.sdq.lissa.ratlr.cache.CacheManager; -import edu.kit.kastel.sdq.lissa.ratlr.cache.ClassifierCacheKey; +import edu.kit.kastel.sdq.lissa.ratlr.cache.classifier.ClassifierCacheKey; import edu.kit.kastel.sdq.lissa.ratlr.configuration.ModuleConfiguration; import edu.kit.kastel.sdq.lissa.ratlr.context.ContextStore; import edu.kit.kastel.sdq.lissa.ratlr.knowledge.Element; @@ -36,7 +36,7 @@ public class ReasoningClassifier extends Classifier { */ private static final String CLASSIFICATION_PROMPT_KEY = "prompt"; - private final Cache cache; + private final Cache cache; /** * Provider for the language model used in classification. @@ -72,7 +72,7 @@ public class ReasoningClassifier extends Classifier { public ReasoningClassifier(ModuleConfiguration configuration, ContextStore contextStore) { super(ChatLanguageModelProvider.threads(configuration), contextStore); this.provider = new ChatLanguageModelProvider(configuration); - this.cache = CacheManager.getDefaultInstance().getCache(this, provider.getCacheParameters()); + this.cache = CacheManager.getDefaultInstance().getCache(this, provider.cacheParameters()); this.prompt = configuration.argumentAsStringByEnumIndex( CLASSIFICATION_PROMPT_KEY, 0, @@ -96,7 +96,7 @@ public ReasoningClassifier(ModuleConfiguration configuration, ContextStore conte */ private ReasoningClassifier( int threads, - Cache cache, + Cache cache, ChatLanguageModelProvider provider, String prompt, boolean useOriginalArtifacts, @@ -200,10 +200,8 @@ private String classifyIntern(Element source, Element target) { messages.add(new UserMessage(request)); String messageString = getRepresentation(messages); - ClassifierCacheKey cacheKey = - ClassifierCacheKey.of(provider.modelName(), provider.seed(), provider.temperature(), messageString); - String cachedResponse = cache.get(cacheKey, String.class); + String cachedResponse = cache.get(messageString, String.class); if (cachedResponse != null) { return cachedResponse; } else { @@ -214,7 +212,7 @@ private String classifyIntern(Element source, Element target) { target.getIdentifier()); ChatResponse response = llm.chat(messages); String responseText = response.aiMessage().text(); - cache.put(cacheKey, responseText); + cache.put(messageString, responseText); return responseText; } } diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleClassifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleClassifier.java index 43b5b5c4..77bdd36d 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleClassifier.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleClassifier.java @@ -1,11 +1,11 @@ -/* Licensed under MIT 2025. */ +/* Licensed under MIT 2025-2026. */ package edu.kit.kastel.sdq.lissa.ratlr.classifier; import java.util.Optional; import edu.kit.kastel.sdq.lissa.ratlr.cache.Cache; import edu.kit.kastel.sdq.lissa.ratlr.cache.CacheManager; -import edu.kit.kastel.sdq.lissa.ratlr.cache.ClassifierCacheKey; +import edu.kit.kastel.sdq.lissa.ratlr.cache.classifier.ClassifierCacheKey; import edu.kit.kastel.sdq.lissa.ratlr.configuration.ModuleConfiguration; import edu.kit.kastel.sdq.lissa.ratlr.context.ContextStore; import edu.kit.kastel.sdq.lissa.ratlr.knowledge.Element; @@ -29,8 +29,7 @@ public class SimpleClassifier extends Classifier { * The default template for classification requests. * This template presents two artifacts and asks if they are related. */ - private static final String DEFAULT_TEMPLATE = - """ + private static final String DEFAULT_TEMPLATE = """ Question: Here are two parts of software development artifacts. {source_type}: '''{source_content}''' @@ -48,7 +47,7 @@ public class SimpleClassifier extends Classifier { /** * The cache used for storing classification results. */ - private final Cache cache; + private final Cache cache; /** * Provider for the language model used in classification. @@ -75,7 +74,7 @@ public SimpleClassifier(ModuleConfiguration configuration, ContextStore contextS super(ChatLanguageModelProvider.threads(configuration), contextStore); this.provider = new ChatLanguageModelProvider(configuration); this.template = configuration.argumentAsString(PROMPT_TEMPLATE_KEY, DEFAULT_TEMPLATE); - this.cache = CacheManager.getDefaultInstance().getCache(this, provider.getCacheParameters()); + this.cache = CacheManager.getDefaultInstance().getCache(this, provider.cacheParameters()); this.llm = provider.createChatModel(); } @@ -89,7 +88,11 @@ public SimpleClassifier(ModuleConfiguration configuration, ContextStore contextS * @param template The template to use for classification requests */ private SimpleClassifier( - int threads, Cache cache, ChatLanguageModelProvider provider, String template, ContextStore contextStore) { + int threads, + Cache cache, + ChatLanguageModelProvider provider, + String template, + ContextStore contextStore) { super(threads, contextStore); this.cache = cache; this.provider = provider; @@ -160,9 +163,7 @@ private String classifyIntern(Element source, Element target) { .replace("{target_type}", target.getType()) .replace("{target_content}", target.getContent()); - ClassifierCacheKey cacheKey = - ClassifierCacheKey.of(provider.modelName(), provider.seed(), provider.temperature(), request); - String cachedResponse = cache.get(cacheKey, String.class); + String cachedResponse = cache.get(request, String.class); if (cachedResponse != null) { return cachedResponse; } else { @@ -172,7 +173,7 @@ private String classifyIntern(Element source, Element target) { source.getIdentifier(), target.getIdentifier()); String response = llm.chat(request); - cache.put(cacheKey, response); + cache.put(request, response); return response; } } diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/embeddingcreator/CachedEmbeddingCreator.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/embeddingcreator/CachedEmbeddingCreator.java index a3c40fbc..d2423d51 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/embeddingcreator/CachedEmbeddingCreator.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/embeddingcreator/CachedEmbeddingCreator.java @@ -1,4 +1,4 @@ -/* Licensed under MIT 2025. */ +/* Licensed under MIT 2025-2026. */ package edu.kit.kastel.sdq.lissa.ratlr.embeddingcreator; import java.util.*; @@ -11,10 +11,9 @@ import com.knuddels.jtokkit.api.Encoding; import com.knuddels.jtokkit.api.EncodingRegistry; -import edu.kit.kastel.sdq.lissa.ratlr.cache.Cache; -import edu.kit.kastel.sdq.lissa.ratlr.cache.CacheKey; -import edu.kit.kastel.sdq.lissa.ratlr.cache.CacheManager; -import edu.kit.kastel.sdq.lissa.ratlr.cache.EmbeddingCacheKey; +import edu.kit.kastel.sdq.lissa.ratlr.cache.*; +import edu.kit.kastel.sdq.lissa.ratlr.cache.embedding.EmbeddingCacheKey; +import edu.kit.kastel.sdq.lissa.ratlr.cache.embedding.EmbeddingCacheParameter; import edu.kit.kastel.sdq.lissa.ratlr.context.ContextStore; import edu.kit.kastel.sdq.lissa.ratlr.knowledge.Element; import edu.kit.kastel.sdq.lissa.ratlr.utils.Futures; @@ -41,10 +40,11 @@ abstract class CachedEmbeddingCreator extends EmbeddingCreator { private static final Logger STATIC_LOGGER = LoggerFactory.getLogger(CachedEmbeddingCreator.class); protected final Logger logger = LoggerFactory.getLogger(this.getClass()); - private final Cache cache; + private final Cache cache; private final EmbeddingModel embeddingModel; private final String rawNameOfModel; private final int threads; + private final EmbeddingCacheParameter embeddingCacheParameter; /** * Creates a new cached embedding creator with the specified model and thread count. @@ -56,7 +56,8 @@ abstract class CachedEmbeddingCreator extends EmbeddingCreator { */ protected CachedEmbeddingCreator(ContextStore contextStore, String model, int threads, String... params) { super(contextStore); - this.cache = CacheManager.getDefaultInstance().getCache(this, new String[] {model}); + this.embeddingCacheParameter = new EmbeddingCacheParameter(model); + this.cache = CacheManager.getDefaultInstance().getCache(this, embeddingCacheParameter); this.embeddingModel = Objects.requireNonNull(createEmbeddingModel(model, params)); this.rawNameOfModel = model; this.threads = Math.max(1, threads); @@ -138,7 +139,7 @@ private List calculateEmbeddingsSequential(List elements) { private List calculateEmbeddingsSequential(EmbeddingModel embeddingModel, List elements) { List embeddings = new ArrayList<>(); for (Element element : elements) { - embeddings.add(calculateFinalEmbedding(embeddingModel, cache, rawNameOfModel, element)); + embeddings.add(calculateFinalEmbedding(embeddingModel, cache, embeddingCacheParameter, element)); } return embeddings; } @@ -170,30 +171,32 @@ private List calculateEmbeddingsSequential(EmbeddingModel embeddingMode * * @param embeddingModel The model to use for embedding generation * @param cache The cache to use for storing and retrieving embeddings - * @param rawNameOfModel The name of the model being used + * @param embeddingCacheParameter The EmbeddingCacheParameter of the model being used * @param element The element to create an embedding for * @return The vector embedding of the element, either from cache or newly generated */ private static float[] calculateFinalEmbedding( - EmbeddingModel embeddingModel, Cache cache, String rawNameOfModel, Element element) { - - EmbeddingCacheKey cacheKey = EmbeddingCacheKey.of(rawNameOfModel, element.getContent()); + EmbeddingModel embeddingModel, + Cache cache, + EmbeddingCacheParameter embeddingCacheParameter, + Element element) { - float[] cachedEmbedding = cache.get(cacheKey, float[].class); + String elementContent = element.getContent(); + float[] cachedEmbedding = cache.get(elementContent, float[].class); if (cachedEmbedding != null) { return cachedEmbedding; } else { STATIC_LOGGER.info("Calculating embedding for: {}", element.getIdentifier()); try { float[] embedding = - embeddingModel.embed(element.getContent()).content().vector(); - cache.put(cacheKey, embedding); + embeddingModel.embed(elementContent).content().vector(); + cache.put(elementContent, embedding); return embedding; } catch (Exception e) { STATIC_LOGGER.error( "Error while calculating embedding for .. try to fix ..: {}", element.getIdentifier()); // Probably the length was too long .. check that - return tryToFixWithLength(embeddingModel, cache, rawNameOfModel, cacheKey, element.getContent()); + return tryToFixWithLength(embeddingModel, cache, embeddingCacheParameter.modelName(), elementContent); } } } @@ -206,24 +209,25 @@ private static float[] calculateFinalEmbedding( * @param embeddingModel The model to use for embedding generation * @param cache The cache to use for storing and retrieving embeddings * @param rawNameOfModel The name of the model being used - * @param key The original cache key * @param content The content that exceeded the token limit * @return The vector embedding of the truncated content * @throws IllegalArgumentException If the token length was not the cause of the failure */ private static float[] tryToFixWithLength( - EmbeddingModel embeddingModel, Cache cache, String rawNameOfModel, CacheKey key, String content) { - String newKey = key.localKey() + "_fixed_" + MAX_TOKEN_LENGTH; + EmbeddingModel embeddingModel, Cache cache, String rawNameOfModel, String content) { + EmbeddingCacheKey originalKey = cache.getCacheParameter().createCacheKey(content); + String newKey = originalKey.localKey() + "_fixed_" + MAX_TOKEN_LENGTH; // We need the old keys for backwards compatibility @SuppressWarnings("deprecation") EmbeddingCacheKey newCacheKey = EmbeddingCacheKey.ofRaw(rawNameOfModel, "(FIXED::%d): %s".formatted(MAX_TOKEN_LENGTH, content), newKey); - float[] cachedEmbedding = cache.get(newCacheKey, float[].class); + @SuppressWarnings("deprecation") + float[] cachedEmbedding = cache.getViaInternalKey(newCacheKey, float[].class); if (cachedEmbedding != null) { if (STATIC_LOGGER.isInfoEnabled()) { - STATIC_LOGGER.info("using fixed embedding for: {}", key.localKey()); + STATIC_LOGGER.info("using fixed embedding for: {}", originalKey.localKey()); } return cachedEmbedding; } @@ -252,9 +256,9 @@ private static float[] tryToFixWithLength( String fixedContent = content.substring(0, left); float[] embedding = embeddingModel.embed(fixedContent).content().vector(); if (STATIC_LOGGER.isInfoEnabled()) { - STATIC_LOGGER.info("using fixed embedding for: {}", key.localKey()); + STATIC_LOGGER.info("using fixed embedding for: {}", originalKey.localKey()); } - cache.put(newCacheKey, embedding); + cache.putViaInternalKey(newCacheKey, embedding); return embedding; } } diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/SummarizePreprocessor.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/SummarizePreprocessor.java index 3710b555..3d72dde8 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/SummarizePreprocessor.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/SummarizePreprocessor.java @@ -1,4 +1,4 @@ -/* Licensed under MIT 2025. */ +/* Licensed under MIT 2025-2026. */ package edu.kit.kastel.sdq.lissa.ratlr.preprocessor; import java.util.ArrayList; @@ -7,7 +7,7 @@ import edu.kit.kastel.sdq.lissa.ratlr.cache.Cache; import edu.kit.kastel.sdq.lissa.ratlr.cache.CacheManager; -import edu.kit.kastel.sdq.lissa.ratlr.cache.ClassifierCacheKey; +import edu.kit.kastel.sdq.lissa.ratlr.cache.classifier.ClassifierCacheKey; import edu.kit.kastel.sdq.lissa.ratlr.classifier.ChatLanguageModelProvider; import edu.kit.kastel.sdq.lissa.ratlr.configuration.ModuleConfiguration; import edu.kit.kastel.sdq.lissa.ratlr.context.ContextStore; @@ -51,7 +51,7 @@ public class SummarizePreprocessor extends Preprocessor { /** Number of threads to use for parallel processing */ private final int threads; /** Cache for storing and retrieving summaries */ - private final Cache cache; + private final Cache cache; /** * Creates a new summarize preprocessor with the specified configuration and context store. @@ -64,7 +64,7 @@ public SummarizePreprocessor(ModuleConfiguration moduleConfiguration, ContextSto this.template = moduleConfiguration.argumentAsString("template", "Summarize the following {type}: {content}"); this.provider = new ChatLanguageModelProvider(moduleConfiguration); this.threads = ChatLanguageModelProvider.threads(moduleConfiguration); - this.cache = CacheManager.getDefaultInstance().getCache(this, provider.getCacheParameters()); + this.cache = CacheManager.getDefaultInstance().getCache(this, provider.cacheParameters()); } /** @@ -107,17 +107,14 @@ public List preprocess(List artifacts) { List> tasks = new ArrayList<>(); for (String request : requests) { tasks.add(() -> { - ClassifierCacheKey cacheKey = - ClassifierCacheKey.of(provider.modelName(), provider.seed(), provider.temperature(), request); - - String cachedResponse = cache.get(cacheKey, String.class); + String cachedResponse = cache.get(request, String.class); if (cachedResponse != null) { return cachedResponse; } ChatModel chatModel = threads > 1 ? provider.createChatModel() : llmInstance; String response = chatModel.chat(request); - cache.put(cacheKey, response); + cache.put(request, response); return response; }); } diff --git a/src/test/java/edu/kit/kastel/sdq/lissa/ratlr/ArchitectureTest.java b/src/test/java/edu/kit/kastel/sdq/lissa/ratlr/ArchitectureTest.java index e6413c75..a61f7ae8 100644 --- a/src/test/java/edu/kit/kastel/sdq/lissa/ratlr/ArchitectureTest.java +++ b/src/test/java/edu/kit/kastel/sdq/lissa/ratlr/ArchitectureTest.java @@ -1,21 +1,31 @@ -/* Licensed under MIT 2025. */ +/* Licensed under MIT 2025-2026. */ package edu.kit.kastel.sdq.lissa.ratlr; +import static com.tngtech.archunit.lang.SimpleConditionEvent.violated; import static com.tngtech.archunit.lang.syntax.ArchRuleDefinition.*; import java.util.List; +import java.util.Set; import java.util.UUID; import java.util.concurrent.Future; import java.util.function.Consumer; +import java.util.stream.Collectors; import java.util.stream.Stream; import com.tngtech.archunit.base.DescribedPredicate; +import com.tngtech.archunit.core.domain.JavaClass; import com.tngtech.archunit.core.domain.JavaConstructorCall; +import com.tngtech.archunit.core.domain.JavaModifier; import com.tngtech.archunit.junit.AnalyzeClasses; import com.tngtech.archunit.junit.ArchTest; +import com.tngtech.archunit.lang.ArchCondition; import com.tngtech.archunit.lang.ArchRule; +import com.tngtech.archunit.lang.ConditionEvents; import edu.kit.kastel.sdq.lissa.ratlr.cache.CacheKey; +import edu.kit.kastel.sdq.lissa.ratlr.cache.CacheParameter; +import edu.kit.kastel.sdq.lissa.ratlr.cache.classifier.ClassifierCacheParameter; +import edu.kit.kastel.sdq.lissa.ratlr.cache.embedding.EmbeddingCacheParameter; import edu.kit.kastel.sdq.lissa.ratlr.utils.Environment; import edu.kit.kastel.sdq.lissa.ratlr.utils.Futures; import edu.kit.kastel.sdq.lissa.ratlr.utils.KeyGenerator; @@ -84,23 +94,27 @@ class ArchitectureTest { .because("Lambdas should be functional. ForEach is typically used for side-effects."); /** - * CacheKeys should only be created using the #of method of the CacheKey class. + * Rule that enforces that CacheKey implementations should only be created via static factory methods. + *

+ * External code should not directly instantiate CacheKey implementations. Instead, they should use + * the static factory methods (typically 'of()') provided by each CacheKey implementation or let the + * CacheParameter.createCacheKey() method handle key creation. + *

+ * This rule checks that constructors of classes implementing CacheKey are not called from outside + * those classes themselves (constructors are private and only called from static factory methods). */ @ArchTest static final ArchRule cacheKeysShouldBeCreatedUsingKeyGenerator = noClasses() .that() - .haveNameNotMatching(CacheKey.class.getName()) + .doNotImplement(CacheKey.class) // Exclude CacheKey implementations themselves .should() - .callConstructorWhere(new DescribedPredicate("calls CacheKey constructor") { - @Override - public boolean test(JavaConstructorCall javaConstructorCall) { - return javaConstructorCall - .getTarget() - .getOwner() - .getFullName() - .equals(CacheKey.class.getName()); - } - }); + .callConstructorWhere( + new DescribedPredicate("calls CacheKey implementation constructor") { + @Override + public boolean test(JavaConstructorCall javaConstructorCall) { + return javaConstructorCall.getTarget().getOwner().isAssignableTo(CacheKey.class); + } + }); /** * Futures should be opened with a logger. @@ -113,4 +127,154 @@ public boolean test(JavaConstructorCall javaConstructorCall) { .callMethod(Future.class, "get") .orShould() .callMethod(Future.class, "resultNow"); + + /** + * Rule that enforces that each CacheKey implementation has a static of() method. + *

+ * Each class implementing CacheKey must provide a static factory method named 'of' + * that takes a specific CacheParameter and a String as parameters. The method must + * access all record components (accessor methods) of the corresponding CacheParameter. + *

+ * This ensures that all configuration parameters (model name, seed, temperature, etc.) + * are properly used when creating cache keys, making the cache keys complete and unique. + * + * @see edu.kit.kastel.sdq.lissa.ratlr.cache.classifier.ClassifierCacheKey#of(ClassifierCacheParameter, String) + * @see edu.kit.kastel.sdq.lissa.ratlr.cache.embedding.EmbeddingCacheKey#of(EmbeddingCacheParameter, String) + */ + @ArchTest + static final ArchRule cacheKeysMustHaveOfMethodWithCacheParameter = classes() + .that() + .implement(CacheKey.class) + .and() + .areNotInterfaces() + .should( + new ArchCondition<>( + "have a static 'of' method that takes a CacheParameter and String, and reads all CacheParameter attributes") { + @Override + public void check(JavaClass javaClass, ConditionEvents events) { + // Check for static 'of' method + var ofMethods = javaClass.getMethods().stream() + .filter(m -> m.getName().equals("of")) + .filter(m -> m.getModifiers().contains(JavaModifier.STATIC)) + .filter(m -> m.getRawParameterTypes().size() == 2) + .filter(m -> m.getRawParameterTypes() + .get(0) + .isAssignableTo(edu.kit.kastel.sdq.lissa.ratlr.cache.CacheParameter.class)) + .filter(m -> m.getRawParameterTypes().get(1).isAssignableTo(String.class)) + .toList(); + + if (ofMethods.isEmpty()) { + String message = String.format( + "Class %s does not have a static 'of' method with signature: of(CacheParameter, String)", + javaClass.getFullName()); + events.add(violated(javaClass, message)); + return; + } + + // Check that the 'of' method reads all CacheParameter attributes + for (var ofMethod : ofMethods) { + var cacheParameterType = + ofMethod.getRawParameterTypes().get(0); + + // Get all accessor methods of the CacheParameter record components + // Exclude inherited methods, utility methods, and factory methods + var parameterMethods = cacheParameterType.getMethods().stream() + .filter(m -> !m.getOwner().isEquivalentTo(Object.class)) + // parameters() generates cache file name, not used in key creation + .filter(m -> !m.getName().equals("parameters")) + // createCacheKey() is the factory method called by Cache, not by of() + .filter(m -> !m.getName().equals("createCacheKey")) + // Default methods from Object + .filter(m -> !m.getName().equals("equals")) + .filter(m -> !m.getName().equals("hashCode")) + .filter(m -> !m.getName().equals("toString")) + .toList(); + + // Get all method calls in the 'of' method + var methodCallsInOf = ofMethod.getMethodCallsFromSelf(); + Set calledMethodNames = methodCallsInOf.stream() + .map(call -> call.getTarget().getName()) + .collect(Collectors.toSet()); + + // Check if all parameter methods are called + for (var paramMethod : parameterMethods) { + boolean isCalled = calledMethodNames.contains(paramMethod.getName()); + + if (!isCalled) { + String message = String.format( + "Method %s.of() does not read CacheParameter attribute '%s'", + javaClass.getSimpleName(), paramMethod.getName()); + events.add(violated(javaClass, message)); + } + } + } + } + }); + + /** + * Rule that enforces that the parameters() method in each CacheParameter implementation accesses all fields. + *

+ * Each class implementing CacheParameter must have a parameters() method that uses all record components/fields + * to ensure the cache key is unique and complete. + */ + @ArchTest + static final ArchRule cacheParametersMustUseAllFieldsInParametersMethod = classes() + .that() + .implement(CacheParameter.class) + .and() + .areNotInterfaces() + .should(new ArchCondition<>("have a parameters() method that accesses all fields") { + @Override + public void check(JavaClass javaClass, ConditionEvents events) { + // Find the parameters() method + var parametersMethod = javaClass.getMethods().stream() + .filter(m -> m.getName().equals("parameters")) + .filter(m -> m.getRawParameterTypes().isEmpty()) + .findFirst(); + + if (parametersMethod.isEmpty()) { + String message = + String.format("Class %s does not have a parameters() method", javaClass.getFullName()); + events.add(violated(javaClass, message)); + return; + } + + // Get all fields of the CacheParameter (record components) + var fields = javaClass.getAllFields().stream() + .filter(f -> !f.getModifiers().contains(JavaModifier.STATIC)) + .toList(); + + if (fields.isEmpty()) { + return; // No fields to check + } + + var method = parametersMethod.get(); + + // Get all field accesses in the parameters() method + var fieldAccesses = method.getFieldAccesses(); + Set accessedFieldNames = fieldAccesses.stream() + .map(access -> access.getTarget().getName()) + .collect(Collectors.toSet()); + + // Also check for method calls (record accessor methods) + var methodCalls = method.getMethodCallsFromSelf(); + Set calledMethodNames = methodCalls.stream() + .map(call -> call.getTarget().getName()) + .collect(Collectors.toSet()); + + // Check if all fields are accessed (either directly or via accessor methods) + for (var field : fields) { + String fieldName = field.getName(); + boolean isAccessed = + accessedFieldNames.contains(fieldName) || calledMethodNames.contains(fieldName); + + if (!isAccessed) { + String message = String.format( + "Method %s.parameters() does not access field '%s'", + javaClass.getSimpleName(), fieldName); + events.add(violated(javaClass, message)); + } + } + } + }); }