Skip to content

Commit a49c037

Browse files
committed
feat(Vector Stores): Move the indexer and the searcher to use pluggable stores.
1 parent b605134 commit a49c037

File tree

5 files changed

+56
-185
lines changed

5 files changed

+56
-185
lines changed

wiki_rag/index/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,3 @@
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
"""wiki_rag.index package."""
5-
6-
milvus_url: str = "" # Default Milvus URL, to be shared across the package.

wiki_rag/index/main.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from dotenv import load_dotenv
1313

14-
import wiki_rag.index as index
14+
import wiki_rag.vector as vector
1515

1616
from wiki_rag import LOG_LEVEL, ROOT_DIR, __version__
1717
from wiki_rag.index.util import (
@@ -21,6 +21,7 @@
2121
replace_previous_collection,
2222
)
2323
from wiki_rag.util import setup_logging
24+
from wiki_rag.vector import load_vector_store
2425

2526

2627
def main():
@@ -70,10 +71,10 @@ def main():
7071
logger.error("Collection name not found in environment. Exiting.")
7172
sys.exit(1)
7273

73-
index.milvus_url = os.getenv("MILVUS_URL")
74-
if not index.milvus_url:
75-
logger.error("Milvus URL not found in environment. Exiting.")
76-
sys.exit(1)
74+
index_vendor = os.getenv("INDEX_VENDOR")
75+
if not index_vendor:
76+
logger.warning("Index vendor (INDEX_VENDOR) not found in environment. Defaulting to 'milvus'.")
77+
index_vendor = "milvus"
7778

7879
user_agent = os.getenv("USER_AGENT")
7980
if not user_agent:
@@ -92,6 +93,8 @@ def main():
9293
sys.exit(1)
9394
embedding_dimensions = int(embedding_dimensions)
9495

96+
vector.store = load_vector_store(index_vendor) # Set up the global wiki_rag.vector.store to be used elsewhere.
97+
9598
input_candidate = ""
9699
# TODO: Implement CLI argument to accept the input file here.
97100

wiki_rag/index/util.py

Lines changed: 28 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) 2025, Moodle HQ - Research
22
# SPDX-License-Identifier: BSD-3-Clause
33

4-
"""Util functions to proceed to index the information to Milvus collection."""
4+
"""Util functions to proceed to index to some collection is a vector store / index."""
55

66
import json
77
import logging
@@ -11,17 +11,9 @@
1111

1212
from jsonschema import ValidationError, validate
1313
from langchain_openai import OpenAIEmbeddings
14-
from pymilvus import (
15-
CollectionSchema,
16-
DataType,
17-
FieldSchema,
18-
Function,
19-
FunctionType,
20-
MilvusClient,
21-
)
2214
from tqdm import tqdm
2315

24-
import wiki_rag.index as index
16+
import wiki_rag.vector as vector
2517

2618
from wiki_rag import ROOT_DIR
2719

@@ -72,52 +64,7 @@ def load_parsed_information(input_file: Path) -> dict:
7264

7365
def create_temp_collection_schema(collection_name: str, embedding_dimension: int) -> None:
7466
"""Create a temporary schema for the collection."""
75-
milvus = MilvusClient(index.milvus_url)
76-
if milvus.has_collection(collection_name):
77-
milvus.drop_collection(collection_name)
78-
79-
fields = [
80-
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100),
81-
FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=1000),
82-
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=5000, enable_analyzer=True,
83-
analyzer_params={"type": "english"}, enable_match=True, ),
84-
FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=1000),
85-
FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=embedding_dimension),
86-
FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR),
87-
FieldSchema(name="parent", dtype=DataType.VARCHAR, max_length=100, nullable=True),
88-
FieldSchema(name="children", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_length=4000,
89-
max_capacity=100, is_array=True),
90-
FieldSchema(name="previous", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_length=4000,
91-
max_capacity=100, is_array=True),
92-
FieldSchema(name="next", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_length=4000,
93-
max_capacity=100, is_array=True),
94-
FieldSchema(name="relations", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_length=4000,
95-
max_capacity=100, is_array=True),
96-
FieldSchema(name="page_id", dtype=DataType.INT32),
97-
FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=100),
98-
FieldSchema(name="doc_title", dtype=DataType.VARCHAR, max_length=1000),
99-
FieldSchema(name="doc_hash", dtype=DataType.VARCHAR, max_length=100),
100-
]
101-
schema = CollectionSchema(fields)
102-
103-
bm25_function = Function(
104-
name="text_bm25_emb",
105-
input_field_names=["text"], # Input text field
106-
output_field_names=["sparse_vector"], # Internal mapping sparse vector field
107-
function_type=FunctionType.BM25, # Model for processing mapping relationship
108-
)
109-
110-
schema.add_function(bm25_function)
111-
112-
index_params = milvus.prepare_index_params()
113-
index_params.add_index(field_name="dense_vector", index_type="HNSW", metric_type="IP",
114-
params={"M": 64, "efConstruction": 100})
115-
index_params.add_index(field_name="sparse_vector", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25",
116-
params={"inverted_index_algo": "DAAT_WAND", "drop_ratio_build": 0.2})
117-
118-
milvus.create_collection(collection_name, schema=schema, index_params=index_params)
119-
120-
milvus.close()
67+
vector.store.create_collection(collection_name, embedding_dimension)
12168

