Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions core/cat/routes/websocket/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ async def websocket_endpoint(
except WebSocketDisconnect:
log.info(f"WebSocket connection closed for user {cat.user_id}")
finally:

# cat's working memory in this scope has not been updated
#cat.load_working_memory_from_cache()
cat.load_working_memory_from_cache()

# Remove connection on disconnect
websocket_manager.remove_connection(cat.user_id)
11 changes: 5 additions & 6 deletions core/cat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Dict, Tuple
from pydantic import BaseModel, ConfigDict

from rapidfuzz.fuzz import ratio
from rapidfuzz.distance import Levenshtein
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import PromptTemplate
Expand Down Expand Up @@ -146,7 +145,7 @@ def explicit_error_message(e):
def deprecation_warning(message: str, skip=3):
"""Log a deprecation warning with caller's information.
"skip" is the number of stack levels to go back to the caller info."""

caller = get_caller_info(skip, return_short=False)

# Format and log the warning message
Expand Down Expand Up @@ -179,7 +178,7 @@ def parse_json(json_string: str, pydantic_model: BaseModel = None) -> dict:

# parse
parsed = parser.parse(json_string[start_index:])

if pydantic_model:
return pydantic_model(**parsed)
return parsed
Expand Down Expand Up @@ -207,7 +206,7 @@ def match_prompt_variables(
prompt_template = \
prompt_template.replace("{" + m + "}", "")
log.debug(f"Placeholder '{m}' not found in prompt variables, removed")

return prompt_variables, prompt_template


Expand Down Expand Up @@ -249,7 +248,7 @@ def get_caller_info(skip=2, return_short=True, return_string=True):
start = 0 + skip
if len(stack) < start + 1:
return None

parentframe = stack[start][0]

# module and packagename.
Expand Down Expand Up @@ -341,7 +340,7 @@ class BaseModelDict(BaseModel):
def __getitem__(self, key):
# deprecate dictionary usage
deprecation_warning(
f'To get `{key}` use dot notation instead of dictionary keys, example:'
f'To get `{key}` use dot notation instead of dictionary keys, example:'
f'`obj.{key}` instead of `obj["{key}"]`'
)

Expand Down
33 changes: 19 additions & 14 deletions core/tests/cache/test_core_caches.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,23 @@ def create_cache(cache_type):
assert False



@pytest.mark.parametrize("cache_type", ["in_memory", "file_system"])
def test_cache_creation(cache_type):

cache = create_cache(cache_type)

if cache_type == "in_memory":
assert cache.items == {}
assert cache.max_items == 100
else:
assert cache.cache_dir == "/tmp_cache"
assert os.path.exists("/tmp_cache")
assert os.listdir("/tmp_cache") == []
try:
cache = create_cache(cache_type)

if cache_type == "in_memory":
assert cache.items == {}
assert cache.max_items == 100
else:
assert cache.cache_dir == "/tmp_cache"
assert os.path.exists("/tmp_cache")
assert os.listdir("/tmp_cache") == []
finally:
import shutil
if os.path.exists("/tmp_cache"):
shutil.rmtree("/tmp_cache")


@pytest.mark.parametrize("cache_type", ["in_memory", "file_system"])
Expand All @@ -36,13 +41,13 @@ def test_cache_get_insert(cache_type):
cache = create_cache(cache_type)

assert cache.get_item("a") is None
c1 = CacheItem("a", [])

c1 = CacheItem("a", [])
cache.insert(c1)
assert cache.get_item("a").value == []
assert cache.get_value("a") == []


c1.value = [0]
cache.insert(c1) # will be overwritten
assert cache.get_item("a").value == [0]
Expand All @@ -64,7 +69,7 @@ def test_cache_delete(cache_type):

c1 = CacheItem("a", [])
cache.insert(c1)

cache.delete("a")
assert cache.get_item("a") is None

Expand Down
44 changes: 41 additions & 3 deletions core/tests/routes/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# only valid for in_memory cache
def test_no_sessions_at_startup(client):

for username in ["admin", "user", "Alice"]:
wm = client.app.state.ccat.cache.get_value(f"{username}_working_memory")
assert wm is None
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_session_sync_between_protocols(client, cache_type):


def test_session_sync_while_websocket_is_open(client):

mex = {"text": "Oh dear!"}

# keep open a websocket connection
Expand Down Expand Up @@ -167,6 +167,44 @@ def test_session_sync_while_websocket_is_open(client):
wm = client.app.state.ccat.cache.get_value("Alice_working_memory")
assert len(wm.history) == 0

@pytest.mark.parametrize("cache_type", ["in_memory", "file_system"])
def test_session_sync_when_websocket_gets_closed_and_reopened(client, cache_type):
mex = {"text": "Oh dear!"}

try:
os.environ["CCAT_CACHE_TYPE"] = cache_type
client.app.state.ccat.cache = CacheManager().cache

# keep open a websocket connection
with client.websocket_connect("/ws/Alice") as websocket:
# send ws message
websocket.send_json(mex)
# get reply
res = websocket.receive_json()

# checks
wm = client.app.state.ccat.cache.get_value("Alice_working_memory")
assert res["user_id"] == "Alice"
assert len(wm.history) == 2

# clear convo history via http while nw connection is open
res = client.delete("/memory/conversation_history", headers={"user_id": "Alice"})
# checks
assert res.status_code == 200
wm = client.app.state.ccat.cache.get_value("Alice_working_memory")
assert len(wm.history) == 0

time.sleep(0.5)

# at connection closed, reopen a new connection and rerun checks
with client.websocket_connect("/ws/Alice") as websocket:
wm = client.app.state.ccat.cache.get_value("Alice_working_memory")
assert len(wm.history) == 0

finally:
del os.environ["CCAT_CACHE_TYPE"]



# in_memory cache can store max 100 sessions
def test_sessions_are_deleted_from_in_memory_cache(client):
Expand All @@ -179,7 +217,7 @@ def test_sessions_are_deleted_from_in_memory_cache(client):
assert len(cache.items) <= cache.max_items




# TODO: how do we test that:
# - streaming happens
Expand Down