Skip to content

Commit c9865e6

Browse files
authored
feat: return dataset entity when creating a dataset (#768)
* return dataset id when upload * return dataset entity instead * fix test * add back missing diff
1 parent 295eb8f commit c9865e6

File tree

8 files changed

+37
-22
lines changed

8 files changed

+37
-22
lines changed

docs/reference/dataset/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@
1313
show_root_heading: false
1414
::: kolena._api.v2.dataset
1515
options:
16-
members: ["Filters", "GeneralFieldFilter"]
16+
members: ["Filters", "GeneralFieldFilter", "DatasetEntity"]
1717
show_root_heading: false

kolena/_api/v2/dataset.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,19 @@ class LoadDatasetByNameRequest:
9090

9191

9292
@dataclass(frozen=True)
93-
class EntityData:
93+
class DatasetEntity:
94+
"""
95+
The descriptor of a dataset on Kolena.
96+
"""
97+
9498
id: int
99+
"""ID of the dataset."""
95100
name: str
101+
"""Name of the dataset."""
96102
description: str
103+
"""Description of the dataset."""
97104
id_fields: List[str]
105+
"""ID fields of the dataset."""
98106

99107

100108
@dataclass(frozen=True)

kolena/_experimental/quality_standard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from kolena._api.v2._testing import StratifyFieldSpec
3636
from kolena._api.v2._testing import TestingRequest
3737
from kolena._api.v2._testing import TestingResponse
38-
from kolena._api.v2.dataset import EntityData
38+
from kolena._api.v2.dataset import DatasetEntity
3939
from kolena._api.v2.model import ModelWithEvalConfig
4040
from kolena._api.v2.quality_standard import CopyQualityStandardRequest
4141
from kolena._api.v2.quality_standard import Path
@@ -150,7 +150,7 @@ def _download_quality_standard(
150150

151151
def _calculate_moe_map(
152152
qs_result: pd.DataFrame,
153-
dataset_entity: EntityData,
153+
dataset_entity: DatasetEntity,
154154
confidence_level: float,
155155
qs: QualityStandardResponse,
156156
) -> Dict[Tuple[str, Any], float]:

kolena/dataset/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from kolena.dataset.evaluation import download_results
1919
from kolena.dataset.evaluation import EvalConfigResults
2020
from kolena.dataset.dataset import list_datasets
21+
from kolena.dataset.dataset import DatasetEntity
2122
from kolena.dataset.evaluation import ModelEntity
2223
from kolena.dataset.evaluation import get_models
2324
from kolena.dataset.embeddings import upload_dataset_embeddings
@@ -33,6 +34,7 @@
3334
"download_results",
3435
"EvalConfigResults",
3536
"list_datasets",
37+
"DatasetEntity",
3638
"ModelEntity",
3739
"get_models",
3840
"upload_dataset_embeddings",

kolena/dataset/dataset.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from kolena._api.v1.event import EventAPI
3232
from kolena._api.v2.dataset import CommitData
33-
from kolena._api.v2.dataset import EntityData
33+
from kolena._api.v2.dataset import DatasetEntity
3434
from kolena._api.v2.dataset import Filters
3535
from kolena._api.v2.dataset import ListCommitHistoryRequest
3636
from kolena._api.v2.dataset import ListCommitHistoryResponse
@@ -189,7 +189,7 @@ def _upload_dataset_chunk(df: pd.DataFrame, load_uuid: str, id_fields: List[str]
189189
upload_data_frame(df=df_serialized, load_uuid=load_uuid)
190190

191191

192-
def _load_dataset_metadata(name: str, raise_error_if_not_found: bool = True) -> Optional[EntityData]:
192+
def _load_dataset_metadata(name: str, raise_error_if_not_found: bool = True) -> Optional[DatasetEntity]:
193193
"""
194194
Load the metadata of a given dataset.
195195
@@ -210,13 +210,13 @@ def _load_dataset_metadata(name: str, raise_error_if_not_found: bool = True) ->
210210
return None
211211
response.raise_for_status()
212212

213-
return from_dict(EntityData, response.json())
213+
return from_dict(DatasetEntity, response.json())
214214

215215

216216
def _resolve_id_fields(
217217
df: pd.DataFrame,
218218
id_fields: Optional[List[str]],
219-
existing_dataset: Optional[EntityData],
219+
existing_dataset: Optional[DatasetEntity],
220220
) -> List[str]:
221221
existing_id_fields = []
222222
if existing_dataset:
@@ -269,7 +269,7 @@ def _send_upload_dataset_request(
269269
commit_tags: Optional[List[str]] = None,
270270
dataset_tags: Optional[List[str]] = None,
271271
description: Optional[str] = None,
272-
) -> EntityData:
272+
) -> DatasetEntity:
273273
request = RegisterRequest(
274274
name=name,
275275
id_fields=id_fields,
@@ -282,8 +282,8 @@ def _send_upload_dataset_request(
282282
)
283283
response = krequests.post(Path.REGISTER, json=asdict(request))
284284
krequests.raise_for_status(response)
285-
data = from_dict(EntityData, response.json())
286-
return data
285+
dataset_entity = from_dict(DatasetEntity, response.json())
286+
return dataset_entity
287287

288288

289289
def _upload_dataset(
@@ -296,10 +296,10 @@ def _upload_dataset(
296296
commit_tags: Optional[List[str]] = None,
297297
dataset_tags: Optional[List[str]] = None,
298298
description: Optional[str] = None,
299-
) -> None:
299+
) -> DatasetEntity:
300300
prepared_id_fields, load_uuid = _prepare_upload_dataset_request(name, df, id_fields=id_fields)
301301

302-
data = _send_upload_dataset_request(
302+
dataset_entity = _send_upload_dataset_request(
303303
name,
304304
prepared_id_fields,
305305
load_uuid,
@@ -309,7 +309,8 @@ def _upload_dataset(
309309
dataset_tags=dataset_tags,
310310
description=description,
311311
)
312-
log.info(f"uploaded dataset '{name}' ({get_dataset_url(dataset_id=data.id)})")
312+
log.info(f"uploaded dataset '{name}' ({get_dataset_url(dataset_id=dataset_entity.id)})")
313+
return dataset_entity
313314

314315

315316
@with_event(event_name=EventAPI.Event.REGISTER_DATASET)
@@ -322,7 +323,7 @@ def upload_dataset(
322323
dataset_tags: Optional[List[str]] = None,
323324
append_only: bool = False,
324325
description: Optional[str] = None,
325-
) -> None:
326+
) -> DatasetEntity:
326327
"""
327328
Create or update a dataset with the contents of the provided DataFrame `df`.
328329
@@ -343,8 +344,10 @@ def upload_dataset(
343344
datapoints from the input dataframe will be added, and existing datapoints will be modified if present in the
344345
input dataframe, but no datapoints will be deleted from the datasets. This behaves like an `UPSERT` operation.
345346
:param description: Optionally specify the description of the dataset.
347+
348+
:return: The dataset as a [`DatasetEntity`][kolena.dataset.DatasetEntity] object.
346349
"""
347-
_upload_dataset(
350+
return _upload_dataset(
348351
name,
349352
df,
350353
id_fields=id_fields,

tests/integration/dataset/test_dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def test__upload_dataset() -> None:
110110
]
111111
columns = ["locator", "width", "height", "city", "bboxes", "time_str", "time_num"]
112112

113-
upload_dataset(name, pd.DataFrame(datapoints[:10], columns=columns), id_fields=["locator"])
113+
dataset_entity = upload_dataset(name, pd.DataFrame(datapoints[:10], columns=columns), id_fields=["locator"])
114+
assert dataset_entity.id == _load_dataset_metadata(name).id
115+
assert dataset_entity.name == name
114116

115117
loaded_datapoints = download_dataset(name).sort_values("width", ignore_index=True).reindex(columns=columns)
116118
expected = pd.DataFrame(expected_datapoints[:10], columns=columns)

tests/unit/_experimental/trace/test_trace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from unittest.mock import Mock
1616
from unittest.mock import patch
1717

18-
from kolena._api.v2.dataset import EntityData
18+
from kolena._api.v2.dataset import DatasetEntity
1919
from kolena._experimental.trace import kolena_trace
2020
from kolena._experimental.trace.trace import _Trace
2121

@@ -126,7 +126,7 @@ def predict(data, request_id): # type: ignore
126126
assert str(e) == "Id Field request_id cannot be None in datapoint input"
127127

128128
with patch("kolena._experimental.trace.trace._load_dataset_metadata") as mock_load_dataset_metadata:
129-
mock_load_dataset_metadata.return_value = EntityData(
129+
mock_load_dataset_metadata.return_value = DatasetEntity(
130130
id=1,
131131
name=dataset_name,
132132
description="test",

tests/unit/dataset/test_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import pytest
2121
from pandas.testing import assert_frame_equal
2222

23-
from kolena._api.v2.dataset import EntityData
23+
from kolena._api.v2.dataset import DatasetEntity
2424
from kolena._utils.datatypes import DATA_TYPE_FIELD
2525
from kolena.dataset._common import COL_DATAPOINT
2626
from kolena.dataset._common import COL_RESULT
@@ -357,7 +357,7 @@ def test__infer_id_fields__error(input_df: pd.DataFrame) -> None:
357357

358358
def test__resolve_id_fields() -> None:
359359
df = pd.DataFrame(dict(user_dp=["a", "b", "c"], new_user_dp=["d", "e", "f"]))
360-
dataset = EntityData(id=1, name="foo", description="", id_fields=["user_dp"])
360+
dataset = DatasetEntity(id=1, name="foo", description="", id_fields=["user_dp"])
361361
inferrable_df = pd.DataFrame(dict(locator=["x", "y", "z"]))
362362

363363
# new dataset without id_fields
@@ -371,7 +371,7 @@ def test__resolve_id_fields() -> None:
371371
assert _resolve_id_fields(
372372
inferrable_df,
373373
None,
374-
EntityData(id=1, name="foo", description="", id_fields=["locator"]),
374+
DatasetEntity(id=1, name="foo", description="", id_fields=["locator"]),
375375
) == ["locator"]
376376

377377
# new dataset with explicit id_fields should resolve to explicit id_fields

0 commit comments

Comments
 (0)