Skip to content

Commit a394a41

Browse files
authored
Merge pull request #159 from BESSER-PEARL/hotfix/initial_message_not_sent
Hotfix/initial message not sent
2 parents 04a2491 + c3b8b58 commit a394a41

File tree

10 files changed

+111
-51
lines changed

10 files changed

+111
-51
lines changed

besser/agent/core/session.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from asyncio import TimerHandle
66
from collections import deque
77
from typing import Any, TYPE_CHECKING
8+
from datetime import datetime
89

910
from pandas import DataFrame
1011
from websocket import WebSocketApp
@@ -142,7 +143,7 @@ def _stop_event_thread(self) -> None:
142143
self._event_loop = None
143144
self._event_thread = None
144145

145-
def get_chat_history(self, n: int = None) -> list[Message]:
146+
def get_chat_history(self, n: int = None, until_timestamp: datetime = None) -> list[Message]:
146147
"""Get the history of messages between this session and its agent.
147148
148149
Args:
@@ -154,7 +155,7 @@ def get_chat_history(self, n: int = None) -> list[Message]:
154155
"""
155156
chat_history: list[Message] = []
156157
if self._agent.get_property(DB_MONITORING) and self._agent._monitoring_db.connected:
157-
chat_df: DataFrame = self._agent._monitoring_db.select_chat(self, n=n)
158+
chat_df: DataFrame = self._agent._monitoring_db.select_chat(self, n=n, until_timestamp=until_timestamp)
158159
for i, row in chat_df.iterrows():
159160
t = get_message_type(row['type'])
160161
chat_history.append(Message(t=t, content=row['content'], is_user=row['is_user'], timestamp=row['timestamp']))

besser/agent/db/monitoring_db.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import TYPE_CHECKING, Any
2+
from typing import TYPE_CHECKING, Any, Optional
33

