Skip to content

Commit e67d2e4

Browse files
Fix workflow stuck in refining/evaluation loops (#120)
* fix: prevent evaluation loop from causing stuck workflows - Make evaluation informational-only when run_assessment=False - Add 15s LLM call timeout via request_timeout on ChatLiteLLM - Default evaluation parsing to ACCEPT when ambiguous - Derive max_total_iterations from max_validation_attempts + 1 - Add per-node timing to all workflow nodes - Switch eval model default to openai/gpt-oss-120b on groq - Lower recursion_limit from 100 to 50 - Update default max_validation_attempts from 5 to 3 Closes #119 * fix: address review findings and ty type checking - Move re import to module level in evaluation_agent.py - Add missing "Entering assess node" log for consistency - Centralize max_total_iterations derivation in state.py and workflow.py (was duplicated 3x in main.py, now defaults to max_validation_attempts + 1) - Update create_initial_state defaults (was stale at 5/10) - Fix ty warnings: remove unused type: ignore comments - Fix ty errors: add type: ignore for LangGraph/Starlette typing limitations - Fix return type on get_default_path (-> str | None) - Update test_state to match new default * Add error handling for LLM timeouts and rate limits - Add try/except with logging to evaluation, assessment, and feedback agents - Map timeouts to HTTP 504, rate limits to HTTP 429 in API endpoints - Add error_type field to streaming SSE error events - Sanitize error messages to avoid leaking internal details - Add debug log for silent ACCEPT fallback in evaluation parsing
1 parent f97815b commit e67d2e4

File tree

10 files changed

+201
-103
lines changed

10 files changed

+201
-103
lines changed

.env.example

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,10 @@ OPENROUTER_API_KEY=your-openrouter-api-key-here
6767
ANNOTATION_MODEL=mistralai/mistral-small-3.2-24b-instruct
6868
ANNOTATION_PROVIDER=mistral
6969

70-
# Evaluation/Assessment Model (consistent quality checks: Qwen3-235B via DeepInfra)
70+
# Evaluation/Assessment Model (fast quality checks: GPT-OSS-120B via Groq)
7171
# Used for evaluation, assessment, and feedback agents
72-
# Leave EVALUATION_PROVIDER empty to let OpenRouter auto-route
73-
EVALUATION_MODEL=qwen/qwen3-235b-a22b-2507
74-
EVALUATION_PROVIDER=deepinfra/fp8
72+
EVALUATION_MODEL=openai/gpt-oss-120b
73+
EVALUATION_PROVIDER=groq
7574

7675
# Vision Model (image description: Qwen3-VL via deepinfra)
7776
VISION_MODEL=qwen/qwen3-vl-30b-a3b-instruct
@@ -142,8 +141,8 @@ API_WORKERS=4
142141
# ============================================================================
143142
# Workflow Configuration
144143
# ============================================================================
145-
MAX_VALIDATION_ATTEMPTS=5
146-
MAX_TOTAL_ITERATIONS=10
144+
MAX_VALIDATION_ATTEMPTS=3
145+
MAX_TOTAL_ITERATIONS=4
147146

148147
# ============================================================================
149148
# Logging

src/agents/assessment_agent.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
elements or dimensions in the HED annotation.
55
"""
66

7+
import logging
78
from pathlib import Path
89

910
from langchain_core.language_models import BaseChatModel
1011
from langchain_core.messages import HumanMessage, SystemMessage
1112

1213
from src.agents.state import HedAnnotationState
1314

15+
logger = logging.getLogger(__name__)
16+
1417

1518
class AssessmentAgent:
1619
"""Agent that performs final assessment of HED annotations.
@@ -104,7 +107,11 @@ async def assess(self, state: HedAnnotationState) -> dict:
104107
HumanMessage(content=user_prompt),
105108
]
106109

107-
response = await self.llm.ainvoke(messages)
110+
try:
111+
response = await self.llm.ainvoke(messages)
112+
except Exception as e:
113+
logger.error("Assessment LLM invocation failed: %s", e, exc_info=True)
114+
raise
108115
content = response.content
109116
feedback = content.strip() if isinstance(content, str) else str(content)
110117

src/agents/evaluation_agent.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
the original natural language event description.
55
"""
66

7+
import logging
8+
import re
79
from pathlib import Path
810

911
from langchain_core.language_models import BaseChatModel
@@ -12,6 +14,8 @@
1214
from src.agents.state import HedAnnotationState
1315
from src.utils.json_schema_loader import HedJsonSchemaLoader, load_latest_schema
1416

17+
logger = logging.getLogger(__name__)
18+
1519

1620
class EvaluationAgent:
1721
"""Agent that evaluates the faithfulness of HED annotations.
@@ -163,7 +167,11 @@ async def evaluate(self, state: HedAnnotationState) -> dict:
163167
HumanMessage(content=user_prompt),
164168
]
165169

