|
| 1 | +""" |
| 2 | +Trains a PyTorch image classification model using device-agnostic code. |
| 3 | +""" |
| 4 | + |
| 5 | +import os |
| 6 | +import torch |
| 7 | +# import going_modular.going_modular.data_setup, going_modular.going_modular.engine, going_modular.going_modular.model_builder, going_modular.going_modular.utils |
| 8 | +from going_modular.going_modular import * |
| 9 | +from torchvision import transforms |
| 10 | + |
| 11 | +# Setup hyperparameters |
| 12 | +NUM_EPOCHS = 5 |
| 13 | +BATCH_SIZE = 32 |
| 14 | +HIDDEN_UNITS = 10 |
| 15 | +LEARNING_RATE = 0.001 |
| 16 | + |
| 17 | +# Setup directories |
| 18 | +train_dir = "data/pizza_steak_sushi/train" |
| 19 | +test_dir = "data/pizza_steak_sushi/test" |
| 20 | + |
| 21 | +# Setup target device |
| 22 | +device = "cuda" if torch.cuda.is_available() else "cpu" |
| 23 | + |
| 24 | +# Create transforms |
| 25 | +data_transform = transforms.Compose([ |
| 26 | + transforms.Resize((64, 64)), |
| 27 | + transforms.ToTensor() |
| 28 | +]) |
| 29 | + |
| 30 | +# Create DataLoaders with help from data_setup.py |
| 31 | +train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders( |
| 32 | + train_dir=train_dir, |
| 33 | + test_dir=test_dir, |
| 34 | + transform=data_transform, |
| 35 | + batch_size=BATCH_SIZE |
| 36 | +) |
| 37 | + |
| 38 | +# Create model with help from model_builder.py |
| 39 | +model = model_builder.TinyVGG( |
| 40 | + input_shape=3, |
| 41 | + hidden_units=HIDDEN_UNITS, |
| 42 | + output_shape=len(class_names) |
| 43 | +).to(device) |
| 44 | + |
| 45 | +# Set loss and optimizer |
| 46 | +loss_fn = torch.nn.CrossEntropyLoss() |
| 47 | +optimizer = torch.optim.Adam(model.parameters(), |
| 48 | + lr=LEARNING_RATE) |
| 49 | + |
| 50 | +# Start training with help from engine.py |
| 51 | +engine.train(model=model, |
| 52 | + train_dataloader=train_dataloader, |
| 53 | + test_dataloader=test_dataloader, |
| 54 | + loss_fn=loss_fn, |
| 55 | + optimizer=optimizer, |
| 56 | + epochs=NUM_EPOCHS, |
| 57 | + device=device) |
| 58 | + |
| 59 | +# Save the model with help from utils.py |
| 60 | +utils.save_model(model=model, |
| 61 | + target_dir="models", |
| 62 | + model_name="05_going_modular_script_mode_tinyvgg_model.pth") |
0 commit comments