Skip to content

Commit ef2cc0f

Browse files
authored
Added a minimal batched sweep and prune broadphase (#29)
Adds a sweep-and-prune broad phase that assumes axis aligned bounding boxes for geometries per environment.
1 parent 867c1da commit ef2cc0f

File tree

5 files changed

+727
-0
lines changed

5 files changed

+727
-0
lines changed

mujoco/mjx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Public API for MJX."""
1717

18+
from ._src.collision_driver import broad_phase
1819
from ._src.constraint import make_constraint
1920
from ._src.forward import euler
2021
from ._src.forward import forward
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
# Copyright 2025 The Physics-Next Project Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Tests for broad phase functions."""
17+
18+
from absl.testing import absltest
19+
from absl.testing import parameterized
20+
import mujoco
21+
from mujoco import mjx
22+
import numpy as np
23+
import warp as wp
24+
25+
from . import test_util
26+
27+
BoxType = wp.types.matrix(shape=(2, 3), dtype=wp.float32)
28+
29+
# Helper function to initialize a box
30+
def init_box(min_x, min_y, min_z, max_x, max_y, max_z):
31+
center = wp.vec3((min_x + max_x) / 2, (min_y + max_y) / 2, (min_z + max_z) / 2)
32+
size = wp.vec3(max_x - min_x, max_y - min_y, max_z - min_z)
33+
box = wp.types.matrix(shape=(2, 3), dtype=wp.float32)(
34+
[center.x, center.y, center.z, size.x, size.y, size.z]
35+
)
36+
return box
37+
38+
39+
def overlap(
40+
a: wp.types.matrix(shape=(2, 3), dtype=wp.float32),
41+
b: wp.types.matrix(shape=(2, 3), dtype=wp.float32),
42+
) -> bool:
43+
# Extract centers and sizes
44+
a_center = a[0]
45+
a_size = a[1]
46+
b_center = b[0]
47+
b_size = b[1]
48+
49+
# Calculate min/max from center and size
50+
a_min = a_center - 0.5 * a_size
51+
a_max = a_center + 0.5 * a_size
52+
b_min = b_center - 0.5 * b_size
53+
b_max = b_center + 0.5 * b_size
54+
55+
return not (
56+
a_min.x > b_max.x
57+
or b_min.x > a_max.x
58+
or a_min.y > b_max.y
59+
or b_min.y > a_max.y
60+
or a_min.z > b_max.z
61+
or b_min.z > a_max.z
62+
)
63+
64+
65+
def transform_aabb(
66+
aabb: wp.types.matrix(shape=(2, 3), dtype=wp.float32),
67+
pos: wp.vec3,
68+
rot: wp.mat33,
69+
) -> wp.types.matrix(shape=(2, 3), dtype=wp.float32):
70+
# Extract center and half-extents from AABB
71+
center = aabb[0]
72+
half_extents = aabb[1] * 0.5
73+
74+
# Get absolute values of rotation matrix columns
75+
right = wp.vec3(wp.abs(rot[0, 0]), wp.abs(rot[0, 1]), wp.abs(rot[0, 2]))
76+
up = wp.vec3(wp.abs(rot[1, 0]), wp.abs(rot[1, 1]), wp.abs(rot[1, 2]))
77+
forward = wp.vec3(wp.abs(rot[2, 0]), wp.abs(rot[2, 1]), wp.abs(rot[2, 2]))
78+
79+
# Compute world space half-extents
80+
world_extents = (
81+
right * half_extents.x + up * half_extents.y + forward * half_extents.z
82+
)
83+
84+
# Transform center
85+
new_center = rot @ center + pos
86+
87+
# Return new AABB as matrix with center and full size
88+
result = BoxType()
89+
result[0] = wp.vec3(new_center.x, new_center.y, new_center.z)
90+
result[1] = wp.vec3(
91+
world_extents.x * 2.0, world_extents.y * 2.0, world_extents.z * 2.0
92+
)
93+
return result
94+
95+
96+
def find_overlaps_brute_force(worldId: int, num_boxes_per_world: int, boxes, pos, rot):
97+
"""
98+
Finds overlapping bounding boxes using the brute-force O(n^2) algorithm.
99+
100+
Returns:
101+
List of tuples [(idx1, idx2)] where idx1 and idx2 are indices of overlapping boxes.
102+
"""
103+
overlaps = []
104+
105+
for i in range(num_boxes_per_world):
106+
box_a = boxes[i]
107+
box_a = transform_aabb(box_a, pos[worldId, i], rot[worldId, i])
108+
109+
for j in range(i + 1, num_boxes_per_world):
110+
box_b = boxes[j]
111+
box_b = transform_aabb(box_b, pos[worldId, j], rot[worldId, j])
112+
113+
# Use the overlap function to check for overlap
114+
if overlap(box_a, box_b):
115+
overlaps.append((i, j)) # Store indices of overlapping boxes
116+
117+
return overlaps
118+
119+
120+
def find_overlaps_brute_force_batched(
121+
num_worlds: int, num_boxes_per_world: int, boxes, pos, rot
122+
):
123+
"""
124+
Finds overlapping bounding boxes using the brute-force O(n^2) algorithm.
125+
126+
Returns:
127+
List of tuples [(idx1, idx2)] where idx1 and idx2 are indices of overlapping boxes.
128+
"""
129+
130+
overlaps = []
131+
132+
for worldId in range(num_worlds):
133+
overlaps.append(
134+
find_overlaps_brute_force(worldId, num_boxes_per_world, boxes, pos, rot)
135+
)
136+
137+
# Show progress bar for brute force computation
138+
# from tqdm import tqdm
139+
140+
# for worldId in tqdm(range(num_worlds), desc="Computing overlaps"):
141+
# overlaps.append(find_overlaps_brute_force(worldId, num_boxes_per_world, boxes))
142+
143+
return overlaps
144+
145+
146+
class MultiIndexList:
147+
def __init__(self):
148+
self.data = {}
149+
150+
def __setitem__(self, key, value):
151+
worldId, i = key
152+
if worldId not in self.data:
153+
self.data[worldId] = []
154+
if i >= len(self.data[worldId]):
155+
self.data[worldId].extend([None] * (i - len(self.data[worldId]) + 1))
156+
self.data[worldId][i] = value
157+
158+
def __getitem__(self, key):
159+
worldId, i = key
160+
return self.data[worldId][i] # Raises KeyError if not found
161+
162+
163+
class BroadPhaseTest(parameterized.TestCase):
164+
def test_broad_phase(self):
165+
"""Tests broad phase."""
166+
_, mjd, m, d = test_util.fixture("humanoid/humanoid.xml")
167+
168+
# Create some test boxes
169+
num_worlds = d.nworld
170+
num_boxes_per_world = m.ngeom
171+
# print(f"num_worlds: {num_worlds}, num_boxes_per_world: {num_boxes_per_world}")
172+
173+
# Parameters for random box generation
174+
sample_space_origin = wp.vec3(-10.0, -10.0, -10.0) # Origin of the bounding volume
175+
sample_space_size = wp.vec3(20.0, 20.0, 20.0) # Size of the bounding volume
176+
min_edge_length = 0.5 # Minimum edge length of random boxes
177+
max_edge_length = 5.0 # Maximum edge length of random boxes
178+
179+
boxes_list = []
180+
181+
# Set random seed for reproducibility
182+
import random
183+
184+
random.seed(11)
185+
186+
# Generate random boxes for each world
187+
for _ in range(num_boxes_per_world):
188+
# Generate random position within bounding volume
189+
pos_x = sample_space_origin.x + random.random() * sample_space_size.x
190+
pos_y = sample_space_origin.y + random.random() * sample_space_size.y
191+
pos_z = sample_space_origin.z + random.random() * sample_space_size.z
192+
193+
# Generate random box dimensions between min and max edge lengths
194+
size_x = min_edge_length + random.random() * (max_edge_length - min_edge_length)
195+
size_y = min_edge_length + random.random() * (max_edge_length - min_edge_length)
196+
size_z = min_edge_length + random.random() * (max_edge_length - min_edge_length)
197+
198+
# Create box with random position and size
199+
boxes_list.append(
200+
init_box(pos_x, pos_y, pos_z, pos_x + size_x, pos_y + size_y, pos_z + size_z)
201+
)
202+
203+
# Generate random positions and orientations for each box
204+
pos = []
205+
rot = []
206+
for _ in range(num_worlds * num_boxes_per_world):
207+
# Random position within bounding volume
208+
pos_x = sample_space_origin.x + random.random() * sample_space_size.x
209+
pos_y = sample_space_origin.y + random.random() * sample_space_size.y
210+
pos_z = sample_space_origin.z + random.random() * sample_space_size.z
211+
pos.append(wp.vec3(pos_x, pos_y, pos_z))
212+
# pos.append(wp.vec3(0, 0, 0))
213+
214+
# Random rotation matrix
215+
rx = random.random() * 6.28318530718 # 2*pi
216+
ry = random.random() * 6.28318530718
217+
rz = random.random() * 6.28318530718
218+
axis = wp.vec3(rx, ry, rz)
219+
axis = axis / wp.length(axis) # normalize axis
220+
angle = random.random() * 6.28318530718 # random angle between 0 and 2*pi
221+
rot.append(wp.quat_to_matrix(wp.quat_from_axis_angle(axis, angle)))
222+
# rot.append(wp.quat_to_matrix(wp.quat_from_axis_angle(wp.vec3(1, 0, 0), float(0))))
223+
224+
# Convert pos and rot to MultiIndexList format
225+
pos_multi = MultiIndexList()
226+
rot_multi = MultiIndexList()
227+
228+
# Populate the MultiIndexLists using pos and rot data
229+
idx = 0
230+
for world_idx in range(num_worlds):
231+
for i in range(num_boxes_per_world):
232+
pos_multi[world_idx, i] = pos[idx]
233+
rot_multi[world_idx, i] = rot[idx]
234+
idx += 1
235+
236+
brute_force_overlaps = find_overlaps_brute_force_batched(
237+
num_worlds, num_boxes_per_world, boxes_list, pos_multi, rot_multi
238+
)
239+
240+
# Test the broad phase by setting custom aabb data
241+
m.geom_aabb = wp.array(
242+
boxes_list, dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32)
243+
)
244+
m.geom_aabb = m.geom_aabb.reshape((num_boxes_per_world))
245+
d.geom_xpos = wp.array(pos, dtype=wp.vec3)
246+
d.geom_xpos = d.geom_xpos.reshape((num_worlds, num_boxes_per_world))
247+
d.geom_xmat = wp.array(rot, dtype=wp.mat33)
248+
d.geom_xmat = d.geom_xmat.reshape((num_worlds, num_boxes_per_world))
249+
250+
mjx.broad_phase(m, d)
251+
252+
result = d.broadphase_pairs
253+
result_count = d.result_count
254+
255+
# Get numpy arrays from result and result_count
256+
result_np = result.numpy()
257+
result_count_np = result_count.numpy()
258+
259+
# Iterate over each world
260+
for world_idx in range(num_worlds):
261+
# Get number of collisions for this world
262+
num_collisions = result_count_np[world_idx]
263+
print(f"Number of collisions for world {world_idx}: {num_collisions}")
264+
265+
list = brute_force_overlaps[world_idx]
266+
assert len(list) == num_collisions, "Number of collisions does not match"
267+
268+
# Print each collision pair
269+
for i in range(num_collisions):
270+
pair = result_np[world_idx][i]
271+
272+
# Convert pair to tuple for comparison
273+
pair_tuple = (int(pair[0]), int(pair[1]))
274+
assert pair_tuple in list, (
275+
f"Collision pair {pair_tuple} not found in brute force results"
276+
)
277+
278+
279+
if __name__ == "__main__":
280+
wp.init()
281+
absltest.main()

0 commit comments

Comments
 (0)