Skip to content

Commit f7a45a1

Browse files
authored
[BUG FIX] Fix sensor IMU accelerometer signal. (#1962)
1 parent a8cfe73 commit f7a45a1

File tree

5 files changed

+129
-131
lines changed

5 files changed

+129
-131
lines changed

examples/sensors/imu_franka.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def main():
5454
link_idx_local=end_effector.idx_local,
5555
pos_offset=(0.0, 0.0, 0.15),
5656
# noise parameters
57-
acc_axes_skew=(0.0, 0.01, 0.02),
58-
gyro_axes_skew=(0.03, 0.04, 0.05),
57+
acc_cross_axis_coupling=(0.0, 0.01, 0.02),
58+
gyro_cross_axis_coupling=(0.03, 0.04, 0.05),
5959
acc_noise=(0.01, 0.01, 0.01),
6060
gyro_noise=(0.01, 0.01, 0.01),
6161
acc_random_walk=(0.001, 0.001, 0.001),

genesis/engine/sensors/base_sensor.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from genesis.utils.misc import concat_with_tensor, make_tensor_field
1313

1414
if TYPE_CHECKING:
15-
from genesis.engine.solvers import RigidSolver
1615
from genesis.engine.entities.rigid_entity.rigid_link import RigidLink
16+
from genesis.engine.solvers import RigidSolver
1717
from genesis.recorders.base_recorder import Recorder, RecorderOptions
1818
from genesis.utils.ring_buffer import TensorRingBuffer
1919
from genesis.vis.rasterizer_context import RasterizerContext
@@ -292,8 +292,37 @@ def _get_formatted_data(self, tensor: torch.Tensor, envs_idx=None) -> torch.Tens
292292
return self._return_data_class(*return_values)
293293

294294
def _sanitize_envs_idx(self, envs_idx) -> torch.Tensor:
295+
if self._manager._sim.n_envs == 0:
296+
return torch.tensor([0], device=gs.device, dtype=gs.tc_int)
295297
return self._manager._sim._scene._sanitize_envs_idx(envs_idx)
296298

299+
def _set_metadata_field(self, input, field, field_size, envs_idx):
300+
envs_idx = self._sanitize_envs_idx(envs_idx)
301+
if field.ndim == 2:
302+
# flat field structure
303+
idx = self._idx * field_size
304+
index_slice = slice(idx, idx + field_size)
305+
else:
306+
# per sensor field structure
307+
index_slice = self._idx
308+
field[:, index_slice] = self._sanitize_for_metadata_tensor(
309+
input, shape=(len(envs_idx), field_size), dtype=field.dtype
310+
)
311+
312+
def _sanitize_for_metadata_tensor(self, input, shape, dtype) -> torch.Tensor:
313+
if not isinstance(input, Sequence):
314+
input = [input]
315+
tensor_input = torch.tensor(input, dtype=dtype, device=gs.device)
316+
if tensor_input.ndim == len(shape) - 1:
317+
# Batch dimension is missing
318+
tensor_input = tensor_input.unsqueeze(0)
319+
if tensor_input.shape[0] != shape[0]:
320+
tensor_input = tensor_input.expand((shape[0], *tensor_input.shape[1:]))
321+
assert (
322+
tensor_input.shape == shape
323+
), f"Input shape {tensor_input.shape} for setting sensor metadata does not match shape {shape}"
324+
return tensor_input
325+
297326

298327
@dataclass
299328
class RigidSensorMetadataMixin:
@@ -345,6 +374,16 @@ def build(self):
345374
dim=1,
346375
)
347376

377+
@gs.assert_built
378+
def set_pos_offset(self, pos_offset, envs_idx=None):
379+
envs_idx = self._sanitize_envs_idx(envs_idx)
380+
self._set_metadata_field(pos_offset, self._shared_metadata.offsets_pos, field_size=3, envs_idx=envs_idx)
381+
382+
@gs.assert_built
383+
def set_quat_offset(self, quat_offset, envs_idx=None):
384+
envs_idx = self._sanitize_envs_idx(envs_idx)
385+
self._set_metadata_field(quat_offset, self._shared_metadata.offsets_quat, field_size=4, envs_idx=envs_idx)
386+
348387

