Skip to content

Commit b0d38d5

Browse files
committed
updated docs
1 parent 4562580 commit b0d38d5

File tree

2 files changed

+64
-16
lines changed

2 files changed

+64
-16
lines changed

pytorch_lightning/root_module/root_module.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88

99
class LightningModule(GradInformation, ModelIO, ModelHooks):
1010

11-
def __init__(self, hparams):
11+
def __init__(self):
1212
super(LightningModule, self).__init__()
13-
self.hparams = hparams
1413

1514
self.dtype = torch.FloatTensor
1615
self.exp_save_path = None
@@ -64,18 +63,6 @@ def configure_optimizers(self):
6463
"""
6564
raise NotImplementedError
6665

67-
def summarize(self):
68-
model_summary = ModelSummary(self)
69-
print(model_summary)
70-
71-
def freeze(self):
72-
for param in self.parameters():
73-
param.requires_grad = False
74-
75-
def unfreeze(self):
76-
for param in self.parameters():
77-
param.requires_grad = True
78-
7966
@data_loader
8067
def tng_dataloader(self):
8168
"""
@@ -128,5 +115,17 @@ def load_from_metrics(cls, weights_path, tags_csv, on_gpu, map_location=None):
128115
model.load_state_dict(checkpoint['state_dict'], strict=False)
129116
return model
130117

118+
def summarize(self):
119+
model_summary = ModelSummary(self)
120+
print(model_summary)
121+
122+
def freeze(self):
123+
for param in self.parameters():
124+
param.requires_grad = False
125+
126+
def unfreeze(self):
127+
for param in self.parameters():
128+
param.requires_grad = True
129+
131130

132131

tests/debug.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,55 @@
1111
import shutil
1212
import pdb
1313

14+
import pytorch_lightning as ptl
15+
import torch
16+
from torch.nn import functional as F
17+
from torch.utils.data import DataLoader
18+
from torchvision.datasets import MNIST
19+
20+
21+
class CoolModel(ptl.LightningModule):
22+
23+
def __init(self):
24+
super(CoolModel, self).__init__()
25+
# not the best model...
26+
self.l1 = torch.nn.Linear(28 * 28, 10)
27+
28+
def forward(self, x):
29+
return torch.relu(self.l1(x))
30+
31+
def my_loss(self, y_hat, y):
32+
return F.cross_entropy(y_hat, y)
33+
34+
def training_step(self, batch, batch_nb):
35+
x, y = batch
36+
y_hat = self.forward(x)
37+
return {'tng_loss': self.my_loss(y_hat, y)}
38+
39+
def validation_step(self, batch, batch_nb):
40+
x, y = batch
41+
y_hat = self.forward(x)
42+
return {'val_loss': self.my_loss(y_hat, y)}
43+
44+
def validation_end(self, outputs):
45+
avg_loss = torch.stack([x for x in outputs['val_loss']]).mean()
46+
return avg_loss
47+
48+
def configure_optimizers(self):
49+
return [torch.optim.Adam(self.parameters(), lr=0.02)]
50+
51+
@ptl.data_loader
52+
def tng_dataloader(self):
53+
return DataLoader(MNIST('path/to/save', train=True), batch_size=32)
54+
55+
@ptl.data_loader
56+
def val_dataloader(self):
57+
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
58+
59+
@ptl.data_loader
60+
def test_dataloader(self):
61+
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
62+
1463

1564
def get_model():
1665
# set up model with these hyperparams
@@ -94,11 +143,9 @@ def run_prediction(dataloader, trained_model):
94143
def main():
95144

96145
save_dir = init_save_dir()
97-
model, hparams = get_model()
98146

99147
# exp file to get meta
100148
exp = get_exp(False)
101-
exp.argparse(hparams)
102149
exp.save()
103150

104151
# exp file to get weights
@@ -113,6 +160,8 @@ def main():
113160
distributed_backend='dp',
114161
)
115162

163+
model = CoolModel()
164+
116165
result = trainer.fit(model)
117166

118167
# correct result and ok accuracy

0 commit comments

Comments
 (0)