@@ -89,7 +89,7 @@ longer training time. Inspired by https://github.com/BlackHC/toma.
89
89
# Autoscale batch size
90
90
trainer = Trainer(auto_scale_batch_size = None | " power" | " binsearch" )
91
91
92
- # find the batch size
92
+ # Find the batch size
93
93
trainer.tune(model)
94
94
95
95
Currently, this feature supports two modes ``'power' `` scaling and ``'binsearch' ``
@@ -105,18 +105,48 @@ search for batch sizes larger than the size of the training dataset.
105
105
106
106
This feature expects that a ``batch_size `` field is either located as a model attribute
107
107
i.e. ``model.batch_size `` or as a field in your ``hparams `` i.e. ``model.hparams.batch_size ``.
108
- The field should exist and will be overridden by the results of this algorithm.
109
- Additionally, your ``train_dataloader() `` method should depend on this field
110
- for this feature to work i.e.
108
+ Similarly it can work with datamodules too. The field should exist and will be updated by
109
+ the results of this algorithm. Additionally, your ``train_dataloader() `` method should depend
110
+ on this field for this feature to work i.e.
111
111
112
112
.. code-block :: python
113
113
114
- def train_dataloader (self ):
115
- return DataLoader(train_dataset, batch_size = self .batch_size | self .hparams.batch_size)
114
+ # using LightningModule
115
+ class LitModel (LightningModule ):
116
+ def __init__ (self , batch_size ):
117
+ super ().__init__ ()
118
+ self .save_hyperparameters()
119
+ # or
120
+ self .batch_size = batch_size
121
+
122
+ def train_dataloader (self ):
123
+ return DataLoader(train_dataset, batch_size = self .batch_size | self .hparams.batch_size)
124
+
125
+
126
+ trainer = Trainer(... )
127
+ model = LitModel(batch_size = 32 )
128
+ trainer.tune(model)
129
+
130
+ # using LightningDataModule
131
+ class LitDataModule (LightningDataModule ):
132
+ def __init__ (self , batch_size ):
133
+ super ().__init__ ()
134
+ self .save_hyperparameters()
135
+ # or
136
+ self .batch_size = batch_size
137
+
138
+ def train_dataloader (self ):
139
+ return DataLoader(train_dataset, batch_size = self .batch_size | self .hparams.batch_size)
140
+
141
+
142
+ trainer = Trainer(... )
143
+ model = MyModel()
144
+ datamodule = LitDataModule(batch_size = 32 )
145
+ trainer.tune(model, datamodule = datamodule)
116
146
117
147
.. warning ::
118
148
119
- Due to these constraints, this features does *NOT * work when passing dataloaders directly
149
+ Due to the constraints listed above , this features does *NOT * work when passing dataloaders directly
120
150
to ``.fit() ``.
121
151
122
152
The scaling algorithm has a number of parameters that the user can control by
@@ -178,7 +208,7 @@ rate, a `learning rate finder` can be used. As described in `this paper <https:/
178
208
a learning rate finder does a small run where the learning rate is increased
179
209
after each processed batch and the corresponding loss is logged. The result of
180
210
this is a ``lr `` vs. ``loss `` plot that can be used as guidance for choosing an optimal
181
- initial lr .
211
+ initial learning rate .
182
212
183
213
.. warning ::
184
214
@@ -189,16 +219,21 @@ initial lr.
189
219
Using Lightning's built-in LR finder
190
220
====================================
191
221
192
- To enable the learning rate finder, your :doc: `lightning module <../common/lightning_module >` needs to have a ``learning_rate `` or ``lr `` property.
193
- Then, set ``Trainer(auto_lr_find=True) `` during trainer construction,
194
- and then call ``trainer.tune(model) `` to run the LR finder. The suggested ``learning_rate ``
195
- will be written to the console and will be automatically set to your :doc: `lightning module <../common/lightning_module >`,
196
- which can be accessed via ``self.learning_rate `` or ``self.lr ``.
222
+ To enable the learning rate finder, your :doc: `lightning module <../common/lightning_module >` needs to
223
+ have a ``learning_rate `` or ``lr `` attribute (or as a field in your ``hparams `` i.e.
224
+ ``hparams.learning_rate `` or ``hparams.lr ``). Then, set ``Trainer(auto_lr_find=True) ``
225
+ during trainer construction, and then call ``trainer.tune(model) `` to run the LR finder.
226
+ The suggested ``learning_rate `` will be written to the console and will be automatically
227
+ set to your :doc: `lightning module <../common/lightning_module >`, which can be accessed
228
+ via ``self.learning_rate `` or ``self.lr ``.
229
+
230
+ .. seealso :: :ref:`trainer.tune <common/trainer:tune>`.
197
231
198
232
.. code-block :: python
199
233
200
234
class LitModel (LightningModule ):
201
235
def __init__ (self , learning_rate ):
236
+ super ().__init__ ()
202
237
self .learning_rate = learning_rate
203
238
self .model = Model(... )
204
239
@@ -225,7 +260,6 @@ If your model is using an arbitrary value instead of ``self.lr`` or ``self.learn
225
260
226
261
trainer.tune(model)
227
262
228
-
229
263
You can also inspect the results of the learning rate finder or just play around
230
264
with the parameters of the algorithm. This can be done by invoking the
231
265
:meth: `~pytorch_lightning.tuner.tuning.Tuner.lr_find ` method. A typical example of this would look like:
@@ -239,7 +273,7 @@ with the parameters of the algorithm. This can be done by invoking the
239
273
lr_finder = trainer.tuner.lr_find(model)
240
274
241
275
# Results can be found in
242
- lr_finder.results
276
+ print ( lr_finder.results)
243
277
244
278
# Plot with
245
279
fig = lr_finder.plot(suggest = True )
0 commit comments