Skip to content

Commit 185f513

Browse files
Basic search implementation (#1563)
* basic search implementation * basic streaming functionality * format check * check fix * release change * Chore/gleanings any encoding (#1569) * Make claims and entities independent of encoding * Semver * Change semver release type --------- Co-authored-by: Alonso Guevara <[email protected]>
1 parent 5f9ad0d commit 185f513

File tree

22 files changed

+915
-198
lines changed

22 files changed

+915
-198
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "minor",
3+
"description": "new search implemented as a new option for the api"
4+
}

graphrag/api/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from graphrag.api.index import build_index
1111
from graphrag.api.prompt_tune import generate_indexing_prompts
1212
from graphrag.api.query import (
13+
basic_search,
14+
basic_search_streaming,
1315
drift_search,
1416
global_search,
1517
global_search_streaming,
@@ -27,6 +29,8 @@
2729
"local_search",
2830
"local_search_streaming",
2931
"drift_search",
32+
"basic_search",
33+
"basic_search_streaming",
3034
# prompt tuning API
3135
"DocSelectionType",
3236
"generate_indexing_prompts",

graphrag/api/query.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828
from graphrag.index.config.embeddings import (
2929
community_full_content_embedding,
3030
entity_description_embedding,
31+
text_unit_text_embedding,
3132
)
3233
from graphrag.logger.print_progress import PrintProgressLogger
3334
from graphrag.query.factory import (
35+
get_basic_search_engine,
3436
get_drift_search_engine,
3537
get_global_search_engine,
3638
get_local_search_engine,
@@ -423,6 +425,109 @@ async def drift_search(
423425
return response, context_data
424426

425427

428+
@validate_call(config={"arbitrary_types_allowed": True})
429+
async def basic_search(
430+
config: GraphRagConfig,
431+
text_units: pd.DataFrame,
432+
query: str,
433+
) -> tuple[
434+
str | dict[str, Any] | list[dict[str, Any]],
435+
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
436+
]:
437+
"""Perform a basic search and return the context data and response.
438+
439+
Parameters
440+
----------
441+
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
442+
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
443+
- response_type (str): The response type to return.
444+
- query (str): The user query to search for.
445+
446+
Returns
447+
-------
448+
TODO: Document the search response type and format.
449+
450+
Raises
451+
------
452+
TODO: Document any exceptions to expect.
453+
"""
454+
vector_store_args = config.embeddings.vector_store
455+
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
456+
457+
description_embedding_store = _get_embedding_store(
458+
config_args=vector_store_args, # type: ignore
459+
embedding_name=text_unit_text_embedding,
460+
)
461+
462+
prompt = _load_search_prompt(config.root_dir, config.basic_search.prompt)
463+
464+
search_engine = get_basic_search_engine(
465+
config=config,
466+
text_units=read_indexer_text_units(text_units),
467+
text_unit_embeddings=description_embedding_store,
468+
system_prompt=prompt,
469+
)
470+
471+
result: SearchResult = await search_engine.asearch(query=query)
472+
response = result.response
473+
context_data = _reformat_context_data(result.context_data) # type: ignore
474+
return response, context_data
475+
476+
477+
@validate_call(config={"arbitrary_types_allowed": True})
478+
async def basic_search_streaming(
479+
config: GraphRagConfig,
480+
text_units: pd.DataFrame,
481+
query: str,
482+
) -> AsyncGenerator:
483+
"""Perform a local search and return the context data and response via a generator.
484+
485+
Parameters
486+
----------
487+
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
488+
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
489+
- query (str): The user query to search for.
490+
491+
Returns
492+
-------
493+
TODO: Document the search response type and format.
494+
495+
Raises
496+
------
497+
TODO: Document any exceptions to expect.
498+
"""
499+
vector_store_args = config.embeddings.vector_store
500+
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
501+
502+
description_embedding_store = _get_embedding_store(
503+
config_args=vector_store_args, # type: ignore
504+
embedding_name=text_unit_text_embedding,
505+
)
506+
507+
prompt = _load_search_prompt(config.root_dir, config.basic_search.prompt)
508+
509+
search_engine = get_basic_search_engine(
510+
config=config,
511+
text_units=read_indexer_text_units(text_units),
512+
text_unit_embeddings=description_embedding_store,
513+
system_prompt=prompt,
514+
)
515+
516+
search_result = search_engine.astream_search(query=query)
517+
518+
# when streaming results, a context data object is returned as the first result
519+
# and the query response in subsequent tokens
520+
context_data = None
521+
get_context_data = True
522+
async for stream_chunk in search_result:
523+
if get_context_data:
524+
context_data = _reformat_context_data(stream_chunk) # type: ignore
525+
yield context_data
526+
get_context_data = False
527+
else:
528+
yield stream_chunk
529+
530+
426531
def _get_embedding_store(
427532
config_args: dict,
428533
embedding_name: str,

graphrag/cli/initialize.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from graphrag.prompts.index.entity_extraction import GRAPH_EXTRACTION_PROMPT
1515
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
16+
from graphrag.prompts.query.basic_search_system_prompt import BASIC_SEARCH_SYSTEM_PROMPT
1617
from graphrag.prompts.query.drift_search_system_prompt import DRIFT_LOCAL_SYSTEM_PROMPT
1718
from graphrag.prompts.query.global_search_knowledge_system_prompt import (
1819
GENERAL_KNOWLEDGE_INSTRUCTION,
@@ -60,6 +61,7 @@ def initialize_project_at(path: Path) -> None:
6061
"global_search_reduce_system_prompt": REDUCE_SYSTEM_PROMPT,
6162
"global_search_knowledge_system_prompt": GENERAL_KNOWLEDGE_INSTRUCTION,
6263
"local_search_system_prompt": LOCAL_SEARCH_SYSTEM_PROMPT,
64+
"basic_search_system_prompt": BASIC_SEARCH_SYSTEM_PROMPT,
6365
"question_gen_system_prompt": QUESTION_SYSTEM_PROMPT,
6466
}
6567

graphrag/cli/main.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class SearchType(Enum):
8888
LOCAL = "local"
8989
GLOBAL = "global"
9090
DRIFT = "drift"
91+
BASIC = "basic"
9192

9293
def __str__(self):
9394
"""Return the string representation of the enum value."""
@@ -424,7 +425,12 @@ def _query_cli(
424425
] = False,
425426
):
426427
"""Query a knowledge graph index."""
427-
from graphrag.cli.query import run_drift_search, run_global_search, run_local_search
428+
from graphrag.cli.query import (
429+
run_basic_search,
430+
run_drift_search,
431+
run_global_search,
432+
run_local_search,
433+
)
428434

429435
match method:
430436
case SearchType.LOCAL:
@@ -457,5 +463,13 @@ def _query_cli(
457463
streaming=False, # Drift search does not support streaming (yet)
458464
query=query,
459465
)
466+
case SearchType.BASIC:
467+
run_basic_search(
468+
config_filepath=config,
469+
data_dir=data,
470+
root_dir=root,
471+
streaming=streaming,
472+
query=query,
473+
)
460474
case _:
461475
raise ValueError(INVALID_METHOD_ERROR)

graphrag/cli/query.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,69 @@ def run_drift_search(
257257
return response, context_data
258258

259259

260+
def run_basic_search(
261+
config_filepath: Path | None,
262+
data_dir: Path | None,
263+
root_dir: Path,
264+
streaming: bool,
265+
query: str,
266+
):
267+
"""Perform a basics search with a given query.
268+
269+
Loads index files required for basic search and calls the Query API.
270+
"""
271+
root = root_dir.resolve()
272+
config = load_config(root, config_filepath)
273+
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
274+
resolve_paths(config)
275+
276+
dataframe_dict = _resolve_output_files(
277+
config=config,
278+
output_list=[
279+
"create_final_text_units.parquet",
280+
],
281+
)
282+
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]
283+
284+
print(streaming) # noqa: T201
285+
286+
# # call the Query API
287+
if streaming:
288+
289+
async def run_streaming_search():
290+
full_response = ""
291+
context_data = None
292+
get_context_data = True
293+
async for stream_chunk in api.basic_search_streaming(
294+
config=config,
295+
text_units=final_text_units,
296+
query=query,
297+
):
298+
if get_context_data:
299+
context_data = stream_chunk
300+
get_context_data = False
301+
else:
302+
full_response += stream_chunk
303+
print(stream_chunk, end="") # noqa: T201
304+
sys.stdout.flush() # flush output buffer to display text immediately
305+
print() # noqa: T201
306+
return full_response, context_data
307+
308+
return asyncio.run(run_streaming_search())
309+
# not streaming
310+
response, context_data = asyncio.run(
311+
api.basic_search(
312+
config=config,
313+
text_units=final_text_units,
314+
query=query,
315+
)
316+
)
317+
logger.success(f"Basic Search Response:\n{response}")
318+
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
319+
# External users should use the API directly to get the response and context data.
320+
return response, context_data
321+
322+
260323
def _resolve_output_files(
261324
config: GraphRagConfig,
262325
output_list: list[str],

graphrag/config/create_graphrag_config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
from graphrag.config.input_models.graphrag_config_input import GraphRagConfigInput
3232
from graphrag.config.input_models.llm_config_input import LLMConfigInput
33+
from graphrag.config.models.basic_search_config import BasicSearchConfig
3334
from graphrag.config.models.cache_config import CacheConfig
3435
from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType
3536
from graphrag.config.models.claim_extraction_config import ClaimExtractionConfig
@@ -636,6 +637,28 @@ def hydrate_parallelization_params(
636637
or defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS,
637638
)
638639

640+
with (
641+
reader.use(values.get("basic_search")),
642+
reader.envvar_prefix(Section.basic_search),
643+
):
644+
basic_search_model = BasicSearchConfig(
645+
prompt=reader.str("prompt") or None,
646+
text_unit_prop=reader.float("text_unit_prop")
647+
or defs.BASIC_SEARCH_TEXT_UNIT_PROP,
648+
conversation_history_max_turns=reader.int(
649+
"conversation_history_max_turns"
650+
)
651+
or defs.BASIC_SEARCH_CONVERSATION_HISTORY_MAX_TURNS,
652+
temperature=reader.float("llm_temperature")
653+
or defs.BASIC_SEARCH_LLM_TEMPERATURE,
654+
top_p=reader.float("llm_top_p") or defs.BASIC_SEARCH_LLM_TOP_P,
655+
n=reader.int("llm_n") or defs.BASIC_SEARCH_LLM_N,
656+
max_tokens=reader.int(Fragment.max_tokens)
657+
or defs.BASIC_SEARCH_MAX_TOKENS,
658+
llm_max_tokens=reader.int("llm_max_tokens")
659+
or defs.BASIC_SEARCH_LLM_MAX_TOKENS,
660+
)
661+
639662
skip_workflows = reader.list("skip_workflows") or []
640663

641664
return GraphRagConfig(
@@ -663,6 +686,7 @@ def hydrate_parallelization_params(
663686
local_search=local_search_model,
664687
global_search=global_search_model,
665688
drift_search=drift_search_model,
689+
basic_search=basic_search_model,
666690
)
667691

668692

@@ -731,6 +755,7 @@ class Section(str, Enum):
731755
local_search = "LOCAL_SEARCH"
732756
global_search = "GLOBAL_SEARCH"
733757
drift_search = "DRIFT_SEARCH"
758+
basic_search = "BASIC_SEARCH"
734759

735760

736761
def _is_azure(llm_type: LLMType | None) -> bool:

graphrag/config/defaults.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,12 @@
161161
DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS = 2000
162162

163163
DRIFT_N_DEPTH = 3
164+
165+
# Basic Search
166+
BASIC_SEARCH_TEXT_UNIT_PROP = 0.5
167+
BASIC_SEARCH_CONVERSATION_HISTORY_MAX_TURNS = 5
168+
BASIC_SEARCH_MAX_TOKENS = 12_000
169+
BASIC_SEARCH_LLM_TEMPERATURE = 0
170+
BASIC_SEARCH_LLM_TOP_P = 1
171+
BASIC_SEARCH_LLM_N = 1
172+
BASIC_SEARCH_LLM_MAX_TOKENS = 2000

graphrag/config/init_content.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@
132132
133133
drift_search:
134134
prompt: "prompts/drift_search_system_prompt.txt"
135+
136+
basic_search:
137+
prompt: "prompts/basic_search_system_prompt.txt"
135138
"""
136139

137140
INIT_DOTENV = """\
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""Parameterization settings for the default configuration."""
5+
6+
from typing_extensions import NotRequired, TypedDict
7+
8+
9+
class BasicSearchConfigInput(TypedDict):
10+
"""The default configuration section for Cache."""
11+
12+
text_unit_prop: NotRequired[float | str | None]
13+
conversation_history_max_turns: NotRequired[int | str | None]
14+
max_tokens: NotRequired[int | str | None]
15+
llm_max_tokens: NotRequired[int | str | None]

0 commit comments

Comments
 (0)