diff --git a/core/cat/routes/websocket/websocket.py b/core/cat/routes/websocket/websocket.py index 576852c76..298c218cf 100644 --- a/core/cat/routes/websocket/websocket.py +++ b/core/cat/routes/websocket/websocket.py @@ -43,9 +43,9 @@ async def websocket_endpoint( except WebSocketDisconnect: log.info(f"WebSocket connection closed for user {cat.user_id}") finally: - + # cat's working memory in this scope has not been updated - #cat.load_working_memory_from_cache() - + cat.load_working_memory_from_cache() + # Remove connection on disconnect websocket_manager.remove_connection(cat.user_id) \ No newline at end of file diff --git a/core/cat/utils.py b/core/cat/utils.py index fae4992a5..a096a8457 100644 --- a/core/cat/utils.py +++ b/core/cat/utils.py @@ -7,7 +7,6 @@ from typing import Dict, Tuple from pydantic import BaseModel, ConfigDict -from rapidfuzz.fuzz import ratio from rapidfuzz.distance import Levenshtein from langchain_core.output_parsers import JsonOutputParser from langchain_core.prompts import PromptTemplate @@ -146,7 +145,7 @@ def explicit_error_message(e): def deprecation_warning(message: str, skip=3): """Log a deprecation warning with caller's information. "skip" is the number of stack levels to go back to the caller info.""" - + caller = get_caller_info(skip, return_short=False) # Format and log the warning message @@ -179,7 +178,7 @@ def parse_json(json_string: str, pydantic_model: BaseModel = None) -> dict: # parse parsed = parser.parse(json_string[start_index:]) - + if pydantic_model: return pydantic_model(**parsed) return parsed @@ -207,7 +206,7 @@ def match_prompt_variables( prompt_template = \ prompt_template.replace("{" + m + "}", "") log.debug(f"Placeholder '{m}' not found in prompt variables, removed") - + return prompt_variables, prompt_template @@ -249,7 +248,7 @@ def get_caller_info(skip=2, return_short=True, return_string=True): start = 0 + skip if len(stack) < start + 1: return None - + parentframe = stack[start][0] # module and packagename. @@ -341,7 +340,7 @@ class BaseModelDict(BaseModel): def __getitem__(self, key): # deprecate dictionary usage deprecation_warning( - f'To get `{key}` use dot notation instead of dictionary keys, example:' + f'To get `{key}` use dot notation instead of dictionary keys, example:' f'`obj.{key}` instead of `obj["{key}"]`' ) diff --git a/core/tests/cache/test_core_caches.py b/core/tests/cache/test_core_caches.py index a80c6bc7c..a77361228 100644 --- a/core/tests/cache/test_core_caches.py +++ b/core/tests/cache/test_core_caches.py @@ -16,18 +16,23 @@ def create_cache(cache_type): assert False + @pytest.mark.parametrize("cache_type", ["in_memory", "file_system"]) def test_cache_creation(cache_type): - - cache = create_cache(cache_type) - - if cache_type == "in_memory": - assert cache.items == {} - assert cache.max_items == 100 - else: - assert cache.cache_dir == "/tmp_cache" - assert os.path.exists("/tmp_cache") - assert os.listdir("/tmp_cache") == [] + try: + cache = create_cache(cache_type) + + if cache_type == "in_memory": + assert cache.items == {} + assert cache.max_items == 100 + else: + assert cache.cache_dir == "/tmp_cache" + assert os.path.exists("/tmp_cache") + assert os.listdir("/tmp_cache") == [] + finally: + import shutil + if os.path.exists("/tmp_cache"): + shutil.rmtree("/tmp_cache") @pytest.mark.parametrize("cache_type", ["in_memory", "file_system"]) @@ -36,13 +41,13 @@ def test_cache_get_insert(cache_type): cache = create_cache(cache_type) assert cache.get_item("a") is None - - c1 = CacheItem("a", []) + + c1 = CacheItem("a", []) cache.insert(c1) assert cache.get_item("a").value == [] assert cache.get_value("a") == [] - + c1.value = [0] cache.insert(c1) # will be overwritten assert cache.get_item("a").value == [0] @@ -64,7 +69,7 @@ def test_cache_delete(cache_type): c1 = CacheItem("a", []) cache.insert(c1) - + cache.delete("a") assert cache.get_item("a") is None diff --git a/core/tests/routes/test_session.py b/core/tests/routes/test_session.py index 0d1790534..0f1e6da4d 100644 --- a/core/tests/routes/test_session.py +++ b/core/tests/routes/test_session.py @@ -11,7 +11,7 @@ # only valid for in_memory cache def test_no_sessions_at_startup(client): - + for username in ["admin", "user", "Alice"]: wm = client.app.state.ccat.cache.get_value(f"{username}_working_memory") assert wm is None @@ -128,7 +128,7 @@ def test_session_sync_between_protocols(client, cache_type): def test_session_sync_while_websocket_is_open(client): - + mex = {"text": "Oh dear!"} # keep open a websocket connection @@ -167,6 +167,44 @@ def test_session_sync_while_websocket_is_open(client): wm = client.app.state.ccat.cache.get_value("Alice_working_memory") assert len(wm.history) == 0 +@pytest.mark.parametrize("cache_type", ["in_memory", "file_system"]) +def test_session_sync_when_websocket_gets_closed_and_reopened(client, cache_type): + mex = {"text": "Oh dear!"} + + try: + os.environ["CCAT_CACHE_TYPE"] = cache_type + client.app.state.ccat.cache = CacheManager().cache + + # keep open a websocket connection + with client.websocket_connect("/ws/Alice") as websocket: + # send ws message + websocket.send_json(mex) + # get reply + res = websocket.receive_json() + + # checks + wm = client.app.state.ccat.cache.get_value("Alice_working_memory") + assert res["user_id"] == "Alice" + assert len(wm.history) == 2 + + # clear convo history via http while nw connection is open + res = client.delete("/memory/conversation_history", headers={"user_id": "Alice"}) + # checks + assert res.status_code == 200 + wm = client.app.state.ccat.cache.get_value("Alice_working_memory") + assert len(wm.history) == 0 + + time.sleep(0.5) + + # at connection closed, reopen a new connection and rerun checks + with client.websocket_connect("/ws/Alice") as websocket: + wm = client.app.state.ccat.cache.get_value("Alice_working_memory") + assert len(wm.history) == 0 + + finally: + del os.environ["CCAT_CACHE_TYPE"] + + # in_memory cache can store max 100 sessions def test_sessions_are_deleted_from_in_memory_cache(client): @@ -179,7 +217,7 @@ def test_sessions_are_deleted_from_in_memory_cache(client): assert len(cache.items) <= cache.max_items - + # TODO: how do we test that: # - streaming happens