@@ -38,9 +38,11 @@ class Agent:
3838
3939 Args:
4040 name (str): The agent's name
41+ persist_sessions (bool): whether to persist sessions or not after restarting the agent
4142
4243 Attributes:
4344 _name (str): The agent name
45+ _persist_sessions (bool): whether to persist sessions or not after restarting the agent
4446 _platforms (list[Platform]): The agent platforms
4547 _platforms_threads (list[threading.Thread]): The threads where the platforms are run
4648 _event_loop (asyncio.AbstractEventLoop): The event loop managing external events
@@ -62,8 +64,9 @@ class Agent:
6264 processors (list[Processors]): List of processors used by the agent
6365 """
6466
65- def __init__ (self , name : str ):
67+ def __init__ (self , name : str , persist_sessions : bool = False ):
6668 self ._name : str = name
69+ self ._persist_sessions : bool = persist_sessions
6770 self ._platforms : list [Platform ] = []
6871 self ._platforms_threads : list [threading .Thread ] = []
6972 self ._nlp_engine = NLPEngine (self )
@@ -295,8 +298,15 @@ def _run_platforms(self) -> None:
295298 def _stop_platforms (self ) -> None :
296299 """Stop the execution of the agent platforms"""
297300 for platform , thread in zip (self ._platforms , self ._platforms_threads ):
298- platform .stop ()
299- thread .join ()
301+ try :
302+ platform .stop ()
303+ except KeyboardInterrupt :
304+ logger .warning ('Keyboard interrupt while stopping %s; forcing shutdown.' , platform .__class__ .__name__ )
305+ except Exception as exc :
306+ logger .error ('Error while stopping %s: %s' , platform .__class__ .__name__ , exc )
307+ finally :
308+ if thread .is_alive ():
309+ thread .join (timeout = 5 )
300310 self ._platforms_threads = []
301311
302312 def run (self , train : bool = True , sleep : bool = True ) -> None :
@@ -316,6 +326,9 @@ def run(self, train: bool = True, sleep: bool = True) -> None:
316326 self ._monitoring_db .connect_to_db (self )
317327 if self ._monitoring_db .connected :
318328 self ._monitoring_db .initialize_db ()
329+ if not self ._monitoring_db .connected and self ._persist_sessions :
330+ logger .warning (f'Agent { self ._name } persistence of sessions is enabled, but the monitoring database is not connected. Sessions will not be persisted.' )
331+ self ._persist_sessions = False
319332 self ._run_platforms ()
320333 # self._run_event_thread()
321334 if sleep :
@@ -334,8 +347,9 @@ def stop(self) -> None:
334347 self ._stop_platforms ()
335348 if self .get_property (DB_MONITORING ) and self ._monitoring_db .connected :
336349 self ._monitoring_db .close_connection ()
350+
337351 for session_id in list (self ._sessions .keys ()):
338- self .delete_session (session_id )
352+ self .close_session (session_id )
339353
340354 def reset (self , session_id : str ) -> Session or None :
341355 """Reset the agent current state and memory for the specified session. Then, restart the agent again for this session.
@@ -348,12 +362,13 @@ def reset(self, session_id: str) -> Session or None:
348362 """
349363 if session_id not in self ._sessions :
350364 return None
351- session = self ._sessions [session_id ]
352- new_session = Session (session_id , self , session .platform )
353- self ._sessions [session_id ] = new_session
365+ else :
366+ session = self ._sessions [session_id ]
367+ self .delete_session (session_id )
368+ self .get_or_create_session (session_id , session .platform )
354369 logger .info (f'{ self ._name } restarted by user { session_id } ' )
355- new_session . current_state . run ( new_session )
356- return new_session
370+
371+ return self . _sessions [ session_id ]
357372
358373 def receive_event (self , event : Event ) -> None :
359374 """Receive an external event from a platform.
@@ -363,7 +378,7 @@ def receive_event(self, event: Event) -> None:
363378 Args:
364379 event (Event): the received event
365380 """
366- session = None
381+ session : Session = None
367382 if event .is_broadcasted ():
368383 for session in self ._sessions .values ():
369384 session .events .appendleft (event )
@@ -463,9 +478,20 @@ def _new_session(self, session_id: str, platform: Platform) -> Session:
463478 raise ValueError (f"Platform { platform .__class__ .__name__ } not found in agent '{ self .name } '" )
464479 session = Session (session_id , self , platform )
465480 self ._sessions [session_id ] = session
466- self ._monitoring_db_insert_session (session )
481+ if self ._persist_sessions and self ._monitoring_db_session_exists (session_id , platform ):
482+ dest_state = self ._monitoring_db_get_last_state_of_session (session_id , platform )
483+ if dest_state :
484+ for state in self .states :
485+ if state .name == dest_state :
486+ session ._current_state = state
487+ self ._monitoring_db_load_session_variables (session )
488+ break
489+
490+ else :
491+ self ._monitoring_db_insert_session (session )
492+ session .current_state .run (session )
493+
467494 # ADD LOOP TO CHECK TRANSITIONS HERE
468- session .current_state .run (session )
469495 session ._run_event_thread ()
470496 return session
471497
@@ -475,6 +501,18 @@ def get_or_create_session(self, session_id: str, platform: Platform) -> Session:
475501 session = self ._new_session (session_id , platform )
476502 return session
477503
504+ def close_session (self , session_id : str ) -> None :
505+ """Delete an existing agent session.
506+
507+ Args:
508+ session_id (str): the session id
509+ """
510+ while self ._sessions [session_id ]._agent_connections :
511+ agent_connection = next (iter (self ._sessions [session_id ]._agent_connections .values ()))
512+ agent_connection .close ()
513+ self ._sessions [session_id ]._stop_event_thread ()
514+ del self ._sessions [session_id ]
515+
478516 def delete_session (self , session_id : str ) -> None :
479517 """Delete an existing agent session.
480518
@@ -485,18 +523,20 @@ def delete_session(self, session_id: str) -> None:
485523 agent_connection = next (iter (self ._sessions [session_id ]._agent_connections .values ()))
486524 agent_connection .close ()
487525 self ._sessions [session_id ]._stop_event_thread ()
526+ self ._monitoring_db_delete_session (self ._sessions [session_id ])
488527 del self ._sessions [session_id ]
489528
490- def use_websocket_platform (self , use_ui : bool = True ) -> WebSocketPlatform :
529+ def use_websocket_platform (self , use_ui : bool = True , authenticate_users : bool = False ) -> WebSocketPlatform :
491530 """Use the :class:`~besser.agent.platforms.websocket.websocket_platform.WebSocketPlatform` on this agent.
492531
493532 Args:
494533 use_ui (bool): if true, the default UI will be run to use this platform
495-
534+ authenticate_users (bool): whether to enable user persistence and authentication.
535+ Requires streamlit database configuration. Default is False
496536 Returns:
497537 WebSocketPlatform: the websocket platform
498538 """
499- websocket_platform = WebSocketPlatform (self , use_ui )
539+ websocket_platform = WebSocketPlatform (self , use_ui , authenticate_users )
500540 self ._platforms .append (websocket_platform )
501541 return websocket_platform
502542
@@ -550,6 +590,74 @@ def _monitoring_db_insert_session(self, session: Session) -> None:
550590 # Not in thread since we must ensure it is added before running a state (the chat table needs the session)
551591 self ._monitoring_db .insert_session (session )
552592
593+ def _monitoring_db_session_exists (self , session_id : str , platform : Platform ) -> bool :
594+ """
595+ Check if a session with the given session_id exists in the monitoring database.
596+
597+ Args:
598+ session_id (str): The session ID to check.
599+ platform (Platform): The platform to check.
600+
601+ Returns:
602+ bool: True if the session exists in the database, False otherwise
603+ """
604+ if self .get_property (DB_MONITORING ) and self ._monitoring_db and self ._monitoring_db .connected :
605+ result = self ._monitoring_db .session_exists (self .name , platform .__class__ .__name__ , session_id )
606+ return result
607+ return False
608+
609+ def _monitoring_db_get_last_state_of_session (
610+ self ,
611+ session_id : str ,
612+ platform : Platform
613+ ) -> str | None :
614+ """Get the last state of a session from the monitoring database.
615+
616+ Args:
617+ session_id (str): The session ID to check.
618+ platform (Platform): The platform to check.
619+
620+ Returns:
621+ str | None: The last state of the session if it exists, None otherwise.
622+ """
623+ if self .get_property (DB_MONITORING ) and self ._monitoring_db and self ._monitoring_db .connected :
624+ return self ._monitoring_db .get_last_state_of_session (self .name , platform .__class__ .__name__ , session_id )
625+ return None
626+
627+ def _monitoring_db_store_session_variables (
628+ self ,
629+ session : Session
630+ ) -> None :
631+ """Store the session variables (private data dictionary) in the monitoring database.
632+
633+ Args:
634+ session (Session): The session to store the variables for.
635+ """
636+ if self .get_property (DB_MONITORING ) and self ._monitoring_db .connected :
637+ self ._monitoring_db .store_session_variables (session )
638+
639+ def _monitoring_db_load_session_variables (
640+ self ,
641+ session : Session
642+ ) -> None :
643+ """Load the session variables (private data dictionary) from the monitoring database.
644+
645+ Args:
646+ session (Session): The session to load the variables for.
647+ """
648+ if self .get_property (DB_MONITORING ) and self ._monitoring_db .connected :
649+ self ._monitoring_db .load_session_variables (session )
650+
651+ def _monitoring_db_delete_session (self , session : Session ) -> None :
652+ """Delete a session record from the monitoring database.
653+
654+ Args:
655+ session (Session): the session of the current user
656+ """
657+ if self .get_property (DB_MONITORING ) and self ._monitoring_db .connected :
658+ # Not in thread since we must ensure it is deleted before removing the session
659+ self ._monitoring_db .delete_session (session )
660+
553661 def _monitoring_db_insert_intent_prediction (
554662 self ,
555663 session : Session ,
0 commit comments