44
import json
55
import pandas as pd
@@ -432,13 +432,18 @@ def get_last_state_of_session(self, agent_name: str, platform_name: str, session
432432
result_transition = self.conn.execute(stmt_transition).first()
433433
return result_transition[0] if result_transition else None
434434

435-
def select_chat(self, session: Session, n: int) -> pd.DataFrame:
436-
"""Retrieves a conversation history from the chat table of the database.
435+
def select_chat(
436+
self,
437+
session: Session,
438+
n: Optional[int] = None,
439+
until_timestamp: Optional[datetime] = None,
440+
) -> pd.DataFrame:
441+
"""Retrieve chat messages for a session, optionally capped by count or timestamp.
437442
438443
Args:
439-
session (Session): the session to get from the database
440-
n (int or None): the number of messages to get (from the most recents). If none is provided, gets all the
441-
messages
444+
session (Session): Session whose messages should be returned.
445+
n (int | None): Optional cap for the most recent messages; None returns all.
446+
until_timestamp (datetime | None): Optional inclusive upper bound on message timestamps.
442447
Returns:
443448
pandas.DataFrame: the session record, should be a 1 row DataFrame
444449
@@ -448,6 +453,9 @@ def select_chat(self, session: Session, n: int) -> pd.DataFrame:
448453
stmt = (select(table).where(
449454
table.c.session_id == int(session_entry['id'][0])
450455
))
456+
if until_timestamp is not None:
457+
stmt = stmt.where(table.c.timestamp <= until_timestamp)
458+
451459
if n:
452460
stmt = stmt.order_by(desc(table.c.id)).limit(n)
453461
return pd.read_sql_query(stmt, self.conn).sort_values(by='id')

besser/agent/platforms/payload.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,18 +96,16 @@ def decode(payload_str):
9696
payload_dict = json.loads(payload_str)
9797
payload_action = payload_dict['action']
9898
payload_message = payload_dict['message']
99-
user_id = payload_dict.get('user_id')
10099
history = payload_dict.get('history', False)
101100

102101
for action in PayloadAction:
103102
if action.value == payload_action:
104-
return Payload(action, payload_message, user_id=user_id, history=history)
103+
return Payload(action, payload_message, history=history)
105104
return None
106105

107-
def __init__(self, action: PayloadAction, message: str or dict = None, user_id: str = None, history: bool = False):
106+
def __init__(self, action: PayloadAction, message: str or dict = None, history: bool = False):
108107
self.action: str = action.value
109108
self.message: str or dict = message
110-
self.user_id: str = user_id
111109
self.history: bool = history
112110

113111

@@ -135,7 +133,6 @@ def default(self, obj):
135133
payload_dict = {
136134
'action': obj.action,
137135
'message': obj.message,
138-
'user_id': getattr(obj, 'user_id', None),
139136
'history': getattr(obj, 'history', None),
140137
}
141138
return payload_dict

besser/agent/platforms/websocket/streamlit_ui/chat.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def send_option():
7373
payload = Payload(
7474
action=PayloadAction.USER_MESSAGE,
7575
message=option,
76-
user_id=st.session_state.get("username", "Guest"),
7776
)
7877
ws = ensure_websocket_connection()
7978
if not ws:
@@ -140,7 +139,6 @@ def load_chat():
140139
payload = Payload(
141140
action=PayloadAction.FETCH_USER_MESSAGES,
142141
message=None,
143-
user_id=username,
144142
)
145143
try:
146144
ws.send(json.dumps(payload, cls=PayloadEncoder))

besser/agent/platforms/websocket/streamlit_ui/initialization.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,27 @@ def _resolve_host_port():
4242

4343
def _start_websocket(host: str, port: str):
4444
try:
45-
ws = websocket.WebSocketApp(
46-
f"ws://{host}:{port}/",
47-
on_open=on_open,
48-
on_message=on_message,
49-
on_error=on_error,
50-
on_close=on_close,
51-
on_ping=on_ping,
52-
on_pong=on_pong,
53-
)
45+
if st.session_state.get("username"):
46+
ws = websocket.WebSocketApp(
47+
f"ws://{host}:{port}/",
48+
header={"X-User-ID": st.session_state["username"]},
49+
on_open=on_open,
50+
on_message=on_message,
51+
on_error=on_error,
52+
on_close=on_close,
53+
on_ping=on_ping,
54+
on_pong=on_pong,
55+
)
56+
else:
57+
ws = websocket.WebSocketApp(
58+
f"ws://{host}:{port}/",
59+
on_open=on_open,
60+
on_message=on_message,
61+
on_error=on_error,
62+
on_close=on_close,
63+
on_ping=on_ping,
64+
on_pong=on_pong,
65+
)
5466
websocket_thread = threading.Thread(target=ws.run_forever)
5567
add_script_run_ctx(websocket_thread)
5668
websocket_thread.start()

besser/agent/platforms/websocket/streamlit_ui/message_input.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,7 @@ def submit_text():
2323
st.write(user_input)
2424
message = Message(t=MessageType.STR, content=user_input, is_user=True, timestamp=datetime.now())
2525
st.session_state.history.append(message)
26-
if st.session_state.get("authenticated", False):
27-
payload = Payload(action=PayloadAction.USER_MESSAGE,
28-
message=user_input, user_id=st.session_state.get("username"))
29-
else:
30-
payload = Payload(action=PayloadAction.USER_MESSAGE,
31-
message=user_input)
26+
payload = Payload(action=PayloadAction.USER_MESSAGE, message=user_input)
3227
try:
3328
ws = st.session_state[WEBSOCKET]
3429
ws.send(json.dumps(payload, cls=PayloadEncoder))

besser/agent/platforms/websocket/websocket_platform.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import time
88
from datetime import datetime
9+
from urllib.parse import parse_qs, urlsplit
910

1011
import numpy as np
1112
import subprocess
@@ -36,6 +37,28 @@
3637
DB_STREAMLIT
3738
)
3839

40+
def _extract_user_id_from_request(request) -> str | None:
41+
if not request:
42+
return None
43+
for attr in ("path", "raw_path", "uri"):
44+
value = getattr(request, attr, None)
45+
if not value:
46+
continue
47+
if isinstance(value, bytes):
48+
# Prefer UTF-8 for URL paths; fall back to latin-1 to remain robust to non-UTF-8 bytes.
49+
try:
50+
value = value.decode("utf-8")
51+
except UnicodeDecodeError:
52+
value = value.decode("latin-1", errors="replace")
53+
query = urlsplit(value).query
54+
if not query:
55+
continue
56+
params = parse_qs(query)
57+
user_values = params.get("user_id")
58+
if user_values:
59+
return user_values[0]
60+
return None
61+
3962
if TYPE_CHECKING:
4063
from besser.agent.core.agent import Agent
4164

