Skip to content

Commit 833696c

Browse files
committed
fix: Fix list saved dataset api
Signed-off-by: ntkathole <[email protected]>
1 parent be004ef commit 833696c

File tree

4 files changed

+137
-2
lines changed

4 files changed

+137
-2
lines changed

sdk/python/feast/api/registry/rest/saved_datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,6 @@ def list_saved_datasets(
3232
tags=tags,
3333
)
3434
response = grpc_call(grpc_handler.ListSavedDatasets, req)
35-
return {"saved_datasets": response.get("saved_datasets", [])}
35+
return {"saved_datasets": response.get("savedDatasets", [])}
3636

3737
return router

sdk/python/feast/infra/registry/proto_registry_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def list_saved_datasets(
320320
saved_datasets = []
321321
for saved_dataset in registry_proto.saved_datasets:
322322
if saved_dataset.spec.project == project and utils.has_all_tags(
323-
saved_dataset.tags, tags
323+
saved_dataset.spec.tags, tags
324324
):
325325
saved_datasets.append(SavedDataset.from_proto(saved_dataset))
326326
return saved_datasets

sdk/python/tests/unit/api/test_api_rest_registry.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from feast import Entity, FeatureService, FeatureView, Field, FileSource
88
from feast.api.registry.rest.rest_registry_server import RestRegistryServer
99
from feast.feature_store import FeatureStore
10+
from feast.infra.offline_stores.file_source import SavedDatasetFileStorage
1011
from feast.repo_config import RepoConfig
12+
from feast.saved_dataset import SavedDataset
1113
from feast.types import Float64, Int64
1214
from feast.value_type import ValueType
1315

@@ -67,8 +69,19 @@ def fastapi_test_app():
6769
features=[user_profile_feature_view],
6870
)
6971

72+
# Create a saved dataset for testing
73+
saved_dataset_storage = SavedDatasetFileStorage(path=parquet_file_path)
74+
test_saved_dataset = SavedDataset(
75+
name="test_saved_dataset",
76+
features=["user_profile:age", "user_profile:income"],
77+
join_keys=["user_id"],
78+
storage=saved_dataset_storage,
79+
tags={"environment": "test", "version": "1.0"},
80+
)
81+
7082
# Apply objects
7183
store.apply([user_id_entity, user_profile_feature_view, user_feature_service])
84+
store._registry.apply_saved_dataset(test_saved_dataset, "demo_project")
7285

