Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from datetime import datetime

from torch.utils.tensorboard import SummaryWriter

from pipnet.pipnet import PIPNet, get_network
from util.log import Log
import torch.nn as nn
Expand Down Expand Up @@ -56,6 +60,7 @@ def run_pipnet(args=None):
device = torch.device('cuda:'+str(device_ids[0]))
else:
device = torch.device('cpu')
device_ids.append(0)

# Log which device was actually used
print("Device used: ", device, "with id", device_ids, flush=True)
Expand All @@ -81,9 +86,12 @@ def run_pipnet(args=None):
classification_layer = classification_layer
)
net = net.to(device=device)
net = nn.DataParallel(net, device_ids = device_ids)
net = nn.DataParallel(net, device_ids = device_ids)

optimizer_net, optimizer_classifier, params_to_freeze, params_to_train, params_backbone = get_optimizer_nn(net, args)
optimizer_net, optimizer_classifier, params_to_freeze, params_to_train, params_backbone = get_optimizer_nn(net, args)
global_epoch = 1
dirname = datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
tb_writer = SummaryWriter(f"tensorboard/{dirname}")

# Initialize or load model
with torch.no_grad():
Expand Down Expand Up @@ -157,13 +165,16 @@ def run_pipnet(args=None):
print("\nPretrain Epoch", epoch, "with batch size", trainloader_pretraining.batch_size, flush=True)

# Pretrain prototypes
train_info = train_pipnet(net, trainloader_pretraining, optimizer_net, optimizer_classifier, scheduler_net, None, criterion, epoch, args.epochs_pretrain, device, pretrain=True, finetune=False)
train_info = train_pipnet(net, trainloader_pretraining, optimizer_net, optimizer_classifier, scheduler_net, None, criterion, epoch, args.epochs_pretrain, global_epoch, device, pretrain=True, finetune=False, tb_writer=tb_writer)
lrs_pretrain_net+=train_info['lrs_net']
# eval_info = eval_pipnet(net, testloader, global_epoch, device, log, tensorboard=tb_writer)

plt.clf()
plt.plot(lrs_pretrain_net)
plt.savefig(os.path.join(args.log_dir,'lr_pretrain_net.png'))
log.log_values('log_epoch_overview', epoch, "n.a.", "n.a.", "n.a.", "n.a.", "n.a.", "n.a.", "n.a.", train_info['loss'])

global_epoch += 1

if args.state_dict_dir_net == '':
net.eval()
torch.save({'model_state_dict': net.state_dict(), 'optimizer_net_state_dict': optimizer_net.state_dict()}, os.path.join(os.path.join(args.log_dir, 'checkpoints'), 'net_pretrained'))
Expand Down Expand Up @@ -239,11 +250,11 @@ def run_pipnet(args=None):
print("Classifier bias: ", net.module._classification.bias, flush=True)
torch.set_printoptions(profile="default")

train_info = train_pipnet(net, trainloader, optimizer_net, optimizer_classifier, scheduler_net, scheduler_classifier, criterion, epoch, args.epochs, device, pretrain=False, finetune=finetune)
train_info = train_pipnet(net, trainloader, optimizer_net, optimizer_classifier, scheduler_net, scheduler_classifier, criterion, epoch, args.epochs, global_epoch, device, pretrain=False, finetune=finetune, tb_writer=tb_writer)
lrs_net+=train_info['lrs_net']
lrs_classifier+=train_info['lrs_class']
# Evaluate model
eval_info = eval_pipnet(net, testloader, epoch, device, log)
eval_info = eval_pipnet(net, testloader, global_epoch, device, log, tensorboard=tb_writer)
log.log_values('log_epoch_overview', epoch, eval_info['top1_accuracy'], eval_info['top5_accuracy'], eval_info['almost_sim_nonzeros'], eval_info['local_size_all_classes'], eval_info['almost_nonzeros'], eval_info['num non-zero prototypes'], train_info['train_accuracy'], train_info['loss'])

with torch.no_grad():
Expand All @@ -261,6 +272,7 @@ def run_pipnet(args=None):
plt.clf()
plt.plot(lrs_classifier)
plt.savefig(os.path.join(args.log_dir,'lr_class.png'))
global_epoch += 1

