2020INPUT_ENTITY_SPEC_KEY = "entity_specs"
2121INPUT_CLAIM_DESCRIPTION_KEY = "claim_description"
2222INPUT_RESOLVED_ENTITIES_KEY = "resolved_entities"
23- TUPLE_DELIMITER_KEY = "tuple_delimiter"
2423RECORD_DELIMITER_KEY = "record_delimiter"
2524COMPLETION_DELIMITER_KEY = "completion_delimiter"
26- DEFAULT_TUPLE_DELIMITER = "<|>"
27- DEFAULT_RECORD_DELIMITER = "##"
28- DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
25+ TUPLE_DELIMITER = "<|>"
26+ RECORD_DELIMITER = "##"
27+ COMPLETION_DELIMITER = "<|COMPLETE|>"
2928logger = logging .getLogger (__name__ )
3029
3130
@@ -71,20 +70,13 @@ async def __call__(
7170 ) -> ClaimExtractorResult :
7271 """Call method definition."""
7372 source_doc_map = {}
74-
75- prompt_args = {
76- INPUT_ENTITY_SPEC_KEY : entity_spec ,
77- INPUT_CLAIM_DESCRIPTION_KEY : claim_description ,
78- TUPLE_DELIMITER_KEY : DEFAULT_TUPLE_DELIMITER ,
79- RECORD_DELIMITER_KEY : DEFAULT_RECORD_DELIMITER ,
80- COMPLETION_DELIMITER_KEY : DEFAULT_COMPLETION_DELIMITER ,
81- }
82-
8373 all_claims : list [dict ] = []
8474 for doc_index , text in enumerate (texts ):
8575 document_id = f"d{ doc_index } "
8676 try :
87- claims = await self ._process_document (prompt_args , text )
77+ claims = await self ._process_document (
78+ text , claim_description , entity_spec
79+ )
8880 all_claims += [
8981 self ._clean_claim (c , document_id , resolved_entities ) for c in claims
9082 ]
@@ -117,15 +109,18 @@ def _clean_claim(
117109 claim ["subject_id" ] = subject
118110 return claim
119111
120- async def _process_document (self , prompt_args : dict , doc ) -> list [dict ]:
112+ async def _process_document (
113+ self , text : str , claim_description : str , entity_spec : dict
114+ ) -> list [dict ]:
121115 response = await self ._model .achat (
122116 self ._extraction_prompt .format (** {
123- INPUT_TEXT_KEY : doc ,
124- ** prompt_args ,
117+ INPUT_TEXT_KEY : text ,
118+ INPUT_CLAIM_DESCRIPTION_KEY : claim_description ,
119+ INPUT_ENTITY_SPEC_KEY : entity_spec ,
125120 })
126121 )
127122 results = response .output .content or ""
128- claims = results .strip ().removesuffix (DEFAULT_COMPLETION_DELIMITER )
123+ claims = results .strip ().removesuffix (COMPLETION_DELIMITER )
129124
130125 # if gleanings are specified, enter a loop to extract more claims
131126 # there are two exit criteria: (a) we hit the configured max, (b) the model says there are no more claims
@@ -137,8 +132,8 @@ async def _process_document(self, prompt_args: dict, doc) -> list[dict]:
137132 history = response .history ,
138133 )
139134 extension = response .output .content or ""
140- claims += DEFAULT_RECORD_DELIMITER + extension .strip ().removesuffix (
141- DEFAULT_COMPLETION_DELIMITER
135+ claims += RECORD_DELIMITER + extension .strip ().removesuffix (
136+ COMPLETION_DELIMITER
142137 )
143138
144139 # If this isn't the last loop, check to see if we should continue
@@ -164,18 +159,16 @@ def pull_field(index: int, fields: list[str]) -> str | None:
164159
165160 result : list [dict [str , Any ]] = []
166161 claims_values = (
167- claims .strip ()
168- .removesuffix (DEFAULT_COMPLETION_DELIMITER )
169- .split (DEFAULT_RECORD_DELIMITER )
162+ claims .strip ().removesuffix (COMPLETION_DELIMITER ).split (RECORD_DELIMITER )
170163 )
171164 for claim in claims_values :
172165 claim = claim .strip ().removeprefix ("(" ).removesuffix (")" )
173166
174167 # Ignore the completion delimiter
175- if claim == DEFAULT_COMPLETION_DELIMITER :
168+ if claim == COMPLETION_DELIMITER :
176169 continue
177170
178- claim_fields = claim .split (DEFAULT_TUPLE_DELIMITER )
171+ claim_fields = claim .split (TUPLE_DELIMITER )
179172 result .append ({
180173 "subject_id" : pull_field (0 , claim_fields ),
181174 "object_id" : pull_field (1 , claim_fields ),
0 commit comments