diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 410f2807c..3fd8ec35b 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,18 +1,10 @@ # Description -Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. - -Related to issue #(issue) - -## Type of change - -- [ ] Bug fix (non-breaking change which fixes an issue) -- [ ] New feature (non-breaking change which adds functionality) -- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) -- [ ] This change requires a documentation update +Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. # Checklist: -- [ ] My code follows the style guidelines of this project +- [ ] I submitted my PR to branch `develop` - [ ] I have performed a self-review of my own code - [ ] I have commented my code, particularly in hard-to-understand areas +- [ ] I have run or added tests to cover my contribution diff --git a/core/cat/env.py b/core/cat/env.py index 24037e788..75c7e00c1 100644 --- a/core/cat/env.py +++ b/core/cat/env.py @@ -24,6 +24,7 @@ def get_supported_env_variables(): "CCAT_CORS_ENABLED": "true", "CCAT_CACHE_TYPE": "in_memory", "CCAT_CACHE_DIR": "/tmp", + "CCAT_QDRANT_CLIENT_TIMEOUT": None, } diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index c5f030b49..67caece47 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -7,7 +7,7 @@ from langchain_core.runnables import RunnableLambda from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers.string import StrOutputParser -from langchain_community.llms import Cohere +from langchain_cohere import ChatCohere from langchain_openai import ChatOpenAI, OpenAI from langchain_google_genai import ChatGoogleGenerativeAI @@ -208,7 +208,7 @@ def load_language_embedder(self) -> embedders.EmbedderSettings: # For Azure avoid automatic embedder selection # Cohere - elif type(self._llm) in [Cohere]: + elif type(self._llm) in [ChatCohere]: embedder = embedders.EmbedderCohereConfig.get_embedder_from_config( { "cohere_api_key": self._llm.cohere_api_key, diff --git a/core/cat/looking_glass/stray_cat.py b/core/cat/looking_glass/stray_cat.py index dc0f04adf..b71093fd9 100644 --- a/core/cat/looking_glass/stray_cat.py +++ b/core/cat/looking_glass/stray_cat.py @@ -625,7 +625,7 @@ def classify( Allowed classes are: {labels_list}{examples_list} -"{sentence}" -> """ +Just output the class, nothing else.""" response = self.llm(prompt) diff --git a/core/cat/memory/vector_memory.py b/core/cat/memory/vector_memory.py index 18b3c703a..aef685fd3 100644 --- a/core/cat/memory/vector_memory.py +++ b/core/cat/memory/vector_memory.py @@ -65,6 +65,9 @@ def connect_to_vector_memory(self) -> None: qdrant_https = is_https(qdrant_host) qdrant_host = extract_domain_from_url(qdrant_host) qdrant_api_key = get_env("CCAT_QDRANT_API_KEY") + + qdrant_client_timeout = get_env("CCAT_QDRANT_CLIENT_TIMEOUT") + qdrant_client_timeout = int(qdrant_client_timeout) if qdrant_client_timeout is not None else None try: s = socket.socket() @@ -81,6 +84,7 @@ def connect_to_vector_memory(self) -> None: port=qdrant_port, https=qdrant_https, api_key=qdrant_api_key, + timeout=qdrant_client_timeout ) def delete_collection(self, collection_name: str): diff --git a/core/cat/memory/vector_memory_collection.py b/core/cat/memory/vector_memory_collection.py index 92b77afeb..e7cb5b356 100644 --- a/core/cat/memory/vector_memory_collection.py +++ b/core/cat/memory/vector_memory_collection.py @@ -59,11 +59,11 @@ def check_embedding_size(self): == self.embedder_size ) alias = self.embedder_name + "_" + self.collection_name - if ( - alias - == self.client.get_collection_aliases(self.collection_name) - .aliases[0] - .alias_name + + existing_aliases = self.client.get_collection_aliases(self.collection_name).aliases + + if ( len(existing_aliases) > 0 and + alias == existing_aliases[0].alias_name and same_size ): log.debug(f'Collection "{self.collection_name}" has the same embedder') @@ -94,31 +94,48 @@ def create_db_collection_if_not_exists(self): # create collection def create_collection(self): - log.warning(f'Creating collection "{self.collection_name}" ...') - self.client.create_collection( - collection_name=self.collection_name, - vectors_config=VectorParams( - size=self.embedder_size, distance=Distance.COSINE - ), - # hybrid mode: original vector on Disk, quantized vector in RAM - optimizers_config=OptimizersConfigDiff(memmap_threshold=20000), - quantization_config=ScalarQuantization( - scalar=ScalarQuantizationConfig( - type=ScalarType.INT8, quantile=0.95, always_ram=True - ) - ), - ) - - self.client.update_collection_aliases( - change_aliases_operations=[ - CreateAliasOperation( - create_alias=CreateAlias( - collection_name=self.collection_name, - alias_name=self.embedder_name + "_" + self.collection_name, + try: + log.warning(f'Creating collection "{self.collection_name}" ...') + self.client.create_collection( + collection_name=self.collection_name, + vectors_config=VectorParams( + size=self.embedder_size, distance=Distance.COSINE + ), + # hybrid mode: original vector on Disk, quantized vector in RAM + optimizers_config=OptimizersConfigDiff(memmap_threshold=20000), + quantization_config=ScalarQuantization( + scalar=ScalarQuantizationConfig( + type=ScalarType.INT8, quantile=0.95, always_ram=True ) - ) - ] - ) + ), + ) + except Exception as e: + log.error(f"Error creating collection {self.collection_name}. Try setting a higher timeout value in CCAT_QDRANT_CLIENT_TIMEOUT: {e}") + self.client.delete_collection(self.collection_name) + raise + + try: + alias_name=self.embedder_name + "_" + self.collection_name + log.warning(f'Creating alias {alias_name} for collection "{self.collection_name}" ...') + + self.client.update_collection_aliases( + change_aliases_operations=[ + CreateAliasOperation( + create_alias=CreateAlias( + collection_name=self.collection_name, + alias_name=alias_name, + ) + ) + ] + ) + + log.warning(f'Created alias {alias_name} for collection "{self.collection_name}" ...') + except Exception as e: + log.error(f"Error creating collection alias {alias_name} for collection {self.collection_name}: {e}") + self.client.delete_collection(self.collection_name) + log.error(f" collection {self.collection_name} deleted") + raise + # adapted from https://github.com/langchain-ai/langchain/blob/bfc12a4a7644cfc4d832cc4023086a7a5374f46a/libs/langchain/langchain/vectorstores/qdrant.py#L1965 def _qdrant_filter_from_dict(self, filter: dict) -> Filter: diff --git a/core/cat/routes/plugins.py b/core/cat/routes/plugins.py index 8e189d145..bd63c6dc0 100644 --- a/core/cat/routes/plugins.py +++ b/core/cat/routes/plugins.py @@ -28,7 +28,18 @@ async def get_available_plugins( # index registry plugins by url registry_plugins_index = {} for p in registry_plugins: - plugin_url = p["url"] + plugin_url = p.get("plugin_url", None) + if plugin_url is None: + log.warning(f"Plugin {p.get('name')} has no `plugin_url`. It will be skipped from the registry list.") + continue + # url = p.get("url", None) + # if url and url != plugin_url: + # log.info(f"Plugin {p.get('name')} has `url` {url} different from `plugin_url` {plugin_url}. please check the plugin.") + if plugin_url in registry_plugins_index: + current = registry_plugins_index[plugin_url] + log.warning(f"duplicate plugin_url {plugin_url} found in registry. Plugins {p.get('name')} has same url than {current.get('name')}. Skipping.") + continue + registry_plugins_index[plugin_url] = p # get active plugins @@ -53,19 +64,20 @@ async def get_available_plugins( manifest["endpoints"] = [{"name": endpoint.name, "tags": endpoint.tags} for endpoint in p.endpoints] manifest["forms"] = [{"name": form.name} for form in p.forms] + # do not show already installed plugins among registry plugins + r = registry_plugins_index.pop(manifest["plugin_url"], None) + # filter by query plugin_text = [str(field) for field in manifest.values()] plugin_text = " ".join(plugin_text).lower() + if (query is None) or (query.lower() in plugin_text): - for r in registry_plugins: - if r["plugin_url"] == p.manifest["plugin_url"]: - if r["version"] != p.manifest["version"]: - manifest["upgrade"] = r["version"] + if r is not None: + r_version = r.get("version", None) + if r_version is not None and r_version != p.manifest.get("version"): + manifest["upgrade"] = r["version"] installed_plugins.append(manifest) - # do not show already installed plugins among registry plugins - registry_plugins_index.pop(manifest["plugin_url"], None) - return { "filters": { "query": query, @@ -298,4 +310,4 @@ async def delete_plugin( # remove folder, hooks and tools ccat.mad_hatter.uninstall_plugin(plugin_id) - return {"deleted": plugin_id} + return {"deleted": plugin_id} \ No newline at end of file diff --git a/core/cat/routes/websocket/websocket.py b/core/cat/routes/websocket/websocket.py index 576852c76..298c218cf 100644 --- a/core/cat/routes/websocket/websocket.py +++ b/core/cat/routes/websocket/websocket.py @@ -43,9 +43,9 @@ async def websocket_endpoint( except WebSocketDisconnect: log.info(f"WebSocket connection closed for user {cat.user_id}") finally: - + # cat's working memory in this scope has not been updated - #cat.load_working_memory_from_cache() - + cat.load_working_memory_from_cache() + # Remove connection on disconnect websocket_manager.remove_connection(cat.user_id) \ No newline at end of file diff --git a/core/cat/utils.py b/core/cat/utils.py index 8609a4013..a096a8457 100644 --- a/core/cat/utils.py +++ b/core/cat/utils.py @@ -7,7 +7,7 @@ from typing import Dict, Tuple from pydantic import BaseModel, ConfigDict -from langchain.evaluation import StringDistance, load_evaluator, EvaluatorType +from rapidfuzz.distance import Levenshtein from langchain_core.output_parsers import JsonOutputParser from langchain_core.prompts import PromptTemplate from langchain_core.utils import get_colored_text @@ -145,7 +145,7 @@ def explicit_error_message(e): def deprecation_warning(message: str, skip=3): """Log a deprecation warning with caller's information. "skip" is the number of stack levels to go back to the caller info.""" - + caller = get_caller_info(skip, return_short=False) # Format and log the warning message @@ -155,15 +155,8 @@ def deprecation_warning(message: str, skip=3): def levenshtein_distance(prediction: str, reference: str) -> int: - jaro_evaluator = load_evaluator( - EvaluatorType.STRING_DISTANCE, distance=StringDistance.LEVENSHTEIN - ) - result = jaro_evaluator.evaluate_strings( - prediction=prediction, - reference=reference, - ) - return result["score"] - + res = Levenshtein.normalized_distance(prediction, reference) + return res def parse_json(json_string: str, pydantic_model: BaseModel = None) -> dict: # instantiate parser @@ -171,8 +164,8 @@ def parse_json(json_string: str, pydantic_model: BaseModel = None) -> dict: # clean to help small LLMs replaces = { - "\_": "_", - "\-": "-", + "\\_": "_", + "\\-": "-", "None": "null", "{{": "{", "}}": "}", @@ -185,7 +178,7 @@ def parse_json(json_string: str, pydantic_model: BaseModel = None) -> dict: # parse parsed = parser.parse(json_string[start_index:]) - + if pydantic_model: return pydantic_model(**parsed) return parsed @@ -213,7 +206,7 @@ def match_prompt_variables( prompt_template = \ prompt_template.replace("{" + m + "}", "") log.debug(f"Placeholder '{m}' not found in prompt variables, removed") - + return prompt_variables, prompt_template @@ -255,7 +248,7 @@ def get_caller_info(skip=2, return_short=True, return_string=True): start = 0 + skip if len(stack) < start + 1: return None - + parentframe = stack[start][0] # module and packagename. @@ -347,7 +340,7 @@ class BaseModelDict(BaseModel): def __getitem__(self, key): # deprecate dictionary usage deprecation_warning( - f'To get `{key}` use dot notation instead of dictionary keys, example:' + f'To get `{key}` use dot notation instead of dictionary keys, example:' f'`obj.{key}` instead of `obj["{key}"]`' ) diff --git a/core/pyproject.toml b/core/pyproject.toml index 093133ff3..bdfef7e73 100644 --- a/core/pyproject.toml +++ b/core/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "Cheshire-Cat" description = "Production ready AI assistant framework" -version = "1.9.1" +version = "1.9.2" requires-python = ">=3.10" license = { file = "LICENSE" } authors = [ diff --git a/core/tests/cache/test_core_caches.py b/core/tests/cache/test_core_caches.py index a80c6bc7c..a77361228 100644 --- a/core/tests/cache/test_core_caches.py +++ b/core/tests/cache/test_core_caches.py @@ -16,18 +16,23 @@ def create_cache(cache_type): assert False + @pytest.mark.parametrize("cache_type", ["in_memory", "file_system"]) def test_cache_creation(cache_type): - - cache = create_cache(cache_type) - - if cache_type == "in_memory": - assert cache.items == {} - assert cache.max_items == 100 - else: - assert cache.cache_dir == "/tmp_cache" - assert os.path.exists("/tmp_cache") - assert os.listdir("/tmp_cache") == [] + try: + cache = create_cache(cache_type) + + if cache_type == "in_memory": + assert cache.items == {} + assert cache.max_items == 100 + else: + assert cache.cache_dir == "/tmp_cache" + assert os.path.exists("/tmp_cache") + assert os.listdir("/tmp_cache") == [] + finally: + import shutil + if os.path.exists("/tmp_cache"): + shutil.rmtree("/tmp_cache") @pytest.mark.parametrize("cache_type", ["in_memory", "file_system"]) @@ -36,13 +41,13 @@ def test_cache_get_insert(cache_type): cache = create_cache(cache_type) assert cache.get_item("a") is None - - c1 = CacheItem("a", []) + + c1 = CacheItem("a", []) cache.insert(c1) assert cache.get_item("a").value == [] assert cache.get_value("a") == [] - + c1.value = [0] cache.insert(c1) # will be overwritten assert cache.get_item("a").value == [0] @@ -64,7 +69,7 @@ def test_cache_delete(cache_type): c1 = CacheItem("a", []) cache.insert(c1) - + cache.delete("a") assert cache.get_item("a") is None diff --git a/core/tests/routes/test_session.py b/core/tests/routes/test_session.py index 0d1790534..0f1e6da4d 100644 --- a/core/tests/routes/test_session.py +++ b/core/tests/routes/test_session.py @@ -11,7 +11,7 @@ # only valid for in_memory cache def test_no_sessions_at_startup(client): - + for username in ["admin", "user", "Alice"]: wm = client.app.state.ccat.cache.get_value(f"{username}_working_memory") assert wm is None @@ -128,7 +128,7 @@ def test_session_sync_between_protocols(client, cache_type): def test_session_sync_while_websocket_is_open(client): - + mex = {"text": "Oh dear!"} # keep open a websocket connection @@ -167,6 +167,44 @@ def test_session_sync_while_websocket_is_open(client): wm = client.app.state.ccat.cache.get_value("Alice_working_memory") assert len(wm.history) == 0 +@pytest.mark.parametrize("cache_type", ["in_memory", "file_system"]) +def test_session_sync_when_websocket_gets_closed_and_reopened(client, cache_type): + mex = {"text": "Oh dear!"} + + try: + os.environ["CCAT_CACHE_TYPE"] = cache_type + client.app.state.ccat.cache = CacheManager().cache + + # keep open a websocket connection + with client.websocket_connect("/ws/Alice") as websocket: + # send ws message + websocket.send_json(mex) + # get reply + res = websocket.receive_json() + + # checks + wm = client.app.state.ccat.cache.get_value("Alice_working_memory") + assert res["user_id"] == "Alice" + assert len(wm.history) == 2 + + # clear convo history via http while nw connection is open + res = client.delete("/memory/conversation_history", headers={"user_id": "Alice"}) + # checks + assert res.status_code == 200 + wm = client.app.state.ccat.cache.get_value("Alice_working_memory") + assert len(wm.history) == 0 + + time.sleep(0.5) + + # at connection closed, reopen a new connection and rerun checks + with client.websocket_connect("/ws/Alice") as websocket: + wm = client.app.state.ccat.cache.get_value("Alice_working_memory") + assert len(wm.history) == 0 + + finally: + del os.environ["CCAT_CACHE_TYPE"] + + # in_memory cache can store max 100 sessions def test_sessions_are_deleted_from_in_memory_cache(client): @@ -179,7 +217,7 @@ def test_sessions_are_deleted_from_in_memory_cache(client): assert len(cache.items) <= cache.max_items - + # TODO: how do we test that: # - streaming happens