Skip to content

Commit d2789b3

Browse files
authored
Merge pull request #154 from BESSER-PEARL/release/4.2.0
Release/4.2.0
2 parents ff63639 + 232a3ff commit d2789b3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1052
-97
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
<img src="./docs/source/_static/baf_logo_readme.svg" alt="BESSER Agentic Framework" width="500"/>
33
</div>
44

5-
[![Python](https://img.shields.io/badge/python-3.10%20%7C%203.11-blue?logo=python&logoColor=gold)](https://pypi.org/project/besser-agentic-framework/)
5+
[![Python](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue?logo=python&logoColor=gold)](https://pypi.org/project/besser-agentic-framework/)
66
[![PyPI version](https://img.shields.io/pypi/v/besser-agentic-framework?logo=pypi&logoColor=white)](https://pypi.org/project/besser-agentic-framework/)
77
[![PyPI - Downloads](https://static.pepy.tech/badge/besser-agentic-framework)](https://pypi.org/project/besser-agentic-framework/)
88
[![Documentation Status](https://readthedocs.org/projects/besser-agentic-framework/badge/?version=latest)](https://besser-agentic-framework.readthedocs.io/latest/?badge=latest)
@@ -20,7 +20,7 @@ to make the design and implementation of agents, bots and chatbots easier and ac
2020

2121
### Requirements
2222

23-
- Python 3.11
23+
- Python >=3.10
2424
- Recommended: Create a virtual environment
2525
(e.g. [venv](https://docs.python.org/3/library/venv.html),
2626
[conda](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html))

besser/agent/core/agent.py

Lines changed: 123 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,11 @@ class Agent:
3838
3939
Args:
4040
name (str): The agent's name
41+
persist_sessions (bool): whether to persist sessions or not after restarting the agent
4142
4243
Attributes:
4344
_name (str): The agent name
45+
_persist_sessions (bool): whether to persist sessions or not after restarting the agent
4446
_platforms (list[Platform]): The agent platforms
4547
_platforms_threads (list[threading.Thread]): The threads where the platforms are run
4648
_event_loop (asyncio.AbstractEventLoop): The event loop managing external events
@@ -62,8 +64,9 @@ class Agent:
6264
processors (list[Processors]): List of processors used by the agent
6365
"""
6466

65-
def __init__(self, name: str):
67+
def __init__(self, name: str, persist_sessions: bool = False):
6668
self._name: str = name
69+
self._persist_sessions: bool = persist_sessions
6770
self._platforms: list[Platform] = []
6871
self._platforms_threads: list[threading.Thread] = []
6972
self._nlp_engine = NLPEngine(self)
@@ -295,8 +298,15 @@ def _run_platforms(self) -> None:
295298
def _stop_platforms(self) -> None:
296299
"""Stop the execution of the agent platforms"""
297300
for platform, thread in zip(self._platforms, self._platforms_threads):
298-
platform.stop()
299-
thread.join()
301+
try:
302+
platform.stop()
303+
except KeyboardInterrupt:
304+
logger.warning('Keyboard interrupt while stopping %s; forcing shutdown.', platform.__class__.__name__)
305+
except Exception as exc:
306+
logger.error('Error while stopping %s: %s', platform.__class__.__name__, exc)
307+
finally:
308+
if thread.is_alive():
309+
thread.join(timeout=5)
300310
self._platforms_threads = []
301311

302312
def run(self, train: bool = True, sleep: bool = True) -> None:
@@ -316,6 +326,9 @@ def run(self, train: bool = True, sleep: bool = True) -> None:
316326
self._monitoring_db.connect_to_db(self)
317327
if self._monitoring_db.connected:
318328
self._monitoring_db.initialize_db()
329+
if not self._monitoring_db.connected and self._persist_sessions:
330+
logger.warning(f'Agent {self._name} persistence of sessions is enabled, but the monitoring database is not connected. Sessions will not be persisted.')
331+
self._persist_sessions = False
319332
self._run_platforms()
320333
# self._run_event_thread()
321334
if sleep:
@@ -334,8 +347,9 @@ def stop(self) -> None:
334347
self._stop_platforms()
335348
if self.get_property(DB_MONITORING) and self._monitoring_db.connected:
336349
self._monitoring_db.close_connection()
350+
337351
for session_id in list(self._sessions.keys()):
338-
self.delete_session(session_id)
352+
self.close_session(session_id)
339353

340354
def reset(self, session_id: str) -> Session or None:
341355
"""Reset the agent current state and memory for the specified session. Then, restart the agent again for this session.
@@ -348,12 +362,13 @@ def reset(self, session_id: str) -> Session or None:
348362
"""
349363
if session_id not in self._sessions:
350364
return None
351-
session = self._sessions[session_id]
352-
new_session = Session(session_id, self, session.platform)
353-
self._sessions[session_id] = new_session
365+
else:
366+
session = self._sessions[session_id]
367+
self.delete_session(session_id)
368+
self.get_or_create_session(session_id, session.platform)
354369
logger.info(f'{self._name} restarted by user {session_id}')
355-
new_session.current_state.run(new_session)
356-
return new_session
370+
371+
return self._sessions[session_id]
357372

358373
def receive_event(self, event: Event) -> None:
359374
"""Receive an external event from a platform.
@@ -363,7 +378,7 @@ def receive_event(self, event: Event) -> None:
363378
Args:
364379
event (Event): the received event
365380
"""
366-
session = None
381+
session: Session = None
367382
if event.is_broadcasted():
368383
for session in self._sessions.values():
369384
session.events.appendleft(event)
@@ -463,9 +478,20 @@ def _new_session(self, session_id: str, platform: Platform) -> Session:
463478
raise ValueError(f"Platform {platform.__class__.__name__} not found in agent '{self.name}'")
464479
session = Session(session_id, self, platform)
465480
self._sessions[session_id] = session
466-
self._monitoring_db_insert_session(session)
481+
if self._persist_sessions and self._monitoring_db_session_exists(session_id, platform):
482+
dest_state = self._monitoring_db_get_last_state_of_session(session_id, platform)
483+
if dest_state:
484+
for state in self.states:
485+
if state.name == dest_state:
486+
session._current_state = state
487+
self._monitoring_db_load_session_variables(session)
488+
break
489+
490+
else:
491+
self._monitoring_db_insert_session(session)
492+
session.current_state.run(session)
493+
467494
# ADD LOOP TO CHECK TRANSITIONS HERE
468-
session.current_state.run(session)
469495
session._run_event_thread()
470496
return session
471497

@@ -475,6 +501,18 @@ def get_or_create_session(self, session_id: str, platform: Platform) -> Session:
475501
session = self._new_session(session_id, platform)
476502
return session
477503

504+
def close_session(self, session_id: str) -> None:
505+
"""Delete an existing agent session.
506+
507+
Args:
508+
session_id (str): the session id
509+
"""
510+
while self._sessions[session_id]._agent_connections:
511+
agent_connection = next(iter(self._sessions[session_id]._agent_connections.values()))
512+
agent_connection.close()
513+
self._sessions[session_id]._stop_event_thread()
514+
del self._sessions[session_id]
515+
478516
def delete_session(self, session_id: str) -> None:
479517
"""Delete an existing agent session.
480518
@@ -485,18 +523,20 @@ def delete_session(self, session_id: str) -> None:
485523
agent_connection = next(iter(self._sessions[session_id]._agent_connections.values()))
486524
agent_connection.close()
487525
self._sessions[session_id]._stop_event_thread()
526+
self._monitoring_db_delete_session(self._sessions[session_id])
488527
del self._sessions[session_id]
489528

490-
def use_websocket_platform(self, use_ui: bool = True) -> WebSocketPlatform:
529+
def use_websocket_platform(self, use_ui: bool = True, authenticate_users: bool = False) -> WebSocketPlatform:
491530
"""Use the :class:`~besser.agent.platforms.websocket.websocket_platform.WebSocketPlatform` on this agent.
492531
493532
Args:
494533
use_ui (bool): if true, the default UI will be run to use this platform
495-
534+
authenticate_users (bool): whether to enable user persistence and authentication.
535+
Requires streamlit database configuration. Default is False
496536
Returns:
497537
WebSocketPlatform: the websocket platform
498538
"""
499-
websocket_platform = WebSocketPlatform(self, use_ui)
539+
websocket_platform = WebSocketPlatform(self, use_ui, authenticate_users)
500540
self._platforms.append(websocket_platform)
501541
return websocket_platform
502542

@@ -550,6 +590,74 @@ def _monitoring_db_insert_session(self, session: Session) -> None:
550590
# Not in thread since we must ensure it is added before running a state (the chat table needs the session)
551591
self._monitoring_db.insert_session(session)
552592

593+
def _monitoring_db_session_exists(self, session_id: str, platform: Platform) -> bool:
594+
"""
595+
Check if a session with the given session_id exists in the monitoring database.
596+
597+
Args:
598+
session_id (str): The session ID to check.
599+
platform (Platform): The platform to check.
600+
601+
Returns:
602+
bool: True if the session exists in the database, False otherwise
603+
"""
604+
if self.get_property(DB_MONITORING) and self._monitoring_db and self._monitoring_db.connected:
605+
result = self._monitoring_db.session_exists(self.name, platform.__class__.__name__, session_id)
606+
return result
607+
return False
608+
609+
def _monitoring_db_get_last_state_of_session(
610+
self,
611+
session_id: str,
612+
platform: Platform
613+
) -> str | None:
614+
"""Get the last state of a session from the monitoring database.
615+
616+
Args:
617+
session_id (str): The session ID to check.
618+
platform (Platform): The platform to check.
619+
620+
Returns:
621+
str | None: The last state of the session if it exists, None otherwise.
622+
"""
623+
if self.get_property(DB_MONITORING) and self._monitoring_db and self._monitoring_db.connected:
624+
return self._monitoring_db.get_last_state_of_session(self.name, platform.__class__.__name__, session_id)
625+
return None
626+
627+
def _monitoring_db_store_session_variables(
628+
self,
629+
session: Session
630+
) -> None:
631+
"""Store the session variables (private data dictionary) in the monitoring database.
632+
633+
Args:
634+
session (Session): The session to store the variables for.
635+
"""
636+
if self.get_property(DB_MONITORING) and self._monitoring_db.connected:
637+
self._monitoring_db.store_session_variables(session)
638+
639+
def _monitoring_db_load_session_variables(
640+
self,
641+
session: Session
642+
) -> None:
643+
"""Load the session variables (private data dictionary) from the monitoring database.
644+
645+
Args:
646+
session (Session): The session to load the variables for.
647+
"""
648+
if self.get_property(DB_MONITORING) and self._monitoring_db.connected:
649+
self._monitoring_db.load_session_variables(session)
650+
651+
def _monitoring_db_delete_session(self, session: Session) -> None:
652+
"""Delete a session record from the monitoring database.
653+
654+
Args:
655+
session (Session): the session of the current user
656+
"""
657+
if self.get_property(DB_MONITORING) and self._monitoring_db.connected:
658+
# Not in thread since we must ensure it is deleted before removing the session
659+
self._monitoring_db.delete_session(session)
660+
553661
def _monitoring_db_insert_intent_prediction(
554662
self,
555663
session: Session,

besser/agent/core/session.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ def set(self, key: str, value: Any) -> None:
178178
value (Any): the entry value
179179
"""
180180
self._dictionary[key] = value
181+
try:
182+
self._agent._monitoring_db_store_session_variables(self)
183+
except Exception as e:
184+
logger.error(f"Failed to store session variables to the database for session {self.id}: {e}", exc_info=True)
181185

182186
def get(self, key: str, default: Any = None) -> Any:
183187
"""Get an entry of the session private data storage.
@@ -190,7 +194,7 @@ def get(self, key: str, default: Any = None) -> Any:
190194
Any: the entry value, default or None if the key does not exist
191195
"""
192196
if key not in self._dictionary:
193-
if default:
197+
if default is not None:
194198
return default
195199
return None
196200
return self._dictionary[key]
@@ -204,7 +208,16 @@ def delete(self, key: str) -> None:
204208
try:
205209
del self._dictionary[key]
206210
except Exception as e:
211+
logger.error(f"Failed to delete key '{key}' from session {self.id}: {e}", exc_info=True)
207212
return None
213+
def get_dictionary(self) -> dict[str, Any]:
214+
"""
215+
Returns the private data dictionary for this session.
216+
217+
Returns:
218+
dict[str, Any]: The session's private data storage.
219+
"""
220+
return self._dictionary
208221

209222
def move(self, transition: Transition) -> None:
210223
"""Move to another agent state.

besser/agent/core/state.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from besser.agent.core.transition.transition import Transition
88
from besser.agent.core.transition.transition_builder import TransitionBuilder
99
from besser.agent.library.intent.intent_library import fallback_intent
10-
from besser.agent.library.transition.events.base_events import ReceiveTextEvent, ReceiveFileEvent, WildcardEvent
10+
from besser.agent.library.transition.events.base_events import ReceiveTextEvent, ReceiveFileEvent, WildcardEvent, ReceiveJSONEvent
1111
from besser.agent.library.transition.conditions import IntentMatcher, VariableOperationMatcher
1212
from besser.agent.core.transition.condition import Condition
1313
from besser.agent.core.intent.intent import Intent
@@ -303,16 +303,18 @@ def check_transitions(self, session: Session) -> None:
303303
session.event = session.events.pop()
304304
if isinstance(session.event, ReceiveTextEvent):
305305
session.event.predict_intent(session)
306+
elif isinstance(session.event, ReceiveJSONEvent) and session.event.contains_message:
307+
session.event.predict_intent(session)
306308
if next_transition.evaluate(session, session.event):
307309
session.move(next_transition)
308310
# TODO: Make this configurable (we can consider remove all the previously checked events)
309311
session.events.extend(fallback_deque) # We restore the queue but with the matched event removed
310312
return
311-
if isinstance(session.event, ReceiveTextEvent) and session.event.human:
312-
# There is a ReceiveTextEvent and we couldn't match any transition so far
313+
if (isinstance(session.event, ReceiveTextEvent) and session.event.human) or (isinstance(session.event, ReceiveJSONEvent) and session.event.contains_message and session.event.human):
314+
# There is a ReceiveTextEvent or ReceiveJSONEvent (with message) and we couldn't match any transition so far
313315
run_fallback = True
314316
if i < len(self.transitions)-1:
315-
# We only append ReceiveTextEvent (human) if we didn't finish checking all transitions
317+
# We only append ReceiveTextEvent or ReceiveJSONEvent (human with message) if we didn't finish checking all transitions
316318
fallback_deque.appendleft(session.event)
317319
else:
318320
fallback_deque.appendleft(session.event)

0 commit comments

Comments
 (0)