Skip to content

Commit 385e348

Browse files
authored
refactor: fix critical code inconsistencies (types, protocols, logging) (#3)
- Replace ABC with Protocol for structural typing (cache, similarity, secrets) - Add GraphValidationError for workflow validation errors - Fix Any types to specific types (logger, base_dir, response_cache) - Convert f-strings to %-formatting in logging for deferred evaluation - Replace logger.error with logger.exception in except blocks - Convert relative imports to absolute in tournament.py - Standardize logger naming to __name__
1 parent 2ac3507 commit 385e348

File tree

11 files changed

+110
-80
lines changed

11 files changed

+110
-80
lines changed

.github/workflows/diff-context.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
name: Diff Context
2+
on: [pull_request]
3+
4+
permissions:
5+
contents: read
6+
pull-requests: write
7+
8+
jobs:
9+
diff:
10+
uses: nikolay-e/treemapper-action/.github/workflows/diff-context.yml@v1

src/arbitrium_core/application/execution/executor.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from arbitrium_core.application.workflow.nodes.base import BaseNode
88
from arbitrium_core.application.workflow.registry import registry
9+
from arbitrium_core.domain.errors import GraphValidationError
910
from arbitrium_core.shared.logging import get_contextual_logger
1011

1112
logger = get_contextual_logger(__name__)
@@ -158,7 +159,9 @@ def _build_graph(
158159
errors.append(f"Unknown node type: {node_type}")
159160

160161
if errors:
161-
raise ValueError(f"Graph build errors: {'; '.join(errors)}")
162+
raise GraphValidationError(
163+
f"Graph build errors: {'; '.join(errors)}"
164+
)
162165

163166
dependencies: dict[str, list[str]] = defaultdict(list)
164167
connections: dict[str, list[tuple[str, str, str]]] = defaultdict(list)
@@ -178,11 +181,11 @@ def _build_graph(
178181
)
179182

180183
if source not in node_instances:
181-
raise ValueError(
184+
raise GraphValidationError(
182185
f"Edge references unknown source node: {source}"
183186
)
184187
if target not in node_instances:
185-
raise ValueError(
188+
raise GraphValidationError(
186189
f"Edge references unknown target node: {target}"
187190
)
188191

@@ -226,7 +229,9 @@ def _topological_sort(
226229
queue.append(other_id)
227230

228231
if len(result) != len(nodes):
229-
raise ValueError("Graph contains a cycle - check node connections")
232+
raise GraphValidationError(
233+
"Graph contains a cycle - check node connections"
234+
)
230235

231236
return result
232237

@@ -260,7 +265,7 @@ def _build_execution_layers(
260265
]
261266

262267
if not current_layer:
263-
raise ValueError(
268+
raise GraphValidationError(
264269
"Graph contains a cycle - check node connections"
265270
)
266271

@@ -358,7 +363,7 @@ def _validate_workflow(
358363
node_instances: dict[str, BaseNode],
359364
) -> None:
360365
if not node_instances:
361-
raise ValueError("No valid nodes in graph")
366+
raise GraphValidationError("No valid nodes in graph")
362367

363368
def _build_execution_graph(
364369
self,

src/arbitrium_core/domain/errors.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"ExceptionClassifier",
1717
"FatalError",
1818
"FileSystemError",
19+
"GraphValidationError",
1920
"InputError",
2021
"ModelError",
2122
"ModelResponseError",
@@ -112,6 +113,23 @@ def __init__(
112113
super().__init__(enhanced_message, *args)
113114

114115

116+
class GraphValidationError(ArbitriumError):
117+
"""Exception for workflow graph validation errors."""
118+
119+
def __init__(
120+
self,
121+
message: str,
122+
node_id: str | None = None,
123+
*args: object,
124+
**kwargs: object,
125+
) -> None:
126+
self.node_id = node_id
127+
enhanced_message = message
128+
if node_id:
129+
enhanced_message = f"[Node: {node_id}] {enhanced_message}"
130+
super().__init__(enhanced_message, *args)
131+
132+
115133
class ErrorClassification:
116134
__slots__ = ("error_type", "is_retryable", "message")
117135

src/arbitrium_core/domain/knowledge/bank.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,16 @@ def _determine_extractor_model_key(
6363
else None
6464
)
6565
self.logger.warning(
66-
f"Leader not determined yet, using fallback model: {extractor_model_key}"
66+
"Leader not determined yet, using fallback model: %s",
67+
extractor_model_key,
6768
)
6869
else:
6970
leader_display = self.comparison.anon_mapping.get(
7071
extractor_model_key, extractor_model_key
7172
)
7273
self.logger.info(
73-
f"Using tournament leader {leader_display} for insight extraction"
74+
"Using tournament leader %s for insight extraction",
75+
leader_display,
7476
)
7577
return extractor_model_key
7678
else:
@@ -129,13 +131,16 @@ def _parse_claims_from_response(
129131
self, response_content: str, extractor_model_key: str
130132
) -> list[str]:
131133
self.logger.debug(
132-
f"[{extractor_model_key}] Raw insight extraction response: {response_content}"
134+
"[%s] Raw insight extraction response: %s",
135+
extractor_model_key,
136+
response_content,
133137
)
134138

135139
if detect_apology_or_refusal(response_content):
136140
self.logger.error(
137-
f"[{extractor_model_key}] Model returned apology/refusal instead of insight extraction. "
138-
f"Response: {response_content}"
141+
"[%s] Model returned apology/refusal instead of insight extraction. Response: %s",
142+
extractor_model_key,
143+
response_content,
139144
)
140145
return []
141146

@@ -145,7 +150,9 @@ def _parse_claims_from_response(
145150

146151
if not claims:
147152
self.logger.warning(
148-
f"[{extractor_model_key}] No valid insights found in response. Response: {response_content}"
153+
"[%s] No valid insights found in response. Response: %s",
154+
extractor_model_key,
155+
response_content,
149156
)
150157

151158
return claims

src/arbitrium_core/domain/tournament/tournament.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
"Core comparison functionality for Arbitrium Framework."
22

33
import asyncio
4+
import logging
45
import re
56
import statistics
67
from abc import ABC, abstractmethod
78
from collections.abc import Callable
89
from datetime import datetime
9-
from typing import Any
10+
from pathlib import Path
11+
from typing import TYPE_CHECKING, Any
12+
13+
if TYPE_CHECKING:
14+
from arbitrium_core.shared.logging.structured import ContextualLogger
1015

1116
from arbitrium_core.domain.errors import (
1217
BudgetExceededError,
@@ -19,6 +24,9 @@
1924
PromptBuilder,
2025
PromptFormatter,
2126
)
27+
from arbitrium_core.domain.tournament.anonymizer import ModelAnonymizer
28+
from arbitrium_core.domain.tournament.report import ReportGenerator
29+
from arbitrium_core.domain.tournament.scoring import ScoreExtractor
2230
from arbitrium_core.ports.llm import BaseModel, ModelResponse
2331
from arbitrium_core.ports.similarity import SimilarityEngine
2432
from arbitrium_core.shared.constants import (
@@ -29,10 +37,6 @@
2937
from arbitrium_core.shared.logging import get_contextual_logger
3038
from arbitrium_core.shared.text import indent_text, strip_meta_commentary
3139

32-
from .anonymizer import ModelAnonymizer
33-
from .report import ReportGenerator
34-
from .scoring import ScoreExtractor
35-
3640

3741
# Internal interfaces for ModelComparison
3842
class EventHandler(ABC):
@@ -42,7 +46,7 @@ def publish(self, _event_name: str, _data: dict[str, Any]) -> None:
4246

4347

4448
class HostEnvironment(ABC):
45-
base_dir: Any # Output directory path (required by implementations)
49+
base_dir: Path | str # Output directory path (required by implementations)
4650

4751
@abstractmethod
4852
async def read_file(self, path: str) -> str:
@@ -116,7 +120,9 @@ def reset(self) -> None:
116120

117121

118122
class CostTracker:
119-
def __init__(self, logger: Any = None):
123+
def __init__(
124+
self, logger: "logging.Logger | ContextualLogger | None" = None
125+
):
120126
self.total_cost = 0.0
121127
self.cost_by_model: dict[str, float] = {}
122128
self.logger = logger or get_contextual_logger("arbitrium.cost_tracker")
@@ -129,7 +135,10 @@ def add_cost(self, model_name: str, cost: float) -> None:
129135
self.cost_by_model[model_name] += cost
130136

131137
self.logger.debug(
132-
f"💰 Added ${cost:.4f} for {model_name}, total now: ${self.total_cost:.4f}"
138+
"Added $%.4f for %s, total now: $%.4f",
139+
cost,
140+
model_name,
141+
self.total_cost,
133142
)
134143

135144
def get_summary(self) -> dict[str, Any]:
@@ -171,7 +180,7 @@ def display_summary(self) -> None:
171180
extra={"display_type": "colored_text"},
172181
)
173182

174-
self.logger.info(f"💰 Tournament total cost: ${self.total_cost:.4f}")
183+
self.logger.info("Tournament total cost: $%.4f", self.total_cost)
175184

176185

177186
class TournamentRunner:
@@ -196,7 +205,7 @@ async def run(self, initial_question: str) -> str:
196205
)
197206

198207
self.logger.info(
199-
f"Starting model comparison tournament: {initial_question}"
208+
"Starting model comparison tournament: %s", initial_question
200209
)
201210
self.comp.previous_answers = []
202211
self.comp.eliminated_models = []
@@ -210,9 +219,7 @@ async def run(self, initial_question: str) -> str:
210219
self.logger.warning("Process interrupted by user.")
211220
return "Process interrupted by user."
212221
except Exception as e:
213-
self.logger.error(
214-
f"Unexpected error in tournament: {e!s}", exc_info=True
215-
)
222+
self.logger.exception("Unexpected error in tournament: %s", e)
216223
return f"Tournament error: {e!s}"
217224

218225
async def _run_initial_phase(self, initial_question: str) -> bool:

src/arbitrium_core/engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from arbitrium_core.ports.similarity import SimilarityEngine
1010
from arbitrium_core.shared.logging import get_contextual_logger
1111

12-
logger = get_contextual_logger("arbitrium")
12+
logger = get_contextual_logger(__name__)
1313

1414

1515
class _InternalEventHandler:
@@ -157,7 +157,7 @@ async def run_tournament(
157157
f"Failed models: {list(self._failed_models.keys())}"
158158
)
159159

160-
logger.info(f"Starting tournament with {len(models)} models")
160+
logger.info("Starting tournament with %d models", len(models))
161161

162162
comparison = self._create_comparison(models)
163163
self._last_comparison = comparison
@@ -198,16 +198,16 @@ async def run_all_models(self, prompt: str) -> dict[str, ModelResponse]:
198198
raise RuntimeError("No healthy models available")
199199

200200
logger.info(
201-
f"Running prompt through {len(self._healthy_models)} models"
201+
"Running prompt through %d models", len(self._healthy_models)
202202
)
203203

204204
results = {}
205205
for model_key in self._healthy_models:
206206
try:
207207
response = await self.run_single_model(model_key, prompt)
208208
results[model_key] = response
209-
except Exception as e:
210-
logger.error(f"Failed to run {model_key}: {e}")
209+
except Exception:
210+
logger.exception("Failed to run %s", model_key)
211211
# Continue with other models
212212
continue
213213

src/arbitrium_core/infrastructure/llm/model_factory.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
from arbitrium_core.ports.llm import BaseModel
44
from arbitrium_core.shared.logging import get_contextual_logger
55

6-
logger = get_contextual_logger("arbitrium.infrastructure.llm.model_factory")
6+
logger = get_contextual_logger(__name__)
77

88

99
async def ensure_model_instances(
1010
models: dict[str, Any],
11-
) -> dict[str, Any]:
11+
) -> dict[str, BaseModel]:
1212
from arbitrium_core.infrastructure.llm.litellm_adapter import LiteLLMModel
1313

14-
result = {}
14+
result: dict[str, BaseModel] = {}
1515
for key, model in models.items():
1616
if isinstance(model, BaseModel):
1717
result[key] = model
@@ -22,14 +22,16 @@ async def ensure_model_instances(
2222
model_config=model,
2323
)
2424
result[key] = instance
25-
except Exception as e:
26-
logger.error(f"Failed to create model from config {key}: {e}")
25+
except Exception:
26+
logger.exception("Failed to create model from config %s", key)
2727
else:
28-
logger.warning(f"Unknown model type for {key}: {type(model)}")
28+
logger.warning("Unknown model type for %s: %s", key, type(model))
2929
return result
3030

3131

32-
async def ensure_single_model_instance(model: Any, key: str = "model") -> Any:
32+
async def ensure_single_model_instance(
33+
model: Any, key: str = "model"
34+
) -> BaseModel | None:
3335
from arbitrium_core.infrastructure.llm.litellm_adapter import LiteLLMModel
3436

3537
if isinstance(model, BaseModel):
@@ -40,9 +42,9 @@ async def ensure_single_model_instance(model: Any, key: str = "model") -> Any:
4042
model_key=model.get("name", key),
4143
model_config=model,
4244
)
43-
except Exception as e:
44-
logger.error(f"Failed to create model from config {key}: {e}")
45+
except Exception:
46+
logger.exception("Failed to create model from config %s", key)
4547
return None
4648
else:
47-
logger.warning(f"Unknown model type for {key}: {type(model)}")
49+
logger.warning("Unknown model type for %s: %s", key, type(model))
4850
return None

src/arbitrium_core/ports/cache.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
from abc import ABC, abstractmethod
1+
from typing import Protocol
22

33

4-
class CacheProtocol(ABC):
5-
@abstractmethod
4+
class CacheProtocol(Protocol):
65
def get(
76
self, model_name: str, prompt: str, temperature: float, max_tokens: int
8-
) -> tuple[str, float] | None:
9-
pass
7+
) -> tuple[str, float] | None: ...
108

11-
@abstractmethod
129
def set(
1310
self,
1411
model_name: str,
@@ -17,13 +14,8 @@ def set(
1714
max_tokens: int,
1815
response: str,
1916
cost: float,
20-
) -> None:
21-
pass
17+
) -> None: ...
2218

23-
@abstractmethod
24-
def clear(self) -> None:
25-
pass
19+
def clear(self) -> None: ...
2620

27-
@abstractmethod
28-
def close(self) -> None:
29-
pass
21+
def close(self) -> None: ...

src/arbitrium_core/ports/llm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from abc import ABC, abstractmethod
22
from collections.abc import Awaitable
3-
from typing import Any, Protocol
3+
from typing import TYPE_CHECKING, Any, Protocol
4+
5+
if TYPE_CHECKING:
6+
from arbitrium_core.ports.cache import CacheProtocol
47

58

69
class ModelResponse:
@@ -94,5 +97,5 @@ def from_config(
9497
cls,
9598
model_key: str,
9699
config: dict[str, Any],
97-
response_cache: Any | None = None,
100+
response_cache: "CacheProtocol | None" = None,
98101
) -> Awaitable[BaseModel]: ...

0 commit comments

Comments
 (0)