Skip to content

Commit 462906c

Browse files
committed
update imu test
1 parent a876fdf commit 462906c

File tree

4 files changed

+259
-200
lines changed

4 files changed

+259
-200
lines changed

genesis/engine/sensors/base_sensor.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from genesis.utils.misc import concat_with_tensor, make_tensor_field
1313

1414
if TYPE_CHECKING:
15-
from genesis.engine.solvers import RigidSolver
1615
from genesis.engine.entities.rigid_entity.rigid_link import RigidLink
16+
from genesis.engine.solvers import RigidSolver
1717
from genesis.recorders.base_recorder import Recorder, RecorderOptions
1818
from genesis.utils.ring_buffer import TensorRingBuffer
1919
from genesis.vis.rasterizer_context import RasterizerContext
@@ -292,8 +292,38 @@ def _get_formatted_data(self, tensor: torch.Tensor, envs_idx=None) -> torch.Tens
292292
return self._return_data_class(*return_values)
293293

294294
def _sanitize_envs_idx(self, envs_idx) -> torch.Tensor:
295+
if self._manager._sim.n_envs == 0:
296+
return torch.tensor([0], device=gs.device, dtype=gs.tc_int)
295297
return self._manager._sim._scene._sanitize_envs_idx(envs_idx)
296298

299+
def _set_metadata_field(self, input, field, field_size, envs_idx):
300+
envs_idx = self._sanitize_envs_idx(envs_idx)
301+
if field.ndim == 2:
302+
# flat field structure
303+
idx = self._idx * field_size
304+
index_slice = slice(idx, idx + field_size)
305+
else:
306+
# per sensor field structure
307+
index_slice = self._idx
308+
# field[envs_idx, index_slice] = self._sanitize_for_metadata_tensor(
309+
field[:, index_slice] = self._sanitize_for_metadata_tensor(
310+
input, shape=(len(envs_idx), field_size), dtype=field.dtype
311+
)
312+
313+
def _sanitize_for_metadata_tensor(self, input, shape, dtype) -> torch.Tensor:
314+
if not isinstance(input, Sequence):
315+
input = [input]
316+
tensor_input = torch.tensor(input, dtype=dtype, device=gs.device)
317+
if tensor_input.ndim == len(shape) - 1:
318+
# Batch dimension is missing
319+
tensor_input = tensor_input.unsqueeze(0)
320+
if tensor_input.shape[0] != shape[0]:
321+
tensor_input = tensor_input.expand((shape[0], *tensor_input.shape[1:]))
322+
assert (
323+
tensor_input.shape == shape
324+
), f"Input shape {tensor_input.shape} for setting sensor metadata does not match shape {shape}"
325+
return tensor_input
326+
297327

298328
@dataclass
299329
class RigidSensorMetadataMixin:
@@ -345,6 +375,16 @@ def build(self):
345375
dim=1,
346376
)
347377

378+
@gs.assert_built
379+
def set_pos_offset(self, pos_offset, envs_idx=None):
380+
envs_idx = self._sanitize_envs_idx(envs_idx)
381+
self._set_metadata_field(pos_offset, self._shared_metadata.offsets_pos, field_size=3, envs_idx=envs_idx)
382+
383+
@gs.assert_built
384+
def set_quat_offset(self, quat_offset, envs_idx=None):
385+
envs_idx = self._sanitize_envs_idx(envs_idx)
386+
self._set_metadata_field(quat_offset, self._shared_metadata.offsets_quat, field_size=4, envs_idx=envs_idx)
387+
348388

