Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 40 additions & 20 deletions genesis/engine/sensors/base_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from genesis.utils.misc import concat_with_tensor, make_tensor_field

if TYPE_CHECKING:
from genesis.engine.solvers import RigidSolver
from genesis.engine.entities.rigid_entity.rigid_link import RigidLink
from genesis.engine.solvers import RigidSolver
from genesis.recorders.base_recorder import Recorder, RecorderOptions
from genesis.utils.ring_buffer import TensorRingBuffer
from genesis.vis.rasterizer_context import RasterizerContext
Expand Down Expand Up @@ -292,8 +292,37 @@ def _get_formatted_data(self, tensor: torch.Tensor, envs_idx=None) -> torch.Tens
return self._return_data_class(*return_values)

def _sanitize_envs_idx(self, envs_idx) -> torch.Tensor:
if self._manager._sim.n_envs == 0:
return torch.tensor([0], device=gs.device, dtype=gs.tc_int)
return self._manager._sim._scene._sanitize_envs_idx(envs_idx)

def _set_metadata_field(self, input, field, field_size, envs_idx):
envs_idx = self._sanitize_envs_idx(envs_idx)
if field.ndim == 2:
# flat field structure
idx = self._idx * field_size
index_slice = slice(idx, idx + field_size)
else:
# per sensor field structure
index_slice = self._idx
field[:, index_slice] = self._sanitize_for_metadata_tensor(
input, shape=(len(envs_idx), field_size), dtype=field.dtype
)

def _sanitize_for_metadata_tensor(self, input, shape, dtype) -> torch.Tensor:
if not isinstance(input, Sequence):
input = [input]
tensor_input = torch.tensor(input, dtype=dtype, device=gs.device)
if tensor_input.ndim == len(shape) - 1:
# Batch dimension is missing
tensor_input = tensor_input.unsqueeze(0)
if tensor_input.shape[0] != shape[0]:
tensor_input = tensor_input.expand((shape[0], *tensor_input.shape[1:]))
assert (
tensor_input.shape == shape
), f"Input shape {tensor_input.shape} for setting sensor metadata does not match shape {shape}"
return tensor_input


@dataclass
class RigidSensorMetadataMixin:
Expand Down Expand Up @@ -345,6 +374,16 @@ def build(self):
dim=1,
)

@gs.assert_built
def set_pos_offset(self, pos_offset, envs_idx=None):
envs_idx = self._sanitize_envs_idx(envs_idx)
self._set_metadata_field(pos_offset, self._shared_metadata.offsets_pos, field_size=3, envs_idx=envs_idx)

@gs.assert_built
def set_quat_offset(self, quat_offset, envs_idx=None):
envs_idx = self._sanitize_envs_idx(envs_idx)
self._set_metadata_field(quat_offset, self._shared_metadata.offsets_quat, field_size=4, envs_idx=envs_idx)


@dataclass
class NoisySensorMetadataMixin:
Expand Down Expand Up @@ -372,25 +411,6 @@ class NoisySensorMixin(Generic[NoisySensorMetadataMixinT]):
Base sensor class for analog sensors that are attached to a RigidEntity.
"""

def _set_metadata_field(self, input, field, field_size, envs_idx):
envs_idx = self._sanitize_envs_idx(envs_idx)
idx = self._idx * field_size
field[envs_idx, idx : idx + field_size] = self._sanitize_for_metadata_tensor(
input, shape=(len(envs_idx), field_size), dtype=field.dtype
)

def _sanitize_for_metadata_tensor(self, input, shape, dtype) -> torch.Tensor:
if not isinstance(input, Sequence):
input = [input]
tensor_input = torch.tensor(input, dtype=dtype, device=gs.device)
if tensor_input.ndim == len(shape) - 1:
# Batch dimension is missing
tensor_input = tensor_input.unsqueeze(0).expand((shape[0], *tensor_input.shape))
assert (
tensor_input.shape == shape
), f"Input shape {tensor_input.shape} for setting sensor metadata does not match shape {shape}"
return tensor_input

@gs.assert_built
def set_resolution(self, resolution, envs_idx=None):
self._set_metadata_field(resolution, self._shared_metadata.resolution, self._cache_size, envs_idx)
Expand Down
94 changes: 33 additions & 61 deletions genesis/engine/sensors/imu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import torch

import genesis as gs
from genesis.options.sensors import (
MaybeMatrix3x3Type,
IMU as IMUOptions,
from genesis.options.sensors import IMU as IMUOptions
from genesis.options.sensors import MaybeMatrix3x3Type
from genesis.utils.geom import (
inv_transform_by_quat,
transform_by_quat,
transform_quat_by_quat,
)
from genesis.utils.geom import inv_transform_by_trans_quat, transform_by_quat, transform_quat_by_quat
from genesis.utils.misc import concat_with_tensor, make_tensor_field, tensor_to_array

from .base_sensor import (
Expand All @@ -30,21 +32,9 @@
from genesis.vis.rasterizer_context import RasterizerContext


def _view_metadata_as_acc_gyro(metadata_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Get views of the metadata tensor (B, n_imus * 6) as a tuple of acc and gyro metadata tensors (B, n_imus * 3).
"""
batch_shape, n_data = metadata_tensor.shape[:-1], metadata_tensor.shape[-1]
n_imus = n_data // 6
metadata_tensor_per_sensor = metadata_tensor.reshape((*batch_shape, n_imus, 2, 3))

