diff --git a/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py b/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py index dd7d3f7b..5df269e8 100644 --- a/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py +++ b/lib/idp_common_pkg/idp_common/extraction/agentic_idp.py @@ -16,14 +16,8 @@ from pathlib import Path from typing import ( Any, - Dict, - List, - Optional, - Tuple, - Type, TypedDict, TypeVar, - Union, ) import jsonpatch @@ -95,14 +89,14 @@ class BedrockUsage(TypedDict, total=False): class BedrockMessageContent(TypedDict): """Content item in a Bedrock message.""" - text: Optional[str] + text: str | None class BedrockMessage(TypedDict): """Message structure in Bedrock response.""" role: str - content: List[BedrockMessageContent] + content: list[BedrockMessageContent] class BedrockOutput(TypedDict): @@ -116,8 +110,8 @@ class BedrockResponse(TypedDict, total=False): output: BedrockOutput usage: BedrockUsage - stopReason: Optional[str] - metrics: Optional[Dict[str, Any]] + stopReason: str | None + metrics: dict[str, Any] | None class BedrockInvokeModelResponse(TypedDict): @@ -132,7 +126,7 @@ class BedrockInvokeModelResponse(TypedDict): """ response: BedrockResponse - metering: Dict[str, BedrockUsage] # Key format: "{context}/bedrock/{model_id}" + metering: dict[str, BedrockUsage] # Key format: "{context}/bedrock/{model_id}" # Data Models for structured extraction @@ -146,7 +140,7 @@ class BoolResponseModel(BaseModel): class JsonPatchModel(BaseModel): """Model for JSON patch operations.""" - patches: List[Dict[str, Any]] = Field( + patches: list[dict[str, Any]] = Field( ..., description="JSON patch operations to apply. Each patch should follow RFC 6902 format with 'op', 'path', and optionally 'value' keys.", ) @@ -157,9 +151,9 @@ class JsonPatchModel(BaseModel): def apply_patches_to_data( - existing_data: Dict[str, Any], - patches: List[Dict[str, Any]], -) -> Dict[str, Any]: + existing_data: dict[str, Any], + patches: list[dict[str, Any]], +) -> dict[str, Any]: """ Apply JSON patches to existing data and validate the result. @@ -179,7 +173,7 @@ def apply_patches_to_data( return patched_dict -def create_dynamic_extraction_tool_and_patch_tool(model_class: Type[TargetModel]): +def create_dynamic_extraction_tool_and_patch_tool(model_class: type[TargetModel]): """ Create a dynamic tool function that extracts data according to a Pydantic model. @@ -215,9 +209,9 @@ def extraction_tool( @tool def apply_json_patches( - patches: List[Dict[str, Any]], + patches: list[dict[str, Any]], agent: Agent, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Apply JSON patches to fix or update the extracted data. @@ -225,7 +219,7 @@ def apply_json_patches( patches: List of JSON patch operations (RFC 6902 format) reasoning: Explanation of what the patches fix """ - current_data: Dict | None = agent.state.get("current_extraction") + current_data: dict[str, Any] | None = agent.state.get("current_extraction") logger.info("Patch tool called", extra={"patch_request": patches}) if not current_data: @@ -268,6 +262,8 @@ def write_buffer_date(data: dict[str, Any], agent: Agent) -> str: """ Use this tool when the extraction is too large to do in a single step, this is a buffer where you can save intermediate data that wouldn't pass validation yet. + IMPORTANT: The data you save here must eventually match the extraction schema structure. Plan your buffer data structure to align with the required schema fields and types. + Review the extraction schema before using this tool to ensure compatibility. """ agent.state.set("intermediate_extraction", data) logger.info("Saving intermediate data", extra={"intermediate_extraction": data}) @@ -276,11 +272,114 @@ def write_buffer_date(data: dict[str, Any], agent: Agent) -> str: @tool def view_buffer_data(agent: Agent) -> str: - """View the intermediate buffer data with this tool, this data is not a validated extraction, but intermediate state for you to work with.""" + """View the intermediate buffer data with this tool, this data is not a validated extraction, but intermediate state for you to work with. + + WARNING: This returns the ENTIRE buffer which can be very large. For large extractions, prefer using view_buffer_data_section or view_buffer_data_stats.""" return agent.state.get("intermediate_extraction") +@tool +def view_buffer_data_section(path: str, agent: Agent) -> Any: + """View a specific section of the intermediate buffer data by JSON Pointer path (RFC 6901). Token-efficient way to inspect parts of large data. + + Args: + path: JSON Pointer path (same format as JSON Patch). Must start with "/" for nested paths, or use "" for root. + + Examples: + - path="/table_rows" -> returns the entire table_rows array + - path="/table_rows/0" -> returns first item in table_rows array + - path="/table_rows/0/fund_name" -> returns fund_name field of first row + - path="/document_name" -> returns just the document_name field + - path="" -> returns entire buffer (same as view_buffer_data) + + Note: Uses same path format as JSON Patch operations for consistency. + """ + data = agent.state.get("intermediate_extraction") + + if not data: + return {"error": "No intermediate data in buffer"} + + # Handle empty path (root) + if path == "": + return data + + # Remove leading slash and parse path + if not path.startswith("/"): + return { + "error": "Path must start with '/' (e.g., '/table_rows/0') or be empty string for root" + } + + parts = path[1:].split("/") if path != "/" else [] + current = data + + try: + for part in parts: + if not part: # Skip empty parts from double slashes + continue + + if isinstance(current, list): + # Try to convert to int for list indexing + current = current[int(part)] + elif isinstance(current, dict): + current = current[part] + else: + return { + "error": f"Cannot navigate path '{path}' - reached non-dict/list at '{part}'" + } + + return current + except (KeyError, IndexError, ValueError) as e: + return {"error": f"Path '{path}' not found: {str(e)}"} + + +@tool +def get_extraction_schema_reminder(agent: Agent) -> str: + """Use this tool during long extractions to review the expected data schema and field requirements. Helps ensure you stay aligned with the required structure. + + RECOMMENDED: Call this tool every 100-200 rows in large extractions to verify you're maintaining the correct structure.""" + + schema = agent.state.get("extraction_schema_json") + if not schema: + return "Schema not available" + + return f"Remember: Your extraction must match this schema:\n\n{schema}\n\nEnsure all field names, types, and required fields are correct." + + +@tool +def view_buffer_data_stats(agent: Agent) -> dict[str, Any]: + """View overview statistics of intermediate buffer data. Token-efficient alternative to viewing full data. Use this for progress checks during large extractions. + + TIP: For large extractions (500+ items), consider calling get_extraction_schema_reminder every 100-200 items to stay aligned with requirements.""" + + data = agent.state.get("intermediate_extraction") + + if not data: + return {"status": "empty", "message": "No intermediate data in buffer"} + + # Build statistics + stats = {} + + if isinstance(data, dict): + stats["keys"] = list(data.keys()) + stats["field_count"] = len(data) + # Sample nested structure for arrays + for key, value in data.items(): + if isinstance(value, list): + stats[f"{key}_length"] = len(value) + if value: + stats[f"{key}_sample_type"] = type(value[0]).__name__ + # Add estimated token count (rough approximation) + data_str = str(data) + + return { + "status": "contains_data", + "structure": stats, + "estimated_size_chars": len(data_str), + "tip": "Use patch_buffer_data to update specific fields, Use make_buffer_data_final_extraction to complete or get detailed guidance on missing requirements.", + } + + @tool def patch_buffer_data(patches: list[dict[str, Any]], agent: Agent) -> str: """Update the intermediate_extraction data inside the buffer, this is not validated yet @@ -293,28 +392,104 @@ def patch_buffer_data(patches: list[dict[str, Any]], agent: Agent) -> str: """ + logger.info("Buffer Patch tool called", extra={"patch_request": patches}) patched_data = apply_patches_to_data( existing_data=agent.state.get("intermediate_extraction"), patches=patches ) agent.state.set("intermediate_extraction", patched_data) + logger.info(f"Current length of buffer data {len(patched_data)} ") + return f"Successfully patched {str(patched_data)[100:]}...." +@tool +def create_todo_list(todos: list[str], agent: Agent) -> str: + """Create a new todo list to track your extraction tasks. Use this to plan your work, especially for large documents. + + Args: + todos: List of task descriptions to track (e.g., ["Extract rows 1-100", "Extract rows 101-200"]) + + Example: + create_todo_list(["Extract first 100 rows", "Extract rows 101-200", "Extract rows 201-300", "Validate and finalize"], agent) + """ + todo_list = [{"task": task, "completed": False} for task in todos] + agent.state.set("todo_list", todo_list) + logger.info("Created todo list", extra={"todo_count": len(todo_list)}) + return f"Created todo list with {len(todo_list)} tasks:\n" + "\n".join( + f"{i + 1}. [ ] {item['task']}" for i, item in enumerate(todo_list) + ) + + +@tool +def update_todo(task_index: int, completed: bool, agent: Agent) -> str: + """Mark a todo item as completed or not completed. + + Args: + task_index: Index of the task to update (1-based, matching the list display) + completed: True to mark as completed, False to mark as incomplete + + Example: + update_todo(1, True, agent) # Mark first task as completed + """ + todo_list: list[dict[str, Any]] | None = agent.state.get("todo_list") + + if not todo_list: + return "No todo list found. Create one first using create_todo_list." + + # Convert to 0-based index + index = task_index - 1 + + if index < 0 or index >= len(todo_list): + return f"Invalid task index {task_index}. Valid range: 1-{len(todo_list)}" + + todo_list[index]["completed"] = completed + agent.state.set("todo_list", todo_list) + + status = "completed" if completed else "incomplete" + logger.info( + f"Updated todo {task_index}", + extra={"task": todo_list[index]["task"], "completed": completed}, + ) + return f"Task {task_index} marked as {status}: {todo_list[index]['task']}" + + +@tool +def view_todo_list(agent: Agent) -> str: + """View your current todo list with completion status.""" + todo_list: list[dict[str, Any]] | None = agent.state.get("todo_list") + + if not todo_list: + return "No todo list found. Create one using create_todo_list to track your extraction tasks." + + completed_count = sum(1 for item in todo_list if item["completed"]) + total_count = len(todo_list) + + result = f"Todo List ({completed_count}/{total_count} completed):\n" + result += "\n".join( + f"{i + 1}. [{'✓' if item['completed'] else ' '}] {item['task']}" + for i, item in enumerate(todo_list) + ) + + return result + + SYSTEM_PROMPT = """ You are a useful assistant that helps turn unstructured data into structured data using the provided tools. EXTRACTION APPROACH: -1. Use the extraction_tool for fresh data extraction +1. Use the extraction_tool for fresh data extraction - this validates data against the schema immediately 2. When updating existing data or fixing validation errors, use JSON patch operations via the apply_json_patches tool 3. JSON patches allow precise, targeted updates without losing correct data -4. If the document is large and the extraction request can't be done in one go, create a valid extraction object and interate with jsonpatch until you completed the entire extraction! -5. Use intermediate data buffer if you can't extract a valid data object in a single step. +4. If the document is large and the extraction request can't be done in one go, create a valid extraction object and iterate with jsonpatches until you completed the entire extraction! +5. Use intermediate data buffer if you can't extract a valid data object in a single step IMPORTANT: YOU MUST perform a batched extraction if there are more than 50 fields to extract. -batched extraction is when you create a viable format with extraction tool and then you expand it with jsonpatches. You can pass up to 50 records in a single patch operation. +batched extraction is when you create a viable format with extraction tool and then you expand it with jsonpatches. +You can pass up to 100 records in a single patch operation. +When using batched extraction plan it out and make a todo list with target size based on the document and other key tasks. NEVER STOP early on large documents, always extract all the data. @@ -351,9 +526,9 @@ def patch_buffer_data(patches: list[dict[str, Any]], agent: Agent) -> str: async def structured_output_async( model_id: str, - data_format: Type[TargetModel], - prompt: Union[str, Message, Image.Image], - existing_data: Optional[BaseModel] = None, + data_format: type[TargetModel], + prompt: str | Message | Image.Image, + existing_data: BaseModel | None = None, system_prompt: str | None = None, custom_instruction: str | None = None, review_agent: bool = False, @@ -361,8 +536,8 @@ async def structured_output_async( max_retries: int = 7, connect_timeout: float = 10.0, read_timeout: float = 300.0, - max_tokens: Optional[int] = None, -) -> Tuple[TargetModel, BedrockInvokeModelResponse]: + max_tokens: int | None = None, +) -> tuple[TargetModel, BedrockInvokeModelResponse]: """ Extract structured data using Strands agents with tool-based validation. @@ -450,7 +625,13 @@ async def structured_output_async( view_existing_extraction, patch_buffer_data, view_buffer_data, + view_buffer_data_section, + view_buffer_data_stats, write_buffer_date, + get_extraction_schema_reminder, + create_todo_list, + update_todo, + view_todo_list, ] # Create agent with system prompt and tools @@ -553,9 +734,11 @@ async def structured_output_async( else: logger.debug("Caching not supported for model", extra={"model_id": model_id}) + final_system_prompt = SYSTEM_PROMPT + if custom_instruction: - final_system_prompt = f"{system_prompt}\n\nCustom Instructions for this specific task: {custom_instruction}" - logger.debug("Running extraction", extra={"system_prompt": final_system_prompt}) + final_system_prompt = f"{final_system_prompt}\n\nCustom Instructions for this specific task: {custom_instruction}" + agent = Agent( model=BedrockModel(**model_config), # pyright: ignore[reportArgumentType] tools=tools, @@ -564,9 +747,10 @@ async def structured_output_async( "current_extraction": None, "images": {}, "existing_data": existing_data.model_dump() if existing_data else None, + "extraction_schema_json": schema_json, # Store for schema reminder tool }, conversation_manager=SummarizingConversationManager( - summary_ratio=0.8, preserve_recent_messages=3 + summary_ratio=0.8, preserve_recent_messages=2 ), ) @@ -816,9 +1000,9 @@ async def structured_output_async( def structured_output( model_id: str, - data_format: Type[BaseModel], - prompt: Union[str, Message, Image.Image], - existing_data: Optional[BaseModel] = None, + data_format: type[BaseModel], + prompt: str | Message | Image.Image, + existing_data: BaseModel | None = None, system_prompt: str | None = None, custom_instruction: str | None = None, review_agent: bool = False, @@ -826,7 +1010,7 @@ def structured_output( max_retries: int = 7, connect_timeout: float = 10.0, read_timeout: float = 300.0, -) -> Tuple[BaseModel, BedrockInvokeModelResponse]: +) -> tuple[BaseModel, BedrockInvokeModelResponse]: """ Synchronous version of structured_output_async. @@ -970,7 +1154,7 @@ class DocumentRow(BaseModel): class DocumentFormat(BaseModel): document_name: str document_text: str - table_rows: list[DocumentRow] = Field(gt=500) + table_rows: list[DocumentRow] = Field(min_length=500) with open(file_path, "rb") as f: data = f.read()