|
13 | 13 | GenericFakeChatModel,
|
14 | 14 | )
|
15 | 15 | from langchain_core.messages import AIMessage
|
16 |
| -from langchain_core.outputs import ChatGeneration |
| 16 | +from langchain_core.outputs import ChatGeneration, Generation |
| 17 | +from langchain_core.outputs.chat_result import ChatResult |
17 | 18 |
|
18 | 19 |
|
19 | 20 | class InMemoryCache(BaseCache):
|
@@ -305,6 +306,93 @@ def test_llm_representation_for_serializable() -> None:
|
305 | 306 | )
|
306 | 307 |
|
307 | 308 |
|
| 309 | +def test_cache_with_generation_objects() -> None: |
| 310 | + """Test that cache can handle Generation objects instead of ChatGeneration objects. |
| 311 | +
|
| 312 | + This test reproduces a bug where cache returns Generation objects |
| 313 | + but ChatResult expects ChatGeneration objects, causing validation errors. |
| 314 | +
|
| 315 | + See #22389 for more info. |
| 316 | + |
| 317 | + """ |
| 318 | + cache = InMemoryCache() |
| 319 | + |
| 320 | + # Create a simple fake chat model that we can control |
| 321 | + from langchain_core.messages import AIMessage |
| 322 | + |
| 323 | + class SimpleFakeChat: |
| 324 | + """Simple fake chat model for testing.""" |
| 325 | + |
| 326 | + def __init__(self, cache: BaseCache) -> None: |
| 327 | + self.cache = cache |
| 328 | + self.response = "hello" |
| 329 | + |
| 330 | + def _get_llm_string(self) -> str: |
| 331 | + return "test_llm_string" |
| 332 | + |
| 333 | + def generate_response(self, prompt: str) -> ChatResult: |
| 334 | + """Simulate the cache lookup and generation logic.""" |
| 335 | + from langchain_core.load import dumps |
| 336 | + |
| 337 | + llm_string = self._get_llm_string() |
| 338 | + prompt_str = dumps([prompt]) |
| 339 | + |
| 340 | + # Check cache first |
| 341 | + cache_val = self.cache.lookup(prompt_str, llm_string) |
| 342 | + if cache_val: |
| 343 | + # This is where our fix should work |
| 344 | + converted_generations = [] |
| 345 | + for gen in cache_val: |
| 346 | + if isinstance(gen, Generation) and not isinstance( |
| 347 | + gen, ChatGeneration |
| 348 | + ): |
| 349 | + # Convert Generation to ChatGeneration by creating an AIMessage |
| 350 | + chat_gen = ChatGeneration( |
| 351 | + message=AIMessage(content=gen.text), |
| 352 | + generation_info=gen.generation_info, |
| 353 | + ) |
| 354 | + converted_generations.append(chat_gen) |
| 355 | + else: |
| 356 | + converted_generations.append(gen) |
| 357 | + return ChatResult(generations=converted_generations) |
| 358 | + |
| 359 | + # Generate new response |
| 360 | + chat_gen = ChatGeneration( |
| 361 | + message=AIMessage(content=self.response), generation_info={} |
| 362 | + ) |
| 363 | + result = ChatResult(generations=[chat_gen]) |
| 364 | + |
| 365 | + # Store in cache |
| 366 | + self.cache.update(prompt_str, llm_string, result.generations) |
| 367 | + return result |
| 368 | + |
| 369 | + model = SimpleFakeChat(cache) |
| 370 | + |
| 371 | + # First call - normal operation |
| 372 | + result1 = model.generate_response("test prompt") |
| 373 | + assert result1.generations[0].message.content == "hello" |
| 374 | + |
| 375 | + # Manually corrupt the cache by replacing ChatGeneration with Generation |
| 376 | + cache_key = next(iter(cache._cache.keys())) |
| 377 | + cached_chat_generations = cache._cache[cache_key] |
| 378 | + |
| 379 | + # Replace with Generation objects (missing message field) |
| 380 | + corrupted_generations = [ |
| 381 | + Generation( |
| 382 | + text=gen.text, |
| 383 | + generation_info=gen.generation_info, |
| 384 | + type="Generation", # This is the key - wrong type |
| 385 | + ) |
| 386 | + for gen in cached_chat_generations |
| 387 | + ] |
| 388 | + cache._cache[cache_key] = corrupted_generations |
| 389 | + |
| 390 | + # Second call should handle the Generation objects gracefully |
| 391 | + result2 = model.generate_response("test prompt") |
| 392 | + assert result2.generations[0].message.content == "hello" |
| 393 | + assert isinstance(result2.generations[0], ChatGeneration) |
| 394 | + |
| 395 | + |
308 | 396 | def test_cleanup_serialized() -> None:
|
309 | 397 | cleanup_serialized = {
|
310 | 398 | "lc": 1,
|
|
0 commit comments