-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent.py
More file actions
152 lines (124 loc) · 5.4 KB
/
agent.py
File metadata and controls
152 lines (124 loc) · 5.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
import glob
import torch.optim as optim
from pathlib import Path
from dataloaders import PickleLoader
from models import ResNetUNet
import argparse
import tqdm
from typing import List
import numpy as np
from metric import Metric
class Agent:
def __init__(self, agent_idx: int, task_list: List[Path], num_epochs: int):
self.agent_idx = agent_idx
self.task_list = task_list
self.num_epochs = num_epochs
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.model = ResNetUNet(3, 3).to(self.device)
def learn(self, task_idx: int) -> List[float]:
"""
train a task
"""
assert 0 <= task_idx < len(
self.task_list
), f"task_idx out of range. Must be between 0 and {len(self.task_list) - 1}. Got {task_idx}"
data_path = self.task_list[task_idx]
optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()
# Load data
dataset = PickleLoader(data_path)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=2,
shuffle=True)
# Define the number of epochs
running_loss_list = []
# Train loop
pbar = tqdm.tqdm(range(self.num_epochs),
desc=f'Agent {self.agent_idx} {data_path.name} Epoch')
for epoch in pbar:
running_loss = 0.0
total_loss = 0.0
for i, data in enumerate(dataloader, 0):
# Get input and target from the dataset
inputs, targets = data['input'].to(
self.device), data['target'].to(self.device)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
outputs = self.model(inputs)
# Output image is Bx3xNxN, target is BxNxN. Output and target must be flattened
# across their spacial axis to be treated as a per-pixel classification problem.
B, C, _, _ = outputs.shape
outputs = outputs.reshape(B, C, -1)
targets = targets.reshape(B, -1)
# Compute loss
loss = criterion(outputs, targets)
# Backward pass
loss.backward()
# Update the weights
optimizer.step()
# Print running loss
running_loss += loss.item()
total_loss += loss.item()
running_loss_list.append(running_loss / len(data))
if i % 1000 == 999: # Print every 1000 mini-batches
print(
f"[{epoch + 1}, {i + 1}] loss: {running_loss / 1000:.3f}"
)
running_loss = 0.0
# Update tqdm bar with latest loss
pbar.set_postfix({'loss': total_loss / len(dataloader)})
return running_loss_list
@torch.no_grad()
def evaluate(self, task_idx: int) -> Metric:
"""
evaluate a task
"""
assert 0 <= task_idx < len(
self.task_list
), f"task_idx out of range. Must be between 0 and {len(self.task_list) - 1}. Got {task_idx}"
data_path = self.task_list[task_idx]
criterion = torch.nn.CrossEntropyLoss()
# Load data
dataset = PickleLoader(data_path)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=1,
shuffle=False)
pixel_accuracies = []
calibration_errors = []
losses = []
for data in dataloader:
inputs, targets = data['input'].to(self.device), data['target'].to(
self.device)
outputs = self.model(inputs)
B, C, _, _ = outputs.shape
# Flatten outputs and targets to get rid of spacial dimensions.
outputs = outputs.reshape(B, C, -1)
targets = targets.reshape(B, -1)
# This is an accuracy measure between 0 and 1 across all pixels.
correct_pixels = (outputs.argmax(dim=1) == targets).sum()
total_pixels = targets.numel()
pixel_accuracy = correct_pixels / total_pixels
# Measure calibration error by computing the average confidence of the correct class.
probability_of_correct_class = outputs.softmax(dim=1).gather(
1, targets.unsqueeze(1)).squeeze(1)
calibration_error = 1 - probability_of_correct_class.mean()
pixel_accuracies.append(pixel_accuracy.item())
calibration_errors.append(calibration_error.item())
# Compute loss
loss = criterion(outputs, targets)
losses.append(loss.item())
return Metric(loss=np.mean(losses),
pixel_average_accuracy=np.mean(pixel_accuracies),
calibration_error=np.mean(calibration_errors))
def get_weights(self):
return torch.cat([p.data.view(-1) for p in self.model.parameters()],
-1)
def load_weights(self, weights):
# weights is a vector
beg = 0
for p in self.model.parameters():
p.data.copy_(weights[beg:beg + p.numel()].reshape(*p.data.shape))
beg += p.numel()