166-
response = await self.llm.ainvoke(messages)
170+
try:
171+
response = await self.llm.ainvoke(messages)
172+
except Exception as e:
173+
logger.error("Evaluation LLM invocation failed: %s", e, exc_info=True)
174+
raise
167175
content = response.content
168176
feedback = content.strip() if isinstance(content, str) else str(content)
169177

@@ -186,8 +194,6 @@ def _parse_decision(self, feedback: str) -> bool:
186194
Returns:
187195
True if annotation should be accepted, False if needs refinement
188196
"""
189-
import re
190-
191197
feedback_lower = feedback.lower()
192198

193199
# Check for explicit DECISION line
@@ -201,19 +207,16 @@ def _parse_decision(self, feedback: str) -> bool:
201207
result = faithful_match.group(1)
202208
return result in ["yes", "partial"] # Accept partial as good enough!
203209

204-
# Fallback: look for positive indicators
205-
positive_indicators = ["accept", "good", "sufficient", "adequate", "captures well"]
206-
negative_indicators = ["refine", "missing", "incorrect", "inaccurate", "lacks"]
207-
208-
positive_score = sum(1 for indicator in positive_indicators if indicator in feedback_lower)
209-
negative_score = sum(1 for indicator in negative_indicators if indicator in feedback_lower)
210+
# Fallback: look for explicit refine indicators only
211+
refine_indicators = ["refine", "incorrect", "inaccurate", "wrong"]
212+
if any(indicator in feedback_lower for indicator in refine_indicators):
213+
return False
210214

211-
# If more positive than negative, accept
212-
if positive_score > negative_score:
213-
return True
214-
215-
# Default to refine if ambiguous (conservative)
216-
return False
215+
# Default to accept if ambiguous -- avoid unnecessary refinement loops
216+
logger.debug(
217+
"Evaluation parsing: no explicit DECISION/FAITHFUL/refine indicator found; defaulting to ACCEPT"
218+
)
219+
return True
217220

218221
def _check_tags_and_suggest(self, annotation: str) -> str:
219222
"""Check annotation for invalid tags and suggest alternatives.

src/agents/feedback_summarizer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@
44
into concise, actionable points for the annotation agent.
55
"""
66

7+
import logging
8+
79
from langchain_core.language_models import BaseChatModel
810
from langchain_core.messages import HumanMessage, SystemMessage
911

1012
from src.agents.state import HedAnnotationState
1113

14+
logger = logging.getLogger(__name__)
15+
1216

1317
class FeedbackSummarizer:
1418
"""Agent that summarizes validation errors and feedback.
@@ -112,7 +116,11 @@ async def summarize(self, state: HedAnnotationState) -> dict:
112116
HumanMessage(content=user_prompt),
113117
]
114118

115-
response = await self.llm.ainvoke(messages)
119+
try:
120+
response = await self.llm.ainvoke(messages)
121+
except Exception as e:
122+
logger.error("Feedback summarization LLM invocation failed: %s", e, exc_info=True)
123+
raise
116124
content = response.content
117125
summarized_feedback = content.strip() if isinstance(content, str) else str(content)
118126

