Skip to content

Commit 5c3179e

Browse files
authored
DOC, TST: Wrapping of PyTorch models (#699)
1 parent 868ac5d commit 5c3179e

File tree

11 files changed

+193
-19
lines changed

11 files changed

+193
-19
lines changed

ci/posix.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ jobs:
3333
- bash: conda env create --quiet --file=$(envFile) --name=dask-ml-test && conda list -n dask-ml-test
3434
displayName: "install"
3535

36+
- bash: |
37+
conda install -y -q pytorch cpuonly -c pytorch -n dask-ml-test
38+
source activate dask-ml-test
39+
pip install skorch
40+
displayName: "install PyTorch"
41+
condition: eq(variables['Build.SourceBranch'], 'refs/heads/master')
42+
3643
- script: |
3744
source activate dask-ml-test
3845
conda uninstall -y --force scikit-learn

dask_ml/model_selection/_hyperband.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def _get_SHAs(self, brackets):
388388
return SHAs
389389

390390
async def _fit(self, X, y, **fit_params):
391-
X, y, scorer = self._validate_parameters(X, y)
391+
X, y, scorer = await self._validate_parameters(X, y)
392392

393393
brackets = _get_hyperband_params(self.max_iter, eta=self.aggressiveness)
394394
SHAs = self._get_SHAs(brackets)

dask_ml/model_selection/_incremental.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -515,21 +515,25 @@ def __init__(
515515
self.prefix = prefix
516516
super(BaseIncrementalSearchCV, self).__init__(estimator, scoring=scoring)
517517

518-
def _validate_parameters(self, X, y):
518+
async def _validate_parameters(self, X, y):
519519
if (self.max_iter is not None) and self.max_iter < 1:
520520
raise ValueError(
521521
"Received max_iter={}. max_iter < 1 is not supported".format(
522522
self.max_iter
523523
)
524524
)
525525

526-
# Make sure dask arrays are passed so error on unknown chunk size is raised
527526
kwargs = dict(accept_unknown_chunks=True, accept_dask_dataframe=True)
528527
if not isinstance(X, dd.DataFrame):
529528
X = self._check_array(X, **kwargs)
530-
if not isinstance(y, dd.Series):
529+
if not isinstance(y, (dd.DataFrame, dd.Series)):
531530
y = self._check_array(y, ensure_2d=False, **kwargs)
532-
scorer = check_scoring(self.estimator, scoring=self.scoring)
531+
estimator = self.estimator
532+
if isinstance(estimator, Future):
533+
client = default_client()
534+
scorer = await client.submit(check_scoring, estimator, scoring=self.scoring)
535+
else:
536+
scorer = check_scoring(self.estimator, scoring=self.scoring)
533537
return X, y, scorer
534538

535539
@property
@@ -640,7 +644,7 @@ async def _fit(self, X, y, **fit_params):
640644
else:
641645
context = dummy_context()
642646

643-
X, y, scorer = self._validate_parameters(X, y)
647+
X, y, scorer = await self._validate_parameters(X, y)
644648

645649
X_train, X_test, y_train, y_test = self._get_train_test_split(X, y)
646650

dask_ml/wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def _fit_for_estimator(self, estimator, X, y, **fit_kwargs):
481481
random_state=self.random_state,
482482
shuffle_blocks=self.shuffle_blocks,
483483
assume_equal_chunks=self.assume_equal_chunks,
484-
**fit_kwargs
484+
**fit_kwargs,
485485
)
486486

487487
copy_learned_attributes(result, self)

docs/source/hyper-parameter-search.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ generalized to any of the above estimators.
403403

404404
.. note::
405405

406-
These estimators require that the model implement ``partial_fit``
406+
These estimators require that the model implement ``partial_fit``.
407407

408408
By default, these class will call ``partial_fit`` on each chunk of the data.
409409
These classes can stop training any models if their score stops increasing

docs/source/index.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,20 @@ Scikit-Learn should feel at home with Dask-ML.
120120
hyper-parameter-search.rst
121121
compose.rst
122122
glm.rst
123-
joblib.rst
124123
meta-estimators.rst
125124
incremental.rst
126125
clustering.rst
127-
xgboost.rst
128126
modules/api.rst
129127

128+
.. toctree::
129+
:maxdepth: 2
130+
:hidden:
131+
:caption: Integration
132+
133+
joblib.rst
134+
xgboost.rst
135+
pytorch.rst
136+
130137
.. toctree::
131138
:maxdepth: 2
132139
:hidden:

docs/source/joblib.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
.. _joblib:
22

3-
Joblib
4-
======
3+
Scikit-Learn & Joblib
4+
=====================
55

66
Many Scikit-Learn algorithms are written for parallel execution using
77
`Joblib <http://joblib.readthedocs.io/en/latest/>`__, which natively provides

docs/source/pytorch.rst

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
PyTorch
2+
=======
3+
4+
Skorch_ brings a Scikit-learn API to PyTorch_. Skorch allows PyTorch models to
5+
be wrapped in Scikit-learn compatible estimators. So, that means that PyTorch
6+
models wrapped in Skorch can be used with the rest of the Dask-ML API. For
7+
example, using Dask-ML's :class:`~dask_ml.model_selection.HyperbandSearchCV` or
8+
:class:`~dask_ml.model_selection.Incremental` with PyTorch is possible after
9+
wrapping with Skorch.
10+
11+
We encourage looking at the Skorch documentation for complete details.
12+
13+
Example usage
14+
-------------
15+
16+
First, let's create a normal PyTorch model:
17+
18+
.. code-block:: python
19+
20+
21+
import torch.nn as nn
22+
import torch.nn.functional as F
23+
24+
class ShallowNet(nn.Module):
25+
def __init__(self, n_features=5):
26+
super().__init__()
27+
self.layer1 = nn.Linear(n_features, 1)
28+
29+
def forward(self, x):
30+
return F.relu(self.layer1(x))
31+
32+
With this, it's easy to use Skorch:
33+
34+
.. code-block:: python
35+
36+
from skorch import NeuralNetRegressor
37+
import torch.optim as optim
38+
39+
niceties = {
40+
"callbacks": False,
41+
"warm_start": False,
42+
"train_split": None,
43+
"max_epochs": 1,
44+
}
45+
46+
model = NeuralNetRegressor(
47+
module=ShallowNet,
48+
module__n_features=5,
49+
criterion=nn.MSELoss,
50+
optimizer=optim.SGD,
51+
optimizer__lr=0.1,
52+
optimizer__momentum=0.9,
53+
batch_size=64,
54+
**niceties,
55+
)
56+
57+
Each parameter that the PyTorch ``nn.Module`` takes is prefixed with ``module__``,
58+
and same for the optimizer (``optim.SGD`` takes a ``lr`` and ``momentum``
59+
parameters). The ``niceties`` make sure Skorch uses all the data for training
60+
and doesn't print excessive amounts of logs.
61+
62+
Now, this model can be used with Dask-ML. For example, it's possible to do the
63+
following:
64+
65+
* Use PyTorch with the Dask-ML's model selection, including
66+
:class:`~dask_ml.model_selection.HyperbandSearchCV`.
67+
* Use PyTorch with Dask-ML's :class:`~dask_ml.wrappers.Incremental`.
68+
69+
.. _Skorch: https://skorch.readthedocs.io/en/stable/
70+
.. _PyTorch: https://pytorch.org

docs/source/xgboost.rst

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,29 @@
1-
XGBoost
2-
=======
1+
XGBoost & LightGBM
2+
==================
33

44
.. currentmodule:: dask_ml.xgboost
55

6+
XGBoost_ is a powerful and popular library for gradient boosted trees. For
7+
larger datasets or faster training XGBoost also provides a distributed
8+
computing solution. LightGBM_ is another library similar to XGBoost; it also
9+
natively supplies native distributed training for decision trees.
10+
11+
Dask-ML can set up distributed XGBoost or LightGBM for you and hand off data
12+
from distributed dask.dataframes. This automates much of the hassle of
13+
preprocessing and setup while still letting XGBoost/LightGBM do what they do
14+
well.
15+
16+
Below, we'll refer to an example with XGBoost. Here are the relevant XGBoost
17+
classes/functions:
18+
619
.. autosummary::
720
train
821
predict
922
XGBClassifier
1023
XGBRegressor
1124

12-
XGBoost_ is a powerful and popular library for gradient boosted trees. For
13-
larger datasets or faster training XGBoost also provides a distributed
14-
computing solution. Dask-ML can set up distributed XGBoost for you and hand
15-
off data from distributed dask.dataframes. This automates much of the hassle
16-
of preprocessing and setup while still letting XGBoost do what it does well.
25+
The LightGBM implementation and documentation can be found at
26+
https://github.com/dask/dask-lightgbm.
1727

1828
Example
1929
-------
@@ -63,3 +73,4 @@ relevant GitHub issue here: `dmlc/xgboost #2032 <https://github.com/dmlc/xgboost
6373
See the ":doc:`Dask-ML examples <examples>`" for an example usage.
6474

6575
.. _XGBoost: https://xgboost.readthedocs.io/
76+
.. _LightGBM: https://lightgbm.readthedocs.io/

tests/model_selection/test_incremental.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,3 +853,18 @@ def test_warns_scores_per_fit(c, s, a, b):
853853
search = IncrementalSearchCV(model, params, scores_per_fit=2)
854854
with pytest.warns(UserWarning, match="deprecated since Dask-ML v1.4.0"):
855855
yield search.fit(X, y)
856+
857+
858+
@gen_cluster(client=True)
859+
async def test_model_future(c, s, a, b):
860+
X, y = make_classification(n_samples=100, n_features=5, chunks=10)
861+
862+
params = {"value": np.random.RandomState(42).rand(1000)}
863+
model = ConstantFunction()
864+
model_future = await c.scatter(model)
865+
866+
search = IncrementalSearchCV(model_future, params, max_iter=10)
867+
868+
await search.fit(X, y, classes=[0, 1])
869+
assert search.history_
870+
assert search.best_score_ > 0

0 commit comments

Comments
 (0)