|
6 | 6 | import os |
7 | 7 | import time |
8 | 8 | from datetime import datetime |
| 9 | +from urllib.parse import parse_qs, urlsplit |
9 | 10 |
|
10 | 11 | import numpy as np |
11 | 12 | import subprocess |
|
36 | 37 | DB_STREAMLIT |
37 | 38 | ) |
38 | 39 |
|
| 40 | +def _extract_user_id_from_request(request) -> str | None: |
| 41 | + if not request: |
| 42 | + return None |
| 43 | + for attr in ("path", "raw_path", "uri"): |
| 44 | + value = getattr(request, attr, None) |
| 45 | + if not value: |
| 46 | + continue |
| 47 | + if isinstance(value, bytes): |
| 48 | + # Prefer UTF-8 for URL paths; fall back to latin-1 to remain robust to non-UTF-8 bytes. |
| 49 | + try: |
| 50 | + value = value.decode("utf-8") |
| 51 | + except UnicodeDecodeError: |
| 52 | + value = value.decode("latin-1", errors="replace") |
| 53 | + query = urlsplit(value).query |
| 54 | + if not query: |
| 55 | + continue |
| 56 | + params = parse_qs(query) |
| 57 | + user_values = params.get("user_id") |
| 58 | + if user_values: |
| 59 | + return user_values[0] |
| 60 | + return None |
| 61 | + |
39 | 62 | if TYPE_CHECKING: |
40 | 63 | from besser.agent.core.agent import Agent |
41 | 64 |
|
@@ -101,33 +124,34 @@ def message_handler(conn: ServerConnection) -> None: |
101 | 124 | conn (ServerConnection): the user connection |
102 | 125 | """ |
103 | 126 | session: Session = None |
| 127 | + current_time = datetime.now() |
| 128 | + request = getattr(conn, "request", None) |
| 129 | + headers = getattr(request, "headers", {}) if request else {} |
| 130 | + header_user = headers.get("X-User-ID") if hasattr(headers, "get") else None |
| 131 | + query_user = _extract_user_id_from_request(request) |
| 132 | + session_key = header_user or query_user or str(conn.id) |
| 133 | + self._connections[str(session_key)] = conn |
| 134 | + session = self._agent.get_or_create_session(session_key, self) |
104 | 135 | try: |
| 136 | + |
105 | 137 | for payload_str in conn: |
106 | 138 | if not self.running: |
107 | 139 | raise ConnectionClosedError(None, None) |
108 | 140 | payload: Payload = Payload.decode(payload_str) |
109 | | - if session is None: |
110 | | - if payload.user_id: |
111 | | - session = self._agent.get_or_create_session(payload.user_id, self) |
112 | | - self._connections[str(payload.user_id)] = conn |
113 | | - else: |
114 | | - session = self._agent.get_or_create_session(str(conn.id), self) |
115 | | - self._connections[str(conn.id)] = conn |
| 141 | + |
116 | 142 | if payload.action == PayloadAction.FETCH_USER_MESSAGES.value: |
117 | 143 | try: |
118 | | - chat_history = session.get_chat_history() |
| 144 | + chat_history = session.get_chat_history(until_timestamp=current_time) |
119 | 145 | for message in chat_history: |
120 | 146 | history_payload = None |
121 | 147 | if message.is_user: |
122 | 148 | history_payload = Payload(action=PayloadAction.USER_MESSAGE, |
123 | 149 | message=message.content, |
124 | | - user_id=session.id, |
125 | 150 | history=True |
126 | 151 | ) |
127 | 152 | else: |
128 | 153 | history_payload = Payload(action=PayloadAction.AGENT_REPLY_STR, |
129 | 154 | message=message.content, |
130 | | - user_id=session.id, |
131 | 155 | history=True |
132 | 156 | ) |
133 | 157 | self._send(session.id, history_payload) |
@@ -271,7 +295,6 @@ def reply(self, session: Session, message: str) -> None: |
271 | 295 | session.save_message(Message(t=MessageType.STR, content=message, is_user=False, timestamp=datetime.now())) |
272 | 296 | payload = Payload(action=PayloadAction.AGENT_REPLY_STR, |
273 | 297 | message=message, |
274 | | - user_id=session.id |
275 | 298 | ) |
276 | 299 | payload.message = self._agent.process(session=session, message=payload.message, is_user_message=False) |
277 | 300 | self._send(session.id, payload) |
|
0 commit comments