net.eval()
torch.save({'model_state_dict': net.state_dict(), 'optimizer_net_state_dict': optimizer_net.state_dict(), 'optimizer_classifier_state_dict': optimizer_classifier.state_dict()}, os.path.join(os.path.join(args.log_dir, 'checkpoints'), 'net_trained_last'))
Expand Down Expand Up @@ -361,11 +373,11 @@ def run_pipnet(args=None):
tqdm_dir = os.path.join(args.log_dir,'tqdm.txt')
if not os.path.isdir(args.log_dir):
os.mkdir(args.log_dir)
sys.stdout.close()
sys.stderr.close()
sys.stdout = open(print_dir, 'w')
sys.stderr = open(tqdm_dir, 'w')

# sys.stdout.close()
# sys.stderr.close()
# sys.stdout = open(print_dir, 'w')
# sys.stderr = open(tqdm_dir, 'w')
run_pipnet(args)

sys.stdout.close()
Expand Down
4 changes: 2 additions & 2 deletions pipnet/pipnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_network(num_classes: int, args: argparse.Namespace):
num_prototypes = first_add_on_layer_in_channels
print("Number of prototypes: ", num_prototypes, flush=True)
add_on_layers = nn.Sequential(
nn.Softmax(dim=1), #softmax over every prototype for each patch, such that for every location in image, sum over prototypes is 1
nn.Softmax(dim=1), #softmax over every prototype for each patch, such that for every location in image, sum over prototypes is 1
)
else:
num_prototypes = args.num_features
Expand All @@ -99,7 +99,7 @@ def get_network(num_classes: int, args: argparse.Namespace):
pool_layer = nn.Sequential(
nn.AdaptiveMaxPool2d(output_size=(1,1)), #outputs (bs, ps,1,1)
nn.Flatten() #outputs (bs, ps)
)
)

if args.bias:
classification_layer = NonNegLinear(num_prototypes, num_classes, bias=True)
Expand Down
136 changes: 130 additions & 6 deletions pipnet/test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import itertools
from typing import List

from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import numpy as np
import torch
Expand All @@ -13,7 +18,8 @@ def eval_pipnet(net,
test_loader: DataLoader,
epoch,
device,
log: Log = None,
log: Log = None,
tensorboard: SummaryWriter|None = None,
progress_prefix: str = 'Eval Epoch'
) -> dict:

