Skip to content

Commit c89dcac

Browse files
fix: address review findings and ty type checking
- Move re import to module level in evaluation_agent.py - Add missing "Entering assess node" log for consistency - Centralize max_total_iterations derivation in state.py and workflow.py (was duplicated 3x in main.py, now defaults to max_validation_attempts + 1) - Update create_initial_state defaults (was stale at 5/10) - Fix ty warnings: remove unused type: ignore comments - Fix ty errors: add type: ignore for LangGraph/Starlette typing limitations - Fix return type on get_default_path (-> str | None) - Update test_state to match new default
1 parent ce55059 commit c89dcac

File tree

6 files changed

+25
-29
lines changed

6 files changed

+25
-29
lines changed

src/agents/evaluation_agent.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
the original natural language event description.
55
"""
66

7+
import re
78
from pathlib import Path
89

910
from langchain_core.language_models import BaseChatModel
@@ -186,8 +187,6 @@ def _parse_decision(self, feedback: str) -> bool:
186187
Returns:
187188
True if annotation should be accepted, False if needs refinement
188189
"""
189-
import re
190-
191190
feedback_lower = feedback.lower()
192191

193192
# Check for explicit DECISION line

src/agents/state.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ class HedAnnotationState(TypedDict):
8181
def create_initial_state(
8282
input_description: str,
8383
schema_version: str = "8.4.0",
84-
max_validation_attempts: int = 5,
85-
max_total_iterations: int = 10,
84+
max_validation_attempts: int = 3,
85+
max_total_iterations: int | None = None,
8686
run_assessment: bool = False,
8787
extracted_keywords: list[str] | None = None,
8888
semantic_hints: list[dict] | None = None,
@@ -93,8 +93,8 @@ def create_initial_state(
9393
Args:
9494
input_description: Natural language event description to annotate
9595
schema_version: HED schema version to use (default: "8.4.0")
96-
max_validation_attempts: Maximum validation retry attempts (default: 5)
97-
max_total_iterations: Maximum total iterations to prevent infinite loops (default: 10)
96+
max_validation_attempts: Maximum validation retry attempts (default: 3)
97+
max_total_iterations: Maximum total iterations (default: max_validation_attempts + 1)
9898
run_assessment: Whether to run final assessment (default: False)
9999
extracted_keywords: Pre-extracted keywords from description (optional)
100100
semantic_hints: Pre-computed semantic search hints (optional)
@@ -103,6 +103,9 @@ def create_initial_state(
103103
Returns:
104104
Initial HedAnnotationState
105105
"""
106+
if max_total_iterations is None:
107+
max_total_iterations = max_validation_attempts + 1
108+
106109
return HedAnnotationState(
107110
messages=[],
108111
input_description=input_description,

src/agents/workflow.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _build_graph(self) -> StateGraph:
105105
Compiled StateGraph
106106
"""
107107
# Create graph
108-
workflow = StateGraph(HedAnnotationState)
108+
workflow = StateGraph(HedAnnotationState) # type: ignore[arg-type] # LangGraph typing limitation
109109

110110
# Add nodes
111111
if self.enable_semantic_search:
@@ -292,6 +292,7 @@ async def _assess_node(self, state: HedAnnotationState) -> dict:
292292
Returns:
293293
State update
294294
"""
295+
print("[WORKFLOW] Entering assess node")
295296
t0 = time.monotonic()
296297
result = await self.assessment_agent.assess(state)
297298
elapsed = time.monotonic() - t0
@@ -388,7 +389,7 @@ async def run(
388389
input_description: str,
389390
schema_version: str = "8.4.0",
390391
max_validation_attempts: int = 3,
391-
max_total_iterations: int = 4,
392+
max_total_iterations: int | None = None,
392393
run_assessment: bool = False,
393394
no_extend: bool = False,
394395
config: dict | None = None,
@@ -399,7 +400,7 @@ async def run(
399400
input_description: Natural language event description
400401
schema_version: HED schema version to use
401402
max_validation_attempts: Maximum validation retry attempts
402-
max_total_iterations: Maximum total iterations to prevent infinite loops
403+
max_total_iterations: Maximum total iterations (default: max_validation_attempts + 1)
403404
run_assessment: Whether to run final assessment (default: False)
404405
no_extend: If True, prohibit tag extensions (use only existing vocabulary)
405406
config: Optional LangGraph config (e.g., recursion_limit)
@@ -409,6 +410,9 @@ async def run(
409410
"""
410411
from src.agents.state import create_initial_state
411412

413+
if max_total_iterations is None:
414+
max_total_iterations = max_validation_attempts + 1
415+
412416
# Create initial state
413417
initial_state = create_initial_state(
414418
input_description,
@@ -422,4 +426,4 @@ async def run(
422426
# Run workflow
423427
final_state = await self.graph.ainvoke(initial_state, config=config) # type: ignore[attr-defined]
424428

425-
return final_state # type: ignore[no-any-return]
429+
return final_state

src/api/main.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ async def lifespan(app: FastAPI):
283283
print("Initializing HEDit annotation workflow...")
284284

285285
# Auto-detect environment (Docker vs local)
286-
def get_default_path(docker_path: str, local_path: str) -> str:
286+
def get_default_path(docker_path: str, local_path: str) -> str | None:
287287
"""Get default path based on environment.
288288
289289
Args:
@@ -473,7 +473,7 @@ def get_default_path(docker_path: str, local_path: str) -> str:
473473

474474
# Add CORS middleware
475475
app.add_middleware(
476-
CORSMiddleware,
476+
CORSMiddleware, # type: ignore[arg-type] # Starlette typing limitation
477477
allow_origins=allowed_origins,
478478
allow_credentials=True,
479479
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
@@ -642,15 +642,11 @@ async def annotate(
642642
try:
643643
config = {"recursion_limit": 50}
644644

645-
# Derive total iteration cap from validation attempts (+1 for evaluation refinement)
646-
max_total_iterations = request.max_validation_attempts + 1
647-
648645
start_time = time.time()
649646
final_state = await active_workflow.run(
650647
input_description=request.description,
651648
schema_version=request.schema_version,
652649
max_validation_attempts=request.max_validation_attempts,
653-
max_total_iterations=max_total_iterations,
654650
run_assessment=request.run_assessment,
655651
config=config,
656652
)
@@ -842,13 +838,11 @@ async def annotate_from_image(
842838

843839
# Step 2: Pass description through HED annotation workflow
844840
config = {"recursion_limit": 50}
845-
img_max_total_iters = request.max_validation_attempts + 1
846841

847842
final_state = await active_workflow.run(
848843
input_description=image_description,
849844
schema_version=request.schema_version,
850845
max_validation_attempts=request.max_validation_attempts,
851-
max_total_iterations=img_max_total_iters,
852846
run_assessment=request.run_assessment,
853847
config=config,
854848
)
@@ -996,14 +990,12 @@ async def annotate_stream(
996990
raise HTTPException(status_code=503, detail="Workflow not initialized")
997991
active_workflow = workflow
998992

999-
# Create initial state with iteration cap derived from validation attempts
1000-
max_total_iterations = request.max_validation_attempts + 1
993+
# Create initial state (max_total_iterations derived from max_validation_attempts + 1)
1001994
initial_state = create_initial_state(
1002995
request.description,
1003996
request.schema_version,
1004997
request.max_validation_attempts,
1005-
max_total_iterations,
1006-
request.run_assessment,
998+
run_assessment=request.run_assessment,
1007999
)
10081000

10091001
# Node name to user-friendly stage mapping
@@ -1037,7 +1029,7 @@ def send_event(event_type: str, data: dict) -> str:
10371029

10381030
# Use LangGraph's astream_events for real-time streaming
10391031
config = {"recursion_limit": 50}
1040-
async for event in active_workflow.graph.astream_events(
1032+
async for event in active_workflow.graph.astream_events( # type: ignore[union-attr]
10411033
initial_state, config=config, version="v2"
10421034
):
10431035
event_type = event.get("event")
@@ -1292,13 +1284,11 @@ def send_event(event_type: str, data: dict) -> str:
12921284
)
12931285

12941286
# Step 2: Create initial state for annotation workflow
1295-
img_max_total_iterations = request.max_validation_attempts + 1
12961287
initial_state = create_initial_state(
12971288
image_description,
12981289
request.schema_version,
12991290
request.max_validation_attempts,
1300-
img_max_total_iterations,
1301-
request.run_assessment,
1291+
run_assessment=request.run_assessment,
13021292
)
13031293

13041294
# Track state and progress
@@ -1308,7 +1298,7 @@ def send_event(event_type: str, data: dict) -> str:
13081298

13091299
# Use LangGraph's astream_events for real-time streaming
13101300
config = {"recursion_limit": 50}
1311-
async for event in active_workflow.graph.astream_events(
1301+
async for event in active_workflow.graph.astream_events( # type: ignore[union-attr]
13121302
initial_state, config=config, version="v2"
13131303
):
13141304
event_type = event.get("event")

src/utils/openrouter_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class CachingLLMWrapper(BaseChatModel):
101101

102102
model_config = {"arbitrary_types_allowed": True}
103103

104-
def __init__(self, llm: BaseChatModel, **kwargs) -> None: # type: ignore[no-untyped-def]
104+
def __init__(self, llm: BaseChatModel, **kwargs) -> None:
105105
super().__init__(llm=llm, **kwargs) # type: ignore[call-arg]
106106

107107
@property

tests/test_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_create_initial_state():
1515
assert state["is_valid"] is False
1616
assert state["is_faithful"] is False
1717
assert state["is_complete"] is False
18-
assert state["max_validation_attempts"] == 5
18+
assert state["max_validation_attempts"] == 3
1919
assert state["schema_version"] == "8.4.0"
2020
assert state["no_extend"] is False
2121
assert state["tag_suggestions"] == {}

0 commit comments

Comments
 (0)