Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,10 @@ OPENROUTER_API_KEY=your-openrouter-api-key-here
ANNOTATION_MODEL=mistralai/mistral-small-3.2-24b-instruct
ANNOTATION_PROVIDER=mistral

# Evaluation/Assessment Model (consistent quality checks: Qwen3-235B via DeepInfra)
# Evaluation/Assessment Model (fast quality checks: GPT-OSS-120B via Groq)
# Used for evaluation, assessment, and feedback agents
# Leave EVALUATION_PROVIDER empty to let OpenRouter auto-route
EVALUATION_MODEL=qwen/qwen3-235b-a22b-2507
EVALUATION_PROVIDER=deepinfra/fp8
EVALUATION_MODEL=openai/gpt-oss-120b
EVALUATION_PROVIDER=groq

# Vision Model (image description: Qwen3-VL via deepinfra)
VISION_MODEL=qwen/qwen3-vl-30b-a3b-instruct
Expand Down Expand Up @@ -142,8 +141,8 @@ API_WORKERS=4
# ============================================================================
# Workflow Configuration
# ============================================================================
MAX_VALIDATION_ATTEMPTS=5
MAX_TOTAL_ITERATIONS=10
MAX_VALIDATION_ATTEMPTS=3
MAX_TOTAL_ITERATIONS=4

# ============================================================================
# Logging
Expand Down
9 changes: 8 additions & 1 deletion src/agents/assessment_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
elements or dimensions in the HED annotation.
"""

import logging
from pathlib import Path

from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage

from src.agents.state import HedAnnotationState

logger = logging.getLogger(__name__)


class AssessmentAgent:
"""Agent that performs final assessment of HED annotations.
Expand Down Expand Up @@ -104,7 +107,11 @@ async def assess(self, state: HedAnnotationState) -> dict:
HumanMessage(content=user_prompt),
]

response = await self.llm.ainvoke(messages)
try:
response = await self.llm.ainvoke(messages)
except Exception as e:
logger.error("Assessment LLM invocation failed: %s", e, exc_info=True)
raise
content = response.content
feedback = content.strip() if isinstance(content, str) else str(content)

Expand Down
33 changes: 18 additions & 15 deletions src/agents/evaluation_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
the original natural language event description.
"""

import logging
import re
from pathlib import Path

from langchain_core.language_models import BaseChatModel
Expand All @@ -12,6 +14,8 @@
from src.agents.state import HedAnnotationState
from src.utils.json_schema_loader import HedJsonSchemaLoader, load_latest_schema

logger = logging.getLogger(__name__)


class EvaluationAgent:
"""Agent that evaluates the faithfulness of HED annotations.
Expand Down Expand Up @@ -163,7 +167,11 @@ async def evaluate(self, state: HedAnnotationState) -> dict:
HumanMessage(content=user_prompt),
]

response = await self.llm.ainvoke(messages)
try:
response = await self.llm.ainvoke(messages)
except Exception as e:
logger.error("Evaluation LLM invocation failed: %s", e, exc_info=True)
raise
content = response.content
feedback = content.strip() if isinstance(content, str) else str(content)

Expand All @@ -186,8 +194,6 @@ def _parse_decision(self, feedback: str) -> bool:
Returns:
True if annotation should be accepted, False if needs refinement
"""
import re

feedback_lower = feedback.lower()

# Check for explicit DECISION line
Expand All @@ -201,19 +207,16 @@ def _parse_decision(self, feedback: str) -> bool:
result = faithful_match.group(1)
return result in ["yes", "partial"] # Accept partial as good enough!

# Fallback: look for positive indicators
positive_indicators = ["accept", "good", "sufficient", "adequate", "captures well"]
negative_indicators = ["refine", "missing", "incorrect", "inaccurate", "lacks"]

positive_score = sum(1 for indicator in positive_indicators if indicator in feedback_lower)
negative_score = sum(1 for indicator in negative_indicators if indicator in feedback_lower)
# Fallback: look for explicit refine indicators only
refine_indicators = ["refine", "incorrect", "inaccurate", "wrong"]
if any(indicator in feedback_lower for indicator in refine_indicators):
return False

# If more positive than negative, accept
if positive_score > negative_score:
return True

# Default to refine if ambiguous (conservative)
return False
# Default to accept if ambiguous -- avoid unnecessary refinement loops
logger.debug(
"Evaluation parsing: no explicit DECISION/FAITHFUL/refine indicator found; defaulting to ACCEPT"
)
return True

def _check_tags_and_suggest(self, annotation: str) -> str:
"""Check annotation for invalid tags and suggest alternatives.
Expand Down
10 changes: 9 additions & 1 deletion src/agents/feedback_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
into concise, actionable points for the annotation agent.
"""

import logging

from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage

from src.agents.state import HedAnnotationState

logger = logging.getLogger(__name__)


class FeedbackSummarizer:
"""Agent that summarizes validation errors and feedback.
Expand Down Expand Up @@ -112,7 +116,11 @@ async def summarize(self, state: HedAnnotationState) -> dict:
HumanMessage(content=user_prompt),
]

