Skip to content

Commit 0805924

Browse files
Fix/drift n depth (#1676)
* Fix n_depth param * Semver * Change smoke tests params for drift * Reduce log printing for expected exceptions
1 parent a4d35bc commit 0805924

File tree

7 files changed

+31
-17
lines changed

7 files changed

+31
-17
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Fix proper use of n_depth for drift search"
4+
}

graphrag/config/defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@
187187
# DRIFT Search
188188
DRIFT_SEARCH_LLM_TEMPERATURE = 0
189189
DRIFT_SEARCH_LLM_TOP_P = 1
190-
DRIFT_SEARCH_LLM_N = 3
190+
DRIFT_SEARCH_LLM_N = 1
191191
DRIFT_SEARCH_MAX_TOKENS = 12_000
192192
DRIFT_SEARCH_DATA_MAX_TOKENS = 12_000
193193
DRIFT_SEARCH_CONCURRENCY = 32

graphrag/query/llm/text_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def chunk_text(
5050
yield from (token_encoder.decode(list(chunk)) for chunk in chunk_iterator)
5151

5252

53-
def try_parse_json_object(input: str) -> tuple[str, dict]:
53+
def try_parse_json_object(input: str, verbose: bool = True) -> tuple[str, dict]:
5454
"""JSON cleaning and formatting utilities."""
5555
# Sometimes, the LLM returns a json string with some extra description, this function will clean it up.
5656

@@ -59,7 +59,8 @@ def try_parse_json_object(input: str) -> tuple[str, dict]:
5959
# Try parse first
6060
result = json.loads(input)
6161
except json.JSONDecodeError:
62-
log.info("Warning: Error decoding faulty json, attempting repair")
62+
if verbose:
63+
log.info("Warning: Error decoding faulty json, attempting repair")
6364

6465
if result:
6566
return input, result
@@ -97,11 +98,13 @@ def try_parse_json_object(input: str) -> tuple[str, dict]:
9798
try:
9899
result = json.loads(input)
99100
except json.JSONDecodeError:
100-
log.exception("error loading json, json=%s", input)
101+
if verbose:
102+
log.exception("error loading json, json=%s", input)
101103
return input, {}
102104
else:
103105
if not isinstance(result, dict):
104-
log.exception("not expected dict type. type=%s:", type(result))
106+
if verbose:
107+
log.exception("not expected dict type. type=%s:", type(result))
105108
return input, {}
106109
return input, result
107110
else:

graphrag/query/structured_search/drift_search/action.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import logging
88
from typing import Any
99

10+
from graphrag.query.llm.text_utils import try_parse_json_object
11+
1012
log = logging.getLogger(__name__)
1113

1214

@@ -72,17 +74,12 @@ async def asearch(self, search_engine: Any, global_query: str, scorer: Any = Non
7274
drift_query=global_query, query=self.query
7375
)
7476

75-
try:
76-
response = json.loads(search_result.response)
77-
except json.JSONDecodeError:
78-
error_message = "Failed to parse search response"
79-
log.exception("%s: %s", error_message, search_result.response)
80-
# Do not launch exception as it will roll up with other steps
81-
# Instead return an empty response and let score -inf handle it
82-
response = {}
77+
# Do not launch exception as it will roll up with other steps
78+
# Instead return an empty response and let score -inf handle it
79+
_, response = try_parse_json_object(search_result.response, verbose=False)
8380

8481
self.answer = response.pop("response", None)
85-
self.score = response.pop("score", float("-inf"))
82+
self.score = float(response.pop("score", "-inf"))
8683
self.metadata.update({"context_data": search_result.context_data})
8784

8885
if self.answer is None:

graphrag/query/structured_search/drift_search/search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ async def asearch(
219219
# Main loop
220220
epochs = 0
221221
llm_call_offset = 0
222-
while epochs < self.context_builder.config.n:
222+
while epochs < self.context_builder.config.n_depth:
223223
actions = self.query_state.rank_incomplete_actions()
224224
if len(actions) == 0:
225225
log.info("No more actions to take. Exiting DRIFT loop.")

tests/fixtures/min-csv/settings.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,9 @@ reporting:
5050
base_dir: "logs"
5151

5252
snapshots:
53-
embeddings: True
53+
embeddings: True
54+
55+
drift_search:
56+
n_depth: 1
57+
k_follow_ups: 3
58+
primer_folds: 3

tests/fixtures/text/settings.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,9 @@ reporting:
5555
base_dir: "logs"
5656

5757
snapshots:
58-
embeddings: True
58+
embeddings: True
59+
60+
drift_search:
61+
n_depth: 1
62+
k_follow_ups: 3
63+
primer_folds: 3

0 commit comments

Comments
 (0)