Skip to content

Commit 2ae33de

Browse files
chore: Improve async code for feature server (#5646)
* chore: Improve async code for feature server Signed-off-by: Francisco Javier Arceo <[email protected]> * attempting fix for dynamo test Signed-off-by: Francisco Javier Arceo <[email protected]> * attempting fix for dynamo test Signed-off-by: Francisco Javier Arceo <[email protected]> --------- Signed-off-by: Francisco Javier Arceo <[email protected]>
1 parent 0bda0f5 commit 2ae33de

File tree

3 files changed

+131
-92
lines changed

3 files changed

+131
-92
lines changed

sdk/python/feast/feature_server.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -120,27 +120,26 @@ class SaveDocumentRequest(BaseModel):
120120
data: dict
121121

122122

123-
def _get_features(
123+
async def _get_features(
124124
request: Union[GetOnlineFeaturesRequest, GetOnlineDocumentsRequest],
125125
store: "feast.FeatureStore",
126126
):
127127
if request.feature_service:
128-
feature_service = store.get_feature_service(
129-
request.feature_service, allow_cache=True
128+
feature_service = await run_in_threadpool(
129+
store.get_feature_service, request.feature_service, allow_cache=True
130130
)
131131
assert_permissions(
132132
resource=feature_service, actions=[AuthzedAction.READ_ONLINE]
133133
)
134134
features = feature_service # type: ignore
135135
else:
136-
all_feature_views, all_on_demand_feature_views = (
137-
utils._get_feature_views_to_use(
138-
store.registry,
139-
store.project,
140-
request.features,
141-
allow_cache=True,
142-
hide_dummy_entity=False,
143-
)
136+
all_feature_views, all_on_demand_feature_views = await run_in_threadpool(
137+
utils._get_feature_views_to_use,
138+
store.registry,
139+
store.project,
140+
request.features,
141+
allow_cache=True,
142+
hide_dummy_entity=False,
144143
)
145144
for feature_view in all_feature_views:
146145
assert_permissions(
@@ -230,7 +229,7 @@ async def lifespan(app: FastAPI):
230229
)
231230
async def get_online_features(request: GetOnlineFeaturesRequest) -> Dict[str, Any]:
232231
# Initialize parameters for FeatureStore.get_online_features(...) call
233-
features = await run_in_threadpool(_get_features, request, store)
232+
features = await _get_features(request, store)
234233

235234
read_params = dict(
236235
features=features,
@@ -265,7 +264,7 @@ async def retrieve_online_documents(
265264
"This endpoint is in alpha and will be moved to /get-online-features when stable."
266265
)
267266
# Initialize parameters for FeatureStore.retrieve_online_documents_v2(...) call
268-
features = await run_in_threadpool(_get_features, request, store)
267+
features = await _get_features(request, store)
269268

270269
read_params = dict(features=features, query=request.query, top_k=request.top_k)
271270
if request.api_version == 2 and request.query_string is not None:
@@ -342,26 +341,31 @@ async def push(request: PushFeaturesRequest) -> None:
342341
else:
343342
store.push(**push_params)
344343

345-
def _get_feast_object(
344+
async def _get_feast_object(
346345
feature_view_name: str, allow_registry_cache: bool
347346
) -> FeastObject:
348347
try:
349-
return store.get_stream_feature_view( # type: ignore
350-
feature_view_name, allow_registry_cache=allow_registry_cache
348+
return await run_in_threadpool(
349+
store.get_stream_feature_view,
350+
feature_view_name,
351+
allow_registry_cache=allow_registry_cache,
351352
)
352353
except FeatureViewNotFoundException:
353-
return store.get_feature_view( # type: ignore
354-
feature_view_name, allow_registry_cache=allow_registry_cache
354+
return await run_in_threadpool(
355+
store.get_feature_view,
356+
feature_view_name,
357+
allow_registry_cache=allow_registry_cache,
355358
)
356359

357360
@app.post("/write-to-online-store", dependencies=[Depends(inject_user_details)])
358-
def write_to_online_store(request: WriteToFeatureStoreRequest) -> None:
361+
async def write_to_online_store(request: WriteToFeatureStoreRequest) -> None:
359362
df = pd.DataFrame(request.df)
360363
feature_view_name = request.feature_view_name
361364
allow_registry_cache = request.allow_registry_cache
362-
resource = _get_feast_object(feature_view_name, allow_registry_cache)
365+
resource = await _get_feast_object(feature_view_name, allow_registry_cache)
363366
assert_permissions(resource=resource, actions=[AuthzedAction.WRITE_ONLINE])
364-
store.write_to_online_store(
367+
await run_in_threadpool(
368+
store.write_to_online_store,
365369
feature_view_name=feature_view_name,
366370
df=df,
367371
allow_registry_cache=allow_registry_cache,
@@ -428,10 +432,11 @@ async def chat_ui():
428432
return Response(content=content, media_type="text/html")
429433

430434
@app.post("/materialize", dependencies=[Depends(inject_user_details)])
431-
def materialize(request: MaterializeRequest) -> None:
435+
async def materialize(request: MaterializeRequest) -> None:
432436
for feature_view in request.feature_views or []:
437+
resource = await _get_feast_object(feature_view, True)
433438
assert_permissions(
434-
resource=_get_feast_object(feature_view, True),
439+
resource=resource,
435440
actions=[AuthzedAction.WRITE_ONLINE],
436441
)
437442

@@ -450,22 +455,26 @@ def materialize(request: MaterializeRequest) -> None:
450455
start_date = utils.make_tzaware(parser.parse(request.start_ts))
451456
end_date = utils.make_tzaware(parser.parse(request.end_ts))
452457

453-
store.materialize(
458+
await run_in_threadpool(
459+
store.materialize,
454460
start_date,
455461
end_date,
456462
request.feature_views,
457463
disable_event_timestamp=request.disable_event_timestamp,
458464
)
459465

460466
@app.post("/materialize-incremental", dependencies=[Depends(inject_user_details)])
461-
def materialize_incremental(request: MaterializeIncrementalRequest) -> None:
467+
async def materialize_incremental(request: MaterializeIncrementalRequest) -> None:
462468
for feature_view in request.feature_views or []:
469+
resource = await _get_feast_object(feature_view, True)
463470
assert_permissions(
464-
resource=_get_feast_object(feature_view, True),
471+
resource=resource,
465472
actions=[AuthzedAction.WRITE_ONLINE],
466473
)
467-
store.materialize_incremental(
468-
utils.make_tzaware(parser.parse(request.end_ts)), request.feature_views
474+
await run_in_threadpool(
475+
store.materialize_incremental,
476+
utils.make_tzaware(parser.parse(request.end_ts)),
477+
request.feature_views,
469478
)
470479

471480
@app.exception_handler(Exception)

sdk/python/feast/infra/online_stores/dynamodb.py

Lines changed: 66 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,24 @@ class DynamoDBOnlineStore(OnlineStore):
108108
Attributes:
109109
_dynamodb_client: Boto3 DynamoDB client.
110110
_dynamodb_resource: Boto3 DynamoDB resource.
111+
_aioboto_session: Async boto session.
112+
_aioboto_client: Async boto client.
113+
_aioboto_context_stack: Async context stack.
111114
"""
112115

113116
_dynamodb_client = None
114117
_dynamodb_resource = None
115118

119+
def __init__(self):
120+
super().__init__()
121+
self._aioboto_session = None
122+
self._aioboto_client = None
123+
self._aioboto_context_stack = None
124+
116125
async def initialize(self, config: RepoConfig):
117126
online_config = config.online_store
118127

119-
await _get_aiodynamodb_client(
128+
await self._get_aiodynamodb_client(
120129
online_config.region,
121130
online_config.max_pool_connections,
122131
online_config.keepalive_timeout,
@@ -127,7 +136,59 @@ async def initialize(self, config: RepoConfig):
127136
)
128137

129138
async def close(self):
130-
await _aiodynamodb_close()
139+
await self._aiodynamodb_close()
140+
141+
def _get_aioboto_session(self):
142+
if self._aioboto_session is None:
143+
logger.debug("initializing the aiobotocore session")
144+
self._aioboto_session = session.get_session()
145+
return self._aioboto_session
146+
147+
async def _get_aiodynamodb_client(
148+
self,
149+
region: str,
150+
max_pool_connections: int,
151+
keepalive_timeout: float,
152+
connect_timeout: Union[int, float],
153+
read_timeout: Union[int, float],
154+
total_max_retry_attempts: Union[int, None],
155+
retry_mode: Union[Literal["legacy", "standard", "adaptive"], None],
156+
):
157+
if self._aioboto_client is None:
158+
logger.debug("initializing the aiobotocore dynamodb client")
159+
160+
retries: Dict[str, Any] = {}
161+
if total_max_retry_attempts is not None:
162+
retries["total_max_attempts"] = total_max_retry_attempts
163+
if retry_mode is not None:
164+
retries["mode"] = retry_mode
165+
166+
client_context = self._get_aioboto_session().create_client(
167+
"dynamodb",
168+
region_name=region,
169+
config=AioConfig(
170+
max_pool_connections=max_pool_connections,
171+
connect_timeout=connect_timeout,
172+
read_timeout=read_timeout,
173+
retries=retries if retries else None,
174+
connector_args={"keepalive_timeout": keepalive_timeout},
175+
),
176+
)
177+
self._aioboto_context_stack = contextlib.AsyncExitStack()
178+
self._aioboto_client = (
179+
await self._aioboto_context_stack.enter_async_context(client_context)
180+
)
181+
return self._aioboto_client
182+
183+
async def _aiodynamodb_close(self):
184+
if self._aioboto_client:
185+
await self._aioboto_client.close()
186+
self._aioboto_client = None
187+
if self._aioboto_context_stack:
188+
await self._aioboto_context_stack.aclose()
189+
self._aioboto_context_stack = None
190+
if self._aioboto_session:
191+
self._aioboto_session = None
131192

132193
@property
133194
def async_supported(self) -> SupportedAsyncMethods:
@@ -362,7 +423,7 @@ async def online_write_batch_async(
362423
_to_client_write_item(config, entity_key, features, timestamp)
363424
for entity_key, features, timestamp, _ in _latest_data_to_write(data)
364425
]
365-
client = await _get_aiodynamodb_client(
426+
client = await self._get_aiodynamodb_client(
366427
online_config.region,
367428
online_config.max_pool_connections,
368429
online_config.keepalive_timeout,
@@ -473,7 +534,7 @@ def to_tbl_resp(raw_client_response):
473534
batches.append(batch)
474535
entity_id_batches.append(entity_id_batch)
475536

476-
client = await _get_aiodynamodb_client(
537+
client = await self._get_aiodynamodb_client(
477538
online_config.region,
478539
online_config.max_pool_connections,
479540
online_config.keepalive_timeout,
@@ -627,66 +688,7 @@ def _to_client_batch_get_payload(online_config, table_name, batch):
627688
}
628689

629690

630-
_aioboto_session = None
631-
_aioboto_client = None
632-
_aioboto_context_stack = None
633-
634-
635-
def _get_aioboto_session():
636-
global _aioboto_session
637-
if _aioboto_session is None:
638-
logger.debug("initializing the aiobotocore session")
639-
_aioboto_session = session.get_session()
640-
return _aioboto_session
641-
642-
643-
async def _get_aiodynamodb_client(
644-
region: str,
645-
max_pool_connections: int,
646-
keepalive_timeout: float,
647-
connect_timeout: Union[int, float],
648-
read_timeout: Union[int, float],
649-
total_max_retry_attempts: Union[int, None],
650-
retry_mode: Union[Literal["legacy", "standard", "adaptive"], None],
651-
):
652-
global _aioboto_client, _aioboto_context_stack
653-
if _aioboto_client is None:
654-
logger.debug("initializing the aiobotocore dynamodb client")
655-
656-
retries: Dict[str, Any] = {}
657-
if total_max_retry_attempts is not None:
658-
retries["total_max_attempts"] = total_max_retry_attempts
659-
if retry_mode is not None:
660-
retries["mode"] = retry_mode
661-
662-
client_context = _get_aioboto_session().create_client(
663-
"dynamodb",
664-
region_name=region,
665-
config=AioConfig(
666-
max_pool_connections=max_pool_connections,
667-
connect_timeout=connect_timeout,
668-
read_timeout=read_timeout,
669-
retries=retries if retries else None,
670-
connector_args={"keepalive_timeout": keepalive_timeout},
671-
),
672-
)
673-
_aioboto_context_stack = contextlib.AsyncExitStack()
674-
_aioboto_client = await _aioboto_context_stack.enter_async_context(
675-
client_context
676-
)
677-
return _aioboto_client
678-
679-
680-
async def _aiodynamodb_close():
681-
global _aioboto_client, _aioboto_session, _aioboto_context_stack
682-
if _aioboto_client:
683-
await _aioboto_client.close()
684-
_aioboto_client = None
685-
if _aioboto_context_stack:
686-
await _aioboto_context_stack.aclose()
687-
_aioboto_context_stack = None
688-
if _aioboto_session:
689-
_aioboto_session = None
691+
# Global async client functions removed - now using instance methods
690692

691693

692694
def _initialize_dynamodb_client(
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from unittest.mock import AsyncMock, MagicMock
2+
3+
from fastapi.testclient import TestClient
4+
5+
from feast.feature_server import get_app
6+
from feast.online_response import OnlineResponse
7+
from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesResponse
8+
9+
10+
def test_async_get_online_features():
11+
"""Test that async get_online_features endpoint works correctly"""
12+
fs = MagicMock()
13+
fs._get_provider.return_value.async_supported.online.read = True
14+
fs.get_online_features_async = AsyncMock(
15+
return_value=OnlineResponse(GetOnlineFeaturesResponse())
16+
)
17+
fs.get_feature_service = MagicMock()
18+
fs.initialize = AsyncMock()
19+
fs.close = AsyncMock()
20+
21+
client = TestClient(get_app(fs))
22+
response = client.post(
23+
"/get-online-features",
24+
json={"features": ["test:feature"], "entities": {"entity_id": [123]}},
25+
)
26+
27+
assert response.status_code == 200
28+
assert fs.get_online_features_async.await_count == 1

0 commit comments

Comments
 (0)