349389
@dataclass
350390
class NoisySensorMetadataMixin:
@@ -372,25 +412,6 @@ class NoisySensorMixin(Generic[NoisySensorMetadataMixinT]):
372412
Base sensor class for analog sensors that are attached to a RigidEntity.
373413
"""
374414

375-
def _set_metadata_field(self, input, field, field_size, envs_idx):
376-
envs_idx = self._sanitize_envs_idx(envs_idx)
377-
idx = self._idx * field_size
378-
field[envs_idx, idx : idx + field_size] = self._sanitize_for_metadata_tensor(
379-
input, shape=(len(envs_idx), field_size), dtype=field.dtype
380-
)
381-
382-
def _sanitize_for_metadata_tensor(self, input, shape, dtype) -> torch.Tensor:
383-
if not isinstance(input, Sequence):
384-
input = [input]
385-
tensor_input = torch.tensor(input, dtype=dtype, device=gs.device)
386-
if tensor_input.ndim == len(shape) - 1:
387-
# Batch dimension is missing
388-
tensor_input = tensor_input.unsqueeze(0).expand((shape[0], *tensor_input.shape))
389-
assert (
390-
tensor_input.shape == shape
391-
), f"Input shape {tensor_input.shape} for setting sensor metadata does not match shape {shape}"
392-
return tensor_input
393-
394415
@gs.assert_built
395416
def set_resolution(self, resolution, envs_idx=None):
396417
self._set_metadata_field(resolution, self._shared_metadata.resolution, self._cache_size, envs_idx)

genesis/engine/sensors/imu.py

Lines changed: 24 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
import torch
77

88
import genesis as gs
9-
from genesis.options.sensors import (
10-
MaybeMatrix3x3Type,
11-
IMU as IMUOptions,
9+
from genesis.options.sensors import IMU as IMUOptions
10+
from genesis.options.sensors import MaybeMatrix3x3Type
11+
from genesis.utils.geom import (
12+
inv_transform_by_quat,
13+
transform_by_quat,
14+
transform_quat_by_quat,
1215
)
13-
from genesis.utils.geom import inv_transform_by_trans_quat, transform_by_quat, transform_quat_by_quat
1416
from genesis.utils.misc import concat_with_tensor, make_tensor_field, tensor_to_array
1517

1618
from .base_sensor import (
@@ -30,20 +32,6 @@
3032
from genesis.vis.rasterizer_context import RasterizerContext
3133

3234

33-
def _view_metadata_as_acc_gyro(metadata_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
34-
"""
35-
Get views of the metadata tensor (B, n_imus * 6) as a tuple of acc and gyro metadata tensors (B, n_imus * 3).
36-
"""
37-
batch_shape, n_data = metadata_tensor.shape[:-1], metadata_tensor.shape[-1]
38-
n_imus = n_data // 6
39-
metadata_tensor_per_sensor = metadata_tensor.reshape((*batch_shape, n_imus, 2, 3))
40-
41-
return (
42-
metadata_tensor_per_sensor[..., 0, :].reshape(*batch_shape, n_imus * 3),
43-
metadata_tensor_per_sensor[..., 1, :].reshape(*batch_shape, n_imus * 3),
44-
)
45-
46-
4735
def _get_skew_to_alignment_matrix(input: MaybeMatrix3x3Type, out: torch.Tensor | None = None) -> torch.Tensor:
4836
"""
4937
Convert the alignment input to a matrix. Modifies in place if provided, else allocate a new matrix.
@@ -120,30 +108,6 @@ def set_gyro_axes_skew(self, axes_skew: MaybeMatrix3x3Type, envs_idx=None):
120108
rot_matrix = _get_skew_to_alignment_matrix(axes_skew)
121109
self._shared_metadata.alignment_rot_matrix[envs_idx, self._idx * 2 + 1, :, :] = rot_matrix
122110

