Skip to content

Commit f66b2c9

Browse files
feat(LAB-2609): add LLM_RLHF label export (#1620)
1 parent baa6869 commit f66b2c9

File tree

9 files changed

+484
-13
lines changed

9 files changed

+484
-13
lines changed

src/kili/adapters/kili_api_gateway/label/annotation_to_json_response.py

Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
from kili.adapters.kili_api_gateway.project.common import get_project
88
from kili.core.graphql.graphql_client import GraphQLClient
99
from kili.domain.annotation import (
10+
ClassicAnnotation,
11+
ClassificationAnnotation,
12+
RankingAnnotation,
13+
TranscriptionAnnotation,
1014
Vertice,
1115
VideoAnnotation,
1216
VideoClassificationAnnotation,
@@ -54,11 +58,17 @@ def patch_label_json_response(self, label: Dict, label_id: LabelId) -> None:
5458
5559
Modifies the input label.
5660
"""
57-
if self._project_input_type == "VIDEO":
61+
if self._project_input_type in {"VIDEO", "LLM_RLHF"}:
5862
annotations = list_annotations(
5963
graphql_client=self._graphql_client,
6064
label_id=label_id,
6165
annotation_fields=("__typename", "id", "job", "path", "labelId"),
66+
classification_annotation_fields=("annotationValue.categories",),
67+
ranking_annotation_fields=(
68+
"annotationValue.orders.elements",
69+
"annotationValue.orders.rank",
70+
),
71+
transcription_annotation_fields=("annotationValue.text",),
6272
video_annotation_fields=(
6373
"frames.start",
6474
"frames.end",
@@ -79,10 +89,15 @@ def patch_label_json_response(self, label: Dict, label_id: LabelId) -> None:
7989
if not annotations and self._label_has_json_response_data(label):
8090
return
8191

82-
annotations = cast(List[VideoAnnotation], annotations)
83-
converted_json_resp = _video_label_annotations_to_json_response(
84-
annotations=annotations, json_interface=self._project_json_interface
85-
)
92+
if self._project_input_type == "VIDEO":
93+
annotations = cast(List[VideoAnnotation], annotations)
94+
converted_json_resp = _video_annotations_to_json_response(
95+
annotations=annotations, json_interface=self._project_json_interface
96+
)
97+
else:
98+
annotations = cast(List[ClassicAnnotation], annotations)
99+
converted_json_resp = _classic_annotations_to_json_response(annotations=annotations)
100+
86101
label["jsonResponse"] = converted_json_resp
87102

88103

@@ -105,7 +120,7 @@ def _fill_empty_frames(json_response: Dict) -> None:
105120
json_response.setdefault(str(frame_id), {})
106121

107122

108-
def _video_label_annotations_to_json_response(
123+
def _video_annotations_to_json_response(
109124
annotations: List[VideoAnnotation], json_interface: Dict
110125
) -> Dict[str, Dict[JobName, Dict]]:
111126
"""Convert video label annotations to a video json response."""
@@ -147,14 +162,49 @@ def _video_label_annotations_to_json_response(
147162
json_resp[frame_id] = {**json_resp[frame_id], **frame_json_resp}
148163

149164
else:
150-
raise NotImplementedError(f"Cannot convert annotation to json response: {ann}")
165+
raise NotImplementedError(f"Cannot convert video annotation to json response: {ann}")
151166

152167
_add_annotation_metadata(annotations, json_resp)
153168
_fill_empty_frames(json_resp)
154169

155170
return dict(sorted(json_resp.items(), key=lambda item: int(item[0]))) # sort by frame id
156171

157172

173+
def _classic_annotations_to_json_response(
174+
annotations: List[ClassicAnnotation],
175+
) -> Dict[str, Dict[JobName, Dict]]:
176+
"""Convert label annotations to a json response."""
177+
json_resp = defaultdict(dict)
178+
179+
for ann in annotations:
180+
if ann["__typename"] == "ClassificationAnnotation":
181+
ann = cast(ClassificationAnnotation, ann)
182+
ann_json_resp = _classification_annotation_to_json_response(ann)
183+
for job_name, job_resp in ann_json_resp.items():
184+
json_resp.setdefault(job_name, {}).setdefault("categories", []).extend(
185+
job_resp["categories"]
186+
)
187+
188+
elif ann["__typename"] == "RankingAnnotation":
189+
ann = cast(RankingAnnotation, ann)
190+
ann_json_resp = _ranking_annotation_to_json_response(ann)
191+
for job_name, job_resp in ann_json_resp.items():
192+
json_resp.setdefault(job_name, {}).setdefault("orders", []).extend(
193+
job_resp["orders"]
194+
)
195+
196+
elif ann["__typename"] == "TranscriptionAnnotation":
197+
ann = cast(TranscriptionAnnotation, ann)
198+
ann_json_resp = _transcription_annotation_to_json_response(ann)
199+
for job_name, job_resp in ann_json_resp.items():
200+
json_resp.setdefault(job_name, {}).setdefault("text", job_resp["text"])
201+
202+
else:
203+
raise NotImplementedError(f"Cannot convert classic annotation to json response: {ann}")
204+
205+
return dict(json_resp)
206+
207+
158208
@overload
159209
def _key_annotations_iterator(
160210
annotation: VideoTranscriptionAnnotation,
@@ -226,6 +276,40 @@ def _key_annotations_iterator(annotation: VideoAnnotation) -> Generator:
226276
yield key_ann, key_ann_start, key_ann_end, next_key_ann
227277

228278

279+
def _ranking_annotation_to_json_response(
280+
annotation: RankingAnnotation,
281+
) -> Dict[JobName, Dict]:
282+
"""Convert ranking annotation to a json response.
283+
284+
Ranking jobs cannot have child jobs.
285+
"""
286+
json_resp = {
287+
annotation["job"]: {
288+
"orders": sorted(
289+
annotation["annotationValue"]["orders"], key=lambda item: int(item["rank"])
290+
),
291+
}
292+
}
293+
294+
return json_resp
295+
296+
297+
def _transcription_annotation_to_json_response(
298+
annotation: TranscriptionAnnotation,
299+
) -> Dict[JobName, Dict]:
300+
"""Convert transcription annotation to a json response.
301+
302+
Transcription jobs cannot have child jobs.
303+
"""
304+
json_resp = {
305+
annotation["job"]: {
306+
"text": annotation["annotationValue"]["text"],
307+
}
308+
}
309+
310+
return json_resp
311+
312+
229313
def _video_transcription_annotation_to_json_response(
230314
annotation: VideoTranscriptionAnnotation,
231315
) -> Dict[str, Dict[JobName, Dict]]:
@@ -286,6 +370,25 @@ def _compute_children_json_resp(
286370
return children_json_resp
287371

288372

373+
def _classification_annotation_to_json_response(
374+
annotation: ClassificationAnnotation,
375+
) -> Dict[JobName, Dict]:
376+
# initialize the json response
377+
json_resp = {
378+
annotation["job"]: {
379+
"categories": [],
380+
}
381+
}
382+
383+
# a frame can have one or multiple categories
384+
categories = annotation["annotationValue"]["categories"]
385+
for category in categories:
386+
category_annotation: Dict = {"name": category}
387+
json_resp[annotation["job"]]["categories"].append(category_annotation)
388+
389+
return json_resp
390+
391+
289392
def _video_classification_annotation_to_json_response(
290393
annotation: VideoClassificationAnnotation,
291394
other_annotations: List[VideoAnnotation],

src/kili/adapters/kili_api_gateway/label/common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ def list_annotations(
1515
label_id: LabelId,
1616
*,
1717
annotation_fields: ListOrTuple[str],
18+
classification_annotation_fields: ListOrTuple[str] = (),
19+
ranking_annotation_fields: ListOrTuple[str] = (),
20+
transcription_annotation_fields: ListOrTuple[str] = (),
1821
video_annotation_fields: ListOrTuple[str] = (),
1922
video_classification_fields: ListOrTuple[str] = (),
2023
video_object_detection_fields: ListOrTuple[str] = (),
@@ -23,6 +26,9 @@ def list_annotations(
2326
"""List annotations for a label."""
2427
query = get_annotations_query(
2528
annotation_fragment=fragment_builder(annotation_fields),
29+
classification_annotation_fragment=fragment_builder(classification_annotation_fields),
30+
ranking_annotation_fragment=fragment_builder(ranking_annotation_fields),
31+
transcription_annotation_fragment=fragment_builder(transcription_annotation_fields),
2632
video_annotation_fragment=fragment_builder(video_annotation_fields),
2733
video_classification_annotation_fragment=fragment_builder(video_classification_fields),
2834
video_object_detection_annotation_fragment=fragment_builder(video_object_detection_fields),

src/kili/adapters/kili_api_gateway/label/operations.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def get_append_to_labels_mutation(fragment: str) -> str:
9191
def get_annotations_query(
9292
*,
9393
annotation_fragment: str,
94+
classification_annotation_fragment: str,
95+
ranking_annotation_fragment: str,
96+
transcription_annotation_fragment: str,
9497
video_annotation_fragment: str,
9598
video_object_detection_annotation_fragment: str,
9699
video_classification_annotation_fragment: str,
@@ -99,6 +102,27 @@ def get_annotations_query(
99102
"""Get the gql annotations query."""
100103
inline_fragments = ""
101104

105+
if classification_annotation_fragment.strip():
106+
inline_fragments += f"""
107+
... on ClassificationAnnotation {{
108+
{classification_annotation_fragment}
109+
}}
110+
"""
111+
112+
if ranking_annotation_fragment.strip():
113+
inline_fragments += f"""
114+
... on RankingAnnotation {{
115+
{ranking_annotation_fragment}
116+
}}
117+
"""
118+
119+
if transcription_annotation_fragment.strip():
120+
inline_fragments += f"""
121+
... on TranscriptionAnnotation {{
122+
{transcription_annotation_fragment}
123+
}}
124+
"""
125+
102126
if video_annotation_fragment.strip():
103127
inline_fragments += f"""
104128
... on VideoAnnotation {{

src/kili/adapters/kili_api_gateway/label/operations_mixin.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def list_annotations(
178178
label_id: LabelId,
179179
*,
180180
annotation_fields: ListOrTuple[str],
181+
classification_annotation_fields: ListOrTuple[str] = (),
182+
ranking_annotation_fields: ListOrTuple[str] = (),
183+
transcription_annotation_fields: ListOrTuple[str] = (),
181184
video_annotation_fields: ListOrTuple[str] = (),
182185
video_classification_fields: ListOrTuple[str] = (),
183186
video_object_detection_fields: ListOrTuple[str] = (),
@@ -188,6 +191,9 @@ def list_annotations(
188191
graphql_client=self.graphql_client,
189192
label_id=label_id,
190193
annotation_fields=annotation_fields,
194+
classification_annotation_fields=classification_annotation_fields,
195+
ranking_annotation_fields=ranking_annotation_fields,
196+
transcription_annotation_fields=transcription_annotation_fields,
191197
video_annotation_fields=video_annotation_fields,
192198
video_classification_fields=video_classification_fields,
193199
video_object_detection_fields=video_object_detection_fields,

src/kili/domain/annotation.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .ontology import JobName
77

88
AnnotationId = NewType("AnnotationId", str)
9+
AnnotationValueId = NewType("AnnotationValueId", str)
910
KeyAnnotationId = NewType("KeyAnnotationId", str)
1011

1112

@@ -28,12 +29,61 @@ class ClassificationAnnotationValue(TypedDict):
2829
categories: List[str]
2930

3031

32+
class ClassificationAnnotation(TypedDict):
33+
"""Classification annotation."""
34+
35+
# pylint: disable=unused-private-member
36+
__typename: Literal["ClassificationAnnotation"]
37+
id: AnnotationId
38+
labelId: LabelId
39+
job: JobName
40+
path: List[List[str]]
41+
annotationValue: ClassificationAnnotationValue
42+
43+
44+
class RankingOrderValue(TypedDict):
45+
"""Ranking order value."""
46+
47+
rank: int
48+
elements: List[str]
49+
50+
51+
class RankingAnnotationValue(TypedDict):
52+
"""Ranking annotation value."""
53+
54+
orders: List[RankingOrderValue]
55+
56+
57+
class RankingAnnotation(TypedDict):
58+
"""Ranking annotation."""
59+
60+
# pylint: disable=unused-private-member
61+
__typename: Literal["RankingAnnotation"]
62+
id: AnnotationId
63+
labelId: LabelId
64+
job: JobName
65+
path: List[List[str]]
66+
annotationValue: RankingAnnotationValue
67+
68+
3169
class TranscriptionAnnotationValue(TypedDict):
3270
"""Transcription annotation value."""
3371

3472
text: str
3573

3674

75+
class TranscriptionAnnotation(TypedDict):
76+
"""Transcription annotation."""
77+
78+
# pylint: disable=unused-private-member
79+
__typename: Literal["TranscriptionAnnotation"]
80+
id: AnnotationId
81+
labelId: LabelId
82+
job: JobName
83+
path: List[List[str]]
84+
annotationValue: TranscriptionAnnotationValue
85+
86+
3787
class Annotation(TypedDict):
3888
"""Annotation."""
3989

@@ -121,3 +171,16 @@ class VideoTranscriptionAnnotation(TypedDict):
121171
VideoClassificationAnnotation,
122172
VideoTranscriptionAnnotation,
123173
]
174+
175+
ClassicAnnotation = Union[
176+
ClassificationAnnotation,
177+
RankingAnnotation,
178+
TranscriptionAnnotation,
179+
]
180+
181+
Annotation = Union[
182+
ClassificationAnnotation,
183+
RankingAnnotation,
184+
TranscriptionAnnotation,
185+
VideoAnnotation,
186+
]

0 commit comments

Comments
 (0)