Skip to content

Commit aeb3e45

Browse files
committed
moved callback to constructor
1 parent d7bc984 commit aeb3e45

File tree

3 files changed

+117
-42
lines changed

3 files changed

+117
-42
lines changed

docs/how-to-guides/client-callback-function.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,29 @@ A callback function is a function that is sent to another function as an argumen
1010

1111
## How to use callback functions
1212

13-
Currently the below functions in feathr client support passing a callback as an argument:
13+
We can pass a callback function when initializing the feathr client.
14+
15+
```python
16+
client = FeathrClient(config_path, callback)
17+
```
18+
19+
The below functions accept an optional parameters named **params**. params is a dictionary where user can pass the arguments for the callback function.
1420

1521
- get_online_features
1622
- multi_get_online_features
1723
- get_offline_features
1824
- monitor_features
1925
- materialize_features
2026

21-
These functions accept two optional parameters named **callback** and **params**.
22-
callback is of type function and params is a dictionary where user can pass the arguments for the callback function.
23-
2427
An example on how to use it:
2528

2629
```python
2730
# inside notebook
28-
client = FeathrClient(config_path)
29-
client.get_offline_features(observation_settings,feature_query,output_path, callback, params)
30-
31-
# users can define their own callback function and params
31+
client = FeathrClient(config_path, callback)
3232
params = {"param1":"value1", "param2":"value2"}
33+
client.get_offline_features(observation_settings,feature_query,output_path, params)
3334

35+
# users can define their own callback function
3436
async def callback(params):
3537
import httpx
3638
async with httpx.AsyncClient() as requestHandler:

feathr_project/feathr/client.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,13 @@ class FeathrClient(object):
8383
local_workspace_dir (str, optional): set where is the local work space dir. If not set, Feathr will create a temporary folder to store local workspace related files.
8484
credential (optional): credential to access cloud resources, most likely to be the returned result of DefaultAzureCredential(). If not set, Feathr will initialize DefaultAzureCredential() inside the __init__ function to get credentials.
8585
project_registry_tag (Dict[str, str]): adding tags for project in Feathr registry. This might be useful if you want to tag your project as deprecated, or allow certain customizations on project leve. Default is empty
86+
callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread. This is optional.
8687
8788
Raises:
8889
RuntimeError: Fail to create the client since necessary environment variables are not set for Redis
8990
client creation.
9091
"""
91-
def __init__(self, config_path:str = "./feathr_config.yaml", local_workspace_dir: str = None, credential=None, project_registry_tag: Dict[str, str]=None):
92+
def __init__(self, config_path:str = "./feathr_config.yaml", local_workspace_dir: str = None, credential=None, project_registry_tag: Dict[str, str]=None, callback:callable = None):
9293
self.logger = logging.getLogger(__name__)
9394
# Redis key separator
9495
self._KEY_SEPARATOR = ':'
@@ -183,6 +184,7 @@ def __init__(self, config_path:str = "./feathr_config.yaml", local_workspace_dir
183184
'feature_registry', 'purview', 'purview_name')
184185
# initialize the registry no matter whether we set purview name or not, given some of the methods are used there.
185186
self.registry = _FeatureRegistry(self.project_name, self.azure_purview_name, self.registry_delimiter, project_registry_tag, config_path = config_path, credential=self.credential)
187+
self.callback = callback
186188

187189
def _check_required_environment_variables_exist(self):
188190
"""Checks if the required environment variables(form feathr_config.yaml) is set.
@@ -265,15 +267,14 @@ def _get_registry_client(self):
265267
"""
266268
return self.registry._get_registry_client()
267269

268-
def get_online_features(self, feature_table, key, feature_names, callback: callable = None, params: dict = None):
270+
def get_online_features(self, feature_table, key, feature_names, params: dict = None):
269271
"""Fetches feature value for a certain key from a online feature table. There is an optional callback function
270272
and the params to extend this function's capability.For eg. cosumer of the features.
271273
272274
Args:
273275
feature_table: the name of the feature table.
274276
key: the key of the entity
275277
feature_names: list of feature names to fetch
276-
callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread.
277278
params: a dictionary of parameters for the callback function
278279
279280
Return:
@@ -288,19 +289,18 @@ def get_online_features(self, feature_table, key, feature_names, callback: calla
288289
redis_key = self._construct_redis_key(feature_table, key)
289290
res = self.redis_clint.hmget(redis_key, *feature_names)
290291
feature_values = self._decode_proto(res)
291-
if (callback is not None) and (params is not None):
292+
if (self.callback is not None) and (params is not None):
292293
event_loop = asyncio.get_event_loop()
293-
event_loop.create_task(callback(params))
294+
event_loop.create_task(self.callback(params))
294295
return feature_values
295296

296-
def multi_get_online_features(self, feature_table, keys, feature_names, callback: callable = None, params: dict = None):
297+
def multi_get_online_features(self, feature_table, keys, feature_names, params: dict = None):
297298
"""Fetches feature value for a list of keys from a online feature table. This is the batch version of the get API.
298299
299300
Args:
300301
feature_table: the name of the feature table.
301302
keys: list of keys for the entities
302303
feature_names: list of feature names to fetch
303-
callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread.
304304
params: a dictionary of parameters for the callback function
305305
306306
Return:
@@ -322,9 +322,9 @@ def multi_get_online_features(self, feature_table, keys, feature_names, callback
322322
for feature_list in pipeline_result:
323323
decoded_pipeline_result.append(self._decode_proto(feature_list))
324324

325-
if (callback is not None) and (params is not None):
325+
if (self.callback is not None) and (params is not None):
326326
event_loop = asyncio.get_event_loop()
327-
event_loop.create_task(callback(params))
327+
event_loop.create_task(self.callback(params))
328328

329329
return dict(zip(keys, decoded_pipeline_result))
330330

@@ -427,7 +427,6 @@ def get_offline_features(self,
427427
execution_configurations: Union[SparkExecutionConfiguration ,Dict[str,str]] = {},
428428
udf_files = None,
429429
verbose: bool = False,
430-
callback: callable = None,
431430
params: dict = None
432431
):
433432
"""
@@ -438,7 +437,6 @@ def get_offline_features(self,
438437
feature_query: features that are requested to add onto the observation data
439438
output_path: output path of job, i.e. the observation data with features attached.
440439
execution_configuratons: a dict that will be passed to spark job when the job starts up, i.e. the "spark configurations". Note that not all of the configuration will be honored since some of the configurations are managed by the Spark platform, such as Databricks or Azure Synapse. Refer to the [spark documentation](https://spark.apache.org/docs/latest/configuration.html) for a complete list of spark configurations.
441-
callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread.
442440
params: a dictionary of parameters for the callback function
443441
"""
444442
feature_queries = feature_query if isinstance(feature_query, List) else [feature_query]
@@ -476,10 +474,10 @@ def get_offline_features(self,
476474
FeaturePrinter.pretty_print_feature_query(feature_query)
477475

478476
write_to_file(content=config, full_file_name=config_file_path)
479-
job_info = self._get_offline_features_with_config(config_file_path, execution_configuratons, udf_files=udf_files)
480-
if (callback is not None) and (params is not None):
477+
job_info = self._get_offline_features_with_config(config_file_path, execution_configurations, udf_files=udf_files)
478+
if (self.callback is not None) and (params is not None):
481479
event_loop = asyncio.get_event_loop()
482-
event_loop.create_task(callback(params))
480+
event_loop.create_task(self.callback(params))
483481
return job_info
484482

485483
def _get_offline_features_with_config(self, feature_join_conf_path='feature_join_conf/feature_join.conf', execution_configurations: Dict[str,str] = {}, udf_files=[]):
@@ -557,29 +555,27 @@ def wait_job_to_finish(self, timeout_sec: int = 300):
557555
else:
558556
raise RuntimeError('Spark job failed.')
559557

560-
def monitor_features(self, settings: MonitoringSettings, execution_configuratons: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False, callback: callable = None, params: dict = None):
558+
def monitor_features(self, settings: MonitoringSettings, execution_configuratons: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False, params: dict = None):
561559
"""Create a offline job to generate statistics to monitor feature data. There is an optional
562560
callback function and the params to extend this function's capability.For eg. cosumer of the features.
563561
564562
Args:
565563
settings: Feature monitoring settings
566564
execution_configuratons: a dict that will be passed to spark job when the job starts up, i.e. the "spark configurations". Note that not all of the configuration will be honored since some of the configurations are managed by the Spark platform, such as Databricks or Azure Synapse. Refer to the [spark documentation](https://spark.apache.org/docs/latest/configuration.html) for a complete list of spark configurations.
567-
callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread.
568565
params: a dictionary of parameters for the callback function.
569566
"""
570567
self.materialize_features(settings, execution_configuratons, verbose)
571-
if (callback is not None) and (params is not None):
568+
if (self.callback is not None) and (params is not None):
572569
event_loop = asyncio.get_event_loop()
573-
event_loop.create_task(callback(params))
570+
event_loop.create_task(self.callback(params))
574571

575-
def materialize_features(self, settings: MaterializationSettings, execution_configuratons: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False, callback: callable = None, params: dict = None):
572+
def materialize_features(self, settings: MaterializationSettings, execution_configurations: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False, params: dict = None):
576573
"""Materialize feature data. There is an optional callback function and the params
577574
to extend this function's capability.For eg. cosumer of the feature store.
578575
579576
Args:
580577
settings: Feature materialization settings
581578
execution_configuratons: a dict that will be passed to spark job when the job starts up, i.e. the "spark configurations". Note that not all of the configuration will be honored since some of the configurations are managed by the Spark platform, such as Databricks or Azure Synapse. Refer to the [spark documentation](https://spark.apache.org/docs/latest/configuration.html) for a complete list of spark configurations.
582-
callback: an async callback function that will be called after execution of the original logic. This callback should not block the thread.
583579
params: a dictionary of parameters for the callback function
584580
"""
585581
# produce materialization config
@@ -608,9 +604,9 @@ def materialize_features(self, settings: MaterializationSettings, execution_conf
608604
if verbose and settings:
609605
FeaturePrinter.pretty_print_materialize_features(settings)
610606

611-
if (callback is not None) and (params is not None):
607+
if (self.callback is not None) and (params is not None):
612608
event_loop = asyncio.get_event_loop()
613-
event_loop.create_task(callback(params))
609+
event_loop.create_task(self.callback(params))
614610

615611
def _materialize_features_with_config(self, feature_gen_conf_path: str = 'feature_gen_conf/feature_gen.conf',execution_configurations: Dict[str,str] = {}, udf_files=[]):
616612
"""Materializes feature data based on the feature generation config. The feature

0 commit comments

Comments
 (0)