123-
@gs.assert_built
124-
def set_acc_bias(self, bias, envs_idx=None):
125-
self._set_metadata_field(bias, self._shared_metadata.acc_bias, field_size=3, envs_idx=envs_idx)
126-
127-
@gs.assert_built
128-
def set_gyro_bias(self, bias, envs_idx=None):
129-
self._set_metadata_field(bias, self._shared_metadata.gyro_bias, field_size=3, envs_idx=envs_idx)
130-
131-
@gs.assert_built
132-
def set_acc_random_walk(self, random_walk, envs_idx=None):
133-
self._set_metadata_field(random_walk, self._shared_metadata.acc_random_walk, field_size=3, envs_idx=envs_idx)
134-
135-
@gs.assert_built
136-
def set_gyro_random_walk(self, random_walk, envs_idx=None):
137-
self._set_metadata_field(random_walk, self._shared_metadata.gyro_random_walk, field_size=3, envs_idx=envs_idx)
138-
139-
@gs.assert_built
140-
def set_acc_noise(self, noise, envs_idx=None):
141-
self._set_metadata_field(noise, self._shared_metadata.acc_noise, field_size=3, envs_idx=envs_idx)
142-
143-
@gs.assert_built
144-
def set_gyro_noise(self, noise, envs_idx=None):
145-
self._set_metadata_field(noise, self._shared_metadata.gyro_noise, field_size=3, envs_idx=envs_idx)
146-
147111
# ================================ internal methods ================================
148112

149113
def build(self):
@@ -160,15 +124,6 @@ def build(self):
160124
self._options.noise = _to_tuple(self._options.acc_noise, self._options.gyro_noise, length_per_value=3)
161125
super().build() # set all shared metadata from RigidSensorBase and NoisySensorBase
162126

