Skip to content

Commit ce55059

Browse files
fix: prevent evaluation loop from causing stuck workflows
- Make evaluation informational-only when run_assessment=False - Add 15s LLM call timeout via request_timeout on ChatLiteLLM - Default evaluation parsing to ACCEPT when ambiguous - Derive max_total_iterations from max_validation_attempts + 1 - Add per-node timing to all workflow nodes - Switch eval model default to openai/gpt-oss-120b on groq - Lower recursion_limit from 100 to 50 - Update default max_validation_attempts from 5 to 3 Closes #119
1 parent f97815b commit ce55059

File tree

6 files changed

+81
-77
lines changed

6 files changed

+81
-77
lines changed

.env.example

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,10 @@ OPENROUTER_API_KEY=your-openrouter-api-key-here
6767
ANNOTATION_MODEL=mistralai/mistral-small-3.2-24b-instruct
6868
ANNOTATION_PROVIDER=mistral
6969

70-
# Evaluation/Assessment Model (consistent quality checks: Qwen3-235B via DeepInfra)
70+
# Evaluation/Assessment Model (fast quality checks: GPT-OSS-120B via Groq)
7171
# Used for evaluation, assessment, and feedback agents
72-
# Leave EVALUATION_PROVIDER empty to let OpenRouter auto-route
73-
EVALUATION_MODEL=qwen/qwen3-235b-a22b-2507
74-
EVALUATION_PROVIDER=deepinfra/fp8
72+
EVALUATION_MODEL=openai/gpt-oss-120b
73+
EVALUATION_PROVIDER=groq
7574

7675
# Vision Model (image description: Qwen3-VL via deepinfra)
7776
VISION_MODEL=qwen/qwen3-vl-30b-a3b-instruct
@@ -142,8 +141,8 @@ API_WORKERS=4
142141
# ============================================================================
143142
# Workflow Configuration
144143
# ============================================================================
145-
MAX_VALIDATION_ATTEMPTS=5
146-
MAX_TOTAL_ITERATIONS=10
144+
MAX_VALIDATION_ATTEMPTS=3
145+
MAX_TOTAL_ITERATIONS=4
147146

148147
# ============================================================================
149148
# Logging

src/agents/evaluation_agent.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -201,19 +201,13 @@ def _parse_decision(self, feedback: str) -> bool:
201201
result = faithful_match.group(1)
202202
return result in ["yes", "partial"] # Accept partial as good enough!
203203

204-
# Fallback: look for positive indicators
205-
positive_indicators = ["accept", "good", "sufficient", "adequate", "captures well"]
206-
negative_indicators = ["refine", "missing", "incorrect", "inaccurate", "lacks"]
204+
# Fallback: look for explicit refine indicators only
205+
refine_indicators = ["refine", "incorrect", "inaccurate", "wrong"]
206+
if any(indicator in feedback_lower for indicator in refine_indicators):
207+
return False
207208

208-
positive_score = sum(1 for indicator in positive_indicators if indicator in feedback_lower)
209-
negative_score = sum(1 for indicator in negative_indicators if indicator in feedback_lower)
210-
211-
# If more positive than negative, accept
212-
if positive_score > negative_score:
213-
return True
214-
215-
# Default to refine if ambiguous (conservative)
216-
return False
209+
# Default to accept if ambiguous -- avoid unnecessary refinement loops
210+
return True
217211

