-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
52 lines (41 loc) · 1.94 KB
/
train.py
File metadata and controls
52 lines (41 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import torch
from torch.utils.data import DataLoader
import argparse
from training_loop import get_model, train_model
from helpers.load import load_dataset
from helpers.plot import plot_prediction
############################
# Parse Arguments
############################
parser = argparse.ArgumentParser(description="Train a model with given parameters.")
parser.add_argument("--model_name", type=str, default="EfficientNetB7", help="Name of the model to use")
parser.add_argument("--dataset_size", type=str, default="10k", help="Size of the dataset (e.g., '1k')")
parser.add_argument("--num_epochs", type=int, default=20, help="Number of epochs")
parser.add_argument("--batch_size_train", type=int, default=8, help="Batch size for training")
args = parser.parse_args()
model_name = args.model_name
dataset_size = args.dataset_size
num_epochs = args.num_epochs
batch_size_train = args.batch_size_train
############################
# Parameters
############################
PATH = os.getcwd()
image_dir = f"{PATH}/data/{dataset_size}/images"
label_dir = f"{PATH}/data/{dataset_size}/labels"
############################
# Training Setup
############################
model = get_model(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
train_dataset, val_dataset = load_dataset(image_dir, label_dir, test_size=0.2)
train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=2*batch_size_train, shuffle=False, num_workers=4, pin_memory=True)
############################
# Training Loop
############################
print(f"Model name : {model_name} ---- Dataset size : {dataset_size} ---- Epochs : {num_epochs}")
train_model(model, train_loader, val_loader, device, num_epochs, model_name)
plot_prediction(model, image_dir + "/image_0100.png", label_dir + "/labels_0100.png", device, "outputs/plots/"+model_name)