Skip to content

Commit 87ad4cd

Browse files
franciscojavierarceontkathole
authored andcommitted
chore: Improve async code for feature server
Signed-off-by: Francisco Javier Arceo <[email protected]>
1 parent eb0a86e commit 87ad4cd

File tree

2 files changed

+65
-28
lines changed

2 files changed

+65
-28
lines changed

sdk/python/feast/feature_server.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -120,27 +120,26 @@ class SaveDocumentRequest(BaseModel):
120120
data: dict
121121

122122

123-
def _get_features(
123+
async def _get_features(
124124
request: Union[GetOnlineFeaturesRequest, GetOnlineDocumentsRequest],
125125
store: "feast.FeatureStore",
126126
):
127127
if request.feature_service:
128-
feature_service = store.get_feature_service(
129-
request.feature_service, allow_cache=True
128+
feature_service = await run_in_threadpool(
129+
store.get_feature_service, request.feature_service, allow_cache=True
130130
)
131131
assert_permissions(
132132
resource=feature_service, actions=[AuthzedAction.READ_ONLINE]
133133
)
134134
features = feature_service # type: ignore
135135
else:
136-
all_feature_views, all_on_demand_feature_views = (
137-
utils._get_feature_views_to_use(
138-
store.registry,
139-
store.project,
140-
request.features,
141-
allow_cache=True,
142-
hide_dummy_entity=False,
143-
)
136+
all_feature_views, all_on_demand_feature_views = await run_in_threadpool(
137+
utils._get_feature_views_to_use,
138+
store.registry,
139+
store.project,
140+
request.features,
141+
allow_cache=True,
142+
hide_dummy_entity=False,
144143
)
145144
for feature_view in all_feature_views:
146145
assert_permissions(
@@ -230,7 +229,7 @@ async def lifespan(app: FastAPI):
230229
)
231230
async def get_online_features(request: GetOnlineFeaturesRequest) -> Dict[str, Any]:
232231
# Initialize parameters for FeatureStore.get_online_features(...) call
233-
features = await run_in_threadpool(_get_features, request, store)
232+
features = await _get_features(request, store)
234233

235234
read_params = dict(
236235
features=features,
@@ -265,7 +264,7 @@ async def retrieve_online_documents(
265264
"This endpoint is in alpha and will be moved to /get-online-features when stable."
266265
)
267266
# Initialize parameters for FeatureStore.retrieve_online_documents_v2(...) call
268-
features = await run_in_threadpool(_get_features, request, store)
267+
features = await _get_features(request, store)
269268

270269
read_params = dict(features=features, query=request.query, top_k=request.top_k)
271270
if request.api_version == 2 and request.query_string is not None:
@@ -342,26 +341,31 @@ async def push(request: PushFeaturesRequest) -> None:
342341
else:
343342
store.push(**push_params)
344343

345-
def _get_feast_object(
344+
async def _get_feast_object(
346345
feature_view_name: str, allow_registry_cache: bool
347346
) -> FeastObject:
348347
try:
349-
return store.get_stream_feature_view( # type: ignore
350-
feature_view_name, allow_registry_cache=allow_registry_cache
348+
return await run_in_threadpool(
349+
store.get_stream_feature_view,
350+
feature_view_name,
351+
allow_registry_cache=allow_registry_cache,
351352
)
352353
except FeatureViewNotFoundException:
353-
return store.get_feature_view( # type: ignore
354-
feature_view_name, allow_registry_cache=allow_registry_cache
354+
return await run_in_threadpool(
355+
store.get_feature_view,
356+
feature_view_name,
357+
allow_registry_cache=allow_registry_cache,
355358
)
356359

357360
@app.post("/write-to-online-store", dependencies=[Depends(inject_user_details)])
358-
def write_to_online_store(request: WriteToFeatureStoreRequest) -> None:
361+
async def write_to_online_store(request: WriteToFeatureStoreRequest) -> None:
359362
df = pd.DataFrame(request.df)
360363
feature_view_name = request.feature_view_name
361364
allow_registry_cache = request.allow_registry_cache
362-
resource = _get_feast_object(feature_view_name, allow_registry_cache)
365+
resource = await _get_feast_object(feature_view_name, allow_registry_cache)
363366
assert_permissions(resource=resource, actions=[AuthzedAction.WRITE_ONLINE])
364-
store.write_to_online_store(
367+
await run_in_threadpool(
368+
store.write_to_online_store,
365369
feature_view_name=feature_view_name,
366370
df=df,
367371
allow_registry_cache=allow_registry_cache,
@@ -428,10 +432,11 @@ async def chat_ui():
428432
return Response(content=content, media_type="text/html")
429433

430434
@app.post("/materialize", dependencies=[Depends(inject_user_details)])
431-
def materialize(request: MaterializeRequest) -> None:
435+
async def materialize(request: MaterializeRequest) -> None:
432436
for feature_view in request.feature_views or []:
437+
resource = await _get_feast_object(feature_view, True)
433438
assert_permissions(
434-
resource=_get_feast_object(feature_view, True),
439+
resource=resource,
435440
actions=[AuthzedAction.WRITE_ONLINE],
436441
)
437442

@@ -450,22 +455,26 @@ def materialize(request: MaterializeRequest) -> None:
450455
start_date = utils.make_tzaware(parser.parse(request.start_ts))
451456
end_date = utils.make_tzaware(parser.parse(request.end_ts))
452457

453-
store.materialize(
458+
await run_in_threadpool(
459+
store.materialize,
454460
start_date,
455461
end_date,
456462
request.feature_views,
457463
disable_event_timestamp=request.disable_event_timestamp,
458464
)
459465

460466
@app.post("/materialize-incremental", dependencies=[Depends(inject_user_details)])
461-
def materialize_incremental(request: MaterializeIncrementalRequest) -> None:
467+
async def materialize_incremental(request: MaterializeIncrementalRequest) -> None:
462468
for feature_view in request.feature_views or []:
469+
resource = await _get_feast_object(feature_view, True)
463470
assert_permissions(
464-
resource=_get_feast_object(feature_view, True),
471+
resource=resource,
465472
actions=[AuthzedAction.WRITE_ONLINE],
466473
)
467-
store.materialize_incremental(
468-
utils.make_tzaware(parser.parse(request.end_ts)), request.feature_views
474+
await run_in_threadpool(
475+
store.materialize_incremental,
476+
utils.make_tzaware(parser.parse(request.end_ts)),
477+
request.feature_views,
469478
)
470479

471480
@app.exception_handler(Exception)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from unittest.mock import AsyncMock, MagicMock
2+
3+
from fastapi.testclient import TestClient
4+
5+
from feast.feature_server import get_app
6+
from feast.online_response import OnlineResponse
7+
from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesResponse
8+
9+
10+
def test_async_get_online_features():
11+
"""Test that async get_online_features endpoint works correctly"""
12+
fs = MagicMock()
13+
fs._get_provider.return_value.async_supported.online.read = True
14+
fs.get_online_features_async = AsyncMock(
15+
return_value=OnlineResponse(GetOnlineFeaturesResponse())
16+
)
17+
fs.get_feature_service = MagicMock()
18+
fs.initialize = AsyncMock()
19+
fs.close = AsyncMock()
20+
21+
client = TestClient(get_app(fs))
22+
response = client.post(
23+
"/get-online-features",
24+
json={"features": ["test:feature"], "entities": {"entity_id": [123]}},
25+
)
26+
27+
assert response.status_code == 200
28+
assert fs.get_online_features_async.await_count == 1

0 commit comments

Comments
 (0)