Skip to content

Commit 04f8b63

Browse files
feat(LAB-3098): dynamic llm export completion level classification (#1828)
Co-authored-by: paulruelle <[email protected]>
1 parent 966223f commit 04f8b63

File tree

2 files changed

+275
-9
lines changed

2 files changed

+275
-9
lines changed

src/kili/llm/services/export/dynamic.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
"modelName",
2929
]
3030

31+
DEFAULT_JOB_LEVEL = "round"
32+
3133

3234
class LLMDynamicExporter:
3335
"""Handle exports of LLM_RLHF projects."""
@@ -72,8 +74,10 @@ def export(
7274
"label_type": label["labelType"],
7375
"label": {},
7476
}
75-
if formatted_response["turn"]:
76-
label_data["label"]["turn"] = formatted_response["turn"]
77+
if formatted_response["round"]:
78+
label_data["label"]["round"] = formatted_response["round"]
79+
if formatted_response["completion"]:
80+
label_data["label"]["completion"] = formatted_response["completion"]
7781
if step == total_rounds - 1 and formatted_response["conversation"]:
7882
label_data["label"]["conversation"] = formatted_response["conversation"]
7983

@@ -238,7 +242,7 @@ def _format_comparison_annotation(annotation, completions, job, obfuscated_model
238242
def _format_json_response(
239243
jobs_config: Dict, annotations: List[Dict], completions: List[Dict], obfuscated_models: Dict
240244
) -> Dict[str, Dict[str, Union[str, List[str]]]]:
241-
result = {"turn": {}, "conversation": {}}
245+
result = {"round": {}, "conversation": {}, "completion": {}}
242246
for annotation in annotations:
243247
formatted_response = None
244248
job = jobs_config[annotation["job"]]
@@ -251,14 +255,20 @@ def _format_json_response(
251255
annotation, completions, job, obfuscated_models
252256
)
253257

258+
job_level = job.get("level", DEFAULT_JOB_LEVEL)
259+
254260
if formatted_response is None:
255261
logging.warning(
256262
f"Annotation with job {annotation['job']} with mlTask {job['mlTask']} not supported. Ignored in the export."
257263
)
258-
elif "level" in job and job["level"] == "conversation":
259-
result["conversation"][annotation["job"]] = formatted_response
264+
265+
elif job_level == "completion":
266+
result.setdefault(job_level, {}).setdefault(annotation["job"], {})[
267+
annotation["chatItemId"]
268+
] = formatted_response
269+
260270
else:
261-
result["turn"][annotation["job"]] = formatted_response
271+
result[job_level][annotation["job"]] = formatted_response
262272

263273
return result
264274

tests/unit/llm/services/export/test_dynamic.py

Lines changed: 259 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,10 @@
281281
"created_at": "2024-08-06T12:30:42.122Z",
282282
"label_type": "DEFAULT",
283283
"label": {
284-
"turn": {"COMPARISON_JOB": "A_3", "CLASSIFICATION_JOB": ["BOTH_ARE_GOOD"]},
284+
"round": {
285+
"COMPARISON_JOB": "A_3",
286+
"CLASSIFICATION_JOB": ["BOTH_ARE_GOOD"],
287+
},
285288
},
286289
}
287290
],
@@ -358,7 +361,7 @@
358361
"created_at": "2024-08-06T12:30:42.122Z",
359362
"label_type": "DEFAULT",
360363
"label": {
361-
"turn": {"COMPARISON_JOB": "B_1"},
364+
"round": {"COMPARISON_JOB": "B_1"},
362365
},
363366
}
364367
],
@@ -449,7 +452,7 @@
449452
"created_at": "2024-08-06T12:30:42.122Z",
450453
"label_type": "DEFAULT",
451454
"label": {
452-
"turn": {"COMPARISON_JOB": "A_2"},
455+
"round": {"COMPARISON_JOB": "A_2"},
453456
},
454457
}
455458
],
@@ -709,3 +712,256 @@ def test_export_dynamic_with_conversation_level(mocker):
709712
project_id="project_id",
710713
)
711714
assert result == updated_expected_export
715+
716+
717+
def test_export_dynamic_with_completion_level(mocker):
718+
updated_mock_json_interface = copy.deepcopy(mock_json_interface)
719+
720+
updated_mock_json_interface["jobs"].update(
721+
{
722+
"CLASSIFICATION_JOB_AT_COMPLETION_LEVEL": {
723+
"content": {
724+
"categories": {
725+
"TOO_SHORT": {"children": [], "name": "Too short", "id": "category1"},
726+
"JUST_RIGHT": {"children": [], "name": "Just right", "id": "category2"},
727+
"TOO_VERBOSE": {"children": [], "name": "Too verbose", "id": "category3"},
728+
},
729+
"input": "radio",
730+
},
731+
"instruction": "Verbosity",
732+
"level": "completion",
733+
"mlTask": "CLASSIFICATION",
734+
"required": 0,
735+
"isChild": False,
736+
"isNew": False,
737+
},
738+
"CLASSIFICATION_JOB_AT_COMPLETION_LEVEL_1": {
739+
"content": {
740+
"categories": {
741+
"NO_ISSUES": {"children": [], "name": "No issues", "id": "category4"},
742+
"MINOR_ISSUES": {
743+
"children": [],
744+
"name": "Minor issue(s)",
745+
"id": "category5",
746+
},
747+
"MAJOR_ISSUES": {
748+
"children": [],
749+
"name": "Major issue(s)",
750+
"id": "category6",
751+
},
752+
},
753+
"input": "radio",
754+
},
755+
"instruction": "Instructions Following",
756+
"level": "completion",
757+
"mlTask": "CLASSIFICATION",
758+
"required": 0,
759+
"isChild": False,
760+
"isNew": False,
761+
},
762+
"CLASSIFICATION_JOB_AT_COMPLETION_LEVEL_2": {
763+
"content": {
764+
"categories": {
765+
"NO_ISSUES": {"children": [], "name": "No issues", "id": "category7"},
766+
"MINOR_INACCURACY": {
767+
"children": [],
768+
"name": "Minor inaccuracy",
769+
"id": "category8",
770+
},
771+
"MAJOR_INACCURACY": {
772+
"children": [],
773+
"name": "Major inaccuracy",
774+
"id": "category9",
775+
},
776+
},
777+
"input": "radio",
778+
},
779+
"instruction": "Truthfulness",
780+
"level": "completion",
781+
"mlTask": "CLASSIFICATION",
782+
"required": 0,
783+
"isChild": False,
784+
"isNew": False,
785+
},
786+
"CLASSIFICATION_JOB_AT_COMPLETION_LEVEL_3": {
787+
"content": {
788+
"categories": {
789+
"NO_ISSUES": {"children": [], "name": "No issues", "id": "category10"},
790+
"MINOR_SAFETY_CONCERN": {
791+
"children": [],
792+
"name": "Minor safety concern",
793+
"id": "category11",
794+
},
795+
"MAJOR_SAFETY_CONCERN": {
796+
"children": [],
797+
"name": "Major safety concern",
798+
"id": "category12",
799+
},
800+
},
801+
"input": "radio",
802+
},
803+
"instruction": "Harmlessness/Safety",
804+
"level": "completion",
805+
"mlTask": "CLASSIFICATION",
806+
"required": 0,
807+
"isChild": False,
808+
"isNew": False,
809+
},
810+
}
811+
)
812+
813+
updated_mock_fetch_assets = copy.deepcopy(mock_fetch_assets)
814+
updated_mock_fetch_assets[0]["labels"][0]["annotations"].extend(
815+
[
816+
{
817+
"id": "20241209092703759-1",
818+
"job": "CLASSIFICATION_JOB_AT_COMPLETION_LEVEL",
819+
"path": [],
820+
"labelId": "clzief6q2003e7tc91jm46uii",
821+
"chatItemId": "clzieuhlc005a7tc9bx6f0mb5",
822+
"annotationValue": {
823+
"categories": ["TOO_SHORT"],
824+
"id": "20241209092703759-1",
825+
"isPrediction": False,
826+
"__typename": "ClassificationAnnotationValue",
827+
},
828+
"__typename": "ClassificationAnnotation",
829+
},
830+
{
831+
"id": "20241209092704576-2",
832+
"job": "CLASSIFICATION_JOB_AT_COMPLETION_LEVEL_1",
833+
"path": [],
834+
"labelId": "clzief6q2003e7tc91jm46uii",
835+
"chatItemId": "clzieuhlc005a7tc9bx6f0mb5",
836+
"annotationValue": {
837+
"categories": ["MINOR_ISSUES"],
838+
"id": "20241209092704576-2",
839+
"isPrediction": False,
840+
"__typename": "ClassificationAnnotationValue",
841+
},
842+
"__typename": "ClassificationAnnotation",
843+
},
844+
{
845+
"id": "20241209092705314-3",
846+
"job": "CLASSIFICATION_JOB_AT_COMPLETION_LEVEL_2",
847+
"path": [],
848+
"labelId": "clzief6q2003e7tc91jm46uii",
849+
"chatItemId": "clzieuhlc005a7tc9bx6f0mb5",
850+
"annotationValue": {
851+
"categories": ["MAJOR_INACCURACY"],
852+
"id": "20241209092705314-3",
853+
"isPrediction": False,
854+
"__typename": "ClassificationAnnotationValue",
855+
},
856+
"__typename": "ClassificationAnnotation",
857+
},
858+
{
859+
"id": "20241209092706381-4",
860+
"job": "CLASSIFICATION_JOB_AT_COMPLETION_LEVEL_3",
861+
"path": [],
862+
"labelId": "clzief6q2003e7tc91jm46uii",
863+
"chatItemId": "clzieuhlc005a7tc9bx6f0mb5",
864+
"annotationValue": {
865+
"categories": ["MAJOR_SAFETY_CONCERN"],
866+
"id": "20241209092706381-4",
867+
"isPrediction": False,
868+
"__typename": "ClassificationAnnotationValue",
869+
},
870+
"__typename": "ClassificationAnnotation",
871+
},
872+
{
873+
"id": "20241209092707543-5",
874+
"job": "CLASSIFICATION_JOB_AT_COMPLETION_LEVEL",
875+
"path": [],
876+
"labelId": "clzief6q2003e7tc91jm46uii",
877+
"chatItemId": "clzieuhm1005b7tc9b747clxw",
878+
"annotationValue": {
879+
"categories": ["JUST_RIGHT"],
880+
"id": "20241209092707543-5",
881+
"isPrediction": False,
882+
"__typename": "ClassificationAnnotationValue",
883+
},
884+
"__typename": "ClassificationAnnotation",
885+
},
886+
{
887+
"id": "20241209092710361-6",
888+
"job": "CLASSIFICATION_JOB_AT_COMPLETION_LEVEL_1",
889+
"path": [],
890+
"labelId": "clzief6q2003e7tc91jm46uii",
891+
"chatItemId": "clzieuhm1005b7tc9b747clxw",
892+
"annotationValue": {
893+
"categories": ["NO_ISSUES"],
894+
"id": "20241209092710361-6",
895+
"isPrediction": False,
896+
"__typename": "ClassificationAnnotationValue",
897+
},
898+
"__typename": "ClassificationAnnotation",
899+
},
900+
{
901+
"id": "20241209092711511-7",
902+
"job": "CLASSIFICATION_JOB_AT_COMPLETION_LEVEL_2",
903+
"path": [],
904+
"labelId": "clzief6q2003e7tc91jm46uii",
905+
"chatItemId": "clzieuhm1005b7tc9b747clxw",
906+
"annotationValue": {
907+
"categories": ["NO_ISSUES"],
908+
"id": "20241209092711511-7",
909+
"isPrediction": False,
910+
"__typename": "ClassificationAnnotationValue",
911+
},
912+
"__typename": "ClassificationAnnotation",
913+
},
914+
{
915+
"id": "20241209092713123-8",
916+
"job": "CLASSIFICATION_JOB_AT_COMPLETION_LEVEL_3",
917+
"path": [],
918+
"labelId": "clzief6q2003e7tc91jm46uii",
919+
"chatItemId": "clzieuhm1005b7tc9b747clxw",
920+
"annotationValue": {
921+
"categories": ["NO_ISSUES"],
922+
"id": "20241209092713123-8",
923+
"isPrediction": False,
924+
"__typename": "ClassificationAnnotationValue",
925+
},
926+
"__typename": "ClassificationAnnotation",
927+
},
928+
]
929+
)
930+
931+
updated_expected_export = copy.deepcopy(expected_export)
932+
updated_expected_export[0]["2"]["labels"][0]["label"]["completion"] = {
933+
"CLASSIFICATION_JOB_AT_COMPLETION_LEVEL": {
934+
"clzieuhlc005a7tc9bx6f0mb5": ["TOO_SHORT"],
935+
"clzieuhm1005b7tc9b747clxw": ["JUST_RIGHT"],
936+
},
937+
"CLASSIFICATION_JOB_AT_COMPLETION_LEVEL_1": {
938+
"clzieuhlc005a7tc9bx6f0mb5": ["MINOR_ISSUES"],
939+
"clzieuhm1005b7tc9b747clxw": ["NO_ISSUES"],
940+
},
941+
"CLASSIFICATION_JOB_AT_COMPLETION_LEVEL_2": {
942+
"clzieuhlc005a7tc9bx6f0mb5": ["MAJOR_INACCURACY"],
943+
"clzieuhm1005b7tc9b747clxw": ["NO_ISSUES"],
944+
},
945+
"CLASSIFICATION_JOB_AT_COMPLETION_LEVEL_3": {
946+
"clzieuhlc005a7tc9bx6f0mb5": ["MAJOR_SAFETY_CONCERN"],
947+
"clzieuhm1005b7tc9b747clxw": ["NO_ISSUES"],
948+
},
949+
}
950+
get_project_return_val = {
951+
"jsonInterface": updated_mock_json_interface,
952+
"inputType": "LLM_INSTR_FOLLOWING",
953+
"title": "Test project with classifications at completion level",
954+
"id": "project_id",
955+
"dataConnections": None,
956+
}
957+
kili_api_gateway = mocker.MagicMock()
958+
kili_api_gateway.count_assets.return_value = 3
959+
kili_api_gateway.get_project.return_value = get_project_return_val
960+
kili_api_gateway.list_assets.return_value = updated_mock_fetch_assets
961+
962+
kili_llm = LlmClientMethods(kili_api_gateway)
963+
964+
result = kili_llm.export(
965+
project_id="project_id",
966+
)
967+
assert result == updated_expected_export

0 commit comments

Comments
 (0)