@@ -337,12 +337,11 @@ def verify_patching(stream: io.StringIO, function_name) -> bool:
337
337
def create_online_function (
338
338
estimator_instance , method_instance , data_args , num_batches , batch_size
339
339
):
340
- n_batches = data_args [0 ].shape [0 ] // batch_size
341
340
342
341
if "y" in list (inspect .signature (method_instance ).parameters ):
343
342
344
343
def ndarray_function (x , y ):
345
- for i in range (n_batches ):
344
+ for i in range (num_batches ):
346
345
method_instance (
347
346
x [i * batch_size : (i + 1 ) * batch_size ],
348
347
y [i * batch_size : (i + 1 ) * batch_size ],
@@ -351,7 +350,7 @@ def ndarray_function(x, y):
351
350
estimator_instance ._onedal_finalize_fit ()
352
351
353
352
def dataframe_function (x , y ):
354
- for i in range (n_batches ):
353
+ for i in range (num_batches ):
355
354
method_instance (
356
355
x .iloc [i * batch_size : (i + 1 ) * batch_size ],
357
356
y .iloc [i * batch_size : (i + 1 ) * batch_size ],
@@ -362,13 +361,13 @@ def dataframe_function(x, y):
362
361
else :
363
362
364
363
def ndarray_function (x ):
365
- for i in range (n_batches ):
364
+ for i in range (num_batches ):
366
365
method_instance (x [i * batch_size : (i + 1 ) * batch_size ])
367
366
if hasattr (estimator_instance , "_onedal_finalize_fit" ):
368
367
estimator_instance ._onedal_finalize_fit ()
369
368
370
369
def dataframe_function (x ):
371
- for i in range (n_batches ):
370
+ for i in range (num_batches ):
372
371
method_instance (x .iloc [i * batch_size : (i + 1 ) * batch_size ])
373
372
if hasattr (estimator_instance , "_onedal_finalize_fit" ):
374
373
estimator_instance ._onedal_finalize_fit ()
0 commit comments