66import numpy as np
77import pandas as pd
88
9+ from graphrag .cache .noop_pipeline_cache import NoopPipelineCache
910from graphrag .callbacks .noop_workflow_callbacks import NoopWorkflowCallbacks
1011from graphrag .config .models .graph_rag_config import GraphRagConfig
1112from graphrag .index .input .factory import create_input
13+ from graphrag .index .operations .embed_text .strategies .openai import (
14+ run as run_embed_text ,
15+ )
1216from graphrag .index .workflows .create_base_text_units import create_base_text_units
13- from graphrag .language_model .manager import ModelManager
14- from graphrag .language_model .protocol .base import EmbeddingModel
1517from graphrag .logger .base import ProgressLogger
1618from graphrag .prompt_tune .defaults import (
1719 LIMIT ,
2123from graphrag .prompt_tune .types import DocSelectionType
2224
2325
24- async def _embed_chunks (
25- text_chunks : pd .DataFrame ,
26- embedding_llm : EmbeddingModel ,
27- n_subset_max : int = N_SUBSET_MAX ,
28- ) -> tuple [pd .DataFrame , np .ndarray ]:
29- """Convert text chunks into dense text embeddings."""
30- sampled_text_chunks = text_chunks .sample (n = min (n_subset_max , len (text_chunks )))
31- embeddings = await embedding_llm .aembed_batch (sampled_text_chunks ["text" ].tolist ())
32- return text_chunks , np .array (embeddings )
33-
34-
3526def _sample_chunks_from_embeddings (
3627 text_chunks : pd .DataFrame ,
37- embeddings ,
28+ embeddings : np . ndarray [ float , np . dtype [ np . float_ ]] ,
3829 k : int = K ,
3930) -> pd .DataFrame :
4031 """Sample text chunks from embeddings."""
@@ -60,7 +51,6 @@ async def load_docs_in_chunks(
6051 embeddings_llm_settings = config .get_language_model_config (
6152 config .embed_text .model_id
6253 )
63-
6454 dataset = await create_input (config .input , logger , root )
6555 chunk_config = config .chunks
6656 chunks_df = create_base_text_units (
@@ -88,18 +78,29 @@ async def load_docs_in_chunks(
8878 if k is None or k <= 0 :
8979 msg = "k must be an integer > 0"
9080 raise ValueError (msg )
91- embedding_llm = ModelManager ().register_embedding (
92- name = "prompt_tuning_embeddings" ,
93- model_type = embeddings_llm_settings .type ,
94- config = embeddings_llm_settings ,
95- callbacks = NoopWorkflowCallbacks (),
96- cache = None ,
97- )
9881
99- chunks_df , embeddings = await _embed_chunks (
100- chunks_df , embedding_llm , n_subset_max = n_subset_max
82+ """Convert text chunks into dense text embeddings."""
83+ sampled_text_chunks = chunks_df .sample (n = min (n_subset_max , len (chunks_df )))[
84+ "text"
85+ ].tolist ()
86+
87+ embedding_results = await run_embed_text (
88+ sampled_text_chunks ,
89+ callbacks = NoopWorkflowCallbacks (),
90+ cache = NoopPipelineCache (),
91+ args = {
92+ "llm" : embeddings_llm_settings .model_dump (),
93+ "num_threads" : embeddings_llm_settings .concurrent_requests ,
94+ "batch_size" : config .embed_text .batch_size ,
95+ "batch_max_tokens" : config .embed_text .batch_max_tokens ,
96+ },
10197 )
98+ embeddings = np .array (embedding_results .embeddings )
10299 chunks_df = _sample_chunks_from_embeddings (chunks_df , embeddings , k = k )
103100
104101 # Convert the dataset to list form, so we have a list of documents
105- return chunks_df ["text" ].tolist ()
102+ return [
103+ # need this to prevent the str.format() function from breaking when parsing LaTeX from markdown files
104+ i .replace ("{" , "{{" ).replace ("}" , "}}" )
105+ for i in chunks_df ["text" ]
106+ ]
0 commit comments