diff --git a/main.py b/main.py index 0ecf0a7..7825013 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,7 @@ +from datetime import datetime + +from torch.utils.tensorboard import SummaryWriter + from pipnet.pipnet import PIPNet, get_network from util.log import Log import torch.nn as nn @@ -56,6 +60,7 @@ def run_pipnet(args=None): device = torch.device('cuda:'+str(device_ids[0])) else: device = torch.device('cpu') + device_ids.append(0) # Log which device was actually used print("Device used: ", device, "with id", device_ids, flush=True) @@ -81,9 +86,12 @@ def run_pipnet(args=None): classification_layer = classification_layer ) net = net.to(device=device) - net = nn.DataParallel(net, device_ids = device_ids) + net = nn.DataParallel(net, device_ids = device_ids) - optimizer_net, optimizer_classifier, params_to_freeze, params_to_train, params_backbone = get_optimizer_nn(net, args) + optimizer_net, optimizer_classifier, params_to_freeze, params_to_train, params_backbone = get_optimizer_nn(net, args) + global_epoch = 1 + dirname = datetime.now().strftime("%Y_%m_%d__%H_%M_%S") + tb_writer = SummaryWriter(f"tensorboard/{dirname}") # Initialize or load model with torch.no_grad(): @@ -157,13 +165,16 @@ def run_pipnet(args=None): print("\nPretrain Epoch", epoch, "with batch size", trainloader_pretraining.batch_size, flush=True) # Pretrain prototypes - train_info = train_pipnet(net, trainloader_pretraining, optimizer_net, optimizer_classifier, scheduler_net, None, criterion, epoch, args.epochs_pretrain, device, pretrain=True, finetune=False) + train_info = train_pipnet(net, trainloader_pretraining, optimizer_net, optimizer_classifier, scheduler_net, None, criterion, epoch, args.epochs_pretrain, global_epoch, device, pretrain=True, finetune=False, tb_writer=tb_writer) lrs_pretrain_net+=train_info['lrs_net'] + # eval_info = eval_pipnet(net, testloader, global_epoch, device, log, tensorboard=tb_writer) + plt.clf() plt.plot(lrs_pretrain_net) plt.savefig(os.path.join(args.log_dir,'lr_pretrain_net.png')) log.log_values('log_epoch_overview', epoch, "n.a.", "n.a.", "n.a.", "n.a.", "n.a.", "n.a.", "n.a.", train_info['loss']) - + global_epoch += 1 + if args.state_dict_dir_net == '': net.eval() torch.save({'model_state_dict': net.state_dict(), 'optimizer_net_state_dict': optimizer_net.state_dict()}, os.path.join(os.path.join(args.log_dir, 'checkpoints'), 'net_pretrained')) @@ -239,11 +250,11 @@ def run_pipnet(args=None): print("Classifier bias: ", net.module._classification.bias, flush=True) torch.set_printoptions(profile="default") - train_info = train_pipnet(net, trainloader, optimizer_net, optimizer_classifier, scheduler_net, scheduler_classifier, criterion, epoch, args.epochs, device, pretrain=False, finetune=finetune) + train_info = train_pipnet(net, trainloader, optimizer_net, optimizer_classifier, scheduler_net, scheduler_classifier, criterion, epoch, args.epochs, global_epoch, device, pretrain=False, finetune=finetune, tb_writer=tb_writer) lrs_net+=train_info['lrs_net'] lrs_classifier+=train_info['lrs_class'] # Evaluate model - eval_info = eval_pipnet(net, testloader, epoch, device, log) + eval_info = eval_pipnet(net, testloader, global_epoch, device, log, tensorboard=tb_writer) log.log_values('log_epoch_overview', epoch, eval_info['top1_accuracy'], eval_info['top5_accuracy'], eval_info['almost_sim_nonzeros'], eval_info['local_size_all_classes'], eval_info['almost_nonzeros'], eval_info['num non-zero prototypes'], train_info['train_accuracy'], train_info['loss']) with torch.no_grad(): @@ -261,6 +272,7 @@ def run_pipnet(args=None): plt.clf() plt.plot(lrs_classifier) plt.savefig(os.path.join(args.log_dir,'lr_class.png')) + global_epoch += 1 net.eval() torch.save({'model_state_dict': net.state_dict(), 'optimizer_net_state_dict': optimizer_net.state_dict(), 'optimizer_classifier_state_dict': optimizer_classifier.state_dict()}, os.path.join(os.path.join(args.log_dir, 'checkpoints'), 'net_trained_last')) @@ -361,11 +373,11 @@ def run_pipnet(args=None): tqdm_dir = os.path.join(args.log_dir,'tqdm.txt') if not os.path.isdir(args.log_dir): os.mkdir(args.log_dir) - - sys.stdout.close() - sys.stderr.close() - sys.stdout = open(print_dir, 'w') - sys.stderr = open(tqdm_dir, 'w') + + # sys.stdout.close() + # sys.stderr.close() + # sys.stdout = open(print_dir, 'w') + # sys.stderr = open(tqdm_dir, 'w') run_pipnet(args) sys.stdout.close() diff --git a/pipnet/pipnet.py b/pipnet/pipnet.py index 3722bd8..5f30b99 100644 --- a/pipnet/pipnet.py +++ b/pipnet/pipnet.py @@ -87,7 +87,7 @@ def get_network(num_classes: int, args: argparse.Namespace): num_prototypes = first_add_on_layer_in_channels print("Number of prototypes: ", num_prototypes, flush=True) add_on_layers = nn.Sequential( - nn.Softmax(dim=1), #softmax over every prototype for each patch, such that for every location in image, sum over prototypes is 1 + nn.Softmax(dim=1), #softmax over every prototype for each patch, such that for every location in image, sum over prototypes is 1 ) else: num_prototypes = args.num_features @@ -99,7 +99,7 @@ def get_network(num_classes: int, args: argparse.Namespace): pool_layer = nn.Sequential( nn.AdaptiveMaxPool2d(output_size=(1,1)), #outputs (bs, ps,1,1) nn.Flatten() #outputs (bs, ps) - ) + ) if args.bias: classification_layer = NonNegLinear(num_prototypes, num_classes, bias=True) diff --git a/pipnet/test.py b/pipnet/test.py index 18d8238..71bf485 100644 --- a/pipnet/test.py +++ b/pipnet/test.py @@ -1,3 +1,8 @@ +import itertools +from typing import List + +from matplotlib import pyplot as plt +from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm import numpy as np import torch @@ -13,7 +18,8 @@ def eval_pipnet(net, test_loader: DataLoader, epoch, device, - log: Log = None, + log: Log = None, + tensorboard: SummaryWriter|None = None, progress_prefix: str = 'Eval Epoch' ) -> dict: @@ -41,6 +47,7 @@ def eval_pipnet(net, mininterval=5., ncols=0) (xs, ys) = next(iter(test_loader)) + inputs, results = [], [] # Iterate through the test set for i, (xs, ys) in test_iter: xs, ys = xs.to(device), ys.to(device) @@ -81,14 +88,20 @@ def eval_pipnet(net, y_preds += ys_pred_scores.detach().tolist() y_trues += ys.detach().tolist() y_preds_classes += ys_pred.detach().tolist() + inputs.append(xs) + results.append((pooled, out, ys_pred)) + - del out - del pooled - del ys_pred + # del out + # del pooled + # del ys_pred - print("PIP-Net abstained from a decision for", abstained.item(), "images", flush=True) + print("PIP-Net abstained from a decision for", abstained.item(), "images", flush=True) + info['abstained'] = abstained.item() info['num non-zero prototypes'] = torch.gt(net.module._classification.weight,1e-3).any(dim=0).sum().item() - print("sparsity ratio: ", (torch.numel(net.module._classification.weight)-torch.count_nonzero(torch.nn.functional.relu(net.module._classification.weight-1e-3)).item()) / torch.numel(net.module._classification.weight), flush=True) + sparsity_ratio = (torch.numel(net.module._classification.weight)-torch.count_nonzero(torch.nn.functional.relu(net.module._classification.weight-1e-3)).item()) / torch.numel(net.module._classification.weight) + print("sparsity ratio: ", sparsity_ratio, flush=True) + info['sparsity ratio'] = sparsity_ratio info['confusion_matrix'] = cm info['test_accuracy'] = acc_from_cm(cm) info['top1_accuracy'] = global_top1acc/len(test_loader.dataset) @@ -97,6 +110,10 @@ def eval_pipnet(net, info['local_size_all_classes'] = local_size_total / len(test_loader.dataset) info['almost_nonzeros'] = global_anz/len(test_loader.dataset) + if tensorboard is not None: + log_to_tensorboard(info, net.module, inputs, results, classes=test_loader.dataset.class_to_idx.values(), global_epoch=epoch, tb_writer=tensorboard) + tensorboard.flush() + if net.module._num_classes == 2: tp = cm[0][0] fn = cm[0][1] @@ -128,6 +145,113 @@ def eval_pipnet(net, return info +def log_to_tensorboard(info: dict, model_module, inp, preds, classes: List[str], global_epoch: int, tb_writer: SummaryWriter): + for key, value in info.items(): + match key: + case 'confusion_matrix': + # figure = plot_confusion_matrix(value, class_names=classes) + # image = plot_to_image(figure) + # tb_writer.add_image(key, image, global_epoch) + pass + case _: + tb_writer.add_scalar(key, value, global_epoch) + + if hasattr(model_module, '_classification'): + weights = model_module._classification.weight.flatten().cpu() + fig = plot_tensor(weights, samples=66) + tb_writer.add_figure('classification_weights', fig, global_epoch) + + fig = plot_tensor(torch.relu(weights), samples=66) + tb_writer.add_figure('classification_non_neg', fig, global_epoch) + + inp = inp[0] + tb_writer.add_graph(model_module, inp) + pooled, out, ys_pred = [torch.cat([pred[i] for pred in preds], dim=0) for i in range(3)] + tb_writer.add_figure('pooled', plot_batch(pooled), global_epoch) + tb_writer.add_figure('out', plot_batch(out), global_epoch) + tb_writer.add_histogram('ys_pred', ys_pred, global_epoch) + +def plot_batch(tensor: torch.Tensor, samples = 8, title: str = None): + batch_size = tensor.shape[0] + if batch_size > samples: + indices = torch.linspace(0, batch_size - 1, steps=samples).long() + sample_tensor = tensor[indices] + else: + sample_tensor = tensor + + stats = [ + ("Max over batch", torch.max(tensor, dim=0).values), + ("Min over batch", torch.min(tensor, dim=0).values), + ("Mean over batch", torch.mean(tensor, dim=0)) + ] + + fig = plt.figure(figsize=(12, 10)) + if title: + fig.suptitle(title, fontsize=16) + + ax = plt.subplot(2, 2, 1) + for s in sample_tensor: + plot_tensor(s, subplot=True) + ax.set_title("Samples") + + for i, (stat_title, stat_tensor) in enumerate(stats, start=2): + ax = plt.subplot(2, 2, i) + plot_tensor(stat_tensor, subplot=True) + ax.set_title(stat_title) + + plt.tight_layout() + return fig + + +def plot_tensor(tensor: torch.Tensor, title: str = None, samples: int = 50, subplot: bool = False): + array = tensor.cpu().numpy() + array = smooth_array(array, samples) + if not subplot: + plt.figure() + plt.plot(array) + if title is not None: + plt.title(title) + return plt.gcf() + +def smooth_array(array: np.ndarray, samples: int = 50): + window_size = len(array) // samples + kernel = np.ones(window_size) / window_size + smoothed_tensor = np.convolve(array, kernel, mode='valid') + return smoothed_tensor + +def plot_confusion_matrix(cm, class_names): + figure = plt.figure(figsize=(8, 8)) + plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Accent) + plt.title("Confusion matrix") + plt.colorbar() + tick_marks = np.arange(len(class_names)) + plt.xticks(tick_marks, class_names, rotation=45) + plt.yticks(tick_marks, class_names) + + cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2) + threshold = cm.max() / 2. + + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): + color = "white" if cm[i, j] > threshold else "black" + plt.text(j, i, cm[i, j], horizontalalignment="center", color=color) + + plt.tight_layout() + plt.ylabel('True label') + plt.xlabel('Predicted label') + + return figure + +def plot_to_image(figure): + figure.canvas.draw() + data = np.frombuffer(figure.canvas.tostring_argb(), dtype=np.uint8) + data = data.reshape(figure.canvas.get_width_height()[::-1] + (4,)) # (H, W, 4) + + data = data[:, :, [1, 2, 3]] # Drop the alpha channel (0th index) + tensor = torch.tensor(data) + tensor = tensor.float() / 255.0 + plt.close(figure) + return tensor.permute(2, 0, 1) + def acc_from_cm(cm: np.ndarray) -> float: """ Compute the accuracy from the confusion matrix diff --git a/pipnet/train.py b/pipnet/train.py index bb013ef..5cc385f 100644 --- a/pipnet/train.py +++ b/pipnet/train.py @@ -3,9 +3,10 @@ import torch.nn.functional as F import torch.optim import torch.utils.data -import math +from torch.utils.tensorboard import SummaryWriter -def train_pipnet(net, train_loader, optimizer_net, optimizer_classifier, scheduler_net, scheduler_classifier, criterion, epoch, nr_epochs, device, pretrain=False, finetune=False, progress_prefix: str = 'Train Epoch'): + +def train_pipnet(net, train_loader, optimizer_net, optimizer_classifier, scheduler_net, scheduler_classifier, criterion, epoch, nr_epochs, global_epoch, device, pretrain=False, finetune=False, progress_prefix: str = 'Train Epoch', tb_writer: SummaryWriter = None): # Make sure the model is in train mode net.train() @@ -65,7 +66,7 @@ def train_pipnet(net, train_loader, optimizer_net, optimizer_classifier, schedul # Perform a forward pass through the network proto_features, pooled, out = net(torch.cat([xs1, xs2])) - loss, acc = calculate_loss(proto_features, pooled, out, ys, align_pf_weight, t_weight, unif_weight, cl_weight, net.module._classification.normalization_multiplier, pretrain, finetune, criterion, train_iter, print=True, EPS=1e-8) + loss, acc = calculate_loss(proto_features, pooled, out, ys, align_pf_weight, t_weight, unif_weight, cl_weight, net.module._classification.normalization_multiplier, pretrain, finetune, criterion, train_iter, global_epoch, print=True, EPS=1e-8, tb_writer=tb_writer) # Compute the gradient loss.backward() @@ -99,7 +100,7 @@ def train_pipnet(net, train_loader, optimizer_net, optimizer_classifier, schedul return train_info -def calculate_loss(proto_features, pooled, out, ys1, align_pf_weight, t_weight, unif_weight, cl_weight, net_normalization_multiplier, pretrain, finetune, criterion, train_iter, print=True, EPS=1e-10): +def calculate_loss(proto_features, pooled, out, ys1, align_pf_weight, t_weight, unif_weight, cl_weight, net_normalization_multiplier, pretrain, finetune, criterion, train_iter, global_epoch, print=True, EPS=1e-10, tb_writer=None): ys = torch.cat([ys1,ys1]) pooled1, pooled2 = pooled.chunk(2) pf1, pf2 = proto_features.chunk(2) @@ -132,7 +133,16 @@ def calculate_loss(proto_features, pooled, out, ys1, align_pf_weight, t_weight, ys_pred_max = torch.argmax(out, dim=1) correct = torch.sum(torch.eq(ys_pred_max, ys)) acc = correct.item() / float(len(ys)) - if print: + if print: + if tb_writer is not None: + scalars = {'loss': loss.item(), 'align_loss': a_loss_pf.item(), 'tanh_loss': tanh_loss.item(), + 'acc': acc, 't_weight': t_weight, 'align_pf_weight': align_pf_weight, 'cl_weight': cl_weight, + 'num_scores>0.1': torch.count_nonzero(torch.relu(pooled-0.1),dim=1).float().mean().item()} + if not pretrain: + scalars['class_loss'] = class_loss.item() + for key, value in scalars.items(): + tb_writer.add_scalar(key, value, global_epoch) + with torch.no_grad(): if pretrain: train_iter.set_postfix_str( diff --git a/util/vis_pipnet.py b/util/vis_pipnet.py index 055db1a..1248989 100644 --- a/util/vis_pipnet.py +++ b/util/vis_pipnet.py @@ -115,7 +115,7 @@ def visualize_topk(net, projectloader, num_classes, device, foldername, args: ar max_per_prototype, max_idx_per_prototype = torch.max(softmaxes, dim=0) max_per_prototype_h, max_idx_per_prototype_h = torch.max(max_per_prototype, dim=1) max_per_prototype_w, max_idx_per_prototype_w = torch.max(max_per_prototype_h, dim=1) #shape (num_prototypes) - + c_weight = torch.max(classification_weights[:,p]) #ignore prototypes that are not relevant to any class if (c_weight > 1e-10) or ('pretrain' in foldername):