Skip to content

Commit e874783

Browse files
authored
feat: ability to download datapoint embeddings (#779)
* draft - ability to get embedding keys + download embeddings * skip getting embedding key column, do not modify df when uploading * basic integ tests * raise not found error if embedding key does not exist + integ tests * add api ref to docs * add docs for split_part * remove text_summarization workflow test
1 parent 7b85eb8 commit e874783

File tree

9 files changed

+278
-56
lines changed

9 files changed

+278
-56
lines changed

.circleci/continue_config.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -442,12 +442,6 @@ workflows:
442442
subproject: [ age_estimation, automatic_speech_recognition, classification, keypoint_detection, question_answering, rain_forecast, semantic_segmentation, speaker_diarization, semantic_textual_similarity, person_detection, crossing_pedestrian_detection, named_entity_recognition ]
443443
resource-class: [ small ]
444444
python-version: [ "3.9.18" ]
445-
- example-test-workflow:
446-
matrix:
447-
parameters:
448-
subproject: [ text_summarization ]
449-
resource-class: [ large ]
450-
python-version: [ "3.9.18" ]
451445
- example-test-workflow:
452446
context:
453447
- aws

docs/dataset/advanced-usage/custom-queries.md

Lines changed: 27 additions & 26 deletions
Large diffs are not rendered by default.

docs/reference/dataset/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
::: kolena.dataset.evaluation
1010
::: kolena.dataset.embeddings
1111
options:
12-
members: ["upload_dataset_embeddings"]
12+
members: ["upload_dataset_embeddings", "get_dataset_embedding_keys", "download_dataset_embeddings"]
1313
show_root_heading: false
1414
::: kolena._api.v2.dataset
1515
options:

kolena/_api/v1/event.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ class Event(str, Enum):
7373

7474
# dataset search
7575
UPLOAD_DATASET_EMBEDDINGS = "sdk-dataset-embeddings-uploaded"
76+
FETCH_DATASET_EMBEDDINGS = "sdk-dataset-embeddings-fetched"
77+
GET_DATASET_EMBEDDING_KEYS = "sdk-dataset-embedding-keys-fetched"
7678

7779
@dataclass(frozen=True)
7880
class RecordEventRequest:

kolena/_api/v2/search.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from enum import Enum
15+
from typing import List
1516

1617
from kolena._api.v1.batched_load import BatchedLoad
1718
from kolena._utils.pydantic_v1.dataclasses import dataclass
1819

1920

2021
class Path(str, Enum):
2122
EMBEDDINGS = "/search/embeddings"
23+
GET_EMBEDDING_KEYS = "/search/get-embedding-model-keys"
24+
LOAD_EMBEDDINGS = "/search/load-embeddings"
2225

2326

2427
@dataclass(frozen=True)
@@ -30,3 +33,19 @@ class UploadDatasetEmbeddingsRequest(BatchedLoad.WithLoadUUID):
3033
@dataclass(frozen=True)
3134
class UploadDatasetEmbeddingsResponse:
3235
n_datapoints: int
36+
37+
38+
@dataclass(frozen=True)
39+
class DownloadDatasetEmbeddingsRequest(BatchedLoad.BaseInitDownloadRequest):
40+
dataset: str
41+
model_key: str
42+
43+
44+
@dataclass(frozen=True)
45+
class GetEmbeddingKeysRequest:
46+
dataset_identifier: str
47+
48+
49+
@dataclass(frozen=True)
50+
class GetEmbeddingKeysResponse:
51+
model_keys: List[str]

kolena/dataset/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from kolena.dataset.dataset import DatasetEntity
2323
from kolena.dataset.evaluation import ModelEntity
2424
from kolena.dataset.evaluation import get_models
25+
from kolena.dataset.embeddings import download_dataset_embeddings
26+
from kolena.dataset.embeddings import get_dataset_embedding_keys
2527
from kolena.dataset.embeddings import upload_dataset_embeddings
2628
from kolena._api.v2.dataset import Filters
2729
from kolena._api.v2.dataset import GeneralFieldFilter
@@ -40,4 +42,6 @@
4042
"ModelEntity",
4143
"get_models",
4244
"upload_dataset_embeddings",
45+
"get_dataset_embedding_keys",
46+
"download_dataset_embeddings",
4347
]