Expand Down Expand Up @@ -41,6 +47,7 @@ def eval_pipnet(net,
mininterval=5.,
ncols=0)
(xs, ys) = next(iter(test_loader))
inputs, results = [], []
# Iterate through the test set
for i, (xs, ys) in test_iter:
xs, ys = xs.to(device), ys.to(device)
Expand Down Expand Up @@ -81,14 +88,20 @@ def eval_pipnet(net,
y_preds += ys_pred_scores.detach().tolist()
y_trues += ys.detach().tolist()
y_preds_classes += ys_pred.detach().tolist()
inputs.append(xs)
results.append((pooled, out, ys_pred))


del out
del pooled
del ys_pred
# del out
# del pooled
# del ys_pred

print("PIP-Net abstained from a decision for", abstained.item(), "images", flush=True)
print("PIP-Net abstained from a decision for", abstained.item(), "images", flush=True)
info['abstained'] = abstained.item()
info['num non-zero prototypes'] = torch.gt(net.module._classification.weight,1e-3).any(dim=0).sum().item()
print("sparsity ratio: ", (torch.numel(net.module._classification.weight)-torch.count_nonzero(torch.nn.functional.relu(net.module._classification.weight-1e-3)).item()) / torch.numel(net.module._classification.weight), flush=True)
sparsity_ratio = (torch.numel(net.module._classification.weight)-torch.count_nonzero(torch.nn.functional.relu(net.module._classification.weight-1e-3)).item()) / torch.numel(net.module._classification.weight)
print("sparsity ratio: ", sparsity_ratio, flush=True)
info['sparsity ratio'] = sparsity_ratio
info['confusion_matrix'] = cm
info['test_accuracy'] = acc_from_cm(cm)
info['top1_accuracy'] = global_top1acc/len(test_loader.dataset)
Expand All @@ -97,6 +110,10 @@ def eval_pipnet(net,
info['local_size_all_classes'] = local_size_total / len(test_loader.dataset)
info['almost_nonzeros'] = global_anz/len(test_loader.dataset)

if tensorboard is not None:
log_to_tensorboard(info, net.module, inputs, results, classes=test_loader.dataset.class_to_idx.values(), global_epoch=epoch, tb_writer=tensorboard)
tensorboard.flush()

if net.module._num_classes == 2:
tp = cm[0][0]
fn = cm[0][1]
Expand Down Expand Up @@ -128,6 +145,113 @@ def eval_pipnet(net,

return info

def log_to_tensorboard(info: dict, model_module, inp, preds, classes: List[str], global_epoch: int, tb_writer: SummaryWriter):
for key, value in info.items():
match key:
case 'confusion_matrix':
# figure = plot_confusion_matrix(value, class_names=classes)
# image = plot_to_image(figure)
# tb_writer.add_image(key, image, global_epoch)
pass
case _:
tb_writer.add_scalar(key, value, global_epoch)

if hasattr(model_module, '_classification'):
weights = model_module._classification.weight.flatten().cpu()
fig = plot_tensor(weights, samples=66)
tb_writer.add_figure('classification_weights', fig, global_epoch)

fig = plot_tensor(torch.relu(weights), samples=66)
tb_writer.add_figure('classification_non_neg', fig, global_epoch)

inp = inp[0]
tb_writer.add_graph(model_module, inp)
pooled, out, ys_pred = [torch.cat([pred[i] for pred in preds], dim=0) for i in range(3)]
tb_writer.add_figure('pooled', plot_batch(pooled), global_epoch)
tb_writer.add_figure('out', plot_batch(out), global_epoch)
tb_writer.add_histogram('ys_pred', ys_pred, global_epoch)

def plot_batch(tensor: torch.Tensor, samples = 8, title: str = None):
batch_size = tensor.shape[0]
if batch_size > samples:
indices = torch.linspace(0, batch_size - 1, steps=samples).long()
sample_tensor = tensor[indices]
else:
sample_tensor = tensor

stats = [
("Max over batch", torch.max(tensor, dim=0).values),
("Min over batch", torch.min(tensor, dim=0).values),
("Mean over batch", torch.mean(tensor, dim=0))
]

fig = plt.figure(figsize=(12, 10))
if title:
fig.suptitle(title, fontsize=16)

ax = plt.subplot(2, 2, 1)
for s in sample_tensor:
plot_tensor(s, subplot=True)
ax.set_title("Samples")

for i, (stat_title, stat_tensor) in enumerate(stats, start=2):
ax = plt.subplot(2, 2, i)
plot_tensor(stat_tensor, subplot=True)
ax.set_title(stat_title)

plt.tight_layout()
return fig


def plot_tensor(tensor: torch.Tensor, title: str = None, samples: int = 50, subplot: bool = False):
array = tensor.cpu().numpy()
array = smooth_array(array, samples)
if not subplot:
plt.figure()
plt.plot(array)
if title is not None:
plt.title(title)
return plt.gcf()

def smooth_array(array: np.ndarray, samples: int = 50):
window_size = len(array) // samples
kernel = np.ones(window_size) / window_size
smoothed_tensor = np.convolve(array, kernel, mode='valid')
return smoothed_tensor

def plot_confusion_matrix(cm, class_names):
figure = plt.figure(figsize=(8, 8))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Accent)
plt.title("Confusion matrix")
plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)

cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)
threshold = cm.max() / 2.

for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
color = "white" if cm[i, j] > threshold else "black"
plt.text(j, i, cm[i, j], horizontalalignment="center", color=color)

plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')

return figure

def plot_to_image(figure):
figure.canvas.draw()
data = np.frombuffer(figure.canvas.tostring_argb(), dtype=np.uint8)
data = data.reshape(figure.canvas.get_width_height()[::-1] + (4,)) # (H, W, 4)

data = data[:, :, [1, 2, 3]] # Drop the alpha channel (0th index)
tensor = torch.tensor(data)
tensor = tensor.float() / 255.0
plt.close(figure)
return tensor.permute(2, 0, 1)

