Skip to content

Commit cc5bb9e

Browse files
committed
Remove configurable prompt delimiters
1 parent 5e2013f commit cc5bb9e

File tree

6 files changed

+220
-238
lines changed

6 files changed

+220
-238
lines changed

graphrag/index/operations/extract_covariates/claim_extractor.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020
INPUT_ENTITY_SPEC_KEY = "entity_specs"
2121
INPUT_CLAIM_DESCRIPTION_KEY = "claim_description"
2222
INPUT_RESOLVED_ENTITIES_KEY = "resolved_entities"
23-
TUPLE_DELIMITER_KEY = "tuple_delimiter"
2423
RECORD_DELIMITER_KEY = "record_delimiter"
2524
COMPLETION_DELIMITER_KEY = "completion_delimiter"
26-
DEFAULT_TUPLE_DELIMITER = "<|>"
27-
DEFAULT_RECORD_DELIMITER = "##"
28-
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
25+
TUPLE_DELIMITER = "<|>"
26+
RECORD_DELIMITER = "##"
27+
COMPLETION_DELIMITER = "<|COMPLETE|>"
2928
logger = logging.getLogger(__name__)
3029

3130

@@ -71,20 +70,13 @@ async def __call__(
7170
) -> ClaimExtractorResult:
7271
"""Call method definition."""
7372
source_doc_map = {}
74-
75-
prompt_args = {
76-
INPUT_ENTITY_SPEC_KEY: entity_spec,
77-
INPUT_CLAIM_DESCRIPTION_KEY: claim_description,
78-
TUPLE_DELIMITER_KEY: DEFAULT_TUPLE_DELIMITER,
79-
RECORD_DELIMITER_KEY: DEFAULT_RECORD_DELIMITER,
80-
COMPLETION_DELIMITER_KEY: DEFAULT_COMPLETION_DELIMITER,
81-
}
82-
8373
all_claims: list[dict] = []
8474
for doc_index, text in enumerate(texts):
8575
document_id = f"d{doc_index}"
8676
try:
87-
claims = await self._process_document(prompt_args, text)
77+
claims = await self._process_document(
78+
text, claim_description, entity_spec
79+
)
8880
all_claims += [
8981
self._clean_claim(c, document_id, resolved_entities) for c in claims
9082
]
@@ -117,15 +109,18 @@ def _clean_claim(
117109
claim["subject_id"] = subject
118110
return claim
119111

120-
async def _process_document(self, prompt_args: dict, doc) -> list[dict]:
112+
async def _process_document(
113+
self, text: str, claim_description: str, entity_spec: dict
114+
) -> list[dict]:
121115
response = await self._model.achat(
122116
self._extraction_prompt.format(**{
123-
INPUT_TEXT_KEY: doc,
124-
**prompt_args,
117+
INPUT_TEXT_KEY: text,
118+
INPUT_CLAIM_DESCRIPTION_KEY: claim_description,
119+
INPUT_ENTITY_SPEC_KEY: entity_spec,
125120
})
126121
)
127122
results = response.output.content or ""
128-
claims = results.strip().removesuffix(DEFAULT_COMPLETION_DELIMITER)
123+
claims = results.strip().removesuffix(COMPLETION_DELIMITER)
129124

130125
# if gleanings are specified, enter a loop to extract more claims
131126
# there are two exit criteria: (a) we hit the configured max, (b) the model says there are no more claims
@@ -137,8 +132,8 @@ async def _process_document(self, prompt_args: dict, doc) -> list[dict]:
137132
history=response.history,
138133
)
139134
extension = response.output.content or ""
140-
claims += DEFAULT_RECORD_DELIMITER + extension.strip().removesuffix(
141-
DEFAULT_COMPLETION_DELIMITER
135+
claims += RECORD_DELIMITER + extension.strip().removesuffix(
136+
COMPLETION_DELIMITER
142137
)
143138

144139
# If this isn't the last loop, check to see if we should continue
@@ -164,18 +159,16 @@ def pull_field(index: int, fields: list[str]) -> str | None:
164159

165160
result: list[dict[str, Any]] = []
166161
claims_values = (
167-
claims.strip()
168-
.removesuffix(DEFAULT_COMPLETION_DELIMITER)
169-
.split(DEFAULT_RECORD_DELIMITER)
162+
claims.strip().removesuffix(COMPLETION_DELIMITER).split(RECORD_DELIMITER)
170163
)
171164
for claim in claims_values:
172165
claim = claim.strip().removeprefix("(").removesuffix(")")
173166

174167
# Ignore the completion delimiter
175-
if claim == DEFAULT_COMPLETION_DELIMITER:
168+
if claim == COMPLETION_DELIMITER:
176169
continue
177170

178-
claim_fields = claim.split(DEFAULT_TUPLE_DELIMITER)
171+
claim_fields = claim.split(TUPLE_DELIMITER)
179172
result.append({
180173
"subject_id": pull_field(0, claim_fields),
181174
"object_id": pull_field(1, claim_fields),

graphrag/index/operations/extract_graph/graph_extractor.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,11 @@
2222

2323
INPUT_TEXT_KEY = "input_text"
2424
RECORD_DELIMITER_KEY = "record_delimiter"
25-
TUPLE_DELIMITER_KEY = "tuple_delimiter"
2625
COMPLETION_DELIMITER_KEY = "completion_delimiter"
2726
ENTITY_TYPES_KEY = "entity_types"
28-
DEFAULT_TUPLE_DELIMITER = "<|>"
29-
DEFAULT_RECORD_DELIMITER = "##"
30-
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
27+
TUPLE_DELIMITER = "<|>"
28+
RECORD_DELIMITER = "##"
29+
COMPLETION_DELIMITER = "<|COMPLETE|>"
3130
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
3231

3332
logger = logging.getLogger(__name__)
@@ -72,18 +71,10 @@ async def __call__(
7271
all_records: dict[int, str] = {}
7372
source_doc_map: dict[int, str] = {}
7473

75-
# Wire defaults into the prompt variables
76-
prompt_variables = {
77-
ENTITY_TYPES_KEY: ",".join(entity_types),
78-
TUPLE_DELIMITER_KEY: DEFAULT_TUPLE_DELIMITER,
79-
RECORD_DELIMITER_KEY: DEFAULT_RECORD_DELIMITER,
80-
COMPLETION_DELIMITER_KEY: DEFAULT_COMPLETION_DELIMITER,
81-
}
82-
8374
for doc_index, text in enumerate(texts):
8475
try:
8576
# Invoke the entity extraction
86-
result = await self._process_document(text, prompt_variables)
77+
result = await self._process_document(text, entity_types)
8778
source_doc_map[doc_index] = text
8879
all_records[doc_index] = result
8980
except Exception as e:
@@ -99,22 +90,20 @@ async def __call__(
9990

10091
output = await self._process_results(
10192
all_records,
102-
DEFAULT_TUPLE_DELIMITER,
103-
DEFAULT_RECORD_DELIMITER,
93+
TUPLE_DELIMITER,
94+
RECORD_DELIMITER,
10495
)
10596

10697
return GraphExtractionResult(
10798
output=output,
10899
source_docs=source_doc_map,
109100
)
110101

111-
async def _process_document(
112-
self, text: str, prompt_variables: dict[str, str]
113-
) -> str:
102+
async def _process_document(self, text: str, entity_types: list[str]) -> str:
114103
response = await self._model.achat(
115104
self._extraction_prompt.format(**{
116-
**prompt_variables,
117105
INPUT_TEXT_KEY: text,
106+
ENTITY_TYPES_KEY: ",".join(entity_types),
118107
}),
119108
)
120109
results = response.output.content or ""

0 commit comments

Comments
 (0)