Skip to content

Commit b82d772

Browse files
committed
Fix num_batches usage
1 parent 9461fad commit b82d772

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

sklbench/benchmarks/sklearn_estimator.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -337,12 +337,11 @@ def verify_patching(stream: io.StringIO, function_name) -> bool:
337337
def create_online_function(
338338
estimator_instance, method_instance, data_args, num_batches, batch_size
339339
):
340-
n_batches = data_args[0].shape[0] // batch_size
341340

342341
if "y" in list(inspect.signature(method_instance).parameters):
343342

344343
def ndarray_function(x, y):
345-
for i in range(n_batches):
344+
for i in range(num_batches):
346345
method_instance(
347346
x[i * batch_size : (i + 1) * batch_size],
348347
y[i * batch_size : (i + 1) * batch_size],
@@ -351,7 +350,7 @@ def ndarray_function(x, y):
351350
estimator_instance._onedal_finalize_fit()
352351

353352
def dataframe_function(x, y):
354-
for i in range(n_batches):
353+
for i in range(num_batches):
355354
method_instance(
356355
x.iloc[i * batch_size : (i + 1) * batch_size],
357356
y.iloc[i * batch_size : (i + 1) * batch_size],
@@ -362,13 +361,13 @@ def dataframe_function(x, y):
362361
else:
363362

364363
def ndarray_function(x):
365-
for i in range(n_batches):
364+
for i in range(num_batches):
366365
method_instance(x[i * batch_size : (i + 1) * batch_size])
367366
if hasattr(estimator_instance, "_onedal_finalize_fit"):
368367
estimator_instance._onedal_finalize_fit()
369368

370369
def dataframe_function(x):
371-
for i in range(n_batches):
370+
for i in range(num_batches):
372371
method_instance(x.iloc[i * batch_size : (i + 1) * batch_size])
373372
if hasattr(estimator_instance, "_onedal_finalize_fit"):
374373
estimator_instance._onedal_finalize_fit()

0 commit comments

Comments
 (0)