163-
self._shared_metadata.acc_bias, self._shared_metadata.gyro_bias = _view_metadata_as_acc_gyro(
164-
self._shared_metadata.bias
165-
)
166-
self._shared_metadata.acc_random_walk, self._shared_metadata.gyro_random_walk = _view_metadata_as_acc_gyro(
167-
self._shared_metadata.random_walk
168-
)
169-
self._shared_metadata.acc_noise, self._shared_metadata.gyro_noise = _view_metadata_as_acc_gyro(
170-
self._shared_metadata.noise
171-
)
172127
self._shared_metadata.alignment_rot_matrix = concat_with_tensor(
173128
self._shared_metadata.alignment_rot_matrix,
174129
torch.stack(
@@ -203,15 +158,30 @@ def _update_shared_ground_truth_cache(
203158
quats = shared_metadata.solver.get_links_quat(links_idx=shared_metadata.links_idx)
204159
acc = shared_metadata.solver.get_links_acc(links_idx=shared_metadata.links_idx)
205160
ang = shared_metadata.solver.get_links_ang(links_idx=shared_metadata.links_idx)
161+
if acc.ndim == 2:
162+
acc = acc.unsqueeze(0)
163+
ang = ang.unsqueeze(0)
206164

207165
offset_quats = transform_quat_by_quat(quats, shared_metadata.offsets_quat)
208166

167+
# additional acceleration if offset: a_imu = a_link + α × r + ω × (ω × r)
168+
if torch.any(torch.abs(shared_metadata.offsets_pos) > gs.EPS):
169+
ang_acc = shared_metadata.solver.get_links_acc_ang(links_idx=shared_metadata.links_idx)
170+
if ang_acc.ndim == 2:
171+
ang_acc = ang_acc.unsqueeze(0)
172+
offset_pos_world = transform_by_quat(shared_metadata.offsets_pos, quats)
173+
tangential_acc = torch.cross(ang_acc, offset_pos_world, dim=-1)
174+
centripetal_acc = torch.cross(ang, torch.cross(ang, offset_pos_world, dim=-1), dim=-1)
175+
acc += tangential_acc + centripetal_acc
176+
209177
# acc/ang shape: (B, n_imus, 3)
210-
local_acc = inv_transform_by_trans_quat(acc, shared_metadata.offsets_pos, offset_quats)
211-
local_ang = inv_transform_by_trans_quat(ang, shared_metadata.offsets_pos, offset_quats)
178+
local_acc = inv_transform_by_quat(acc, offset_quats)
179+
local_ang = inv_transform_by_quat(ang, offset_quats)
212180

213181
*batch_size, n_imus, _ = local_acc.shape
214-
local_acc = local_acc - gravity.unsqueeze(-2).expand((*batch_size, n_imus, -1))
182+
local_acc = local_acc - inv_transform_by_quat(
183+
gravity.unsqueeze(-2).expand((*batch_size, n_imus, -1)), offset_quats
184+
)
215185

216186
# cache shape: (B, n_imus * 6)
217187
strided_ground_truth_cache = shared_ground_truth_cache.reshape((*batch_size, n_imus, 2, 3))

tests/test_imu.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import numpy as np
2+
import torch
3+
from utils import assert_allclose, assert_array_equal
4+
5+
import genesis as gs
6+
import genesis.utils.geom as gu
7+
8+
9+
def expand_batch_dim(values: tuple[float, ...], n_envs: int) -> tuple[float, ...] | np.ndarray:
10+
"""Helper function to expand expected values for n_envs dimension."""
11+
if n_envs == 0:
12+
return values
13+
return np.tile(np.array(values), (n_envs,) + (1,) * len(values))
14+
15+
16+
def test_imu_sensor(show_viewer, tol, n_envs):
17+
"""Test if the IMU sensor returns the correct data."""
18+
GRAVITY = -10.0
19+
DT = 1e-2
20+
BIAS = (0.1, 0.2, 0.3)
21+
DELAY_STEPS = 2
22+
23+
scene = gs.Scene(
24+
sim_options=gs.options.SimOptions(
25+
dt=DT,
26+
substeps=1,
27+
gravity=(0.0, 0.0, GRAVITY),
28+
),
29+
profiling_options=gs.options.ProfilingOptions(show_FPS=False),
30+
show_viewer=show_viewer,
31+
)
32+
33+
scene.add_entity(gs.morphs.Plane())
34+
35+
box = scene.add_entity(
36+
morph=gs.morphs.Box(
37+
size=(0.1, 0.1, 0.1),
38+
pos=(0.0, 0.0, 0.2),
39+
),
40+
)
41+
42+
imu = scene.add_sensor(
43+
gs.sensors.IMU(
44+
entity_idx=box.idx,
45+
)
46+
)
47+
imu_delayed = scene.add_sensor(
48+
gs.sensors.IMU(
49+
entity_idx=box.idx,
50+
delay=DT * DELAY_STEPS,
51+
)
52+
)
53+
imu_noisy = scene.add_sensor(
54+
gs.sensors.IMU(
55+
entity_idx=box.idx,
56+
acc_axes_skew=0.01,
57+
gyro_axes_skew=(0.02, 0.03, 0.04),
58+
acc_noise=(0.01, 0.01, 0.01),
59+
gyro_noise=(0.01, 0.01, 0.01),
60+
acc_random_walk=(0.001, 0.001, 0.001),
61+
gyro_random_walk=(0.001, 0.001, 0.001),
62+
delay=DT,
63+
jitter=DT * 0.1,
64+
interpolate=True,
65+
)
66+
)
67+
68+
scene.build(n_envs=n_envs)
69+
70+
# box is in freefall
71+
for _ in range(10):
72+
scene.step()
73+
74+
# IMU should calculate "classical linear acceleration" using the local frame without accounting for gravity
75+
# acc_classical_lin_z = - theta_dot ** 2 - cos(theta) * g
76+
assert_allclose(imu.read().lin_acc, 0.0, tol=tol)
77+
assert_allclose(imu.read().ang_vel, 0.0, tol=tol)
78+
assert_allclose(imu_noisy.read().lin_acc, 0.0, tol=1e-1)
79+
assert_allclose(imu_noisy.read().ang_vel, 0.0, tol=1e-1)
80+
81+
# shift COM to induce angular velocity
82+
com_shift = torch.tensor([[0.05, 0.05, 0.05]])
83+
box.set_COM_shift(com_shift.expand((n_envs, 1, 3)) if n_envs > 0 else com_shift)
84+
85+
# update noise and bias for accelerometer and gyroscope
86+
imu_noisy.set_noise((0.01, 0.01, 0.01, 0.02, 0.02, 0.02))
87+
imu_noisy.set_bias((0.01, 0.01, 0.01, 0.02, 0.02, 0.02))
88+
imu_noisy.set_jitter(0.001)
89+
90+
for _ in range(10 - DELAY_STEPS):
91+
scene.step()
92+
93+
true_imu_delayed_reading = imu_delayed.read_ground_truth()
94+
95+
for _ in range(DELAY_STEPS):
96+
scene.step()
97+
98+
assert_array_equal(imu_delayed.read().lin_acc, true_imu_delayed_reading.lin_acc)
99+
assert_array_equal(imu_delayed.read().ang_vel, true_imu_delayed_reading.ang_vel)
100+
101+
# check that position offset affects linear acceleration
102+
imu.set_pos_offset((0.5, 0.0, 0.0))
103+
lin_acc_no_offset = imu.read().lin_acc
104+
scene.step()
105+
lin_acc_with_offset = imu.read().lin_acc
106+
assert not np.allclose(lin_acc_no_offset, lin_acc_with_offset, atol=0.2)
107+
imu.set_pos_offset((0.0, 0.0, 0.0))
108+
109+
# let box collide with ground
110+
for _ in range(20):
111+
scene.step()
112+
113+
assert_array_equal(imu.read_ground_truth().lin_acc, imu_delayed.read_ground_truth().lin_acc)
114+
assert_array_equal(imu.read_ground_truth().ang_vel, imu_delayed.read_ground_truth().ang_vel)
115+
116+
with np.testing.assert_raises(AssertionError, msg="Angular velocity should not be zero due to COM shift"):
117+
assert_allclose(imu.read_ground_truth().ang_vel, 0.0, tol=tol)
118+
119+
with np.testing.assert_raises(AssertionError, msg="Delayed data should not be equal to the ground truth data"):
120+
assert_array_equal(imu_delayed.read().lin_acc - imu_delayed.read_ground_truth().lin_acc, 0.0)
121+
122+
zero_com_shift = torch.tensor([[0.0, 0.0, 0.0]])
123+
box.set_COM_shift(zero_com_shift.expand((n_envs, 1, 3)) if n_envs > 0 else zero_com_shift)
124+
quat_tensor = torch.tensor([0.0, 0.0, 0.0, 1.0])
125+
box.set_quat(quat_tensor.expand((n_envs, 4)) if n_envs > 0 else quat_tensor)
126+
127+
# box is stationary on ground
128+
for _ in range(80):
129+
scene.step()
130+
131+
assert_allclose(
132+
imu.read().lin_acc,
133+
expand_batch_dim((0.0, 0.0, -GRAVITY), n_envs),
134+
tol=5e-6,
135+
)
136+
assert_allclose(imu.read().ang_vel, expand_batch_dim((0.0, 0.0, 0.0), n_envs), tol=1e-5)
137+
138+
# rotate IMU 90 deg around x axis means gravity should be along -y axis
139+
imu.set_quat_offset(gu.euler_to_quat((90.0, 0.0, 0.0)))
140+
imu.set_acc_axes_skew((0.0, 1.0, 0.0))
141+
scene.step()
142+
assert_allclose(imu.read().lin_acc, GRAVITY, tol=5e-6)
143+
imu.set_quat_offset((0.0, 0.0, 0.0, 1.0))
144+
imu.set_acc_axes_skew((0.0, 0.0, 0.0))
145+
146+
scene.reset()
147+
148+
assert_allclose(imu.read().lin_acc, 0.0, tol=gs.EPS) # biased, but cache hasn't been updated yet
149+
assert_allclose(imu_delayed.read().lin_acc, 0.0, tol=gs.EPS)
150+
assert_allclose(imu_noisy.read().ang_vel, 0.0, tol=gs.EPS)
151+
152+
imu.set_bias(BIAS + (0.0, 0.0, 0.0))
153+
scene.step()
154+
assert_allclose(imu.read().lin_acc, expand_batch_dim(BIAS, n_envs), tol=tol)
155+
156+
157+
if __name__ == "__main__":
158+
gs.init(backend=gs.cpu)
159+
# test_imu_sensor(show_viewer=False, tol=1e-4, n_envs=0)
160+
test_imu_sensor(show_viewer=False, tol=1e-4, n_envs=2)

0 commit comments

Comments
 (0)