From 681e5f2d5708fb17936453029f968ed13780f9fe Mon Sep 17 00:00:00 2001 From: finn Date: Fri, 15 Aug 2025 10:35:24 +0200 Subject: [PATCH 1/2] Added mjx ray_hfield implementation. Matches the C version exactly, but requires static compilation of hfield size. --- mjx/mujoco/mjx/_src/ray.py | 167 ++++++++++++++++++++++++++++++++++--- 1 file changed, 157 insertions(+), 10 deletions(-) diff --git a/mjx/mujoco/mjx/_src/ray.py b/mjx/mujoco/mjx/_src/ray.py index 2e2055cf04..32e215699d 100644 --- a/mjx/mujoco/mjx/_src/ray.py +++ b/mjx/mujoco/mjx/_src/ray.py @@ -14,10 +14,12 @@ # ============================================================================== """Functions for ray interesection testing.""" +from functools import partial from typing import Sequence, Tuple import jax from jax import numpy as jp +from jax.scipy.ndimage import map_coordinates import mujoco from mujoco.mjx._src import math # pylint: disable=g-importing-member @@ -154,6 +156,29 @@ def _ray_box( return jp.min(jp.where(valid, x, jp.inf)) +def _ray_box_6( + size: jax.Array, pnt: jax.Array, vec: jax.Array +) -> jax.Array: + """Returns intersection distances for all 6 faces of a box.""" + # replace zero vec components with small number to avoid division by zero + safe_vec = jp.where(jp.abs(vec) < mujoco.mjMINVAL, mujoco.mjMINVAL, vec) + iface = jp.array([(1, 2), (1, 2), (0, 2), (0, 2), (0, 1), (0, 1)]) + + # distances to planes for each of the 6 faces (+x, -x, +y, -y, +z, -z) + x = jp.concatenate([(size - pnt) / safe_vec, (-size - pnt) / safe_vec]) + + # check if intersection points are within face bounds + p_intersect = pnt + x[:, None] * vec + p_check_dim1 = jp.abs(p_intersect[jp.arange(6), iface[:, 0]]) + p_check_dim2 = jp.abs(p_intersect[jp.arange(6), iface[:, 1]]) + valid = (p_check_dim1 <= size[iface[:, 0]]) & ( + p_check_dim2 <= size[iface[:, 1]] + ) + valid &= x >= 0 + + return jp.where(valid, x, jp.inf) + + def _ray_triangle( vert: jax.Array, pnt: jax.Array, @@ -219,6 +244,112 @@ def _ray_mesh( return dist, id_ +@partial(jax.jit, static_argnames=('nrow', 'ncol')) +def _ray_hfield_static( + hfield_data_flat: jax.Array, + size: jax.Array, + pnt: jax.Array, + vec: jax.Array, + adr: jax.Array, + nrow: int, + ncol: int, +) -> jax.Array: + """JIT-compiled kernel to raycast against a single hfield size. + """ + # size: (xy_size_x, xy_size_y, height_range, base_thickness) + # Intersection with base box + base_size = jp.array([size[0], size[1], size[3] / 2.0]) + base_pos = jp.array([0, 0, -size[3] / 2.0]) + dist = _ray_box(base_size, pnt - base_pos, vec) + + # Intersection with top box (containing terrain) + top_size = jp.array([size[0], size[1], size[2] / 2.0]) + top_pos = jp.array([0, 0, size[2] / 2.0]) + top_dists_all = _ray_box_6(top_size, pnt - top_pos, vec) + top_dist_min = jp.min(top_dists_all, initial=jp.inf) + + def _intersect_surface() -> jax.Array: + r_idx_grid, c_idx_grid = jp.meshgrid( + jp.arange(nrow), jp.arange(ncol), indexing='ij' + ) + flat_indices = (adr + r_idx_grid * ncol + c_idx_grid).flatten() + hfield_data = hfield_data_flat[flat_indices].reshape((nrow, ncol)) + + # 1. Test against all triangles in the grid. + # Use initial=jp.inf for safety against empty arrays if nrow/ncol <= 1 + min_tri_dist = jp.inf + if nrow > 1 and ncol > 1: + dx = 2.0 * size[0] / (ncol - 1) + dy = 2.0 * size[1] / (nrow - 1) + x_coords = c_idx_grid * dx - size[0] + y_coords = r_idx_grid * dy - size[1] + z_coords = hfield_data * size[2] + v00 = jp.stack([x_coords[:-1, :-1], y_coords[:-1, :-1], z_coords[:-1, :-1]], + axis=-1) + v10 = jp.stack([x_coords[1:, :-1], y_coords[1:, :-1], z_coords[1:, :-1]], + axis=-1) + v01 = jp.stack([x_coords[:-1, 1:], y_coords[:-1, 1:], z_coords[:-1, 1:]], + axis=-1) + v11 = jp.stack([x_coords[1:, 1:], y_coords[1:, 1:], z_coords[1:, 1:]], + axis=-1) + tri1_verts = jp.stack([v00, v11, v10], axis=-2).reshape(-1, 3, 3) + tri2_verts = jp.stack([v00, v11, v01], axis=-2).reshape(-1, 3, 3) + verts = jp.concatenate([tri1_verts, tri2_verts]) + basis = jp.array(math.orthogonals(math.normalize(vec))).T + tri_dists = jax.vmap(_ray_triangle, in_axes=(0, None, None, None))( + verts, pnt, vec, basis + ) + min_tri_dist = jp.min(tri_dists, initial=jp.inf) + + # 2. Test against the four vertical side faces of the top box. + # Replicates the C-code's 1D linear interpolation + # for side hits to ensure identical behavior at the boundaries. + d_sides = top_dists_all[0:4] # Distances for +x, -x, +y, -y faces + p_sides = pnt + d_sides[:, None] * vec + + safe_dx = jp.where(ncol > 1, 2.0 * size[0] / (ncol - 1), 1.0) + safe_dy = jp.where(nrow > 1, 2.0 * size[1] / (nrow - 1), 1.0) + + # Handle sides normal to X-axis (+x, -x faces) + y_float_x_sides = (p_sides[:2, 1] + size[1]) / safe_dy + y0_x_sides = jp.clip(jp.floor(y_float_x_sides).astype(int), 0, nrow - 2) + + y0_rounded = jp.round(y0_x_sides).astype(int) + y1_rounded = jp.round(y0_x_sides + 1).astype(int) + x_indices = jp.array([ncol - 1, 0]) # Grid indices for +x and -x edges + z0_x = hfield_data[y0_rounded, x_indices] + z1_x = hfield_data[y1_rounded, x_indices] + interp_h_x = z0_x * (y0_x_sides + 1 - y_float_x_sides) + z1_x * ( + y_float_x_sides - y0_x_sides + ) + + # Handle sides normal to Y-axis (+y, -y faces) + x_float_y_sides = (p_sides[2:, 0] + size[0]) / safe_dx + x0_y_sides = jp.clip(jp.floor(x_float_y_sides).astype(int), 0, ncol - 2) + + + x0_rounded = jp.round(x0_y_sides).astype(int) + x1_rounded = jp.round(x0_y_sides + 1).astype(int) + y_indices = jp.array([nrow - 1, 0]) # Grid indices for +y and -y edges + z0_y = hfield_data[y_indices, x0_rounded] + z1_y = hfield_data[y_indices, x1_rounded] + interp_h_y = z0_y * (x0_y_sides + 1 - x_float_y_sides) + z1_y * ( + x_float_y_sides - x0_y_sides + ) + + # Combine interpolated heights, scale to world units, and check validity + interp_h_norm = jp.concatenate([interp_h_x, interp_h_y]) + interp_h = interp_h_norm * size[2] + valid_side_hit = p_sides[:, 2] < interp_h + side_dists = jp.where(valid_side_hit, d_sides, jp.inf) + min_side_dist = jp.min(side_dists, initial=jp.inf) + + return jp.minimum(min_tri_dist, min_side_dist) + + dist_surface = jax.lax.cond( + jp.isinf(top_dist_min), lambda: jp.inf, _intersect_surface + ) + return jp.minimum(dist, dist_surface) _RAY_FUNC = { GeomType.PLANE: _ray_plane, @@ -227,6 +358,7 @@ def _ray_mesh( GeomType.ELLIPSOID: _ray_ellipsoid, GeomType.BOX: _ray_box, GeomType.MESH: _ray_mesh, + GeomType.HFIELD: _ray_hfield_static, } @@ -254,7 +386,6 @@ def ray( dist: distance from ray origin to geom surface (or -1.0 for no intersection) id: id of intersected geom (or -1 for no intersection) """ - dists, ids = [], [] geom_filter = m.geom_bodyid != bodyexclude geom_filter &= flg_static | (m.body_weldid[m.geom_bodyid] != 0) @@ -270,32 +401,48 @@ def ray( geom_filter_dyn &= (m.geom_matid == -1) | (m.mat_rgba[m.geom_matid, 3] != 0) for geom_type, fn in _RAY_FUNC.items(): (id_,) = np.nonzero(geom_filter & (m.geom_type == geom_type)) - if id_.size == 0: continue - - args = m.geom_size[id_], geom_pnts[id_], geom_vecs[id_] - + if geom_type == GeomType.MESH: - dist, id_ = fn(m, id_, *args) + args = (m, id_, m.geom_size[id_], geom_pnts[id_], geom_vecs[id_]) + dist, id_ = fn(*args) + elif geom_type == GeomType.HFIELD: + hfield_dataid = m.geom_dataid[id_] + nrow = m.hfield_nrow[hfield_dataid][0] + ncol = m.hfield_ncol[hfield_dataid][0] + hfield_data_flat = jp.asarray(m.hfield_data) + hfield_id_for_geoms = m.geom_dataid[id_] + + args = ( + hfield_data_flat, + m.hfield_size[hfield_id_for_geoms][0], + geom_pnts[id_][0], + geom_vecs[id_][0], + jp.asarray(m.hfield_adr)[hfield_id_for_geoms][0], + nrow, + ncol, + ) + dist = _ray_hfield_static(*args) else: + # remove model and id from args for primitive functions + args = (m.geom_size[id_], geom_pnts[id_], geom_vecs[id_]) dist = jax.vmap(fn)(*args) dist = jp.where(geom_filter_dyn[id_], dist, jp.inf) dists, ids = dists + [dist], ids + [id_] if not ids: - return jp.array(-1), jp.array(-1.0) + return jp.array(-1.0), jp.array(-1) dists = jp.concatenate(dists) ids = jp.concatenate(ids) min_id = jp.argmin(dists) - dist = jp.where(jp.isinf(dists[min_id]), -1, dists[min_id]) + dist = jp.where(jp.isinf(dists[min_id]), -1.0, dists[min_id]) id_ = jp.where(jp.isinf(dists[min_id]), -1, ids[min_id]) return dist, id_ - def ray_geom( size: jax.Array, pnt: jax.Array, vec: jax.Array, geomtype: GeomType ) -> jax.Array: @@ -310,4 +457,4 @@ def ray_geom( Returns: dist: distance from ray origin to geom surface """ - return _RAY_FUNC[geomtype](size, pnt, vec) + return _RAY_FUNC[geomtype](size, pnt, vec) \ No newline at end of file From 1835e6ed2144503ff25779b55b3292e313223c76 Mon Sep 17 00:00:00 2001 From: finn Date: Fri, 15 Aug 2025 12:05:47 +0200 Subject: [PATCH 2/2] Added tests for ray_hfield. --- mjx/mujoco/mjx/_src/ray_test.py | 99 ++++++++++++++++++++++++++++++++ mjx/mujoco/mjx/test_data/ray.xml | 2 + 2 files changed, 101 insertions(+) diff --git a/mjx/mujoco/mjx/_src/ray_test.py b/mjx/mujoco/mjx/_src/ray_test.py index 505878d32e..28e59723c5 100644 --- a/mjx/mujoco/mjx/_src/ray_test.py +++ b/mjx/mujoco/mjx/_src/ray_test.py @@ -250,6 +250,105 @@ def test_ray_invisible(self): _assert_eq(geomid, -1, 'geom_id') _assert_eq(dist, -1, 'dist') + def test_ray_hfield(self): + """Tests that MJX ray<>hfield matches MuJoCo.""" + m = test_util.load_test_file('ray.xml') + d = mujoco.MjData(m) + mujoco.mj_forward(m, d) + mx, dx = mjx.put_model(m), mjx.put_data(m, d) + ray_fn = jax.jit(mjx.ray, static_argnums=(4,)) + # Find hfield geom ID to test only hfield intersections + hfield_geom_id = -1 + for i in range(m.geom_type.shape[0]): + if m.geom_type[i] == mujoco.mjx._src.types.GeomType.HFIELD: + hfield_geom_id = i + break + + # Set all groups to False except the one containing hfield + geomgroup = np.zeros(mujoco.mjNGROUP, dtype=bool) + if hfield_geom_id >= 0: + hfield_group = m.geom_group[hfield_geom_id] + geomgroup[hfield_group] = True + geomgroup = tuple(geomgroup.tolist()) + + # Test 1: Single ray hitting hfield from directly above + # Hfield is at [20, 20, 20] with size [.6, .4, .1, .1] + pnt, vec = jp.array([20.0, 20.0, 25.0]), jp.array([0.0, 0.0, -1.0]) + dist, geomid = ray_fn(mx, dx, pnt, vec, geomgroup) + + mj_geomid = np.zeros(1, dtype=np.int32) + mj_dist = mujoco.mj_ray(m, d, pnt, vec, geomgroup, True, -1, mj_geomid) + + assert mj_dist != -1, 'MuJoCo ray should hit hfield, test case setup might be wrong.' + _assert_eq(dist, mj_dist, 'hfield_single_dist') + _assert_eq(geomid, mj_geomid[0], 'hfield_single_geomid') + + # Test 2: Sample random ray origin + # Hfield is at [20, 20, 20] with size [.6, .4, .1, .1] + x_positions = jp.linspace(19.5, 20.5, 5) # Within hfield x bounds + y_positions = jp.linspace(19.7, 20.3, 5) # Within hfield y bounds + z_start = 25.0 # Well above hfield + + ray_directions = [ + jp.array([0.0, 0.0, -1.0]), # Straight down + jp.array([0.1, 0.0, -1.0]), # Slight angle in x + jp.array([0.0, 0.1, -1.0]), # Slight angle in y + jp.array([0.1, 0.1, -1.0]), # Diagonal angle + ] + + test_count = 0 + for x in x_positions: + for y in y_positions: + for vec_unnorm in ray_directions: + vec = vec_unnorm / jp.linalg.norm(vec_unnorm) # Normalize + pnt = jp.array([x, y, z_start]) + + # MJX ray (hfield only) + dist_mjx, geomid_mjx = ray_fn(mx, dx, pnt, vec, geomgroup) + + # MuJoCo ground truth (hfield only) + mj_geomid = np.zeros(1, dtype=np.int32) + mj_dist = mujoco.mj_ray(m, d, pnt, vec, geomgroup, True, -1, mj_geomid) + + # Assert equality + _assert_eq(dist_mjx, mj_dist, f'grid_dist_{test_count}') + _assert_eq(geomid_mjx, mj_geomid[0], f'grid_geomid_{test_count}') + test_count += 1 + + # Test 3: Rays that should miss hfield (outside bounds) + miss_tests = [ + (jp.array([25.0, 20.0, 25.0]), jp.array([0.0, 0.0, -1.0])), # Outside x bounds + (jp.array([20.0, 25.0, 25.0]), jp.array([0.0, 0.0, -1.0])), # Outside y bounds + (jp.array([20.0, 20.0, 25.0]), jp.array([1.0, 0.0, 0.0])), # Horizontal ray + ] + + for i, (pnt_miss, vec_miss) in enumerate(miss_tests): + dist_miss, geomid_miss = ray_fn(mx, dx, pnt_miss, vec_miss, geomgroup) + + mj_geomid_miss = np.zeros(1, dtype=np.int32) + mj_dist_miss = mujoco.mj_ray(m, d, pnt_miss, vec_miss, geomgroup, True, -1, mj_geomid_miss) + + _assert_eq(dist_miss, mj_dist_miss, f'miss_dist_{i}') + _assert_eq(dist_miss, -1, f'miss_dist_{i}_check') + _assert_eq(geomid_miss, mj_geomid_miss[0], f'miss_geomid_{i}') + + # Test 4: Angular rays from different positions + center_pos = jp.array([20.0, 20.0, 22.0]) # Above hfield center + + # Random angles + angles = jp.linspace(0, jp.pi/4, 5) # 0 to 45 degrees from vertical + for i, angle in enumerate(angles): + # Create angled ray (rotating around x-axis) + vec_angled = jp.array([0.0, jp.sin(angle), -jp.cos(angle)]) + + dist_angled, geomid_angled = ray_fn(mx, dx, center_pos, vec_angled, geomgroup) + + mj_geomid_angled = np.zeros(1, dtype=np.int32) + mj_dist_angled = mujoco.mj_ray(m, d, center_pos, vec_angled, geomgroup, True, -1, mj_geomid_angled) + + _assert_eq(dist_angled, mj_dist_angled, f'angle_dist_{i}') + _assert_eq(geomid_angled, mj_geomid_angled[0], f'angle_geomid_{i}') + if __name__ == '__main__': absltest.main() diff --git a/mjx/mujoco/mjx/test_data/ray.xml b/mjx/mujoco/mjx/test_data/ray.xml index bbf42b1e46..e920be3da8 100644 --- a/mjx/mujoco/mjx/test_data/ray.xml +++ b/mjx/mujoco/mjx/test_data/ray.xml @@ -2,6 +2,7 @@ + @@ -14,5 +15,6 @@ +