diff --git a/pyproject.toml b/pyproject.toml index 91279205620..47b400ceccc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ [project.optional-dependencies] aws = ["boto3==1.38.27", "fsspec<=2024.9.0", "aiobotocore>2,<3"] +dax = ["boto3>=1.26.0", "amazon-dax-client>=2.0.0,<3"] azure = [ "azure-storage-blob>=0.37.0", "azure-identity>=1.6.1", diff --git a/sdk/python/feast/infra/online_stores/dynamodb.py b/sdk/python/feast/infra/online_stores/dynamodb.py index 814058c77e5..cad3404cc9f 100644 --- a/sdk/python/feast/infra/online_stores/dynamodb.py +++ b/sdk/python/feast/infra/online_stores/dynamodb.py @@ -21,7 +21,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union from aiobotocore.config import AioConfig -from pydantic import StrictBool, StrictStr +from pydantic import StrictBool, StrictStr, model_validator from feast import Entity, FeatureView, utils from feast.infra.online_stores.helpers import compute_entity_id @@ -62,6 +62,23 @@ class DynamoDBOnlineStoreConfig(FeastConfigBaseModel): """DynamoDB endpoint URL. Use for local development (e.g., http://localhost:8000) or VPC endpoints for improved latency.""" + use_dax: bool = False + """Enable DAX (DynamoDB Accelerator) for sub-millisecond read latency. + Requires amazon-dax-client package and a running DAX cluster. + + IMPORTANT: DAX is only supported for synchronous operations. When using + Feast feature server (which uses async), DAX will NOT be used and requests + will fall back to direct DynamoDB access. DAX works with: + - Direct Python SDK usage (FeatureStore.get_online_features) + - Batch operations via CLI + + For async feature server workloads, consider using DynamoDB VPC endpoints + with endpoint_url configuration instead.""" + + dax_endpoint: Union[str, None] = None + """DAX cluster endpoint URL (e.g., dax://my-cluster.xxx.dax-clusters.us-east-1.amazonaws.com). + Required when use_dax is True. Supports both 'dax://' and 'daxs://' (TLS) schemes.""" + region: StrictStr """AWS Region Name""" @@ -110,6 +127,16 @@ class DynamoDBOnlineStoreConfig(FeastConfigBaseModel): 'adaptive' mode provides intelligent retry with client-side rate limiting. """ + @model_validator(mode="after") + def _validate_dax_config(self): + """Validate that dax_endpoint is provided when use_dax is True.""" + if self.use_dax and not self.dax_endpoint: + raise ValueError( + "dax_endpoint is required when use_dax is True. " + "Provide the DAX cluster endpoint URL (e.g., dax://my-cluster.xxx.dax-clusters.us-east-1.amazonaws.com)" + ) + return self + class DynamoDBOnlineStore(OnlineStore): """ @@ -141,6 +168,14 @@ def __init__(self): async def initialize(self, config: RepoConfig): online_config = config.online_store + # Warn if DAX is enabled but async mode is being used + if online_config.use_dax and online_config.dax_endpoint: + logger.warning( + "DAX is enabled but async mode (feature server) does not support DAX. " + "Requests will use direct DynamoDB access. DAX only works with " + "synchronous operations (e.g., FeatureStore.get_online_features())." + ) + await self._get_aiodynamodb_client( online_config.region, online_config.max_pool_connections, @@ -270,15 +305,22 @@ def update( """ online_config = config.online_store assert isinstance(online_config, DynamoDBOnlineStoreConfig) - dynamodb_client = self._get_dynamodb_client( + # Table operations (describe, create, delete) are NOT supported by DAX. + # Create fresh non-DAX clients directly (don't use cached _get_dynamodb_* methods + # as those may cache DAX clients for data operations). + dynamodb_client = _initialize_dynamodb_client( online_config.region, online_config.endpoint_url, online_config.session_based_auth, + use_dax=False, + dax_endpoint=None, ) - dynamodb_resource = self._get_dynamodb_resource( + dynamodb_resource = _initialize_dynamodb_resource( online_config.region, online_config.endpoint_url, online_config.session_based_auth, + use_dax=False, + dax_endpoint=None, ) do_tag_updates = defaultdict(bool) @@ -369,10 +411,15 @@ def teardown( """ online_config = config.online_store assert isinstance(online_config, DynamoDBOnlineStoreConfig) - dynamodb_resource = self._get_dynamodb_resource( + # Table operations (delete) are NOT supported by DAX. + # Create fresh non-DAX client directly (don't use cached _get_dynamodb_resource + # as it may cache DAX client for data operations). + dynamodb_resource = _initialize_dynamodb_resource( online_config.region, online_config.endpoint_url, online_config.session_based_auth, + use_dax=False, + dax_endpoint=None, ) for table in tables: @@ -410,6 +457,8 @@ def online_write_batch( online_config.region, online_config.endpoint_url, online_config.session_based_auth, + online_config.use_dax, + online_config.dax_endpoint, ) table_instance = dynamodb_resource.Table( @@ -483,6 +532,8 @@ def online_read( online_config.region, online_config.endpoint_url, online_config.session_based_auth, + online_config.use_dax, + online_config.dax_endpoint, ) table_name = _get_table_name(online_config, config, table) @@ -516,6 +567,8 @@ def online_read( online_config.region, online_config.endpoint_url, online_config.session_based_auth, + online_config.use_dax, + online_config.dax_endpoint, ) def fetch_batch(batch: List[str]) -> Dict[str, Any]: @@ -640,10 +693,12 @@ def _get_dynamodb_client( region: str, endpoint_url: Optional[str] = None, session_based_auth: Optional[bool] = False, + use_dax: bool = False, + dax_endpoint: Optional[str] = None, ): if self._dynamodb_client is None: self._dynamodb_client = _initialize_dynamodb_client( - region, endpoint_url, session_based_auth + region, endpoint_url, session_based_auth, use_dax, dax_endpoint ) return self._dynamodb_client @@ -652,10 +707,12 @@ def _get_dynamodb_resource( region: str, endpoint_url: Optional[str] = None, session_based_auth: Optional[bool] = False, + use_dax: bool = False, + dax_endpoint: Optional[str] = None, ): if self._dynamodb_resource is None: self._dynamodb_resource = _initialize_dynamodb_resource( - region, endpoint_url, session_based_auth + region, endpoint_url, session_based_auth, use_dax, dax_endpoint ) return self._dynamodb_resource @@ -811,6 +868,8 @@ def update_online_store( online_config.region, online_config.endpoint_url, online_config.session_based_auth, + online_config.use_dax, + online_config.dax_endpoint, ) table_instance = dynamodb_resource.Table( @@ -1119,7 +1178,32 @@ def _initialize_dynamodb_client( region: str, endpoint_url: Optional[str] = None, session_based_auth: Optional[bool] = False, + use_dax: bool = False, + dax_endpoint: Optional[str] = None, ): + """ + Initialize DynamoDB client, optionally using DAX for caching. + + When use_dax=True, returns a DAX client that is API-compatible with + the boto3 DynamoDB client but routes requests through DAX cluster. + """ + if use_dax and dax_endpoint: + try: + from amazondax import AmazonDaxClient + + logger.info(f"Initializing DAX client with endpoint: {dax_endpoint}") + # AmazonDaxClient() constructor creates a client (not .client() method) + # endpoint_url should be in format: dax://cluster.xxx.dax-clusters.region.amazonaws.com + return AmazonDaxClient( + endpoint_url=dax_endpoint, + region_name=region, + ) + except ImportError: + logger.warning( + "amazon-dax-client not installed. Install with: pip install amazon-dax-client. " + "Falling back to standard DynamoDB client." + ) + if session_based_auth: return boto3.Session().client( "dynamodb", @@ -1140,7 +1224,32 @@ def _initialize_dynamodb_resource( region: str, endpoint_url: Optional[str] = None, session_based_auth: Optional[bool] = False, + use_dax: bool = False, + dax_endpoint: Optional[str] = None, ): + """ + Initialize DynamoDB resource, optionally using DAX for caching. + + When use_dax=True, returns a DAX resource that is API-compatible with + the boto3 DynamoDB resource but routes requests through DAX cluster. + """ + if use_dax and dax_endpoint: + try: + from amazondax import AmazonDaxClient + + logger.info(f"Initializing DAX resource with endpoint: {dax_endpoint}") + # AmazonDaxClient.resource() creates a resource interface + # endpoint_url should be in format: dax://cluster.xxx.dax-clusters.region.amazonaws.com + return AmazonDaxClient.resource( + endpoint_url=dax_endpoint, + region_name=region, + ) + except ImportError: + logger.warning( + "amazon-dax-client not installed. Install with: pip install amazon-dax-client. " + "Falling back to standard DynamoDB resource." + ) + if session_based_auth: return boto3.Session().resource( "dynamodb", region_name=region, endpoint_url=endpoint_url diff --git a/sdk/python/tests/integration/online_store/test_dax_integration.py b/sdk/python/tests/integration/online_store/test_dax_integration.py new file mode 100644 index 00000000000..8560df901c3 --- /dev/null +++ b/sdk/python/tests/integration/online_store/test_dax_integration.py @@ -0,0 +1,299 @@ +""" +Test script for DAX integration with DynamoDB online store. + +This script can be used to verify DAX client initialization and basic operations. +Run with: pytest -v tests/integration/online_store/test_dax_integration.py + +Prerequisites: +1. pip install amazon-dax-client +2. A running DAX cluster (or use mock for unit testing) +3. AWS credentials configured +""" + +import os +import pytest +from unittest.mock import MagicMock, patch + + +def _dax_client_available() -> bool: + """Check if amazon-dax-client package is installed.""" + try: + import amazondax + return True + except ImportError: + return False + + +class TestDaxConfiguration: + """Test DAX configuration parsing.""" + + def test_dax_config_defaults(self): + """Test that DAX config defaults are correct.""" + from feast.infra.online_stores.dynamodb import DynamoDBOnlineStoreConfig + + config = DynamoDBOnlineStoreConfig(region="us-east-1") + + assert config.use_dax is False + assert config.dax_endpoint is None + + def test_dax_config_enabled(self): + """Test DAX config when enabled.""" + from feast.infra.online_stores.dynamodb import DynamoDBOnlineStoreConfig + + config = DynamoDBOnlineStoreConfig( + region="us-east-1", + use_dax=True, + dax_endpoint="dax://my-cluster.xxx.dax-clusters.us-east-1.amazonaws.com", + ) + + assert config.use_dax is True + assert config.dax_endpoint == "dax://my-cluster.xxx.dax-clusters.us-east-1.amazonaws.com" + + def test_dax_config_validation_missing_endpoint(self): + """Test that validation fails when use_dax=True but dax_endpoint is missing.""" + from feast.infra.online_stores.dynamodb import DynamoDBOnlineStoreConfig + + with pytest.raises(ValueError, match="dax_endpoint is required when use_dax is True"): + DynamoDBOnlineStoreConfig( + region="us-east-1", + use_dax=True, + # dax_endpoint intentionally missing + ) + + def test_dax_config_with_tls_endpoint(self): + """Test DAX config with TLS (daxs://) endpoint.""" + from feast.infra.online_stores.dynamodb import DynamoDBOnlineStoreConfig + + config = DynamoDBOnlineStoreConfig( + region="us-east-1", + use_dax=True, + dax_endpoint="daxs://my-cluster.xxx.dax-clusters.us-east-1.amazonaws.com", + ) + + assert config.use_dax is True + assert config.dax_endpoint.startswith("daxs://") + + def test_dax_disabled_with_endpoint_set(self): + """Test that having dax_endpoint set but use_dax=False is valid (endpoint ignored).""" + from feast.infra.online_stores.dynamodb import DynamoDBOnlineStoreConfig + + config = DynamoDBOnlineStoreConfig( + region="us-east-1", + use_dax=False, + dax_endpoint="dax://my-cluster.xxx.dax-clusters.us-east-1.amazonaws.com", + ) + + assert config.use_dax is False + assert config.dax_endpoint is not None # Set but will be ignored + + +class TestDaxClientInitialization: + """Test DAX client initialization.""" + + def test_client_init_without_dax(self): + """Test that regular boto3 client is created when DAX is disabled.""" + from feast.infra.online_stores.dynamodb import _initialize_dynamodb_client + + with patch("feast.infra.online_stores.dynamodb.boto3") as mock_boto3: + mock_boto3.client.return_value = MagicMock() + + client = _initialize_dynamodb_client( + region="us-east-1", + use_dax=False, + ) + + mock_boto3.client.assert_called_once() + assert "dynamodb" in str(mock_boto3.client.call_args) + + def test_resource_init_without_dax(self): + """Test that regular boto3 resource is created when DAX is disabled.""" + from feast.infra.online_stores.dynamodb import _initialize_dynamodb_resource + + with patch("feast.infra.online_stores.dynamodb.boto3") as mock_boto3: + mock_boto3.resource.return_value = MagicMock() + + resource = _initialize_dynamodb_resource( + region="us-east-1", + use_dax=False, + ) + + mock_boto3.resource.assert_called_once() + assert "dynamodb" in str(mock_boto3.resource.call_args) + + def test_client_init_with_dax_package_missing(self): + """Test fallback to boto3 when amazon-dax-client is not installed.""" + from feast.infra.online_stores.dynamodb import _initialize_dynamodb_client + + with patch("feast.infra.online_stores.dynamodb.boto3") as mock_boto3: + mock_boto3.client.return_value = MagicMock() + + # Simulate ImportError for amazondax + with patch.dict("sys.modules", {"amazondax": None}): + client = _initialize_dynamodb_client( + region="us-east-1", + use_dax=True, + dax_endpoint="dax://test.xxx.dax-clusters.us-east-1.amazonaws.com", + ) + + # Should fall back to boto3 + mock_boto3.client.assert_called_once() + + def test_resource_init_with_dax_package_missing(self): + """Test fallback to boto3 resource when amazon-dax-client is not installed.""" + from feast.infra.online_stores.dynamodb import _initialize_dynamodb_resource + + with patch("feast.infra.online_stores.dynamodb.boto3") as mock_boto3: + mock_boto3.resource.return_value = MagicMock() + + # Simulate ImportError for amazondax + with patch.dict("sys.modules", {"amazondax": None}): + resource = _initialize_dynamodb_resource( + region="us-east-1", + use_dax=True, + dax_endpoint="dax://test.xxx.dax-clusters.us-east-1.amazonaws.com", + ) + + # Should fall back to boto3 + mock_boto3.resource.assert_called_once() + + def test_client_init_with_session_auth(self): + """Test client initialization with session-based auth.""" + from feast.infra.online_stores.dynamodb import _initialize_dynamodb_client + + with patch("feast.infra.online_stores.dynamodb.boto3") as mock_boto3: + mock_session = MagicMock() + mock_boto3.Session.return_value = mock_session + mock_session.client.return_value = MagicMock() + + client = _initialize_dynamodb_client( + region="us-east-1", + session_based_auth=True, + ) + + mock_boto3.Session.assert_called_once() + mock_session.client.assert_called_once() + + @pytest.mark.skipif( + not _dax_client_available(), + reason="amazon-dax-client not installed", + ) + def test_client_init_with_dax_enabled(self): + """Test DAX client initialization when package is available.""" + from feast.infra.online_stores.dynamodb import _initialize_dynamodb_client + + with patch("amazondax.AmazonDaxClient") as mock_dax: + mock_dax.return_value = MagicMock() + + client = _initialize_dynamodb_client( + region="us-east-1", + use_dax=True, + dax_endpoint="dax://test.xxx.dax-clusters.us-east-1.amazonaws.com", + ) + + mock_dax.assert_called_once_with( + endpoint_url="dax://test.xxx.dax-clusters.us-east-1.amazonaws.com", + region_name="us-east-1", + ) + + @pytest.mark.skipif( + not _dax_client_available(), + reason="amazon-dax-client not installed", + ) + def test_resource_init_with_dax_enabled(self): + """Test DAX resource initialization when package is available.""" + from feast.infra.online_stores.dynamodb import _initialize_dynamodb_resource + + with patch("amazondax.AmazonDaxClient") as mock_dax: + mock_dax.resource.return_value = MagicMock() + + resource = _initialize_dynamodb_resource( + region="us-east-1", + use_dax=True, + dax_endpoint="dax://test.xxx.dax-clusters.us-east-1.amazonaws.com", + ) + + mock_dax.resource.assert_called_once_with( + endpoint_url="dax://test.xxx.dax-clusters.us-east-1.amazonaws.com", + region_name="us-east-1", + ) + + +class TestDaxEndToEnd: + """End-to-end tests requiring actual DAX cluster (skipped by default).""" + + @pytest.mark.skipif( + not os.environ.get("DAX_ENDPOINT"), + reason="DAX_ENDPOINT environment variable not set", + ) + @pytest.mark.skipif( + not _dax_client_available(), + reason="amazon-dax-client not installed", + ) + def test_dax_get_item(self): + """ + Test actual DAX GetItem operation. + + Set environment variables before running: + - DAX_ENDPOINT: Your DAX cluster endpoint + - AWS_REGION: AWS region (default: us-east-1) + """ + from amazondax import AmazonDaxClient + + endpoint = os.environ["DAX_ENDPOINT"] + region = os.environ.get("AWS_REGION", "us-east-1") + + # Create DAX resource + dax = AmazonDaxClient.resource( + endpoint_url=endpoint, + region_name=region, + ) + + # This will fail if table doesn't exist, but tests the connection + try: + table = dax.Table("test_table") + # Just verify we can create table reference + assert table is not None + finally: + # DAX client cleanup + pass + + +if __name__ == "__main__": + # Quick smoke test + print("Testing DAX configuration...") + + from feast.infra.online_stores.dynamodb import DynamoDBOnlineStoreConfig + + # Test 1: Default config + config = DynamoDBOnlineStoreConfig(region="us-east-1") + assert config.use_dax is False + print("✓ Default config: use_dax=False") + + # Test 2: DAX enabled config + config = DynamoDBOnlineStoreConfig( + region="us-east-1", + use_dax=True, + dax_endpoint="dax://test.xxx.dax-clusters.us-east-1.amazonaws.com", + ) + assert config.use_dax is True + print("✓ DAX enabled config: use_dax=True") + + # Test 3: Validation - use_dax=True without endpoint should fail + try: + config = DynamoDBOnlineStoreConfig( + region="us-east-1", + use_dax=True, + # Missing dax_endpoint + ) + print("✗ Validation test failed - should have raised ValueError") + except ValueError as e: + assert "dax_endpoint is required" in str(e) + print("✓ Validation: use_dax=True without endpoint raises ValueError") + + # Test 4: Check if amazon-dax-client is available + if _dax_client_available(): + print("✓ amazon-dax-client is installed") + else: + print("⚠ amazon-dax-client not installed (install with: pip install amazon-dax-client)") + + print("\nAll smoke tests passed!") diff --git a/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py b/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py index 7e5558e19d7..3020457e583 100644 --- a/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py +++ b/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py @@ -1050,3 +1050,59 @@ def tracking_client(*args, **kwargs): f"Expected 1 shared client for thread-safety, " f"got {len(set(dynamodb_clients))} unique clients" ) + + +# ============================================================================ +# DAX (DynamoDB Accelerator) Configuration Tests +# ============================================================================ + + +def test_dynamodb_online_store_config_dax_defaults(): + """Test DynamoDBOnlineStoreConfig DAX defaults.""" + config = DynamoDBOnlineStoreConfig(region="us-west-2") + assert config.use_dax is False + assert config.dax_endpoint is None + + +def test_dynamodb_online_store_config_dax_enabled(): + """Test DynamoDBOnlineStoreConfig with DAX enabled.""" + config = DynamoDBOnlineStoreConfig( + region="us-west-2", + use_dax=True, + dax_endpoint="dax://my-cluster.xxx.dax-clusters.us-west-2.amazonaws.com", + ) + assert config.use_dax is True + assert config.dax_endpoint == "dax://my-cluster.xxx.dax-clusters.us-west-2.amazonaws.com" + + +def test_dynamodb_online_store_config_dax_validation_missing_endpoint(): + """Test that use_dax=True without dax_endpoint raises ValueError.""" + with pytest.raises(ValueError, match="dax_endpoint is required when use_dax is True"): + DynamoDBOnlineStoreConfig( + region="us-west-2", + use_dax=True, + # dax_endpoint intentionally missing + ) + + +def test_dynamodb_online_store_config_dax_with_tls(): + """Test DynamoDBOnlineStoreConfig with TLS DAX endpoint (daxs://).""" + config = DynamoDBOnlineStoreConfig( + region="us-west-2", + use_dax=True, + dax_endpoint="daxs://my-cluster.xxx.dax-clusters.us-west-2.amazonaws.com", + ) + assert config.use_dax is True + assert config.dax_endpoint.startswith("daxs://") + + +def test_dynamodb_online_store_config_dax_disabled_with_endpoint(): + """Test that dax_endpoint is ignored when use_dax=False.""" + config = DynamoDBOnlineStoreConfig( + region="us-west-2", + use_dax=False, + dax_endpoint="dax://my-cluster.xxx.dax-clusters.us-west-2.amazonaws.com", + ) + assert config.use_dax is False + # Endpoint is set but will be ignored since use_dax=False + assert config.dax_endpoint is not None