@@ -101,33 +124,34 @@ def message_handler(conn: ServerConnection) -> None:
101124
conn (ServerConnection): the user connection
102125
"""
103126
session: Session = None
127+
current_time = datetime.now()
128+
request = getattr(conn, "request", None)
129+
headers = getattr(request, "headers", {}) if request else {}
130+
header_user = headers.get("X-User-ID") if hasattr(headers, "get") else None
131+
query_user = _extract_user_id_from_request(request)
132+
session_key = header_user or query_user or str(conn.id)
133+
self._connections[str(session_key)] = conn
134+
session = self._agent.get_or_create_session(session_key, self)
104135
try:
136+
105137
for payload_str in conn:
106138
if not self.running:
107139
raise ConnectionClosedError(None, None)
108140
payload: Payload = Payload.decode(payload_str)
109-
if session is None:
110-
if payload.user_id:
111-
session = self._agent.get_or_create_session(payload.user_id, self)
112-
self._connections[str(payload.user_id)] = conn
113-
else:
114-
session = self._agent.get_or_create_session(str(conn.id), self)
115-
self._connections[str(conn.id)] = conn
141+
116142
if payload.action == PayloadAction.FETCH_USER_MESSAGES.value:
117143
try:
118-
chat_history = session.get_chat_history()
144+
chat_history = session.get_chat_history(until_timestamp=current_time)
119145
for message in chat_history:
120146
history_payload = None
121147
if message.is_user:
122148
history_payload = Payload(action=PayloadAction.USER_MESSAGE,
123149
message=message.content,
124-
user_id=session.id,
125150
history=True
126151
)
127152
else:
128153
history_payload = Payload(action=PayloadAction.AGENT_REPLY_STR,
129154
message=message.content,
130-
user_id=session.id,
131155
history=True
132156
)
133157
self._send(session.id, history_payload)
@@ -271,7 +295,6 @@ def reply(self, session: Session, message: str) -> None:
271295
session.save_message(Message(t=MessageType.STR, content=message, is_user=False, timestamp=datetime.now()))
272296
payload = Payload(action=PayloadAction.AGENT_REPLY_STR,
273297
message=message,
274-
user_id=session.id
275298
)
276299
payload.message = self._agent.process(session=session, message=payload.message, is_user_message=False)
277300
self._send(session.id, payload)

docs/source/release_notes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Release Notes
33

44
.. toctree::
55

6+
release_notes/v4.2.2
67
release_notes/v4.2.1
78
release_notes/v4.2.0
89
release_notes/v4.1.0
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Version 4.2.2
2+
=============
3+
4+
Hotfix
5+
------
6+
7+
- Restored initial agent messages when persistence is enabled by updating the user identification protocol to accept the user ID from either the WebSocket header or query parameters.

docs/source/wiki/platforms/websocket_platform.rst

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,19 +130,37 @@ The WebSocket platform allows the following kinds of user messages:
130130

131131
Enabling persistent user sessions
132132
---------------------------------
133-
When building your own UI on top of the WebSocket API, you need to implement a user authentication mechanism to enable session persistence.
134-
BAF relies on user identifiers to map incoming connections to the correct sessions.
135-
Thus, if your UI does not authenticate users, each new connection will be treated as a new user session.
136-
Once your platform authenticates users, you simply need to set a unique identifier for each user in the payload sent to the agent via WebSocket:
133+
When building your own UI on top of the WebSocket API, implement user authentication so every connection can be tied to a stable identifier. BAF maps connections to sessions using either the ``X-User-ID`` header or the ``user_id`` query parameter from the handshake URL. If both are provided, the header takes precedence. If neither identifier is available, the platform treats the connection as a new anonymous user.
134+
135+
Once your client authenticates users, include the identifier in the WebSocket handshake headers:
136+
137+
.. code:: python
138+
139+
ws = websocket.WebSocketApp(
140+
f"ws://{host}:{port}/",
141+
header={"X-User-ID": user_id},
142+
on_open=on_open,
143+
on_message=on_message,
144+
on_error=on_error,
145+
on_close=on_close,
146+
on_ping=on_ping,
147+
on_pong=on_pong,
148+
)
149+
150+
If you cannot control the HTTP headers, you can also attach the identifier as a query parameter on the WebSocket URL:
137151

138152
.. code:: python
139153
140-
payload_dict = {
141-
'action': obj.action,
142-
'message': obj.message,
143-
'user_id': getattr(obj, 'user_id', None),
144-
'history': getattr(obj, 'history', None),
145-
}
154+
ws = websocket.WebSocketApp(
155+
f"ws://{host}:{port}/?user_id={user_id}",
156+
on_open=on_open,
157+
on_message=on_message,
158+
on_error=on_error,
159+
on_close=on_close,
160+
on_ping=on_ping,
161+
on_pong=on_pong,
162+
)
163+
146164
147165
On the agent's side, you'll need to start the monitoring database and set the ``persist_sessions=True`` parameter when initializing the agent:
148166

0 commit comments

Comments
 (0)