-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathptr_test.py
More file actions
142 lines (120 loc) · 4.64 KB
/
ptr_test.py
File metadata and controls
142 lines (120 loc) · 4.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from dataclasses import dataclass
from typing import Iterator, List, Optional
from pprint import pprint
from random import shuffle
import dataclasses
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import Dataset
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from convex_hull_dataset import create_dataset, get_dataloader
from convex_hull_dataset import ConvexHullDataset, ConvexHullSample
from ptr_network import PointerNet
@dataclass
class TrainingState:
optimizer: Optimizer
lr_scheduler: Optional[_LRScheduler]
# number of loss.backwards() calls before optimizer.step() is called
mini_batch_size: int
training_loss: List[float]
validation_loss: List[float]
# gets validation after validation_rate minibatches
validation_rate: int
# saves a checkpoint after checkpoint_rate epochs
checkpoint_rate: int
epochs: int
current_epoch: int
@staticmethod
def make_default(parameters: Iterator[nn.Parameter]) -> "TrainingState":
optimizer = SGD(parameters, lr=0.01, momentum=0.7)
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.75)
return TrainingState(
optimizer=optimizer,
lr_scheduler=lr_scheduler,
mini_batch_size=2,
training_loss=[],
validation_loss=[],
validation_rate=1,
checkpoint_rate=5,
epochs=20,
current_epoch=0,
)
class Trainer:
state: TrainingState
model: PointerNet
training_dataset: ConvexHullDataset
validation_dataset: ConvexHullDataset
name: str
def __init__(
self,
model: PointerNet,
training_dataset: ConvexHullDataset,
validation_dataset: ConvexHullDataset,
name: str = "pointer_net_trainer_2",
):
self.model = model
self.training_dataset = training_dataset
self.validation_dataset = validation_dataset
self.state = TrainingState.make_default(model.parameters())
self.name = name
def run_epoch(self, dataset: ConvexHullDataset, is_training: bool = True) -> float:
epoch_loss = 0.0
if is_training:
shuffle(dataset)
print(f"training with {len(dataset)} samples")
else:
print(f"validating with {len(dataset)} samples")
with torch.set_grad_enabled(is_training):
for i, sample in tqdm(enumerate(self.training_dataset, 1)):
output = self.model(
sample.points,
positions=sample.vertices,
teacher_forcing=True,
)
for loss in output.loss:
epoch_loss += loss.item()
if is_training:
self.state.optimizer.zero_grad()
for loss in output.loss:
loss.backward(retain_graph=True)
if i % self.state.mini_batch_size == 0:
self.state.optimizer.step()
print(f"epoch loss: loss: {epoch_loss:8.3f}")
return epoch_loss
def train(self):
try:
for epoch in range(1, self.state.epochs + 1):
print(f"Starting epoch: {epoch}")
epoch_loss = self.run_epoch(self.training_dataset, is_training=True)
self.state.training_loss.append(epoch_loss)
if epoch % self.state.validation_rate == 0:
# this is "smart training" (...)
# if our validation scores decrease, then step the lr_scheduler
if not self.validate():
print('stepping lr')
self.state.lr_scheduler.step()
if epoch % self.state.checkpoint_rate == 0:
self.checkpoint(f"epoch_{epoch}")
except KeyboardInterrupt:
self.checkpoint(f"interrupt")
def validate(self) -> bool:
validation_loss = self.run_epoch(self.validation_dataset, is_training=False)
losses = self.state.validation_loss
losses.append(validation_loss)
return len(losses) < 2 or losses[-1] >= losses[-2]
def checkpoint(self, note: str):
torch.save(self, f"{self.name}_{note}.pt")
def run_new_training():
training_dataset = create_dataset(2 ** 16)
validation_dataset = create_dataset(2 ** 8)
hidden_d = 128
hidden_v = 64
model = PointerNet(
encoder_args={"hidden_d": hidden_d},
decoder_args={"hidden_d": hidden_d, "hidden_v": hidden_v},
)
trainer = Trainer(model, training_dataset, validation_dataset)
trainer.train()