diff --git a/devtools/conda-recipe/meta.yaml b/devtools/conda-recipe/meta.yaml index b97d318..b035bca 100644 --- a/devtools/conda-recipe/meta.yaml +++ b/devtools/conda-recipe/meta.yaml @@ -33,6 +33,8 @@ test: - nose - nose-timer - gpy + - skorch + - pytorch - msmbuilder - msmb_data - mdtraj diff --git a/devtools/travis-ci/build_docs.sh b/devtools/travis-ci/build_docs.sh index fa59651..a3a5fb9 100755 --- a/devtools/travis-ci/build_docs.sh +++ b/devtools/travis-ci/build_docs.sh @@ -8,11 +8,9 @@ conda create --yes -n docenv python=$CONDA_PY source activate docenv conda install -yq --use-local osprey - # Install doc requirements conda install --yes --file docs/requirements.txt - # We don't use conda for these: # sphinx_rtd_theme's latest releases are not available # neither is msmb_theme @@ -20,6 +18,7 @@ conda install --yes --file docs/requirements.txt pip install -I sphinx==1.3.5 pip install -I sphinx_rtd_theme==0.1.9 msmb_theme==1.2.0 + # Make docs cd docs && make html && cd - diff --git a/osprey/data/torch_skeleton_config.yaml b/osprey/data/torch_skeleton_config.yaml new file mode 100644 index 0000000..631f39c --- /dev/null +++ b/osprey/data/torch_skeleton_config.yaml @@ -0,0 +1,38 @@ +estimator: + eval: Pipeline([ + ('scale', RobustScaler()), + ('classifier', NeuralNetClassifier(nn.Sequential(nn.Linear(64, 32), + nn.ReLU(), + nn.Linear(32, 10), + nn.Softmax(dim=1)), + max_epochs=10)), + ]) + eval_scope: ['sklearn', 'torch'] + +scoring: accuracy + +strategy: + name: gp + params: + seeds: 5 + +search_space: + classifier__lr: + min: 1e-3 + max: 1e-1 + num: 10 + type: jump + var_type: float + warp: log + +cv: 5 + +dataset_loader: + name: sklearn_dataset + params: + method: load_digits + +trials: + uri: sqlite:///osprey-trials.db + +random_seed: 42 diff --git a/osprey/eval_scopes.py b/osprey/eval_scopes.py index cc6cbd4..ac3846f 100644 --- a/osprey/eval_scopes.py +++ b/osprey/eval_scopes.py @@ -7,7 +7,7 @@ from sklearn.base import BaseEstimator -__all__ = ['msmbuilder', 'import_all_estimators'] +__all__ = ['msmbuilder', 'torch', 'import_all_estimators'] def msmbuilder(): @@ -21,6 +21,20 @@ def msmbuilder(): return scope +def torch(): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + import torch + from torch import nn + import skorch + from sklearn.pipeline import Pipeline + + scope = import_all_estimators(skorch) + scope.update({'nn': nn}) + scope['Pipeline'] = Pipeline + return scope + + def import_all_estimators(pkg): def estimator_in_module(mod): for name, obj in inspect.getmembers(mod): diff --git a/osprey/execute_skeleton.py b/osprey/execute_skeleton.py index 3376aca..99a6246 100644 --- a/osprey/execute_skeleton.py +++ b/osprey/execute_skeleton.py @@ -8,7 +8,8 @@ 'random_example': 'random_example.yaml', 'bayes_example': 'sklearn_skeleton_config.yaml', 'grid_example': 'grid_example.yaml', - 'msmb_feat_select': 'msmb_feat_select_skeleton_config.yaml'} + 'msmb_feat_select': 'msmb_feat_select_skeleton_config.yaml', + 'torch': 'torch_skeleton_config.yaml'} def execute(args, parser): diff --git a/osprey/tests/test_cli_worker_and_dump.py b/osprey/tests/test_cli_worker_and_dump.py index 47a0ca0..4f0b87e 100644 --- a/osprey/tests/test_cli_worker_and_dump.py +++ b/osprey/tests/test_cli_worker_and_dump.py @@ -17,6 +17,13 @@ except: HAVE_MSMBUILDER = False +try: + __import__('skorch') + HAVE_SKORCH = True +except: + HAVE_SKORCH = False + + OSPREY_BIN = find_executable('osprey') @@ -151,6 +158,30 @@ def test_gp_example(): shutil.rmtree(dirname) +@skipif(not HAVE_SKORCH, 'this test requires Skorch') +def test_torch_example(): + assert OSPREY_BIN is not None + cwd = os.path.abspath(os.curdir) + dirname = tempfile.mkdtemp() + + try: + os.chdir(dirname) + subprocess.check_call([OSPREY_BIN, 'skeleton', '-t', 'torch', + '-f', 'config.yaml']) + subprocess.check_call([OSPREY_BIN, 'worker', 'config.yaml', '-n', '1']) + assert os.path.exists('osprey-trials.db') + + subprocess.check_call([OSPREY_BIN, 'current_best', 'config.yaml']) + + yield _test_dump_1 + + yield _test_plot_1 + + finally: + os.chdir(cwd) + shutil.rmtree(dirname) + + def test_grid_example(): assert OSPREY_BIN is not None cwd = os.path.abspath(os.curdir)