src/agents/state.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ class HedAnnotationState(TypedDict):
8181
def create_initial_state(
8282
input_description: str,
8383
schema_version: str = "8.4.0",
84-
max_validation_attempts: int = 5,
85-
max_total_iterations: int = 10,
84+
max_validation_attempts: int = 3,
85+
max_total_iterations: int | None = None,
8686
run_assessment: bool = False,
8787
extracted_keywords: list[str] | None = None,
8888
semantic_hints: list[dict] | None = None,
@@ -93,8 +93,8 @@ def create_initial_state(
9393
Args:
9494
input_description: Natural language event description to annotate
9595
schema_version: HED schema version to use (default: "8.4.0")
96-
max_validation_attempts: Maximum validation retry attempts (default: 5)
97-
max_total_iterations: Maximum total iterations to prevent infinite loops (default: 10)
96+
max_validation_attempts: Maximum validation retry attempts (default: 3)
97+
max_total_iterations: Maximum total iterations (default: max_validation_attempts + 1)
9898
run_assessment: Whether to run final assessment (default: False)
9999
extracted_keywords: Pre-extracted keywords from description (optional)
100100
semantic_hints: Pre-computed semantic search hints (optional)
@@ -103,6 +103,9 @@ def create_initial_state(
103103
Returns:
104104
Initial HedAnnotationState
105105
"""
106+
if max_total_iterations is None:
107+
max_total_iterations = max_validation_attempts + 1
108+
106109
return HedAnnotationState(
107110
messages=[],
108111
input_description=input_description,

src/agents/workflow.py

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
import logging
8+
import time
89
from pathlib import Path
910

1011
from langchain_core.language_models import BaseChatModel
@@ -104,7 +105,7 @@ def _build_graph(self) -> StateGraph:
104105
Compiled StateGraph
105106
"""
106107
# Create graph
107-
workflow = StateGraph(HedAnnotationState)
108+
workflow = StateGraph(HedAnnotationState) # type: ignore[arg-type] # LangGraph typing limitation
108109

109110
# Add nodes
110111
if self.enable_semantic_search:
@@ -221,9 +222,13 @@ async def _annotate_node(self, state: HedAnnotationState) -> dict:
221222
print(
222223
f"[WORKFLOW] Entering annotate node (validation attempt {state['validation_attempts']}, total iteration {total_iters})"
223224
)
225+
t0 = time.monotonic()
224226
result = await self.annotation_agent.annotate(state)
227+
elapsed = time.monotonic() - t0
225228
result["total_iterations"] = total_iters # Increment counter
226-
print(f"[WORKFLOW] Annotation generated: {result.get('current_annotation', '')[:100]}...")
229+
print(
230+
f"[WORKFLOW] Annotation generated in {elapsed:.1f}s: {result.get('current_annotation', '')[:100]}..."
231+
)
227232
return result
228233

229234
async def _validate_node(self, state: HedAnnotationState) -> dict:
@@ -236,9 +241,11 @@ async def _validate_node(self, state: HedAnnotationState) -> dict:
236241
State update
237242
"""
238243
print("[WORKFLOW] Entering validate node")
244+
t0 = time.monotonic()
239245
result = await self.validation_agent.validate(state)
246+
elapsed = time.monotonic() - t0
240247
print(
241-
f"[WORKFLOW] Validation result: {result.get('validation_status')}, is_valid: {result.get('is_valid')}"
248+
f"[WORKFLOW] Validation result in {elapsed:.1f}s: {result.get('validation_status')}, is_valid: {result.get('is_valid')}"
242249
)
243250
if not result.get("is_valid"):
244251
print(f"[WORKFLOW] Validation errors: {result.get('validation_errors', [])}")
@@ -254,8 +261,12 @@ async def _evaluate_node(self, state: HedAnnotationState) -> dict:
254261
State update
255262
"""
256263
print("[WORKFLOW] Entering evaluate node")
264+
t0 = time.monotonic()
257265
result = await self.evaluation_agent.evaluate(state)
258-
print(f"[WORKFLOW] Evaluation result: is_faithful={result.get('is_faithful')}")
266+
elapsed = time.monotonic() - t0
267+
print(
268+
f"[WORKFLOW] Evaluation result in {elapsed:.1f}s: is_faithful={result.get('is_faithful')}"
269+
)
259270

260271
# Set default assessment values if assessment will be skipped
261272
run_assessment = state.get("run_assessment", False)
@@ -281,7 +292,12 @@ async def _assess_node(self, state: HedAnnotationState) -> dict:
281292
Returns:
282293
State update
283294
"""
284-
return await self.assessment_agent.assess(state)
295+
print("[WORKFLOW] Entering assess node")
296+
t0 = time.monotonic()
297+
result = await self.assessment_agent.assess(state)
298+
elapsed = time.monotonic() - t0
299+
print(f"[WORKFLOW] Assessment completed in {elapsed:.1f}s")
300+
return result
285301

286302
async def _summarize_feedback_node(self, state: HedAnnotationState) -> dict:
287303
"""Summarize feedback node: Condense errors and feedback.
@@ -293,9 +309,11 @@ async def _summarize_feedback_node(self, state: HedAnnotationState) -> dict:
293309
State update with summarized feedback
294310
"""
295311
print("[WORKFLOW] Entering summarize_feedback node")
312+
t0 = time.monotonic()
296313
result = await self.feedback_summarizer.summarize(state)
314+
elapsed = time.monotonic() - t0
297315
print(
298-
f"[WORKFLOW] Feedback summarized: {result.get('validation_errors_augmented', [''])[0][:100] if result.get('validation_errors_augmented') else 'No feedback'}..."
316+
f"[WORKFLOW] Feedback summarized in {elapsed:.1f}s: {result.get('validation_errors_augmented', [''])[0][:100] if result.get('validation_errors_augmented') else 'No feedback'}..."
299317
)
300318
return result
301319

@@ -327,52 +345,39 @@ def _route_after_evaluation(
327345
self,
328346
state: HedAnnotationState,
329347
) -> str:
330-
"""Route after evaluation based on faithfulness.
348+
"""Route after evaluation based on faithfulness and assessment mode.
349+
350+
When run_assessment=False (default), evaluation is informational only;
351+
the result is reported but never triggers refinement loops.
352+
When run_assessment=True, evaluation can trigger refinement and the
353+
assessment node runs at the end.
331354
332355
Args:
333356
state: Current workflow state
334357
335358
Returns:
336359
Next node name
337360
"""
338-
# Check if max total iterations reached
339-
total_iters = state.get("total_iterations", 0)
340-
max_iters = state.get("max_total_iterations", 10)
341361
run_assessment = state.get("run_assessment", False)
342362

363+
# When assessment is off, evaluation is informational -- always end
364+
if not run_assessment:
365+
print(
366+
f"[WORKFLOW] Evaluation complete (informational, is_faithful={state['is_faithful']}) - routing to END"
367+
)
368+
return "end"
369+
370+
# Assessment mode: allow refinement loops with iteration cap
371+
total_iters = state.get("total_iterations", 0)
372+
max_iters = state.get("max_total_iterations", 4)
373+
343374
if total_iters >= max_iters:
344-
# Only run assessment at max iterations if explicitly requested
345-
if run_assessment:
346-
print(f"[WORKFLOW] Routing to assess (max total iterations {max_iters} reached)")
347-
return "assess"
348-
else:
349-
print(
350-
"[WORKFLOW] Skipping assessment (max iterations reached, assessment not requested) - routing to END"
351-
)
352-
return "end"
375+
print(f"[WORKFLOW] Routing to assess (max total iterations {max_iters} reached)")
376+
return "assess"
353377

354378
if state["is_faithful"]:
355-
# Only run assessment if explicitly requested
356-
if state.get("is_valid") and run_assessment:
357-
print(
358-
"[WORKFLOW] Routing to assess (annotation is valid and faithful, assessment requested)"
359-
)
360-
return "assess"
361-
elif state.get("is_valid"):
362-
print(
363-
"[WORKFLOW] Skipping assessment (annotation is valid and faithful, assessment not requested) - routing to END"
364-
)
365-
return "end"
366-
elif run_assessment:
367-
print(
368-
"[WORKFLOW] Routing to assess (annotation is faithful but has validation issues)"
369-
)
370-
return "assess"
371-
else:
372-
print(
373-
"[WORKFLOW] Skipping assessment (has validation issues, assessment not requested) - routing to END"
374-
)
375-
return "end"
379+
print("[WORKFLOW] Routing to assess (annotation is faithful)")
380+
return "assess"
376381
else:
377382
print(
378383
f"[WORKFLOW] Routing to summarize_feedback (annotation needs refinement, iteration {total_iters}/{max_iters})"
@@ -383,8 +388,8 @@ async def run(
383388
self,
384389
input_description: str,
385390
schema_version: str = "8.4.0",
386-
max_validation_attempts: int = 5,
387-
max_total_iterations: int = 10,
391+
max_validation_attempts: int = 3,
392+
max_total_iterations: int | None = None,
388393
run_assessment: bool = False,
389394
no_extend: bool = False,
390395
config: dict | None = None,
@@ -395,7 +400,7 @@ async def run(
395400
input_description: Natural language event description
396401
schema_version: HED schema version to use
397402
max_validation_attempts: Maximum validation retry attempts
398-
max_total_iterations: Maximum total iterations to prevent infinite loops
403+
max_total_iterations: Maximum total iterations (default: max_validation_attempts + 1)
399404
run_assessment: Whether to run final assessment (default: False)
400405
no_extend: If True, prohibit tag extensions (use only existing vocabulary)
401406
config: Optional LangGraph config (e.g., recursion_limit)
@@ -405,6 +410,9 @@ async def run(
405410
"""
406411
from src.agents.state import create_initial_state
407412

413+
if max_total_iterations is None:
414+
max_total_iterations = max_validation_attempts + 1
415+
408416
# Create initial state
409417
initial_state = create_initial_state(
410418
input_description,
@@ -418,4 +426,4 @@ async def run(
418426
# Run workflow
419427
final_state = await self.graph.ainvoke(initial_state, config=config) # type: ignore[attr-defined]
420428

421-
return final_state # type: ignore[no-any-return]
429+
return final_state

0 commit comments

Comments
 (0)