349388
@dataclass
350389
class NoisySensorMetadataMixin:
@@ -372,25 +411,6 @@ class NoisySensorMixin(Generic[NoisySensorMetadataMixinT]):
372411
Base sensor class for analog sensors that are attached to a RigidEntity.
373412
"""
374413

375-
def _set_metadata_field(self, input, field, field_size, envs_idx):
376-
envs_idx = self._sanitize_envs_idx(envs_idx)
377-
idx = self._idx * field_size
378-
field[envs_idx, idx : idx + field_size] = self._sanitize_for_metadata_tensor(
379-
input, shape=(len(envs_idx), field_size), dtype=field.dtype
380-
)
381-
382-
def _sanitize_for_metadata_tensor(self, input, shape, dtype) -> torch.Tensor:
383-
if not isinstance(input, Sequence):
384-
input = [input]
385-
tensor_input = torch.tensor(input, dtype=dtype, device=gs.device)
386-
if tensor_input.ndim == len(shape) - 1:
387-
# Batch dimension is missing
388-
tensor_input = tensor_input.unsqueeze(0).expand((shape[0], *tensor_input.shape))
389-
assert (
390-
tensor_input.shape == shape
391-
), f"Input shape {tensor_input.shape} for setting sensor metadata does not match shape {shape}"
392-
return tensor_input
393-
394414
@gs.assert_built
395415
def set_resolution(self, resolution, envs_idx=None):
396416
self._set_metadata_field(resolution, self._shared_metadata.resolution, self._cache_size, envs_idx)

genesis/engine/sensors/imu.py

Lines changed: 33 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
import torch
77

88
import genesis as gs
9-
from genesis.options.sensors import (
10-
MaybeMatrix3x3Type,
11-
IMU as IMUOptions,
9+
from genesis.options.sensors import IMU as IMUOptions
10+
from genesis.options.sensors import MaybeMatrix3x3Type
11+
from genesis.utils.geom import (
12+
inv_transform_by_quat,
13+
transform_by_quat,
14+
transform_quat_by_quat,
1215
)
13-
from genesis.utils.geom import inv_transform_by_trans_quat, transform_by_quat, transform_quat_by_quat
1416
from genesis.utils.misc import concat_with_tensor, make_tensor_field, tensor_to_array
1517

1618
from .base_sensor import (
@@ -30,21 +32,9 @@
3032
from genesis.vis.rasterizer_context import RasterizerContext
3133

3234

33-
def _view_metadata_as_acc_gyro(metadata_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
34-
"""
35-
Get views of the metadata tensor (B, n_imus * 6) as a tuple of acc and gyro metadata tensors (B, n_imus * 3).
36-
"""
37-
batch_shape, n_data = metadata_tensor.shape[:-1], metadata_tensor.shape[-1]
38-
n_imus = n_data // 6
39-
metadata_tensor_per_sensor = metadata_tensor.reshape((*batch_shape, n_imus, 2, 3))
40-
41-
return (
42-
metadata_tensor_per_sensor[..., 0, :].reshape(*batch_shape, n_imus * 3),
43-
metadata_tensor_per_sensor[..., 1, :].reshape(*batch_shape, n_imus * 3),
44-
)
45-
46-
47-
def _get_skew_to_alignment_matrix(input: MaybeMatrix3x3Type, out: torch.Tensor | None = None) -> torch.Tensor:
35+
def _get_cross_axis_coupling_to_alignment_matrix(
36+
input: MaybeMatrix3x3Type, out: torch.Tensor | None = None
37+
) -> torch.Tensor:
4838
"""
4939
Convert the alignment input to a matrix. Modifies in place if provided, else allocate a new matrix.
5040
"""
@@ -109,41 +99,17 @@ def __init__(
10999
self.pos_offset: torch.Tensor
110100

111101
@gs.assert_built
112-
def set_acc_axes_skew(self, axes_skew: MaybeMatrix3x3Type, envs_idx=None):
102+
def set_acc_cross_axis_coupling(self, cross_axis_coupling: MaybeMatrix3x3Type, envs_idx=None):
113103
envs_idx = self._sanitize_envs_idx(envs_idx)
114-
rot_matrix = _get_skew_to_alignment_matrix(axes_skew)
104+
rot_matrix = _get_cross_axis_coupling_to_alignment_matrix(cross_axis_coupling)
115105
self._shared_metadata.alignment_rot_matrix[envs_idx, self._idx * 2, :, :] = rot_matrix
116106

117107
@gs.assert_built
118-
def set_gyro_axes_skew(self, axes_skew: MaybeMatrix3x3Type, envs_idx=None):
108+
def set_gyro_cross_axis_coupling(self, cross_axis_coupling: MaybeMatrix3x3Type, envs_idx=None):
119109
envs_idx = self._sanitize_envs_idx(envs_idx)
120-
rot_matrix = _get_skew_to_alignment_matrix(axes_skew)
110+
rot_matrix = _get_cross_axis_coupling_to_alignment_matrix(cross_axis_coupling)
121111
self._shared_metadata.alignment_rot_matrix[envs_idx, self._idx * 2 + 1, :, :] = rot_matrix
122112

123-
@gs.assert_built
124-
def set_acc_bias(self, bias, envs_idx=None):
125-
self._set_metadata_field(bias, self._shared_metadata.acc_bias, field_size=3, envs_idx=envs_idx)
126-
127-
@gs.assert_built
128-
def set_gyro_bias(self, bias, envs_idx=None):
129-
self._set_metadata_field(bias, self._shared_metadata.gyro_bias, field_size=3, envs_idx=envs_idx)
130-
131-
@gs.assert_built
132-
def set_acc_random_walk(self, random_walk, envs_idx=None):
133-
self._set_metadata_field(random_walk, self._shared_metadata.acc_random_walk, field_size=3, envs_idx=envs_idx)
134-
135-
@gs.assert_built
136-
def set_gyro_random_walk(self, random_walk, envs_idx=None):
137-
self._set_metadata_field(random_walk, self._shared_metadata.gyro_random_walk, field_size=3, envs_idx=envs_idx)
138-
139-
@gs.assert_built
140-
def set_acc_noise(self, noise, envs_idx=None):
141-
self._set_metadata_field(noise, self._shared_metadata.acc_noise, field_size=3, envs_idx=envs_idx)
142-
143-
@gs.assert_built
144-
def set_gyro_noise(self, noise, envs_idx=None):
145-
self._set_metadata_field(noise, self._shared_metadata.gyro_noise, field_size=3, envs_idx=envs_idx)
146-
147113
# ================================ internal methods ================================
148114

149115
def build(self):
@@ -160,21 +126,12 @@ def build(self):
160126
self._options.noise = _to_tuple(self._options.acc_noise, self._options.gyro_noise, length_per_value=3)
161127
super().build() # set all shared metadata from RigidSensorBase and NoisySensorBase
162128

163-
self._shared_metadata.acc_bias, self._shared_metadata.gyro_bias = _view_metadata_as_acc_gyro(
164-
self._shared_metadata.bias
165-
)
166-
self._shared_metadata.acc_random_walk, self._shared_metadata.gyro_random_walk = _view_metadata_as_acc_gyro(
167-
self._shared_metadata.random_walk
168-
)
169-
self._shared_metadata.acc_noise, self._shared_metadata.gyro_noise = _view_metadata_as_acc_gyro(
170-
self._shared_metadata.noise
171-
)
172129
self._shared_metadata.alignment_rot_matrix = concat_with_tensor(
173130
self._shared_metadata.alignment_rot_matrix,
174131
torch.stack(
175132
[
176-
_get_skew_to_alignment_matrix(self._options.acc_axes_skew),
177-
_get_skew_to_alignment_matrix(self._options.gyro_axes_skew),
133+
_get_cross_axis_coupling_to_alignment_matrix(self._options.acc_cross_axis_coupling),
134+
_get_cross_axis_coupling_to_alignment_matrix(self._options.gyro_cross_axis_coupling),
178135
],
179136
),
180137
expand=(self._manager._sim._B, 2, 3, 3),
@@ -198,22 +155,35 @@ def _update_shared_ground_truth_cache(
198155
"""
199156
Update the current ground truth values for all IMU sensors.
200157
"""
158+
# Extract acceleration and gravity in world frame
201159
assert shared_metadata.solver is not None
202160
gravity = shared_metadata.solver.get_gravity()
203161
quats = shared_metadata.solver.get_links_quat(links_idx=shared_metadata.links_idx)
204162
acc = shared_metadata.solver.get_links_acc(links_idx=shared_metadata.links_idx)
205163
ang = shared_metadata.solver.get_links_ang(links_idx=shared_metadata.links_idx)
164+
if acc.ndim == 2:
165+
acc = acc.unsqueeze(0)
166+
ang = ang.unsqueeze(0)
206167

207168
offset_quats = transform_quat_by_quat(quats, shared_metadata.offsets_quat)
208169

170+
# Additional acceleration if offset: a_imu = a_link + α × r + ω × (ω × r)
171+
if torch.any(torch.abs(shared_metadata.offsets_pos) > gs.EPS):
172+
ang_acc = shared_metadata.solver.get_links_acc_ang(links_idx=shared_metadata.links_idx)
173+
if ang_acc.ndim == 2:
174+
ang_acc = ang_acc.unsqueeze(0)
175+
offset_pos_world = transform_by_quat(shared_metadata.offsets_pos, quats)
176+
tangential_acc = torch.cross(ang_acc, offset_pos_world, dim=-1)
177+
centripetal_acc = torch.cross(ang, torch.cross(ang, offset_pos_world, dim=-1), dim=-1)
178+
acc += tangential_acc + centripetal_acc
179+
180+
# Subtract gravity then move to local frame
209181
# acc/ang shape: (B, n_imus, 3)
210-
local_acc = inv_transform_by_trans_quat(acc, shared_metadata.offsets_pos, offset_quats)
211-
local_ang = inv_transform_by_trans_quat(ang, shared_metadata.offsets_pos, offset_quats)
212-
213-
*batch_size, n_imus, _ = local_acc.shape
214-
local_acc = local_acc - gravity.unsqueeze(-2).expand((*batch_size, n_imus, -1))
182+
local_acc = inv_transform_by_quat(acc - gravity.unsqueeze(-2), offset_quats)
183+
local_ang = inv_transform_by_quat(ang, offset_quats)
215184

216185
# cache shape: (B, n_imus * 6)
186+
*batch_size, n_imus, _ = local_acc.shape
217187
strided_ground_truth_cache = shared_ground_truth_cache.reshape((*batch_size, n_imus, 2, 3))
218188
strided_ground_truth_cache[..., 0, :].copy_(local_acc)
219189
strided_ground_truth_cache[..., 1, :].copy_(local_ang)

genesis/options/sensors/options.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1-
from dataclasses import dataclass
21
from typing import Sequence
32

43
import numpy as np
5-
import torch
64
from pydantic import Field
75

86
import genesis as gs
97

108
from ..options import Options
11-
from .raycaster import RaycastPattern, DepthCameraPattern
12-
9+
from .raycaster import DepthCameraPattern, RaycastPattern
1310

1411
Tuple3FType = tuple[float, float, float]
1512
MaybeTuple3FType = float | Tuple3FType
@@ -190,7 +187,7 @@ class IMU(RigidSensorOptionsMixin, NoisySensorOptionsMixin, SensorOptions):
190187
acc_resolution : float, optional
191188
The measurement resolution of the accelerometer (smallest increment of change in the sensor reading).
192189
Default is 0.0, which means no quantization is applied.
193-
acc_axes_skew : float | tuple[float, float, float] | Sequence[float]
190+
acc_cross_axis_coupling : float | tuple[float, float, float] | Sequence[float]
194191
Accelerometer axes alignment as a 3x3 rotation matrix, where diagonal elements represent alignment (0.0 to 1.0)
195192
for each axis, and off-diagonal elements account for cross-axis misalignment effects.
196193
- If a scalar is provided (float), all off-diagonal elements are set to the scalar value.
@@ -205,8 +202,8 @@ class IMU(RigidSensorOptionsMixin, NoisySensorOptionsMixin, SensorOptions):
205202
gyro_resolution : float, optional
206203
The measurement resolution of the gyroscope (smallest increment of change in the sensor reading).
207204
Default is 0.0, which means no quantization is applied.
208-
gyro_axes_skew : float | tuple[float, float, float] | Sequence[float]
209-
Gyroscope axes alignment as a 3x3 rotation matrix, similar to `acc_axes_skew`.
205+
gyro_cross_axis_coupling : float | tuple[float, float, float] | Sequence[float]
206+
Gyroscope axes alignment as a 3x3 rotation matrix, similar to `acc_cross_axis_coupling`.
210207
gyro_bias : tuple[float, float, float]
211208
The constant additive bias for each axis of the gyroscope.
212209
gyro_noise : tuple[float, float, float]
@@ -225,8 +222,8 @@ class IMU(RigidSensorOptionsMixin, NoisySensorOptionsMixin, SensorOptions):
225222

226223
acc_resolution: MaybeTuple3FType = 0.0
227224
gyro_resolution: MaybeTuple3FType = 0.0
228-
acc_axes_skew: MaybeMatrix3x3Type = 0.0
229-
gyro_axes_skew: MaybeMatrix3x3Type = 0.0
225+
acc_cross_axis_coupling: MaybeMatrix3x3Type = 0.0
226+
gyro_cross_axis_coupling: MaybeMatrix3x3Type = 0.0
230227
acc_noise: MaybeTuple3FType = 0.0
231228
gyro_noise: MaybeTuple3FType = 0.0
232229
acc_bias: MaybeTuple3FType = 0.0
@@ -240,15 +237,17 @@ class IMU(RigidSensorOptionsMixin, NoisySensorOptionsMixin, SensorOptions):
240237
debug_gyro_scale: float = 0.01
241238

242239
def model_post_init(self, _):
243-
self._validate_axes_skew(self.acc_axes_skew)
244-
self._validate_axes_skew(self.gyro_axes_skew)
245-
246-
def _validate_axes_skew(self, axes_skew):
247-
axes_skew_np = np.array(axes_skew)
248-
if axes_skew_np.shape not in ((), (3,), (3, 3)):
249-
gs.raise_exception(f"axes_skew shape should be (), (3,), or (3, 3), got: {axes_skew_np.shape}")
250-
if np.any(axes_skew_np < 0.0) or np.any(axes_skew_np > 1.0):
251-
gs.raise_exception(f"axes_skew values should be between 0.0 and 1.0, got: {axes_skew}")
240+
self._validate_cross_axis_coupling(self.acc_cross_axis_coupling)
241+
self._validate_cross_axis_coupling(self.gyro_cross_axis_coupling)
242+
243+
def _validate_cross_axis_coupling(self, cross_axis_coupling):
244+
cross_axis_coupling_np = np.array(cross_axis_coupling)
245+
if cross_axis_coupling_np.shape not in ((), (3,), (3, 3)):
246+
gs.raise_exception(
247+
f"cross_axis_coupling shape should be (), (3,), or (3, 3), got: {cross_axis_coupling_np.shape}"
248+
)
249+
if np.any(cross_axis_coupling_np < 0.0) or np.any(cross_axis_coupling_np > 1.0):
250+
gs.raise_exception(f"cross_axis_coupling values should be between 0.0 and 1.0, got: {cross_axis_coupling}")
252251

253252

254253
class Raycaster(RigidSensorOptionsMixin, SensorOptions):

0 commit comments

Comments
 (0)