diff --git a/dreadnode/scorers/judge.py b/dreadnode/scorers/judge.py index df5bab28..d2205b06 100644 --- a/dreadnode/scorers/judge.py +++ b/dreadnode/scorers/judge.py @@ -42,6 +42,7 @@ def llm_judge( input: t.Any | None = None, expected_output: t.Any | None = None, model_params: rg.GenerateParams | AnyDict | None = None, + fallback_model: str | rg.Generator | None = None, passing: t.Callable[[float], bool] | None = None, min_score: float | None = None, max_score: float | None = None, @@ -56,12 +57,30 @@ def llm_judge( input: The input which produced the output for context, if applicable. expected_output: The expected output to compare against, if applicable. model_params: Optional parameters for the model. + fallback_model: Optional fallback model to use if the primary model fails. passing: Optional callback to determine if the score is passing based on the score value - overrides any model-specified value. min_score: Optional minimum score for the judgement - if provided, the score will be clamped to this value. max_score: Optional maximum score for the judgement - if provided, the score will be clamped to this value. name: The name of the scorer. """ + def _get_generator( + model_input: str | rg.Generator, params: rg.GenerateParams | AnyDict | None + ) -> rg.Generator: + """Helper to create a generator from model string or return existing generator.""" + if isinstance(model_input, str): + return rg.get_generator( + model_input, + params=params + if isinstance(params, rg.GenerateParams) + else rg.GenerateParams.model_validate(params) + if params + else None, + ) + if isinstance(model_input, rg.Generator): + return model_input + raise TypeError("Model must be a string identifier or a Generator instance.") + async def evaluate( data: t.Any, *, @@ -72,24 +91,10 @@ async def evaluate( input: t.Any | None = input, expected_output: t.Any | None = expected_output, model_params: rg.GenerateParams | AnyDict | None = model_params, + fallback_model: str | rg.Generator | None = fallback_model, min_score: float | None = min_score, max_score: float | None = max_score, ) -> list[Metric]: - generator: rg.Generator - if isinstance(model, str): - generator = rg.get_generator( - model, - params=model_params - if isinstance(model_params, rg.GenerateParams) - else rg.GenerateParams.model_validate(model_params) - if model_params - else None, - ) - elif isinstance(model, rg.Generator): - generator = model - else: - raise TypeError("Model must be a string identifier or a Generator instance.") - input_data = JudgeInput( input=str(input) if input is not None else None, expected_output=str(expected_output) if expected_output is not None else None, @@ -97,7 +102,15 @@ async def evaluate( rubric=rubric, ) - judgement = await judge.bind(generator)(input_data) + # Try primary model, fallback if needed + try: + generator = _get_generator(model, model_params) + judgement = await judge.bind(generator)(input_data) + except Exception: + if fallback_model is None: + raise + generator = _get_generator(fallback_model, model_params) + judgement = await judge.bind(generator)(input_data) if min_score is not None: judgement.score = max(min_score, judgement.score)