Skip to content

Commit fb00595

Browse files
committed
enable camera struct tests
1 parent 78fc92b commit fb00595

File tree

1 file changed

+50
-37
lines changed

1 file changed

+50
-37
lines changed

test/differentiable/test_structs.py

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
import os
2-
import sys
3-
4-
sys.path.append(
5-
os.path.dirname((os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
6-
)
7-
81
import numpy as np
92
import pytest
103
import torch
@@ -63,7 +56,6 @@ def test_torch_mesh_container() -> None:
6356
raise ValueError("Expected target reduction")
6457

6558

66-
@pytest.mark.skip(reason="To be fixed.")
6759
def test_batched_camera() -> None:
6860
device = "cuda" if torch.cuda.is_available() else "cpu"
6961
resolution = (480, 640)
@@ -75,12 +67,18 @@ def test_batched_camera() -> None:
7567
device=device,
7668
)
7769

78-
if camera.intrinsics.shape != (3, 3):
79-
raise ValueError(f"Expected shape (3, 3), got {camera.intrinsics.shape}")
80-
if camera.extrinsics.shape != (4, 4):
81-
raise ValueError(f"Expected shape (4, 4), got {camera.extrinsics.shape}")
82-
if camera.ht_optical.shape != (4, 4):
83-
raise ValueError(f"Expected shape (4, 4), got {camera.ht_optical.shape}")
70+
assert camera.intrinsics.shape == (
71+
3,
72+
3,
73+
), f"Expected shape (3, 3), got {camera.intrinsics.shape}"
74+
assert camera.extrinsics.shape == (
75+
4,
76+
4,
77+
), f"Expected shape (4, 4), got {camera.extrinsics.shape}"
78+
assert camera.ht_optical.shape == (
79+
4,
80+
4,
81+
), f"Expected shape (4, 4), got {camera.ht_optical.shape}"
8482

8583
# construct invalid dim
8684
try:
@@ -90,8 +88,8 @@ def test_batched_camera() -> None:
9088
extrinsics=torch.zeros(3, 3, device=device),
9189
device=device,
9290
)
93-
except ValueError as e:
94-
print(f"Expected and got error: {e}")
91+
except ValueError:
92+
pass
9593
else:
9694
raise ValueError("Expected ValueError")
9795

@@ -104,8 +102,8 @@ def test_batched_camera() -> None:
104102
extrinsics=extrinsics,
105103
device=device,
106104
)
107-
except ValueError as e:
108-
print(f"Expected and got error: {e}")
105+
except ValueError:
106+
pass
109107
else:
110108
raise ValueError("Expected ValueError")
111109

@@ -120,12 +118,21 @@ def test_batched_camera() -> None:
120118
device=device,
121119
)
122120

123-
if camera.intrinsics.shape != (1, 3, 3):
124-
raise ValueError(f"Expected shape (1, 3, 3), got {camera.intrinsics.shape}")
125-
if camera.extrinsics.shape != (1, 4, 4):
126-
raise ValueError(f"Expected shape (1, 4, 4), got {camera.extrinsics.shape}")
127-
if camera.ht_optical.shape != (1, 4, 4):
128-
raise ValueError(f"Expected shape (1, 4, 4), got {camera.ht_optical.shape}")
121+
assert camera.intrinsics.shape == (
122+
1,
123+
3,
124+
3,
125+
), f"Expected shape (1, 3, 3), got {camera.intrinsics.shape}"
126+
assert camera.extrinsics.shape == (
127+
1,
128+
4,
129+
4,
130+
), f"Expected shape (1, 4, 4), got {camera.extrinsics.shape}"
131+
assert camera.ht_optical.shape == (
132+
1,
133+
4,
134+
4,
135+
), f"Expected shape (1, 4, 4), got {camera.ht_optical.shape}"
129136

130137
# take a point in homogeneous coordinates of shape Bx4xN and project it
131138
# using the camera extrinsics / intrinsics
@@ -136,18 +143,17 @@ def test_batched_camera() -> None:
136143

137144
# project the point
138145
p_prime = camera.extrinsics @ p
139-
if p_prime.shape != shape:
140-
raise ValueError(f"Expected shape {shape}, got {p_prime.shape}")
146+
assert p_prime.shape == shape, f"Expected shape {shape}, got {p_prime.shape}"
141147

142148
p_prime = p_prime[:, :3, :] / p_prime[:, 3, :] # to homogeneous coordinates
143149

144150
projected_shape = (batch_size, 3, samples)
145151
p_prime = camera.intrinsics @ p_prime
146-
if p_prime.shape != projected_shape:
147-
raise ValueError(f"Expected shape {projected_shape}, got {p_prime.shape}")
152+
assert (
153+
p_prime.shape == projected_shape
154+
), f"Expected shape {projected_shape}, got {p_prime.shape}"
148155

149156

150-
@pytest.mark.skip(reason="To be fixed.")
151157
def test_batched_virtual_camera() -> None:
152158
device = "cuda" if torch.cuda.is_available() else "cpu"
153159
resolution = (480, 640)
@@ -158,10 +164,10 @@ def test_batched_virtual_camera() -> None:
158164
device=device,
159165
)
160166

161-
if virtual_camera.perspective_projection.shape != (4, 4):
162-
raise ValueError(
163-
f"Expected shape (3, 4), got {virtual_camera.perspective_projection.shape}"
164-
)
167+
assert virtual_camera.perspective_projection.shape == (
168+
4,
169+
4,
170+
), f"Expected shape (3, 4), got {virtual_camera.perspective_projection.shape}"
165171

166172
# construct with batch size
167173
intrinsics = torch.zeros(1, 3, 3, device=device)
@@ -171,13 +177,20 @@ def test_batched_virtual_camera() -> None:
171177
device=device,
172178
)
173179

174-
if virtual_camera.perspective_projection.shape != (1, 4, 4):
175-
raise ValueError(
176-
f"Expected shape (1, 4, 4), got {virtual_camera.perspective_projection.shape}"
177-
)
180+
assert virtual_camera.perspective_projection.shape == (
181+
1,
182+
4,
183+
4,
184+
), f"Expected shape (1, 4, 4), got {virtual_camera.perspective_projection.shape}"
178185

179186

180187
if __name__ == "__main__":
188+
import os
189+
import sys
190+
191+
os.environ["QT_QPA_PLATFORM"] = "offscreen"
192+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
193+
181194
test_torch_mesh_container()
182195
test_batched_camera()
183196
test_batched_virtual_camera()

0 commit comments

Comments
 (0)