diff --git a/README.md b/README.md index 53c8417..5198790 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ If you have any questions, please let us know: ## Instructions This code has been tested on -- Python 3.8.5, PyTorch 1.7.1, CUDA 11.2, gcc 9.3.0, GeForce RTX 3090/GeForce GTX 1080Ti +- Python 3.10, PyTorch 2.0.1, CUDA 11.7, gcc 9.3.0, GeForce RTX 3090/GeForce GTX 1080Ti **Note**: We observe random data loader crashes due to memory issues, if you observe similar issues, please consider reducing the number of workers or increasing CPU RAM. We now released a sparse convolution-based Predator, have a look [here](https://github.com/ShengyuH/OverlapPredator.Mink.git)! diff --git a/lib/benchmark.py b/lib/benchmark.py index a827eef..19acc5c 100755 --- a/lib/benchmark.py +++ b/lib/benchmark.py @@ -6,10 +6,11 @@ """ import numpy as np -import os,sys,glob,torch,math +import os, sys, glob, torch, math from collections import defaultdict import nibabel.quaternions as nq + def rotation_error(R1, R2): """ Torch batch implementation of the rotation error between the estimated and the ground truth rotatiom matrix. @@ -23,7 +24,7 @@ def rotation_error(R1, R2): ae (torch tensor): Rotation error in angular degreees [b,1] """ - R_ = torch.matmul(R1.transpose(1,2), R2) + R_ = torch.matmul(R1.transpose(1, 2), R2) e = torch.stack([(torch.trace(R_[_, :, :]) - 1) / 2 for _ in range(R_.shape[0])], dim=0).unsqueeze(1) # Clamp the errors to the valid range (otherwise torch.acos() is nan) @@ -49,7 +50,8 @@ def translation_error(t1, t2): te (torch tensor): translation error in meters [b,1] """ - return torch.norm(t1-t2, dim=(1, 2)) + return torch.norm(t1 - t2, dim=(1, 2)) + def computeTransformationErr(trans, info): """ @@ -63,15 +65,16 @@ def computeTransformationErr(trans, info): Returns: p (float): transformation error """ - + t = trans[:3, 3] r = trans[:3, :3] q = nq.mat2quat(r) er = np.concatenate([t, q[1:]], axis=0) p = er.reshape(1, 6) @ info @ er.reshape(6, 1) / info[0, 0] - + return p.item() + def read_trajectory(filename, dim=4): """ Function that reads a trajectory saved in the 3DMatch/Redwood format to a numpy array. @@ -90,7 +93,7 @@ def read_trajectory(filename, dim=4): lines = f.readlines() # Extract the point cloud pairs - keys = lines[0::(dim+1)] + keys = lines[0::(dim + 1)] temp_keys = [] for i in range(len(keys)): temp_keys.append(keys[i].split('\t')[0:3]) @@ -99,14 +102,13 @@ def read_trajectory(filename, dim=4): for i in range(len(temp_keys)): final_keys.append([temp_keys[i][0].strip(), temp_keys[i][1].strip(), temp_keys[i][2].strip()]) - traj = [] for i in range(len(lines)): if i % 5 != 0: traj.append(lines[i].split('\t')[0:dim]) - traj = np.asarray(traj, dtype=np.float).reshape(-1,dim,dim) - + traj = np.asarray(traj, dtype=float).reshape(-1, dim, dim) + final_keys = np.asarray(final_keys) return final_keys, traj @@ -115,16 +117,16 @@ def read_trajectory(filename, dim=4): def read_trajectory_info(filename, dim=6): """ Function that reads the trajectory information saved in the 3DMatch/Redwood format to a numpy array. - Information file contains the variance-covariance matrix of the transformation paramaters. + Information file contains the variance-covariance matrix of the transformation paramaters. Format specification can be found at http://redwood-data.org/indoor/fileformat.html - + Args: filename (str): path to the '.txt' file containing the trajectory information data dim (int): dimension of the transformation matrix (4x4 for 3D data) Returns: n_frame (int): number of fragments in the scene - cov_matrix (numpy array): covariance matrix of the transformation matrices for n pairs[n,dim, dim] + cov_matrix (numpy array): covariance matrix of the transformation matrices for n pairs[n,dim, dim] """ with open(filename) as fid: @@ -139,40 +141,42 @@ def read_trajectory_info(filename, dim=6): info_matrix = np.concatenate( [np.fromstring(item, sep='\t').reshape(1, -1) for item in contents[i * 7 + 1:i * 7 + 7]], axis=0) info_list.append(info_matrix) - - cov_matrix = np.asarray(info_list, dtype=np.float).reshape(-1,dim,dim) - + + cov_matrix = np.asarray(info_list, dtype=float).reshape(-1, dim, dim) + return n_frame, cov_matrix -def extract_corresponding_trajectors(est_pairs,gt_pairs, gt_traj): + +def extract_corresponding_trajectors(est_pairs, gt_pairs, gt_traj): """ Extract only those transformation matrices from the ground truth trajectory that are also in the estimated trajectory. - + Args: est_pairs (numpy array): indices of point cloud pairs with enough estimated overlap [m, 3] gt_pairs (numpy array): indices of gt overlaping point cloud pairs [n,3] gt_traj (numpy array): 3d array of the gt transformation parameters [n,4,4] Returns: - ext_traj (numpy array): gt transformation parameters for the point cloud pairs from est_pairs [m,4,4] + ext_traj (numpy array): gt transformation parameters for the point cloud pairs from est_pairs [m,4,4] """ ext_traj = np.zeros((len(est_pairs), 4, 4)) for est_idx, pair in enumerate(est_pairs): pair[2] = gt_pairs[0][2] gt_idx = np.where((gt_pairs == pair).all(axis=1))[0] - - ext_traj[est_idx,:,:] = gt_traj[gt_idx,:,:] + + ext_traj[est_idx, :, :] = gt_traj[gt_idx, :, :] return ext_traj -def write_trajectory(traj,metadata, filename, dim=4): + +def write_trajectory(traj, metadata, filename, dim=4): """ - Writes the trajectory into a '.txt' file in 3DMatch/Redwood format. + Writes the trajectory into a '.txt' file in 3DMatch/Redwood format. Format specification can be found at http://redwood-data.org/indoor/fileformat.html Args: - traj (numpy array): trajectory for n pairs[n,dim, dim] + traj (numpy array): trajectory for n pairs[n,dim, dim] metadata (numpy array): file containing metadata about fragment numbers [n,3] filename (str): path where to save the '.txt' file containing trajectory data dim (int): dimension of the transformation matrix (4x4 for 3D data) @@ -182,39 +186,39 @@ def write_trajectory(traj,metadata, filename, dim=4): for idx in range(traj.shape[0]): # Only save the transfromation parameters for which the overlap threshold was satisfied if metadata[idx][2]: - p = traj[idx,:,:].tolist() + p = traj[idx, :, :].tolist() f.write('\t'.join(map(str, metadata[idx])) + '\n') f.write('\n'.join('\t'.join(map('{0:.12f}'.format, p[i])) for i in range(dim))) f.write('\n') -def read_pairs(src_path,tgt_path,n_points): +def read_pairs(src_path, tgt_path, n_points): # get pointcloud src = torch.load(src_path) tgt = torch.load(tgt_path) - src_pcd, src_embedding = src['coords'],src['feats'] + src_pcd, src_embedding = src['coords'], src['feats'] tgt_pcd, tgt_embedding = tgt['coords'], tgt['feats'] - - #permute and randomly select 2048/1024 points - if(src_pcd.shape[0]>n_points): - src_permute=np.random.permutation(src_pcd.shape[0])[:n_points] + + # permute and randomly select 2048/1024 points + if (src_pcd.shape[0] > n_points): + src_permute = np.random.permutation(src_pcd.shape[0])[:n_points] else: - src_permute=np.random.choice(src_pcd.shape[0],n_points) - if(tgt_pcd.shape[0]>n_points): - tgt_permute=np.random.permutation(tgt_pcd.shape[0])[:n_points] + src_permute = np.random.choice(src_pcd.shape[0], n_points) + if (tgt_pcd.shape[0] > n_points): + tgt_permute = np.random.permutation(tgt_pcd.shape[0])[:n_points] else: - tgt_permute=np.random.choice(tgt_pcd.shape[0],n_points) + tgt_permute = np.random.choice(tgt_pcd.shape[0], n_points) - src_pcd,src_embedding = src_pcd[src_permute],src_embedding[src_permute] - tgt_pcd,tgt_embedding = tgt_pcd[tgt_permute],tgt_embedding[tgt_permute] - return src_pcd,src_embedding,tgt_pcd,tgt_embedding + src_pcd, src_embedding = src_pcd[src_permute], src_embedding[src_permute] + tgt_pcd, tgt_embedding = tgt_pcd[tgt_permute], tgt_embedding[tgt_permute] + return src_pcd, src_embedding, tgt_pcd, tgt_embedding def evaluate_registration(num_fragment, result, result_pairs, gt_pairs, gt, gt_info, err2=0.2): """ Evaluates the performance of the registration algorithm according to the evaluation protocol defined by the 3DMatch/Redwood datasets. The evaluation protocol can be found at http://redwood-data.org/indoor/registration.html - + Args: num_fragment (int): path to the '.txt' file containing the trajectory information data result (numpy array): estimated transformation matrices [n,4,4] @@ -230,14 +234,14 @@ def evaluate_registration(num_fragment, result, result_pairs, gt_pairs, gt, gt_i """ err2 = err2 ** 2 - gt_mask = np.zeros((num_fragment, num_fragment), dtype=np.int) - flags=[] + gt_mask = np.zeros((num_fragment, num_fragment), dtype=int) + flags = [] for idx in range(gt_pairs.shape[0]): - i = int(gt_pairs[idx,0]) - j = int(gt_pairs[idx,1]) + i = int(gt_pairs[idx, 0]) + j = int(gt_pairs[idx, 1]) - # Only non consecutive pairs are tested + # Only non-consecutive pairs are tested if j - i > 1: gt_mask[i, j] = idx @@ -246,14 +250,14 @@ def evaluate_registration(num_fragment, result, result_pairs, gt_pairs, gt, gt_i good = 0 n_res = 0 for idx in range(result_pairs.shape[0]): - i = int(result_pairs[idx,0]) - j = int(result_pairs[idx,1]) - pose = result[idx,:,:] + i = int(result_pairs[idx, 0]) + j = int(result_pairs[idx, 1]) + pose = result[idx, :, :] if gt_mask[i, j] > 0: n_res += 1 gt_idx = gt_mask[i, j] - p = computeTransformationErr(np.linalg.inv(gt[gt_idx,:,:]) @ pose, gt_info[gt_idx,:,:]) + p = computeTransformationErr(np.linalg.inv(gt[gt_idx, :, :]) @ pose, gt_info[gt_idx, :, :]) if p <= err2: good += 1 flags.append(0) @@ -268,70 +272,80 @@ def evaluate_registration(num_fragment, result, result_pairs, gt_pairs, gt, gt_i return precision, recall, flags -def benchmark(est_folder,gt_folder): + +def benchmark(est_folder, gt_folder): scenes = sorted(os.listdir(gt_folder)) - scene_names = [os.path.join(gt_folder,ele) for ele in scenes] + scene_names = [os.path.join(gt_folder, ele) for ele in scenes] re_per_scene = defaultdict(list) te_per_scene = defaultdict(list) re_all, te_all, precision, recall = [], [], [], [] - n_valids= [] + n_valids = [] - short_names=['Kitchen','Home 1','Home 2','Hotel 1','Hotel 2','Hotel 3','Study','MIT Lab'] - with open(f'{est_folder}/result','w') as f: - f.write(("Scene\t¦ prec.\t¦ rec.\t¦ re\t¦ te\t¦ samples\t¦\n")) + short_names = ['Kitchen', 'Home 1', 'Home 2', 'Hotel 1', 'Hotel 2', 'Hotel 3', 'Study', 'MIT Lab'] + with open(f'{est_folder}/result.txt', 'w') as f: + f.write("Scene\t¦ prec.\t¦ rec.\t¦ re\t¦ te\t¦ samples\t¦\n") - for idx,scene in enumerate(scene_names): + for idx, scene in enumerate(scene_names): # ground truth info gt_pairs, gt_traj = read_trajectory(os.path.join(scene, "gt.log")) - n_valid=0 + n_valid = 0 for ele in gt_pairs: - diff=abs(int(ele[0])-int(ele[1])) - n_valid+=diff>1 + diff = abs(int(ele[0]) - int(ele[1])) + n_valid += diff > 1 n_valids.append(n_valid) - n_fragments, gt_traj_cov = read_trajectory_info(os.path.join(scene,"gt.info")) + n_fragments, gt_traj_cov = read_trajectory_info(os.path.join(scene, "gt.info")) # estimated info - est_pairs, est_traj = read_trajectory(os.path.join(est_folder,scenes[idx],'est.log')) + est_pairs, est_traj = read_trajectory(os.path.join(est_folder, scenes[idx], 'est.log')) + temp_precision, temp_recall, c_flag = evaluate_registration(n_fragments, est_traj, est_pairs, gt_pairs, + gt_traj, gt_traj_cov) - temp_precision, temp_recall,c_flag = evaluate_registration(n_fragments, est_traj, est_pairs, gt_pairs, gt_traj, gt_traj_cov) - # Filter out the estimated rotation matrices - ext_gt_traj = extract_corresponding_trajectors(est_pairs,gt_pairs, gt_traj) + ext_gt_traj = extract_corresponding_trajectors(est_pairs, gt_pairs, gt_traj) - re = rotation_error(torch.from_numpy(ext_gt_traj[:,0:3,0:3]), torch.from_numpy(est_traj[:,0:3,0:3])).cpu().numpy()[np.array(c_flag)==0] - te = translation_error(torch.from_numpy(ext_gt_traj[:,0:3,3:4]), torch.from_numpy(est_traj[:,0:3,3:4])).cpu().numpy()[np.array(c_flag)==0] + # Todo: check the case when rotation_error() and translation_error() return an empty list which causes + # errors (happens with sun3d-hote-uc-scan3d idx 4 and 'configs/benchmarks/3DMatch\\sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika' idx 7) + re = rotation_error(torch.from_numpy(ext_gt_traj[:, 0:3, 0:3]), + torch.from_numpy(est_traj[:, 0:3, 0:3])).cpu().numpy()[np.array(c_flag) == 0] + te = translation_error(torch.from_numpy(ext_gt_traj[:, 0:3, 3:4]), + torch.from_numpy(est_traj[:, 0:3, 3:4])).cpu().numpy()[np.array(c_flag) == 0] + if re.size <= 0: + re = np.array([0]) + if te.size <= 0: + te = np.array([0]) re_per_scene['mean'].append(np.mean(re)) re_per_scene['median'].append(np.median(re)) re_per_scene['min'].append(np.min(re)) re_per_scene['max'].append(np.max(re)) - te_per_scene['mean'].append(np.mean(te)) te_per_scene['median'].append(np.median(te)) te_per_scene['min'].append(np.min(te)) te_per_scene['max'].append(np.max(te)) - re_all.extend(re.reshape(-1).tolist()) te_all.extend(te.reshape(-1).tolist()) precision.append(temp_precision) recall.append(temp_recall) - f.write("{}\t¦ {:.3f}\t¦ {:.3f}\t¦ {:.3f}\t¦ {:.3f}\t¦ {:3d}¦\n".format(short_names[idx], temp_precision, temp_recall, np.median(re), np.median(te), n_valid)) - np.save(f'{est_folder}/{scenes[idx]}/flag.npy',c_flag) - + f.write("{}\t¦ {:.3f}\t¦ {:.3f}\t¦ {:.3f}\t¦ {:.3f}\t¦ {:3d}¦\n".format(short_names[idx], temp_precision, + temp_recall, np.median(re), + np.median(te), n_valid)) + np.save(f'{est_folder}/{scenes[idx]}/flag.npy', c_flag) + weighted_precision = (np.array(n_valids) * np.array(precision)).sum() / np.sum(n_valids) - f.write("Mean precision: {:.3f}: +- {:.3f}\n".format(np.mean(precision),np.std(precision))) + f.write("Mean precision: {:.3f}: +- {:.3f}\n".format(np.mean(precision), np.std(precision))) f.write("Weighted precision: {:.3f}\n".format(weighted_precision)) - f.write("Mean median RRE: {:.3f}: +- {:.3f}\n".format(np.mean(re_per_scene['median']), np.std(re_per_scene['median']))) - f.write("Mean median RTE: {:.3F}: +- {:.3f}\n".format(np.mean(te_per_scene['median']),np.std(te_per_scene['median']))) + f.write("Mean median RRE: {:.3f}: +- {:.3f}\n".format(np.mean(re_per_scene['median']), + np.std(re_per_scene['median']))) + f.write("Mean median RTE: {:.3F}: +- {:.3f}\n".format(np.mean(te_per_scene['median']), + np.std(te_per_scene['median']))) f.close() - \ No newline at end of file diff --git a/lib/benchmark_utils.py b/lib/benchmark_utils.py index f18d221..f44493a 100755 --- a/lib/benchmark_utils.py +++ b/lib/benchmark_utils.py @@ -5,7 +5,7 @@ Last modified: 30.11.2020 """ -import os,re,sys,json,yaml,random, glob, argparse, torch, pickle +import os, re, sys, json, yaml, random, glob, argparse, torch, pickle from tqdm import tqdm import numpy as np from scipy.spatial.transform import Rotation @@ -15,14 +15,14 @@ _EPS = 1e-7 # To prevent division by zero -def fmr_wrt_distance(data,split,inlier_ratio_threshold=0.05): +def fmr_wrt_distance(data, split, inlier_ratio_threshold=0.05): """ calculate feature match recall wrt distance threshold """ - fmr_wrt_distance =[] - for distance_threshold in range(1,21): - inlier_ratios =[] - distance_threshold /=100.0 + fmr_wrt_distance = [] + for distance_threshold in range(1, 21): + inlier_ratios = [] + distance_threshold /= 100.0 for idx in range(data.shape[0]): inlier_ratio = (data[idx] < distance_threshold).mean() inlier_ratios.append(inlier_ratio) @@ -30,26 +30,27 @@ def fmr_wrt_distance(data,split,inlier_ratio_threshold=0.05): for ele in split: fmr += (np.array(inlier_ratios[ele[0]:ele[1]]) > inlier_ratio_threshold).mean() fmr /= 8 - fmr_wrt_distance.append(fmr*100) + fmr_wrt_distance.append(fmr * 100) return fmr_wrt_distance + def fmr_wrt_inlier_ratio(data, split, distance_threshold=0.1): """ calculate feature match recall wrt inlier ratio threshold """ - fmr_wrt_inlier =[] - for inlier_ratio_threshold in range(1,21): - inlier_ratios =[] - inlier_ratio_threshold /=100.0 + fmr_wrt_inlier = [] + for inlier_ratio_threshold in range(1, 21): + inlier_ratios = [] + inlier_ratio_threshold /= 100.0 for idx in range(data.shape[0]): inlier_ratio = (data[idx] < distance_threshold).mean() inlier_ratios.append(inlier_ratio) - + fmr = 0 for ele in split: fmr += (np.array(inlier_ratios[ele[0]:ele[1]]) > inlier_ratio_threshold).mean() fmr /= 8 - fmr_wrt_inlier.append(fmr*100) + fmr_wrt_inlier.append(fmr * 100) return fmr_wrt_inlier @@ -58,48 +59,52 @@ def write_est_trajectory(gt_folder, exp_dir, tsfm_est): """ Write the estimated trajectories """ - scene_names=sorted(os.listdir(gt_folder)) - count=0 + scene_names = sorted(os.listdir(gt_folder)) + count = 0 for scene_name in scene_names: - gt_pairs, gt_traj = read_trajectory(os.path.join(gt_folder,scene_name,'gt.log')) + gt_pairs, gt_traj = read_trajectory(os.path.join(gt_folder, scene_name, 'gt.log')) est_traj = [] for i in range(len(gt_pairs)): est_traj.append(tsfm_est[count]) - count+=1 + count += 1 # write the trajectory - c_directory=os.path.join(exp_dir,scene_name) + c_directory = os.path.join(exp_dir, scene_name) os.makedirs(c_directory) - write_trajectory(np.array(est_traj),gt_pairs,os.path.join(c_directory,'est.log')) + write_trajectory(np.array(est_traj), gt_pairs, os.path.join(c_directory, 'est.log')) def to_tensor(array): """ Convert array to tensor """ - if(not isinstance(array,torch.Tensor)): - return torch.from_numpy(array).float() + if not isinstance(array, torch.Tensor): + # Make a copy of the array to ensure it is writable + return torch.from_numpy(array.copy()).float() else: return array + def to_array(tensor): """ Conver tensor to array """ - if(not isinstance(tensor,np.ndarray)): - if(tensor.device == torch.device('cpu')): + if not isinstance(tensor, np.ndarray): + if tensor.device == torch.device('cpu'): return tensor.numpy() else: return tensor.cpu().numpy() else: return tensor -def to_tsfm(rot,trans): + +def to_tsfm(rot, trans): tsfm = np.eye(4) - tsfm[:3,:3]=rot - tsfm[:3,3]=trans.flatten() + tsfm[:3, :3] = rot + tsfm[:3, 3] = trans.flatten() return tsfm - + + def to_o3d_pcd(xyz): """ Convert tensor/array to open3d PointCloud @@ -109,15 +114,17 @@ def to_o3d_pcd(xyz): pcd.points = o3d.utility.Vector3dVector(to_array(xyz)) return pcd + def to_o3d_feats(embedding): """ Convert tensor/array to open3d features embedding: [N, 3] """ - feats = o3d.registration.Feature() + feats = o3d.pipelines.registration.Feature() feats.data = to_array(embedding).T return feats + def get_correspondences(src_pcd, tgt_pcd, trans, search_voxel_size, K=None): src_pcd.transform(trans) pcd_tree = o3d.geometry.KDTreeFlann(tgt_pcd) @@ -129,32 +136,35 @@ def get_correspondences(src_pcd, tgt_pcd, trans, search_voxel_size, K=None): idx = idx[:K] for j in idx: correspondences.append([i, j]) - + correspondences = np.array(correspondences) correspondences = torch.from_numpy(correspondences) return correspondences + def get_blue(): """ Get color blue for rendering """ return [0, 0.651, 0.929] + def get_yellow(): """ Get color yellow for rendering """ return [1, 0.706, 0] + def random_sample(pcd, feats, N): """ Do random sampling to get exact N points and associated features pcd: [N,3] feats: [N,C] """ - if(isinstance(pcd,torch.Tensor)): + if (isinstance(pcd, torch.Tensor)): n1 = pcd.size(0) - elif(isinstance(pcd, np.ndarray)): + elif (isinstance(pcd, np.ndarray)): n1 = pcd.shape[0] if n1 == N: @@ -166,64 +176,67 @@ def random_sample(pcd, feats, N): choice = np.random.choice(n1, N) return pcd[choice], feats[choice] - -def get_angle_deviation(R_pred,R_gt): + + +def get_angle_deviation(R_pred, R_gt): """ Calculate the angle deviation between two rotaion matrice The rotation error is between [0,180] Input: R_pred: [B,3,3] R_gt : [B,3,3] - Return: + Return: degs: [B] """ - R=np.matmul(R_pred,R_gt.transpose(0,2,1)) - tr=np.trace(R,0,1,2) - rads=np.arccos(np.clip((tr-1)/2,-1,1)) # clip to valid range - degs=rads/np.pi*180 + R = np.matmul(R_pred, R_gt.transpose(0, 2, 1)) + tr = np.trace(R, 0, 1, 2) + rads = np.arccos(np.clip((tr - 1) / 2, -1, 1)) # clip to valid range + degs = rads / np.pi * 180 return degs -def ransac_pose_estimation(src_pcd, tgt_pcd, src_feat, tgt_feat, mutual = False, distance_threshold = 0.05, ransac_n = 3): + +def ransac_pose_estimation(src_pcd, tgt_pcd, src_feat, tgt_feat, mutual=False, distance_threshold=0.05, ransac_n=3): """ RANSAC pose estimation with two checkers - We follow D3Feat to set ransac_n = 3 for 3DMatch and ransac_n = 4 for KITTI. + We follow D3Feat to set ransac_n = 3 for 3DMatch and ransac_n = 4 for KITTI. For 3DMatch dataset, we observe significant improvement after changing ransac_n from 4 to 3. """ - if(mutual): - if(torch.cuda.device_count()>=1): + if (mutual): + if (torch.cuda.device_count() >= 1): device = torch.device('cuda') else: device = torch.device('cpu') src_feat, tgt_feat = to_tensor(src_feat), to_tensor(tgt_feat) - scores = torch.matmul(src_feat.to(device), tgt_feat.transpose(0,1).to(device)).cpu() - selection = mutual_selection(scores[None,:,:])[0] + scores = torch.matmul(src_feat.to(device), tgt_feat.transpose(0, 1).to(device)).cpu() + selection = mutual_selection(scores[None, :, :])[0] row_sel, col_sel = np.where(selection) - corrs = o3d.utility.Vector2iVector(np.array([row_sel,col_sel]).T) + corrs = o3d.utility.Vector2iVector(np.array([row_sel, col_sel]).T) src_pcd = to_o3d_pcd(src_pcd) tgt_pcd = to_o3d_pcd(tgt_pcd) - result_ransac = o3d.registration.registration_ransac_based_on_correspondence( - source=src_pcd, target=tgt_pcd,corres=corrs, + result_ransac = o3d.pipelines.registration.registration_ransac_based_on_correspondence( + source=src_pcd, target=tgt_pcd, corres=corrs, max_correspondence_distance=distance_threshold, - estimation_method=o3d.registration.TransformationEstimationPointToPoint(False), + estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(False), ransac_n=4, - criteria=o3d.registration.RANSACConvergenceCriteria(50000, 1000)) + criteria=o3d.pipelines.registration.RANSACConvergenceCriteria(50000, 1000)) else: src_pcd = to_o3d_pcd(src_pcd) tgt_pcd = to_o3d_pcd(tgt_pcd) src_feats = to_o3d_feats(src_feat) tgt_feats = to_o3d_feats(tgt_feat) - result_ransac = o3d.registration.registration_ransac_based_on_feature_matching( - src_pcd, tgt_pcd, src_feats, tgt_feats,distance_threshold, - o3d.registration.TransformationEstimationPointToPoint(False), ransac_n, - [o3d.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9), - o3d.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold)], - o3d.registration.RANSACConvergenceCriteria(50000, 1000)) - + result_ransac = o3d.pipelines.registration.registration_ransac_based_on_feature_matching( + src_pcd, tgt_pcd, src_feats, tgt_feats, True, distance_threshold, + o3d.pipelines.registration.TransformationEstimationPointToPoint(False), ransac_n, + [o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9), + o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold)], + o3d.pipelines.registration.RANSACConvergenceCriteria(50000, 1000)) + return result_ransac.transformation -def get_inlier_ratio(src_pcd, tgt_pcd, src_feat, tgt_feat, rot, trans, inlier_distance_threshold = 0.1): + +def get_inlier_ratio(src_pcd, tgt_pcd, src_feat, tgt_feat, rot, trans, inlier_distance_threshold=0.1): """ Compute inlier ratios with and without mutual check, return both """ @@ -233,22 +246,22 @@ def get_inlier_ratio(src_pcd, tgt_pcd, src_feat, tgt_feat, rot, trans, inlier_di tgt_feat = to_tensor(tgt_feat) rot, trans = to_tensor(rot), to_tensor(trans) - results =dict() - results['w']=dict() - results['wo']=dict() + results = dict() + results['w'] = dict() + results['wo'] = dict() - if(torch.cuda.device_count()>=1): + if (torch.cuda.device_count() >= 1): device = torch.device('cuda') else: device = torch.device('cpu') - src_pcd = (torch.matmul(rot, src_pcd.transpose(0,1)) + trans).transpose(0,1) - scores = torch.matmul(src_feat.to(device), tgt_feat.transpose(0,1).to(device)).cpu() + src_pcd = (torch.matmul(rot, src_pcd.transpose(0, 1)) + trans).transpose(0, 1) + scores = torch.matmul(src_feat.to(device), tgt_feat.transpose(0, 1).to(device)).cpu() ######################################## # 1. calculate inlier ratios wo mutual check _, idx = scores.max(-1) - dist = torch.norm(src_pcd- tgt_pcd[idx],dim=1) + dist = torch.norm(src_pcd - tgt_pcd[idx], dim=1) results['wo']['distance'] = dist.numpy() c_inlier_ratio = (dist < inlier_distance_threshold).float().mean() @@ -256,9 +269,9 @@ def get_inlier_ratio(src_pcd, tgt_pcd, src_feat, tgt_feat, rot, trans, inlier_di ######################################## # 2. calculate inlier ratios w mutual check - selection = mutual_selection(scores[None,:,:])[0] + selection = mutual_selection(scores[None, :, :])[0] row_sel, col_sel = np.where(selection) - dist = torch.norm(src_pcd[row_sel]- tgt_pcd[col_sel],dim=1) + dist = torch.norm(src_pcd[row_sel] - tgt_pcd[col_sel], dim=1) results['w']['distance'] = dist.numpy() c_inlier_ratio = (dist < inlier_distance_threshold).float().mean() @@ -270,42 +283,42 @@ def get_inlier_ratio(src_pcd, tgt_pcd, src_feat, tgt_feat, rot, trans, inlier_di def mutual_selection(score_mat): """ Return a {0,1} matrix, the element is 1 if and only if it's maximum along both row and column - + Args: np.array() score_mat: [B,N,N] Return: - mutuals: [B,N,N] + mutuals: [B,N,N] """ - score_mat=to_array(score_mat) - if(score_mat.ndim==2): - score_mat=score_mat[None,:,:] - - mutuals=np.zeros_like(score_mat) - for i in range(score_mat.shape[0]): # loop through the batch - c_mat=score_mat[i] - flag_row=np.zeros_like(c_mat) - flag_column=np.zeros_like(c_mat) - - max_along_row=np.argmax(c_mat,1)[:,None] - max_along_column=np.argmax(c_mat,0)[None,:] - np.put_along_axis(flag_row,max_along_row,1,1) - np.put_along_axis(flag_column,max_along_column,1,0) - mutuals[i]=(flag_row.astype(np.bool)) & (flag_column.astype(np.bool)) - return mutuals.astype(np.bool) + score_mat = to_array(score_mat) + if (score_mat.ndim == 2): + score_mat = score_mat[None, :, :] + + mutuals = np.zeros_like(score_mat) + for i in range(score_mat.shape[0]): # loop through the batch + c_mat = score_mat[i] + flag_row = np.zeros_like(c_mat) + flag_column = np.zeros_like(c_mat) + + max_along_row = np.argmax(c_mat, 1)[:, None] + max_along_column = np.argmax(c_mat, 0)[None, :] + np.put_along_axis(flag_row, max_along_row, 1, 1) + np.put_along_axis(flag_column, max_along_column, 1, 0) + mutuals[i] = (flag_row.astype(bool)) & (flag_column.astype(bool)) + return mutuals.astype(bool) def get_scene_split(whichbenchmark): """ Just to check how many valid fragments each scene has """ - assert whichbenchmark in ['3DMatch','3DLoMatch'] + assert whichbenchmark in ['3DMatch', '3DLoMatch'] folder = f'configs/benchmarks/{whichbenchmark}/*/gt.log' - scene_files=sorted(glob.glob(folder)) - split=[] - count=0 + scene_files = sorted(glob.glob(folder)) + split = [] + count = 0 for eachfile in scene_files: gt_pairs, gt_traj = read_trajectory(eachfile) - split.append([count,count+len(gt_pairs)]) - count+=len(gt_pairs) + split.append([count, count + len(gt_pairs)]) + count += len(gt_pairs) return split diff --git a/lib/tester.py b/lib/tester.py index 6a2070a..6ef2da2 100644 --- a/lib/tester.py +++ b/lib/tester.py @@ -12,81 +12,87 @@ from collections import defaultdict import coloredlogs + class IndoorTester(Trainer): """ 3DMatch tester """ - def __init__(self,args): - Trainer.__init__(self,args) - + + def __init__(self, args): + Trainer.__init__(self, args) + def test(self): print('Start to evaluate on test datasets...') - os.makedirs(f'{self.snapshot_dir}/{self.config.benchmark}',exist_ok=True) + os.makedirs(f'{self.snapshot_dir}/{self.config.benchmark}', exist_ok=True) num_iter = int(len(self.loader['test'].dataset) // self.loader['test'].batch_size) c_loader_iter = self.loader['test'].__iter__() self.model.eval() with torch.no_grad(): - for idx in tqdm(range(num_iter)): # loop through this epoch - inputs = c_loader_iter.next() - ################################## - # load inputs to device. - for k, v in inputs.items(): - if type(v) == list: - inputs[k] = [item.to(self.device) for item in v] - else: - inputs[k] = v.to(self.device) - ############################################### - # forward pass - feats, scores_overlap, scores_saliency = self.model(inputs) #[N1, C1], [N2, C2] - pcd = inputs['points'][0] - len_src = inputs['stack_lengths'][0][0] - c_rot, c_trans = inputs['rot'], inputs['trans'] - correspondence = inputs['correspondences'] - - src_pcd, tgt_pcd = pcd[:len_src], pcd[len_src:] - src_feats, tgt_feats = feats[:len_src], feats[len_src:] + try: + for idx in tqdm(range(num_iter)): # loop through this epoch + inputs = next(c_loader_iter) + ################################## + # load inputs to device. + for k, v in inputs.items(): + if type(v) == list: + inputs[k] = [item.to(self.device) for item in v] + else: + inputs[k] = v.to(self.device) + ############################################### + # forward pass + feats, scores_overlap, scores_saliency = self.model(inputs) # [N1, C1], [N2, C2] + pcd = inputs['points'][0] + len_src = inputs['stack_lengths'][0][0] + c_rot, c_trans = inputs['rot'], inputs['trans'] + correspondence = inputs['correspondences'] - data = dict() - data['pcd'] = pcd.cpu() - data['feats'] = feats.detach().cpu() - data['overlaps'] = scores_overlap.detach().cpu() - data['saliency'] = scores_saliency.detach().cpu() - data['len_src'] = len_src - data['rot'] = c_rot.cpu() - data['trans'] = c_trans.cpu() + src_pcd, tgt_pcd = pcd[:len_src], pcd[len_src:] + src_feats, tgt_feats = feats[:len_src], feats[len_src:] - torch.save(data,f'{self.snapshot_dir}/{self.config.benchmark}/{idx}.pth') + data = dict() + data['pcd'] = pcd.cpu() + data['feats'] = feats.detach().cpu() + data['overlaps'] = scores_overlap.detach().cpu() + data['saliency'] = scores_saliency.detach().cpu() + data['len_src'] = len_src + data['rot'] = c_rot.cpu() + data['trans'] = c_trans.cpu() + torch.save(data, f'{self.snapshot_dir}/{self.config.benchmark}/{idx}.pth') + except StopIteration: + # Handle the end of the iteration if necessary + pass class KITTITester(Trainer): """ KITTI tester """ - def __init__(self,args): - Trainer.__init__(self,args) - + + def __init__(self, args): + Trainer.__init__(self, args) + def test(self): print('Start to evaluate on test datasets...') tsfm_est = [] num_iter = int(len(self.loader['test'].dataset) // self.loader['test'].batch_size) c_loader_iter = self.loader['test'].__iter__() - + self.model.eval() - rot_gt, trans_gt =[],[] + rot_gt, trans_gt = [], [] with torch.no_grad(): - for _ in tqdm(range(num_iter)): # loop through this epoch + for _ in tqdm(range(num_iter)): # loop through this epoch inputs = c_loader_iter.next() ############################################### # forward pass - for k, v in inputs.items(): + for k, v in inputs.items(): if type(v) == list: inputs[k] = [item.to(self.device) for item in v] else: inputs[k] = v.to(self.device) - feats, scores_overlap, scores_saliency = self.model(inputs) #[N1, C1], [N2, C2] + feats, scores_overlap, scores_saliency = self.model(inputs) # [N1, C1], [N2, C2] scores_overlap = scores_overlap.detach().cpu() scores_saliency = scores_saliency.detach().cpu() @@ -95,7 +101,7 @@ def test(self): rot_gt.append(c_rot.cpu().numpy()) trans_gt.append(c_trans.cpu().numpy()) src_feats, tgt_feats = feats[:len_src], feats[len_src:] - src_pcd , tgt_pcd = inputs['src_pcd_raw'], inputs['tgt_pcd_raw'] + src_pcd, tgt_pcd = inputs['src_pcd_raw'], inputs['tgt_pcd_raw'] src_overlap, tgt_overlap = scores_overlap[:len_src], scores_overlap[len_src:] src_saliency, tgt_saliency = scores_saliency[:len_src], scores_saliency[len_src:] @@ -108,59 +114,59 @@ def test(self): src_scores = src_overlap * src_saliency tgt_scores = tgt_overlap * tgt_saliency - if(src_pcd.size(0) > n_points): + if (src_pcd.size(0) > n_points): idx = np.arange(src_pcd.size(0)) probs = (src_scores / src_scores.sum()).numpy().flatten() - idx = np.random.choice(idx, size= n_points, replace=False, p=probs) + idx = np.random.choice(idx, size=n_points, replace=False, p=probs) src_pcd, src_feats = src_pcd[idx], src_feats[idx] - if(tgt_pcd.size(0) > n_points): + if (tgt_pcd.size(0) > n_points): idx = np.arange(tgt_pcd.size(0)) probs = (tgt_scores / tgt_scores.sum()).numpy().flatten() - idx = np.random.choice(idx, size= n_points, replace=False, p=probs) + idx = np.random.choice(idx, size=n_points, replace=False, p=probs) tgt_pcd, tgt_feats = tgt_pcd[idx], tgt_feats[idx] ######################################## - # run ransac + # run ransac distance_threshold = 0.3 - ts_est = ransac_pose_estimation(src_pcd, tgt_pcd, src_feats, tgt_feats, mutual=False, distance_threshold=distance_threshold, ransac_n = 4) + ts_est = ransac_pose_estimation(src_pcd, tgt_pcd, src_feats, tgt_feats, mutual=False, + distance_threshold=distance_threshold, ransac_n=4) tsfm_est.append(ts_est) - + tsfm_est = np.array(tsfm_est) - rot_est = tsfm_est[:,:3,:3] - trans_est = tsfm_est[:,:3,3] + rot_est = tsfm_est[:, :3, :3] + trans_est = tsfm_est[:, :3, 3] rot_gt = np.array(rot_gt) - trans_gt = np.array(trans_gt)[:,:,0] + trans_gt = np.array(trans_gt)[:, :, 0] rot_threshold = 5 trans_threshold = 2 - np.savez(f'{self.snapshot_dir}/results',rot_est=rot_est, rot_gt=rot_gt, trans_est = trans_est, trans_gt = trans_gt) + np.savez(f'{self.snapshot_dir}/results', rot_est=rot_est, rot_gt=rot_gt, trans_est=trans_est, trans_gt=trans_gt) r_deviation = get_angle_deviation(rot_est, rot_gt) - translation_errors = np.linalg.norm(trans_est-trans_gt,axis=-1) + translation_errors = np.linalg.norm(trans_est - trans_gt, axis=-1) - flag_1=r_deviation n_points): + if (src_pcd.size(0) > n_points): idx = np.arange(src_pcd.size(0)) probs = (src_scores / src_scores.sum()).numpy().flatten() - idx = np.random.choice(idx, size= n_points, replace=False, p=probs) + idx = np.random.choice(idx, size=n_points, replace=False, p=probs) src_pcd, src_feats = src_pcd[idx], src_feats[idx] - if(tgt_pcd.size(0) > n_points): + if (tgt_pcd.size(0) > n_points): idx = np.arange(tgt_pcd.size(0)) probs = (tgt_scores / tgt_scores.sum()).numpy().flatten() - idx = np.random.choice(idx, size= n_points, replace=False, p=probs) + idx = np.random.choice(idx, size=n_points, replace=False, p=probs) tgt_pcd, tgt_feats = tgt_pcd[idx], tgt_feats[idx] ######################################## - # run ransac + # run ransac distance_threshold = 0.025 - ts_est = ransac_pose_estimation(src_pcd, tgt_pcd, src_feats, tgt_feats, mutual=False, distance_threshold=distance_threshold, ransac_n = 3) - except: # sometimes we left over with too few points in the bottleneck and our k-nn graph breaks + ts_est = ransac_pose_estimation(src_pcd, tgt_pcd, src_feats, tgt_feats, mutual=False, + distance_threshold=distance_threshold, ransac_n=3) + except: # sometimes we left over with too few points in the bottleneck and our k-nn graph breaks ts_est = np.eye(4) pred_transforms.append(ts_est) - total_rotation = np.concatenate(total_rotation, axis=0) - _logger.info(('Rotation range in data: {}(avg), {}(max)'.format(np.mean(total_rotation), np.max(total_rotation)))) + _logger.info( + ('Rotation range in data: {}(avg), {}(max)'.format(np.mean(total_rotation), np.max(total_rotation)))) + + pred_transforms = torch.from_numpy(np.array(pred_transforms)).float()[:, None, :, :] - pred_transforms = torch.from_numpy(np.array(pred_transforms)).float()[:,None,:,:] - c_loader_iter = self.loader['test'].__iter__() num_processed, num_total = 0, len(pred_transforms) metrics_for_iter = [defaultdict(list) for _ in range(pred_transforms.shape[1])] - + with torch.no_grad(): - for idx in tqdm(range(num_iter)): # loop through this epoch + for idx in tqdm(range(num_iter)): # loop through this epoch inputs = c_loader_iter.next() - + batch_size = 1 for i_iter in range(pred_transforms.shape[1]): - cur_pred_transforms = pred_transforms[num_processed:num_processed+batch_size, i_iter, :, :] + cur_pred_transforms = pred_transforms[num_processed:num_processed + batch_size, i_iter, :, :] metrics = compute_metrics(inputs['sample'], cur_pred_transforms) for k in metrics: metrics_for_iter[i_iter][k].append(metrics[k]) @@ -392,15 +403,14 @@ def test(self): for k in metrics_for_iter[i_iter]} summary_metrics = summarize_metrics(metrics_for_iter[i_iter]) print_metrics(_logger, summary_metrics, title='Evaluation result (iter {})'.format(i_iter)) - - + def get_trainer(config): - if(config.dataset == 'indoor'): + if (config.dataset == 'indoor'): return IndoorTester(config) - elif(config.dataset == 'kitti'): + elif (config.dataset == 'kitti'): return KITTITester(config) - elif(config.dataset == 'modelnet'): + elif (config.dataset == 'modelnet'): return ModelnetTester(config) else: raise NotImplementedError diff --git a/lib/trainer.py b/lib/trainer.py index bbf93b7..22f6ff0 100644 --- a/lib/trainer.py +++ b/lib/trainer.py @@ -1,9 +1,9 @@ -import time, os, torch,copy +import time, os, torch, copy import numpy as np import torch.nn as nn from tensorboardX import SummaryWriter from lib.timer import Timer, AverageMeter -from lib.utils import Logger,validate_gradient +from lib.utils import Logger, validate_gradient from tqdm import tqdm import torch.nn.functional as F @@ -26,35 +26,34 @@ def __init__(self, args): self.scheduler = args.scheduler self.scheduler_freq = args.scheduler_freq self.snapshot_freq = args.snapshot_freq - self.snapshot_dir = args.snapshot_dir + self.snapshot_dir = args.snapshot_dir self.benchmark = args.benchmark self.iter_size = args.iter_size - self.verbose_freq= args.verbose_freq + self.verbose_freq = args.verbose_freq self.w_circle_loss = args.w_circle_loss self.w_overlap_loss = args.w_overlap_loss - self.w_saliency_loss = args.w_saliency_loss + self.w_saliency_loss = args.w_saliency_loss self.desc_loss = args.desc_loss self.best_loss = 1e5 self.best_recall = -1e5 self.writer = SummaryWriter(log_dir=args.tboard_dir) self.logger = Logger(args.snapshot_dir) - self.logger.write(f'#parameters {sum([x.nelement() for x in self.model.parameters()])/1000000.} M\n') - + self.logger.write(f'#parameters {sum([x.nelement() for x in self.model.parameters()]) / 1000000.} M\n') - if (args.pretrain !=''): + if (args.pretrain != ''): self._load_pretrain(args.pretrain) - - self.loader =dict() - self.loader['train']=args.train_loader - self.loader['val']=args.val_loader + + self.loader = dict() + self.loader['train'] = args.train_loader + self.loader['val'] = args.val_loader self.loader['test'] = args.test_loader - with open(f'{args.snapshot_dir}/model','w') as f: + with open(f'{args.snapshot_dir}/model', 'w') as f: f.write(str(self.model)) f.close() - + def _snapshot(self, epoch, name=None): state = { 'epoch': epoch, @@ -80,7 +79,7 @@ def _load_pretrain(self, resume): self.optimizer.load_state_dict(state['optimizer']) self.best_loss = state['best_loss'] self.best_recall = state['best_recall'] - + self.logger.write(f'Successfully load pretrained model from {resume}!\n') self.logger.write(f'Current best loss {self.best_loss}\n') self.logger.write(f'Current best recall {self.best_recall}\n') @@ -91,34 +90,33 @@ def _get_lr(self, group=0): return self.optimizer.param_groups[group]['lr'] def stats_dict(self): - stats=dict() - stats['circle_loss']=0. - stats['recall']=0. # feature match recall, divided by number of ground truth pairs + stats = dict() + stats['circle_loss'] = 0. + stats['recall'] = 0. # feature match recall, divided by number of ground truth pairs stats['saliency_loss'] = 0. stats['saliency_recall'] = 0. stats['saliency_precision'] = 0. stats['overlap_loss'] = 0. - stats['overlap_recall']=0. - stats['overlap_precision']=0. + stats['overlap_recall'] = 0. + stats['overlap_precision'] = 0. return stats def stats_meter(self): - meters=dict() - stats=self.stats_dict() - for key,_ in stats.items(): - meters[key]=AverageMeter() + meters = dict() + stats = self.stats_dict() + for key, _ in stats.items(): + meters[key] = AverageMeter() return meters - def inference_one_batch(self, inputs, phase): - assert phase in ['train','val','test'] + assert phase in ['train', 'val', 'test'] ################################## # training - if(phase == 'train'): + if (phase == 'train'): self.model.train() ############################################### # forward pass - feats, scores_overlap, scores_saliency = self.model(inputs) #[N1, C1], [N2, C2] + feats, scores_overlap, scores_saliency = self.model(inputs) # [N1, C1], [N2, C2] pcd = inputs['points'][0] len_src = inputs['stack_lengths'][0][0] c_rot, c_trans = inputs['rot'], inputs['trans'] @@ -129,9 +127,11 @@ def inference_one_batch(self, inputs, phase): ################################################### # get loss - stats= self.desc_loss(src_pcd, tgt_pcd, src_feats, tgt_feats,correspondence, c_rot, c_trans, scores_overlap, scores_saliency) + stats = self.desc_loss(src_pcd, tgt_pcd, src_feats, tgt_feats, correspondence, c_rot, c_trans, + scores_overlap, scores_saliency) - c_loss = stats['circle_loss'] * self.w_circle_loss + stats['overlap_loss'] * self.w_overlap_loss + stats['saliency_loss'] * self.w_saliency_loss + c_loss = stats['circle_loss'] * self.w_circle_loss + stats['overlap_loss'] * self.w_overlap_loss + stats[ + 'saliency_loss'] * self.w_saliency_loss c_loss.backward() @@ -140,8 +140,8 @@ def inference_one_batch(self, inputs, phase): with torch.no_grad(): ############################################### # forward pass - feats, scores_overlap, scores_saliency = self.model(inputs) #[N1, C1], [N2, C2] - pcd = inputs['points'][0] + feats, scores_overlap, scores_saliency = self.model(inputs) # [N1, C1], [N2, C2] + pcd = inputs['points'][0] len_src = inputs['stack_lengths'][0][0] c_rot, c_trans = inputs['rot'], inputs['trans'] correspondence = inputs['correspondences'] @@ -151,112 +151,123 @@ def inference_one_batch(self, inputs, phase): ################################################### # get loss - stats= self.desc_loss(src_pcd, tgt_pcd, src_feats, tgt_feats,correspondence, c_rot, c_trans, scores_overlap, scores_saliency) - + stats = self.desc_loss(src_pcd, tgt_pcd, src_feats, tgt_feats, correspondence, c_rot, c_trans, + scores_overlap, scores_saliency) - ################################## + ################################## # detach the gradients for loss terms stats['circle_loss'] = float(stats['circle_loss'].detach()) stats['overlap_loss'] = float(stats['overlap_loss'].detach()) stats['saliency_loss'] = float(stats['saliency_loss'].detach()) - - return stats + return stats - def inference_one_epoch(self,epoch, phase): + def inference_one_epoch(self, epoch, phase): gc.collect() - assert phase in ['train','val','test'] + assert phase in ['train', 'val', 'test'] # init stats meter stats_meter = self.stats_meter() num_iter = int(len(self.loader[phase].dataset) // self.loader[phase].batch_size) c_loader_iter = self.loader[phase].__iter__() - + self.optimizer.zero_grad() - for c_iter in tqdm(range(num_iter)): # loop through this epoch - ################################## - # load inputs to device. - inputs = c_loader_iter.next() - for k, v in inputs.items(): - if type(v) == list: - inputs[k] = [item.to(self.device) for item in v] - elif type(v) == dict: - pass - else: - inputs[k] = v.to(self.device) + for c_iter in tqdm(range(num_iter)): # loop through this epoch + try: + # load inputs to device. + inputs = next(c_loader_iter) + for k, v in inputs.items(): + if type(v) == list: + inputs[k] = [item.to(self.device) for item in v] + elif type(v) == dict: + pass + else: + inputs[k] = v.to(self.device) + except StopIteration: + # Handle the case where the data loader iterator is exhausted. + break + except Exception as e: + print(f"An error occurred: {e}") + try: ################################## # forward pass # with torch.autograd.detect_anomaly(): stats = self.inference_one_batch(inputs, phase) - + ################################################### # run optimisation - if((c_iter+1) % self.iter_size == 0 and phase == 'train'): + if ((c_iter + 1) % self.iter_size == 0 and phase == 'train'): gradient_valid = validate_gradient(self.model) - if(gradient_valid): + if (gradient_valid): self.optimizer.step() else: self.logger.write('gradient not valid\n') self.optimizer.zero_grad() - + ################################ # update to stats_meter - for key,value in stats.items(): + for key, value in stats.items(): stats_meter[key].update(value) except Exception as inst: print(inst) - + torch.cuda.empty_cache() - + if (c_iter + 1) % self.verbose_freq == 0 and self.verbose: curr_iter = num_iter * (epoch - 1) + c_iter for key, value in stats_meter.items(): self.writer.add_scalar(f'{phase}/{key}', value.avg, curr_iter) - - message = f'{phase} Epoch: {epoch} [{c_iter+1:4d}/{num_iter}]' - for key,value in stats_meter.items(): + + message = f'{phase} Epoch: {epoch} [{c_iter + 1:4d}/{num_iter}]' + for key, value in stats_meter.items(): message += f'{key}: {value.avg:.2f}\t' self.logger.write(message + '\n') message = f'{phase} Epoch: {epoch}' - for key,value in stats_meter.items(): + for key, value in stats_meter.items(): message += f'{key}: {value.avg:.2f}\t' - self.logger.write(message+'\n') + self.logger.write(message + '\n') return stats_meter - def train(self): print('start training...') - for epoch in range(self.start_epoch, self.max_epoch): - self.inference_one_epoch(epoch,'train') - self.scheduler.step() - - stats_meter = self.inference_one_epoch(epoch,'val') - - if stats_meter['circle_loss'].avg < self.best_loss: - self.best_loss = stats_meter['circle_loss'].avg - self._snapshot(epoch,'best_loss') - if stats_meter['recall'].avg > self.best_recall: - self.best_recall = stats_meter['recall'].avg - self._snapshot(epoch,'best_recall') - - # we only add saliency loss when we get descent point-wise features - if(stats_meter['recall'].avg>0.3): - self.w_saliency_loss = 1. - else: - self.w_saliency_loss = 0. - + try: + for epoch in range(self.start_epoch, self.max_epoch): + self.inference_one_epoch(epoch, 'train') + self.scheduler.step() + + stats_meter = self.inference_one_epoch(epoch, 'val') + + if stats_meter['circle_loss'].avg < self.best_loss: + self.best_loss = stats_meter['circle_loss'].avg + self._snapshot(epoch, 'best_loss') + if stats_meter['recall'].avg > self.best_recall: + self.best_recall = stats_meter['recall'].avg + self._snapshot(epoch, 'best_recall') + + # we only add saliency loss when we get descent point-wise features + if (stats_meter['recall'].avg > 0.3): + self.w_saliency_loss = 1. + else: + self.w_saliency_loss = 0. + # finish all epoch - print("Training finish!") + except Exception as e: + print(f"An unexpected error occurred: {e}") + # Here you can add any additional logging or cleanup if necessary. + finally: + print("Cleaning up resources...") + # Perform any final cleanup here, such as closing files or releasing CUDA memory. + print("Training finished or interrupted.") def eval(self): print('Start to evaluate on validation datasets...') - stats_meter = self.inference_one_epoch(0,'val') - + stats_meter = self.inference_one_epoch(0, 'val') + for key, value in stats_meter.items(): print(key, value.avg) diff --git a/main.py b/main.py index 0b8bf6b..67a19cb 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -import os, torch, time, shutil, json,glob, argparse, shutil +import os, torch, time, shutil, json, glob, argparse, shutil import numpy as np from easydict import EasyDict as edict @@ -9,15 +9,21 @@ from lib.loss import MetricLoss from configs.models import architectures +import warnings +from sklearn.exceptions import UndefinedMetricWarning + +# Suppress only the UndefinedMetricWarning +warnings.filterwarnings("ignore", category=UndefinedMetricWarning) + from torch import optim from torch import nn -setup_seed(0) +setup_seed(0) if __name__ == '__main__': # load configs parser = argparse.ArgumentParser() - parser.add_argument('config', type=str, help= 'Path to the config file.') + parser.add_argument('config', type=str, help='Path to the config file.') args = parser.parse_args() config = load_config(args.config) config['snapshot_dir'] = 'snapshot/%s' % config['exp_dir'] @@ -33,69 +39,70 @@ open(os.path.join(config.snapshot_dir, 'config.json'), 'w'), indent=4, ) + # Your existing configuration check if config.gpu_mode: - config.device = torch.device('cuda') + # Checks if CUDA is available and then sets to CUDA device, else uses CPU + config.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: config.device = torch.device('cpu') - + # backup the files - os.system(f'cp -r models {config.snapshot_dir}') - os.system(f'cp -r datasets {config.snapshot_dir}') - os.system(f'cp -r lib {config.snapshot_dir}') - shutil.copy2('main.py',config.snapshot_dir) - - + shutil.copytree('models', os.path.join(config.snapshot_dir, 'models'), dirs_exist_ok=True) + shutil.copytree('datasets', os.path.join(config.snapshot_dir, 'datasets'), dirs_exist_ok=True) + shutil.copytree('lib', os.path.join(config.snapshot_dir, 'lib'), dirs_exist_ok=True) + shutil.copy2('main.py', config.snapshot_dir) + # model initialization config.architecture = architectures[config.dataset] - config.model = KPFCNN(config) + config.model = KPFCNN(config) - # create optimizer + # create optimizer if config.optimizer == 'SGD': config.optimizer = optim.SGD( - config.model.parameters(), + config.model.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay, - ) + ) elif config.optimizer == 'ADAM': config.optimizer = optim.Adam( - config.model.parameters(), + config.model.parameters(), lr=config.lr, betas=(0.9, 0.999), weight_decay=config.weight_decay, ) - + # create learning rate scheduler config.scheduler = optim.lr_scheduler.ExponentialLR( config.optimizer, gamma=config.scheduler_gamma, ) - + # create dataset and dataloader train_set, val_set, benchmark_set = get_datasets(config) config.train_loader, neighborhood_limits = get_dataloader(dataset=train_set, - batch_size=config.batch_size, - shuffle=True, - num_workers=config.num_workers, - ) + batch_size=config.batch_size, + shuffle=True, + num_workers=config.num_workers, + ) config.val_loader, _ = get_dataloader(dataset=val_set, - batch_size=config.batch_size, - shuffle=False, - num_workers=1, - neighborhood_limits=neighborhood_limits - ) + batch_size=config.batch_size, + shuffle=False, + num_workers=1, + neighborhood_limits=neighborhood_limits + ) config.test_loader, _ = get_dataloader(dataset=benchmark_set, - batch_size=config.batch_size, - shuffle=False, - num_workers=1, - neighborhood_limits=neighborhood_limits) - + batch_size=config.batch_size, + shuffle=False, + num_workers=1, + neighborhood_limits=neighborhood_limits) + # create evaluation metrics config.desc_loss = MetricLoss(config) trainer = get_trainer(config) - if(config.mode=='train'): + if (config.mode == 'train'): trainer.train() - elif(config.mode =='val'): + elif (config.mode == 'val'): trainer.eval() else: trainer.test() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 865c21c..f511573 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,19 +1,17 @@ -matplotlib==3.3.3 -numpy==1.19.4 -torch == 1.7.1 -torchvision==0.8.2 -torchaudio==0.7.2 -nibabel==3.2.1 -tqdm==4.38.0 -open3d==0.10.0.0 -easydict==1.9 -scipy==1.5.4 -coloredlogs==15.0 -PyYAML==5.4.1 -scikit_learn==0.24.1 -tensorboardX==2.1 -vtk_visualizer==0.9.6 -nibabel==3.2.1 -h5py==3.2.1 -coloredlogs==15.0 -gitpython==3.1.17 \ No newline at end of file +PyYAML==6.0.1 +GitPython==3.1.40 +coloredlogs==15.0.1 +easydict==1.11 +h5py==3.10.0 +matplotlib==3.8.1 +nibabel==5.1.0 +networkx==3.0 +open3d==0.17.0 +scikit-learn==1.3.2 +scipy==1.11.3 +tensorboardX==2.6.2.2 +torch==2.0.1 +torchaudio==2.0.2 +torchvision==0.15.2 +tqdm==4.66.1 +vtk-visualizer==0.9.6 diff --git a/scripts/demo.py b/scripts/demo.py index a3fbc48..80eb756 100644 --- a/scripts/demo.py +++ b/scripts/demo.py @@ -145,7 +145,7 @@ def main(config, demo_loader): config.model.eval() c_loader_iter = demo_loader.__iter__() with torch.no_grad(): - inputs = c_loader_iter.next() + inputs = next(c_loader_iter) ################################## # load inputs to device. for k, v in inputs.items():