Skip to content

Commit 6c4981a

Browse files
committed
allow cuda as a device name [no ci]
1 parent 9f2ca1b commit 6c4981a

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

pytabkit/models/sklearn/sklearn_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,9 @@ def fit(self, X, y, X_val: Optional = None, y_val: Optional = None, val_idxs: Op
347347
n_physical_threads = max(1, n_logical_threads//2)
348348

349349
device = params.get('device', None)
350+
if device == 'cuda':
351+
device = 'cuda:0' # 'cuda' doesn't work with some of the code
352+
350353
n_threads = params.get('n_threads', n_physical_threads)
351354
self.n_threads_ = n_threads
352355

0 commit comments

Comments
 (0)