Skip to content

Commit 65e9a96

Browse files
committed
Merge branch 'main' into reasoning-models
2 parents b3456fc + 74ad1d4 commit 65e9a96

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+7418
-44
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Added batching logic to the prompt tuning autoselection embeddings workflow"
4+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "add vector store integration tests"
4+
}

dictionary.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,8 @@ unnavigated
200200
# Names
201201
Hochul
202202
Ashish
203+
204+
#unified-search
205+
apos
206+
dearmor
207+
venv

graphrag/api/prompt_tune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ async def generate_indexing_prompts(
111111

112112
# if max_retries is not set, inject a dynamically assigned value based on the number of expected LLM calls
113113
# to be made or fallback to a default value in the worst case
114-
if default_llm_settings.max_retries == -1:
114+
if default_llm_settings.max_retries < -1:
115115
default_llm_settings.max_retries = min(
116116
len(doc_list), language_model_defaults.max_retries
117117
)

graphrag/prompt_tune/loader/input.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
import numpy as np
77
import pandas as pd
88

9+
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
910
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
1011
from graphrag.config.models.graph_rag_config import GraphRagConfig
1112
from graphrag.index.input.factory import create_input
13+
from graphrag.index.operations.embed_text.strategies.openai import (
14+
run as run_embed_text,
15+
)
1216
from graphrag.index.workflows.create_base_text_units import create_base_text_units
13-
from graphrag.language_model.manager import ModelManager
14-
from graphrag.language_model.protocol.base import EmbeddingModel
1517
from graphrag.logger.base import ProgressLogger
1618
from graphrag.prompt_tune.defaults import (
1719
LIMIT,
@@ -21,20 +23,9 @@
2123
from graphrag.prompt_tune.types import DocSelectionType
2224

2325

24-
async def _embed_chunks(
25-
text_chunks: pd.DataFrame,
26-
embedding_llm: EmbeddingModel,
27-
n_subset_max: int = N_SUBSET_MAX,
28-
) -> tuple[pd.DataFrame, np.ndarray]:
29-
"""Convert text chunks into dense text embeddings."""
30-
sampled_text_chunks = text_chunks.sample(n=min(n_subset_max, len(text_chunks)))
31-
embeddings = await embedding_llm.aembed_batch(sampled_text_chunks["text"].tolist())
32-
return text_chunks, np.array(embeddings)
33-
34-
3526
def _sample_chunks_from_embeddings(
3627
text_chunks: pd.DataFrame,
37-
embeddings,
28+
embeddings: np.ndarray[float, np.dtype[np.float_]],
3829
k: int = K,
3930
) -> pd.DataFrame:
4031
"""Sample text chunks from embeddings."""
@@ -60,7 +51,6 @@ async def load_docs_in_chunks(
6051
embeddings_llm_settings = config.get_language_model_config(
6152
config.embed_text.model_id
6253
)
63-
6454
dataset = await create_input(config.input, logger, root)
6555
chunk_config = config.chunks
6656
chunks_df = create_base_text_units(
@@ -88,18 +78,29 @@ async def load_docs_in_chunks(
8878
if k is None or k <= 0:
8979
msg = "k must be an integer > 0"
9080
raise ValueError(msg)
91-
embedding_llm = ModelManager().register_embedding(
92-
name="prompt_tuning_embeddings",
93-
model_type=embeddings_llm_settings.type,
94-
config=embeddings_llm_settings,
95-
callbacks=NoopWorkflowCallbacks(),
96-
cache=None,
97-
)
9881

99-
chunks_df, embeddings = await _embed_chunks(
100-
chunks_df, embedding_llm, n_subset_max=n_subset_max
82+
"""Convert text chunks into dense text embeddings."""
83+
sampled_text_chunks = chunks_df.sample(n=min(n_subset_max, len(chunks_df)))[
84+
"text"
85+
].tolist()
86+
87+
embedding_results = await run_embed_text(
88+
sampled_text_chunks,
89+
callbacks=NoopWorkflowCallbacks(),
90+
cache=NoopPipelineCache(),
91+
args={
92+
"llm": embeddings_llm_settings.model_dump(),
93+
"num_threads": embeddings_llm_settings.concurrent_requests,
94+
"batch_size": config.embed_text.batch_size,
95+
"batch_max_tokens": config.embed_text.batch_max_tokens,
96+
},
10197
)
98+
embeddings = np.array(embedding_results.embeddings)
10299
chunks_df = _sample_chunks_from_embeddings(chunks_df, embeddings, k=k)
103100

104101
# Convert the dataset to list form, so we have a list of documents
105-
return chunks_df["text"].tolist()
102+
return [
103+
# need this to prevent the str.format() function from breaking when parsing LaTeX from markdown files
104+
i.replace("{", "{{").replace("}", "}}")
105+
for i in chunks_df["text"]
106+
]

graphrag/vector_stores/cosmosdb.py

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any
88

99
from azure.cosmos import ContainerProxy, CosmosClient, DatabaseProxy
10+
from azure.cosmos.exceptions import CosmosHttpResponseError
1011
from azure.cosmos.partition_key import PartitionKey
1112
from azure.identity import DefaultAzureCredential
1213

@@ -19,7 +20,7 @@
1920
)
2021

2122

22-
class CosmosDBVectoreStore(BaseVectorStore):
23+
class CosmosDBVectorStore(BaseVectorStore):
2324
"""Azure CosmosDB vector storage implementation."""
2425

2526
_cosmos_client: CosmosClient
@@ -99,16 +100,32 @@ def _create_container(self) -> None:
99100
"automatic": True,
100101
"includedPaths": [{"path": "/*"}],
101102
"excludedPaths": [{"path": "/_etag/?"}, {"path": "/vector/*"}],
102-
"vectorIndexes": [{"path": "/vector", "type": "diskANN"}],
103103
}
104104

105-
# Create the container and container client
106-
self._database_client.create_container_if_not_exists(
107-
id=self._container_name,
108-
partition_key=partition_key,
109-
indexing_policy=indexing_policy,
110-
vector_embedding_policy=vector_embedding_policy,
111-
)
105+
# Currently, the CosmosDB emulator does not support the diskANN policy.
106+
try:
107+
# First try with the standard diskANN policy
108+
indexing_policy["vectorIndexes"] = [{"path": "/vector", "type": "diskANN"}]
109+
110+
# Create the container and container client
111+
self._database_client.create_container_if_not_exists(
112+
id=self._container_name,
113+
partition_key=partition_key,
114+
indexing_policy=indexing_policy,
115+
vector_embedding_policy=vector_embedding_policy,
116+
)
117+
except CosmosHttpResponseError:
118+
# If diskANN fails (likely in emulator), retry without vector indexes
119+
indexing_policy.pop("vectorIndexes", None)
120+
121+
# Create the container with compatible indexing policy
122+
self._database_client.create_container_if_not_exists(
123+
id=self._container_name,
124+
partition_key=partition_key,
125+
indexing_policy=indexing_policy,
126+
vector_embedding_policy=vector_embedding_policy,
127+
)
128+
112129
self._container_client = self._database_client.get_container_client(
113130
self._container_name
114131
)
@@ -157,13 +174,46 @@ def similarity_search_by_vector(
157174
msg = "Container client is not initialized."
158175
raise ValueError(msg)
159176

160-
query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608
161-
query_params = [{"name": "@embedding", "value": query_embedding}]
162-
items = self._container_client.query_items(
163-
query=query,
164-
parameters=query_params,
165-
enable_cross_partition_query=True,
166-
)
177+
try:
178+
query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608
179+
query_params = [{"name": "@embedding", "value": query_embedding}]
180+
items = list(
181+
self._container_client.query_items(
182+
query=query,
183+
parameters=query_params,
184+
enable_cross_partition_query=True,
185+
)
186+
)
187+
except (CosmosHttpResponseError, ValueError):
188+
# Currently, the CosmosDB emulator does not support the VectorDistance function.
189+
# For emulator or test environments - fetch all items and calculate distance locally
190+
query = "SELECT c.id, c.text, c.vector, c.attributes FROM c"
191+
items = list(
192+
self._container_client.query_items(
193+
query=query,
194+
enable_cross_partition_query=True,
195+
)
196+
)
197+
198+
# Calculate cosine similarity locally (1 - cosine distance)
199+
from numpy import dot
200+
from numpy.linalg import norm
201+
202+
def cosine_similarity(a, b):
203+
if norm(a) * norm(b) == 0:
204+
return 0.0
205+
return dot(a, b) / (norm(a) * norm(b))
206+
207+
# Calculate scores for all items
208+
for item in items:
209+
item_vector = item.get("vector", [])
210+
similarity = cosine_similarity(query_embedding, item_vector)
211+
item["SimilarityScore"] = similarity
212+
213+
# Sort by similarity score (higher is better) and take top k
214+
items = sorted(
215+
items, key=lambda x: x.get("SimilarityScore", 0.0), reverse=True
216+
)[:k]
167217

168218
return [
169219
VectorStoreSearchResult(
@@ -214,3 +264,8 @@ def search_by_id(self, id: str) -> VectorStoreDocument:
214264
text=item.get("text", ""),
215265
attributes=(json.loads(item.get("attributes", "{}"))),
216266
)
267+
268+
def clear(self) -> None:
269+
"""Clear the vector store."""
270+
self._delete_container()
271+
self._delete_database()

graphrag/vector_stores/factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
1010
from graphrag.vector_stores.base import BaseVectorStore
11-
from graphrag.vector_stores.cosmosdb import CosmosDBVectoreStore
11+
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
1212
from graphrag.vector_stores.lancedb import LanceDBVectorStore
1313

1414

@@ -44,7 +44,7 @@ def create_vector_store(
4444
case VectorStoreType.AzureAISearch:
4545
return AzureAISearchVectorStore(**kwargs)
4646
case VectorStoreType.CosmosDB:
47-
return CosmosDBVectoreStore(**kwargs)
47+
return CosmosDBVectorStore(**kwargs)
4848
case _:
4949
if vector_store_type in cls.vector_store_types:
5050
return cls.vector_store_types[vector_store_type](**kwargs)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""Integration tests for vector store implementations."""

0 commit comments

Comments
 (0)