|
26 | 26 | from mujoco.mjx._src.types import DataJAX
|
27 | 27 | from mujoco.mjx._src.types import DisableBit
|
28 | 28 | from mujoco.mjx._src.types import EqType
|
29 |
| -from mujoco.mjx._src.types import Impl |
30 | 29 | from mujoco.mjx._src.types import JointType
|
31 | 30 | from mujoco.mjx._src.types import Model
|
32 | 31 | from mujoco.mjx._src.types import ModelJAX
|
33 | 32 | from mujoco.mjx._src.types import ObjType
|
34 | 33 | from mujoco.mjx._src.types import TrnType
|
35 | 34 | from mujoco.mjx._src.types import WrapType
|
36 | 35 | # pylint: enable=g-importing-member
|
37 |
| -import mujoco.mjx.warp as mjxw |
38 | 36 | import numpy as np
|
39 | 37 |
|
40 | 38 |
|
41 | 39 | def kinematics(m: Model, d: Data) -> Data:
|
42 | 40 | """Converts position/velocity from generalized coordinates to maximal."""
|
43 |
| - if m.impl == Impl.WARP and d.impl == Impl.WARP and mjxw.WARP_INSTALLED: |
44 |
| - from mujoco.mjx.warp import smooth as mjxw_smooth # pylint: disable=g-import-not-at-top # pytype: disable=import-error |
45 |
| - return mjxw_smooth.kinematics(m, d) |
46 |
| - |
47 | 41 | def fn(carry, jnt_typs, jnt_pos, jnt_axis, qpos, qpos0, pos, quat):
|
48 | 42 | # calculate joint anchors, axes, body pos and quat in global frame
|
49 | 43 | # also normalize qpos while we're at it
|
@@ -850,10 +844,6 @@ def _forward(carry, cfrc_ext, cinert, cvel, body_dofadr, body_dofnum):
|
850 | 844 |
|
851 | 845 | def tendon(m: Model, d: Data) -> Data:
|
852 | 846 | """Computes tendon lengths and moments."""
|
853 |
| - if m.impl == Impl.WARP and d.impl == Impl.WARP and mjxw.WARP_INSTALLED: |
854 |
| - from mujoco.mjx.warp import smooth as mjxw_smooth # pylint: disable=g-import-not-at-top # pytype: disable=import-error |
855 |
| - return mjxw_smooth.tendon(m, d) |
856 |
| - |
857 | 847 | if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX):
|
858 | 848 | raise ValueError('tendon requires JAX backend implementation.')
|
859 | 849 |
|
@@ -1101,7 +1091,7 @@ def _distance(p0, p1):
|
1101 | 1091 |
|
1102 | 1092 | # assemble length and moment
|
1103 | 1093 | ten_length = (
|
1104 |
| - jp.zeros_like(d.ten_length).at[tendon_id_jnt].set(length_jnt) |
| 1094 | + jp.zeros_like(d._impl.ten_length).at[tendon_id_jnt].set(length_jnt) |
1105 | 1095 | )
|
1106 | 1096 | ten_length = ten_length.at[tendon_id_site].add(length_site)
|
1107 | 1097 | ten_length = ten_length.at[tendon_id_geom].add(length_geom)
|
@@ -1171,7 +1161,7 @@ def _distance(p0, p1):
|
1171 | 1161 | ).reshape((m.nwrap, 2))
|
1172 | 1162 |
|
1173 | 1163 | return d.tree_replace({
|
1174 |
| - 'ten_length': ten_length, |
| 1164 | + '_impl.ten_length': ten_length, |
1175 | 1165 | '_impl.ten_J': ten_moment,
|
1176 | 1166 | '_impl.ten_wrapadr': jp.array(ten_wrapadr, dtype=int),
|
1177 | 1167 | '_impl.ten_wrapnum': jp.array(ten_wrapnum, dtype=int),
|
@@ -1273,7 +1263,7 @@ def fn(
|
1273 | 1263 | wrench = jp.concatenate((frame_xmat @ gear[:3], frame_xmat @ gear[3:]))
|
1274 | 1264 | moment = jac @ wrench
|
1275 | 1265 | elif trntype == TrnType.TENDON:
|
1276 |
| - length = d.ten_length[trnid[0]] * gear[:1] |
| 1266 | + length = d._impl.ten_length[trnid[0]] * gear[:1] |
1277 | 1267 | moment = d._impl.ten_J[trnid[0]] * gear[0]
|
1278 | 1268 | else:
|
1279 | 1269 | raise RuntimeError(f'unrecognized trntype: {TrnType(trntype)}')
|
|
0 commit comments