|
26 | 26 | from mujoco.mjx._src.types import DataJAX
|
27 | 27 | from mujoco.mjx._src.types import DisableBit
|
28 | 28 | from mujoco.mjx._src.types import EqType
|
| 29 | +from mujoco.mjx._src.types import Impl |
29 | 30 | from mujoco.mjx._src.types import JointType
|
30 | 31 | from mujoco.mjx._src.types import Model
|
31 | 32 | from mujoco.mjx._src.types import ModelJAX
|
32 | 33 | from mujoco.mjx._src.types import ObjType
|
33 | 34 | from mujoco.mjx._src.types import TrnType
|
34 | 35 | from mujoco.mjx._src.types import WrapType
|
35 | 36 | # pylint: enable=g-importing-member
|
| 37 | +import mujoco.mjx.warp as mjxw |
36 | 38 | import numpy as np
|
37 | 39 |
|
38 | 40 |
|
39 | 41 | def kinematics(m: Model, d: Data) -> Data:
|
40 | 42 | """Converts position/velocity from generalized coordinates to maximal."""
|
| 43 | + if m.impl == Impl.WARP and d.impl == Impl.WARP and mjxw.WARP_INSTALLED: |
| 44 | + from mujoco.mjx.warp import smooth as mjxw_smooth # pylint: disable=g-import-not-at-top # pytype: disable=import-error |
| 45 | + return mjxw_smooth.kinematics(m, d) |
| 46 | + |
41 | 47 | def fn(carry, jnt_typs, jnt_pos, jnt_axis, qpos, qpos0, pos, quat):
|
42 | 48 | # calculate joint anchors, axes, body pos and quat in global frame
|
43 | 49 | # also normalize qpos while we're at it
|
@@ -844,6 +850,10 @@ def _forward(carry, cfrc_ext, cinert, cvel, body_dofadr, body_dofnum):
|
844 | 850 |
|
845 | 851 | def tendon(m: Model, d: Data) -> Data:
|
846 | 852 | """Computes tendon lengths and moments."""
|
| 853 | + if m.impl == Impl.WARP and d.impl == Impl.WARP and mjxw.WARP_INSTALLED: |
| 854 | + from mujoco.mjx.warp import smooth as mjxw_smooth # pylint: disable=g-import-not-at-top # pytype: disable=import-error |
| 855 | + return mjxw_smooth.tendon(m, d) |
| 856 | + |
847 | 857 | if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX):
|
848 | 858 | raise ValueError('tendon requires JAX backend implementation.')
|
849 | 859 |
|
@@ -1091,7 +1101,7 @@ def _distance(p0, p1):
|
1091 | 1101 |
|
1092 | 1102 | # assemble length and moment
|
1093 | 1103 | ten_length = (
|
1094 |
| - jp.zeros_like(d._impl.ten_length).at[tendon_id_jnt].set(length_jnt) |
| 1104 | + jp.zeros_like(d.ten_length).at[tendon_id_jnt].set(length_jnt) |
1095 | 1105 | )
|
1096 | 1106 | ten_length = ten_length.at[tendon_id_site].add(length_site)
|
1097 | 1107 | ten_length = ten_length.at[tendon_id_geom].add(length_geom)
|
@@ -1161,7 +1171,7 @@ def _distance(p0, p1):
|
1161 | 1171 | ).reshape((m.nwrap, 2))
|
1162 | 1172 |
|
1163 | 1173 | return d.tree_replace({
|
1164 |
| - '_impl.ten_length': ten_length, |
| 1174 | + 'ten_length': ten_length, |
1165 | 1175 | '_impl.ten_J': ten_moment,
|
1166 | 1176 | '_impl.ten_wrapadr': jp.array(ten_wrapadr, dtype=int),
|
1167 | 1177 | '_impl.ten_wrapnum': jp.array(ten_wrapnum, dtype=int),
|
@@ -1263,7 +1273,7 @@ def fn(
|
1263 | 1273 | wrench = jp.concatenate((frame_xmat @ gear[:3], frame_xmat @ gear[3:]))
|
1264 | 1274 | moment = jac @ wrench
|
1265 | 1275 | elif trntype == TrnType.TENDON:
|
1266 |
| - length = d._impl.ten_length[trnid[0]] * gear[:1] |
| 1276 | + length = d.ten_length[trnid[0]] * gear[:1] |
1267 | 1277 | moment = d._impl.ten_J[trnid[0]] * gear[0]
|
1268 | 1278 | else:
|
1269 | 1279 | raise RuntimeError(f'unrecognized trntype: {TrnType(trntype)}')
|
|
0 commit comments