diff --git a/sdk/python/feast/feature_server.py b/sdk/python/feast/feature_server.py index 0cc90b294d..fee7e56e9c 100644 --- a/sdk/python/feast/feature_server.py +++ b/sdk/python/feast/feature_server.py @@ -120,27 +120,26 @@ class SaveDocumentRequest(BaseModel): data: dict -def _get_features( +async def _get_features( request: Union[GetOnlineFeaturesRequest, GetOnlineDocumentsRequest], store: "feast.FeatureStore", ): if request.feature_service: - feature_service = store.get_feature_service( - request.feature_service, allow_cache=True + feature_service = await run_in_threadpool( + store.get_feature_service, request.feature_service, allow_cache=True ) assert_permissions( resource=feature_service, actions=[AuthzedAction.READ_ONLINE] ) features = feature_service # type: ignore else: - all_feature_views, all_on_demand_feature_views = ( - utils._get_feature_views_to_use( - store.registry, - store.project, - request.features, - allow_cache=True, - hide_dummy_entity=False, - ) + all_feature_views, all_on_demand_feature_views = await run_in_threadpool( + utils._get_feature_views_to_use, + store.registry, + store.project, + request.features, + allow_cache=True, + hide_dummy_entity=False, ) for feature_view in all_feature_views: assert_permissions( @@ -230,7 +229,7 @@ async def lifespan(app: FastAPI): ) async def get_online_features(request: GetOnlineFeaturesRequest) -> Dict[str, Any]: # Initialize parameters for FeatureStore.get_online_features(...) call - features = await run_in_threadpool(_get_features, request, store) + features = await _get_features(request, store) read_params = dict( features=features, @@ -265,7 +264,7 @@ async def retrieve_online_documents( "This endpoint is in alpha and will be moved to /get-online-features when stable." ) # Initialize parameters for FeatureStore.retrieve_online_documents_v2(...) call - features = await run_in_threadpool(_get_features, request, store) + features = await _get_features(request, store) read_params = dict(features=features, query=request.query, top_k=request.top_k) if request.api_version == 2 and request.query_string is not None: @@ -342,26 +341,31 @@ async def push(request: PushFeaturesRequest) -> None: else: store.push(**push_params) - def _get_feast_object( + async def _get_feast_object( feature_view_name: str, allow_registry_cache: bool ) -> FeastObject: try: - return store.get_stream_feature_view( # type: ignore - feature_view_name, allow_registry_cache=allow_registry_cache + return await run_in_threadpool( + store.get_stream_feature_view, + feature_view_name, + allow_registry_cache=allow_registry_cache, ) except FeatureViewNotFoundException: - return store.get_feature_view( # type: ignore - feature_view_name, allow_registry_cache=allow_registry_cache + return await run_in_threadpool( + store.get_feature_view, + feature_view_name, + allow_registry_cache=allow_registry_cache, ) @app.post("/write-to-online-store", dependencies=[Depends(inject_user_details)]) - def write_to_online_store(request: WriteToFeatureStoreRequest) -> None: + async def write_to_online_store(request: WriteToFeatureStoreRequest) -> None: df = pd.DataFrame(request.df) feature_view_name = request.feature_view_name allow_registry_cache = request.allow_registry_cache - resource = _get_feast_object(feature_view_name, allow_registry_cache) + resource = await _get_feast_object(feature_view_name, allow_registry_cache) assert_permissions(resource=resource, actions=[AuthzedAction.WRITE_ONLINE]) - store.write_to_online_store( + await run_in_threadpool( + store.write_to_online_store, feature_view_name=feature_view_name, df=df, allow_registry_cache=allow_registry_cache, @@ -428,10 +432,11 @@ async def chat_ui(): return Response(content=content, media_type="text/html") @app.post("/materialize", dependencies=[Depends(inject_user_details)]) - def materialize(request: MaterializeRequest) -> None: + async def materialize(request: MaterializeRequest) -> None: for feature_view in request.feature_views or []: + resource = await _get_feast_object(feature_view, True) assert_permissions( - resource=_get_feast_object(feature_view, True), + resource=resource, actions=[AuthzedAction.WRITE_ONLINE], ) @@ -450,7 +455,8 @@ def materialize(request: MaterializeRequest) -> None: start_date = utils.make_tzaware(parser.parse(request.start_ts)) end_date = utils.make_tzaware(parser.parse(request.end_ts)) - store.materialize( + await run_in_threadpool( + store.materialize, start_date, end_date, request.feature_views, @@ -458,14 +464,17 @@ def materialize(request: MaterializeRequest) -> None: ) @app.post("/materialize-incremental", dependencies=[Depends(inject_user_details)]) - def materialize_incremental(request: MaterializeIncrementalRequest) -> None: + async def materialize_incremental(request: MaterializeIncrementalRequest) -> None: for feature_view in request.feature_views or []: + resource = await _get_feast_object(feature_view, True) assert_permissions( - resource=_get_feast_object(feature_view, True), + resource=resource, actions=[AuthzedAction.WRITE_ONLINE], ) - store.materialize_incremental( - utils.make_tzaware(parser.parse(request.end_ts)), request.feature_views + await run_in_threadpool( + store.materialize_incremental, + utils.make_tzaware(parser.parse(request.end_ts)), + request.feature_views, ) @app.exception_handler(Exception) diff --git a/sdk/python/feast/infra/online_stores/dynamodb.py b/sdk/python/feast/infra/online_stores/dynamodb.py index a03695fd1b..c577159884 100644 --- a/sdk/python/feast/infra/online_stores/dynamodb.py +++ b/sdk/python/feast/infra/online_stores/dynamodb.py @@ -108,15 +108,24 @@ class DynamoDBOnlineStore(OnlineStore): Attributes: _dynamodb_client: Boto3 DynamoDB client. _dynamodb_resource: Boto3 DynamoDB resource. + _aioboto_session: Async boto session. + _aioboto_client: Async boto client. + _aioboto_context_stack: Async context stack. """ _dynamodb_client = None _dynamodb_resource = None + def __init__(self): + super().__init__() + self._aioboto_session = None + self._aioboto_client = None + self._aioboto_context_stack = None + async def initialize(self, config: RepoConfig): online_config = config.online_store - await _get_aiodynamodb_client( + await self._get_aiodynamodb_client( online_config.region, online_config.max_pool_connections, online_config.keepalive_timeout, @@ -127,7 +136,59 @@ async def initialize(self, config: RepoConfig): ) async def close(self): - await _aiodynamodb_close() + await self._aiodynamodb_close() + + def _get_aioboto_session(self): + if self._aioboto_session is None: + logger.debug("initializing the aiobotocore session") + self._aioboto_session = session.get_session() + return self._aioboto_session + + async def _get_aiodynamodb_client( + self, + region: str, + max_pool_connections: int, + keepalive_timeout: float, + connect_timeout: Union[int, float], + read_timeout: Union[int, float], + total_max_retry_attempts: Union[int, None], + retry_mode: Union[Literal["legacy", "standard", "adaptive"], None], + ): + if self._aioboto_client is None: + logger.debug("initializing the aiobotocore dynamodb client") + + retries: Dict[str, Any] = {} + if total_max_retry_attempts is not None: + retries["total_max_attempts"] = total_max_retry_attempts + if retry_mode is not None: + retries["mode"] = retry_mode + + client_context = self._get_aioboto_session().create_client( + "dynamodb", + region_name=region, + config=AioConfig( + max_pool_connections=max_pool_connections, + connect_timeout=connect_timeout, + read_timeout=read_timeout, + retries=retries if retries else None, + connector_args={"keepalive_timeout": keepalive_timeout}, + ), + ) + self._aioboto_context_stack = contextlib.AsyncExitStack() + self._aioboto_client = ( + await self._aioboto_context_stack.enter_async_context(client_context) + ) + return self._aioboto_client + + async def _aiodynamodb_close(self): + if self._aioboto_client: + await self._aioboto_client.close() + self._aioboto_client = None + if self._aioboto_context_stack: + await self._aioboto_context_stack.aclose() + self._aioboto_context_stack = None + if self._aioboto_session: + self._aioboto_session = None @property def async_supported(self) -> SupportedAsyncMethods: @@ -362,7 +423,7 @@ async def online_write_batch_async( _to_client_write_item(config, entity_key, features, timestamp) for entity_key, features, timestamp, _ in _latest_data_to_write(data) ] - client = await _get_aiodynamodb_client( + client = await self._get_aiodynamodb_client( online_config.region, online_config.max_pool_connections, online_config.keepalive_timeout, @@ -473,7 +534,7 @@ def to_tbl_resp(raw_client_response): batches.append(batch) entity_id_batches.append(entity_id_batch) - client = await _get_aiodynamodb_client( + client = await self._get_aiodynamodb_client( online_config.region, online_config.max_pool_connections, online_config.keepalive_timeout, @@ -627,66 +688,7 @@ def _to_client_batch_get_payload(online_config, table_name, batch): } -_aioboto_session = None -_aioboto_client = None -_aioboto_context_stack = None - - -def _get_aioboto_session(): - global _aioboto_session - if _aioboto_session is None: - logger.debug("initializing the aiobotocore session") - _aioboto_session = session.get_session() - return _aioboto_session - - -async def _get_aiodynamodb_client( - region: str, - max_pool_connections: int, - keepalive_timeout: float, - connect_timeout: Union[int, float], - read_timeout: Union[int, float], - total_max_retry_attempts: Union[int, None], - retry_mode: Union[Literal["legacy", "standard", "adaptive"], None], -): - global _aioboto_client, _aioboto_context_stack - if _aioboto_client is None: - logger.debug("initializing the aiobotocore dynamodb client") - - retries: Dict[str, Any] = {} - if total_max_retry_attempts is not None: - retries["total_max_attempts"] = total_max_retry_attempts - if retry_mode is not None: - retries["mode"] = retry_mode - - client_context = _get_aioboto_session().create_client( - "dynamodb", - region_name=region, - config=AioConfig( - max_pool_connections=max_pool_connections, - connect_timeout=connect_timeout, - read_timeout=read_timeout, - retries=retries if retries else None, - connector_args={"keepalive_timeout": keepalive_timeout}, - ), - ) - _aioboto_context_stack = contextlib.AsyncExitStack() - _aioboto_client = await _aioboto_context_stack.enter_async_context( - client_context - ) - return _aioboto_client - - -async def _aiodynamodb_close(): - global _aioboto_client, _aioboto_session, _aioboto_context_stack - if _aioboto_client: - await _aioboto_client.close() - _aioboto_client = None - if _aioboto_context_stack: - await _aioboto_context_stack.aclose() - _aioboto_context_stack = None - if _aioboto_session: - _aioboto_session = None +# Global async client functions removed - now using instance methods def _initialize_dynamodb_client( diff --git a/sdk/python/tests/unit/test_feature_server_async.py b/sdk/python/tests/unit/test_feature_server_async.py new file mode 100644 index 0000000000..641a3c5327 --- /dev/null +++ b/sdk/python/tests/unit/test_feature_server_async.py @@ -0,0 +1,28 @@ +from unittest.mock import AsyncMock, MagicMock + +from fastapi.testclient import TestClient + +from feast.feature_server import get_app +from feast.online_response import OnlineResponse +from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesResponse + + +def test_async_get_online_features(): + """Test that async get_online_features endpoint works correctly""" + fs = MagicMock() + fs._get_provider.return_value.async_supported.online.read = True + fs.get_online_features_async = AsyncMock( + return_value=OnlineResponse(GetOnlineFeaturesResponse()) + ) + fs.get_feature_service = MagicMock() + fs.initialize = AsyncMock() + fs.close = AsyncMock() + + client = TestClient(get_app(fs)) + response = client.post( + "/get-online-features", + json={"features": ["test:feature"], "entities": {"entity_id": [123]}}, + ) + + assert response.status_code == 200 + assert fs.get_online_features_async.await_count == 1