Skip to content

Commit 9461fad

Browse files
committed
Add condition for finalize
1 parent 3ac5c23 commit 9461fad

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

sklbench/benchmarks/sklearn_estimator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,27 +347,31 @@ def ndarray_function(x, y):
347347
x[i * batch_size : (i + 1) * batch_size],
348348
y[i * batch_size : (i + 1) * batch_size],
349349
)
350-
estimator_instance._onedal_finalize_fit()
350+
if hasattr(estimator_instance, "_onedal_finalize_fit"):
351+
estimator_instance._onedal_finalize_fit()
351352

352353
def dataframe_function(x, y):
353354
for i in range(n_batches):
354355
method_instance(
355356
x.iloc[i * batch_size : (i + 1) * batch_size],
356357
y.iloc[i * batch_size : (i + 1) * batch_size],
357358
)
358-
estimator_instance._onedal_finalize_fit()
359+
if hasattr(estimator_instance, "_onedal_finalize_fit"):
360+
estimator_instance._onedal_finalize_fit()
359361

360362
else:
361363

362364
def ndarray_function(x):
363365
for i in range(n_batches):
364366
method_instance(x[i * batch_size : (i + 1) * batch_size])
365-
estimator_instance._onedal_finalize_fit()
367+
if hasattr(estimator_instance, "_onedal_finalize_fit"):
368+
estimator_instance._onedal_finalize_fit()
366369

367370
def dataframe_function(x):
368371
for i in range(n_batches):
369372
method_instance(x.iloc[i * batch_size : (i + 1) * batch_size])
370-
estimator_instance._onedal_finalize_fit()
373+
if hasattr(estimator_instance, "_onedal_finalize_fit"):
374+
estimator_instance._onedal_finalize_fit()
371375

372376
if "ndarray" in str(type(data_args[0])):
373377
return ndarray_function

0 commit comments

Comments
 (0)