Skip to content

Commit 0fb15fe

Browse files
committed
fixing typing errors
1 parent cbd7ae1 commit 0fb15fe

File tree

7 files changed

+47
-33
lines changed

7 files changed

+47
-33
lines changed

stringsight/_public/sync_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,11 @@ def explain(
391391
from ..prompts.expansion.trace_based import expand_task_description
392392
from ..formatters.traces import format_single_trace_from_row, format_side_by_side_trace_from_row
393393

394+
if task_description is None:
395+
raise ValueError(
396+
"task_description must be provided when prompt_expansion=True and use_dynamic_prompts=False."
397+
)
398+
394399
if verbose:
395400
logger.info("[DEPRECATED] Using old prompt_expansion. Consider use_dynamic_prompts instead.")
396401
logger.info("Expanding task description using example traces...")

stringsight/api.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from __future__ import annotations
1414

15-
from typing import Any, Dict, List, Literal
15+
from typing import Any, Dict, List, Literal, cast
1616
import asyncio
1717
import io
1818
import os
@@ -561,8 +561,10 @@ async def _run_cluster_job_async(job: ClusterJob, req: ClusterRunRequest):
561561

562562
# Create minimal conversations that match the properties
563563
conversations: List[ConversationRecord] = []
564-
all_models = set()
565-
property_keys = {(prop.question_id, prop.model) for prop in properties}
564+
all_models: set[str] = set()
565+
property_keys: set[tuple[str, str]] = {
566+
(prop.question_id, cast(str, prop.model)) for prop in properties
567+
}
566568

567569
logger.info(f"Found {len(property_keys)} unique (question_id, model) pairs from {len(properties)} properties")
568570

@@ -695,9 +697,11 @@ async def _run_cluster_job_async(job: ClusterJob, req: ClusterRunRequest):
695697
meta["winner"] = matching_row["score"]["winner"]
696698

697699
# Create SxS conversation record
700+
model_a_str = model_a if isinstance(model_a, str) else str(model_a)
701+
model_b_str = model_b if isinstance(model_b, str) else str(model_b)
698702
conv = ConversationRecord(
699703
question_id=qid,
700-
model=[model_a, model_b],
704+
model=[model_a_str, model_b_str],
701705
prompt=matching_row.get("prompt", ""),
702706
responses=[matching_row.get("model_a_response", ""), matching_row.get("model_b_response", "")],
703707
scores=[score_a, score_b],

stringsight/postprocess/parser.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,12 @@ def __init__(
4747

4848
def _parse_single_property(self, index: int, prop: Property, total_props: int) -> Dict[str, Any]:
4949
"""Parse a single property response. Returns dict with results and errors."""
50-
result = {
50+
parsed_properties: List[Property] = []
51+
errors: List[Dict[str, Any]] = []
52+
result: Dict[str, Any] = {
5153
'index': index,
52-
'parsed_properties': [],
53-
'errors': [],
54+
'parsed_properties': parsed_properties,
55+
'errors': errors,
5456
'parse_failed': False,
5557
'empty_response': False
5658
}
@@ -63,7 +65,7 @@ def _parse_single_property(self, index: int, prop: Property, total_props: int) -
6365
if parsed_json is None:
6466
result['parse_failed'] = True
6567
error_details = self._analyze_json_parsing_error(prop.raw_response)
66-
result['errors'].append({
68+
errors.append({
6769
'property_id': prop.id,
6870
'question_id': prop.question_id,
6971
'model': prop.model,
@@ -84,7 +86,7 @@ def _parse_single_property(self, index: int, prop: Property, total_props: int) -
8486
else:
8587
result['parse_failed'] = True
8688
error_details = f"Parsed JSON has unsupported type: {type(parsed_json)}. Expected dict, list, or dict with 'properties' key."
87-
result['errors'].append({
89+
errors.append({
8890
'property_id': prop.id,
8991
'question_id': prop.question_id,
9092
'model': prop.model,
@@ -98,7 +100,7 @@ def _parse_single_property(self, index: int, prop: Property, total_props: int) -
98100

99101
# Process property dicts
100102
if not prop_dicts or (isinstance(prop_dicts, list) and len(prop_dicts) == 0):
101-
result['errors'].append({
103+
errors.append({
102104
'property_id': prop.id,
103105
'question_id': prop.question_id,
104106
'model': prop.model,
@@ -123,7 +125,7 @@ def _parse_single_property(self, index: int, prop: Property, total_props: int) -
123125
contains_errors=prop_dict.get("contains_errors"),
124126
raw_response=prop.raw_response,
125127
)
126-
result['parsed_properties'].append(new_prop)
128+
parsed_properties.append(new_prop)
127129

128130
return result
129131

stringsight/prompts/dynamic/discovery_generator.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
This module generates custom discovery prompt sections tailored to specific tasks.
55
"""
66

7-
import json
87
import litellm
98
import logging
10-
from typing import Dict, Any
9+
from typing import Dict, Any, cast
1110
from concurrent.futures import ThreadPoolExecutor
1211
from ...core.caching import UnifiedCache, CacheKeyBuilder
1312

@@ -49,10 +48,7 @@ def generate(
4948
# Check cache
5049
cached = self.cache.get_completion(cache_key)
5150
if cached is not None:
52-
try:
53-
return json.loads(cached)
54-
except json.JSONDecodeError:
55-
logger.warning("Invalid JSON in cache, regenerating...")
51+
return cast(Dict[str, str], cached)
5652

5753
# Generate custom sections in parallel
5854
try:
@@ -88,7 +84,7 @@ def generate(
8884
# Keep: json_schema, model_naming_rule, reasoning_suffix from base
8985

9086
# Cache result
91-
self.cache.set_completion(cache_key, json.dumps(custom_config))
87+
self.cache.set_completion(cache_key, custom_config)
9288
return custom_config
9389

9490
def _generate_intro_task(
@@ -210,7 +206,7 @@ def _build_cache_key(
210206
expanded_description: str,
211207
method: str,
212208
model: str
213-
) -> str:
209+
) -> CacheKeyBuilder:
214210
"""Build cache key for discovery prompt generation.
215211
216212
Args:
@@ -219,7 +215,7 @@ def _build_cache_key(
219215
model: LLM model.
220216
221217
Returns:
222-
Cache key string.
218+
CacheKeyBuilder for use with UnifiedCache.
223219
"""
224220
from .meta_prompts import (
225221
INTRO_TASK_GENERATION_TEMPLATE,
@@ -244,5 +240,4 @@ def _build_cache_key(
244240
"analysis_process": ANALYSIS_PROCESS_GENERATION_TEMPLATE,
245241
}).get_key(),
246242
}
247-
builder = CacheKeyBuilder(cache_data)
248-
return builder.get_key()
243+
return CacheKeyBuilder(cache_data)

stringsight/prompts/dynamic/task_expander.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import random
99
import tiktoken
10-
from typing import List, Dict, Any
10+
from typing import List, Dict, Any, cast
1111
from ...core.data_objects import ConversationRecord
1212
from ...core.caching import UnifiedCache, CacheKeyBuilder
1313
from ..expansion.trace_based import expand_task_description
@@ -60,7 +60,7 @@ def expand(
6060
# Check cache
6161
cached = self.cache.get_completion(cache_key)
6262
if cached is not None:
63-
return cached
63+
return cast(str, cached["expanded_task_description"])
6464

6565
# Convert to trace format and truncate
6666
traces = []
@@ -79,7 +79,7 @@ def expand(
7979
)
8080

8181
# Cache result
82-
self.cache.set_completion(cache_key, expanded)
82+
self.cache.set_completion(cache_key, {"expanded_task_description": expanded})
8383
return expanded
8484

8585
def _sample_conversations(
@@ -218,7 +218,7 @@ def _build_cache_key(
218218
task_description: str,
219219
sample_ids: List[str],
220220
model: str
221-
) -> str:
221+
) -> CacheKeyBuilder:
222222
"""Build cache key for task expansion.
223223
224224
Args:
@@ -227,7 +227,7 @@ def _build_cache_key(
227227
model: LLM model used for expansion.
228228
229229
Returns:
230-
Cache key string.
230+
CacheKeyBuilder for use with UnifiedCache.
231231
"""
232232
cache_data = {
233233
"type": "task_expansion",
@@ -237,5 +237,4 @@ def _build_cache_key(
237237
"max_tokens_per_sample": self.max_tokens,
238238
"version": "1.0",
239239
}
240-
builder = CacheKeyBuilder(cache_data)
241-
return builder.get_key()
240+
return CacheKeyBuilder(cache_data)

stringsight/routers/prompts.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ async def generate_prompts_endpoint(req: GeneratePromptsRequest) -> Dict[str, An
155155
logger.info(f"Prompt generation completed in {generation_time:.2f}s")
156156

157157
# Return metadata
158+
if prompts_metadata is None:
159+
raise HTTPException(
160+
status_code=500,
161+
detail="Prompt generation succeeded but returned no metadata."
162+
)
158163
return {
159164
"prompts": prompts_metadata.dict(),
160165
"generation_time_seconds": generation_time

stringsight/workers/tasks.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import logging
3-
from typing import Dict, Any, List
3+
from typing import Dict, Any, List, cast
44
from datetime import datetime
55
from pathlib import Path
66
import pandas as pd
@@ -384,8 +384,10 @@ def update_progress(progress: float):
384384
# Phase 2: Create conversations (10%)
385385
update_progress(0.10)
386386
conversations: List[ConversationRecord] = []
387-
all_models = set()
388-
property_keys = {(prop.question_id, prop.model) for prop in properties}
387+
all_models: set[str] = set()
388+
property_keys: set[tuple[str, str]] = {
389+
(prop.question_id, cast(str, prop.model)) for prop in properties
390+
}
389391

390392
for question_id, model in property_keys:
391393
all_models.add(model)
@@ -477,9 +479,11 @@ def update_progress(progress: float):
477479
elif "score" in matching_row and isinstance(matching_row["score"], dict) and "winner" in matching_row["score"]:
478480
meta["winner"] = matching_row["score"]["winner"]
479481

482+
model_a_str = model_a if isinstance(model_a, str) else str(model_a)
483+
model_b_str = model_b if isinstance(model_b, str) else str(model_b)
480484
conv = ConversationRecord(
481485
question_id=qid,
482-
model=[model_a, model_b],
486+
model=[model_a_str, model_b_str],
483487
prompt=matching_row.get("prompt", ""),
484488
responses=[matching_row.get("model_a_response", ""), matching_row.get("model_b_response", "")],
485489
scores=[score_a, score_b],

0 commit comments

Comments
 (0)