-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathcommon.py
More file actions
214 lines (173 loc) · 7.25 KB
/
common.py
File metadata and controls
214 lines (173 loc) · 7.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import torch
import numpy as np
import logging
logger_py = logging.getLogger(__name__)
def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.),
subsample_to=None, invert_y_axis=False):
''' Arranges pixels for given resolution in range image_range.
The function returns the unscaled pixel locations as integers and the
scaled float values.
Args:
resolution (tuple): image resolution
batch_size (int): batch size
image_range (tuple): range of output points (default [-1, 1])
subsample_to (int): if integer and > 0, the points are randomly
subsampled to this value
'''
h, w = resolution
n_points = resolution[0] * resolution[1]
# Arrange pixel location in scale resolution
pixel_locations = torch.meshgrid(torch.arange(0, w), torch.arange(0, h))
pixel_locations = torch.stack(
[pixel_locations[0], pixel_locations[1]],
dim=-1).long().view(1, -1, 2).repeat(batch_size, 1, 1)
pixel_scaled = pixel_locations.clone().float()
# Shift and scale points to match image_range
scale = (image_range[1] - image_range[0])
loc = scale / 2
pixel_scaled[:, :, 0] = scale * pixel_scaled[:, :, 0] / (w - 1) - loc
pixel_scaled[:, :, 1] = scale * pixel_scaled[:, :, 1] / (h - 1) - loc
# Subsample points if subsample_to is not None and > 0
if (subsample_to is not None and subsample_to > 0 and
subsample_to < n_points):
idx = np.random.choice(pixel_scaled.shape[1], size=(subsample_to,),
replace=False)
pixel_scaled = pixel_scaled[:, idx]
pixel_locations = pixel_locations[:, idx]
if invert_y_axis:
assert(image_range == (-1, 1))
pixel_scaled[..., -1] *= -1.
pixel_locations[..., -1] = (h - 1) - pixel_locations[..., -1]
return pixel_locations, pixel_scaled
def to_pytorch(tensor, return_type=False):
''' Converts input tensor to pytorch.
Args:
tensor (tensor): Numpy or Pytorch tensor
return_type (bool): whether to return input type
'''
is_numpy = False
if type(tensor) == np.ndarray:
tensor = torch.from_numpy(tensor)
is_numpy = True
tensor = tensor.clone()
if return_type:
return tensor, is_numpy
return tensor
def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat=None,
invert=True, use_absolute_depth=True):
''' Transforms pixel positions p with given depth value d to world coordinates.
Args:
pixels (tensor): pixel tensor of size B x N x 2
depth (tensor): depth tensor of size B x N x 1
camera_mat (tensor): camera matrix
world_mat (tensor): world matrix
scale_mat (tensor): scale matrix
invert (bool): whether to invert matrices (default: true)
'''
assert(pixels.shape[-1] == 2)
if scale_mat is None:
scale_mat = torch.eye(4).unsqueeze(0).repeat(
camera_mat.shape[0], 1, 1).to(camera_mat.device)
# Convert to pytorch
pixels, is_numpy = to_pytorch(pixels, True)
depth = to_pytorch(depth)
camera_mat = to_pytorch(camera_mat)
world_mat = to_pytorch(world_mat)
scale_mat = to_pytorch(scale_mat)
# Invert camera matrices
if invert:
camera_mat = torch.inverse(camera_mat)
world_mat = torch.inverse(world_mat)
scale_mat = torch.inverse(scale_mat)
# Transform pixels to homogen coordinates
pixels = pixels.permute(0, 2, 1)
pixels = torch.cat([pixels, torch.ones_like(pixels)], dim=1)
# Project pixels into camera space
if use_absolute_depth:
pixels[:, :2] = pixels[:, :2] * depth.permute(0, 2, 1).abs()
pixels[:, 2:3] = pixels[:, 2:3] * depth.permute(0, 2, 1)
else:
pixels[:, :3] = pixels[:, :3] * depth.permute(0, 2, 1)
# Transform pixels to world space
p_world = scale_mat @ world_mat @ camera_mat @ pixels
# Transform p_world back to 3D coordinates
p_world = p_world[:, :3].permute(0, 2, 1)
if is_numpy:
p_world = p_world.numpy()
return p_world
def transform_to_camera_space(p_world, camera_mat, world_mat, scale_mat):
''' Transforms world points to camera space.
Args:
p_world (tensor): world points tensor of size B x N x 3
camera_mat (tensor): camera matrix
world_mat (tensor): world matrix
scale_mat (tensor): scale matrix
'''
batch_size, n_p, _ = p_world.shape
device = p_world.device
# Transform world points to homogen coordinates
p_world = torch.cat([p_world, torch.ones(
batch_size, n_p, 1).to(device)], dim=-1).permute(0, 2, 1)
# Apply matrices to transform p_world to camera space
p_cam = camera_mat @ world_mat @ scale_mat @ p_world
# Transform points back to 3D coordinates
p_cam = p_cam[:, :3].permute(0, 2, 1)
return p_cam
def origin_to_world(n_points, camera_mat, world_mat, scale_mat=None,
invert=False):
''' Transforms origin (camera location) to world coordinates.
Args:
n_points (int): how often the transformed origin is repeated in the
form (batch_size, n_points, 3)
camera_mat (tensor): camera matrix
world_mat (tensor): world matrix
scale_mat (tensor): scale matrix
invert (bool): whether to invert the matrices (default: true)
'''
batch_size = camera_mat.shape[0]
device = camera_mat.device
# Create origin in homogen coordinates
p = torch.zeros(batch_size, 4, n_points).to(device)
p[:, -1] = 1.
if scale_mat is None:
scale_mat = torch.eye(4).unsqueeze(
0).repeat(batch_size, 1, 1).to(device)
# Invert matrices
if invert:
camera_mat = torch.inverse(camera_mat)
world_mat = torch.inverse(world_mat)
scale_mat = torch.inverse(scale_mat)
# Apply transformation
p_world = scale_mat @ world_mat @ camera_mat @ p
# Transform points back to 3D coordinates
p_world = p_world[:, :3].permute(0, 2, 1)
return p_world
def image_points_to_world(image_points, camera_mat, world_mat, scale_mat=None,
invert=False, negative_depth=True):
''' Transforms points on image plane to world coordinates.
In contrast to transform_to_world, no depth value is needed as points on
the image plane have a fixed depth of 1.
Args:
image_points (tensor): image points tensor of size B x N x 2
camera_mat (tensor): camera matrix
world_mat (tensor): world matrix
scale_mat (tensor): scale matrix
invert (bool): whether to invert matrices (default: true)
'''
batch_size, n_pts, dim = image_points.shape
assert(dim == 2)
device = image_points.device
d_image = torch.ones(batch_size, n_pts, 1).to(device)
if negative_depth:
d_image *= -1.
return transform_to_world(image_points, d_image, camera_mat, world_mat,
scale_mat, invert=invert)
def interpolate_sphere(z1, z2, t):
p = (z1 * z2).sum(dim=-1, keepdim=True)
p = p / z1.pow(2).sum(dim=-1, keepdim=True).sqrt()
p = p / z2.pow(2).sum(dim=-1, keepdim=True).sqrt()
omega = torch.acos(p)
s1 = torch.sin((1-t)*omega)/torch.sin(omega)
s2 = torch.sin(t*omega)/torch.sin(omega)
z = s1 * z1 + s2 * z2
return z