Skip to content

Commit eff6784

Browse files
authored
Merge pull request #213 from veithly/fix/graph
Enhance GraphAgent memory initialization by including storage path.
2 parents f89439d + d7ae3ef commit eff6784

File tree

3 files changed

+78
-41
lines changed

3 files changed

+78
-41
lines changed

spoon_ai/graph/agent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,11 @@ def load_session(self, session_id: str):
359359
memory_class = self.memory.__class__
360360
if memory_class == Memory or memory_class == MockMemory:
361361
old_session = self.memory.session_id
362-
new_memory = memory_class(session_id=session_id)
362+
storage_path = getattr(self.memory, "storage_path", None)
363+
new_memory = memory_class(
364+
storage_path=str(storage_path) if storage_path else None,
365+
session_id=session_id,
366+
)
363367
self.memory = new_memory
364368
print(f"Switched from session '{old_session}' to '{session_id}'")
365369
else:

spoon_ai/graph/engine.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ async def __call__(self, state: State, config: Optional[Dict[str, Any]] = None)
8787
return {"result": result}
8888

8989
except InterruptError:
90+
# Human-in-the-loop: allow interrupts to propagate to the graph engine.
9091
# Let InterruptError propagate to be handled in invoke()
9192
raise
9293
except Exception as e:
@@ -157,6 +158,9 @@ async def __call__(self, state: State, config: Optional[Dict[str, Any]] = None)
157158
result = await self.condition_func(state)
158159