return (
metadata_tensor_per_sensor[..., 0, :].reshape(*batch_shape, n_imus * 3),
metadata_tensor_per_sensor[..., 1, :].reshape(*batch_shape, n_imus * 3),
)


def _get_skew_to_alignment_matrix(input: MaybeMatrix3x3Type, out: torch.Tensor | None = None) -> torch.Tensor:
def _get_cross_axis_coupling_to_alignment_matrix(
input: MaybeMatrix3x3Type, out: torch.Tensor | None = None
) -> torch.Tensor:
"""
Convert the alignment input to a matrix. Modifies in place if provided, else allocate a new matrix.
"""
Expand Down Expand Up @@ -109,41 +99,17 @@ def __init__(
self.pos_offset: torch.Tensor

@gs.assert_built
def set_acc_axes_skew(self, axes_skew: MaybeMatrix3x3Type, envs_idx=None):
def set_acc_cross_axis_coupling(self, cross_axis_coupling: MaybeMatrix3x3Type, envs_idx=None):
envs_idx = self._sanitize_envs_idx(envs_idx)
rot_matrix = _get_skew_to_alignment_matrix(axes_skew)
rot_matrix = _get_cross_axis_coupling_to_alignment_matrix(cross_axis_coupling)
self._shared_metadata.alignment_rot_matrix[envs_idx, self._idx * 2, :, :] = rot_matrix

@gs.assert_built
def set_gyro_axes_skew(self, axes_skew: MaybeMatrix3x3Type, envs_idx=None):
def set_gyro_cross_axis_coupling(self, cross_axis_coupling: MaybeMatrix3x3Type, envs_idx=None):
envs_idx = self._sanitize_envs_idx(envs_idx)
rot_matrix = _get_skew_to_alignment_matrix(axes_skew)
rot_matrix = _get_cross_axis_coupling_to_alignment_matrix(cross_axis_coupling)
self._shared_metadata.alignment_rot_matrix[envs_idx, self._idx * 2 + 1, :, :] = rot_matrix

@gs.assert_built
def set_acc_bias(self, bias, envs_idx=None):
self._set_metadata_field(bias, self._shared_metadata.acc_bias, field_size=3, envs_idx=envs_idx)

@gs.assert_built
def set_gyro_bias(self, bias, envs_idx=None):
self._set_metadata_field(bias, self._shared_metadata.gyro_bias, field_size=3, envs_idx=envs_idx)

Comment on lines -123 to -130
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the set_acc_* and set_gyro_* functions since can use set_* from the NoisySensorMixin instead; having these APIs introduce duplicity and also don't work properly in some cases because of tensor copying not modifying the original metadata field

@gs.assert_built
def set_acc_random_walk(self, random_walk, envs_idx=None):
self._set_metadata_field(random_walk, self._shared_metadata.acc_random_walk, field_size=3, envs_idx=envs_idx)

@gs.assert_built
def set_gyro_random_walk(self, random_walk, envs_idx=None):
self._set_metadata_field(random_walk, self._shared_metadata.gyro_random_walk, field_size=3, envs_idx=envs_idx)

@gs.assert_built
def set_acc_noise(self, noise, envs_idx=None):
self._set_metadata_field(noise, self._shared_metadata.acc_noise, field_size=3, envs_idx=envs_idx)

@gs.assert_built
def set_gyro_noise(self, noise, envs_idx=None):
self._set_metadata_field(noise, self._shared_metadata.gyro_noise, field_size=3, envs_idx=envs_idx)

# ================================ internal methods ================================

def build(self):
Expand All @@ -160,21 +126,12 @@ def build(self):
self._options.noise = _to_tuple(self._options.acc_noise, self._options.gyro_noise, length_per_value=3)
super().build() # set all shared metadata from RigidSensorBase and NoisySensorBase

self._shared_metadata.acc_bias, self._shared_metadata.gyro_bias = _view_metadata_as_acc_gyro(
self._shared_metadata.bias
)
self._shared_metadata.acc_random_walk, self._shared_metadata.gyro_random_walk = _view_metadata_as_acc_gyro(
self._shared_metadata.random_walk
)
self._shared_metadata.acc_noise, self._shared_metadata.gyro_noise = _view_metadata_as_acc_gyro(
self._shared_metadata.noise
)
self._shared_metadata.alignment_rot_matrix = concat_with_tensor(
self._shared_metadata.alignment_rot_matrix,
torch.stack(
[
_get_skew_to_alignment_matrix(self._options.acc_axes_skew),
_get_skew_to_alignment_matrix(self._options.gyro_axes_skew),
_get_cross_axis_coupling_to_alignment_matrix(self._options.acc_cross_axis_coupling),
_get_cross_axis_coupling_to_alignment_matrix(self._options.gyro_cross_axis_coupling),
],
),
expand=(self._manager._sim._B, 2, 3, 3),
Expand Down Expand Up @@ -203,15 +160,30 @@ def _update_shared_ground_truth_cache(
quats = shared_metadata.solver.get_links_quat(links_idx=shared_metadata.links_idx)
acc = shared_metadata.solver.get_links_acc(links_idx=shared_metadata.links_idx)
ang = shared_metadata.solver.get_links_ang(links_idx=shared_metadata.links_idx)
if acc.ndim == 2:
acc = acc.unsqueeze(0)
ang = ang.unsqueeze(0)

offset_quats = transform_quat_by_quat(quats, shared_metadata.offsets_quat)

# additional acceleration if offset: a_imu = a_link + α × r + ω × (ω × r)
if torch.any(torch.abs(shared_metadata.offsets_pos) > gs.EPS):
ang_acc = shared_metadata.solver.get_links_acc_ang(links_idx=shared_metadata.links_idx)
if ang_acc.ndim == 2:
ang_acc = ang_acc.unsqueeze(0)
offset_pos_world = transform_by_quat(shared_metadata.offsets_pos, quats)
tangential_acc = torch.cross(ang_acc, offset_pos_world, dim=-1)
centripetal_acc = torch.cross(ang, torch.cross(ang, offset_pos_world, dim=-1), dim=-1)
acc += tangential_acc + centripetal_acc

# acc/ang shape: (B, n_imus, 3)
local_acc = inv_transform_by_trans_quat(acc, shared_metadata.offsets_pos, offset_quats)
local_ang = inv_transform_by_trans_quat(ang, shared_metadata.offsets_pos, offset_quats)
local_acc = inv_transform_by_quat(acc, offset_quats)
local_ang = inv_transform_by_quat(ang, offset_quats)

*batch_size, n_imus, _ = local_acc.shape
local_acc = local_acc - gravity.unsqueeze(-2).expand((*batch_size, n_imus, -1))
local_acc = local_acc - inv_transform_by_quat(
gravity.unsqueeze(-2).expand((*batch_size, n_imus, -1)), offset_quats
)

# cache shape: (B, n_imus * 6)
strided_ground_truth_cache = shared_ground_truth_cache.reshape((*batch_size, n_imus, 2, 3))
Expand Down
35 changes: 17 additions & 18 deletions genesis/options/sensors/options.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from dataclasses import dataclass
from typing import Sequence

import numpy as np
import torch
from pydantic import Field

import genesis as gs

from ..options import Options
from .raycaster import RaycastPattern, DepthCameraPattern

from .raycaster import DepthCameraPattern, RaycastPattern

Tuple3FType = tuple[float, float, float]
MaybeTuple3FType = float | Tuple3FType
Expand Down Expand Up @@ -190,7 +187,7 @@ class IMU(RigidSensorOptionsMixin, NoisySensorOptionsMixin, SensorOptions):
acc_resolution : float, optional
The measurement resolution of the accelerometer (smallest increment of change in the sensor reading).
Default is 0.0, which means no quantization is applied.
acc_axes_skew : float | tuple[float, float, float] | Sequence[float]
acc_cross_axis_coupling : float | tuple[float, float, float] | Sequence[float]
Accelerometer axes alignment as a 3x3 rotation matrix, where diagonal elements represent alignment (0.0 to 1.0)
for each axis, and off-diagonal elements account for cross-axis misalignment effects.
- If a scalar is provided (float), all off-diagonal elements are set to the scalar value.
Expand All @@ -205,8 +202,8 @@ class IMU(RigidSensorOptionsMixin, NoisySensorOptionsMixin, SensorOptions):
gyro_resolution : float, optional
The measurement resolution of the gyroscope (smallest increment of change in the sensor reading).
Default is 0.0, which means no quantization is applied.
gyro_axes_skew : float | tuple[float, float, float] | Sequence[float]
Gyroscope axes alignment as a 3x3 rotation matrix, similar to `acc_axes_skew`.
gyro_cross_axis_coupling : float | tuple[float, float, float] | Sequence[float]
Gyroscope axes alignment as a 3x3 rotation matrix, similar to `acc_cross_axis_coupling`.
gyro_bias : tuple[float, float, float]
The constant additive bias for each axis of the gyroscope.
gyro_noise : tuple[float, float, float]
Expand All @@ -225,8 +222,8 @@ class IMU(RigidSensorOptionsMixin, NoisySensorOptionsMixin, SensorOptions):

acc_resolution: MaybeTuple3FType = 0.0
gyro_resolution: MaybeTuple3FType = 0.0
acc_axes_skew: MaybeMatrix3x3Type = 0.0
gyro_axes_skew: MaybeMatrix3x3Type = 0.0
acc_cross_axis_coupling: MaybeMatrix3x3Type = 0.0
gyro_cross_axis_coupling: MaybeMatrix3x3Type = 0.0
acc_noise: MaybeTuple3FType = 0.0
gyro_noise: MaybeTuple3FType = 0.0
acc_bias: MaybeTuple3FType = 0.0
Expand All @@ -240,15 +237,17 @@ class IMU(RigidSensorOptionsMixin, NoisySensorOptionsMixin, SensorOptions):
debug_gyro_scale: float = 0.01

def model_post_init(self, _):
self._validate_axes_skew(self.acc_axes_skew)
self._validate_axes_skew(self.gyro_axes_skew)

def _validate_axes_skew(self, axes_skew):
axes_skew_np = np.array(axes_skew)
if axes_skew_np.shape not in ((), (3,), (3, 3)):
gs.raise_exception(f"axes_skew shape should be (), (3,), or (3, 3), got: {axes_skew_np.shape}")
if np.any(axes_skew_np < 0.0) or np.any(axes_skew_np > 1.0):
gs.raise_exception(f"axes_skew values should be between 0.0 and 1.0, got: {axes_skew}")
self._validate_cross_axis_coupling(self.acc_cross_axis_coupling)
self._validate_cross_axis_coupling(self.gyro_cross_axis_coupling)

def _validate_cross_axis_coupling(self, cross_axis_coupling):
cross_axis_coupling_np = np.array(cross_axis_coupling)
if cross_axis_coupling_np.shape not in ((), (3,), (3, 3)):
gs.raise_exception(
f"cross_axis_coupling shape should be (), (3,), or (3, 3), got: {cross_axis_coupling_np.shape}"
)
if np.any(cross_axis_coupling_np < 0.0) or np.any(cross_axis_coupling_np > 1.0):
gs.raise_exception(f"cross_axis_coupling values should be between 0.0 and 1.0, got: {cross_axis_coupling}")


class Raycaster(RigidSensorOptionsMixin, SensorOptions):
Expand Down
Loading
Loading