-
Notifications
You must be signed in to change notification settings - Fork 1.4k
model.bind() with a slice of length 1 unexpectedly squeezes the batch dimension #3128
Description
Intro
Hi!
My setup
python -c "import mujoco; print(mujoco.__version__)"
3.5.0What's happening? What did you expect?
When using the .bind() method in the Python bindings (either for standard MuJoCo or MJX) on a list/slice of objects of length one, the binding unexpectedly squeezes the batch dimension.
For example, model.bind(geoms[0:2]).size returns an array of shape (2, 3) (which is correct), but model.bind(geoms[0:1]).size returns an array of shape (3,), which is surprising to me. I expected the latter to return an array of shape (1, 3) to preserve the batch dimension, keeping it consistent with slicing of length > 1. This is a major potential source of error unless the user checks for the input length (with len(geoms)) or explicitly unsqueezes the result (with output.reshape(-1, 3)).
Note that this happens with both model and data fields.
Steps for reproduction
- Load a model with at least two bodies/geoms.
- Bind a subset using a slice of length > 1 and check the shape (preserves batch dim).
- Bind a subset using a slice of length == 1 and check the shape.
- See that the dimension drops from 2D (1, N) to 1D (N,).
Minimal model for reproduction
Minimal XML
<mujoco model="test_bind_model">
<worldbody>
<body pos="10 20 30" name="body1">
<joint axis="1 0 0" type="ball" name="joint1"/>
<geom size="1 2 3" type="box" name="geom1" pos="0 1 0"/>
<geom size="4 5 6" type="box" name="geom2" pos="0 2 0"/>
</body>
</worldbody>
</mujoco>Code required for reproduction
import mujoco
from mujoco import mjx
import numpy as np
xml = """
<mujoco model="test_bind_model">
<worldbody>
<body pos="10 20 30" name="body1">
<joint axis="1 0 0" type="ball" name="joint1"/>
<geom size="1 2 3" type="box" name="geom1" pos="0 1 0"/>
<geom size="4 5 6" type="box" name="geom2" pos="0 2 0"/>
</body>
</worldbody>
</mujoco>
"""
s = mujoco.MjSpec.from_string(xml)
m = s.compile()
mx = mjx.put_model(m)
geoms = s.geoms[0:2]
print("--- MuJoCo (m.bind) ---")
print("m.bind(geoms[0:2]).size.shape:", np.array(m.bind(geoms[0:2]).size).shape) # Expected (2, 3)
# Actual (2, 3)
print("m.bind(geoms[0:1]).size.shape:", np.array(m.bind(geoms[0:1]).size).shape) # Expected (1, 3)
# Actual (3,)
print("\n--- MJX (mx.bind) ---")
print("mx.bind(geoms[0:2]).size.shape:", mx.bind(geoms[0:2]).size.shape) # Expected (2, 3)
# Actual (2, 3)
print("mx.bind(geoms[0:1]).size.shape:", mx.bind(geoms[0:1]).size.shape) # Expected (1, 3)
# Actual (3,)main...hartikainen:mujoco:bind-squeeze also includes a simple test case for this.
Confirmations
- I searched the latest documentation thoroughly before posting.
- I searched previous Issues and Discussions, I am certain this has not been raised before.