diff --git a/mjx/training_apg.ipynb b/mjx/training_apg.ipynb index fa9bc8a57b..ae890ce49c 100644 --- a/mjx/training_apg.ipynb +++ b/mjx/training_apg.ipynb @@ -231,7 +231,7 @@ "from jax import config # Analytical gradients work much better with double precision.\n", "config.update(\"jax_debug_nans\", True)\n", "config.update(\"jax_enable_x64\", True)\n", - "config.update('jax_default_matmul_precision', jax.lax.Precision.HIGH)\n", + "config.update('jax_default_matmul_precision', 'high')\n", "from brax import math\n", "\n", "# Sim\n", @@ -715,7 +715,8 @@ " offset = data.xpos[1:, :] - data.subtree_com[self.sys.body_rootid[1:]]\n", " offset = Transform.create(pos=offset)\n", " xd = offset.vmap().do(cvel)\n", - " data = _reformat_contact(self.sys, data)\n", + " mjx_contact = data._impl.contact if hasattr(data, '_impl') else data.contact\n", + " data = data.replace(contact=_reformat_contact(self.sys, mjx_contact))\n", " return data.replace(q=q, qd=qd, x=x, xd=xd)\n", "\n", "\n",