Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/diff-context.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: Diff Context
on: [pull_request]

permissions:
contents: read
pull-requests: write

jobs:
diff:
uses: nikolay-e/treemapper-action/.github/workflows/diff-context.yml@v1
17 changes: 11 additions & 6 deletions src/arbitrium_core/application/execution/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from arbitrium_core.application.workflow.nodes.base import BaseNode
from arbitrium_core.application.workflow.registry import registry
from arbitrium_core.domain.errors import GraphValidationError
from arbitrium_core.shared.logging import get_contextual_logger

logger = get_contextual_logger(__name__)
Expand Down Expand Up @@ -158,7 +159,9 @@ def _build_graph(
errors.append(f"Unknown node type: {node_type}")

if errors:
raise ValueError(f"Graph build errors: {'; '.join(errors)}")
raise GraphValidationError(
f"Graph build errors: {'; '.join(errors)}"
)

dependencies: dict[str, list[str]] = defaultdict(list)
connections: dict[str, list[tuple[str, str, str]]] = defaultdict(list)
Expand All @@ -178,11 +181,11 @@ def _build_graph(
)

if source not in node_instances:
raise ValueError(
raise GraphValidationError(
f"Edge references unknown source node: {source}"
)
if target not in node_instances:
raise ValueError(
raise GraphValidationError(
f"Edge references unknown target node: {target}"
)

Expand Down Expand Up @@ -226,7 +229,9 @@ def _topological_sort(
queue.append(other_id)

if len(result) != len(nodes):
raise ValueError("Graph contains a cycle - check node connections")
raise GraphValidationError(
"Graph contains a cycle - check node connections"
)

return result

Expand Down Expand Up @@ -260,7 +265,7 @@ def _build_execution_layers(
]

if not current_layer:
raise ValueError(
raise GraphValidationError(
"Graph contains a cycle - check node connections"
)

Expand Down Expand Up @@ -358,7 +363,7 @@ def _validate_workflow(
node_instances: dict[str, BaseNode],
) -> None:
if not node_instances:
raise ValueError("No valid nodes in graph")
raise GraphValidationError("No valid nodes in graph")

def _build_execution_graph(
self,
Expand Down
18 changes: 18 additions & 0 deletions src/arbitrium_core/domain/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"ExceptionClassifier",
"FatalError",
"FileSystemError",
"GraphValidationError",
"InputError",
"ModelError",
"ModelResponseError",
Expand Down Expand Up @@ -112,6 +113,23 @@ def __init__(
super().__init__(enhanced_message, *args)


class GraphValidationError(ArbitriumError):
"""Exception for workflow graph validation errors."""

def __init__(
self,
message: str,
node_id: str | None = None,
*args: object,
**kwargs: object,
) -> None:
self.node_id = node_id
enhanced_message = message
if node_id:
enhanced_message = f"[Node: {node_id}] {enhanced_message}"
super().__init__(enhanced_message, *args)


class ErrorClassification:
__slots__ = ("error_type", "is_retryable", "message")

Expand Down
19 changes: 13 additions & 6 deletions src/arbitrium_core/domain/knowledge/bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,16 @@ def _determine_extractor_model_key(
else None
)
self.logger.warning(
f"Leader not determined yet, using fallback model: {extractor_model_key}"
"Leader not determined yet, using fallback model: %s",
extractor_model_key,
)
else:
leader_display = self.comparison.anon_mapping.get(
extractor_model_key, extractor_model_key
)
self.logger.info(
f"Using tournament leader {leader_display} for insight extraction"
"Using tournament leader %s for insight extraction",
leader_display,
)
return extractor_model_key
else:
Expand Down Expand Up @@ -129,13 +131,16 @@ def _parse_claims_from_response(
self, response_content: str, extractor_model_key: str
) -> list[str]:
self.logger.debug(
f"[{extractor_model_key}] Raw insight extraction response: {response_content}"
"[%s] Raw insight extraction response: %s",
extractor_model_key,
response_content,
)

if detect_apology_or_refusal(response_content):
self.logger.error(
f"[{extractor_model_key}] Model returned apology/refusal instead of insight extraction. "
f"Response: {response_content}"
"[%s] Model returned apology/refusal instead of insight extraction. Response: %s",
extractor_model_key,
response_content,
)
return []

Expand All @@ -145,7 +150,9 @@ def _parse_claims_from_response(

if not claims:
self.logger.warning(
f"[{extractor_model_key}] No valid insights found in response. Response: {response_content}"
"[%s] No valid insights found in response. Response: %s",
extractor_model_key,
response_content,
)

return claims
Expand Down
33 changes: 20 additions & 13 deletions src/arbitrium_core/domain/tournament/tournament.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"Core comparison functionality for Arbitrium Framework."

import asyncio
import logging
import re
import statistics
from abc import ABC, abstractmethod
from collections.abc import Callable
from datetime import datetime
from typing import Any
from pathlib import Path
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from arbitrium_core.shared.logging.structured import ContextualLogger

from arbitrium_core.domain.errors import (
BudgetExceededError,
Expand All @@ -19,6 +24,9 @@
PromptBuilder,
PromptFormatter,
)
from arbitrium_core.domain.tournament.anonymizer import ModelAnonymizer
from arbitrium_core.domain.tournament.report import ReportGenerator
from arbitrium_core.domain.tournament.scoring import ScoreExtractor
from arbitrium_core.ports.llm import BaseModel, ModelResponse
from arbitrium_core.ports.similarity import SimilarityEngine
from arbitrium_core.shared.constants import (
Expand All @@ -29,10 +37,6 @@
from arbitrium_core.shared.logging import get_contextual_logger
from arbitrium_core.shared.text import indent_text, strip_meta_commentary

from .anonymizer import ModelAnonymizer
from .report import ReportGenerator
from .scoring import ScoreExtractor


# Internal interfaces for ModelComparison
class EventHandler(ABC):
Expand All @@ -42,7 +46,7 @@ def publish(self, _event_name: str, _data: dict[str, Any]) -> None:


class HostEnvironment(ABC):
base_dir: Any # Output directory path (required by implementations)
base_dir: Path | str # Output directory path (required by implementations)

@abstractmethod
async def read_file(self, path: str) -> str:
Expand Down Expand Up @@ -116,7 +120,9 @@ def reset(self) -> None:


class CostTracker:
def __init__(self, logger: Any = None):
def __init__(
self, logger: "logging.Logger | ContextualLogger | None" = None
):
self.total_cost = 0.0
self.cost_by_model: dict[str, float] = {}
self.logger = logger or get_contextual_logger("arbitrium.cost_tracker")
Expand All @@ -129,7 +135,10 @@ def add_cost(self, model_name: str, cost: float) -> None:
self.cost_by_model[model_name] += cost

self.logger.debug(
f"💰 Added ${cost:.4f} for {model_name}, total now: ${self.total_cost:.4f}"
"Added $%.4f for %s, total now: $%.4f",
cost,
model_name,
self.total_cost,
)

def get_summary(self) -> dict[str, Any]:
Expand Down Expand Up @@ -171,7 +180,7 @@ def display_summary(self) -> None:
extra={"display_type": "colored_text"},
)

self.logger.info(f"💰 Tournament total cost: ${self.total_cost:.4f}")
self.logger.info("Tournament total cost: $%.4f", self.total_cost)


class TournamentRunner:
Expand All @@ -196,7 +205,7 @@ async def run(self, initial_question: str) -> str:
)

self.logger.info(
f"Starting model comparison tournament: {initial_question}"
"Starting model comparison tournament: %s", initial_question
)
self.comp.previous_answers = []
self.comp.eliminated_models = []
Expand All @@ -210,9 +219,7 @@ async def run(self, initial_question: str) -> str:
self.logger.warning("Process interrupted by user.")
return "Process interrupted by user."
except Exception as e:
self.logger.error(
f"Unexpected error in tournament: {e!s}", exc_info=True
)
self.logger.exception("Unexpected error in tournament: %s", e)
return f"Tournament error: {e!s}"

async def _run_initial_phase(self, initial_question: str) -> bool:
Expand Down
10 changes: 5 additions & 5 deletions src/arbitrium_core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from arbitrium_core.ports.similarity import SimilarityEngine
from arbitrium_core.shared.logging import get_contextual_logger

logger = get_contextual_logger("arbitrium")
logger = get_contextual_logger(__name__)


class _InternalEventHandler:
Expand Down Expand Up @@ -157,7 +157,7 @@ async def run_tournament(
f"Failed models: {list(self._failed_models.keys())}"
)

logger.info(f"Starting tournament with {len(models)} models")
logger.info("Starting tournament with %d models", len(models))

comparison = self._create_comparison(models)
self._last_comparison = comparison
Expand Down Expand Up @@ -198,16 +198,16 @@ async def run_all_models(self, prompt: str) -> dict[str, ModelResponse]:
raise RuntimeError("No healthy models available")

logger.info(
f"Running prompt through {len(self._healthy_models)} models"
"Running prompt through %d models", len(self._healthy_models)
)

results = {}
for model_key in self._healthy_models:
try:
response = await self.run_single_model(model_key, prompt)
results[model_key] = response
except Exception as e:
logger.error(f"Failed to run {model_key}: {e}")
except Exception:
logger.exception("Failed to run %s", model_key)
# Continue with other models
continue

Expand Down
22 changes: 12 additions & 10 deletions src/arbitrium_core/infrastructure/llm/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from arbitrium_core.ports.llm import BaseModel
from arbitrium_core.shared.logging import get_contextual_logger

logger = get_contextual_logger("arbitrium.infrastructure.llm.model_factory")
logger = get_contextual_logger(__name__)


async def ensure_model_instances(
models: dict[str, Any],
) -> dict[str, Any]:
) -> dict[str, BaseModel]:
from arbitrium_core.infrastructure.llm.litellm_adapter import LiteLLMModel

result = {}
result: dict[str, BaseModel] = {}
for key, model in models.items():
if isinstance(model, BaseModel):
result[key] = model
Expand All @@ -22,14 +22,16 @@ async def ensure_model_instances(
model_config=model,
)
result[key] = instance
except Exception as e:
logger.error(f"Failed to create model from config {key}: {e}")
except Exception:
logger.exception("Failed to create model from config %s", key)
else:
logger.warning(f"Unknown model type for {key}: {type(model)}")
logger.warning("Unknown model type for %s: %s", key, type(model))
return result


async def ensure_single_model_instance(model: Any, key: str = "model") -> Any:
async def ensure_single_model_instance(
model: Any, key: str = "model"
) -> BaseModel | None:
from arbitrium_core.infrastructure.llm.litellm_adapter import LiteLLMModel

if isinstance(model, BaseModel):
Expand All @@ -40,9 +42,9 @@ async def ensure_single_model_instance(model: Any, key: str = "model") -> Any:
model_key=model.get("name", key),
model_config=model,
)
except Exception as e:
logger.error(f"Failed to create model from config {key}: {e}")
except Exception:
logger.exception("Failed to create model from config %s", key)
return None
else:
logger.warning(f"Unknown model type for {key}: {type(model)}")
logger.warning("Unknown model type for %s: %s", key, type(model))
return None
20 changes: 6 additions & 14 deletions src/arbitrium_core/ports/cache.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from abc import ABC, abstractmethod
from typing import Protocol


class CacheProtocol(ABC):
@abstractmethod
class CacheProtocol(Protocol):
def get(
self, model_name: str, prompt: str, temperature: float, max_tokens: int
) -> tuple[str, float] | None:
pass
) -> tuple[str, float] | None: ...

@abstractmethod
def set(
self,
model_name: str,
Expand All @@ -17,13 +14,8 @@ def set(
max_tokens: int,
response: str,
cost: float,
) -> None:
pass
) -> None: ...

@abstractmethod
def clear(self) -> None:
pass
def clear(self) -> None: ...

@abstractmethod
def close(self) -> None:
pass
def close(self) -> None: ...
7 changes: 5 additions & 2 deletions src/arbitrium_core/ports/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from abc import ABC, abstractmethod
from collections.abc import Awaitable
from typing import Any, Protocol
from typing import TYPE_CHECKING, Any, Protocol

if TYPE_CHECKING:
from arbitrium_core.ports.cache import CacheProtocol


class ModelResponse:
Expand Down Expand Up @@ -94,5 +97,5 @@ def from_config(
cls,
model_key: str,
config: dict[str, Any],
response_cache: Any | None = None,
response_cache: "CacheProtocol | None" = None,
) -> Awaitable[BaseModel]: ...
Loading
Loading