1111from matplotlib import pyplot as plt
1212from modelgauge .annotator import Annotator
1313from modelgauge .annotator_registry import ANNOTATORS
14- from modelgauge .annotator_set import AnnotatorSet
1514from modelgauge .dataset import AnnotationDataset
16- from modelgauge .ensemble_annotator_set import ENSEMBLE_STRATEGIES , EnsembleAnnotatorSet
15+ from modelgauge .ensemble_annotator import EnsembleAnnotator
16+ from modelgauge .ensemble_strategies import ENSEMBLE_STRATEGIES
1717from modelgauge .pipeline_runner import build_runner
1818
1919from modelplane .mlflow .loghelpers import log_tags
20+ from modelplane .runways .data import (
21+ Artifact ,
22+ BaseInput ,
23+ RunArtifacts ,
24+ build_and_log_input ,
25+ )
2026from modelplane .runways .utils import (
2127 CACHE_DIR ,
2228 MODELGAUGE_RUN_TAG_NAME ,
2733 is_debug_mode ,
2834 setup_annotator_credentials ,
2935)
30- from modelplane .runways .data import (
31- Artifact ,
32- BaseInput ,
33- RunArtifacts ,
34- build_and_log_input ,
35- )
36-
37- KNOWN_ENSEMBLES : Dict [str , AnnotatorSet ] = {}
38- # try to load the private ensemble
39- try :
40- from modelgauge .private_ensemble_annotator_set import PRIVATE_ANNOTATOR_SET
41-
42- KNOWN_ENSEMBLES ["official-1.0" ] = PRIVATE_ANNOTATOR_SET
43- except NotImplementedError :
44- pass
4536
4637
4738def annotate (
4839 experiment : str ,
40+ annotator_ids : List [str ],
4941 input_object : BaseInput | None = None ,
5042 dvc_repo : str | None = None ,
5143 response_file : str | None = None ,
5244 response_run_id : str | None = None ,
53- annotator_ids : List [str ] | None = None ,
5445 ensemble_strategy : str | None = None ,
55- ensemble_id : str | None = None ,
5646 overwrite : bool = False ,
5747 disable_cache : bool = False ,
5848 num_workers : int = 1 ,
@@ -65,9 +55,7 @@ def annotate(
6555 Run annotations and record measurements.
6656 """
6757 # this will set annotator_ids and optionally ensemble
68- pipeline_kwargs = _get_annotator_settings (
69- annotator_ids , ensemble_strategy , ensemble_id
70- )
58+ pipeline_kwargs = _get_annotator_settings (annotator_ids , ensemble_strategy )
7159 if not disable_cache :
7260 pipeline_kwargs ["cache_dir" ] = CACHE_DIR
7361 pipeline_kwargs ["num_workers" ] = num_workers
@@ -83,8 +71,6 @@ def annotate(
8371 )
8472 if ensemble_strategy is not None :
8573 tags ["ensemble_strategy" ] = ensemble_strategy
86- if ensemble_id is not None :
87- tags ["ensemble_id" ] = ensemble_id
8874
8975 experiment_id = get_experiment_id (experiment )
9076 if overwrite and response_run_id :
@@ -155,38 +141,26 @@ def annotate(
155141
156142
157143def _get_annotator_settings (
158- annotator_ids : List [str ] | None ,
144+ annotator_ids : List [str ],
159145 ensemble_strategy : str | None ,
160- ensemble_id : str | None ,
161146) -> Dict [str , Any ]:
162147
163148 kwargs = {}
164149
165- if not ((annotator_ids is not None ) ^ (ensemble_id is not None )):
166- raise ValueError ("Either annotator_ids or ensemble_id must be provided." )
167- if annotator_ids is not None :
168- kwargs ["annotators" ] = _get_annotators (annotator_ids )
169-
170- if ensemble_strategy is not None :
171- if ensemble_strategy not in ENSEMBLE_STRATEGIES :
172- raise ValueError (
173- f"Unknown ensemble strategy: { ensemble_strategy } . "
174- f"Available strategies: { list (ENSEMBLE_STRATEGIES .keys ())} "
175- )
176- kwargs ["ensemble" ] = EnsembleAnnotatorSet (
177- annotators = annotator_ids ,
178- strategy = ENSEMBLE_STRATEGIES [ensemble_strategy ],
179- )
180- return kwargs
181- else :
182- if ensemble_id not in KNOWN_ENSEMBLES :
150+ kwargs ["annotators" ] = _get_annotators (annotator_ids )
151+
152+ if ensemble_strategy is not None :
153+ if ensemble_strategy not in ENSEMBLE_STRATEGIES :
183154 raise ValueError (
184- f"Unknown ensemble_id : { ensemble_id } . "
185- f"Available strategies: { list (KNOWN_ENSEMBLES .keys ())} "
155+ f"Unknown ensemble strategy : { ensemble_strategy } . "
156+ f"Available strategies: { list (ENSEMBLE_STRATEGIES .keys ())} "
186157 )
187- kwargs ["ensemble" ] = KNOWN_ENSEMBLES [ensemble_id ]
188- kwargs ["annotators" ] = _get_annotators (KNOWN_ENSEMBLES [ensemble_id ].annotators )
189- return kwargs
158+ kwargs ["ensemble" ] = EnsembleAnnotator (
159+ uid = "ensemble" ,
160+ annotators = annotator_ids ,
161+ ensemble_strategy = ensemble_strategy ,
162+ )
163+ return kwargs
190164
191165
192166def _get_annotators (annotator_ids : List [str ]) -> Dict [str , Annotator ]:
0 commit comments