diff --git a/lib/core/function.py b/lib/core/function.py index 474abb5f..38f2c00f 100644 --- a/lib/core/function.py +++ b/lib/core/function.py @@ -92,8 +92,7 @@ def validate(config, testloader, model, writer_dict, device): world_size = get_world_size() model.eval() ave_loss = AverageMeter() - confusion_matrix = np.zeros( - (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES)) + confusion_matrix = torch.zeros([config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES], device=device) with torch.no_grad(): for _, batch in enumerate(testloader): @@ -109,22 +108,18 @@ def validate(config, testloader, model, writer_dict, device): reduced_loss = reduce_tensor(loss) ave_loss.update(reduced_loss.item()) - confusion_matrix += get_confusion_matrix( - label, - pred, - size, - config.DATASET.NUM_CLASSES, - config.TRAIN.IGNORE_LABEL) + confusion_matrix += get_confusion_matrix_gpu(label, pred, size, + config.DATASET.NUM_CLASSES, + config.TRAIN.IGNORE_LABEL, + device) - confusion_matrix = torch.from_numpy(confusion_matrix).to(device) reduced_confusion_matrix = reduce_tensor(confusion_matrix) - confusion_matrix = reduced_confusion_matrix.cpu().numpy() - pos = confusion_matrix.sum(1) - res = confusion_matrix.sum(0) - tp = np.diag(confusion_matrix) - IoU_array = (tp / np.maximum(1.0, pos + res - tp)) - mean_IoU = IoU_array.mean() + pos = torch.sum(reduced_confusion_matrix, 1) + res = torch.sum(reduced_confusion_matrix, 0) + tp = torch.diag(reduced_confusion_matrix) + IoU_array = (tp / torch.maximum(torch.ones_like(tp), pos + res - tp)) + mean_IoU = torch.mean(IoU_array) print_loss = ave_loss.average()/world_size if rank == 0: diff --git a/lib/utils/utils.py b/lib/utils/utils.py index 00a98713..f426ca63 100644 --- a/lib/utils/utils.py +++ b/lib/utils/utils.py @@ -139,8 +139,41 @@ def get_confusion_matrix(label, pred, size, num_class, ignore=-1): i_pred] = label_count[cur_index] return confusion_matrix +def get_confusion_matrix_gpu(label, pred, size, num_class, ignore=-1, device=None): + """ + The original version calculate the conf_mat in numpy array + which introduces highly expensive gpu2cpu cost + makes the validation in each epoch too slow. + To solve this problem + a conf_mat calculation method by torch api is provided here + which eliminates the gpu2cpu data trans. + """ + n, c, h, w = pred.shape + + output_gpu = pred.reshape([n,c,-1]) + output_gpu = torch.transpose(output_gpu, 1, 2) + output_gpu = torch.reshape(output_gpu, [n, h, w, c]) + + seg_pred_gpu = torch.argmax(output_gpu, axis=3) + seg_gt_gpu = label[:, :h, :w] + + ignore_index_gpu = ~seg_gt_gpu.eq(ignore) + + seg_gt_gpu = seg_gt_gpu[ignore_index_gpu] + seg_pred_gpu = seg_pred_gpu[ignore_index_gpu] + + index_gpu = (seg_gt_gpu * num_class + seg_pred_gpu).int() + label_count_gpu = torch.bincount(index_gpu) + + label_count_gpu_len = label_count_gpu.shape[0] + confusion_matrix_gpu = torch.zeros([num_class*num_class,], device = device) + confusion_matrix_gpu[0:label_count_gpu_len] += label_count_gpu + confusion_matrix_gpu = torch.reshape(confusion_matrix_gpu, [num_class,num_class]) + + return confusion_matrix_gpu + def adjust_learning_rate(optimizer, base_lr, max_iters, cur_iters, power=0.9): lr = base_lr*((1-float(cur_iters)/max_iters)**(power)) optimizer.param_groups[0]['lr'] = lr - return lr \ No newline at end of file + return lr