218212
def _check_tags_and_suggest(self, annotation: str) -> str:
219213
"""Check annotation for invalid tags and suggest alternatives.

src/agents/workflow.py

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
import logging
8+
import time
89
from pathlib import Path
910

1011
from langchain_core.language_models import BaseChatModel
@@ -221,9 +222,13 @@ async def _annotate_node(self, state: HedAnnotationState) -> dict:
221222
print(
222223
f"[WORKFLOW] Entering annotate node (validation attempt {state['validation_attempts']}, total iteration {total_iters})"
223224
)
225+
t0 = time.monotonic()
224226
result = await self.annotation_agent.annotate(state)
227+
elapsed = time.monotonic() - t0
225228
result["total_iterations"] = total_iters # Increment counter
226-
print(f"[WORKFLOW] Annotation generated: {result.get('current_annotation', '')[:100]}...")
229+
print(
230+
f"[WORKFLOW] Annotation generated in {elapsed:.1f}s: {result.get('current_annotation', '')[:100]}..."
231+
)
227232
return result
228233

229234
async def _validate_node(self, state: HedAnnotationState) -> dict:
@@ -236,9 +241,11 @@ async def _validate_node(self, state: HedAnnotationState) -> dict:
236241
State update
237242
"""
238243
print("[WORKFLOW] Entering validate node")
244+
t0 = time.monotonic()
239245
result = await self.validation_agent.validate(state)
246+
elapsed = time.monotonic() - t0
240247
print(
241-
f"[WORKFLOW] Validation result: {result.get('validation_status')}, is_valid: {result.get('is_valid')}"
248+
f"[WORKFLOW] Validation result in {elapsed:.1f}s: {result.get('validation_status')}, is_valid: {result.get('is_valid')}"
242249
)
243250
if not result.get("is_valid"):
244251
print(f"[WORKFLOW] Validation errors: {result.get('validation_errors', [])}")
@@ -254,8 +261,12 @@ async def _evaluate_node(self, state: HedAnnotationState) -> dict:
254261
State update
255262
"""
256263
print("[WORKFLOW] Entering evaluate node")
264+
t0 = time.monotonic()
257265
result = await self.evaluation_agent.evaluate(state)
258-
print(f"[WORKFLOW] Evaluation result: is_faithful={result.get('is_faithful')}")
266+
elapsed = time.monotonic() - t0
267+
print(
268+
f"[WORKFLOW] Evaluation result in {elapsed:.1f}s: is_faithful={result.get('is_faithful')}"
269+
)
259270

260271
# Set default assessment values if assessment will be skipped
261272
run_assessment = state.get("run_assessment", False)
@@ -281,7 +292,11 @@ async def _assess_node(self, state: HedAnnotationState) -> dict:
281292
Returns:
282293
State update
283294
"""
284-
return await self.assessment_agent.assess(state)
295+
t0 = time.monotonic()
296+
result = await self.assessment_agent.assess(state)
297+
elapsed = time.monotonic() - t0
298+
print(f"[WORKFLOW] Assessment completed in {elapsed:.1f}s")
299+
return result
285300

286301
async def _summarize_feedback_node(self, state: HedAnnotationState) -> dict:
287302
"""Summarize feedback node: Condense errors and feedback.
@@ -293,9 +308,11 @@ async def _summarize_feedback_node(self, state: HedAnnotationState) -> dict:
293308
State update with summarized feedback
294309
"""
295310
print("[WORKFLOW] Entering summarize_feedback node")
311+
t0 = time.monotonic()
296312
result = await self.feedback_summarizer.summarize(state)
313+
elapsed = time.monotonic() - t0
297314
print(
298-
f"[WORKFLOW] Feedback summarized: {result.get('validation_errors_augmented', [''])[0][:100] if result.get('validation_errors_augmented') else 'No feedback'}..."
315+
f"[WORKFLOW] Feedback summarized in {elapsed:.1f}s: {result.get('validation_errors_augmented', [''])[0][:100] if result.get('validation_errors_augmented') else 'No feedback'}..."
299316
)
300317
return result
301318