response = await self.llm.ainvoke(messages)
try:
response = await self.llm.ainvoke(messages)
except Exception as e:
logger.error("Feedback summarization LLM invocation failed: %s", e, exc_info=True)
raise
content = response.content
summarized_feedback = content.strip() if isinstance(content, str) else str(content)

Expand Down
11 changes: 7 additions & 4 deletions src/agents/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ class HedAnnotationState(TypedDict):
def create_initial_state(
input_description: str,
schema_version: str = "8.4.0",
max_validation_attempts: int = 5,
max_total_iterations: int = 10,
max_validation_attempts: int = 3,
max_total_iterations: int | None = None,
run_assessment: bool = False,
extracted_keywords: list[str] | None = None,
semantic_hints: list[dict] | None = None,
Expand All @@ -93,8 +93,8 @@ def create_initial_state(
Args:
input_description: Natural language event description to annotate
schema_version: HED schema version to use (default: "8.4.0")
max_validation_attempts: Maximum validation retry attempts (default: 5)
max_total_iterations: Maximum total iterations to prevent infinite loops (default: 10)
max_validation_attempts: Maximum validation retry attempts (default: 3)
max_total_iterations: Maximum total iterations (default: max_validation_attempts + 1)
run_assessment: Whether to run final assessment (default: False)
extracted_keywords: Pre-extracted keywords from description (optional)
semantic_hints: Pre-computed semantic search hints (optional)
Expand All @@ -103,6 +103,9 @@ def create_initial_state(
Returns:
Initial HedAnnotationState
"""
if max_total_iterations is None:
max_total_iterations = max_validation_attempts + 1

return HedAnnotationState(
messages=[],
input_description=input_description,
Expand Down
96 changes: 52 additions & 44 deletions src/agents/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import logging
import time
from pathlib import Path

from langchain_core.language_models import BaseChatModel
Expand Down Expand Up @@ -104,7 +105,7 @@ def _build_graph(self) -> StateGraph:
Compiled StateGraph
"""
# Create graph
workflow = StateGraph(HedAnnotationState)
workflow = StateGraph(HedAnnotationState) # type: ignore[arg-type] # LangGraph typing limitation

# Add nodes
if self.enable_semantic_search:
Expand Down Expand Up @@ -221,9 +222,13 @@ async def _annotate_node(self, state: HedAnnotationState) -> dict:
print(
f"[WORKFLOW] Entering annotate node (validation attempt {state['validation_attempts']}, total iteration {total_iters})"
)
t0 = time.monotonic()
result = await self.annotation_agent.annotate(state)
elapsed = time.monotonic() - t0
result["total_iterations"] = total_iters # Increment counter
print(f"[WORKFLOW] Annotation generated: {result.get('current_annotation', '')[:100]}...")
print(
f"[WORKFLOW] Annotation generated in {elapsed:.1f}s: {result.get('current_annotation', '')[:100]}..."
)
return result

async def _validate_node(self, state: HedAnnotationState) -> dict:
Expand All @@ -236,9 +241,11 @@ async def _validate_node(self, state: HedAnnotationState) -> dict:
State update
"""
print("[WORKFLOW] Entering validate node")
t0 = time.monotonic()
result = await self.validation_agent.validate(state)
elapsed = time.monotonic() - t0
print(
f"[WORKFLOW] Validation result: {result.get('validation_status')}, is_valid: {result.get('is_valid')}"
f"[WORKFLOW] Validation result in {elapsed:.1f}s: {result.get('validation_status')}, is_valid: {result.get('is_valid')}"
)
if not result.get("is_valid"):
print(f"[WORKFLOW] Validation errors: {result.get('validation_errors', [])}")
Expand All @@ -254,8 +261,12 @@ async def _evaluate_node(self, state: HedAnnotationState) -> dict:
State update
"""
print("[WORKFLOW] Entering evaluate node")
t0 = time.monotonic()
result = await self.evaluation_agent.evaluate(state)
print(f"[WORKFLOW] Evaluation result: is_faithful={result.get('is_faithful')}")
elapsed = time.monotonic() - t0
print(
f"[WORKFLOW] Evaluation result in {elapsed:.1f}s: is_faithful={result.get('is_faithful')}"
)

# Set default assessment values if assessment will be skipped
run_assessment = state.get("run_assessment", False)
Expand All @@ -281,7 +292,12 @@ async def _assess_node(self, state: HedAnnotationState) -> dict:
Returns:
State update
"""
return await self.assessment_agent.assess(state)
print("[WORKFLOW] Entering assess node")
t0 = time.monotonic()
result = await self.assessment_agent.assess(state)
elapsed = time.monotonic() - t0
print(f"[WORKFLOW] Assessment completed in {elapsed:.1f}s")
return result

async def _summarize_feedback_node(self, state: HedAnnotationState) -> dict:
"""Summarize feedback node: Condense errors and feedback.
Expand All @@ -293,9 +309,11 @@ async def _summarize_feedback_node(self, state: HedAnnotationState) -> dict:
State update with summarized feedback
"""
print("[WORKFLOW] Entering summarize_feedback node")
t0 = time.monotonic()
result = await self.feedback_summarizer.summarize(state)
elapsed = time.monotonic() - t0
print(
f"[WORKFLOW] Feedback summarized: {result.get('validation_errors_augmented', [''])[0][:100] if result.get('validation_errors_augmented') else 'No feedback'}..."
f"[WORKFLOW] Feedback summarized in {elapsed:.1f}s: {result.get('validation_errors_augmented', [''])[0][:100] if result.get('validation_errors_augmented') else 'No feedback'}..."
)
return result

Expand Down Expand Up @@ -327,52 +345,39 @@ def _route_after_evaluation(
self,
state: HedAnnotationState,
) -> str:
"""Route after evaluation based on faithfulness.
"""Route after evaluation based on faithfulness and assessment mode.

When run_assessment=False (default), evaluation is informational only;
the result is reported but never triggers refinement loops.
When run_assessment=True, evaluation can trigger refinement and the
assessment node runs at the end.

Args:
state: Current workflow state

Returns:
Next node name
"""
# Check if max total iterations reached
total_iters = state.get("total_iterations", 0)
max_iters = state.get("max_total_iterations", 10)
run_assessment = state.get("run_assessment", False)

# When assessment is off, evaluation is informational -- always end
if not run_assessment:
print(
f"[WORKFLOW] Evaluation complete (informational, is_faithful={state['is_faithful']}) - routing to END"
)
return "end"

# Assessment mode: allow refinement loops with iteration cap
total_iters = state.get("total_iterations", 0)
max_iters = state.get("max_total_iterations", 4)

if total_iters >= max_iters:
# Only run assessment at max iterations if explicitly requested
if run_assessment:
print(f"[WORKFLOW] Routing to assess (max total iterations {max_iters} reached)")
return "assess"
else:
print(
"[WORKFLOW] Skipping assessment (max iterations reached, assessment not requested) - routing to END"
)
return "end"
print(f"[WORKFLOW] Routing to assess (max total iterations {max_iters} reached)")
return "assess"

if state["is_faithful"]:
# Only run assessment if explicitly requested
if state.get("is_valid") and run_assessment:
print(
"[WORKFLOW] Routing to assess (annotation is valid and faithful, assessment requested)"
)
return "assess"
elif state.get("is_valid"):
print(
"[WORKFLOW] Skipping assessment (annotation is valid and faithful, assessment not requested) - routing to END"
)
return "end"
elif run_assessment:
print(
"[WORKFLOW] Routing to assess (annotation is faithful but has validation issues)"
)
return "assess"
else:
print(
"[WORKFLOW] Skipping assessment (has validation issues, assessment not requested) - routing to END"
)
return "end"
print("[WORKFLOW] Routing to assess (annotation is faithful)")
return "assess"
else:
print(
f"[WORKFLOW] Routing to summarize_feedback (annotation needs refinement, iteration {total_iters}/{max_iters})"
Expand All @@ -383,8 +388,8 @@ async def run(
self,
input_description: str,
schema_version: str = "8.4.0",
max_validation_attempts: int = 5,
max_total_iterations: int = 10,
max_validation_attempts: int = 3,
max_total_iterations: int | None = None,
run_assessment: bool = False,
no_extend: bool = False,
config: dict | None = None,
Expand All @@ -395,7 +400,7 @@ async def run(
input_description: Natural language event description
schema_version: HED schema version to use
max_validation_attempts: Maximum validation retry attempts
max_total_iterations: Maximum total iterations to prevent infinite loops
max_total_iterations: Maximum total iterations (default: max_validation_attempts + 1)
run_assessment: Whether to run final assessment (default: False)
no_extend: If True, prohibit tag extensions (use only existing vocabulary)
config: Optional LangGraph config (e.g., recursion_limit)
Expand All @@ -405,6 +410,9 @@ async def run(
"""
from src.agents.state import create_initial_state

if max_total_iterations is None:
max_total_iterations = max_validation_attempts + 1

# Create initial state
initial_state = create_initial_state(
input_description,
Expand All @@ -418,4 +426,4 @@ async def run(
# Run workflow
final_state = await self.graph.ainvoke(initial_state, config=config) # type: ignore[attr-defined]

return final_state # type: ignore[no-any-return]
return final_state
Loading