Skip to content

Commit a6f7948

Browse files
authored
Add files via upload
1 parent f303699 commit a6f7948

File tree

5 files changed

+242
-35
lines changed

5 files changed

+242
-35
lines changed

script/check_corr.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import pykitti
2+
import numpy as np
3+
import argparse
4+
import os
5+
import cv2
6+
import open3d as o3d
7+
import copy
8+
# import matplotlib
9+
# matplotlib.use('agg')
10+
from matplotlib import pyplot as plt
11+
from tools import *
12+
13+
os.chdir(os.path.dirname(__file__))
14+
15+
def str2bool(s:str) -> bool:
16+
if s.lower == "false":
17+
return False
18+
else:
19+
return True
20+
21+
def options():
22+
parser = argparse.ArgumentParser()
23+
kitti_parser = parser.add_argument_group()
24+
kitti_parser.add_argument("--base_dir",type=str,default="/data/DATA/data_odometry/dataset/")
25+
kitti_parser.add_argument("--seq",type=int,default=0,choices=[i for i in range(11)])
26+
kitti_parser.add_argument("--index_i",type=int,default=105)
27+
kitti_parser.add_argument("--index_j",type=int,default=107)
28+
29+
io_parser = parser.add_argument_group()
30+
io_parser.add_argument("--Twc_file",type=str,default="../Twc.txt")
31+
io_parser.add_argument("--Twl_file",type=str,default="../Twl.txt")
32+
33+
arg_parser = parser.add_argument_group()
34+
arg_parser.add_argument("--tsl_perturb",type=float,nargs=3,default=[0.1,-0.15,0.1])
35+
arg_parser.add_argument("--rot_perturb",type=float,nargs=3,default=[0.1,0.08,-0.3])
36+
arg_parser.add_argument("--fps_sample",type=int,default=100)
37+
arg_parser.add_argument("--epipolar_threshold",type=float,default=2)
38+
arg_parser.add_argument("--icp_radius",type=float,default=0.6)
39+
arg_parser.add_argument("--corr_threshold",type=float,default=0.05)
40+
arg_parser.add_argument("--view",type=str2bool,default=False)
41+
args = parser.parse_args()
42+
args.seq_id = "%02d"%args.seq
43+
return args
44+
45+
def draw_registration_result(source, target, transformation):
46+
source_temp = copy.deepcopy(source)
47+
target_temp = copy.deepcopy(target)
48+
source_temp.paint_uniform_color([1, 0.706, 0])
49+
target_temp.paint_uniform_color([0, 0.651, 0.929])
50+
source_temp.transform(transformation)
51+
o3d.visualization.draw_geometries([source_temp, target_temp])
52+
53+
54+
def drawlines(img1,img2,lines,pts1,pts2):
55+
''' Copied from: https://docs.opencv.org/4.5.2/da/de9/tutorial_py_epipolar_geometry.html\n
56+
img1 - image on which we draw the epilines for the points in img2
57+
lines - corresponding epilines '''
58+
r,c = img1.shape
59+
img1 = cv2.cvtColor(img1,cv2.COLOR_GRAY2BGR)
60+
img2 = cv2.cvtColor(img2,cv2.COLOR_GRAY2BGR)
61+
for r,pt1,pt2 in zip(lines,pts1,pts2):
62+
color = tuple(np.random.randint(0,255,3).tolist())
63+
x0,y0 = map(int, [0, -r[2]/r[1] ])
64+
x1,y1 = map(int, [c, -(r[2]+r[0]*c)/r[1] ])
65+
img1 = cv2.line(img1, (x0,y0), (x1,y1), color,1)
66+
img1 = cv2.circle(img1,tuple(pt1),5,color,-1)
67+
img2 = cv2.circle(img2,tuple(pt2),5,color,-1)
68+
return img1,img2
69+
70+
def drawcorrpoints(img1,img2,pts1,pts2):
71+
img1 = cv2.cvtColor(img1,cv2.COLOR_GRAY2BGR)
72+
img2 = cv2.cvtColor(img2,cv2.COLOR_GRAY2BGR)
73+
for pt1,pt2 in zip(pts1,pts2):
74+
color = tuple(np.random.randint(0,255,3).tolist())
75+
img1 = cv2.circle(img1,tuple(pt1),5,color,-1)
76+
img2 = cv2.circle(img2,tuple(pt2),5,color,-1)
77+
return img1,img2
78+
79+
80+
if __name__ == "__main__":
81+
args = options()
82+
augT = toMat(args.rot_perturb, args.tsl_perturb)
83+
print("augT:\n{}".format(augT))
84+
Tcw_list = read_pose_file(args.Twc_file)
85+
Tlw_list = read_pose_file(args.Twl_file)
86+
dataStruct = pykitti.odometry(args.base_dir, args.seq_id)
87+
calibStruct = dataStruct.calib
88+
extran = calibStruct.T_cam0_velo # [4,4]
89+
aug_extran = augT @ extran
90+
intran = calibStruct.K_cam0
91+
motion:np.ndarray = dataStruct.poses[args.index_j] @ inv_pose(dataStruct.poses[args.index_i])
92+
print("Motion:\n{}".format(motion))
93+
F = computeF(motion, intran)
94+
print("Fundamental Matrix:\n{}".format(F))
95+
src_pcd_arr = dataStruct.get_velo(args.index_i)[:,:3] # [N, 3]
96+
tgt_pcd_arr = dataStruct.get_velo(args.index_j)[:,:3] # [N, 3]
97+
src_img = np.array(dataStruct.get_cam0(args.index_i)) # [H, W, 3]
98+
tgt_img = np.array(dataStruct.get_cam0(args.index_j)) # [H, W, 3]
99+
img_shape = src_img.shape[:2]
100+
src_pcd = o3d.geometry.PointCloud()
101+
tgt_pcd = o3d.geometry.PointCloud()
102+
src_pcd.points = o3d.utility.Vector3dVector(src_pcd_arr)
103+
tgt_pcd.points = o3d.utility.Vector3dVector(tgt_pcd_arr)
104+
src_pcd.transform(Tlw_list[args.index_i])
105+
tgt_pcd.transform(Tlw_list[args.index_j])
106+
reg_p2p = o3d.pipelines.registration.registration_icp(
107+
src_pcd, tgt_pcd, args.icp_radius, np.eye(4),
108+
o3d.pipelines.registration.TransformationEstimationPointToPoint())
109+
src_pcd.transform(reg_p2p.transformation)
110+
# check correspondences
111+
corr_set = np.array(reg_p2p.correspondence_set)
112+
print("Raw Correspondences:{}".format(corr_set.shape[0]))
113+
src_pcd_corr_transformed = np.array(src_pcd.points)[corr_set[:,0]]
114+
tgt_pcd_corr_transformed = np.array(tgt_pcd.points)[corr_set[:,1]]
115+
corr_rev = np.sum((src_pcd_corr_transformed-tgt_pcd_corr_transformed)**2,axis=1) < args.corr_threshold ** 2 # (N,) bool
116+
corr_set = corr_set[corr_rev]
117+
src_pcd_corr_transformed = src_pcd_corr_transformed[corr_rev]
118+
tgt_pcd_corr_transformed = tgt_pcd_corr_transformed[corr_rev]
119+
# project LiDAR points onto the images
120+
src_pcd_corr = src_pcd_arr[corr_set[:,0]]
121+
tgt_pcd_corr = tgt_pcd_arr[corr_set[:,1]]
122+
print("Selected Correspondences:{}".format(src_pcd_corr.shape[0]))
123+
src_proj_pts, tgt_proj_pts, rev_idx = project_corr_pts_idx(src_pcd_corr, tgt_pcd_corr, extran, intran, img_shape)
124+
print("Projected Points:{}".format(src_proj_pts.shape[0]))
125+
if(args.fps_sample > 0):
126+
src_proj_pts_sampled, tgt_proj_pts_sampled = fps_sample_corr_pts(src_proj_pts, tgt_proj_pts, args.fps_sample) # can be replaced by farthest_point_down_sample of Open3D
127+
proj2src, proj2tgt = drawcorrpoints(src_img, tgt_img, src_proj_pts_sampled, tgt_proj_pts_sampled)
128+
129+
else:
130+
proj2src, proj2tgt = drawcorrpoints(src_img, tgt_img, src_proj_pts, tgt_proj_pts)
131+
E, mask = cv2.findEssentialMat(src_proj_pts, tgt_proj_pts, intran, cv2.RANSAC)
132+
_, R, t, _ = cv2.recoverPose(E, src_proj_pts, tgt_proj_pts, intran, mask=mask)
133+
recover_gt_pose = np.eye(4)
134+
recover_gt_pose[:3,:3] = R
135+
recover_gt_pose[:3,3] = t.flatten()
136+
print("GT Recover Pose:{}".format(recover_gt_pose))
137+
plt.figure(figsize=[12,3.5],dpi=200)
138+
plt.subplot(2,2,1)
139+
plt.imshow(proj2src)
140+
plt.subplot(2,2,2)
141+
plt.imshow(proj2tgt)
142+
src_proj_pts, tgt_proj_pts, rev_idx = project_corr_pts_idx(src_pcd_corr, tgt_pcd_corr, aug_extran, intran, img_shape)
143+
if(args.fps_sample > 0):
144+
src_proj_pts_sampled, tgt_proj_pts_sampled = fps_sample_corr_pts(src_proj_pts, tgt_proj_pts, args.fps_sample) # can be replaced by farthest_point_down_sample of Open3D
145+
augproj2src, augproj2tgt = drawcorrpoints(src_img, tgt_img, src_proj_pts_sampled, tgt_proj_pts_sampled)
146+
147+
else:
148+
augproj2src, augproj2tgt = drawcorrpoints(src_img, tgt_img, src_proj_pts, tgt_proj_pts)
149+
E, mask = cv2.findEssentialMat(src_proj_pts, tgt_proj_pts, intran, cv2.RANSAC)
150+
_, R, t, _ = cv2.recoverPose(E, src_proj_pts, tgt_proj_pts, intran, mask=mask)
151+
recover_pred_pose = np.eye(4)
152+
recover_pred_pose[:3,:3] = R
153+
recover_pred_pose[:3,3] = t.flatten()
154+
print("Augmented Recover Pose:{}".format(recover_pred_pose))
155+
err_pose = inv_pose(recover_pred_pose) @ recover_gt_pose
156+
ervec, etvec = toVec(err_pose)
157+
err_3d_pose = aug_extran @ motion @ inv_pose(aug_extran) @ extran @ inv_pose(motion) @ inv_pose(extran)
158+
e3drvec, e3dtvec = toVec(err_3d_pose)
159+
print("Error:\n{}".format(err_pose))
160+
print("Error rvec:\n{}".format(ervec))
161+
print("Error tvec:\n{}".format(etvec))
162+
print("3D Error:\n{}".format(err_3d_pose))
163+
print("3D Error rvec:\n{}".format(e3drvec))
164+
print("3D Error tvec:\n{}".format(e3dtvec))
165+
166+
plt.subplot(2,2,3)
167+
plt.imshow(augproj2src)
168+
plt.subplot(2,2,4)
169+
plt.imshow(augproj2tgt)
170+
plt.subplots_adjust(hspace=0.4)
171+
if args.view:
172+
plt.show()
173+
else:
174+
plt.savefig("../demo/ep_{:06d}_{:06d}.pdf".format(args.index_i, args.index_j))
175+
if args.view:
176+
draw_registration_result(src_pcd, tgt_pcd, np.eye(4))
177+
178+
179+

