diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py index acee633b6f67..1717b161d5a3 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py @@ -116,3 +116,45 @@ def enrichment_with_vertex_ai_legacy(): | "Enrich W/ Vertex AI" >> Enrichment(vertex_ai_handler) | "Print" >> beam.Map(print)) # [END enrichment_with_vertex_ai_legacy] + + +def enrichment_with_tecton(): + # [START enrichment_with_tecton] + import apache_beam as beam + from apache_beam.transforms.enrichment import Enrichment + from apache_beam.transforms.enrichment_handlers.tecton_feature_store import ( + TectonConnectionConfig, TectonFeaturesRetrievalConfig, + TectonFeatureStoreEnrichmentHandler) + + data = [ + beam.Row(user_id='user_1990251765'), + beam.Row(user_id='user_1284832379'), + beam.Row(user_id='user_9979340926'), + ] + + connection_config = TectonConnectionConfig( + url='https://explore.tecton.ai', + default_workspace_name='prod', + api_key='101142fd7d775e0a1bd9e343cca2a44d' + ) + + features_config = TectonFeaturesRetrievalConfig( + feature_service_name='fraud_detection_feature_service', + entity_id='user_id' + ) + + tecton_handler = TectonFeatureStoreEnrichmentHandler( + connection_config=connection_config, + features_retrieval_config=features_config + ) + + with beam.Pipeline() as p: + _ = ( + p + | "Create" >> beam.Create(data) + | "Enrich W/ Tecton" >> Enrichment(tecton_handler) + | "Print" >> beam.Map(print)) + # [END enrichment_with_tecton] + + +enrichment_with_tecton() diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py index afa2bca7ec68..8041131d431f 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py @@ -28,6 +28,7 @@ from apache_beam.examples.snippets.transforms.elementwise.enrichment import enrichment_with_bigtable, \ enrichment_with_vertex_ai_legacy from apache_beam.examples.snippets.transforms.elementwise.enrichment import enrichment_with_vertex_ai + from apache_beam.examples.snippets.transforms.elementwise.enrichment import enrichment_with_tecton from apache_beam.io.requestresponse import RequestResponseIO except ImportError: raise unittest.SkipTest('RequestResponseIO dependencies are not installed') @@ -60,6 +61,15 @@ def validate_enrichment_with_vertex_ai_legacy(): return expected +def validate_enrichment_with_tecton(): + expected = '''[START enrichment_with_tecton] +Row(user_id='user_9979340926', user_transaction_metrics.amount_count_1d_1d=1, user_transaction_metrics.amount_count_3d_1d=3, user_transaction_metrics.amount_count_7d_1d=7, user_transaction_metrics.amount_mean_1d_1d=65.05, user_transaction_metrics.amount_mean_3d_1d=42.72333333333333, user_transaction_metrics.amount_mean_7d_1d=32.955714285714286) +Row(user_id='user_1990251765', user_transaction_metrics.amount_count_1d_1d=None, user_transaction_metrics.amount_count_3d_1d=2, user_transaction_metrics.amount_count_7d_1d=3, user_transaction_metrics.amount_mean_1d_1d=None, user_transaction_metrics.amount_mean_3d_1d=25.880000000000003, user_transaction_metrics.amount_mean_7d_1d=27.796666666666667) +Row(user_id='user_1284832379', user_transaction_metrics.amount_count_1d_1d=2, user_transaction_metrics.amount_count_3d_1d=6, user_transaction_metrics.amount_count_7d_1d=12, user_transaction_metrics.amount_mean_1d_1d=111.465, user_transaction_metrics.amount_mean_3d_1d=61.961666666666666, user_transaction_metrics.amount_mean_7d_1d=171.5625) + [END enrichment_with_tecton]'''.splitlines()[1:-1] + return expected + + def std_out_to_dict(stdout_lines, row_key): output_dict = {} for stdout_line in stdout_lines: @@ -107,6 +117,16 @@ def test_enrichment_with_vertex_ai_legacy(self, mock_stdout): std_out_to_dict(output, 'entity_id'), std_out_to_dict(expected, 'entity_id')) + def test_enrichment_with_tecton(self, mock_stdout): + enrichment_with_tecton() + output = mock_stdout.getvalue().splitlines() + expected = validate_enrichment_with_tecton() + + self.assertEqual(len(output), len(expected)) + self.assertEqual( + std_out_to_dict(output, 'user_id'), + std_out_to_dict(expected, 'user_id')) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/tecton_feature_store.py b/sdks/python/apache_beam/transforms/enrichment_handlers/tecton_feature_store.py new file mode 100644 index 000000000000..b5d87c890fda --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/tecton_feature_store.py @@ -0,0 +1,232 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from dataclasses import dataclass, field +import logging +from collections.abc import Callable +from collections.abc import Mapping +from typing import Any, Dict +from typing import Optional + +import apache_beam as beam +from apache_beam.transforms.enrichment import EnrichmentSourceHandler +from apache_beam.transforms.enrichment_handlers.utils import ExceptionLevel +from tecton_client import TectonClient, MetadataOptions, RequestOptions + +__all__ = [ + 'TectonFeatureStoreEnrichmentHandler', +] + +EntityRowFn = Callable[[beam.Row], Mapping[str, Any]] + +_LOGGER = logging.getLogger(__name__) + + +@dataclass +class TectonConnectionConfig: + """Configuration dataclass for Tecton connection parameters. + + This dataclass contains the essential connection parameters needed to + establish a connection with a Tecton feature store instance. + + Attributes: + url: The URL of the Tecton instance to connect to. + Example: 'https://your-instance.tecton.ai' + default_workspace_name: The name of the workspace containing the feature + service. This is the workspace where your feature definitions are stored. + api_key: The API key for authenticating with the Tecton instance. + This should be a valid API key with appropriate permissions. + kwargs: Additional keyword arguments for write operations. Enables forward + compatibility with future Tecton connection parameters. + """ + url: str + default_workspace_name: str + api_key: str + kwargs: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not self.url: + raise ValueError('Please provide a Tecton instance URL (`url`).') + + if not self.default_workspace_name: + raise ValueError( + 'Please provide a workspace name (`default_workspace_name`).') + + if not self.api_key: + raise ValueError('Please provide an API key (`api_key`).') + +@dataclass +class TectonFeaturesRetrievalConfig: + """Configuration dataclass for Tecton feature retrieval parameters. + + This dataclass contains the parameters needed to retrieve features from + a Tecton feature store, including entity identification and feature + service configuration. + + Attributes: + feature_service_name: The name of the feature service containing the + features to fetch from the online Tecton feature store. This should + match a feature service defined in your Tecton workspace. + entity_id: The entity name for the entity associated with the features. + The `entity_id` is used to extract the entity value from the input row. + Please provide exactly one of `entity_id` or `entity_row_fn`. + entity_row_fn: A lambda function that takes an input `beam.Row` and + returns a dictionary with a mapping from the entity key column name to + entity key value. It is used to build/extract the entity dict for + feature retrieval. Please provide exactly one of `entity_id` or + `entity_row_fn`. + request_context_map: Optional mapping of request context parameters + to pass to Tecton for feature computation. These are typically used + for real-time features that depend on request-time data. + workspace_name: Optional workspace name override. If not provided, + uses the workspace from the connection config. + allow_partial_results: Whether to allow partial results if some features + fail to compute. Defaults to False. + request_options: Optional RequestOptions for controlling request behavior. + Defaults to None. + metadata_options: Optional MetadataOptions for controlling what metadata + is returned. Defaults to + MetadataOptions(include_names=True, include_data_types=True). + kwargs: Additional keyword arguments for feature retrieval. Enables forward + compatibility with future Tecton feature retrieval parameters. + """ + feature_service_name: str + entity_id: str = "" + entity_row_fn: Optional[EntityRowFn] = None + request_context_map: Optional[Mapping[str, Any]] = None + workspace_name: Optional[str] = None + allow_partial_results: bool = False + request_options: Optional[RequestOptions] = None + metadata_options: Optional[MetadataOptions] = field( + default_factory=lambda: MetadataOptions(include_names=True, + include_data_types=True)) + kwargs: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not self.feature_service_name: + raise ValueError( + 'Please provide a feature service name for the Tecton ' + 'online feature store (`feature_service_name`).') + + if ((not self.entity_row_fn and not self.entity_id) or + bool(self.entity_row_fn and self.entity_id)): + raise ValueError( + "Please specify exactly one of a `entity_id` or a lambda " + "function with `entity_row_fn` to extract the entity id " + "from the input row.") + +class TectonFeatureStoreEnrichmentHandler(EnrichmentSourceHandler[beam.Row, + beam.Row]): + """Enrichment handler to interact with the Tecton feature store. + + This handler fetches features from Tecton's online feature store using + a feature service name. + + Use this handler with :class:`apache_beam.transforms.enrichment.Enrichment` + transform. To filter the features to enrich, use the `join_fn` param in + :class:`apache_beam.transforms.enrichment.Enrichment`. + """ + def __init__( + self, + connection_config: TectonConnectionConfig, + features_retrieval_config: TectonFeaturesRetrievalConfig, + *, + exception_level: ExceptionLevel = ExceptionLevel.WARN, + ): + """Initializes an instance of `TectonFeatureStoreEnrichmentHandler`. + + Args: + connection_config: A `TectonConnectionConfig` dataclass containing + connection parameters (url, workspace_name, api_key). + features_retrieval_config: A `TectonFeaturesRetrievalConfig` dataclass + containing feature retrieval parameters (feature_service_name, + entity_id, entity_row_fn). + exception_level: a `enum.Enum` value from + `apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel` + to set the level when `None` feature values are fetched from the + online Tecton store. Defaults to `ExceptionLevel.WARN`. + """ + self._connection_config = connection_config + self._features_retrieval_config = features_retrieval_config + self._exception_level = exception_level + + def __enter__(self): + """Connect with the Tecton feature store.""" + self._client = TectonClient( + **unpack_dataclass_with_kwargs(self._connection_config)) + + def __call__(self, request: beam.Row, *args, **kwargs): + """Fetches feature values for an entity-id from the Tecton feature store. + + Args: + request: the input `beam.Row` to enrich. + """ + if self._features_retrieval_config.entity_row_fn: + entity = self._features_retrieval_config.entity_row_fn(request) + else: + request_dict = request._asdict() + entity = { + self._features_retrieval_config.entity_id: + request_dict[self._features_retrieval_config.entity_id] + } + + try: + config = unpack_dataclass_with_kwargs(self._features_retrieval_config) + config.pop('entity_id', None) + config.pop('entity_row_fn', None) + response = self._client.get_features(**config,join_key_map=entity) + feature_values = response.get_features_dict() + except Exception as e: + if self._exception_level == ExceptionLevel.RAISE: + raise RuntimeError( + f'Failed to fetch features from Tecton feature store: {e}') + elif self._exception_level == ExceptionLevel.WARN: + _LOGGER.warning( + f'Failed to fetch features from Tecton feature store: {e}') + feature_values = {} + else: # ExceptionLevel.QUIET + feature_values = {} + + return request, beam.Row(**feature_values) + + def __exit__(self, exc_type, exc_val, exc_tb): + """Clean the instantiated Tecton client.""" + self._client._client.close() + self._client = None + + def get_cache_key(self, request: beam.Row) -> str: + """Returns a string formatted with unique entity-id for the feature values. + """ + if self._features_retrieval_config.entity_row_fn: + entity = self._features_retrieval_config.entity_row_fn(request) + entity_id = list(entity.keys())[0] + else: + entity_id = self._features_retrieval_config.entity_id + return f'entity_id: {request._asdict()[entity_id]}' + + +def unpack_dataclass_with_kwargs(dataclass_instance): + """Unpacks dataclass fields into a flat dict, merging kwargs with precedence. + + Args: + dataclass_instance: Dataclass instance to unpack. + + Returns: + dict: Flattened dictionary with kwargs taking precedence over fields. + """ + params: dict = dataclass_instance.__dict__.copy() + nested_kwargs = params.pop('kwargs', {}) + return {**params, **nested_kwargs} diff --git a/sdks/python/setup.py b/sdks/python/setup.py index e7ffc0c9780c..ba7056939565 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -161,7 +161,7 @@ def cythonize(*args, **kwargs): ] milvus_dependency = ['pymilvus>=2.5.10,<3.0.0'] - +tecton_dependency = ['tecton-client>=0.4.0,<1.0.0'] def find_by_ext(root_dir, ext): for root, _, files in os.walk(root_dir): @@ -449,7 +449,7 @@ def get_portability_package_data(): 'pg8000>=1.31.1', "PyMySQL>=1.1.0", 'oracledb>=3.1.1' - ] + milvus_dependency, + ] + milvus_dependency + tecton_dependency, 'gcp': [ 'cachetools>=3.1.0,<7', 'google-api-core>=2.0.0,<3',