7386
# Build REST app with registered routes
7487
rest_server = RestRegistryServer(store)
@@ -248,3 +261,69 @@ def test_lineage_endpoint_error_handling(fastapi_test_app):
248261
# Test object relationships with missing parameters
249262
response = fastapi_test_app.get("/lineage/objects/featureView/test_fv")
250263
assert response.status_code == 422 # Missing required project parameter
264+
265+
266+
def test_saved_datasets_via_rest(fastapi_test_app):
267+
# Test list saved datasets endpoint
268+
response = fastapi_test_app.get("/saved_datasets?project=demo_project")
269+
assert response.status_code == 200
270+
response_data = response.json()
271+
assert "saved_datasets" in response_data
272+
assert isinstance(response_data["saved_datasets"], list)
273+
assert len(response_data["saved_datasets"]) == 1
274+
275+
saved_dataset = response_data["saved_datasets"][0]
276+
assert saved_dataset["spec"]["name"] == "test_saved_dataset"
277+
assert "user_profile:age" in saved_dataset["spec"]["features"]
278+
assert "user_profile:income" in saved_dataset["spec"]["features"]
279+
assert "user_id" in saved_dataset["spec"]["joinKeys"]
280+
assert saved_dataset["spec"]["tags"]["environment"] == "test"
281+
assert saved_dataset["spec"]["tags"]["version"] == "1.0"
282+
283+
# Test get specific saved dataset endpoint
284+
response = fastapi_test_app.get(
285+
"/saved_datasets/test_saved_dataset?project=demo_project"
286+
)
287+
assert response.status_code == 200
288+
response_data = response.json()
289+
assert response_data["spec"]["name"] == "test_saved_dataset"
290+
assert "user_profile:age" in response_data["spec"]["features"]
291+
assert "user_profile:income" in response_data["spec"]["features"]
292+
293+
# Test with allow_cache parameter
294+
response = fastapi_test_app.get(
295+
"/saved_datasets/test_saved_dataset?project=demo_project&allow_cache=false"
296+
)
297+
assert response.status_code == 200
298+
assert response.json()["spec"]["name"] == "test_saved_dataset"
299+
300+
# Test with tags filter
301+
response = fastapi_test_app.get(
302+
"/saved_datasets?project=demo_project&tags=environment:test"
303+
)
304+
assert response.status_code == 200
305+
assert len(response.json()["saved_datasets"]) == 1
306+
307+
# Test with non-matching tags filter
308+
response = fastapi_test_app.get(
309+
"/saved_datasets?project=demo_project&tags=environment:production"
310+
)
311+
assert response.status_code == 200
312+
assert len(response.json()["saved_datasets"]) == 0
313+
314+
# Test with multiple tags filter
315+
response = fastapi_test_app.get(
316+
"/saved_datasets?project=demo_project&tags=environment:test&tags=version:1.0"
317+
)
318+
assert response.status_code == 200
319+
assert len(response.json()["saved_datasets"]) == 1
320+
321+
# Test non-existent saved dataset
322+
response = fastapi_test_app.get("/saved_datasets/non_existent?project=demo_project")
323+
assert response.status_code == 404
324+
325+
# Test missing project parameter
326+
response = fastapi_test_app.get("/saved_datasets/test_saved_dataset")
327+
assert (
328+
response.status_code == 422
329+
) # Unprocessable Entity for missing required query param
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from feast.infra.registry.proto_registry_utils import list_saved_datasets
2+
from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto
3+
from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto
4+
from feast.protos.feast.core.SavedDataset_pb2 import (
5+
SavedDatasetSpec,
6+
SavedDatasetStorage,
7+
)
8+
from feast.saved_dataset import SavedDataset
9+
10+
11+
class TestRegistryProto:
12+
"""Test class for proto_registry_utils functions"""
13+
14+
def test_list_saved_datasets_uses_spec_tags(self):
15+
"""Test that list_saved_datasets correctly uses saved_dataset.spec.tags for tag filtering"""
16+
registry = RegistryProto()
17+
registry.version_id = "test_version"
18+
saved_dataset = SavedDatasetProto()
19+
spec = SavedDatasetSpec()
20+
spec.name = "test_dataset"
21+
spec.project = "test_project"
22+
spec.features.extend(["feature1", "feature2"])
23+
spec.join_keys.extend(["entity1"])
24+
spec.full_feature_names = False
25+
spec.tags["environment"] = "production"
26+
spec.tags["team"] = "ml-team"
27+
storage = SavedDatasetStorage()
28+
file_options = storage.file_storage
29+
file_options.uri = "test_path.parquet"
30+
spec.storage.CopyFrom(storage)
31+
32+
saved_dataset.spec.CopyFrom(spec)
33+
registry.saved_datasets.append(saved_dataset)
34+
35+
# Test that filtering by tags works correctly using spec.tags
36+
result = list_saved_datasets(
37+
registry, "test_project", {"environment": "production"}
38+
)
39+
40+
assert len(result) == 1
41+
assert isinstance(result[0], SavedDataset)
42+
assert result[0].name == "test_dataset"
43+
assert result[0].tags == {"environment": "production", "team": "ml-team"}
44+
45+
# Test that non-matching tags filter correctly
46+
result = list_saved_datasets(
47+
registry, "test_project", {"environment": "staging"}
48+
)
49+
assert len(result) == 0
50+
51+
# Test that multiple tag filtering works
52+
result = list_saved_datasets(
53+
registry, "test_project", {"environment": "production", "team": "ml-team"}
54+
)
55+
assert len(result) == 1
56+
assert result[0].name == "test_dataset"

0 commit comments

Comments
 (0)