Skip to content

Commit 44055cf

Browse files
committed
Align with latest modelbench and fix tests.
1 parent 7645ad9 commit 44055cf

File tree

5 files changed

+42
-87
lines changed

5 files changed

+42
-87
lines changed

src/modelplane/cli.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from typing import List
22

33
import click
4+
from modelgauge.data_schema import DEFAULT_ANNOTATION_SCHEMA
5+
from modelgauge.ensemble_strategies import ENSEMBLE_STRATEGIES
46

5-
from modelgauge.data_schema import DEFAULT_ANNOTATION_SCHEMA as ANNOTATION_SCHEMA
6-
from modelgauge.ensemble_annotator_set import ENSEMBLE_STRATEGIES
7-
8-
from modelplane.runways.annotator import annotate, KNOWN_ENSEMBLES
7+
from modelplane.runways.annotator import annotate
98
from modelplane.runways.lister import (
109
list_annotators,
1110
list_ensemble_strategies,
@@ -152,13 +151,6 @@ def get_sut_responses(
152151
help="The ensemble strategy to use. If set, individual annotator results will be combined using the given strategy. "
153152
"Available strategies: " + ", ".join(list(ENSEMBLE_STRATEGIES.keys())),
154153
)
155-
@click.option(
156-
"--ensemble_id",
157-
type=str,
158-
default=None,
159-
help="Use a fixed ensemble id to use a predefined ensemble strategy. Options include: "
160-
+ ", ".join(list(KNOWN_ENSEMBLES.keys())),
161-
)
162154
@click.option(
163155
"--overwrite",
164156
is_flag=True,
@@ -204,12 +196,11 @@ def get_sut_responses(
204196
@load_from_dotenv
205197
def get_annotations(
206198
experiment: str,
199+
annotator_id: List[str],
207200
dvc_repo: str | None = None,
208201
response_file: str | None = None,
209202
response_run_id: str | None = None,
210-
annotator_id: List[str] | None = None,
211203
ensemble_strategy: str | None = None,
212-
ensemble_id: str | None = None,
213204
overwrite: bool = False,
214205
disable_cache: bool = False,
215206
num_workers: int = 1,
@@ -225,7 +216,6 @@ def get_annotations(
225216
response_run_id=response_run_id,
226217
annotator_ids=annotator_id,
227218
ensemble_strategy=ensemble_strategy,
228-
ensemble_id=ensemble_id,
229219
overwrite=overwrite,
230220
disable_cache=disable_cache,
231221
num_workers=num_workers,
@@ -285,8 +275,8 @@ def score_annotations(
285275
ground_truth: str,
286276
dvc_repo: str | None = None,
287277
sample_uid_col: str | None = None,
288-
annotator_uid_col: str = ANNOTATION_SCHEMA.annotator_uid,
289-
annotation_col: str = ANNOTATION_SCHEMA.annotation,
278+
annotator_uid_col: str | None = DEFAULT_ANNOTATION_SCHEMA.annotator_uid,
279+
annotation_col: str | None = DEFAULT_ANNOTATION_SCHEMA.annotation,
290280
):
291281
return score(
292282
annotation_run_id=annotation_run_id,

src/modelplane/runways/annotator.py

Lines changed: 23 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,18 @@
1111
from matplotlib import pyplot as plt
1212
from modelgauge.annotator import Annotator
1313
from modelgauge.annotator_registry import ANNOTATORS
14-
from modelgauge.annotator_set import AnnotatorSet
1514
from modelgauge.dataset import AnnotationDataset
16-
from modelgauge.ensemble_annotator_set import ENSEMBLE_STRATEGIES, EnsembleAnnotatorSet
15+
from modelgauge.ensemble_annotator import EnsembleAnnotator
16+
from modelgauge.ensemble_strategies import ENSEMBLE_STRATEGIES
1717
from modelgauge.pipeline_runner import build_runner
1818

1919
from modelplane.mlflow.loghelpers import log_tags
20+
from modelplane.runways.data import (
21+
Artifact,
22+
BaseInput,
23+
RunArtifacts,
24+
build_and_log_input,
25+
)
2026
from modelplane.runways.utils import (
2127
CACHE_DIR,
2228
MODELGAUGE_RUN_TAG_NAME,
@@ -27,32 +33,16 @@
2733
is_debug_mode,
2834
setup_annotator_credentials,
2935
)
30-
from modelplane.runways.data import (
31-
Artifact,
32-
BaseInput,
33-
RunArtifacts,
34-
build_and_log_input,
35-
)
36-
37-
KNOWN_ENSEMBLES: Dict[str, AnnotatorSet] = {}
38-
# try to load the private ensemble
39-
try:
40-
from modelgauge.private_ensemble_annotator_set import PRIVATE_ANNOTATOR_SET
41-
42-
KNOWN_ENSEMBLES["official-1.0"] = PRIVATE_ANNOTATOR_SET
43-
except NotImplementedError:
44-
pass
4536

4637

4738
def annotate(
4839
experiment: str,
40+
annotator_ids: List[str],
4941
input_object: BaseInput | None = None,
5042
dvc_repo: str | None = None,
5143
response_file: str | None = None,
5244
response_run_id: str | None = None,
53-
annotator_ids: List[str] | None = None,
5445
ensemble_strategy: str | None = None,
55-
ensemble_id: str | None = None,
5646
overwrite: bool = False,
5747
disable_cache: bool = False,
5848
num_workers: int = 1,
@@ -65,9 +55,7 @@ def annotate(
6555
Run annotations and record measurements.
6656
"""
6757
# this will set annotator_ids and optionally ensemble
68-
pipeline_kwargs = _get_annotator_settings(
69-
annotator_ids, ensemble_strategy, ensemble_id
70-
)
58+
pipeline_kwargs = _get_annotator_settings(annotator_ids, ensemble_strategy)
7159
if not disable_cache:
7260
pipeline_kwargs["cache_dir"] = CACHE_DIR
7361
pipeline_kwargs["num_workers"] = num_workers
@@ -83,8 +71,6 @@ def annotate(
8371
)
8472
if ensemble_strategy is not None:
8573
tags["ensemble_strategy"] = ensemble_strategy
86-
if ensemble_id is not None:
87-
tags["ensemble_id"] = ensemble_id
8874

8975
experiment_id = get_experiment_id(experiment)
9076
if overwrite and response_run_id:
@@ -155,38 +141,26 @@ def annotate(
155141

156142

157143
def _get_annotator_settings(
158-
annotator_ids: List[str] | None,
144+
annotator_ids: List[str],
159145
ensemble_strategy: str | None,
160-
ensemble_id: str | None,
161146
) -> Dict[str, Any]:
162147

163148
kwargs = {}
164149

165-
if not ((annotator_ids is not None) ^ (ensemble_id is not None)):
166-
raise ValueError("Either annotator_ids or ensemble_id must be provided.")
167-
if annotator_ids is not None:
168-
kwargs["annotators"] = _get_annotators(annotator_ids)
169-
170-
if ensemble_strategy is not None:
171-
if ensemble_strategy not in ENSEMBLE_STRATEGIES:
172-
raise ValueError(
173-
f"Unknown ensemble strategy: {ensemble_strategy}. "
174-
f"Available strategies: {list(ENSEMBLE_STRATEGIES.keys())}"
175-
)
176-
kwargs["ensemble"] = EnsembleAnnotatorSet(
177-
annotators=annotator_ids,
178-
strategy=ENSEMBLE_STRATEGIES[ensemble_strategy],
179-
)
180-
return kwargs
181-
else:
182-
if ensemble_id not in KNOWN_ENSEMBLES:
150+
kwargs["annotators"] = _get_annotators(annotator_ids)
151+
152+
if ensemble_strategy is not None:
153+
if ensemble_strategy not in ENSEMBLE_STRATEGIES:
183154
raise ValueError(
184-
f"Unknown ensemble_id: {ensemble_id}. "
185-
f"Available strategies: {list(KNOWN_ENSEMBLES.keys())}"
155+
f"Unknown ensemble strategy: {ensemble_strategy}. "
156+
f"Available strategies: {list(ENSEMBLE_STRATEGIES.keys())}"
186157
)
187-
kwargs["ensemble"] = KNOWN_ENSEMBLES[ensemble_id]
188-
kwargs["annotators"] = _get_annotators(KNOWN_ENSEMBLES[ensemble_id].annotators)
189-
return kwargs
158+
kwargs["ensemble"] = EnsembleAnnotator(
159+
uid="ensemble",
160+
annotators=annotator_ids,
161+
ensemble_strategy=ensemble_strategy,
162+
)
163+
return kwargs
190164

191165

192166
def _get_annotators(annotator_ids: List[str]) -> Dict[str, Annotator]:

src/modelplane/runways/lister.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from modelgauge.annotator_registry import ANNOTATORS
2-
from modelgauge.ensemble_annotator_set import ENSEMBLE_STRATEGIES
2+
from modelgauge.ensemble_strategies import ENSEMBLE_STRATEGIES
33
from modelgauge.sut_registry import SUTS
44

55

tests/it/runways/half_safe_annotator.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,18 @@
1-
import random
2-
3-
from pydantic import BaseModel
4-
5-
from modelgauge.annotator import CompletionAnnotator
1+
from modelgauge.annotation import SafetyAnnotation
2+
from modelgauge.annotator import Annotator
63
from modelgauge.annotator_registry import ANNOTATORS
7-
from modelgauge.annotators.llama_guard_annotator import LlamaGuardAnnotation
84
from modelgauge.prompt import ChatPrompt, TextPrompt
95
from modelgauge.sut import SUTResponse
10-
from modelgauge.tests.safe_v1 import BaseSafeTestVersion1
6+
from pydantic import BaseModel
117

128

139
class AnnotationRequest(BaseModel):
1410
prompt: ChatPrompt | TextPrompt
1511
response: SUTResponse
1612

1713

18-
class TestAnnotator(CompletionAnnotator[LlamaGuardAnnotation]):
19-
"""Simple implementation for demonstration. Uses LlamaGuardAnnotation
20-
for consistency with AILuminate benchmarks. Every other item is marked safe.
14+
class TestAnnotator(Annotator):
15+
"""Simple implementation for demonstration.
2116
2217
The methods to implement are based on the idea that annotate will
2318
potentially make a call to another system and the translators are helpful
@@ -34,21 +29,16 @@ def translate_prompt(self, prompt: ChatPrompt | TextPrompt, response: SUTRespons
3429
response=response,
3530
)
3631

37-
def annotate(self, annotation_request: AnnotationRequest) -> LlamaGuardAnnotation:
38-
unsafe_categories = []
39-
if not self.next_is_safe:
40-
unsafe_categories = BaseSafeTestVersion1.hazards
41-
annotation = LlamaGuardAnnotation(
32+
def annotate(self, annotation_request: AnnotationRequest) -> SafetyAnnotation:
33+
annotation = SafetyAnnotation(
4234
is_safe=self.next_is_safe,
43-
is_safe_logprob=random.random(),
44-
violation_categories=unsafe_categories,
4535
)
4636
self.next_is_safe = not self.next_is_safe
4737
return annotation
4838

4939
def translate_response(
50-
self, request: AnnotationRequest, response: LlamaGuardAnnotation
51-
) -> LlamaGuardAnnotation:
40+
self, request: AnnotationRequest, response: SafetyAnnotation
41+
) -> SafetyAnnotation:
5242
return response
5343

5444

tests/unit/test_lister.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
from modelgauge.ensemble_strategies import ENSEMBLE_STRATEGIES
2+
13
from modelplane.runways.lister import (
24
list_annotators,
35
list_ensemble_strategies,
46
list_suts,
57
)
6-
from modelgauge.ensemble_annotator_set import ENSEMBLE_STRATEGIES
78

89

910
def test_list_annotators(capsys):
@@ -13,7 +14,7 @@ def test_list_annotators(capsys):
1314

1415

1516
def test_list_ensemble_strategies(capsys):
16-
ENSEMBLE_STRATEGIES["demo_ensemble_strategy"] = "Demo Ensemble Strategy"
17+
ENSEMBLE_STRATEGIES["demo_ensemble_strategy"] = ENSEMBLE_STRATEGIES["any_unsafe"]
1718
list_ensemble_strategies()
1819
output = capsys.readouterr().out.strip()
1920
assert "demo_ensemble_strategy" in output

0 commit comments

Comments
 (0)