Skip to content

Commit ffd8db7

Browse files
nievespg1Gabriel Nieves
andauthored
Gnievesponce prompt tune embedd chunking (#1826)
* Added support for embeddings chunking as defined by the config. * ran semvisor -t patch * Eliminated redunant code by using the embed_text strategy directly * Added fix to support brakets within the corpus text; For example, inline LaTeX within a markdown file --------- Co-authored-by: Gabriel Nieves <[email protected]>
1 parent b7b2b56 commit ffd8db7

File tree

3 files changed

+31
-26
lines changed

3 files changed

+31
-26
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Added batching logic to the prompt tuning autoselection embeddings workflow"
4+
}

graphrag/api/prompt_tune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ async def generate_indexing_prompts(
111111

112112
# if max_retries is not set, inject a dynamically assigned value based on the number of expected LLM calls
113113
# to be made or fallback to a default value in the worst case
114-
if default_llm_settings.max_retries == -1:
114+
if default_llm_settings.max_retries < -1:
115115
default_llm_settings.max_retries = min(
116116
len(doc_list), language_model_defaults.max_retries
117117
)

graphrag/prompt_tune/loader/input.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
import numpy as np
77
import pandas as pd
88

9+
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
910
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
1011
from graphrag.config.models.graph_rag_config import GraphRagConfig
1112
from graphrag.index.input.factory import create_input
13+
from graphrag.index.operations.embed_text.strategies.openai import (
14+
run as run_embed_text,
15+
)
1216
from graphrag.index.workflows.create_base_text_units import create_base_text_units
13-
from graphrag.language_model.manager import ModelManager
14-
from graphrag.language_model.protocol.base import EmbeddingModel
1517
from graphrag.logger.base import ProgressLogger
1618
from graphrag.prompt_tune.defaults import (
1719
LIMIT,
@@ -21,20 +23,9 @@
2123
from graphrag.prompt_tune.types import DocSelectionType
2224

2325

24-
async def _embed_chunks(
25-
text_chunks: pd.DataFrame,
26-
embedding_llm: EmbeddingModel,
27-
n_subset_max: int = N_SUBSET_MAX,
28-
) -> tuple[pd.DataFrame, np.ndarray]:
29-
"""Convert text chunks into dense text embeddings."""
30-
sampled_text_chunks = text_chunks.sample(n=min(n_subset_max, len(text_chunks)))
31-
embeddings = await embedding_llm.aembed_batch(sampled_text_chunks["text"].tolist())
32-
return text_chunks, np.array(embeddings)
33-
34-
3526
def _sample_chunks_from_embeddings(
3627
text_chunks: pd.DataFrame,
37-
embeddings,
28+
embeddings: np.ndarray[float, np.dtype[np.float_]],
3829
k: int = K,
3930
) -> pd.DataFrame:
4031
"""Sample text chunks from embeddings."""
@@ -60,7 +51,6 @@ async def load_docs_in_chunks(
6051
embeddings_llm_settings = config.get_language_model_config(
6152
config.embed_text.model_id
6253
)
63-
6454
dataset = await create_input(config.input, logger, root)
6555
chunk_config = config.chunks
6656
chunks_df = create_base_text_units(
@@ -88,18 +78,29 @@ async def load_docs_in_chunks(
8878
if k is None or k <= 0:
8979
msg = "k must be an integer > 0"
9080
raise ValueError(msg)
91-
embedding_llm = ModelManager().register_embedding(
92-
name="prompt_tuning_embeddings",
93-
model_type=embeddings_llm_settings.type,
94-
config=embeddings_llm_settings,
95-
callbacks=NoopWorkflowCallbacks(),
96-
cache=None,
97-
)
9881

99-
chunks_df, embeddings = await _embed_chunks(
100-
chunks_df, embedding_llm, n_subset_max=n_subset_max
82+
"""Convert text chunks into dense text embeddings."""
83+
sampled_text_chunks = chunks_df.sample(n=min(n_subset_max, len(chunks_df)))[
84+
"text"
85+
].tolist()
86+
87+
embedding_results = await run_embed_text(
88+
sampled_text_chunks,
89+
callbacks=NoopWorkflowCallbacks(),
90+
cache=NoopPipelineCache(),
91+
args={
92+
"llm": embeddings_llm_settings.model_dump(),
93+
"num_threads": embeddings_llm_settings.concurrent_requests,
94+
"batch_size": config.embed_text.batch_size,
95+
"batch_max_tokens": config.embed_text.batch_max_tokens,
96+
},
10197
)
98+
embeddings = np.array(embedding_results.embeddings)
10299
chunks_df = _sample_chunks_from_embeddings(chunks_df, embeddings, k=k)
103100

104101
# Convert the dataset to list form, so we have a list of documents
105-
return chunks_df["text"].tolist()
102+
return [
103+
# need this to prevent the str.format() function from breaking when parsing LaTeX from markdown files
104+
i.replace("{", "{{").replace("}", "}}")
105+
for i in chunks_df["text"]
106+
]

0 commit comments

Comments
 (0)