kolena/dataset/_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
COL_EVAL_CONFIG = "eval_config"
2727
COL_RESULT = "result"
2828
COL_THRESHOLDED_OBJECT = "thresholded_object"
29+
COL_EMBEDDING = "embedding"
30+
COL_EMBEDDING_KEY = "key"
2931
_MAX_DUPLICATE_ID_REPORT = 10
3032

3133
DEFAULT_SOURCES = [dict(type="sdk")]

kolena/dataset/embeddings.py

Lines changed: 109 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414
import dataclasses
1515
import json
1616
import pickle
17+
from base64 import b64decode
1718
from base64 import b64encode
1819
from typing import Any
20+
from typing import Iterator
21+
from typing import List
1922
from typing import Set
2023

2124
import numpy as np
@@ -25,21 +28,32 @@
2528
from pandera.typing import Series
2629

2730
from kolena._api.v1.event import EventAPI
31+
from kolena._api.v2.search import DownloadDatasetEmbeddingsRequest
32+
from kolena._api.v2.search import GetEmbeddingKeysRequest
33+
from kolena._api.v2.search import GetEmbeddingKeysResponse
2834
from kolena._api.v2.search import Path as PATH_V2
2935
from kolena._api.v2.search import UploadDatasetEmbeddingsRequest
3036
from kolena._api.v2.search import UploadDatasetEmbeddingsResponse
3137
from kolena._utils import krequests
3238
from kolena._utils import log
39+
from kolena._utils.batched_load import _BatchedLoader
3340
from kolena._utils.batched_load import init_upload
3441
from kolena._utils.batched_load import upload_data_frame
42+
from kolena._utils.consts import BatchSize
3543
from kolena._utils.dataframes.validators import validate_df_schema
3644
from kolena._utils.instrumentation import with_event
3745
from kolena._utils.state import API_V2
46+
from kolena.dataset._common import COL_DATAPOINT
3847
from kolena.dataset._common import COL_DATAPOINT_ID_OBJECT
48+
from kolena.dataset._common import COL_EMBEDDING
49+
from kolena.dataset._common import COL_EMBEDDING_KEY
50+
from kolena.dataset._common import validate_batch_size
3951
from kolena.dataset._common import validate_dataframe_ids
4052
from kolena.dataset.dataset import _load_dataset_metadata
53+
from kolena.dataset.dataset import _to_deserialized_dataframe
4154
from kolena.dataset.dataset import _to_serialized_dataframe
4255
from kolena.errors import InputValidationError
56+
from kolena.errors import NotFoundError
4357

4458
# Ensure check method is registered or else would get SchemaInitError
4559
# noreorder
@@ -71,6 +85,7 @@ def _upload_dataset_embeddings(
7185
df_embedding: pd.DataFrame,
7286
run_embedding_reduction_pipeline: bool = True,
7387
) -> None:
88+
df_embedding = df_embedding.copy(deep=True)
7489
dataset_entity_data = _load_dataset_metadata(dataset_name)
7590
assert dataset_entity_data
7691
embedding_lengths: Set[int] = set()
@@ -82,7 +97,7 @@ def encode_embedding(embedding: Any) -> str:
8297
return b64encode(pickle.dumps(embedding.astype(np.float32))).decode("utf-8")
8398

8499
# encode embeddings to string
85-
df_embedding["embedding"] = df_embedding["embedding"].apply(encode_embedding)
100+
df_embedding[COL_EMBEDDING] = df_embedding[COL_EMBEDDING].apply(encode_embedding)
86101
if len(embedding_lengths) > 1:
87102
raise InputValidationError(f"embeddings are not of the same size, found {embedding_lengths}")
88103

@@ -95,8 +110,8 @@ def encode_embedding(embedding: Any) -> str:
95110
)
96111
df_embedding = pd.concat([df_embedding, df_serialized_datapoint_id_object], axis=1)
97112

98-
df_embedding["key"] = key
99-
df_embedding = df_embedding[[COL_DATAPOINT_ID_OBJECT, "key", "embedding"]]
113+
df_embedding[COL_EMBEDDING_KEY] = key
114+
df_embedding = df_embedding[[COL_DATAPOINT_ID_OBJECT, COL_EMBEDDING_KEY, COL_EMBEDDING]]
100115
df_validated = validate_df_schema(df_embedding, DatasetEmbeddingsDataFrameSchema)
101116

