Skip to content

Commit d3b1a70

Browse files
committed
add image transforms, modify argparse to store_true and add tqdm to loops
1 parent 8d3d7e0 commit d3b1a70

File tree

2 files changed

+57
-25
lines changed

2 files changed

+57
-25
lines changed

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies:
1818
- pytest
1919
- ruff
2020
- scalene
21+
- tqdm
2122
- pip:
2223
- torch
2324
- torchvision

main.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torch.nn as nn
77
import wandb
88
from torch.utils.data import DataLoader
9+
from torchvision import transforms
10+
from tqdm import tqdm
911

1012
from utils import MetricWrapper, createfolders, load_data, load_model
1113

@@ -49,15 +51,13 @@ def main():
4951
)
5052
parser.add_argument(
5153
"--savemodel",
52-
type=bool,
53-
default=False,
54+
action="store_true",
5455
help="Whether model should be saved or not.",
5556
)
5657

5758
parser.add_argument(
5859
"--download-data",
59-
type=bool,
60-
default=False,
60+
action="store_true",
6161
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
6262
)
6363

@@ -126,17 +126,27 @@ def main():
126126

127127
metrics = MetricWrapper(*args.metric)
128128

129+
augmentations = transforms.Compose(
130+
[
131+
transforms.Resize((16, 16)), # At least for USPS
132+
transforms.ToTensor(),
133+
]
134+
)
135+
129136
# Dataset
130137
traindata = load_data(
131138
args.dataset,
132139
train=True,
133140
data_path=args.datafolder,
134141
download=args.download_data,
142+
transform=augmentations,
135143
)
136144
validata = load_data(
137145
args.dataset,
138146
train=False,
139147
data_path=args.datafolder,
148+
download=args.download_data,
149+
transform=augmentations,
140150
)
141151

142152
# Find number of channels in the dataset
@@ -153,34 +163,53 @@ def main():
153163
)
154164
model.to(device)
155165

156-
trainloader = DataLoader(traindata,
157-
batch_size=args.batchsize,
158-
shuffle=True,
159-
pin_memory=True,
160-
drop_last=True)
161-
valiloader = DataLoader(validata,
162-
batch_size=args.batchsize,
163-
shuffle=False,
164-
pin_memory=True)
166+
trainloader = DataLoader(
167+
traindata,
168+
batch_size=args.batchsize,
169+
shuffle=True,
170+
pin_memory=True,
171+
drop_last=True,
172+
)
173+
valiloader = DataLoader(
174+
validata, batch_size=args.batchsize, shuffle=False, pin_memory=True
175+
)
165176

166177
criterion = nn.CrossEntropyLoss()
167178
optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate)
168179

169180
# This allows us to load all the components without running the training loop
170181
if args.dry_run:
171-
print("Dry run completed")
182+
dry_run_loader = DataLoader(
183+
traindata,
184+
batch_size=1,
185+
shuffle=True,
186+
pin_memory=True,
187+
drop_last=True,
188+
)
189+
190+
for x, y in tqdm(dry_run_loader, desc="Dry run", total=1):
191+
x, y = x.to(device), y.to(device)
192+
pred = model.forward(x)
193+
194+
loss = criterion(y, pred)
195+
loss.backward()
196+
197+
optimizer.step()
198+
optimizer.zero_grad(set_to_none=True)
199+
200+
break
201+
202+
print("Dry run completed successfully.")
172203
exit(0)
173204

174-
wandb.init(project='',
175-
tags=[])
205+
wandb.init(project="", tags=[])
176206
wandb.watch(model)
177207

178208
for epoch in range(args.epoch):
179-
180209
# Training loop start
181210
trainingloss = []
182211
model.train()
183-
for x, y in trainloader:
212+
for x, y in tqdm(trainloader, desc="Training"):
184213
x, y = x.to(device), y.to(device)
185214
pred = model.forward(x)
186215

@@ -195,18 +224,20 @@ def main():
195224
# Eval loop start
196225
model.eval()
197226
with th.no_grad():
198-
for x, y in valiloader:
227+
for x, y in tqdm(valiloader, desc="Validation"):
199228
x, y = x.to(device), y.to(device)
200229
pred = model.forward(x)
201230
loss = criterion(y, pred)
202231
evalloss.append(loss.item())
203232

204-
wandb.log({
205-
'Epoch': epoch,
206-
'Train loss': np.mean(trainingloss),
207-
'Evaluation Loss': np.mean(evalloss)
208-
})
233+
wandb.log(
234+
{
235+
"Epoch": epoch,
236+
"Train loss": np.mean(trainingloss),
237+
"Evaluation Loss": np.mean(evalloss),
238+
}
239+
)
209240

210241

211-
if __name__ == '__main__':
242+
if __name__ == "__main__":
212243
main()

0 commit comments

Comments
 (0)