Skip to content

Commit 31be799

Browse files
krshrimalirohitgr7akihironitta
authored
Doc Fix: Passing datamodule argument to trainer.tune is supported (#12406)
Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Akihiro Nitta <[email protected]>
1 parent 59b0ccb commit 31be799

File tree

1 file changed

+49
-15
lines changed

1 file changed

+49
-15
lines changed

docs/source/advanced/training_tricks.rst

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ longer training time. Inspired by https://github.com/BlackHC/toma.
8989
# Autoscale batch size
9090
trainer = Trainer(auto_scale_batch_size=None | "power" | "binsearch")
9191
92-
# find the batch size
92+
# Find the batch size
9393
trainer.tune(model)
9494
9595
Currently, this feature supports two modes ``'power'`` scaling and ``'binsearch'``
@@ -105,18 +105,48 @@ search for batch sizes larger than the size of the training dataset.
105105

106106
This feature expects that a ``batch_size`` field is either located as a model attribute
107107
i.e. ``model.batch_size`` or as a field in your ``hparams`` i.e. ``model.hparams.batch_size``.
108-
The field should exist and will be overridden by the results of this algorithm.
109-
Additionally, your ``train_dataloader()`` method should depend on this field
110-
for this feature to work i.e.
108+
Similarly it can work with datamodules too. The field should exist and will be updated by
109+
the results of this algorithm. Additionally, your ``train_dataloader()`` method should depend
110+
on this field for this feature to work i.e.
111111

112112
.. code-block:: python
113113
114-
def train_dataloader(self):
115-
return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size)
114+
# using LightningModule
115+
class LitModel(LightningModule):
116+
def __init__(self, batch_size):
117+
super().__init__()
118+
self.save_hyperparameters()
119+
# or
120+
self.batch_size = batch_size
121+
122+
def train_dataloader(self):
123+
return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size)
124+
125+
126+
trainer = Trainer(...)
127+
model = LitModel(batch_size=32)
128+
trainer.tune(model)
129+
130+
# using LightningDataModule
131+
class LitDataModule(LightningDataModule):
132+
def __init__(self, batch_size):
133+
super().__init__()
134+
self.save_hyperparameters()
135+
# or
136+
self.batch_size = batch_size
137+
138+
def train_dataloader(self):
139+
return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size)
140+
141+
142+
trainer = Trainer(...)
143+
model = MyModel()
144+
datamodule = LitDataModule(batch_size=32)
145+
trainer.tune(model, datamodule=datamodule)
116146
117147
.. warning::
118148

119-
Due to these constraints, this features does *NOT* work when passing dataloaders directly
149+
Due to the constraints listed above, this features does *NOT* work when passing dataloaders directly
120150
to ``.fit()``.
121151

122152
The scaling algorithm has a number of parameters that the user can control by
@@ -178,7 +208,7 @@ rate, a `learning rate finder` can be used. As described in `this paper <https:/
178208
a learning rate finder does a small run where the learning rate is increased
179209
after each processed batch and the corresponding loss is logged. The result of
180210
this is a ``lr`` vs. ``loss`` plot that can be used as guidance for choosing an optimal
181-
initial lr.
211+
initial learning rate.
182212

183213
.. warning::
184214

@@ -189,16 +219,21 @@ initial lr.
189219
Using Lightning's built-in LR finder
190220
====================================
191221

192-
To enable the learning rate finder, your :doc:`lightning module <../common/lightning_module>` needs to have a ``learning_rate`` or ``lr`` property.
193-
Then, set ``Trainer(auto_lr_find=True)`` during trainer construction,
194-
and then call ``trainer.tune(model)`` to run the LR finder. The suggested ``learning_rate``
195-
will be written to the console and will be automatically set to your :doc:`lightning module <../common/lightning_module>`,
196-
which can be accessed via ``self.learning_rate`` or ``self.lr``.
222+
To enable the learning rate finder, your :doc:`lightning module <../common/lightning_module>` needs to
223+
have a ``learning_rate`` or ``lr`` attribute (or as a field in your ``hparams`` i.e.
224+
``hparams.learning_rate`` or ``hparams.lr``). Then, set ``Trainer(auto_lr_find=True)``
225+
during trainer construction, and then call ``trainer.tune(model)`` to run the LR finder.
226+
The suggested ``learning_rate`` will be written to the console and will be automatically
227+
set to your :doc:`lightning module <../common/lightning_module>`, which can be accessed
228+
via ``self.learning_rate`` or ``self.lr``.
229+
230+
.. seealso:: :ref:`trainer.tune <common/trainer:tune>`.
197231

198232
.. code-block:: python
199233
200234
class LitModel(LightningModule):
201235
def __init__(self, learning_rate):
236+
super().__init__()
202237
self.learning_rate = learning_rate
203238
self.model = Model(...)
204239
@@ -225,7 +260,6 @@ If your model is using an arbitrary value instead of ``self.lr`` or ``self.learn
225260
226261
trainer.tune(model)
227262
228-
229263
You can also inspect the results of the learning rate finder or just play around
230264
with the parameters of the algorithm. This can be done by invoking the
231265
:meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find` method. A typical example of this would look like:
@@ -239,7 +273,7 @@ with the parameters of the algorithm. This can be done by invoking the
239273
lr_finder = trainer.tuner.lr_find(model)
240274
241275
# Results can be found in
242-
lr_finder.results
276+
print(lr_finder.results)
243277
244278
# Plot with
245279
fig = lr_finder.plot(suggest=True)

0 commit comments

Comments
 (0)