Skip to content

Commit 94c7e34

Browse files
authored
Merge pull request #16 from MaxSchambach/add-catboost-gpu-support
Add CatBoost GPU support
2 parents 0a68604 + 538c7e0 commit 94c7e34

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

pytabkit/models/alg_interfaces/catboost_interfaces.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,12 @@ def _get_eval_metric(self, val_metric_name: Optional[str], n_classes: int) -> Un
178178
# adapted from https://github.com/catboost/benchmarks/blob/master/quality_benchmarks/catboost_experiment.py
179179
def _preprocess_params(self, params: Dict[str, Any], n_classes: int) -> Dict[str, Any]:
180180
params = copy.deepcopy(params)
181+
182+
device = params.pop('device', None)
183+
if device is not None and device.startswith('cuda:'):
184+
params['task_type'] = 'GPU'
185+
params['devices'] = device.split(':')[1]
186+
181187
if n_classes == 0:
182188
train_metric_name = self.config.get('train_metric_name', 'mse')
183189
# val_metric_name = self.config.get('val_metric_name', 'rmse')

0 commit comments

Comments
 (0)