Skip to content

Commit ce961d2

Browse files
gm
1 parent 47e90ba commit ce961d2

File tree

5 files changed

+64
-2
lines changed

5 files changed

+64
-2
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""
2+
Trains a PyTorch image classification model using device-agnostic code.
3+
"""
4+
5+
import os
6+
import torch
7+
# import going_modular.going_modular.data_setup, going_modular.going_modular.engine, going_modular.going_modular.model_builder, going_modular.going_modular.utils
8+
from going_modular.going_modular import *
9+
from torchvision import transforms
10+
11+
# Setup hyperparameters
12+
NUM_EPOCHS = 5
13+
BATCH_SIZE = 32
14+
HIDDEN_UNITS = 10
15+
LEARNING_RATE = 0.001
16+
17+
# Setup directories
18+
train_dir = "data/pizza_steak_sushi/train"
19+
test_dir = "data/pizza_steak_sushi/test"
20+
21+
# Setup target device
22+
device = "cuda" if torch.cuda.is_available() else "cpu"
23+
24+
# Create transforms
25+
data_transform = transforms.Compose([
26+
transforms.Resize((64, 64)),
27+
transforms.ToTensor()
28+
])
29+
30+
# Create DataLoaders with help from data_setup.py
31+
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
32+
train_dir=train_dir,
33+
test_dir=test_dir,
34+
transform=data_transform,
35+
batch_size=BATCH_SIZE
36+
)
37+
38+
# Create model with help from model_builder.py
39+
model = model_builder.TinyVGG(
40+
input_shape=3,
41+
hidden_units=HIDDEN_UNITS,
42+
output_shape=len(class_names)
43+
).to(device)
44+
45+
# Set loss and optimizer
46+
loss_fn = torch.nn.CrossEntropyLoss()
47+
optimizer = torch.optim.Adam(model.parameters(),
48+
lr=LEARNING_RATE)
49+
50+
# Start training with help from engine.py
51+
engine.train(model=model,
52+
train_dataloader=train_dataloader,
53+
test_dataloader=test_dataloader,
54+
loss_fn=loss_fn,
55+
optimizer=optimizer,
56+
epochs=NUM_EPOCHS,
57+
device=device)
58+
59+
# Save the model with help from utils.py
60+
utils.save_model(model=model,
61+
target_dir="models",
62+
model_name="05_going_modular_script_mode_tinyvgg_model.pth")
Binary file not shown.
Binary file not shown.
Binary file not shown.

going_modular/going_modular/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
import os
66
import torch
7-
import data_setup, engine, model_builder, utils
8-
7+
# import going_modular.going_modular.data_setup, going_modular.going_modular.engine, going_modular.going_modular.model_builder, going_modular.going_modular.utils
8+
from going_modular.going_modular import *
99
from torchvision import transforms
1010

1111
# Setup hyperparameters

0 commit comments

Comments
 (0)