Skip to content

Commit aa37eee

Browse files
committed
added annotations
1 parent 4ecce8b commit aa37eee

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

scripts/download_sample.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from __future__ import annotations
12
import argparse, random, shutil
23
from collections import defaultdict
34
from pathlib import Path
45
from PIL import Image
6+
from torch.utils.data import Dataset
57
from torchvision import datasets
68

79
def parse_args():
@@ -11,8 +13,13 @@ def parse_args():
1113
p.add_argument("--test-per-class", type=int, default=4, help="images/class for test")
1214
p.add_argument("--seed", type=int, default=42)
1315
return p.parse_args()
14-
15-
def pick_and_save(ds, split, n_per_class, out, rng):
16+
17+
def pick_and_save(ds: Dataset,
18+
split: str,
19+
n_per_class: int,
20+
out: Path,
21+
rng: random.Random
22+
) -> None:
1623
"""
1724
Selects up to n_per_class images per class from a dataset and saves them to disk.
1825

src/train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,11 @@ def build_model(num_classes: int) -> nn.Module:
186186
criterion = nn.CrossEntropyLoss() # standard multi-class loss
187187

188188
# one epoch function
189-
scaler = GradScaler('cuda') if torch.cuda.is_available() else None # if using autocast('cuda') in epoch_loop
189+
if torch.cuda.is_available(): # if using autocast('cuda') in epoch_loop
190+
scaler = GradScaler('cuda')
191+
print("Detected CUDA, using autocast...")
192+
else:
193+
scaler = None
190194

191195
def epoch_loop (phase: str,
192196
model: nn.Module,

0 commit comments

Comments
 (0)