1414import dataclasses
1515import json
1616import pickle
17+ from base64 import b64decode
1718from base64 import b64encode
1819from typing import Any
20+ from typing import Iterator
21+ from typing import List
1922from typing import Set
2023
2124import numpy as np
2528from pandera .typing import Series
2629
2730from kolena ._api .v1 .event import EventAPI
31+ from kolena ._api .v2 .search import DownloadDatasetEmbeddingsRequest
32+ from kolena ._api .v2 .search import GetEmbeddingKeysRequest
33+ from kolena ._api .v2 .search import GetEmbeddingKeysResponse
2834from kolena ._api .v2 .search import Path as PATH_V2
2935from kolena ._api .v2 .search import UploadDatasetEmbeddingsRequest
3036from kolena ._api .v2 .search import UploadDatasetEmbeddingsResponse
3137from kolena ._utils import krequests
3238from kolena ._utils import log
39+ from kolena ._utils .batched_load import _BatchedLoader
3340from kolena ._utils .batched_load import init_upload
3441from kolena ._utils .batched_load import upload_data_frame
42+ from kolena ._utils .consts import BatchSize
3543from kolena ._utils .dataframes .validators import validate_df_schema
3644from kolena ._utils .instrumentation import with_event
3745from kolena ._utils .state import API_V2
46+ from kolena .dataset ._common import COL_DATAPOINT
3847from kolena .dataset ._common import COL_DATAPOINT_ID_OBJECT
48+ from kolena .dataset ._common import COL_EMBEDDING
49+ from kolena .dataset ._common import COL_EMBEDDING_KEY
50+ from kolena .dataset ._common import validate_batch_size
3951from kolena .dataset ._common import validate_dataframe_ids
4052from kolena .dataset .dataset import _load_dataset_metadata
53+ from kolena .dataset .dataset import _to_deserialized_dataframe
4154from kolena .dataset .dataset import _to_serialized_dataframe
4255from kolena .errors import InputValidationError
56+ from kolena .errors import NotFoundError
4357
4458# Ensure check method is registered or else would get SchemaInitError
4559# noreorder
@@ -71,6 +85,7 @@ def _upload_dataset_embeddings(
7185 df_embedding : pd .DataFrame ,
7286 run_embedding_reduction_pipeline : bool = True ,
7387) -> None :
88+ df_embedding = df_embedding .copy (deep = True )
7489 dataset_entity_data = _load_dataset_metadata (dataset_name )
7590 assert dataset_entity_data
7691 embedding_lengths : Set [int ] = set ()
@@ -82,7 +97,7 @@ def encode_embedding(embedding: Any) -> str:
8297 return b64encode (pickle .dumps (embedding .astype (np .float32 ))).decode ("utf-8" )
8398
8499 # encode embeddings to string
85- df_embedding ["embedding" ] = df_embedding ["embedding" ].apply (encode_embedding )
100+ df_embedding [COL_EMBEDDING ] = df_embedding [COL_EMBEDDING ].apply (encode_embedding )
86101 if len (embedding_lengths ) > 1 :
87102 raise InputValidationError (f"embeddings are not of the same size, found { embedding_lengths } " )
88103
@@ -95,8 +110,8 @@ def encode_embedding(embedding: Any) -> str:
95110 )
96111 df_embedding = pd .concat ([df_embedding , df_serialized_datapoint_id_object ], axis = 1 )
97112
98- df_embedding ["key" ] = key
99- df_embedding = df_embedding [[COL_DATAPOINT_ID_OBJECT , "key" , "embedding" ]]
113+ df_embedding [COL_EMBEDDING_KEY ] = key
114+ df_embedding = df_embedding [[COL_DATAPOINT_ID_OBJECT , COL_EMBEDDING_KEY , COL_EMBEDDING ]]
100115 df_validated = validate_df_schema (df_embedding , DatasetEmbeddingsDataFrameSchema )
101116
102117 log .info (f"uploading embeddings for dataset '{ dataset_name } ' and key '{ key } '" )
@@ -131,3 +146,94 @@ def upload_dataset_embeddings(dataset_name: str, key: str, df_embedding: pd.Data
131146 :raises InputValidationError: The provided input is not valid.
132147 """
133148 _upload_dataset_embeddings (dataset_name , key , df_embedding )
149+
150+
151+ @with_event (event_name = EventAPI .Event .GET_DATASET_EMBEDDING_KEYS )
152+ def get_dataset_embedding_keys (dataset_name : str ) -> List [str ]:
153+ """
154+ Get the list of embedding keys for a dataset.
155+
156+ :param dataset_name: String value indicating the name of the dataset.
157+ :return: Set of embedding keys associated with the dataset.
158+ :raises NotFoundError: The given dataset does not exist.
159+ """
160+ log .info (f"fetching embedding keys for dataset '{ dataset_name } '" )
161+ return _get_dataset_embedding_keys (dataset_name )
162+
163+
164+ def _get_dataset_embedding_keys (dataset_name : str ) -> List [str ]:
165+ _load_dataset_metadata (dataset_name )
166+
167+ request = GetEmbeddingKeysRequest (dataset_identifier = dataset_name )
168+ response = krequests .put (
169+ PATH_V2 .GET_EMBEDDING_KEYS ,
170+ api_version = API_V2 ,
171+ json = dataclasses .asdict (request ),
172+ )
173+ krequests .raise_for_status (response )
174+ return from_dict (GetEmbeddingKeysResponse , response .json ()).model_keys
175+
176+
177+ @with_event (event_name = EventAPI .Event .FETCH_DATASET_EMBEDDINGS )
178+ def download_dataset_embeddings (dataset_name : str , key : str ) -> pd .DataFrame :
179+ """
180+ Download search embeddings for a dataset.
181+
182+ :param dataset_name: String value indicating the name of the dataset for which the embeddings will be .
183+ :param key: String value uniquely corresponding to the embedding vectors.
184+ :return: df_embedding: Dataframe containing id fields for identifying datapoints in the dataset and the associated
185+ embeddings as `numpy.typing.ArrayLike` of numeric values.
186+ :raises NotFoundError: The given dataset or embedding key does not exist.
187+ """
188+
189+ log .info (f"downloading embeddings from dataset '{ dataset_name } ' with key '{ key } '" )
190+ existing_dataset = _load_dataset_metadata (dataset_name )
191+ assert existing_dataset
192+ id_fields = existing_dataset .id_fields
193+
194+ if key not in _get_dataset_embedding_keys (dataset_name ):
195+ raise NotFoundError (
196+ f"embedding key '{ key } ' does not exist for dataset '{ dataset_name } '" ,
197+ )
198+
199+ df = _fetch_embeddings (dataset_name , key )
200+ df_embeddings = pd .concat (
201+ [
202+ _to_deserialized_dataframe (df , column = COL_DATAPOINT )[id_fields ],
203+ df [COL_EMBEDDING ].apply (lambda s : pickle .loads (b64decode (s ))),
204+ ],
205+ axis = 1 ,
206+ )
207+ return df_embeddings
208+
209+
210+ def _iter_embeddings_raw (dataset_name : str , key : str , batch_size : int ) -> Iterator [pd .DataFrame ]:
211+ validate_batch_size (batch_size )
212+ init_request = DownloadDatasetEmbeddingsRequest (
213+ dataset = dataset_name ,
214+ model_key = key ,
215+ batch_size = batch_size ,
216+ )
217+ yield from _BatchedLoader .iter_data (
218+ init_request = init_request ,
219+ endpoint_path = PATH_V2 .LOAD_EMBEDDINGS .value ,
220+ df_class = None ,
221+ endpoint_api_version = API_V2 ,
222+ )
223+
224+
225+ def _fetch_embeddings (dataset_name : str , key : str ) -> pd .DataFrame :
226+ df_result_batch = list (
227+ _iter_embeddings_raw (
228+ dataset_name ,
229+ key ,
230+ batch_size = BatchSize .LOAD_RECORDS ,
231+ ),
232+ )
233+ return (
234+ pd .concat (df_result_batch )
235+ if df_result_batch
236+ else pd .DataFrame (
237+ columns = ["datapoint_id" , COL_DATAPOINT , COL_EMBEDDING ],
238+ )
239+ )
0 commit comments