script/demo_epipolar.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def options():
2424
kitti_parser.add_argument("--base_dir",type=str,default="/data/DATA/data_odometry/dataset/")
2525
kitti_parser.add_argument("--seq",type=int,default=0,choices=[i for i in range(11)])
2626
kitti_parser.add_argument("--index_i",type=int,default=105)
27-
kitti_parser.add_argument("--index_j",type=int,default=106)
27+
kitti_parser.add_argument("--index_j",type=int,default=108)
2828

2929
io_parser = parser.add_argument_group()
3030
io_parser.add_argument("--Twc_file",type=str,default="../Twc.txt")
@@ -35,9 +35,9 @@ def options():
3535
arg_parser.add_argument("--rot_perturb",type=float,nargs=3,default=[0,0,0])
3636
arg_parser.add_argument("--fps_sample",type=int,default=100)
3737
arg_parser.add_argument("--epipolar_threshold",type=float,default=2)
38-
arg_parser.add_argument("--icp_radius",type=float,default=0.6)
38+
arg_parser.add_argument("--icp_radius",type=float,default=0.1)
3939
arg_parser.add_argument("--corr_threshold",type=float,default=0.05)
40-
arg_parser.add_argument("--view",type=str2bool,default=False)
40+
arg_parser.add_argument("--view",type=str2bool,default=True)
4141
args = parser.parse_args()
4242
args.seq_id = "%02d"%args.seq
4343
return args
@@ -122,17 +122,17 @@ def EpipolarwithoutF(src_proj_pts:np.ndarray, tgt_proj_pts:np.ndarray, threshold
122122
args = options()
123123
augT = toMat(args.rot_perturb, args.tsl_perturb)
124124
print("augT:\n{}".format(augT))
125-
Tcw_list = read_pose_file(args.Twc_file)
126-
Tlw_list = read_pose_file(args.Twl_file)
127125
dataStruct = pykitti.odometry(args.base_dir, args.seq_id)
128126
calibStruct = dataStruct.calib
129127
extran = calibStruct.T_cam0_velo # [4,4]
128+
print("GT TCL:\n{}".format(extran))
129+
print("GT TCL Rvec:{}\ntvec:{}".format(*toVec(extran)))
130130
aug_extran = augT @ extran
131131
intran = calibStruct.K_cam0
132-
motion:np.ndarray = dataStruct.poses[args.index_j] @ inv_pose(dataStruct.poses[args.index_i])
133-
print("Motion:\n{}".format(motion))
134-
F = computeF(motion, intran)
135-
print("Fundamental Matrix:\n{}".format(F))
132+
camera_motion:np.ndarray = dataStruct.poses[args.index_j] @ inv_pose(dataStruct.poses[args.index_i])
133+
print("Motion:\n{}".format(camera_motion))
134+
F = computeF(camera_motion, intran)
135+
print("Motion Rvec:{}\ntvec:{}".format(*toVec(camera_motion)))
136136
src_pcd_arr = dataStruct.get_velo(args.index_i)[:,:3] # [N, 3]
137137
tgt_pcd_arr = dataStruct.get_velo(args.index_j)[:,:3] # [N, 3]
138138
src_img = np.array(dataStruct.get_cam0(args.index_i)) # [H, W, 3]
@@ -142,11 +142,12 @@ def EpipolarwithoutF(src_proj_pts:np.ndarray, tgt_proj_pts:np.ndarray, threshold
142142
tgt_pcd = o3d.geometry.PointCloud()
143143
src_pcd.points = o3d.utility.Vector3dVector(src_pcd_arr)
144144
tgt_pcd.points = o3d.utility.Vector3dVector(tgt_pcd_arr)
145-
src_pcd.transform(Tlw_list[args.index_i])
146-
tgt_pcd.transform(Tlw_list[args.index_j])
145+
src_pcd.transform(inv_pose(extran) @ dataStruct.poses[args.index_i] @ extran)
146+
tgt_pcd.transform(inv_pose(extran) @ dataStruct.poses[args.index_j] @ extran)
147147
reg_p2p = o3d.pipelines.registration.registration_icp(
148148
src_pcd, tgt_pcd, args.icp_radius, np.eye(4),
149149
o3d.pipelines.registration.TransformationEstimationPointToPoint())
150+
print(reg_p2p.transformation)
150151
src_pcd.transform(reg_p2p.transformation)
151152
# check correspondences
152153
corr_set = np.array(reg_p2p.correspondence_set)

0 commit comments

Comments
 (0)