66import torch
77
88import genesis as gs
9- from genesis .options .sensors import (
10- MaybeMatrix3x3Type ,
11- IMU as IMUOptions ,
9+ from genesis .options .sensors import IMU as IMUOptions
10+ from genesis .options .sensors import MaybeMatrix3x3Type
11+ from genesis .utils .geom import (
12+ inv_transform_by_quat ,
13+ transform_by_quat ,
14+ transform_quat_by_quat ,
1215)
13- from genesis .utils .geom import inv_transform_by_trans_quat , transform_by_quat , transform_quat_by_quat
1416from genesis .utils .misc import concat_with_tensor , make_tensor_field , tensor_to_array
1517
1618from .base_sensor import (
3032 from genesis .vis .rasterizer_context import RasterizerContext
3133
3234
33- def _view_metadata_as_acc_gyro (metadata_tensor : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
34- """
35- Get views of the metadata tensor (B, n_imus * 6) as a tuple of acc and gyro metadata tensors (B, n_imus * 3).
36- """
37- batch_shape , n_data = metadata_tensor .shape [:- 1 ], metadata_tensor .shape [- 1 ]
38- n_imus = n_data // 6
39- metadata_tensor_per_sensor = metadata_tensor .reshape ((* batch_shape , n_imus , 2 , 3 ))
40-
41- return (
42- metadata_tensor_per_sensor [..., 0 , :].reshape (* batch_shape , n_imus * 3 ),
43- metadata_tensor_per_sensor [..., 1 , :].reshape (* batch_shape , n_imus * 3 ),
44- )
45-
46-
47- def _get_skew_to_alignment_matrix (input : MaybeMatrix3x3Type , out : torch .Tensor | None = None ) -> torch .Tensor :
35+ def _get_cross_axis_coupling_to_alignment_matrix (
36+ input : MaybeMatrix3x3Type , out : torch .Tensor | None = None
37+ ) -> torch .Tensor :
4838 """
4939 Convert the alignment input to a matrix. Modifies in place if provided, else allocate a new matrix.
5040 """
@@ -109,41 +99,17 @@ def __init__(
10999 self .pos_offset : torch .Tensor
110100
111101 @gs .assert_built
112- def set_acc_axes_skew (self , axes_skew : MaybeMatrix3x3Type , envs_idx = None ):
102+ def set_acc_cross_axis_coupling (self , cross_axis_coupling : MaybeMatrix3x3Type , envs_idx = None ):
113103 envs_idx = self ._sanitize_envs_idx (envs_idx )
114- rot_matrix = _get_skew_to_alignment_matrix ( axes_skew )
104+ rot_matrix = _get_cross_axis_coupling_to_alignment_matrix ( cross_axis_coupling )
115105 self ._shared_metadata .alignment_rot_matrix [envs_idx , self ._idx * 2 , :, :] = rot_matrix
116106
117107 @gs .assert_built
118- def set_gyro_axes_skew (self , axes_skew : MaybeMatrix3x3Type , envs_idx = None ):
108+ def set_gyro_cross_axis_coupling (self , cross_axis_coupling : MaybeMatrix3x3Type , envs_idx = None ):
119109 envs_idx = self ._sanitize_envs_idx (envs_idx )
120- rot_matrix = _get_skew_to_alignment_matrix ( axes_skew )
110+ rot_matrix = _get_cross_axis_coupling_to_alignment_matrix ( cross_axis_coupling )
121111 self ._shared_metadata .alignment_rot_matrix [envs_idx , self ._idx * 2 + 1 , :, :] = rot_matrix
122112
123- @gs .assert_built
124- def set_acc_bias (self , bias , envs_idx = None ):
125- self ._set_metadata_field (bias , self ._shared_metadata .acc_bias , field_size = 3 , envs_idx = envs_idx )
126-
127- @gs .assert_built
128- def set_gyro_bias (self , bias , envs_idx = None ):
129- self ._set_metadata_field (bias , self ._shared_metadata .gyro_bias , field_size = 3 , envs_idx = envs_idx )
130-
131- @gs .assert_built
132- def set_acc_random_walk (self , random_walk , envs_idx = None ):
133- self ._set_metadata_field (random_walk , self ._shared_metadata .acc_random_walk , field_size = 3 , envs_idx = envs_idx )
134-
135- @gs .assert_built
136- def set_gyro_random_walk (self , random_walk , envs_idx = None ):
137- self ._set_metadata_field (random_walk , self ._shared_metadata .gyro_random_walk , field_size = 3 , envs_idx = envs_idx )
138-
139- @gs .assert_built
140- def set_acc_noise (self , noise , envs_idx = None ):
141- self ._set_metadata_field (noise , self ._shared_metadata .acc_noise , field_size = 3 , envs_idx = envs_idx )
142-
143- @gs .assert_built
144- def set_gyro_noise (self , noise , envs_idx = None ):
145- self ._set_metadata_field (noise , self ._shared_metadata .gyro_noise , field_size = 3 , envs_idx = envs_idx )
146-
147113 # ================================ internal methods ================================
148114
149115 def build (self ):
@@ -160,21 +126,12 @@ def build(self):
160126 self ._options .noise = _to_tuple (self ._options .acc_noise , self ._options .gyro_noise , length_per_value = 3 )
161127 super ().build () # set all shared metadata from RigidSensorBase and NoisySensorBase
162128
163- self ._shared_metadata .acc_bias , self ._shared_metadata .gyro_bias = _view_metadata_as_acc_gyro (
164- self ._shared_metadata .bias
165- )
166- self ._shared_metadata .acc_random_walk , self ._shared_metadata .gyro_random_walk = _view_metadata_as_acc_gyro (
167- self ._shared_metadata .random_walk
168- )
169- self ._shared_metadata .acc_noise , self ._shared_metadata .gyro_noise = _view_metadata_as_acc_gyro (
170- self ._shared_metadata .noise
171- )
172129 self ._shared_metadata .alignment_rot_matrix = concat_with_tensor (
173130 self ._shared_metadata .alignment_rot_matrix ,
174131 torch .stack (
175132 [
176- _get_skew_to_alignment_matrix (self ._options .acc_axes_skew ),
177- _get_skew_to_alignment_matrix (self ._options .gyro_axes_skew ),
133+ _get_cross_axis_coupling_to_alignment_matrix (self ._options .acc_cross_axis_coupling ),
134+ _get_cross_axis_coupling_to_alignment_matrix (self ._options .gyro_cross_axis_coupling ),
178135 ],
179136 ),
180137 expand = (self ._manager ._sim ._B , 2 , 3 , 3 ),
@@ -198,22 +155,35 @@ def _update_shared_ground_truth_cache(
198155 """
199156 Update the current ground truth values for all IMU sensors.
200157 """
158+ # Extract acceleration and gravity in world frame
201159 assert shared_metadata .solver is not None
202160 gravity = shared_metadata .solver .get_gravity ()
203161 quats = shared_metadata .solver .get_links_quat (links_idx = shared_metadata .links_idx )
204162 acc = shared_metadata .solver .get_links_acc (links_idx = shared_metadata .links_idx )
205163 ang = shared_metadata .solver .get_links_ang (links_idx = shared_metadata .links_idx )
164+ if acc .ndim == 2 :
165+ acc = acc .unsqueeze (0 )
166+ ang = ang .unsqueeze (0 )
206167
207168 offset_quats = transform_quat_by_quat (quats , shared_metadata .offsets_quat )
208169
170+ # Additional acceleration if offset: a_imu = a_link + α × r + ω × (ω × r)
171+ if torch .any (torch .abs (shared_metadata .offsets_pos ) > gs .EPS ):
172+ ang_acc = shared_metadata .solver .get_links_acc_ang (links_idx = shared_metadata .links_idx )
173+ if ang_acc .ndim == 2 :
174+ ang_acc = ang_acc .unsqueeze (0 )
175+ offset_pos_world = transform_by_quat (shared_metadata .offsets_pos , quats )
176+ tangential_acc = torch .cross (ang_acc , offset_pos_world , dim = - 1 )
177+ centripetal_acc = torch .cross (ang , torch .cross (ang , offset_pos_world , dim = - 1 ), dim = - 1 )
178+ acc += tangential_acc + centripetal_acc
179+
180+ # Subtract gravity then move to local frame
209181 # acc/ang shape: (B, n_imus, 3)
210- local_acc = inv_transform_by_trans_quat (acc , shared_metadata .offsets_pos , offset_quats )
211- local_ang = inv_transform_by_trans_quat (ang , shared_metadata .offsets_pos , offset_quats )
212-
213- * batch_size , n_imus , _ = local_acc .shape
214- local_acc = local_acc - gravity .unsqueeze (- 2 ).expand ((* batch_size , n_imus , - 1 ))
182+ local_acc = inv_transform_by_quat (acc - gravity .unsqueeze (- 2 ), offset_quats )
183+ local_ang = inv_transform_by_quat (ang , offset_quats )
215184
216185 # cache shape: (B, n_imus * 6)
186+ * batch_size , n_imus , _ = local_acc .shape
217187 strided_ground_truth_cache = shared_ground_truth_cache .reshape ((* batch_size , n_imus , 2 , 3 ))
218188 strided_ground_truth_cache [..., 0 , :].copy_ (local_acc )
219189 strided_ground_truth_cache [..., 1 , :].copy_ (local_ang )
0 commit comments