|
15 | 15 | from torch import nn |
16 | 16 | from torch.utils.data import Dataset, DataLoader |
17 | 17 |
|
18 | | -from pytorch_lightning.core.step_result import TrainResult, EvalResult |
19 | 18 | from pytorch_lightning.core.lightning import LightningModule |
20 | 19 |
|
21 | 20 |
|
@@ -111,235 +110,6 @@ def training_epoch_end_scalar(self, outputs): |
111 | 110 | assert batch_out.grad_fn is None |
112 | 111 | assert isinstance(batch_out, torch.Tensor) |
113 | 112 |
|
114 | | - def training_step_no_default_callbacks_for_train_loop(self, batch, batch_idx): |
115 | | - """ |
116 | | - Early stop and checkpoint only on these values |
117 | | - """ |
118 | | - acc = self.step(batch, batch_idx) |
119 | | - result = TrainResult(minimize=acc) |
120 | | - assert 'early_step_on' not in result |
121 | | - assert 'checkpoint_on' in result |
122 | | - return result |
123 | | - |
124 | | - def training_step_no_callbacks_result_obj(self, batch, batch_idx): |
125 | | - """ |
126 | | - Early stop and checkpoint only on these values |
127 | | - """ |
128 | | - acc = self.step(batch, batch_idx) |
129 | | - result = TrainResult(minimize=acc, checkpoint_on=False) |
130 | | - assert 'early_step_on' not in result |
131 | | - assert 'checkpoint_on' not in result |
132 | | - return result |
133 | | - |
134 | | - def training_step_result_log_epoch_and_step_for_callbacks(self, batch, batch_idx): |
135 | | - """ |
136 | | - Early stop and checkpoint only on these values |
137 | | - """ |
138 | | - acc = self.step(batch, batch_idx) |
139 | | - |
140 | | - self.assert_backward = False |
141 | | - losses = [20, 19, 18, 10, 15, 14, 9, 11, 11, 20] |
142 | | - idx = self.current_epoch |
143 | | - loss = acc + losses[idx] |
144 | | - result = TrainResult(minimize=loss, early_stop_on=loss, checkpoint_on=loss) |
145 | | - return result |
146 | | - |
147 | | - def training_step_result_log_step_only(self, batch, batch_idx): |
148 | | - acc = self.step(batch, batch_idx) |
149 | | - result = TrainResult(minimize=acc) |
150 | | - |
151 | | - # step only metrics |
152 | | - result.log(f'step_log_and_pbar_acc1_b{batch_idx}', torch.tensor(11).type_as(acc), prog_bar=True) |
153 | | - result.log(f'step_log_acc2_b{batch_idx}', torch.tensor(12).type_as(acc)) |
154 | | - result.log(f'step_pbar_acc3_b{batch_idx}', torch.tensor(13).type_as(acc), logger=False, prog_bar=True) |
155 | | - |
156 | | - self.training_step_called = True |
157 | | - return result |
158 | | - |
159 | | - def training_step_result_log_epoch_only(self, batch, batch_idx): |
160 | | - acc = self.step(batch, batch_idx) |
161 | | - result = TrainResult(minimize=acc) |
162 | | - |
163 | | - result.log(f'epoch_log_and_pbar_acc1_e{self.current_epoch}', torch.tensor(14).type_as(acc), |
164 | | - on_epoch=True, prog_bar=True, on_step=False) |
165 | | - result.log(f'epoch_log_acc2_e{self.current_epoch}', torch.tensor(15).type_as(acc), |
166 | | - on_epoch=True, on_step=False) |
167 | | - result.log(f'epoch_pbar_acc3_e{self.current_epoch}', torch.tensor(16).type_as(acc), |
168 | | - on_epoch=True, logger=False, prog_bar=True, on_step=False) |
169 | | - |
170 | | - self.training_step_called = True |
171 | | - return result |
172 | | - |
173 | | - def training_step_result_log_epoch_and_step(self, batch, batch_idx): |
174 | | - acc = self.step(batch, batch_idx) |
175 | | - result = TrainResult(minimize=acc) |
176 | | - |
177 | | - val_1 = (5 + batch_idx) * (self.current_epoch + 1) |
178 | | - val_2 = (6 + batch_idx) * (self.current_epoch + 1) |
179 | | - val_3 = (7 + batch_idx) * (self.current_epoch + 1) |
180 | | - result.log('step_epoch_log_and_pbar_acc1', torch.tensor(val_1).type_as(acc), |
181 | | - on_epoch=True, prog_bar=True) |
182 | | - result.log('step_epoch_log_acc2', torch.tensor(val_2).type_as(acc), |
183 | | - on_epoch=True) |
184 | | - result.log('step_epoch_pbar_acc3', torch.tensor(val_3).type_as(acc), |
185 | | - on_epoch=True, logger=False, prog_bar=True) |
186 | | - |
187 | | - self.training_step_called = True |
188 | | - return result |
189 | | - |
190 | | - def training_epoch_end_return_for_log_epoch_and_step(self, result): |
191 | | - """ |
192 | | - There should be an array of scalars without graphs that are all 171 (4 of them) |
193 | | - """ |
194 | | - self.training_epoch_end_called = True |
195 | | - |
196 | | - if self.use_dp or self.use_ddp2: |
197 | | - pass |
198 | | - else: |
199 | | - # only saw 4 batches |
200 | | - assert isinstance(result, TrainResult) |
201 | | - |
202 | | - result.step_epoch_log_acc2 = result.step_epoch_log_acc2_step.prod() |
203 | | - result.step_epoch_pbar_acc3 = result.step_epoch_pbar_acc3_step.prod() |
204 | | - result.step_epoch_log_and_pbar_acc1 = result.step_epoch_log_and_pbar_acc1_step.prod() |
205 | | - result.minimize = result.minimize.mean() |
206 | | - result.checkpoint_on = result.checkpoint_on.mean() |
207 | | - |
208 | | - result.step_epoch_log_and_pbar_acc1_step = result.step_epoch_log_and_pbar_acc1_step.prod() |
209 | | - result.step_epoch_log_and_pbar_acc1_epoch = result.step_epoch_log_and_pbar_acc1_epoch.prod() |
210 | | - result.step_epoch_log_acc2_step = result.step_epoch_log_acc2_step.prod() |
211 | | - result.step_epoch_log_acc2_epoch = result.step_epoch_log_acc2_epoch.prod() |
212 | | - result.step_epoch_pbar_acc3_step = result.step_epoch_pbar_acc3_step.prod() |
213 | | - result.step_epoch_pbar_acc3_epoch = result.step_epoch_pbar_acc3_epoch.prod() |
214 | | - result.log('epoch_end_log_acc', torch.tensor(1212).type_as(result.step_epoch_log_acc2_epoch), |
215 | | - logger=True, on_epoch=True) |
216 | | - result.log('epoch_end_pbar_acc', torch.tensor(1213).type_as(result.step_epoch_log_acc2_epoch), |
217 | | - logger=False, prog_bar=True, on_epoch=True) |
218 | | - result.log('epoch_end_log_pbar_acc', torch.tensor(1214).type_as(result.step_epoch_log_acc2_epoch), |
219 | | - logger=True, prog_bar=True, on_epoch=True) |
220 | | - return result |
221 | | - |
222 | | - # -------------------------- |
223 | | - # EvalResults |
224 | | - # -------------------------- |
225 | | - def validation_step_result_callbacks(self, batch, batch_idx): |
226 | | - acc = self.step(batch, batch_idx) |
227 | | - |
228 | | - self.assert_backward = False |
229 | | - losses = [20, 19, 20, 21, 22, 23] |
230 | | - idx = self.current_epoch |
231 | | - loss = acc + losses[idx] |
232 | | - result = EvalResult(early_stop_on=loss, checkpoint_on=loss) |
233 | | - |
234 | | - self.validation_step_called = True |
235 | | - return result |
236 | | - |
237 | | - def validation_step_result_no_callbacks(self, batch, batch_idx): |
238 | | - acc = self.step(batch, batch_idx) |
239 | | - |
240 | | - self.assert_backward = False |
241 | | - losses = [20, 19, 20, 21, 22, 23, 50, 50, 50, 50, 50, 50] |
242 | | - idx = self.current_epoch |
243 | | - loss = acc + losses[idx] |
244 | | - |
245 | | - result = EvalResult(checkpoint_on=loss) |
246 | | - |
247 | | - self.validation_step_called = True |
248 | | - return result |
249 | | - |
250 | | - def validation_step_result_only_epoch_metrics(self, batch, batch_idx): |
251 | | - """ |
252 | | - Only track epoch level metrics |
253 | | - """ |
254 | | - acc = self.step(batch, batch_idx) |
255 | | - result = EvalResult(checkpoint_on=acc, early_stop_on=acc) |
256 | | - |
257 | | - # step only metrics |
258 | | - result.log('no_val_no_pbar', torch.tensor(11 + batch_idx).type_as(acc), prog_bar=False, logger=False) |
259 | | - result.log('val_step_log_acc', torch.tensor(11 + batch_idx).type_as(acc), prog_bar=False, logger=True) |
260 | | - result.log('val_step_log_pbar_acc', torch.tensor(12 + batch_idx).type_as(acc), prog_bar=True, logger=True) |
261 | | - result.log('val_step_pbar_acc', torch.tensor(13 + batch_idx).type_as(acc), prog_bar=True, logger=False) |
262 | | - |
263 | | - self.validation_step_called = True |
264 | | - return result |
265 | | - |
266 | | - def validation_step_result_only_step_metrics(self, batch, batch_idx): |
267 | | - """ |
268 | | - Only track epoch level metrics |
269 | | - """ |
270 | | - acc = self.step(batch, batch_idx) |
271 | | - result = EvalResult(checkpoint_on=acc, early_stop_on=acc) |
272 | | - |
273 | | - # step only metrics |
274 | | - result.log('no_val_no_pbar', torch.tensor(11 + batch_idx).type_as(acc), |
275 | | - prog_bar=False, logger=False, on_epoch=False, on_step=True) |
276 | | - result.log('val_step_log_acc', torch.tensor(11 + batch_idx).type_as(acc), |
277 | | - prog_bar=False, logger=True, on_epoch=False, on_step=True) |
278 | | - result.log('val_step_log_pbar_acc', torch.tensor(12 + batch_idx).type_as(acc), |
279 | | - prog_bar=True, logger=True, on_epoch=False, on_step=True) |
280 | | - result.log('val_step_pbar_acc', torch.tensor(13 + batch_idx).type_as(acc), |
281 | | - prog_bar=True, logger=False, on_epoch=False, on_step=True) |
282 | | - result.log('val_step_batch_idx', torch.tensor(batch_idx).type_as(acc), |
283 | | - prog_bar=True, logger=True, on_epoch=False, on_step=True) |
284 | | - |
285 | | - self.validation_step_called = True |
286 | | - return result |
287 | | - |
288 | | - def validation_step_result_epoch_step_metrics(self, batch, batch_idx): |
289 | | - """ |
290 | | - Only track epoch level metrics |
291 | | - """ |
292 | | - acc = self.step(batch, batch_idx) |
293 | | - result = EvalResult(checkpoint_on=acc, early_stop_on=acc) |
294 | | - |
295 | | - # step only metrics |
296 | | - result.log('no_val_no_pbar', torch.tensor(11 + batch_idx).type_as(acc), |
297 | | - prog_bar=False, logger=False, on_epoch=True, on_step=True) |
298 | | - result.log('val_step_log_acc', torch.tensor(11 + batch_idx).type_as(acc), |
299 | | - prog_bar=False, logger=True, on_epoch=True, on_step=True) |
300 | | - result.log('val_step_log_pbar_acc', torch.tensor(12 + batch_idx).type_as(acc), |
301 | | - prog_bar=True, logger=True, on_epoch=True, on_step=True) |
302 | | - result.log('val_step_pbar_acc', torch.tensor(13 + batch_idx).type_as(acc), |
303 | | - prog_bar=True, logger=False, on_epoch=True, on_step=True) |
304 | | - result.log('val_step_batch_idx', torch.tensor(batch_idx).type_as(acc), |
305 | | - prog_bar=True, logger=True, on_epoch=True, on_step=True) |
306 | | - |
307 | | - self.validation_step_called = True |
308 | | - return result |
309 | | - |
310 | | - def validation_step_for_epoch_end_result(self, batch, batch_idx): |
311 | | - """ |
312 | | - EvalResult flows to epoch end (without step_end) |
313 | | - """ |
314 | | - acc = self.step(batch, batch_idx) |
315 | | - result = EvalResult(checkpoint_on=acc, early_stop_on=acc) |
316 | | - |
317 | | - # step only metrics |
318 | | - result.log('val_step_metric', torch.tensor(batch_idx).type_as(acc), |
319 | | - prog_bar=True, logger=True, on_epoch=True, on_step=False) |
320 | | - result.log('batch_idx', torch.tensor(batch_idx).type_as(acc), |
321 | | - prog_bar=True, logger=True, on_epoch=True, on_step=False) |
322 | | - |
323 | | - self.validation_step_called = True |
324 | | - return result |
325 | | - |
326 | | - def validation_epoch_end_result(self, result): |
327 | | - self.validation_epoch_end_called = True |
328 | | - |
329 | | - if self.trainer.running_sanity_check: |
330 | | - assert len(result.batch_idx) == 2 |
331 | | - else: |
332 | | - assert len(result.batch_idx) == self.trainer.limit_val_batches |
333 | | - |
334 | | - expected_val = result.val_step_metric.sum() / len(result.batch_idx) |
335 | | - result.val_step_metric = result.val_step_metric.mean() |
336 | | - result.batch_idx = result.batch_idx.mean() |
337 | | - assert result.val_step_metric == expected_val |
338 | | - |
339 | | - result.log('val_epoch_end_metric', torch.tensor(189).type_as(result.val_step_metric), prog_bar=True) |
340 | | - |
341 | | - return result |
342 | | - |
343 | 113 | # -------------------------- |
344 | 114 | # dictionary returns |
345 | 115 | # -------------------------- |
|
0 commit comments