Skip to content

Commit a4e52d5

Browse files
GeneAIclaude
authored andcommitted
fix: Type annotation improvements across core modules
- Fixed type annotations in 18 files for mypy compatibility - Added explicit type hints for list/dict variables - Fixed Optional types using | None syntax - Added type: ignore for SDK API calls with complex overloads - Cast values where needed for proper type narrowing Files fixed: - empathy_llm_toolkit: core, providers, state, session_status, code_health - empathy_llm_toolkit/security: audit_logger - empathy_software_plugin/wizards: base, pattern_*, code_review, ai_* - empathy_software_plugin/wizards/testing: quality_analyzer, coverage_analyzer - empathy_software_plugin/wizards/security: exploit_analyzer, vulnerability_scanner 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 5ddf973 commit a4e52d5

19 files changed

+1534
-65
lines changed

empathy_llm_toolkit/code_health.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ async def _run_check_async(
315315
"""Run a check handler asynchronously."""
316316
start_time = datetime.now()
317317
try:
318-
result = await asyncio.to_thread(handler, config)
318+
result: CheckResult = await asyncio.to_thread(handler, config)
319319
result.duration_ms = int((datetime.now() - start_time).total_seconds() * 1000)
320320
return result
321321
except Exception as e:

empathy_llm_toolkit/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,10 @@ async def interact(
298298
context = context or {}
299299

300300
# Initialize security tracking
301-
pii_detections = []
302-
secrets_detections = []
301+
pii_detections: list[dict] = []
302+
secrets_detections: list[dict] = []
303303
sanitized_input = user_input
304-
security_metadata = {}
304+
security_metadata: dict[str, Any] = {}
305305

306306
# Phase 3: Security Pipeline (Step 1 - PII Scrubbing)
307307
if self.enable_security and self.pii_scrubber:

empathy_llm_toolkit/providers.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ async def generate(
161161
api_kwargs.update(kwargs)
162162

163163
# Call Anthropic API
164-
response = self.client.messages.create(**api_kwargs)
164+
response = self.client.messages.create(**api_kwargs) # type: ignore[call-overload]
165165

166166
# Extract thinking content if present
167167
thinking_content = None
@@ -329,21 +329,23 @@ async def generate(
329329
# Call OpenAI API
330330
response = await self.client.chat.completions.create(
331331
model=self.model,
332-
messages=messages,
332+
messages=messages, # type: ignore[arg-type]
333333
temperature=temperature,
334334
max_tokens=max_tokens,
335335
**kwargs,
336336
)
337337

338338
# Convert to standardized format
339+
content = response.choices[0].message.content or ""
340+
usage = response.usage
339341
return LLMResponse(
340-
content=response.choices[0].message.content,
342+
content=content,
341343
model=response.model,
342-
tokens_used=response.usage.total_tokens,
344+
tokens_used=usage.total_tokens if usage else 0,
343345
finish_reason=response.choices[0].finish_reason,
344346
metadata={
345-
"input_tokens": response.usage.prompt_tokens,
346-
"output_tokens": response.usage.completion_tokens,
347+
"input_tokens": usage.prompt_tokens if usage else 0,
348+
"output_tokens": usage.completion_tokens if usage else 0,
347349
"provider": "openai",
348350
},
349351
)

empathy_llm_toolkit/security/audit_logger.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ def query(
660660
>>> # Find patterns with high PII counts (nested filter)
661661
>>> events = logger.query(security__pii_detected__gt=5)
662662
"""
663-
results = []
663+
results: list[dict[str, Any]] = []
664664

665665
try:
666666
if not self.log_path.exists():
@@ -762,23 +762,32 @@ def get_violation_summary(self, user_id: str | None = None) -> dict[str, Any]:
762762
"""
763763
violations = self.query(event_type="security_violation", user_id=user_id)
764764

765-
summary = {
766-
"total_violations": len(violations),
767-
"by_type": {},
768-
"by_severity": {},
769-
"by_user": {},
770-
}
765+
by_type: dict[str, int] = {}
766+
by_severity: dict[str, int] = {}
767+
by_user: dict[str, int] = {}
771768

772769
for violation in violations:
773-
vtype = violation.get("violation", {}).get("type", "unknown")
774-
severity = violation.get("violation", {}).get("severity", "unknown")
775-
vid = violation.get("user_id", "unknown")
770+
viol_data = violation.get("violation", {})
771+
vtype = (
772+
str(viol_data.get("type", "unknown")) if isinstance(viol_data, dict) else "unknown"
773+
)
774+
severity = (
775+
str(viol_data.get("severity", "unknown"))
776+
if isinstance(viol_data, dict)
777+
else "unknown"
778+
)
779+
vid = str(violation.get("user_id", "unknown"))
776780

777-
summary["by_type"][vtype] = summary["by_type"].get(vtype, 0) + 1
778-
summary["by_severity"][severity] = summary["by_severity"].get(severity, 0) + 1
779-
summary["by_user"][vid] = summary["by_user"].get(vid, 0) + 1
781+
by_type[vtype] = by_type.get(vtype, 0) + 1
782+
by_severity[severity] = by_severity.get(severity, 0) + 1
783+
by_user[vid] = by_user.get(vid, 0) + 1
780784

781-
return summary
785+
return {
786+
"total_violations": len(violations),
787+
"by_type": by_type,
788+
"by_severity": by_severity,
789+
"by_user": by_user,
790+
}
782791

783792
def get_compliance_report(
784793
self, start_date: datetime | None = None, end_date: datetime | None = None

empathy_llm_toolkit/session_status.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,8 @@ def _load_previous_snapshot(self) -> dict[str, Any] | None:
528528
if snapshot_file.stem != today:
529529
try:
530530
with open(snapshot_file, encoding="utf-8") as f:
531-
return json.load(f)
531+
data = json.load(f)
532+
return dict(data) if isinstance(data, dict) else None
532533
except (json.JSONDecodeError, OSError):
533534
continue
534535

empathy_llm_toolkit/state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def find_matching_pattern(self, trigger_text: str) -> UserPattern | None:
170170

171171
def get_conversation_history(
172172
self, max_turns: int = 10, include_metadata: bool = False
173-
) -> list[dict[str, str]]:
173+
) -> list[dict[str, Any]]:
174174
"""
175175
Get recent conversation history in LLM format.
176176

empathy_software_plugin/wizards/ai_context_wizard.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def _generate_recommendations(self, issues: list[dict], predictions: list[dict])
321321
# Prioritize high-impact predictions
322322
high_impact = sorted(
323323
predictions,
324-
key=lambda p: {"high": 3, "medium": 2, "low": 1}.get(p.get("impact"), 0),
324+
key=lambda p: {"high": 3, "medium": 2, "low": 1}.get(str(p.get("impact", "")), 0),
325325
reverse=True,
326326
)
327327

@@ -385,9 +385,9 @@ def _get_model_limits(self, provider: str, model: str) -> dict:
385385
def _estimate_context_tokens(self, call: dict, context_sources: list[dict]) -> int:
386386
"""Estimate tokens for an AI call"""
387387
# Simplified: ~4 chars per token
388-
base_prompt = call.get("prompt_size", 1000)
388+
base_prompt = int(call.get("prompt_size", 1000))
389389
dynamic_context = sum(
390-
source.get("estimated_size", 0)
390+
int(source.get("estimated_size", 0))
391391
for source in context_sources
392392
if source.get("call_id") == call.get("id")
393393
)

empathy_software_plugin/wizards/ai_documentation_wizard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ async def _analyze_ai_documentation_quality(
9292
9393
Checks what AI needs to give good recommendations.
9494
"""
95-
issues = []
95+
issues: list[dict[str, Any]] = []
9696

9797
# Check for missing context that AI needs
9898
for doc_file in doc_files:

empathy_software_plugin/wizards/base_wizard.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,12 @@ def get_cached_result(self, context: dict[str, Any]) -> dict[str, Any] | None:
115115
Returns:
116116
Cached result dict, or None if not cached
117117
"""
118-
if not self.has_memory():
118+
if self.short_term_memory is None:
119119
return None
120120

121121
key = self._cache_key(context)
122-
return self.short_term_memory.retrieve(key, self._credentials)
122+
result = self.short_term_memory.retrieve(key, self._credentials)
123+
return dict(result) if result else None
123124

124125
def cache_result(self, context: dict[str, Any], result: dict[str, Any]) -> bool:
125126
"""
@@ -132,11 +133,11 @@ def cache_result(self, context: dict[str, Any], result: dict[str, Any]) -> bool:
132133
Returns:
133134
True if cached successfully
134135
"""
135-
if not self.has_memory():
136+
if self.short_term_memory is None:
136137
return False
137138

138139
key = self._cache_key(context)
139-
return self.short_term_memory.stash(key, result, self._credentials)
140+
return bool(self.short_term_memory.stash(key, result, self._credentials))
140141

141142
async def analyze_with_cache(self, context: dict[str, Any]) -> dict[str, Any]:
142143
"""
@@ -183,18 +184,20 @@ def share_context(self, key: str, data: Any) -> bool:
183184
Returns:
184185
True if shared successfully
185186
"""
186-
if not self.has_memory():
187+
if self.short_term_memory is None:
187188
return False
188189

189190
# Use global credentials for shared context (accessible to all wizards)
190191
global_creds = AgentCredentials(
191192
agent_id="wizard_shared",
192193
tier=AccessTier.CONTRIBUTOR,
193194
)
194-
return self.short_term_memory.stash(
195-
f"shared:{key}",
196-
data,
197-
global_creds,
195+
return bool(
196+
self.short_term_memory.stash(
197+
f"shared:{key}",
198+
data,
199+
global_creds,
200+
)
198201
)
199202

200203
def get_shared_context(self, key: str, from_wizard: str | None = None) -> Any | None:
@@ -212,7 +215,7 @@ def get_shared_context(self, key: str, from_wizard: str | None = None) -> Any |
212215
Returns:
213216
The shared data, or None if not found
214217
"""
215-
if not self.has_memory():
218+
if self.short_term_memory is None:
216219
return None
217220

218221
# Use global shared namespace by default, or specific wizard if requested
@@ -249,7 +252,7 @@ def stage_discovered_pattern(
249252
Returns:
250253
True if staged successfully
251254
"""
252-
if not self.has_memory():
255+
if self.short_term_memory is None:
253256
return False
254257

255258
pattern = StagedPattern(
@@ -263,7 +266,7 @@ def stage_discovered_pattern(
263266
context={"wizard": self.name, "level": self.level},
264267
)
265268

266-
return self.short_term_memory.stage_pattern(pattern, self._credentials)
269+
return bool(self.short_term_memory.stage_pattern(pattern, self._credentials))
267270

268271
def send_signal(self, signal_type: str, data: dict) -> bool:
269272
"""
@@ -278,11 +281,13 @@ def send_signal(self, signal_type: str, data: dict) -> bool:
278281
Returns:
279282
True if sent successfully
280283
"""
281-
if not self.has_memory():
284+
if self.short_term_memory is None:
282285
return False
283286

284-
return self.short_term_memory.send_signal(
285-
signal_type=signal_type,
286-
data={"wizard": self.name, **data},
287-
credentials=self._credentials,
287+
return bool(
288+
self.short_term_memory.send_signal(
289+
signal_type=signal_type,
290+
data={"wizard": self.name, **data},
291+
credentials=self._credentials,
292+
)
288293
)

empathy_software_plugin/wizards/code_review_wizard.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def _bug_to_rule(self, bug: dict) -> AntiPatternRule | None:
283283

284284
def _extract_safe_patterns(self, fix_code: str) -> list[str]:
285285
"""Extract regex patterns from fix code that indicate safety."""
286-
patterns = []
286+
patterns: list[str] = []
287287
if not fix_code:
288288
return patterns
289289

@@ -303,7 +303,7 @@ def _extract_safe_patterns(self, fix_code: str) -> list[str]:
303303

304304
def _review_file(self, file_path: str) -> list[ReviewFinding]:
305305
"""Review a single file for anti-patterns."""
306-
findings = []
306+
findings: list[ReviewFinding] = []
307307

308308
try:
309309
path = Path(file_path)
@@ -465,9 +465,9 @@ def _calculate_confidence(self, findings: list[ReviewFinding]) -> float:
465465
avg_conf = sum(f.confidence for f in findings) / len(findings)
466466
return round(avg_conf, 2)
467467

468-
def _generate_predictions(self, findings: list[ReviewFinding]) -> list[dict]:
468+
def _generate_predictions(self, findings: list[ReviewFinding]) -> list[dict[str, Any]]:
469469
"""Generate Level 4 predictions."""
470-
predictions = []
470+
predictions: list[dict[str, Any]] = []
471471

472472
if not findings:
473473
predictions.append(

0 commit comments

Comments
 (0)