102117
log.info(f"uploading embeddings for dataset '{dataset_name}' and key '{key}'")
@@ -131,3 +146,94 @@ def upload_dataset_embeddings(dataset_name: str, key: str, df_embedding: pd.Data
131146
:raises InputValidationError: The provided input is not valid.
132147
"""
133148
_upload_dataset_embeddings(dataset_name, key, df_embedding)
149+
150+
151+
@with_event(event_name=EventAPI.Event.GET_DATASET_EMBEDDING_KEYS)
152+
def get_dataset_embedding_keys(dataset_name: str) -> List[str]:
153+
"""
154+
Get the list of embedding keys for a dataset.
155+
156+
:param dataset_name: String value indicating the name of the dataset.
157+
:return: Set of embedding keys associated with the dataset.
158+
:raises NotFoundError: The given dataset does not exist.
159+
"""
160+
log.info(f"fetching embedding keys for dataset '{dataset_name}'")
161+
return _get_dataset_embedding_keys(dataset_name)
162+
163+
164+
def _get_dataset_embedding_keys(dataset_name: str) -> List[str]:
165+
_load_dataset_metadata(dataset_name)
166+
167+
request = GetEmbeddingKeysRequest(dataset_identifier=dataset_name)
168+
response = krequests.put(
169+
PATH_V2.GET_EMBEDDING_KEYS,
170+
api_version=API_V2,
171+
json=dataclasses.asdict(request),
172+
)
173+
krequests.raise_for_status(response)
174+
return from_dict(GetEmbeddingKeysResponse, response.json()).model_keys
175+
176+
177+
@with_event(event_name=EventAPI.Event.FETCH_DATASET_EMBEDDINGS)
178+
def download_dataset_embeddings(dataset_name: str, key: str) -> pd.DataFrame:
179+
"""
180+
Download search embeddings for a dataset.
181+
182+
:param dataset_name: String value indicating the name of the dataset for which the embeddings will be .
183+
:param key: String value uniquely corresponding to the embedding vectors.
184+
:return: df_embedding: Dataframe containing id fields for identifying datapoints in the dataset and the associated
185+
embeddings as `numpy.typing.ArrayLike` of numeric values.
186+
:raises NotFoundError: The given dataset or embedding key does not exist.
187+
"""
188+
189+
log.info(f"downloading embeddings from dataset '{dataset_name}' with key '{key}'")
190+
existing_dataset = _load_dataset_metadata(dataset_name)
191+
assert existing_dataset
192+
id_fields = existing_dataset.id_fields
193+
194+
if key not in _get_dataset_embedding_keys(dataset_name):
195+
raise NotFoundError(
196+
f"embedding key '{key}' does not exist for dataset '{dataset_name}'",
197+
)
198+
199+
df = _fetch_embeddings(dataset_name, key)
200+
df_embeddings = pd.concat(
201+
[
202+
_to_deserialized_dataframe(df, column=COL_DATAPOINT)[id_fields],
203+
df[COL_EMBEDDING].apply(lambda s: pickle.loads(b64decode(s))),
204+
],
205+
axis=1,
206+
)
207+
return df_embeddings
208+
209+
210+
def _iter_embeddings_raw(dataset_name: str, key: str, batch_size: int) -> Iterator[pd.DataFrame]:
211+
validate_batch_size(batch_size)
212+
init_request = DownloadDatasetEmbeddingsRequest(
213+
dataset=dataset_name,
214+
model_key=key,
215+
batch_size=batch_size,
216+
)
217+
yield from _BatchedLoader.iter_data(
218+
init_request=init_request,
219+
endpoint_path=PATH_V2.LOAD_EMBEDDINGS.value,
220+
df_class=None,
221+
endpoint_api_version=API_V2,
222+
)
223+
224+
225+
def _fetch_embeddings(dataset_name: str, key: str) -> pd.DataFrame:
226+
df_result_batch = list(
227+
_iter_embeddings_raw(
228+
dataset_name,
229+
key,
230+
batch_size=BatchSize.LOAD_RECORDS,
231+
),
232+
)
233+
return (
234+
pd.concat(df_result_batch)
235+
if df_result_batch
236+
else pd.DataFrame(
237+
columns=["datapoint_id", COL_DATAPOINT, COL_EMBEDDING],
238+
)
239+
)

0 commit comments

Comments
 (0)