-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtraining.py
More file actions
33 lines (24 loc) · 779 Bytes
/
training.py
File metadata and controls
33 lines (24 loc) · 779 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
def train( model, dataloader, accumulation, criterion, optimizer, scheduler):
"""
Train the model for one epoch.
Returns:
Average loss over batches
"""
model.train()
optimizer.zero_grad()
running_loss = 0.
for batch_idx, (inputs, targets) in enumerate(dataloader):
outputs = model(inputs)
loss = criterion(outputs, targets)
running_loss += loss.item()
loss /= accumulation
loss.backward()
if ((batch_idx+1)%accumulation==0):
optimizer.step()
optimizer.zero_grad()
scheduler.step()
return running_loss / (batch_idx+1)