def acc_from_cm(cm: np.ndarray) -> float:
"""
Compute the accuracy from the confusion matrix
Expand Down
20 changes: 15 additions & 5 deletions pipnet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import torch.nn.functional as F
import torch.optim
import torch.utils.data
import math
from torch.utils.tensorboard import SummaryWriter

def train_pipnet(net, train_loader, optimizer_net, optimizer_classifier, scheduler_net, scheduler_classifier, criterion, epoch, nr_epochs, device, pretrain=False, finetune=False, progress_prefix: str = 'Train Epoch'):

def train_pipnet(net, train_loader, optimizer_net, optimizer_classifier, scheduler_net, scheduler_classifier, criterion, epoch, nr_epochs, global_epoch, device, pretrain=False, finetune=False, progress_prefix: str = 'Train Epoch', tb_writer: SummaryWriter = None):

# Make sure the model is in train mode
net.train()
Expand Down Expand Up @@ -65,7 +66,7 @@ def train_pipnet(net, train_loader, optimizer_net, optimizer_classifier, schedul

# Perform a forward pass through the network
proto_features, pooled, out = net(torch.cat([xs1, xs2]))
loss, acc = calculate_loss(proto_features, pooled, out, ys, align_pf_weight, t_weight, unif_weight, cl_weight, net.module._classification.normalization_multiplier, pretrain, finetune, criterion, train_iter, print=True, EPS=1e-8)
loss, acc = calculate_loss(proto_features, pooled, out, ys, align_pf_weight, t_weight, unif_weight, cl_weight, net.module._classification.normalization_multiplier, pretrain, finetune, criterion, train_iter, global_epoch, print=True, EPS=1e-8, tb_writer=tb_writer)

# Compute the gradient
loss.backward()
Expand Down Expand Up @@ -99,7 +100,7 @@ def train_pipnet(net, train_loader, optimizer_net, optimizer_classifier, schedul

return train_info

def calculate_loss(proto_features, pooled, out, ys1, align_pf_weight, t_weight, unif_weight, cl_weight, net_normalization_multiplier, pretrain, finetune, criterion, train_iter, print=True, EPS=1e-10):
def calculate_loss(proto_features, pooled, out, ys1, align_pf_weight, t_weight, unif_weight, cl_weight, net_normalization_multiplier, pretrain, finetune, criterion, train_iter, global_epoch, print=True, EPS=1e-10, tb_writer=None):
ys = torch.cat([ys1,ys1])
pooled1, pooled2 = pooled.chunk(2)
pf1, pf2 = proto_features.chunk(2)
Expand Down Expand Up @@ -132,7 +133,16 @@ def calculate_loss(proto_features, pooled, out, ys1, align_pf_weight, t_weight,
ys_pred_max = torch.argmax(out, dim=1)
correct = torch.sum(torch.eq(ys_pred_max, ys))
acc = correct.item() / float(len(ys))
if print:
if print:
if tb_writer is not None:
scalars = {'loss': loss.item(), 'align_loss': a_loss_pf.item(), 'tanh_loss': tanh_loss.item(),
'acc': acc, 't_weight': t_weight, 'align_pf_weight': align_pf_weight, 'cl_weight': cl_weight,
'num_scores>0.1': torch.count_nonzero(torch.relu(pooled-0.1),dim=1).float().mean().item()}
if not pretrain:
scalars['class_loss'] = class_loss.item()
for key, value in scalars.items():
tb_writer.add_scalar(key, value, global_epoch)

with torch.no_grad():
if pretrain:
train_iter.set_postfix_str(
Expand Down
2 changes: 1 addition & 1 deletion util/vis_pipnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def visualize_topk(net, projectloader, num_classes, device, foldername, args: ar
max_per_prototype, max_idx_per_prototype = torch.max(softmaxes, dim=0)
max_per_prototype_h, max_idx_per_prototype_h = torch.max(max_per_prototype, dim=1)
max_per_prototype_w, max_idx_per_prototype_w = torch.max(max_per_prototype_h, dim=1) #shape (num_prototypes)

c_weight = torch.max(classification_weights[:,p]) #ignore prototypes that are not relevant to any class
if (c_weight > 1e-10) or ('pretrain' in foldername):

Expand Down