1111import shutil
1212import pdb
1313
14+ import pytorch_lightning as ptl
15+ import torch
16+ from torch .nn import functional as F
17+ from torch .utils .data import DataLoader
18+ from torchvision .datasets import MNIST
19+
20+
21+ class CoolModel (ptl .LightningModule ):
22+
23+ def __init (self ):
24+ super (CoolModel , self ).__init__ ()
25+ # not the best model...
26+ self .l1 = torch .nn .Linear (28 * 28 , 10 )
27+
28+ def forward (self , x ):
29+ return torch .relu (self .l1 (x ))
30+
31+ def my_loss (self , y_hat , y ):
32+ return F .cross_entropy (y_hat , y )
33+
34+ def training_step (self , batch , batch_nb ):
35+ x , y = batch
36+ y_hat = self .forward (x )
37+ return {'tng_loss' : self .my_loss (y_hat , y )}
38+
39+ def validation_step (self , batch , batch_nb ):
40+ x , y = batch
41+ y_hat = self .forward (x )
42+ return {'val_loss' : self .my_loss (y_hat , y )}
43+
44+ def validation_end (self , outputs ):
45+ avg_loss = torch .stack ([x for x in outputs ['val_loss' ]]).mean ()
46+ return avg_loss
47+
48+ def configure_optimizers (self ):
49+ return [torch .optim .Adam (self .parameters (), lr = 0.02 )]
50+
51+ @ptl .data_loader
52+ def tng_dataloader (self ):
53+ return DataLoader (MNIST ('path/to/save' , train = True ), batch_size = 32 )
54+
55+ @ptl .data_loader
56+ def val_dataloader (self ):
57+ return DataLoader (MNIST ('path/to/save' , train = False ), batch_size = 32 )
58+
59+ @ptl .data_loader
60+ def test_dataloader (self ):
61+ return DataLoader (MNIST ('path/to/save' , train = False ), batch_size = 32 )
62+
1463
1564def get_model ():
1665 # set up model with these hyperparams
@@ -94,11 +143,9 @@ def run_prediction(dataloader, trained_model):
94143def main ():
95144
96145 save_dir = init_save_dir ()
97- model , hparams = get_model ()
98146
99147 # exp file to get meta
100148 exp = get_exp (False )
101- exp .argparse (hparams )
102149 exp .save ()
103150
104151 # exp file to get weights
@@ -113,6 +160,8 @@ def main():
113160 distributed_backend = 'dp' ,
114161 )
115162
163+ model = CoolModel ()
164+
116165 result = trainer .fit (model )
117166
118167 # correct result and ok accuracy
0 commit comments