@@ -327,52 +344,39 @@ def _route_after_evaluation(
327344
self,
328345
state: HedAnnotationState,
329346
) -> str:
330-
"""Route after evaluation based on faithfulness.
347+
"""Route after evaluation based on faithfulness and assessment mode.
348+
349+
When run_assessment=False (default), evaluation is informational only;
350+
the result is reported but never triggers refinement loops.
351+
When run_assessment=True, evaluation can trigger refinement and the
352+
assessment node runs at the end.
331353
332354
Args:
333355
state: Current workflow state
334356
335357
Returns:
336358
Next node name
337359
"""
338-
# Check if max total iterations reached
339-
total_iters = state.get("total_iterations", 0)
340-
max_iters = state.get("max_total_iterations", 10)
341360
run_assessment = state.get("run_assessment", False)
342361

362+
# When assessment is off, evaluation is informational -- always end
363+
if not run_assessment:
364+
print(
365+
f"[WORKFLOW] Evaluation complete (informational, is_faithful={state['is_faithful']}) - routing to END"
366+
)
367+
return "end"
368+
369+
# Assessment mode: allow refinement loops with iteration cap
370+
total_iters = state.get("total_iterations", 0)
371+
max_iters = state.get("max_total_iterations", 4)
372+
343373
if total_iters >= max_iters:
344-
# Only run assessment at max iterations if explicitly requested
345-
if run_assessment:
346-
print(f"[WORKFLOW] Routing to assess (max total iterations {max_iters} reached)")
347-
return "assess"
348-
else:
349-
print(
350-
"[WORKFLOW] Skipping assessment (max iterations reached, assessment not requested) - routing to END"
351-
)
352-
return "end"
374+
print(f"[WORKFLOW] Routing to assess (max total iterations {max_iters} reached)")
375+
return "assess"
353376

354377
if state["is_faithful"]:
355-
# Only run assessment if explicitly requested
356-
if state.get("is_valid") and run_assessment:
357-
print(
358-
"[WORKFLOW] Routing to assess (annotation is valid and faithful, assessment requested)"
359-
)
360-
return "assess"
361-
elif state.get("is_valid"):
362-
print(
363-
"[WORKFLOW] Skipping assessment (annotation is valid and faithful, assessment not requested) - routing to END"
364-
)
365-
return "end"
366-
elif run_assessment:
367-
print(
368-
"[WORKFLOW] Routing to assess (annotation is faithful but has validation issues)"
369-
)
370-
return "assess"
371-
else:
372-
print(
373-
"[WORKFLOW] Skipping assessment (has validation issues, assessment not requested) - routing to END"
374-
)
375-
return "end"
378+
print("[WORKFLOW] Routing to assess (annotation is faithful)")
379+
return "assess"
376380
else:
377381
print(
378382
f"[WORKFLOW] Routing to summarize_feedback (annotation needs refinement, iteration {total_iters}/{max_iters})"
@@ -383,8 +387,8 @@ async def run(
383387
self,
384388
input_description: str,
385389
schema_version: str = "8.4.0",
386-
max_validation_attempts: int = 5,
387-
max_total_iterations: int = 10,
390+
max_validation_attempts: int = 3,
391+
max_total_iterations: int = 4,
388392
run_assessment: bool = False,
389393
no_extend: bool = False,
390394
config: dict | None = None,

src/api/main.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def create_openrouter_workflow(
9898
api_key: OpenRouter API key
9999
annotation_model: Model for annotation (default: ANNOTATION_MODEL env or Claude Haiku 4.5)
100100
annotation_provider: Provider for annotation model (default: ANNOTATION_PROVIDER env or "anthropic")
101-
eval_model: Model for eval/assessment/feedback (default: EVALUATION_MODEL env or Qwen3-235B)
102-
eval_provider: Provider for eval models (default: EVALUATION_PROVIDER env or auto-routed)
101+
eval_model: Model for eval/assessment/feedback (default: EVALUATION_MODEL env or GPT-OSS-120B)
102+
eval_provider: Provider for eval models (default: EVALUATION_PROVIDER env or "groq")
103103
temperature: LLM temperature (default: 0.1)
104104
user_id: User ID for cache optimization (derived from API key if not provided)
105105
schema_dir: Path to HED schemas (None = fetch from GitHub)
@@ -112,8 +112,8 @@ def create_openrouter_workflow(
112112
# Apply defaults from environment
113113
default_annotation_model = os.getenv("ANNOTATION_MODEL", "anthropic/claude-haiku-4.5")
114114
default_annotation_provider = os.getenv("ANNOTATION_PROVIDER", "anthropic")
115-
default_eval_model = os.getenv("EVALUATION_MODEL", "qwen/qwen3-235b-a22b-2507")
116-
default_eval_provider = os.getenv("EVALUATION_PROVIDER", "")
115+
default_eval_model = os.getenv("EVALUATION_MODEL", "openai/gpt-oss-120b")
116+
default_eval_provider = os.getenv("EVALUATION_PROVIDER", "groq")
117117

118118
# Resolve final values: parameter > env var > default
119119
actual_annotation_model = get_model_name(annotation_model or default_annotation_model)
@@ -640,15 +640,17 @@ async def annotate(
640640
active_workflow = workflow
641641

642642
try:
643-
# Run annotation workflow with increased recursion limit for long descriptions
644-
# LangGraph default is 25, increase to 100 for complex workflows
645-
config = {"recursion_limit": 100}
643+
config = {"recursion_limit": 50}
644+
645+
# Derive total iteration cap from validation attempts (+1 for evaluation refinement)
646+
max_total_iterations = request.max_validation_attempts + 1
646647

647648
start_time = time.time()
648649
final_state = await active_workflow.run(
649650
input_description=request.description,
650651
schema_version=request.schema_version,
651652
max_validation_attempts=request.max_validation_attempts,
653+
max_total_iterations=max_total_iterations,
652654
run_assessment=request.run_assessment,
653655
config=config,
654656
)
@@ -839,12 +841,14 @@ async def annotate_from_image(
839841
image_metadata = vision_result["metadata"]
840842

841843
# Step 2: Pass description through HED annotation workflow
842-
config = {"recursion_limit": 100}
844+
config = {"recursion_limit": 50}
845+
img_max_total_iters = request.max_validation_attempts + 1
843846

844847
final_state = await active_workflow.run(
845848
input_description=image_description,
846849
schema_version=request.schema_version,
847850
max_validation_attempts=request.max_validation_attempts,
851+
max_total_iterations=img_max_total_iters,
848852
run_assessment=request.run_assessment,
849853
config=config,
850854
)
@@ -992,12 +996,13 @@ async def annotate_stream(
992996
raise HTTPException(status_code=503, detail="Workflow not initialized")
993997
active_workflow = workflow
994998

995-
# Create initial state
999+
# Create initial state with iteration cap derived from validation attempts
1000+
max_total_iterations = request.max_validation_attempts + 1
9961001
initial_state = create_initial_state(
9971002
request.description,
9981003
request.schema_version,
9991004
request.max_validation_attempts,
1000-
10, # max_total_iterations
1005+
max_total_iterations,
10011006
request.run_assessment,
10021007
)
10031008

@@ -1031,7 +1036,7 @@ def send_event(event_type: str, data: dict) -> str:
10311036
validation_attempt = 0
10321037

10331038
# Use LangGraph's astream_events for real-time streaming
1034-
config = {"recursion_limit": 100}
1039+
config = {"recursion_limit": 50}
10351040
async for event in active_workflow.graph.astream_events(
10361041
initial_state, config=config, version="v2"
10371042
):
@@ -1287,11 +1292,12 @@ def send_event(event_type: str, data: dict) -> str:
12871292
)
12881293

12891294
# Step 2: Create initial state for annotation workflow
1295+
img_max_total_iterations = request.max_validation_attempts + 1
12901296
initial_state = create_initial_state(
12911297
image_description,
12921298
request.schema_version,
12931299
request.max_validation_attempts,
1294-
10, # max_total_iterations
1300+
img_max_total_iterations,
12951301
request.run_assessment,
12961302
)
12971303

@@ -1301,7 +1307,7 @@ def send_event(event_type: str, data: dict) -> str:
13011307
validation_attempt = 0
13021308

13031309
# Use LangGraph's astream_events for real-time streaming
1304-
config = {"recursion_limit": 100}
1310+
config = {"recursion_limit": 50}
13051311
async for event in active_workflow.graph.astream_events(
13061312
initial_state, config=config, version="v2"
13071313
):

src/api/models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ class AnnotationRequest(BaseModel):
2828
examples=["8.3.0", "8.4.0"],
2929
)
3030
max_validation_attempts: int = Field(
31-
default=5,
32-
description="Maximum validation retry attempts",
31+
default=3,
32+
description="Maximum validation retry attempts (total iterations = this + 1)",
3333
ge=1,
3434
le=10,
3535
)
@@ -155,8 +155,8 @@ class ImageAnnotationRequest(BaseModel):
155155
examples=["8.3.0", "8.4.0"],
156156
)
157157
max_validation_attempts: int = Field(
158-
default=5,
159-
description="Maximum validation retry attempts",
158+
default=3,
159+
description="Maximum validation retry attempts (total iterations = this + 1)",
160160
ge=1,
161161
le=10,
162162
)

src/utils/openrouter_llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,14 @@ def create_openrouter_llm(
6262
if user_id:
6363
model_kwargs["user"] = user_id
6464

65-
# Create base LLM
65+
# Create base LLM with timeout to prevent hanging on slow providers
6666
llm = ChatLiteLLM(
6767
model=litellm_model,
6868
api_key=api_key or os.getenv("OPENROUTER_API_KEY"),
6969
temperature=temperature,
7070
max_tokens=max_tokens,
7171
model_kwargs=model_kwargs,
72+
request_timeout=15,
7273
)
7374

7475
# Determine if caching should be enabled

0 commit comments

Comments
 (0)