Skip to content

Commit 95f6c6d

Browse files
author
Carlos Hernandez
committed
add skorch compat
1 parent af9f1a7 commit 95f6c6d

File tree

3 files changed

+55
-2
lines changed

3 files changed

+55
-2
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
estimator:
2+
eval: Pipeline([
3+
('scale', RobustScaler()),
4+
('classifier', NeuralNetClassifier(nn.Sequential(nn.Linear(64, 32),
5+
nn.ReLU(),
6+
nn.Linear(32, 10),
7+
nn.Softmax(dim=1)),
8+
max_epochs=10)),
9+
])
10+
eval_scope: ['sklearn', 'torch']
11+
12+
scoring: accuracy
13+
14+
strategy:
15+
name: gp
16+
params:
17+
seeds: 5
18+
19+
search_space:
20+
classifier__lr:
21+
min: 1e-3
22+
max: 1e-1
23+
num: 10
24+
type: jump
25+
var_type: float
26+
warp: log
27+
28+
cv: 5
29+
30+
dataset_loader:
31+
name: sklearn_dataset
32+
params:
33+
method: load_digits
34+
35+
trials:
36+
uri: sqlite:///osprey-trials.db
37+
38+
random_seed: 42

osprey/eval_scopes.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sklearn.base import BaseEstimator
99

1010

11-
__all__ = ['msmbuilder', 'import_all_estimators']
11+
__all__ = ['msmbuilder', 'torch', 'import_all_estimators']
1212

1313

1414
def msmbuilder():
@@ -22,6 +22,20 @@ def msmbuilder():
2222
return scope
2323

2424

25+
def torch():
26+
with warnings.catch_warnings():
27+
warnings.filterwarnings("ignore", category=DeprecationWarning)
28+
import torch
29+
from torch import nn
30+
import skorch
31+
from sklearn.pipeline import Pipeline
32+
33+
scope = import_all_estimators(skorch)
34+
scope.update({'nn': nn})
35+
scope['Pipeline'] = Pipeline
36+
return scope
37+
38+
2539
def import_all_estimators(pkg):
2640

2741
def estimator_in_module(mod):

osprey/execute_skeleton.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
'random_example': 'random_example.yaml',
99
'gp_example': 'sklearn_skeleton_config.yaml',
1010
'grid_example': 'grid_example.yaml',
11-
'msmb_feat_select': 'msmb_feat_select_skeleton_config.yaml'}
11+
'msmb_feat_select': 'msmb_feat_select_skeleton_config.yaml',
12+
'torch': 'torch_skeleton_config.yaml'}
1213

1314

1415
def execute(args, parser):

0 commit comments

Comments
 (0)