Skip to content

Commit c3b128b

Browse files
HNicolasNicolas Hervé
andauthored
feat(lab-2766): support llm projects in append_many_to_dataset (#1680)
Co-authored-by: Nicolas Hervé <[email protected]>
1 parent 50ea73b commit c3b128b

File tree

8 files changed

+147
-4
lines changed

8 files changed

+147
-4
lines changed

src/kili/core/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"Pdf": "application/pdf",
1616
"Text": "text/plain",
1717
"TimeSeries": "text/csv",
18+
"LLM": "application/json",
1819
}
1920

2021
mime_extensions_for_IV2 = {
@@ -27,6 +28,7 @@
2728
"URL": "",
2829
"VIDEO": mime_extensions["Video"],
2930
"VIDEO_LEGACY": "",
31+
"LLM_RLHF": mime_extensions["LLM"],
3032
}
3133

3234
mime_extensions_for_py_scripts = ["text/x-python"]

src/kili/domain/project.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .tag import TagId
1111

1212
ProjectId = NewType("ProjectId", str)
13-
InputType = Literal["IMAGE", "PDF", "TEXT", "VIDEO"]
13+
InputType = Literal["IMAGE", "PDF", "TEXT", "VIDEO", "LLM_RLHF"]
1414

1515

1616
class InputTypeEnum(str, Enum):
@@ -20,6 +20,7 @@ class InputTypeEnum(str, Enum):
2020
PDF = "PDF"
2121
TEXT = "TEXT"
2222
VIDEO = "VIDEO"
23+
LLM_RLHF = "LLM_RLHF"
2324

2425

2526
ComplianceTag = Literal["PHI", "PII"]

src/kili/entrypoints/mutations/asset/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class MutationsAsset(BaseOperationEntrypointMixin):
4040
def append_many_to_dataset(
4141
self,
4242
project_id: str,
43-
content_array: Optional[List[str]] = None,
43+
content_array: Optional[Union[List[str], List[dict]]] = None,
4444
multi_layer_content_array: Optional[List[List[dict]]] = None,
4545
external_id_array: Optional[List[str]] = None,
4646
id_array: Optional[List[str]] = None,
@@ -67,7 +67,9 @@ def append_many_to_dataset(
6767
- For a VIDEO project, the content can be either URLs pointing to videos hosted on a web server or paths to
6868
existing video files on your computer. If you want to import video from frames, look at the json_content
6969
section below.
70-
- For an `VIDEO_LEGACY` project, the content can be only be URLs
70+
- For an `VIDEO_LEGACY` project, the content can be only be URLs.
71+
- For an `LLM_RLHF` project, the content can be dicts with the keys `prompt` and `completions`,
72+
paths to local json files or URLs to json files.
7173
multi_layer_content_array: List containing multiple lists of paths.
7274
Each path correspond to a layer of a geosat asset. Should be used only for `IMAGE` projects.
7375
external_id_array: List of external ids given to identify the assets.

src/kili/services/asset_import/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ProjectParams,
1515
)
1616
from .image import ImageDataImporter
17+
from .llm import LLMDataImporter
1718
from .pdf import PdfDataImporter
1819
from .text import TextDataImporter
1920
from .types import AssetLike
@@ -28,6 +29,7 @@
2829
"TEXT": TextDataImporter,
2930
"VIDEO": VideoDataImporter,
3031
"VIDEO_LEGACY": VideoDataImporter,
32+
"LLM_RLHF": LLMDataImporter,
3133
}
3234

3335

src/kili/services/asset_import/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ def _get_organization(self, email: str, options: QueryOptions) -> Dict:
486486
)
487487

