Skip to content

model.bind() with a slice of length 1 unexpectedly squeezes the batch dimension #3128

@hartikainen

Description

@hartikainen

Intro

Hi!

My setup

python -c "import mujoco; print(mujoco.__version__)"
3.5.0

What's happening? What did you expect?

When using the .bind() method in the Python bindings (either for standard MuJoCo or MJX) on a list/slice of objects of length one, the binding unexpectedly squeezes the batch dimension.

For example, model.bind(geoms[0:2]).size returns an array of shape (2, 3) (which is correct), but model.bind(geoms[0:1]).size returns an array of shape (3,), which is surprising to me. I expected the latter to return an array of shape (1, 3) to preserve the batch dimension, keeping it consistent with slicing of length > 1. This is a major potential source of error unless the user checks for the input length (with len(geoms)) or explicitly unsqueezes the result (with output.reshape(-1, 3)).

Note that this happens with both model and data fields.

Steps for reproduction

  1. Load a model with at least two bodies/geoms.
  2. Bind a subset using a slice of length > 1 and check the shape (preserves batch dim).
  3. Bind a subset using a slice of length == 1 and check the shape.
  4. See that the dimension drops from 2D (1, N) to 1D (N,).

Minimal model for reproduction

Minimal XML
<mujoco model="test_bind_model">
    <worldbody>
      <body pos="10 20 30" name="body1">
        <joint axis="1 0 0" type="ball" name="joint1"/>
        <geom size="1 2 3" type="box" name="geom1" pos="0 1 0"/>
        <geom size="4 5 6" type="box" name="geom2" pos="0 2 0"/>
      </body>
    </worldbody>
</mujoco>

Code required for reproduction

import mujoco
from mujoco import mjx
import numpy as np

xml = """
<mujoco model="test_bind_model">
    <worldbody>
      <body pos="10 20 30" name="body1">
        <joint axis="1 0 0" type="ball" name="joint1"/>
        <geom size="1 2 3" type="box" name="geom1" pos="0 1 0"/>
        <geom size="4 5 6" type="box" name="geom2" pos="0 2 0"/>
      </body>
    </worldbody>
</mujoco>
"""

s = mujoco.MjSpec.from_string(xml)
m = s.compile()
mx = mjx.put_model(m)

geoms = s.geoms[0:2]

print("--- MuJoCo (m.bind) ---")
print("m.bind(geoms[0:2]).size.shape:", np.array(m.bind(geoms[0:2]).size).shape)  # Expected (2, 3)
# Actual (2, 3)
print("m.bind(geoms[0:1]).size.shape:", np.array(m.bind(geoms[0:1]).size).shape)  # Expected (1, 3)
# Actual (3,)

print("\n--- MJX (mx.bind) ---")
print("mx.bind(geoms[0:2]).size.shape:", mx.bind(geoms[0:2]).size.shape)          # Expected (2, 3)
# Actual (2, 3)
print("mx.bind(geoms[0:1]).size.shape:", mx.bind(geoms[0:1]).size.shape)          # Expected (1, 3)
# Actual (3,)

main...hartikainen:mujoco:bind-squeeze also includes a simple test case for this.

Confirmations

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions