Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 157 additions & 10 deletions mjx/mujoco/mjx/_src/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# ==============================================================================
"""Functions for ray interesection testing."""

from functools import partial
from typing import Sequence, Tuple

import jax
from jax import numpy as jp
from jax.scipy.ndimage import map_coordinates
import mujoco
from mujoco.mjx._src import math
# pylint: disable=g-importing-member
Expand Down Expand Up @@ -154,6 +156,29 @@ def _ray_box(
return jp.min(jp.where(valid, x, jp.inf))


def _ray_box_6(
size: jax.Array, pnt: jax.Array, vec: jax.Array
) -> jax.Array:
"""Returns intersection distances for all 6 faces of a box."""
# replace zero vec components with small number to avoid division by zero
safe_vec = jp.where(jp.abs(vec) < mujoco.mjMINVAL, mujoco.mjMINVAL, vec)
iface = jp.array([(1, 2), (1, 2), (0, 2), (0, 2), (0, 1), (0, 1)])

# distances to planes for each of the 6 faces (+x, -x, +y, -y, +z, -z)
x = jp.concatenate([(size - pnt) / safe_vec, (-size - pnt) / safe_vec])

# check if intersection points are within face bounds
p_intersect = pnt + x[:, None] * vec
p_check_dim1 = jp.abs(p_intersect[jp.arange(6), iface[:, 0]])
p_check_dim2 = jp.abs(p_intersect[jp.arange(6), iface[:, 1]])
valid = (p_check_dim1 <= size[iface[:, 0]]) & (
p_check_dim2 <= size[iface[:, 1]]
)
valid &= x >= 0

return jp.where(valid, x, jp.inf)


def _ray_triangle(
vert: jax.Array,
pnt: jax.Array,
Expand Down Expand Up @@ -219,6 +244,112 @@ def _ray_mesh(

return dist, id_

@partial(jax.jit, static_argnames=('nrow', 'ncol'))
def _ray_hfield_static(
hfield_data_flat: jax.Array,
size: jax.Array,
pnt: jax.Array,
vec: jax.Array,
adr: jax.Array,
nrow: int,
ncol: int,
) -> jax.Array:
"""JIT-compiled kernel to raycast against a single hfield size.
"""
# size: (xy_size_x, xy_size_y, height_range, base_thickness)
# Intersection with base box
base_size = jp.array([size[0], size[1], size[3] / 2.0])
base_pos = jp.array([0, 0, -size[3] / 2.0])
dist = _ray_box(base_size, pnt - base_pos, vec)

# Intersection with top box (containing terrain)
top_size = jp.array([size[0], size[1], size[2] / 2.0])
top_pos = jp.array([0, 0, size[2] / 2.0])
top_dists_all = _ray_box_6(top_size, pnt - top_pos, vec)
top_dist_min = jp.min(top_dists_all, initial=jp.inf)

def _intersect_surface() -> jax.Array:
r_idx_grid, c_idx_grid = jp.meshgrid(
jp.arange(nrow), jp.arange(ncol), indexing='ij'
)
flat_indices = (adr + r_idx_grid * ncol + c_idx_grid).flatten()
hfield_data = hfield_data_flat[flat_indices].reshape((nrow, ncol))

# 1. Test against all triangles in the grid.
# Use initial=jp.inf for safety against empty arrays if nrow/ncol <= 1
min_tri_dist = jp.inf
if nrow > 1 and ncol > 1:
dx = 2.0 * size[0] / (ncol - 1)
dy = 2.0 * size[1] / (nrow - 1)
x_coords = c_idx_grid * dx - size[0]
y_coords = r_idx_grid * dy - size[1]
z_coords = hfield_data * size[2]
v00 = jp.stack([x_coords[:-1, :-1], y_coords[:-1, :-1], z_coords[:-1, :-1]],
axis=-1)
v10 = jp.stack([x_coords[1:, :-1], y_coords[1:, :-1], z_coords[1:, :-1]],
axis=-1)
v01 = jp.stack([x_coords[:-1, 1:], y_coords[:-1, 1:], z_coords[:-1, 1:]],
axis=-1)
v11 = jp.stack([x_coords[1:, 1:], y_coords[1:, 1:], z_coords[1:, 1:]],
axis=-1)
tri1_verts = jp.stack([v00, v11, v10], axis=-2).reshape(-1, 3, 3)
tri2_verts = jp.stack([v00, v11, v01], axis=-2).reshape(-1, 3, 3)
verts = jp.concatenate([tri1_verts, tri2_verts])
basis = jp.array(math.orthogonals(math.normalize(vec))).T
tri_dists = jax.vmap(_ray_triangle, in_axes=(0, None, None, None))(
verts, pnt, vec, basis
)
min_tri_dist = jp.min(tri_dists, initial=jp.inf)

# 2. Test against the four vertical side faces of the top box.
# Replicates the C-code's 1D linear interpolation
# for side hits to ensure identical behavior at the boundaries.
d_sides = top_dists_all[0:4] # Distances for +x, -x, +y, -y faces
p_sides = pnt + d_sides[:, None] * vec

safe_dx = jp.where(ncol > 1, 2.0 * size[0] / (ncol - 1), 1.0)
safe_dy = jp.where(nrow > 1, 2.0 * size[1] / (nrow - 1), 1.0)

# Handle sides normal to X-axis (+x, -x faces)
y_float_x_sides = (p_sides[:2, 1] + size[1]) / safe_dy
y0_x_sides = jp.clip(jp.floor(y_float_x_sides).astype(int), 0, nrow - 2)

y0_rounded = jp.round(y0_x_sides).astype(int)
y1_rounded = jp.round(y0_x_sides + 1).astype(int)
x_indices = jp.array([ncol - 1, 0]) # Grid indices for +x and -x edges
z0_x = hfield_data[y0_rounded, x_indices]
z1_x = hfield_data[y1_rounded, x_indices]
interp_h_x = z0_x * (y0_x_sides + 1 - y_float_x_sides) + z1_x * (
y_float_x_sides - y0_x_sides
)

# Handle sides normal to Y-axis (+y, -y faces)
x_float_y_sides = (p_sides[2:, 0] + size[0]) / safe_dx
x0_y_sides = jp.clip(jp.floor(x_float_y_sides).astype(int), 0, ncol - 2)


x0_rounded = jp.round(x0_y_sides).astype(int)
x1_rounded = jp.round(x0_y_sides + 1).astype(int)
y_indices = jp.array([nrow - 1, 0]) # Grid indices for +y and -y edges
z0_y = hfield_data[y_indices, x0_rounded]
z1_y = hfield_data[y_indices, x1_rounded]
interp_h_y = z0_y * (x0_y_sides + 1 - x_float_y_sides) + z1_y * (
x_float_y_sides - x0_y_sides
)

# Combine interpolated heights, scale to world units, and check validity
interp_h_norm = jp.concatenate([interp_h_x, interp_h_y])
interp_h = interp_h_norm * size[2]
valid_side_hit = p_sides[:, 2] < interp_h
side_dists = jp.where(valid_side_hit, d_sides, jp.inf)
min_side_dist = jp.min(side_dists, initial=jp.inf)

return jp.minimum(min_tri_dist, min_side_dist)

dist_surface = jax.lax.cond(
jp.isinf(top_dist_min), lambda: jp.inf, _intersect_surface
)
return jp.minimum(dist, dist_surface)

_RAY_FUNC = {
GeomType.PLANE: _ray_plane,
Expand All @@ -227,6 +358,7 @@ def _ray_mesh(
GeomType.ELLIPSOID: _ray_ellipsoid,
GeomType.BOX: _ray_box,
GeomType.MESH: _ray_mesh,
GeomType.HFIELD: _ray_hfield_static,
}


Expand Down Expand Up @@ -254,7 +386,6 @@ def ray(
dist: distance from ray origin to geom surface (or -1.0 for no intersection)
id: id of intersected geom (or -1 for no intersection)
"""

dists, ids = [], []
geom_filter = m.geom_bodyid != bodyexclude
geom_filter &= flg_static | (m.body_weldid[m.geom_bodyid] != 0)
Expand All @@ -270,32 +401,48 @@ def ray(
geom_filter_dyn &= (m.geom_matid == -1) | (m.mat_rgba[m.geom_matid, 3] != 0)
for geom_type, fn in _RAY_FUNC.items():
(id_,) = np.nonzero(geom_filter & (m.geom_type == geom_type))

if id_.size == 0:
continue

args = m.geom_size[id_], geom_pnts[id_], geom_vecs[id_]


if geom_type == GeomType.MESH:
dist, id_ = fn(m, id_, *args)
args = (m, id_, m.geom_size[id_], geom_pnts[id_], geom_vecs[id_])
dist, id_ = fn(*args)
elif geom_type == GeomType.HFIELD:
hfield_dataid = m.geom_dataid[id_]
nrow = m.hfield_nrow[hfield_dataid][0]
ncol = m.hfield_ncol[hfield_dataid][0]
hfield_data_flat = jp.asarray(m.hfield_data)
hfield_id_for_geoms = m.geom_dataid[id_]

args = (
hfield_data_flat,
m.hfield_size[hfield_id_for_geoms][0],
geom_pnts[id_][0],
geom_vecs[id_][0],
jp.asarray(m.hfield_adr)[hfield_id_for_geoms][0],
nrow,
ncol,
)
dist = _ray_hfield_static(*args)
else:
# remove model and id from args for primitive functions
args = (m.geom_size[id_], geom_pnts[id_], geom_vecs[id_])
dist = jax.vmap(fn)(*args)

dist = jp.where(geom_filter_dyn[id_], dist, jp.inf)
dists, ids = dists + [dist], ids + [id_]

if not ids:
return jp.array(-1), jp.array(-1.0)
return jp.array(-1.0), jp.array(-1)

dists = jp.concatenate(dists)
ids = jp.concatenate(ids)
min_id = jp.argmin(dists)
dist = jp.where(jp.isinf(dists[min_id]), -1, dists[min_id])
dist = jp.where(jp.isinf(dists[min_id]), -1.0, dists[min_id])
id_ = jp.where(jp.isinf(dists[min_id]), -1, ids[min_id])

return dist, id_


def ray_geom(
size: jax.Array, pnt: jax.Array, vec: jax.Array, geomtype: GeomType
) -> jax.Array:
Expand All @@ -310,4 +457,4 @@ def ray_geom(
Returns:
dist: distance from ray origin to geom surface
"""
return _RAY_FUNC[geomtype](size, pnt, vec)
return _RAY_FUNC[geomtype](size, pnt, vec)
99 changes: 99 additions & 0 deletions mjx/mujoco/mjx/_src/ray_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,105 @@ def test_ray_invisible(self):
_assert_eq(geomid, -1, 'geom_id')
_assert_eq(dist, -1, 'dist')

def test_ray_hfield(self):
"""Tests that MJX ray<>hfield matches MuJoCo."""
m = test_util.load_test_file('ray.xml')
d = mujoco.MjData(m)
mujoco.mj_forward(m, d)
mx, dx = mjx.put_model(m), mjx.put_data(m, d)
ray_fn = jax.jit(mjx.ray, static_argnums=(4,))
# Find hfield geom ID to test only hfield intersections
hfield_geom_id = -1
for i in range(m.geom_type.shape[0]):
if m.geom_type[i] == mujoco.mjx._src.types.GeomType.HFIELD:
hfield_geom_id = i
break

# Set all groups to False except the one containing hfield
geomgroup = np.zeros(mujoco.mjNGROUP, dtype=bool)
if hfield_geom_id >= 0:
hfield_group = m.geom_group[hfield_geom_id]
geomgroup[hfield_group] = True
geomgroup = tuple(geomgroup.tolist())

# Test 1: Single ray hitting hfield from directly above
# Hfield is at [20, 20, 20] with size [.6, .4, .1, .1]
pnt, vec = jp.array([20.0, 20.0, 25.0]), jp.array([0.0, 0.0, -1.0])
dist, geomid = ray_fn(mx, dx, pnt, vec, geomgroup)

mj_geomid = np.zeros(1, dtype=np.int32)
mj_dist = mujoco.mj_ray(m, d, pnt, vec, geomgroup, True, -1, mj_geomid)

assert mj_dist != -1, 'MuJoCo ray should hit hfield, test case setup might be wrong.'
_assert_eq(dist, mj_dist, 'hfield_single_dist')
_assert_eq(geomid, mj_geomid[0], 'hfield_single_geomid')

# Test 2: Sample random ray origin
# Hfield is at [20, 20, 20] with size [.6, .4, .1, .1]
x_positions = jp.linspace(19.5, 20.5, 5) # Within hfield x bounds
y_positions = jp.linspace(19.7, 20.3, 5) # Within hfield y bounds
z_start = 25.0 # Well above hfield

ray_directions = [
jp.array([0.0, 0.0, -1.0]), # Straight down
jp.array([0.1, 0.0, -1.0]), # Slight angle in x
jp.array([0.0, 0.1, -1.0]), # Slight angle in y
jp.array([0.1, 0.1, -1.0]), # Diagonal angle
]

test_count = 0
for x in x_positions:
for y in y_positions:
for vec_unnorm in ray_directions:
vec = vec_unnorm / jp.linalg.norm(vec_unnorm) # Normalize
pnt = jp.array([x, y, z_start])

# MJX ray (hfield only)
dist_mjx, geomid_mjx = ray_fn(mx, dx, pnt, vec, geomgroup)

# MuJoCo ground truth (hfield only)
mj_geomid = np.zeros(1, dtype=np.int32)
mj_dist = mujoco.mj_ray(m, d, pnt, vec, geomgroup, True, -1, mj_geomid)

# Assert equality
_assert_eq(dist_mjx, mj_dist, f'grid_dist_{test_count}')
_assert_eq(geomid_mjx, mj_geomid[0], f'grid_geomid_{test_count}')
test_count += 1

# Test 3: Rays that should miss hfield (outside bounds)
miss_tests = [
(jp.array([25.0, 20.0, 25.0]), jp.array([0.0, 0.0, -1.0])), # Outside x bounds
(jp.array([20.0, 25.0, 25.0]), jp.array([0.0, 0.0, -1.0])), # Outside y bounds
(jp.array([20.0, 20.0, 25.0]), jp.array([1.0, 0.0, 0.0])), # Horizontal ray
]

for i, (pnt_miss, vec_miss) in enumerate(miss_tests):
dist_miss, geomid_miss = ray_fn(mx, dx, pnt_miss, vec_miss, geomgroup)

mj_geomid_miss = np.zeros(1, dtype=np.int32)
mj_dist_miss = mujoco.mj_ray(m, d, pnt_miss, vec_miss, geomgroup, True, -1, mj_geomid_miss)

_assert_eq(dist_miss, mj_dist_miss, f'miss_dist_{i}')
_assert_eq(dist_miss, -1, f'miss_dist_{i}_check')
_assert_eq(geomid_miss, mj_geomid_miss[0], f'miss_geomid_{i}')

# Test 4: Angular rays from different positions
center_pos = jp.array([20.0, 20.0, 22.0]) # Above hfield center

# Random angles
angles = jp.linspace(0, jp.pi/4, 5) # 0 to 45 degrees from vertical
for i, angle in enumerate(angles):
# Create angled ray (rotating around x-axis)
vec_angled = jp.array([0.0, jp.sin(angle), -jp.cos(angle)])

dist_angled, geomid_angled = ray_fn(mx, dx, center_pos, vec_angled, geomgroup)

mj_geomid_angled = np.zeros(1, dtype=np.int32)
mj_dist_angled = mujoco.mj_ray(m, d, center_pos, vec_angled, geomgroup, True, -1, mj_geomid_angled)

_assert_eq(dist_angled, mj_dist_angled, f'angle_dist_{i}')
_assert_eq(geomid_angled, mj_geomid_angled[0], f'angle_geomid_{i}')


if __name__ == '__main__':
absltest.main()
2 changes: 2 additions & 0 deletions mjx/mujoco/mjx/test_data/ray.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
<asset>
<mesh name="tetrahedron" file="meshes/tetrahedron.stl" scale="0.4 0.4 0.4" />
<mesh name="dodecahedron" file="meshes/dodecahedron.stl" scale="0.04 0.04 0.04" />
<hfield name="hfield" nrow="3" ncol="2" elevation="1 2 3 3 2 1" size=".6 .4 .1 .1"/>
<texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
<material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
</asset>
Expand All @@ -14,5 +15,6 @@
<geom name="box" pos="1 0 1" quat="0 0.3826834 0 0.9238795" size="0.5 0.25 0.3" type="box" rgba="0 0 1 1"/>
<geom name="mesh" pos="1 1 1" quat="0 0 0.3826834 0.9238795" type="mesh" mesh="tetrahedron" rgba="1 1 0 1"/>
<geom name="mesh2" pos="2 1 1" type="mesh" mesh="dodecahedron" rgba="1 0 1 1"/>
<geom name="hfield" pos="20 20 20" type="hfield" hfield="hfield" rgba=".2 .4 .6 1"/>
</worldbody>
</mujoco>