@@ -207,7 +207,8 @@ def _ingest_single_batch(
207207 for row in data_frame [start_index :end_index ].itertuples ():
208208 record = [
209209 FeatureValue (
210- feature_name = data_frame .columns [index - 1 ], value_as_string = str (row [index ])
210+ feature_name = data_frame .columns [index - 1 ],
211+ value_as_string = str (row [index ]),
211212 )
212213 for index in range (1 , len (row ))
213214 if pd .notna (row [index ])
@@ -270,13 +271,24 @@ def _run_multi_process(self, data_frame: DataFrame, wait=True, timeout=None):
270271 timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised
271272 if timeout is reached.
272273 """
274+ # pylint: disable=I1101
273275 batch_size = math .ceil (data_frame .shape [0 ] / self .max_processes )
276+ # pylint: enable=I1101
274277
275278 args = []
276279 for i in range (self .max_processes ):
277280 start_index = min (i * batch_size , data_frame .shape [0 ])
278281 end_index = min (i * batch_size + batch_size , data_frame .shape [0 ])
279- args += [(data_frame [start_index :end_index ], start_index , timeout )]
282+ args += [
283+ (
284+ self .max_workers ,
285+ self .feature_group_name ,
286+ self .sagemaker_fs_runtime_client_config ,
287+ data_frame [start_index :end_index ],
288+ start_index ,
289+ timeout ,
290+ )
291+ ]
280292
281293 def init_worker ():
282294 # ignore keyboard interrupts in child processes.
@@ -285,13 +297,21 @@ def init_worker():
285297 self ._processing_pool = ProcessingPool (self .max_processes , init_worker )
286298 self ._processing_pool .restart (force = True )
287299
288- f = lambda x : self ._run_multi_threaded (* x ) # noqa: E731
300+ f = lambda x : IngestionManagerPandas ._run_multi_threaded (* x ) # noqa: E731
289301 self ._async_result = self ._processing_pool .amap (f , args )
290302
291303 if wait :
292304 self .wait (timeout = timeout )
293305
294- def _run_multi_threaded (self , data_frame : DataFrame , row_offset = 0 , timeout = None ) -> List [int ]:
306+ @staticmethod
307+ def _run_multi_threaded (
308+ max_workers : int ,
309+ feature_group_name : str ,
310+ sagemaker_fs_runtime_client_config : Config ,
311+ data_frame : DataFrame ,
312+ row_offset = 0 ,
313+ timeout = None ,
314+ ) -> List [int ]:
295315 """Start the ingestion process.
296316
297317 Args:
@@ -305,21 +325,23 @@ def _run_multi_threaded(self, data_frame: DataFrame, row_offset=0, timeout=None)
305325 Returns:
306326 List of row indices that failed to be ingested.
307327 """
308- executor = ThreadPoolExecutor (max_workers = self .max_workers )
309- batch_size = math .ceil (data_frame .shape [0 ] / self .max_workers )
328+ executor = ThreadPoolExecutor (max_workers = max_workers )
329+ # pylint: disable=I1101
330+ batch_size = math .ceil (data_frame .shape [0 ] / max_workers )
331+ # pylint: enable=I1101
310332
311333 futures = {}
312- for i in range (self . max_workers ):
334+ for i in range (max_workers ):
313335 start_index = min (i * batch_size , data_frame .shape [0 ])
314336 end_index = min (i * batch_size + batch_size , data_frame .shape [0 ])
315337 futures [
316338 executor .submit (
317- self ._ingest_single_batch ,
318- feature_group_name = self . feature_group_name ,
339+ IngestionManagerPandas ._ingest_single_batch ,
340+ feature_group_name = feature_group_name ,
319341 data_frame = data_frame ,
320342 start_index = start_index ,
321343 end_index = end_index ,
322- client_config = self . sagemaker_fs_runtime_client_config ,
344+ client_config = sagemaker_fs_runtime_client_config ,
323345 )
324346 ] = (start_index + row_offset , end_index + row_offset )
325347
0 commit comments