Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion ai_ta_backend/beam/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,7 @@ def split_and_upload(self, texts: List[str], metadatas: List[Dict[str, Any]], **
context.metadata['doc_groups'] = kwargs.get('groups', [])

openai_embeddings_key = os.getenv('VLADS_OPENAI_KEY')
if metadatas[0].get('course_name') == 'cropwizard-1.5':
if metadatas[0].get('course_name') in ['cropwizard-1.5', 'ag-research']:
print("Using Cropwizard OpenAI key")
openai_embeddings_key = os.getenv('CROPWIZARD_OPENAI_KEY')

Expand Down Expand Up @@ -1247,6 +1247,12 @@ def split_and_upload(self, texts: List[str], metadatas: List[Dict[str, Any]], **
collection_name='cropwizard',
points=vectors,
)
elif metadatas[0].get('course_name') == 'ag-research':
print("Uploading to ag-research collection...")
self.cropwizard_qdrant_client.upsert(
collection_name='ag-research',
points=vectors,
)
else:
self.qdrant_client.upsert(
collection_name=os.environ['QDRANT_COLLECTION_NAME'],
Expand Down Expand Up @@ -1465,6 +1471,17 @@ def delete_data(self, course_name: str, s3_path: str, source_url: str):
),
]),
)
elif course_name == 'ag-research':
print("Deleting from ag-research collection...")
self.cropwizard_qdrant_client.delete(
collection_name='ag-research',
points_selector=models.Filter(must=[
models.FieldCondition(
key="s3_path",
match=models.MatchValue(value=s3_path),
),
]),
)
else:
self.qdrant_client.delete(
collection_name=os.environ['QDRANT_COLLECTION_NAME'],
Expand Down
30 changes: 20 additions & 10 deletions ai_ta_backend/database/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,24 @@ def cropwizard_vector_search(self, search_query, course_name, doc_groups: List[s
Search the vector database for a given query.
"""
top_n = 120

search_results = self.cropwizard_qdrant_client.search(
collection_name='cropwizard',
query_filter=self._create_search_filter(course_name, doc_groups, disabled_doc_groups, public_doc_groups),
with_vectors=False,
query_vector=user_query_embedding,
limit=top_n, # Return n closest points
)
if course_name == 'cropwizard-1.5':
search_results = self.cropwizard_qdrant_client.search(
collection_name='cropwizard',
query_filter=self._create_search_filter(course_name, doc_groups, disabled_doc_groups, public_doc_groups),
with_vectors=False,
query_vector=user_query_embedding,
limit=top_n, # Return n closest points
)
elif course_name == 'ag-research':
search_results = self.cropwizard_qdrant_client.search(
collection_name='ag-research',
query_filter=self._create_search_filter(course_name, doc_groups, disabled_doc_groups, public_doc_groups),
with_vectors=False,
query_vector=user_query_embedding,
limit=top_n, # Return n closest points
)
else:
raise ValueError(f"Invalid course name: {course_name}")

return search_results

Expand Down Expand Up @@ -246,12 +256,12 @@ def delete_data(self, collection_name: str, key: str, value: str):
]),
)

def delete_data_cropwizard(self, key: str, value: str):
def delete_data_cropwizard(self, collection_name: str, key: str, value: str):
"""
Delete data from the vector database.
"""
return self.cropwizard_qdrant_client.delete(
collection_name='cropwizard',
collection_name=collection_name,
wait=True,
points_selector=models.Filter(must=[
models.FieldCondition(
Expand Down
7 changes: 6 additions & 1 deletion ai_ta_backend/service/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,9 @@ def delete_from_qdrant(self, identifier_key: str, identifier_value: str, course_
print("Deleting from Qdrant")
if course_name == 'cropwizard-1.5':
# delete from cw db
response = self.vdb.delete_data_cropwizard(identifier_key, identifier_value)
response = self.vdb.delete_data_cropwizard('cropwizard', identifier_key, identifier_value)
elif course_name == 'ag-research':
response = self.vdb.delete_data_cropwizard('ag-research', identifier_key, identifier_value)
else:
response = self.vdb.delete_data(os.environ['QDRANT_COLLECTION_NAME'], identifier_key, identifier_value)
print(f"Qdrant response: {response}")
Expand Down Expand Up @@ -610,6 +612,9 @@ def vector_search(self,
elif course_name == "cropwizard":
search_results = self.vdb.cropwizard_vector_search(search_query, course_name, doc_groups, user_query_embedding,
top_n, disabled_doc_groups, public_doc_groups)
elif course_name == "ag-research":
search_results = self.vdb.cropwizard_vector_search(search_query, course_name, doc_groups, user_query_embedding,
top_n, disabled_doc_groups, public_doc_groups)
elif course_name == "pubmed":
search_results = self.vdb.pubmed_vector_search(search_query, course_name, doc_groups, user_query_embedding, top_n,
disabled_doc_groups, public_doc_groups)
Expand Down