diff --git a/ai_ta_backend/beam/ingest.py b/ai_ta_backend/beam/ingest.py index 16fb9e2b..a68fb7d0 100644 --- a/ai_ta_backend/beam/ingest.py +++ b/ai_ta_backend/beam/ingest.py @@ -1205,7 +1205,7 @@ def split_and_upload(self, texts: List[str], metadatas: List[Dict[str, Any]], ** context.metadata['doc_groups'] = kwargs.get('groups', []) openai_embeddings_key = os.getenv('VLADS_OPENAI_KEY') - if metadatas[0].get('course_name') == 'cropwizard-1.5': + if metadatas[0].get('course_name') in ['cropwizard-1.5', 'ag-research']: print("Using Cropwizard OpenAI key") openai_embeddings_key = os.getenv('CROPWIZARD_OPENAI_KEY') @@ -1247,6 +1247,12 @@ def split_and_upload(self, texts: List[str], metadatas: List[Dict[str, Any]], ** collection_name='cropwizard', points=vectors, ) + elif metadatas[0].get('course_name') == 'ag-research': + print("Uploading to ag-research collection...") + self.cropwizard_qdrant_client.upsert( + collection_name='ag-research', + points=vectors, + ) else: self.qdrant_client.upsert( collection_name=os.environ['QDRANT_COLLECTION_NAME'], @@ -1465,6 +1471,17 @@ def delete_data(self, course_name: str, s3_path: str, source_url: str): ), ]), ) + elif course_name == 'ag-research': + print("Deleting from ag-research collection...") + self.cropwizard_qdrant_client.delete( + collection_name='ag-research', + points_selector=models.Filter(must=[ + models.FieldCondition( + key="s3_path", + match=models.MatchValue(value=s3_path), + ), + ]), + ) else: self.qdrant_client.delete( collection_name=os.environ['QDRANT_COLLECTION_NAME'], diff --git a/ai_ta_backend/database/vector.py b/ai_ta_backend/database/vector.py index ef7fda33..c96f32e2 100644 --- a/ai_ta_backend/database/vector.py +++ b/ai_ta_backend/database/vector.py @@ -69,14 +69,24 @@ def cropwizard_vector_search(self, search_query, course_name, doc_groups: List[s Search the vector database for a given query. """ top_n = 120 - - search_results = self.cropwizard_qdrant_client.search( - collection_name='cropwizard', - query_filter=self._create_search_filter(course_name, doc_groups, disabled_doc_groups, public_doc_groups), - with_vectors=False, - query_vector=user_query_embedding, - limit=top_n, # Return n closest points - ) + if course_name == 'cropwizard-1.5': + search_results = self.cropwizard_qdrant_client.search( + collection_name='cropwizard', + query_filter=self._create_search_filter(course_name, doc_groups, disabled_doc_groups, public_doc_groups), + with_vectors=False, + query_vector=user_query_embedding, + limit=top_n, # Return n closest points + ) + elif course_name == 'ag-research': + search_results = self.cropwizard_qdrant_client.search( + collection_name='ag-research', + query_filter=self._create_search_filter(course_name, doc_groups, disabled_doc_groups, public_doc_groups), + with_vectors=False, + query_vector=user_query_embedding, + limit=top_n, # Return n closest points + ) + else: + raise ValueError(f"Invalid course name: {course_name}") return search_results @@ -246,12 +256,12 @@ def delete_data(self, collection_name: str, key: str, value: str): ]), ) - def delete_data_cropwizard(self, key: str, value: str): + def delete_data_cropwizard(self, collection_name: str, key: str, value: str): """ Delete data from the vector database. """ return self.cropwizard_qdrant_client.delete( - collection_name='cropwizard', + collection_name=collection_name, wait=True, points_selector=models.Filter(must=[ models.FieldCondition( diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index c6f2ad07..70caff97 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -443,7 +443,9 @@ def delete_from_qdrant(self, identifier_key: str, identifier_value: str, course_ print("Deleting from Qdrant") if course_name == 'cropwizard-1.5': # delete from cw db - response = self.vdb.delete_data_cropwizard(identifier_key, identifier_value) + response = self.vdb.delete_data_cropwizard('cropwizard', identifier_key, identifier_value) + elif course_name == 'ag-research': + response = self.vdb.delete_data_cropwizard('ag-research', identifier_key, identifier_value) else: response = self.vdb.delete_data(os.environ['QDRANT_COLLECTION_NAME'], identifier_key, identifier_value) print(f"Qdrant response: {response}") @@ -610,6 +612,9 @@ def vector_search(self, elif course_name == "cropwizard": search_results = self.vdb.cropwizard_vector_search(search_query, course_name, doc_groups, user_query_embedding, top_n, disabled_doc_groups, public_doc_groups) + elif course_name == "ag-research": + search_results = self.vdb.cropwizard_vector_search(search_query, course_name, doc_groups, user_query_embedding, + top_n, disabled_doc_groups, public_doc_groups) elif course_name == "pubmed": search_results = self.vdb.pubmed_vector_search(search_query, course_name, doc_groups, user_query_embedding, top_n, disabled_doc_groups, public_doc_groups)