|
70 | 70 | "\n", |
71 | 71 | "* An implementation of an annotator should define:\n", |
72 | 72 | "\n", |
73 | | - " * `translate_request`: takes a `TestItem` and `SUTResponse` (which have the needed information on the prompt responses for the given LLM) and translates the request into an arbitrary form useful for the `annotate` method.\n", |
| 73 | + " * `translate_prompt`: takes a `TextPrompt` and `SUTResponse` (which have the needed information on the prompt responses for the given LLM) and translates the request into an arbitrary form useful for the `annotate` method.\n", |
74 | 74 | " * `annotate`: take the translated request and produce an annotation in any form.\n", |
75 | 75 | " * `translate_response`: take the annotation and translate it to the needed from (`SafetyModelResponse`).\n", |
76 | 76 | "\n", |
77 | | - "* An implementation of the combining logic should define `compute_response` which takes a map of `SafetyModelResponse`s from the underlying annotators and combines them to produce a final annotation (`EnsembleSafetyModelResponse`, which is simply a `SafetyModelResponse` along with the map of underlying responses).\n", |
| 77 | + "* An implementation of the combining logic should define `compute_response` which takes a map of `SafetyAnnotation`s from the underlying annotators and combines them to produce a final annotation (`EnsembleSafetyAnnotation`, which is simply a `SafetyAnnotation` along with the map of underlying responses).\n", |
78 | 78 | "\n", |
79 | 79 | "Below, we give a sample ensemble, which takes two independent `RandomAnnotator` responses and reports unsafe if any underlying annotation was unsafe. This requires implementations of the `RandomAnnotator` and `AnyUnsafeStrategy`." |
80 | 80 | ] |
|
100 | 100 | "source": [ |
101 | 101 | "import random\n", |
102 | 102 | "\n", |
103 | | - "from pydantic import BaseModel\n", |
104 | | - "\n", |
105 | | - "from modelgauge.annotator import CompletionAnnotator\n", |
| 103 | + "from modelgauge.annotation import SafetyAnnotation\n", |
| 104 | + "from modelgauge.annotator import Annotator\n", |
106 | 105 | "from modelgauge.annotator_registry import ANNOTATORS\n", |
107 | 106 | "from modelgauge.prompt import ChatPrompt, TextPrompt\n", |
108 | | - "from modelgauge.safety_model_response import SafetyModelResponse\n", |
109 | 107 | "from modelgauge.sut import SUTResponse\n", |
110 | 108 | "from modelgauge.tests.safe_v1 import BaseSafeTestVersion1\n", |
| 109 | + "from pydantic import BaseModel\n", |
| 110 | + "\n", |
111 | 111 | "\n", |
112 | 112 | "class AnnotationRequest(BaseModel):\n", |
113 | 113 | " prompt: ChatPrompt | TextPrompt\n", |
114 | 114 | " response: SUTResponse\n", |
115 | 115 | "\n", |
116 | 116 | "\n", |
117 | | - "class RandomAnnotator(CompletionAnnotator[SafetyModelResponse]):\n", |
| 117 | + "class RandomAnnotator(Annotator):\n", |
118 | 118 | " \"\"\"Simple random implementation for demonstration. Uses SafetyModelResponse\n", |
119 | 119 | " for consistency with AILuminate benchmarks.\n", |
120 | 120 | "\n", |
|
132 | 132 | " response=response,\n", |
133 | 133 | " )\n", |
134 | 134 | "\n", |
135 | | - " def annotate(self, annotation_request: AnnotationRequest) -> SafetyModelResponse:\n", |
| 135 | + " def annotate(self, annotation_request: AnnotationRequest) -> SafetyAnnotation:\n", |
136 | 136 | " unsafe_categories = []\n", |
137 | 137 | " for category in BaseSafeTestVersion1.hazards:\n", |
138 | 138 | " if random.random() < self.prob_unsafe:\n", |
139 | 139 | " unsafe_categories.append(category)\n", |
140 | | - " return SafetyModelResponse(\n", |
| 140 | + " return SafetyAnnotation(\n", |
141 | 141 | " is_safe=not unsafe_categories,\n", |
142 | 142 | " is_valid=True,\n", |
143 | | - " logprobs=random.random(),\n", |
144 | | - " safety_categories=unsafe_categories,\n", |
145 | 143 | " )\n", |
146 | 144 | "\n", |
147 | | - " def translate_response(self, request: AnnotationRequest, response: SafetyModelResponse) -> SafetyModelResponse:\n", |
| 145 | + " def translate_response(self, request: AnnotationRequest, response: SafetyAnnotation) -> SafetyAnnotation:\n", |
148 | 146 | " return response\n", |
149 | 147 | "\n", |
150 | 148 | "\n", |
|
176 | 174 | "source": [ |
177 | 175 | "from typing import Dict\n", |
178 | 176 | "\n", |
179 | | - "from modelgauge.safety_model_response import EnsembleSafetyModelResponse\n", |
180 | | - "from modelgauge.ensemble_annotator_set import EnsembleStrategy, ENSEMBLE_STRATEGIES\n", |
| 177 | + "from modelgauge.annotation import EnsembleSafetyAnnotation\n", |
| 178 | + "from modelgauge.ensemble_strategies import ENSEMBLE_STRATEGIES, EnsembleStrategy\n", |
181 | 179 | "\n", |
182 | 180 | "\n", |
183 | 181 | "class AnyUnsafeEnsembleStrategy(EnsembleStrategy):\n", |
184 | | - " def compute_response(self, annotations: Dict[str, SafetyModelResponse]) -> EnsembleSafetyModelResponse:\n", |
185 | | - " response = EnsembleSafetyModelResponse(\n", |
| 182 | + " def compute_response(self, annotations: Dict[str, SafetyAnnotation]) -> EnsembleSafetyAnnotation:\n", |
| 183 | + " response = EnsembleSafetyAnnotation(\n", |
186 | 184 | " joined_responses=annotations, \n", |
187 | 185 | " is_safe=True,\n", |
188 | | - " safety_categories=[],\n", |
189 | 186 | " is_valid=False,\n", |
190 | 187 | " )\n", |
191 | 188 | " for annotation in annotations.values():\n", |
192 | 189 | " if annotation.is_valid:\n", |
193 | 190 | " response.is_valid = True\n", |
194 | 191 | " if not annotation.is_safe:\n", |
195 | 192 | " response.is_safe = False\n", |
196 | | - " current_categories = set(response.safety_categories)\n", |
197 | | - " current_categories.update(annotation.safety_categories)\n", |
198 | | - " response.safety_categories = sorted(current_categories)\n", |
199 | 193 | " return response\n", |
200 | 194 | "\n", |
201 | 195 | "ENSEMBLE_STRATEGIES[\"anyunsafe\"] = AnyUnsafeEnsembleStrategy()" |
|
286 | 280 | "display_name": "Python 3 (ipykernel)", |
287 | 281 | "language": "python", |
288 | 282 | "name": "python3" |
289 | | - }, |
290 | | - "language_info": { |
291 | | - "codemirror_mode": { |
292 | | - "name": "ipython", |
293 | | - "version": 3 |
294 | | - }, |
295 | | - "file_extension": ".py", |
296 | | - "mimetype": "text/x-python", |
297 | | - "name": "python", |
298 | | - "nbconvert_exporter": "python", |
299 | | - "pygments_lexer": "ipython3", |
300 | | - "version": "3.12.11" |
301 | 283 | } |
302 | 284 | }, |
303 | 285 | "nbformat": 4, |
|
0 commit comments