Skip to content

Commit ff6b628

Browse files
author
Telsho
committed
feat: pass sqlmodel objects as example
1 parent ac4dbd1 commit ff6b628

File tree

4 files changed

+123
-0
lines changed

4 files changed

+123
-0
lines changed

docs/getting_started.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,19 @@ Now, we can feed unstructured text to the orchestrator. The `synthesize_and_save
115115
if __name__ == "__main__":
116116
asyncio.run(main())
117117
118+
.. tip::
119+
120+
**Improving Results with Examples**: If you have existing data (e.g., a "Product" object fetched from your database), you can pass it to the orchestrator to help the LLM understand the output format. Use the `extraction_example_object` parameter in `synthesize_and_save`:
121+
122+
.. code-block:: python
123+
124+
# existing_product is a SQLModel instance
125+
await orchestrator.synthesize_and_save(
126+
input_strings=[text],
127+
db_session=session,
128+
extraction_example_object=existing_product
129+
)
130+
118131
Step 5: See the Results
119132
-----------------------
120133

docs/workflow_orchestrator.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ Once the orchestrator is configured, you can start processing documents using on
163163
* ``input_strings`` (``List[str]``): A list of strings, where each string is a document to be processed.
164164
* ``db_session_for_hydration`` (``Optional[Session]``): An optional SQLAlchemy session. If provided, the hydrator will use it to resolve relationships. If not, a temporary in-memory session is created.
165165
* ``extraction_example_json`` (``str``, optional): A JSON string that provides a few-shot example to the LLM, guiding it to produce a better-structured output. If not provided, the orchestrator will attempt to auto-generate one.
166+
* ``extraction_example_object`` (``Optional[Union[SQLModel, List[SQLModel]]]``, optional): An existing SQLModel object or a list of them to be used as the few-shot example. This is an alternative to providing the example as a raw JSON string.
166167
* ``custom_extraction_process`` (``str``, optional): Custom, step-by-step instructions for the LLM on how to perform the extraction.
167168
* ``custom_extraction_guidelines`` (``str``, optional): A list of rules or guidelines for the LLM to follow.
168169
* ``custom_final_checklist`` (``str``, optional): A final checklist for the LLM to review before finalizing its output.

src/extrai/core/workflow_orchestrator.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ async def synthesize(
256256
input_strings: List[str],
257257
db_session_for_hydration: Optional[Session],
258258
extraction_example_json: str = "",
259+
extraction_example_object: Optional[Union[SQLModel, List[SQLModel]]] = None,
259260
custom_extraction_process: str = "",
260261
custom_extraction_guidelines: str = "",
261262
custom_final_checklist: str = "",
@@ -267,6 +268,7 @@ async def synthesize(
267268
input_strings: A list of input strings for data extraction.
268269
db_session_for_hydration: SQLAlchemy session for the hydrator.
269270
extraction_example_json: Optional JSON string for few-shot prompting.
271+
extraction_example_object: Optional SQLModel object or list of objects to use as example.
270272
custom_extraction_process: Optional custom instructions for LLM extraction process.
271273
custom_extraction_guidelines: Optional custom guidelines for LLM extraction.
272274
custom_final_checklist: Optional custom final checklist for LLM.
@@ -281,6 +283,25 @@ async def synthesize(
281283
if not input_strings:
282284
raise ValueError("Input strings list cannot be empty.")
283285

286+
if extraction_example_object and not extraction_example_json:
287+
objects_to_process = (
288+
extraction_example_object
289+
if isinstance(extraction_example_object, list)
290+
else [extraction_example_object]
291+
)
292+
processed_objects = []
293+
for obj in objects_to_process:
294+
if isinstance(obj, SQLModel):
295+
processed_objects.append(obj.model_dump(mode="json"))
296+
else:
297+
self.logger.warning(
298+
f"Skipping unsupported object type in extraction_example_object: {type(obj)}"
299+
)
300+
if processed_objects:
301+
extraction_example_json = json.dumps(
302+
processed_objects, default=str, indent=2
303+
)
304+
284305
self.logger.info(
285306
f"Starting synthesis for {self.root_sqlmodel_class.__name__}..."
286307
)
@@ -524,6 +545,7 @@ async def synthesize_and_save(
524545
input_strings: List[str],
525546
db_session: Session,
526547
extraction_example_json: str = "",
548+
extraction_example_object: Optional[Union[SQLModel, List[SQLModel]]] = None,
527549
custom_extraction_process: str = "",
528550
custom_extraction_guidelines: str = "",
529551
custom_final_checklist: str = "",
@@ -536,6 +558,7 @@ async def synthesize_and_save(
536558
input_strings=input_strings,
537559
db_session_for_hydration=db_session,
538560
extraction_example_json=extraction_example_json,
561+
extraction_example_object=extraction_example_object,
539562
custom_extraction_process=custom_extraction_process,
540563
custom_extraction_guidelines=custom_extraction_guidelines,
541564
custom_final_checklist=custom_final_checklist,

tests/core/workflow_orchestrator/test_workflow_orchestrator_execution.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,5 +772,91 @@ async def test_synthesize_and_save_persist_raises_generic_exception(
772772
mock_rollback.assert_called_once()
773773

774774

775+
async def test_synthesize_with_extraction_example_parameters(self):
776+
"""Test different scenarios for extraction_example_object and extraction_example_json."""
777+
dept1 = DepartmentModel(name="Dept 1")
778+
dept2 = DepartmentModel(name="Dept 2")
779+
780+
test_cases = [
781+
{
782+
"name": "single_object",
783+
"object": dept1,
784+
"json_arg": "",
785+
"expected_json_in_prepare": lambda j: len(json.loads(j)) == 1
786+
and json.loads(j)[0]["name"] == "Dept 1",
787+
"expect_warning": False,
788+
},
789+
{
790+
"name": "list_of_objects",
791+
"object": [dept1, dept2],
792+
"json_arg": "",
793+
"expected_json_in_prepare": lambda j: len(json.loads(j)) == 2
794+
and json.loads(j)[1]["name"] == "Dept 2",
795+
"expect_warning": False,
796+
},
797+
{
798+
"name": "priority_json_over_object",
799+
"object": dept1,
800+
"json_arg": '[{"name": "Override"}]',
801+
"expected_json_in_prepare": lambda j: j == '[{"name": "Override"}]',
802+
"expect_warning": False,
803+
},
804+
{
805+
"name": "unsupported_type",
806+
"object": ["unsupported"],
807+
"json_arg": "",
808+
"expected_json_in_prepare": lambda j: j == "",
809+
"expect_warning": True,
810+
},
811+
]
812+
813+
for case in test_cases:
814+
with self.subTest(case=case["name"]):
815+
with (
816+
mock.patch.object(
817+
self.orchestrator,
818+
"_prepare_extraction_example",
819+
new_callable=AsyncMock,
820+
) as mock_prepare,
821+
mock.patch.object(
822+
self.orchestrator,
823+
"_execute_standard_extraction",
824+
AsyncMock(return_value=[]),
825+
),
826+
mock.patch.object(
827+
self.orchestrator,
828+
"_hydrate_results",
829+
mock.MagicMock(return_value=[]),
830+
),
831+
mock.patch.object(
832+
self.orchestrator.logger, "warning"
833+
) as mock_logger_warning,
834+
):
835+
mock_prepare.return_value = "{}"
836+
837+
await self.orchestrator.synthesize(
838+
input_strings=["test"],
839+
db_session_for_hydration=self.db_session,
840+
extraction_example_object=case["object"],
841+
extraction_example_json=case["json_arg"],
842+
)
843+
844+
# Verify warning
845+
if case["expect_warning"]:
846+
mock_logger_warning.assert_called()
847+
args, _ = mock_logger_warning.call_args
848+
self.assertIn("Skipping unsupported object type", args[0])
849+
else:
850+
mock_logger_warning.assert_not_called()
851+
852+
# Verify _prepare_extraction_example argument
853+
args, _ = mock_prepare.call_args
854+
actual_json = args[0]
855+
self.assertTrue(
856+
case["expected_json_in_prepare"](actual_json),
857+
f"Failed for case {case['name']}: actual json {actual_json}",
858+
)
859+
860+
775861
if __name__ == "__main__":
776862
unittest.main()

0 commit comments

Comments
 (0)