3030from typing import Sequence , List , Dict , Any , Union
3131from urllib .parse import urlparse
3232
33+ from multiprocessing .pool import AsyncResult
34+ import signal
3335import attr
3436import pandas as pd
3537from pandas import DataFrame
3638
39+ import boto3
40+ from botocore .config import Config
41+ from pathos .multiprocessing import ProcessingPool
42+
3743from sagemaker import Session
3844from sagemaker .feature_store .feature_definition import (
3945 FeatureDefinition ,
@@ -150,23 +156,27 @@ class IngestionManagerPandas:
150156
151157 Attributes:
152158 feature_group_name (str): name of the Feature Group.
153- sagemaker_session (Session): instance of the Session class to perform boto calls.
159+ sagemaker_fs_runtime_client_config (Config): instance of the Config class
160+ for boto calls.
154161 data_frame (DataFrame): pandas DataFrame to be ingested to the given feature group.
155162 max_workers (int): number of threads to create.
163+ max_processes (int): number of processes to create. Each process spawns
164+ ``max_workers`` threads.
156165 """
157166
158167 feature_group_name : str = attr .ib ()
159- sagemaker_session : Session = attr .ib ()
160- data_frame : DataFrame = attr .ib ()
168+ sagemaker_fs_runtime_client_config : Config = attr .ib ()
161169 max_workers : int = attr .ib (default = 1 )
162- _futures : Dict [Any , Any ] = attr .ib (init = False , factory = dict )
170+ max_processes : int = attr .ib (default = 1 )
171+ _async_result : AsyncResult = attr .ib (default = None )
172+ _processing_pool : ProcessingPool = attr .ib (default = None )
163173 _failed_indices : List [int ] = attr .ib (factory = list )
164174
165175 @staticmethod
166176 def _ingest_single_batch (
167177 data_frame : DataFrame ,
168178 feature_group_name : str ,
169- sagemaker_session : Session ,
179+ client_config : Config ,
170180 start_index : int ,
171181 end_index : int ,
172182 ) -> List [int ]:
@@ -175,13 +185,18 @@ def _ingest_single_batch(
175185 Args:
176186 data_frame (DataFrame): source DataFrame to be ingested.
177187 feature_group_name (str): name of the Feature Group.
178- sagemaker_session (Session): session instance to perform boto calls.
188+ client_config (Config): Configuration for the sagemaker feature store runtime
189+ client to perform boto calls.
179190 start_index (int): starting position to ingest in this batch.
180191 end_index (int): ending position to ingest in this batch.
181192
182193 Returns:
183194 List of row indices that failed to be ingested.
184195 """
196+ sagemaker_featurestore_runtime_client = boto3 .Session ().client (
197+ service_name = "sagemaker-featurestore-runtime" , config = client_config
198+ )
199+
185200 logger .info ("Started ingesting index %d to %d" , start_index , end_index )
186201 failed_rows = list ()
187202 for row in data_frame [start_index :end_index ].itertuples ():
@@ -193,9 +208,9 @@ def _ingest_single_batch(
193208 if pd .notna (row [index ])
194209 ]
195210 try :
196- sagemaker_session .put_record (
197- feature_group_name = feature_group_name ,
198- record = [value .to_dict () for value in record ],
211+ sagemaker_featurestore_runtime_client .put_record (
212+ FeatureGroupName = feature_group_name ,
213+ Record = [value .to_dict () for value in record ],
199214 )
200215 except Exception as e : # pylint: disable=broad-except
201216 logger .error ("Failed to ingest row %d: %s" , row [0 ], e )
@@ -204,7 +219,7 @@ def _ingest_single_batch(
204219
205220 @property
206221 def failed_rows (self ) -> List [int ]:
207- """Get rows that failed to ingest
222+ """Get rows that failed to ingest.
208223
209224 Returns:
210225 List of row indices that failed to be ingested.
@@ -218,52 +233,134 @@ def wait(self, timeout=None):
218233 timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
219234 if timeout is reached.
220235 """
221- self ._failed_indices = list ()
222- for future in as_completed (self ._futures , timeout = timeout ):
223- start , end = self ._futures [future ]
224- result = future .result ()
225- if result :
226- logger .error ("Failed to ingest row %d to %d" , start , end )
227- else :
228- logger .info ("Successfully ingested row %d to %d" , start , end )
229- self ._failed_indices += result
236+ try :
237+ results = self ._async_result .get (timeout = timeout )
238+ except KeyboardInterrupt as i :
239+ # terminate workers abruptly on keyboard interrupt.
240+ self ._processing_pool .terminate ()
241+ self ._processing_pool .close ()
242+ self ._processing_pool .clear ()
243+ raise i
244+ else :
245+ # terminate normally
246+ self ._processing_pool .close ()
247+ self ._processing_pool .clear ()
248+
249+ self ._failed_indices = [
250+ failed_index for failed_indices in results for failed_index in failed_indices
251+ ]
230252
231253 if len (self ._failed_indices ) > 0 :
232- raise RuntimeError (
233- f"Failed to ingest some data into FeatureGroup { self .feature_group_name } "
254+ raise IngestionError (
255+ self ._failed_indices ,
256+ f"Failed to ingest some data into FeatureGroup { self .feature_group_name } " ,
234257 )
235258
236- def run (self , wait = True , timeout = None ):
259+ def _run_multi_process (self , data_frame : DataFrame , wait = True , timeout = None ):
260+ """Start the ingestion process with the specified number of processes.
261+
262+ Args:
263+ data_frame (DataFrame): source DataFrame to be ingested.
264+ wait (bool): whether to wait for the ingestion to finish or not.
265+ timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
266+ if timeout is reached.
267+ """
268+ batch_size = math .ceil (data_frame .shape [0 ] / self .max_processes )
269+
270+ args = []
271+ for i in range (self .max_processes ):
272+ start_index = min (i * batch_size , data_frame .shape [0 ])
273+ end_index = min (i * batch_size + batch_size , data_frame .shape [0 ])
274+ args += [(data_frame [start_index :end_index ], start_index , timeout )]
275+
276+ def init_worker ():
277+ # ignore keyboard interrupts in child processes.
278+ signal .signal (signal .SIGINT , signal .SIG_IGN )
279+
280+ self ._processing_pool = ProcessingPool (self .max_processes , init_worker )
281+ self ._processing_pool .restart (force = True )
282+
283+ f = lambda x : self ._run_multi_threaded (* x ) # noqa: E731
284+ self ._async_result = self ._processing_pool .amap (f , args )
285+
286+ if wait :
287+ self .wait (timeout = timeout )
288+
289+ def _run_multi_threaded (self , data_frame : DataFrame , row_offset = 0 , timeout = None ) -> List [int ]:
237290 """Start the ingestion process.
238291
239292 Args:
293+ data_frame (DataFrame): source DataFrame to be ingested.
294+ row_offset (int): if ``data_frame`` is a partition of a parent DataFrame, then the
295+ index of the parent where ``data_frame`` starts. Otherwise, 0.
240296 wait (bool): whether to wait for the ingestion to finish or not.
241297 timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
242- if timeout is reached.
298+ if timeout is reached.
299+
300+ Returns:
301+ List of row indices that failed to be ingested.
243302 """
244303 executor = ThreadPoolExecutor (max_workers = self .max_workers )
245- batch_size = math .ceil (self . data_frame .shape [0 ] / self .max_workers )
304+ batch_size = math .ceil (data_frame .shape [0 ] / self .max_workers )
246305
247306 futures = {}
248307 for i in range (self .max_workers ):
249- start_index = min (i * batch_size , self . data_frame .shape [0 ])
250- end_index = min (i * batch_size + batch_size , self . data_frame .shape [0 ])
308+ start_index = min (i * batch_size , data_frame .shape [0 ])
309+ end_index = min (i * batch_size + batch_size , data_frame .shape [0 ])
251310 futures [
252311 executor .submit (
253312 self ._ingest_single_batch ,
254313 feature_group_name = self .feature_group_name ,
255- sagemaker_session = self .sagemaker_session ,
256- data_frame = self .data_frame ,
314+ data_frame = data_frame ,
257315 start_index = start_index ,
258316 end_index = end_index ,
317+ client_config = self .sagemaker_fs_runtime_client_config ,
259318 )
260- ] = (start_index , end_index )
319+ ] = (start_index + row_offset , end_index + row_offset )
320+
321+ failed_indices = list ()
322+ for future in as_completed (futures , timeout = timeout ):
323+ start , end = futures [future ]
324+ result = future .result ()
325+ if result :
326+ logger .error ("Failed to ingest row %d to %d" , start , end )
327+ else :
328+ logger .info ("Successfully ingested row %d to %d" , start , end )
329+ failed_indices += result
261330
262- self ._futures = futures
263- if wait :
264- self .wait (timeout = timeout )
265331 executor .shutdown (wait = False )
266332
333+ return failed_indices
334+
335+ def run (self , data_frame : DataFrame , wait = True , timeout = None ):
336+ """Start the ingestion process.
337+
338+ Args:
339+ data_frame (DataFrame): source DataFrame to be ingested.
340+ wait (bool): whether to wait for the ingestion to finish or not.
341+ timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
342+ if timeout is reached.
343+ """
344+ self ._run_multi_process (data_frame = data_frame , wait = wait , timeout = timeout )
345+
346+
347+ class IngestionError (Exception ):
348+ """Exception raised for errors during ingestion.
349+
350+ Attributes:
351+ failed_rows: list of indices from the data frame for which ingestion failed.
352+ message: explanation of the error
353+ """
354+
355+ def __init__ (self , failed_rows , message ):
356+ super (IngestionError , self ).__init__ (message )
357+ self .failed_rows = failed_rows
358+ self .message = message
359+
360+ def __str__ (self ) -> str :
361+ """String representation of the error."""
362+ return f"{ self .failed_rows } -> { self .message } "
363+
267364
268365@attr .s
269366class FeatureGroup :
@@ -447,6 +544,7 @@ def ingest(
447544 self ,
448545 data_frame : DataFrame ,
449546 max_workers : int = 1 ,
547+ max_processes : int = 1 ,
450548 wait : bool = True ,
451549 timeout : Union [int , float ] = None ,
452550 ) -> IngestionManagerPandas :
@@ -455,23 +553,45 @@ def ingest(
455553 ``max_worker`` number of thread will be created to work on different partitions of
456554 the ``data_frame`` in parallel.
457555
556+ ``max_processes`` number of processes will be created to work on different partitions
557+ of the ``data_frame`` in parallel, each with ``max_worker`` threads.
558+
559+ The ingest function will attempt to ingest all records in the data frame. If ``wait``
560+ is True, then an exception is thrown after all records have been processed. If ``wait``
561+ is False, then a later call to the returned instance IngestionManagerPandas' ``wait()``
562+ function will throw an exception.
563+
564+ Zero based indices of rows that failed to be ingested can be found in the exception.
565+ They can also be found from the IngestionManagerPandas' ``failed_rows`` function after
566+ the exception is thrown.
567+
458568 Args:
459569 data_frame (DataFrame): data_frame to be ingested to feature store.
460570 max_workers (int): number of threads to be created.
571+ max_processes (int): number of processes to be created. Each process spawns
572+ ``max_worker`` number of threads.
461573 wait (bool): whether to wait for the ingestion to finish or not.
462574 timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
463575 if timeout is reached.
464576
465577 Returns:
466578 An instance of IngestionManagerPandas.
467579 """
580+ if max_processes <= 0 :
581+ raise RuntimeError ("max_processes must be greater than 0." )
582+
583+ if max_workers <= 0 :
584+ raise RuntimeError ("max_workers must be greater than 0." )
585+
468586 manager = IngestionManagerPandas (
469587 feature_group_name = self .name ,
470- sagemaker_session = self .sagemaker_session ,
471- data_frame = data_frame ,
588+ sagemaker_fs_runtime_client_config = self .sagemaker_session .sagemaker_featurestore_runtime_client .meta .config ,
472589 max_workers = max_workers ,
590+ max_processes = max_processes ,
473591 )
474- manager .run (wait = wait , timeout = timeout )
592+
593+ manager .run (data_frame = data_frame , wait = wait , timeout = timeout )
594+
475595 return manager
476596
477597 def athena_query (self ) -> AthenaQuery :
0 commit comments