488488
def _check_upload_is_allowed(self, assets: List[AssetLike]) -> None:
489+
# TODO: avoid querying API for each asset to upload when doing this check
489490
if not self.is_hosted_content(assets) and not self._can_upload_from_local_data():
490491
raise UploadFromLocalDataForbiddenError("Cannot upload content from local data")
491492

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Functions to import assets into a TEXT project."""
2+
3+
import json
4+
import os
5+
from enum import Enum
6+
from typing import List, Optional, Tuple
7+
8+
from kili.core.helpers import is_url
9+
10+
from .base import (
11+
BaseAbstractAssetImporter,
12+
BatchParams,
13+
ContentBatchImporter,
14+
)
15+
from .exceptions import ImportValidationError
16+
from .types import AssetLike
17+
18+
19+
class LLMDataType(Enum):
20+
"""LLM data type."""
21+
22+
DICT = "DICT"
23+
LOCAL_FILE = "LOCAL_FILE"
24+
HOSTED_FILE = "HOSTED_FILE"
25+
26+
27+
class JSONBatchImporter(ContentBatchImporter):
28+
"""Class for importing a batch of LLM assets with dict content into a LLM_RLHF project."""
29+
30+
def get_content_type_and_data_from_content(self, content: Optional[str]) -> Tuple[str, str]:
31+
"""Returns the data of the content (path) and its content type."""
32+
return content or "", "application/json"
33+
34+
35+
class LLMDataImporter(BaseAbstractAssetImporter):
36+
"""Class for importing data into a TEXT project."""
37+
38+
@staticmethod
39+
def get_data_type(assets: List[AssetLike]) -> LLMDataType:
40+
"""Determine the type of data to upload from the service payload."""
41+
content_array = [asset.get("content", None) for asset in assets]
42+
if all(is_url(content) for content in content_array):
43+
return LLMDataType.HOSTED_FILE
44+
if all(isinstance(content, str) and os.path.exists(content) for content in content_array):
45+
return LLMDataType.LOCAL_FILE
46+
if all(isinstance(content, dict) for content in content_array):
47+
return LLMDataType.DICT
48+
raise ImportValidationError("Invalid value in content for LLM project.")
49+
50+
def import_assets(self, assets: List[AssetLike]):
51+
"""Import LLM assets into Kili."""
52+
self._check_upload_is_allowed(assets)
53+
data_type = self.get_data_type(assets)
54+
assets = self.filter_duplicate_external_ids(assets)
55+
if data_type == LLMDataType.LOCAL_FILE:
56+
assets = self.filter_local_assets(assets, self.raise_error)
57+
batch_params = BatchParams(is_hosted=False, is_asynchronous=False)
58+
batch_importer = ContentBatchImporter(
59+
self.kili, self.project_params, batch_params, self.pbar
60+
)
61+
elif data_type == LLMDataType.HOSTED_FILE:
62+
batch_params = BatchParams(is_hosted=True, is_asynchronous=False)
63+
batch_importer = ContentBatchImporter(
64+
self.kili, self.project_params, batch_params, self.pbar
65+
)
66+
elif data_type == LLMDataType.DICT:
67+
for asset in assets:
68+
if "content" in asset and isinstance(asset["content"], dict):
69+
asset["content"] = json.dumps(asset["content"]).encode("utf-8")
70+
batch_params = BatchParams(is_hosted=False, is_asynchronous=False)
71+
batch_importer = JSONBatchImporter(
72+
self.kili, self.project_params, batch_params, self.pbar
73+
)
74+
else:
75+
raise ImportValidationError
76+
return self.import_assets_by_batch(assets, batch_importer)

src/kili/services/asset_import/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
class AssetLike(TypedDict, total=False):
99
"""General type of an asset object through the import functions."""
1010

11-
content: Union[str, bytes]
11+
content: Union[str, bytes, dict]
1212
multi_layer_content: Union[List[dict], None]
1313
json_content: Union[dict, str, list]
1414
external_id: str
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from unittest.mock import patch
2+
3+
from kili.services.asset_import import import_assets
4+
from tests.unit.services.asset_import.base import ImportTestCase
5+
from tests.unit.services.asset_import.mocks import (
6+
mocked_request_signed_urls,
7+
mocked_unique_id,
8+
mocked_upload_data_via_rest,
9+
)
10+
11+
12+
@patch("kili.utils.bucket.request_signed_urls", mocked_request_signed_urls)
13+
@patch("kili.utils.bucket.upload_data_via_rest", mocked_upload_data_via_rest)
14+
@patch("kili.utils.bucket.generate_unique_id", mocked_unique_id)
15+
class LLMTestCase(ImportTestCase):
16+
def test_upload_from_one_local_file(self, *_):
17+
self.kili.kili_api_gateway.get_project.return_value = {"inputType": "LLM_RLHF"}
18+
url = "https://storage.googleapis.com/label-public-staging/asset-test-sample/llm/test_llm_file.json"
19+
path = self.downloader(url)
20+
assets = [{"content": path, "external_id": "local llm file"}]
21+
import_assets(self.kili, self.project_id, assets)
22+
expected_parameters = self.get_expected_sync_call(
23+
["https://signed_url?id=id"],
24+
["local llm file"],
25+
["unique_id"],
26+
[False],
27+
[""],
28+
["{}"],
29+
)
30+
self.kili.graphql_client.execute.assert_called_with(*expected_parameters)
31+
32+
def test_upload_from_one_hosted_text_file(self, *_):
33+
self.kili.kili_api_gateway.get_project.return_value = {"inputType": "LLM_RLHF"}
34+
assets = [
35+
{"content": "https://hosted-data", "external_id": "hosted file", "id": "unique_id"}
36+
]
37+
import_assets(self.kili, self.project_id, assets)
38+
expected_parameters = self.get_expected_sync_call(
39+
["https://hosted-data"], ["hosted file"], ["unique_id"], [False], [""], ["{}"]
40+
)
41+
self.kili.graphql_client.execute.assert_called_with(*expected_parameters)
42+
43+
def test_upload_from_dict(self, *_):
44+
self.kili.kili_api_gateway.get_project.return_value = {"inputType": "LLM_RLHF"}
45+
assets = [
46+
{
47+
"content": {
48+
"prompt": "does it contain code ?",
49+
"completions": ["first completion", "second completion", "#this is markdown"],
50+
"type": "markdown",
51+
},
52+
"external_id": "dict",
53+
}
54+
]
55+
import_assets(self.kili, self.project_id, assets)
56+
expected_parameters = self.get_expected_sync_call(
57+
["https://signed_url?id=id"], ["dict"], ["unique_id"], [False], [""], ["{}"]
58+
)
59+
self.kili.graphql_client.execute.assert_called_with(*expected_parameters)

0 commit comments

Comments
 (0)