159160
return {"condition_result": result, "next_node": result}
161+
except InterruptError:
162+
# Human-in-the-loop: allow interrupts to propagate to the graph engine.
163+
raise
160164
except Exception as e:
161165
logger.error(f"Condition node {self.name} failed: {e}")
162166
raise NodeExecutionError(f"Condition '{self.name}' failed", node_name=self.name, original_error=e, state=state) from e
@@ -511,7 +515,7 @@ def set_llm_router(self, router_func: Optional[Callable[[Dict[str, Any], str], s
511515

512516
return self
513517

514-
def _create_default_llm_router(self, state: Dict[str, Any], query: str) -> str:
518+
async def _create_default_llm_router(self, state: Dict[str, Any], query: str) -> str:
515519
"""Create and use default LLM router for natural language routing"""
516520
try:
517521
# Lazy import to avoid circular dependencies
@@ -533,10 +537,10 @@ def _create_default_llm_router(self, state: Dict[str, Any], query: str) -> str:
533537
Return ONLY the step name (lowercase, no explanation).
534538
"""
535539

536-
messages = [{"role": "user", "content": routing_prompt}]
540+
messages = [Message(role="user", content=routing_prompt)]
537541

538542
# Use LLM to determine next step
539-
response = llm_manager.chat(messages, **config)
543+
response = await llm_manager.chat(messages, provider=None, **config)
540544
next_step = response.content.strip().lower()
541545

542546
# Validate the response
@@ -594,6 +598,14 @@ def enable_llm_routing(self, config: Optional[Dict[str, Any]] = None) -> "StateG
594598
if config:
595599
self.llm_router_config.update(config)
596600

601+
# Enabling LLM routing implies allowing the router to call the LLM.
602+
# Keep this best-effort to avoid surprising failures if config is not a GraphConfig.
603+
try:
604+
if isinstance(self.config, GraphConfig):
605+
self.config.router.allow_llm = True
606+
except Exception:
607+
pass
608+
597609
# Set LLM router as the intelligent router
598610
self.set_llm_router()
599611
return self
@@ -841,6 +853,11 @@ async def invoke(self, initial_state: Optional[Dict[str, Any]] = None, config: O
841853
try:
842854
if callable(self.graph.state_validator):
843855
self.graph.state_validator(state)
856+
# Also support GraphConfig.state_validators (list of callables)
857+
if isinstance(self.graph.config, GraphConfig):
858+
for validator in self.graph.config.state_validators:
859+
if callable(validator):
860+
validator(state)
844861
except Exception as e:
845862
raise GraphExecutionError(f"State validation failed: {e}", node=current_node, iteration=iteration)
846863
except InterruptError as e:
@@ -911,6 +928,7 @@ async def _execute_node(self, node_name: str, state: State, config: Optional[Dic
911928
pass
912929
return result if isinstance(result, dict) else {"result": result}
913930
except InterruptError:
931+
# Human-in-the-loop: allow interrupts to propagate to the graph engine.
914932
# Let InterruptError propagate to be handled in invoke()
915933
raise
916934
except Exception as e:
@@ -1202,11 +1220,21 @@ async def stream(self, initial_state: Optional[Dict[str, Any]] = None, config: O
12021220
self._update_state_with_reducers(state, result)
12031221
if stream_mode == "values":
12041222
yield state.copy()
1223+
# optional validation (mirrors invoke)
1224+
try:
1225+
if callable(self.graph.state_validator):
1226+
self.graph.state_validator(state)
1227+
if isinstance(self.graph.config, GraphConfig):
1228+
for validator in self.graph.config.state_validators:
1229+
if callable(validator):
1230+
validator(state)
1231+
except Exception as e:
1232+
raise GraphExecutionError(f"State validation failed: {e}", node=current_node, iteration=iteration)
12051233
except InterruptError as e:
12061234
yield {"type": "interrupt", "node": current_node, "interrupt_id": e.interrupt_id, "interrupt_data": e.interrupt_data, "state": state.copy()}
12071235
return
12081236
next_node = await self._determine_next_node(current_node, state)
1209-
if next_node == "END" or next_node is None:
1237+
if next_node == END or next_node is None:
12101238
if stream_mode == "values":
12111239
yield state.copy()
12121240
break

spoon_ai/llm/manager.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ def can_retry_initialization(self) -> bool:
4141
"""Check if provider initialization can be retried."""
4242
if self.initialization_attempts >= self.max_attempts:
4343
return False
44-
44+
4545
if self.backoff_until and datetime.now() < self.backoff_until:
4646
return False
47-
47+
4848
return True
4949

5050
def record_initialization_failure(self, error: Exception) -> None:
@@ -54,7 +54,7 @@ def record_initialization_failure(self, error: Exception) -> None:
5454
self.last_error_time = datetime.now()
5555
self.is_initialized = False
5656
self.is_initializing = False
57-
57+
5858
# Exponential backoff: 2^attempts seconds
5959
backoff_seconds = min(2 ** self.initialization_attempts, 300) # Max 5 minutes
6060
self.backoff_until = datetime.now() + timedelta(seconds=backoff_seconds)
@@ -214,7 +214,7 @@ def __init__(self,
214214
self.provider_cleanup_tasks: Set[asyncio.Task] = set()
215215
self._manager_lock = asyncio.Lock()
216216
self._shutdown_event = asyncio.Event()
217-
217+
218218
# Existing configuration
219219
self.fallback_chain: List[str] = []
220220
self.default_provider: Optional[str] = None
@@ -231,7 +231,7 @@ def _register_cleanup(self) -> None:
231231
"""Register cleanup callback for graceful shutdown."""
232232
import atexit
233233
import signal
234-
234+
235235
def cleanup_sync():
236236
"""Synchronous cleanup wrapper."""
237237
try:
@@ -252,7 +252,7 @@ def cleanup_sync():
252252
except Exception as e:
253253
# Silently skip cleanup on shutdown errors
254254
logger.debug(f"Cleanup error (safe to ignore at shutdown): {e}")
255-
255+
256256
atexit.register(cleanup_sync)
257257

258258
def _get_provider_state(self, provider_name: str) -> ProviderState:
@@ -263,16 +263,16 @@ def _get_provider_state(self, provider_name: str) -> ProviderState:
263263

264264
async def _ensure_provider_initialized(self, provider_name: str) -> bool:
265265
"""Ensure provider is properly initialized with thread safety.
266-
266+
267267
Returns:
268268
bool: True if provider is initialized, False if initialization failed
269269
"""
270270
state = self._get_provider_state(provider_name)
271-
271+
272272
# Fast path: already initialized
273273
if state.is_initialized:
274274
return True
275-
275+
276276
# Check if initialization is possible
277277
if not state.can_retry_initialization():
278278
logger.error(f"Provider {provider_name} initialization blocked: "
@@ -285,14 +285,14 @@ async def _ensure_provider_initialized(self, provider_name: str) -> bool:
285285
# Double-check after acquiring lock (another thread might have initialized)
286286
if state.is_initialized:
287287
return True
288-
288+
289289
# Check if already initializing in another coroutine
290290
if state.is_initializing:
291291
logger.info(f"Provider {provider_name} is already being initialized, waiting...")
292292
# Wait for initialization to complete (with timeout)
293293
try:
294294
await asyncio.wait_for(
295-
self._wait_for_initialization(state),
295+
self._wait_for_initialization(state),
296296
timeout=30.0
297297
)
298298
return state.is_initialized
@@ -302,34 +302,34 @@ async def _ensure_provider_initialized(self, provider_name: str) -> bool:
302302

303303
# Mark as initializing
304304
state.is_initializing = True
305-
305+
306306
try:
307307
# Get provider configuration
308308
config = self.config_manager.load_provider_config(provider_name)
309309
provider_instance = self.registry.get_provider(provider_name, config.model_dump())
310310

311311
logger.info(f"Initializing provider: {provider_name}")
312-
312+
313313
# Initialize the provider
314314
await provider_instance.initialize(config.model_dump())
315-
315+
316316
# Mark as successfully initialized
317317
state.record_initialization_success()
318-
318+
319319
logger.info(f"Successfully initialized provider: {provider_name}")
320320
return True
321321

322322
except Exception as e:
323323
# Record the failure
324324
state.record_initialization_failure(e)
325-
325+
326326
logger.error(f"Failed to initialize provider {provider_name} "
327327
f"(attempt {state.initialization_attempts}): {e}")
328-
328+
329329
# If this was the last attempt, mark provider as unhealthy
330330
if not state.can_retry_initialization():
331331
self.load_balancer.update_provider_health(provider_name, False)
332-
332+
333333
return False
334334

335335
async def _wait_for_initialization(self, state: ProviderState) -> None:
@@ -364,7 +364,7 @@ async def _execute_provider_operation(self, provider_name: str, method: str, *ar
364364
operation = getattr(provider_instance, method)
365365
if not callable(operation):
366366
raise ProviderError(provider_name, f"Method {method} not available on provider")
367-
367+
368368
response = await operation(*args, **kwargs)
369369

370370
# Calculate duration and add metadata
@@ -414,20 +414,20 @@ def _is_critical_error(self, error: Exception) -> bool:
414414
critical_error_patterns = [
415415
"connection",
416416
"authentication",
417-
"unauthorized",
417+
"unauthorized",
418418
"invalid_api_key",
419419
"token",
420420
"timeout"
421421
]
422-
422+
423423
error_str = str(error).lower()
424424
return any(pattern in error_str for pattern in critical_error_patterns)
425425

426426
async def cleanup(self) -> None:
427427
"""Enhanced cleanup with proper resource management."""
428428
if self._shutdown_event.is_set():
429429
return # Already cleaning up
430-
430+
431431
logger.info("Starting LLM Manager cleanup...")
432432
self._shutdown_event.set()
433433

@@ -447,15 +447,15 @@ async def cleanup(self) -> None:
447447
cleanup_errors = []
448448
# Only cleanup providers that have been initialized (have instances or states)
449449
providers_to_cleanup = set()
450-
450+
451451
# Add providers that have instances (safely access private attribute)
452452
registry_instances = getattr(self.registry, '_instances', {})
453453
if registry_instances:
454454
providers_to_cleanup.update(registry_instances.keys())
455-
455+
456456
# Add providers that have states (initialized)
457457
providers_to_cleanup.update(self.provider_states.keys())
458-
458+
459459
for provider_name in providers_to_cleanup:
460460
try:
461461
# Only try to cleanup if provider has an instance
@@ -466,8 +466,13 @@ async def cleanup(self) -> None:
466466
logger.debug(f"Cleaned up provider: {provider_name}")
467467
except Exception as e:
468468
# Only track and log non-configuration errors
469-
if "Configuration error" in str(e) or "not configured" in str(e).lower():
469+
error_str = str(e).lower()
470+
if "configuration error" in error_str or "not configured" in error_str:
470471
logger.debug(f"Skipping cleanup for unconfigured provider {provider_name}")
472+
elif "event loop is closed" in error_str:
473+
# During interpreter shutdown, the event loop may already be closed.
474+
# Cleanup is best-effort; don't spam warnings in this case.
475+
logger.debug(f"Skipping cleanup for provider {provider_name}: event loop is closed")
471476
else:
472477
cleanup_errors.append(f"{provider_name}: {e}")
473478
logger.warning(f"Cleanup failed for {provider_name}: {e}")
@@ -483,7 +488,7 @@ async def cleanup(self) -> None:
483488
def get_provider_status(self) -> Dict[str, Dict[str, Any]]:
484489
"""Get detailed status of all providers."""
485490
status = {}
486-
491+
487492
for provider_name in self.registry.list_providers():
488493
state = self._get_provider_state(provider_name)
489494
status[provider_name] = {
@@ -496,22 +501,22 @@ def get_provider_status(self) -> Dict[str, Dict[str, Any]]:
496501
"backoff_until": state.backoff_until.isoformat() if state.backoff_until else None,
497502
"health_status": self.load_balancer.provider_health.get(provider_name, True)
498503
}
499-
504+
500505
return status
501506

502507
async def reset_provider(self, provider_name: str) -> bool:
503508
"""Reset a provider's state and force reinitialization.
504-
509+
505510
Args:
506511
provider_name: Name of provider to reset
507-
512+
508513
Returns:
509514
bool: True if reset successful
510515
"""
511516
if provider_name not in self.provider_states:
512517
logger.warning(f"Provider {provider_name} not found in states")
513518
return False
514-
519+
515520
async with self._manager_lock:
516521
# Reset state
517522
state = self.provider_states[provider_name]
@@ -522,9 +527,9 @@ async def reset_provider(self, provider_name: str) -> bool:
522527
state.last_error = None
523528
state.last_error_time = None
524529
state.backoff_until = None
525-
530+
526531
logger.info(f"Reset provider state: {provider_name}")
527-
532+
528533
# Try to reinitialize
529534
return await self._ensure_provider_initialized(provider_name)
530535
def _initialize_providers(self) -> None:
@@ -603,7 +608,7 @@ async def chat_operation(provider_name: str) -> LLMResponse:
603608
return self.response_normalizer.normalize_response(response)
604609

605610
async def chat_stream(self,messages: List[Message],provider: Optional[str] = None,callbacks: Optional[List[BaseCallbackHandler]] = None,**kwargs) -> AsyncGenerator[LLMResponseChunk, None]:
606-
"""Send streaming chat request with callback support.
611+
"""Send streaming chat request with callback support.
607612
Args:
608613
messages: List of conversation messages
609614
provider: Specific provider to use (optional)
@@ -617,7 +622,7 @@ async def chat_stream(self,messages: List[Message],provider: Optional[str] = Non
617622
providers = self._get_providers_for_request(provider)
618623
provider_name = providers[0]
619624
await self._ensure_provider_initialized(provider_name)
620-
625+
621626
# Get provider instance
622627
provider_instance = self.registry.get_provider(provider_name)
623628

@@ -649,7 +654,7 @@ async def chat_stream(self,messages: List[Message],provider: Optional[str] = Non
649654
provider_name, 'chat_stream', duration, False, error=str(e)
650655
)
651656
raise
652-
657+
653658
def _get_internal_callbacks(self) -> List[BaseCallbackHandler]:
654659
"""Get internal monitoring callbacks."""
655660
# For now, return empty list

0 commit comments

Comments
 (0)