-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_cifar.py
More file actions
40 lines (35 loc) · 1.27 KB
/
train_cifar.py
File metadata and controls
40 lines (35 loc) · 1.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import csv
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from model_cifar_resnet import CIFAR_ResNet
from pytorch_lightning.callbacks import TQDMProgressBar
from matplotlib_inline.backend_inline import set_matplotlib_formats
from pytorch_lightning.callbacks import LearningRateMonitor
from tqdm import tqdm
import torch
# augment can be set to "rand", "rect" or "none"
augment = "rand"
def train_rand(max_perturb, p_box, max_box):
model = CIFAR_ResNet(learning_rate=0.01,
batch_size=64,
augment=augment,
max_perturb=max_perturb,
p_box=p_box,
max_box=max_box
)
savename = f"cifar/{augment}"
logger = CSVLogger("logs", name=savename)
trainer = Trainer(
gpus=1,
strategy="ddp_find_unused_parameters_false",
sync_batchnorm=True,
max_epochs=90,
callbacks=[LearningRateMonitor(logging_interval="step"),
TQDMProgressBar(refresh_rate=10)],
logger=logger,
precision=16
)
trainer.fit(model)
torch.save(model.state_dict(), f"weights/{savename}.pt")
train_rand(0.25, 0.1, 3)