diff --git a/mjx/mujoco/mjx/_src/ray.py b/mjx/mujoco/mjx/_src/ray.py
index 2e2055cf04..32e215699d 100644
--- a/mjx/mujoco/mjx/_src/ray.py
+++ b/mjx/mujoco/mjx/_src/ray.py
@@ -14,10 +14,12 @@
# ==============================================================================
"""Functions for ray interesection testing."""
+from functools import partial
from typing import Sequence, Tuple
import jax
from jax import numpy as jp
+from jax.scipy.ndimage import map_coordinates
import mujoco
from mujoco.mjx._src import math
# pylint: disable=g-importing-member
@@ -154,6 +156,29 @@ def _ray_box(
return jp.min(jp.where(valid, x, jp.inf))
+def _ray_box_6(
+ size: jax.Array, pnt: jax.Array, vec: jax.Array
+) -> jax.Array:
+ """Returns intersection distances for all 6 faces of a box."""
+ # replace zero vec components with small number to avoid division by zero
+ safe_vec = jp.where(jp.abs(vec) < mujoco.mjMINVAL, mujoco.mjMINVAL, vec)
+ iface = jp.array([(1, 2), (1, 2), (0, 2), (0, 2), (0, 1), (0, 1)])
+
+ # distances to planes for each of the 6 faces (+x, -x, +y, -y, +z, -z)
+ x = jp.concatenate([(size - pnt) / safe_vec, (-size - pnt) / safe_vec])
+
+ # check if intersection points are within face bounds
+ p_intersect = pnt + x[:, None] * vec
+ p_check_dim1 = jp.abs(p_intersect[jp.arange(6), iface[:, 0]])
+ p_check_dim2 = jp.abs(p_intersect[jp.arange(6), iface[:, 1]])
+ valid = (p_check_dim1 <= size[iface[:, 0]]) & (
+ p_check_dim2 <= size[iface[:, 1]]
+ )
+ valid &= x >= 0
+
+ return jp.where(valid, x, jp.inf)
+
+
def _ray_triangle(
vert: jax.Array,
pnt: jax.Array,
@@ -219,6 +244,112 @@ def _ray_mesh(
return dist, id_
+@partial(jax.jit, static_argnames=('nrow', 'ncol'))
+def _ray_hfield_static(
+ hfield_data_flat: jax.Array,
+ size: jax.Array,
+ pnt: jax.Array,
+ vec: jax.Array,
+ adr: jax.Array,
+ nrow: int,
+ ncol: int,
+) -> jax.Array:
+ """JIT-compiled kernel to raycast against a single hfield size.
+ """
+ # size: (xy_size_x, xy_size_y, height_range, base_thickness)
+ # Intersection with base box
+ base_size = jp.array([size[0], size[1], size[3] / 2.0])
+ base_pos = jp.array([0, 0, -size[3] / 2.0])
+ dist = _ray_box(base_size, pnt - base_pos, vec)
+
+ # Intersection with top box (containing terrain)
+ top_size = jp.array([size[0], size[1], size[2] / 2.0])
+ top_pos = jp.array([0, 0, size[2] / 2.0])
+ top_dists_all = _ray_box_6(top_size, pnt - top_pos, vec)
+ top_dist_min = jp.min(top_dists_all, initial=jp.inf)
+
+ def _intersect_surface() -> jax.Array:
+ r_idx_grid, c_idx_grid = jp.meshgrid(
+ jp.arange(nrow), jp.arange(ncol), indexing='ij'
+ )
+ flat_indices = (adr + r_idx_grid * ncol + c_idx_grid).flatten()
+ hfield_data = hfield_data_flat[flat_indices].reshape((nrow, ncol))
+
+ # 1. Test against all triangles in the grid.
+ # Use initial=jp.inf for safety against empty arrays if nrow/ncol <= 1
+ min_tri_dist = jp.inf
+ if nrow > 1 and ncol > 1:
+ dx = 2.0 * size[0] / (ncol - 1)
+ dy = 2.0 * size[1] / (nrow - 1)
+ x_coords = c_idx_grid * dx - size[0]
+ y_coords = r_idx_grid * dy - size[1]
+ z_coords = hfield_data * size[2]
+ v00 = jp.stack([x_coords[:-1, :-1], y_coords[:-1, :-1], z_coords[:-1, :-1]],
+ axis=-1)
+ v10 = jp.stack([x_coords[1:, :-1], y_coords[1:, :-1], z_coords[1:, :-1]],
+ axis=-1)
+ v01 = jp.stack([x_coords[:-1, 1:], y_coords[:-1, 1:], z_coords[:-1, 1:]],
+ axis=-1)
+ v11 = jp.stack([x_coords[1:, 1:], y_coords[1:, 1:], z_coords[1:, 1:]],
+ axis=-1)
+ tri1_verts = jp.stack([v00, v11, v10], axis=-2).reshape(-1, 3, 3)
+ tri2_verts = jp.stack([v00, v11, v01], axis=-2).reshape(-1, 3, 3)
+ verts = jp.concatenate([tri1_verts, tri2_verts])
+ basis = jp.array(math.orthogonals(math.normalize(vec))).T
+ tri_dists = jax.vmap(_ray_triangle, in_axes=(0, None, None, None))(
+ verts, pnt, vec, basis
+ )
+ min_tri_dist = jp.min(tri_dists, initial=jp.inf)
+
+ # 2. Test against the four vertical side faces of the top box.
+ # Replicates the C-code's 1D linear interpolation
+ # for side hits to ensure identical behavior at the boundaries.
+ d_sides = top_dists_all[0:4] # Distances for +x, -x, +y, -y faces
+ p_sides = pnt + d_sides[:, None] * vec
+
+ safe_dx = jp.where(ncol > 1, 2.0 * size[0] / (ncol - 1), 1.0)
+ safe_dy = jp.where(nrow > 1, 2.0 * size[1] / (nrow - 1), 1.0)
+
+ # Handle sides normal to X-axis (+x, -x faces)
+ y_float_x_sides = (p_sides[:2, 1] + size[1]) / safe_dy
+ y0_x_sides = jp.clip(jp.floor(y_float_x_sides).astype(int), 0, nrow - 2)
+
+ y0_rounded = jp.round(y0_x_sides).astype(int)
+ y1_rounded = jp.round(y0_x_sides + 1).astype(int)
+ x_indices = jp.array([ncol - 1, 0]) # Grid indices for +x and -x edges
+ z0_x = hfield_data[y0_rounded, x_indices]
+ z1_x = hfield_data[y1_rounded, x_indices]
+ interp_h_x = z0_x * (y0_x_sides + 1 - y_float_x_sides) + z1_x * (
+ y_float_x_sides - y0_x_sides
+ )
+
+ # Handle sides normal to Y-axis (+y, -y faces)
+ x_float_y_sides = (p_sides[2:, 0] + size[0]) / safe_dx
+ x0_y_sides = jp.clip(jp.floor(x_float_y_sides).astype(int), 0, ncol - 2)
+
+
+ x0_rounded = jp.round(x0_y_sides).astype(int)
+ x1_rounded = jp.round(x0_y_sides + 1).astype(int)
+ y_indices = jp.array([nrow - 1, 0]) # Grid indices for +y and -y edges
+ z0_y = hfield_data[y_indices, x0_rounded]
+ z1_y = hfield_data[y_indices, x1_rounded]
+ interp_h_y = z0_y * (x0_y_sides + 1 - x_float_y_sides) + z1_y * (
+ x_float_y_sides - x0_y_sides
+ )
+
+ # Combine interpolated heights, scale to world units, and check validity
+ interp_h_norm = jp.concatenate([interp_h_x, interp_h_y])
+ interp_h = interp_h_norm * size[2]
+ valid_side_hit = p_sides[:, 2] < interp_h
+ side_dists = jp.where(valid_side_hit, d_sides, jp.inf)
+ min_side_dist = jp.min(side_dists, initial=jp.inf)
+
+ return jp.minimum(min_tri_dist, min_side_dist)
+
+ dist_surface = jax.lax.cond(
+ jp.isinf(top_dist_min), lambda: jp.inf, _intersect_surface
+ )
+ return jp.minimum(dist, dist_surface)
_RAY_FUNC = {
GeomType.PLANE: _ray_plane,
@@ -227,6 +358,7 @@ def _ray_mesh(
GeomType.ELLIPSOID: _ray_ellipsoid,
GeomType.BOX: _ray_box,
GeomType.MESH: _ray_mesh,
+ GeomType.HFIELD: _ray_hfield_static,
}
@@ -254,7 +386,6 @@ def ray(
dist: distance from ray origin to geom surface (or -1.0 for no intersection)
id: id of intersected geom (or -1 for no intersection)
"""
-
dists, ids = [], []
geom_filter = m.geom_bodyid != bodyexclude
geom_filter &= flg_static | (m.body_weldid[m.geom_bodyid] != 0)
@@ -270,32 +401,48 @@ def ray(
geom_filter_dyn &= (m.geom_matid == -1) | (m.mat_rgba[m.geom_matid, 3] != 0)
for geom_type, fn in _RAY_FUNC.items():
(id_,) = np.nonzero(geom_filter & (m.geom_type == geom_type))
-
if id_.size == 0:
continue
-
- args = m.geom_size[id_], geom_pnts[id_], geom_vecs[id_]
-
+
if geom_type == GeomType.MESH:
- dist, id_ = fn(m, id_, *args)
+ args = (m, id_, m.geom_size[id_], geom_pnts[id_], geom_vecs[id_])
+ dist, id_ = fn(*args)
+ elif geom_type == GeomType.HFIELD:
+ hfield_dataid = m.geom_dataid[id_]
+ nrow = m.hfield_nrow[hfield_dataid][0]
+ ncol = m.hfield_ncol[hfield_dataid][0]
+ hfield_data_flat = jp.asarray(m.hfield_data)
+ hfield_id_for_geoms = m.geom_dataid[id_]
+
+ args = (
+ hfield_data_flat,
+ m.hfield_size[hfield_id_for_geoms][0],
+ geom_pnts[id_][0],
+ geom_vecs[id_][0],
+ jp.asarray(m.hfield_adr)[hfield_id_for_geoms][0],
+ nrow,
+ ncol,
+ )
+ dist = _ray_hfield_static(*args)
else:
+ # remove model and id from args for primitive functions
+ args = (m.geom_size[id_], geom_pnts[id_], geom_vecs[id_])
dist = jax.vmap(fn)(*args)
dist = jp.where(geom_filter_dyn[id_], dist, jp.inf)
dists, ids = dists + [dist], ids + [id_]
if not ids:
- return jp.array(-1), jp.array(-1.0)
+ return jp.array(-1.0), jp.array(-1)
dists = jp.concatenate(dists)
ids = jp.concatenate(ids)
min_id = jp.argmin(dists)
- dist = jp.where(jp.isinf(dists[min_id]), -1, dists[min_id])
+ dist = jp.where(jp.isinf(dists[min_id]), -1.0, dists[min_id])
id_ = jp.where(jp.isinf(dists[min_id]), -1, ids[min_id])
return dist, id_
-
def ray_geom(
size: jax.Array, pnt: jax.Array, vec: jax.Array, geomtype: GeomType
) -> jax.Array:
@@ -310,4 +457,4 @@ def ray_geom(
Returns:
dist: distance from ray origin to geom surface
"""
- return _RAY_FUNC[geomtype](size, pnt, vec)
+ return _RAY_FUNC[geomtype](size, pnt, vec)
\ No newline at end of file
diff --git a/mjx/mujoco/mjx/_src/ray_test.py b/mjx/mujoco/mjx/_src/ray_test.py
index 505878d32e..28e59723c5 100644
--- a/mjx/mujoco/mjx/_src/ray_test.py
+++ b/mjx/mujoco/mjx/_src/ray_test.py
@@ -250,6 +250,105 @@ def test_ray_invisible(self):
_assert_eq(geomid, -1, 'geom_id')
_assert_eq(dist, -1, 'dist')
+ def test_ray_hfield(self):
+ """Tests that MJX ray<>hfield matches MuJoCo."""
+ m = test_util.load_test_file('ray.xml')
+ d = mujoco.MjData(m)
+ mujoco.mj_forward(m, d)
+ mx, dx = mjx.put_model(m), mjx.put_data(m, d)
+ ray_fn = jax.jit(mjx.ray, static_argnums=(4,))
+ # Find hfield geom ID to test only hfield intersections
+ hfield_geom_id = -1
+ for i in range(m.geom_type.shape[0]):
+ if m.geom_type[i] == mujoco.mjx._src.types.GeomType.HFIELD:
+ hfield_geom_id = i
+ break
+
+ # Set all groups to False except the one containing hfield
+ geomgroup = np.zeros(mujoco.mjNGROUP, dtype=bool)
+ if hfield_geom_id >= 0:
+ hfield_group = m.geom_group[hfield_geom_id]
+ geomgroup[hfield_group] = True
+ geomgroup = tuple(geomgroup.tolist())
+
+ # Test 1: Single ray hitting hfield from directly above
+ # Hfield is at [20, 20, 20] with size [.6, .4, .1, .1]
+ pnt, vec = jp.array([20.0, 20.0, 25.0]), jp.array([0.0, 0.0, -1.0])
+ dist, geomid = ray_fn(mx, dx, pnt, vec, geomgroup)
+
+ mj_geomid = np.zeros(1, dtype=np.int32)
+ mj_dist = mujoco.mj_ray(m, d, pnt, vec, geomgroup, True, -1, mj_geomid)
+
+ assert mj_dist != -1, 'MuJoCo ray should hit hfield, test case setup might be wrong.'
+ _assert_eq(dist, mj_dist, 'hfield_single_dist')
+ _assert_eq(geomid, mj_geomid[0], 'hfield_single_geomid')
+
+ # Test 2: Sample random ray origin
+ # Hfield is at [20, 20, 20] with size [.6, .4, .1, .1]
+ x_positions = jp.linspace(19.5, 20.5, 5) # Within hfield x bounds
+ y_positions = jp.linspace(19.7, 20.3, 5) # Within hfield y bounds
+ z_start = 25.0 # Well above hfield
+
+ ray_directions = [
+ jp.array([0.0, 0.0, -1.0]), # Straight down
+ jp.array([0.1, 0.0, -1.0]), # Slight angle in x
+ jp.array([0.0, 0.1, -1.0]), # Slight angle in y
+ jp.array([0.1, 0.1, -1.0]), # Diagonal angle
+ ]
+
+ test_count = 0
+ for x in x_positions:
+ for y in y_positions:
+ for vec_unnorm in ray_directions:
+ vec = vec_unnorm / jp.linalg.norm(vec_unnorm) # Normalize
+ pnt = jp.array([x, y, z_start])
+
+ # MJX ray (hfield only)
+ dist_mjx, geomid_mjx = ray_fn(mx, dx, pnt, vec, geomgroup)
+
+ # MuJoCo ground truth (hfield only)
+ mj_geomid = np.zeros(1, dtype=np.int32)
+ mj_dist = mujoco.mj_ray(m, d, pnt, vec, geomgroup, True, -1, mj_geomid)
+
+ # Assert equality
+ _assert_eq(dist_mjx, mj_dist, f'grid_dist_{test_count}')
+ _assert_eq(geomid_mjx, mj_geomid[0], f'grid_geomid_{test_count}')
+ test_count += 1
+
+ # Test 3: Rays that should miss hfield (outside bounds)
+ miss_tests = [
+ (jp.array([25.0, 20.0, 25.0]), jp.array([0.0, 0.0, -1.0])), # Outside x bounds
+ (jp.array([20.0, 25.0, 25.0]), jp.array([0.0, 0.0, -1.0])), # Outside y bounds
+ (jp.array([20.0, 20.0, 25.0]), jp.array([1.0, 0.0, 0.0])), # Horizontal ray
+ ]
+
+ for i, (pnt_miss, vec_miss) in enumerate(miss_tests):
+ dist_miss, geomid_miss = ray_fn(mx, dx, pnt_miss, vec_miss, geomgroup)
+
+ mj_geomid_miss = np.zeros(1, dtype=np.int32)
+ mj_dist_miss = mujoco.mj_ray(m, d, pnt_miss, vec_miss, geomgroup, True, -1, mj_geomid_miss)
+
+ _assert_eq(dist_miss, mj_dist_miss, f'miss_dist_{i}')
+ _assert_eq(dist_miss, -1, f'miss_dist_{i}_check')
+ _assert_eq(geomid_miss, mj_geomid_miss[0], f'miss_geomid_{i}')
+
+ # Test 4: Angular rays from different positions
+ center_pos = jp.array([20.0, 20.0, 22.0]) # Above hfield center
+
+ # Random angles
+ angles = jp.linspace(0, jp.pi/4, 5) # 0 to 45 degrees from vertical
+ for i, angle in enumerate(angles):
+ # Create angled ray (rotating around x-axis)
+ vec_angled = jp.array([0.0, jp.sin(angle), -jp.cos(angle)])
+
+ dist_angled, geomid_angled = ray_fn(mx, dx, center_pos, vec_angled, geomgroup)
+
+ mj_geomid_angled = np.zeros(1, dtype=np.int32)
+ mj_dist_angled = mujoco.mj_ray(m, d, center_pos, vec_angled, geomgroup, True, -1, mj_geomid_angled)
+
+ _assert_eq(dist_angled, mj_dist_angled, f'angle_dist_{i}')
+ _assert_eq(geomid_angled, mj_geomid_angled[0], f'angle_geomid_{i}')
+
if __name__ == '__main__':
absltest.main()
diff --git a/mjx/mujoco/mjx/test_data/ray.xml b/mjx/mujoco/mjx/test_data/ray.xml
index bbf42b1e46..e920be3da8 100644
--- a/mjx/mujoco/mjx/test_data/ray.xml
+++ b/mjx/mujoco/mjx/test_data/ray.xml
@@ -2,6 +2,7 @@
+
@@ -14,5 +15,6 @@
+