12269

12370
def index_pages(
@@ -127,9 +74,7 @@ def index_pages(
12774
embedding_dimension: int
12875
) -> list[int]:
12976
"""Index the pages to the collection."""
130-
milvus = MilvusClient(index.milvus_url)
131-
132-
logging.getLogger("httpx").setLevel(logging.WARNING)
77+
logging.getLogger("httpx").setLevel(logging.WARNING) # Don't log (INFO) all http requests.
13378

13479
embeddings = OpenAIEmbeddings(model=embedding_model, dimensions=embedding_dimension)
13580

@@ -142,62 +87,55 @@ def index_pages(
14287
text_preamble = section["doc_title"]
14388
if section["title"] != section["doc_title"]:
14489
text_preamble = text_preamble + f" / {section['title']}"
145-
text_preamble = text_preamble.strip() + "\n\n"
90+
text_preamble = text_preamble.strip()
14691

14792
# Calculate the complete text (preamble + text, if existing).
14893
text_content = section["text"] if section["text"] else ""
14994
if len(text_content) > 5000:
15095
# TODO: We need to split the text in smaller chunks here, say 2500 max or so. For now, just trim.
15196
text_content = text_content[:5000].strip()
15297
logger.warning(f'Text too long for section "{text_preamble}", trimmed to 5000 characters.')
153-
complete_text = text_preamble + text_content
98+
complete_text = text_preamble + "\n\n" + text_content
15499
logger.debug(f"Embedding {text_preamble}, text len {len(text_content)}")
155100

156101
dense_embedding = embeddings.embed_documents([complete_text])
157102
logger.debug(f"Embedding for {text_preamble}, dim len {len(dense_embedding[0])}")
158-
data = [
159-
{
160-
"id": str(section["id"]),
161-
"title": section["title"],
162-
"text": text_content,
163-
"source": section["source"],
164-
"dense_vector": dense_embedding[0],
165-
"parent": str(section["parent"]) if section["parent"] else None,
166-
"children": [str(child) for child in section["children"]],
167-
"previous": [str(prv) for prv in section["previous"]],
168-
"next": [str(nxt) for nxt in section["next"]],
169-
"relations": [str(rel) for rel in section["relations"]],
170-
"page_id": int(section["page_id"]),
171-
"doc_id": str(section["doc_id"]),
172-
"doc_title": section["doc_title"],
173-
"doc_hash": str(section["doc_hash"]),
174-
}
175-
]
103+
record = {
104+
"id": str(section["id"]),
105+
"title": section["title"],
106+
"text": text_content,
107+
"source": section["source"],
108+
"dense_vector": dense_embedding[0],
109+
"parent": str(section["parent"]) if section["parent"] else None,
110+
"children": [str(child) for child in section["children"]],
111+
"previous": [str(prv) for prv in section["previous"]],
112+
"next": [str(nxt) for nxt in section["next"]],
113+
"relations": [str(rel) for rel in section["relations"]],
114+
"page_id": int(section["page_id"]),
115+
"doc_id": str(section["doc_id"]),
116+
"doc_title": section["doc_title"],
117+
"doc_hash": str(section["doc_hash"]),
118+
}
176119
try:
177-
milvus.insert(collection_name, data)
120+
vector.store.insert_batch(collection_name, [record])
178121
num_sections += 1
179122
except Exception as e:
180123
logger.error(f"Failed to insert data: {e}")
181124
num_pages += 1
182125

183-
milvus.close()
184126
return [num_pages, num_sections]
185127

186128

187129
def replace_previous_collection(collection_name: str, temp_collection_name: str) -> None:
188130
"""Replace the previous collection with the new one."""
189-
milvus = MilvusClient(index.milvus_url)
190-
191-
if not milvus.has_collection(temp_collection_name):
131+
if not vector.store.collection_exists(temp_collection_name):
192132
msg = f"Collection {temp_collection_name} does not exist."
193133
raise ValueError(msg)
194134

195-
if milvus.has_collection(collection_name):
196-
milvus.drop_collection(collection_name)
197-
milvus.rename_collection(temp_collection_name, collection_name)
135+
if vector.store.collection_exists(collection_name):
136+
vector.store.drop_collection(collection_name)
137+
vector.store.rename_collection(temp_collection_name, collection_name)
198138

199139
# We have inserted lots of date to the collection, let's compact it.
200140
logger.info(f"Compacting collection {collection_name}")
201-
milvus.compact(collection_name)
202-
203-
milvus.close()
141+
vector.store.compact_collection(collection_name)

wiki_rag/search/main.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
from langchain_core.messages import AIMessageChunk
1818
from langfuse.langchain import CallbackHandler
1919

20-
import wiki_rag.index as index
20+
import wiki_rag.vector as vector
2121

2222
from wiki_rag import LOG_LEVEL, ROOT_DIR, __version__
2323
from wiki_rag.search.util import ContextSchema, build_graph
2424
from wiki_rag.util import setup_logging
25+
from wiki_rag.vector import load_vector_store
2526

2627

2728
async def run():
@@ -71,10 +72,10 @@ async def run():
7172
logger.error("Collection name not found in environment. Exiting.")
7273
sys.exit(1)
7374

74-
index.milvus_url = os.getenv("MILVUS_URL")
75-
if not index.milvus_url:
76-
logger.error("Milvus URL not found in environment. Exiting.")
77-
sys.exit(1)
75+
index_vendor = os.getenv("INDEX_VENDOR")
76+
if not index_vendor:
77+
logger.warning("Index vendor (INDEX_VENDOR) not found in environment. Defaulting to 'milvus'.")
78+
index_vendor = "milvus"
7879

7980
# If LangSmith tracing is enabled, put a name for the project and verify that all required env vars are set.
8081
if os.getenv("LANGSMITH_TRACING", "false") == "true":
@@ -155,6 +156,8 @@ async def run():
155156

156157
contextualisation_model = os.getenv("CONTEXTUALISATION_MODEL")
157158

159+
vector.store = load_vector_store(index_vendor) # Set up the global wiki_rag.vector.store to be used elsewhere.
160+
158161
# Let's accept arg[1] as the question to be asked.
159162
parser = argparse.ArgumentParser()
160163
parser.add_argument("question", nargs="+", help="The question to be asked.")

wiki_rag/search/util.py

Lines changed: 12 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
MessagesPlaceholder,
1818
SystemMessagePromptTemplate,
1919
)
20-
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
20+
from langchain_openai import ChatOpenAI
2121
from langfuse import Langfuse
2222
from langfuse.langchain import CallbackHandler
2323
from langfuse.model import TextPromptClient
@@ -26,9 +26,8 @@
2626
from langgraph.graph.state import CompiledStateGraph
2727
from langgraph.runtime import Runtime
2828
from langsmith.client import Client
29-
from pymilvus import AnnSearchRequest, MilvusClient, WeightedRanker
3029

31-
import wiki_rag.index as index
30+
import wiki_rag.vector as vector
3231

3332
from wiki_rag import LOG_LEVEL
3433

@@ -346,79 +345,18 @@ async def contextualise_question(
346345

347346

348347
async def retrieve(state: RagState, runtime: Runtime[ContextSchema]) -> dict:
349-
"""Retrieve the best matches from the indexed database.
350-
351-
Here we'll be using Milvus hybrid search that performs a vector search (dense, embeddings)
352-
and a BM25 search (sparse, full text). And then will rerank results with the weighted
353-
reranker.
354-
"""
355-
# Note that here we are using the Milvus own library instead of the LangChain one because
356-
# the LangChain one doesn't support many of the features used here.
357-
embeddings = OpenAIEmbeddings(
358-
model=runtime.context["embedding_model"],
359-
dimensions=runtime.context["embedding_dimension"]
360-
)
361-
query_embedding = embeddings.embed_query(state["question"])
362-
363-
milvus = MilvusClient(index.milvus_url)
364-
365-
# TODO: Make a bunch of the defaults used here configurable.
366-
dense_search_limit = 20
367-
sparse_search_limit = 20
368-
sparse_search_drop_ratio = 0.2
369-
hybrid_rerank_limit = 30
370-
rerank_weights = (0.7, 0.3)
371-
372-
# Define the dense search and its parameters.
373-
dense_search_params = {
374-
"metric_type": "IP",
375-
"params": {
376-
"ef": dense_search_limit,
377-
}
378-
}
379-
dense_search = AnnSearchRequest(
380-
[query_embedding], "dense_vector", dense_search_params, limit=dense_search_limit,
381-
)
382-
383-
# Define the sparse search and its parameters.
384-
sparse_search_params = {
385-
"metric_type": "BM25",
386-
"drop_ratio_search": sparse_search_drop_ratio,
387-
}
388-
sparse_search = AnnSearchRequest(
389-
[state["question"]], "sparse_vector", sparse_search_params, limit=sparse_search_limit,
390-
)
391-
392-
# Perform the hybrid search.
393-
retrieved_docs = milvus.hybrid_search(
394-
runtime.context["collection_name"],
395-
[dense_search, sparse_search],
396-
WeightedRanker(*rerank_weights),
397-
limit=hybrid_rerank_limit,
398-
output_fields=[
399-
"id",
400-
"title",
401-
"text",
402-
"source",
403-
"doc_id",
404-
"doc_title",
405-
"doc_hash",
406-
"parent",
407-
"children",
408-
"previous",
409-
"next",
410-
"relations",
411-
"page_id",
412-
]
348+
"""Retrieve the best matches from the indexed database."""
349+
results = vector.store.retrieve(
350+
collection_name=runtime.context["collection_name"],
351+
embedding_model=runtime.context["embedding_model"],
352+
embedding_dimensions=runtime.context["embedding_dimension"],
353+
query=state["question"],
413354
)
414-
milvus.close()
415355

416356
# TODO: Return only the docs which distance is below the cutoff.
417-
# distance_cutoff = config["configurable"]["search_distance_cutoff"]
357+
# distance_cutoff = runtime.context["search_distance_cutoff"]
418358
# return {"vector_search": [doc for doc in retrieved_docs[0] if doc["distance"] >= distance_cutoff]}
419-
results = [dict(doc) for doc in retrieved_docs[0]] # Need this: Langfuse has problems with Milvus Hit objects.
420-
return {"vector_search": results} # those are UserDict objects, hence, not json-serializable.
421-
# Reported @ https://github.com/langfuse/langfuse/issues/9294 , we'll need to keep the workaround, it seems.
359+
return {"vector_search": results}
422360

423361

424362
async def optimise(state: RagState, runtime: Runtime[ContextSchema]) -> dict:
@@ -555,7 +493,7 @@ def retrieve_all_elements(retrieved_docs, context_list, collection_name: str) ->
555493
context_texts[id] = f"{retrieved[0]["entity"]["title"]}\n\n{retrieved[0]["entity"]["text"]}"
556494
else:
557495
context_texts[id] = None
558-
# If not, let's retrieve it from the milvus collection.
496+
# If not, let's accumulate it for later id based retrieval.
559497
context_missing.append(id)
560498

561499
missing_docs = get_missing_from_vector_store(context_missing, collection_name)
@@ -573,16 +511,7 @@ def get_missing_from_vector_store(context_missing: list, collection_name: str) -
573511
if not context_missing: # No missing elements, nothing extra to retrieve.
574512
return {}
575513

576-
milvus = MilvusClient(index.milvus_url)
577-
578-
# Let's find in the collection, the missing elements and get their titles and texts.
579-
missing_docs_db = milvus.query(
580-
collection_name,
581-
ids=context_missing,
582-
output_fields=["id", "title", "text"])
583-
missing_docs = {doc["id"]: f"{doc["title"]}\n\n{doc["text"]}" for doc in missing_docs_db}
584-
milvus.close()
585-
return missing_docs
514+
return vector.store.get_documents_contents_by_id(collection_name, context_missing)
586515

587516

588517
async def generate(state: RagState, runtime: Runtime[ContextSchema]) -> dict:

0 commit comments

Comments
 (0)