Skip to content

Commit 330c2a2

Browse files
authored
Merge pull request #95 from thowell/narrowphase
modify narrowphase
2 parents 5a3f85d + 6bfc2a0 commit 330c2a2

File tree

5 files changed

+65
-204
lines changed

5 files changed

+65
-204
lines changed

mujoco_warp/_src/collision_driver.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import warp as wp
1717

18-
from .support import group_key
1918
from .support import where
2019
from .types import MJ_MINVAL
2120
from .types import Data
@@ -294,9 +293,7 @@ def broadphase_sweep_and_prune_kernel(
294293
return
295294

296295
pair = _geom_pair(m, idx1, idx2)
297-
key = group_key(m.geom_type[idx1], m.geom_type[idx2])
298296
d.collision_pair[pairid] = pair
299-
d.collision_type[pairid] = key
300297
d.collision_worldid[pairid] = worldId
301298

302299
threadId += num_threads
@@ -506,9 +503,7 @@ def _nxn_broadphase(m: Model, d: Data):
506503
return
507504

508505
pair = _geom_pair(m, geom1, geom2)
509-
key = group_key(type1, type2)
510506
d.collision_pair[pairid] = pair
511-
d.collision_type[pairid] = key
512507
d.collision_worldid[pairid] = worldid
513508

514509
wp.launch(

mujoco_warp/_src/collision_functions.py

Lines changed: 65 additions & 188 deletions
Original file line numberDiff line numberDiff line change
@@ -18,129 +18,35 @@
1818
from .math import closest_segment_to_segment_points
1919
from .math import make_frame
2020
from .math import normalize_with_norm
21-
from .support import group_key
2221
from .types import Data
2322
from .types import GeomType
2423
from .types import Model
2524

2625

2726
@wp.struct
28-
class GeomPlane:
27+
class Geom:
2928
pos: wp.vec3
3029
rot: wp.mat33
3130
normal: wp.vec3
32-
33-
34-
@wp.struct
35-
class GeomSphere:
36-
pos: wp.vec3
37-
rot: wp.mat33
38-
radius: float
39-
40-
41-
@wp.struct
42-
class GeomCapsule:
43-
pos: wp.vec3
44-
rot: wp.mat33
45-
radius: float
46-
halfsize: float
47-
48-
49-
@wp.struct
50-
class GeomEllipsoid:
51-
pos: wp.vec3
52-
rot: wp.mat33
5331
size: wp.vec3
32+
# TODO(team): mesh fields: vertadr, vertnum
5433

5534

56-
@wp.struct
57-
class GeomCylinder:
58-
pos: wp.vec3
59-
rot: wp.mat33
60-
radius: float
61-
halfsize: float
62-
63-
64-
@wp.struct
65-
class GeomBox:
66-
pos: wp.vec3
67-
rot: wp.mat33
68-
size: wp.vec3
69-
70-
71-
@wp.struct
72-
class GeomMesh:
73-
pos: wp.vec3
74-
rot: wp.mat33
75-
vertadr: int
76-
vertnum: int
77-
78-
79-
def get_info(t):
80-
@wp.func
81-
def _get_info(
82-
gid: int,
83-
m: Model,
84-
geom_xpos: wp.array(dtype=wp.vec3),
85-
geom_xmat: wp.array(dtype=wp.mat33),
86-
):
87-
pos = geom_xpos[gid]
88-
rot = geom_xmat[gid]
89-
size = m.geom_size[gid]
90-
if wp.static(t == GeomType.SPHERE.value):
91-
sphere = GeomSphere()
92-
sphere.pos = pos
93-
sphere.rot = rot
94-
sphere.radius = size[0]
95-
return sphere
96-
elif wp.static(t == GeomType.BOX.value):
97-
box = GeomBox()
98-
box.pos = pos
99-
box.rot = rot
100-
box.size = size
101-
return box
102-
elif wp.static(t == GeomType.PLANE.value):
103-
plane = GeomPlane()
104-
plane.pos = pos
105-
plane.rot = rot
106-
plane.normal = wp.vec3(rot[0, 2], rot[1, 2], rot[2, 2])
107-
return plane
108-
elif wp.static(t == GeomType.CAPSULE.value):
109-
capsule = GeomCapsule()
110-
capsule.pos = pos
111-
capsule.rot = rot
112-
capsule.radius = size[0]
113-
capsule.halfsize = size[1]
114-
return capsule
115-
elif wp.static(t == GeomType.ELLIPSOID.value):
116-
ellipsoid = GeomEllipsoid()
117-
ellipsoid.pos = pos
118-
ellipsoid.rot = rot
119-
ellipsoid.size = size
120-
return ellipsoid
121-
elif wp.static(t == GeomType.CYLINDER.value):
122-
cylinder = GeomCylinder()
123-
cylinder.pos = pos
124-
cylinder.rot = rot
125-
cylinder.radius = size[0]
126-
cylinder.halfsize = size[1]
127-
return cylinder
128-
elif wp.static(t == GeomType.MESH.value):
129-
mesh = GeomMesh()
130-
mesh.pos = pos
131-
mesh.rot = rot
132-
dataid = m.geom_dataid[gid]
133-
if dataid >= 0:
134-
mesh.vertadr = m.mesh_vertadr[dataid]
135-
mesh.vertnum = m.mesh_vertnum[dataid]
136-
else:
137-
mesh.vertadr = 0
138-
mesh.vertnum = 0
139-
return mesh
140-
else:
141-
wp.static(RuntimeError("Unsupported type", t))
142-
143-
return _get_info
35+
@wp.func
36+
def _geom(
37+
gid: int,
38+
m: Model,
39+
geom_xpos: wp.array(dtype=wp.vec3),
40+
geom_xmat: wp.array(dtype=wp.mat33),
41+
) -> Geom:
42+
geom = Geom()
43+
geom.pos = geom_xpos[gid]
44+
rot = geom_xmat[gid]
45+
geom.rot = rot
46+
geom.size = m.geom_size[gid]
47+
geom.normal = wp.vec3(rot[0, 2], rot[1, 2], rot[2, 2]) # plane
48+
49+
return geom
14450

14551

14652
@wp.func
@@ -175,14 +81,14 @@ def _plane_sphere(
17581

17682
@wp.func
17783
def plane_sphere(
178-
plane: GeomPlane,
179-
sphere: GeomSphere,
84+
plane: Geom,
85+
sphere: Geom,
18086
worldid: int,
18187
d: Data,
18288
margin: float,
18389
geom_indices: wp.vec2i,
18490
):
185-
dist, pos = _plane_sphere(plane.normal, plane.pos, sphere.pos, sphere.radius)
91+
dist, pos = _plane_sphere(plane.normal, plane.pos, sphere.pos, sphere.size[0])
18692

18793
write_contact(d, dist, pos, make_frame(plane.normal), margin, geom_indices, worldid)
18894

@@ -212,18 +118,18 @@ def _sphere_sphere(
212118

213119
@wp.func
214120
def sphere_sphere(
215-
sphere1: GeomSphere,
216-
sphere2: GeomSphere,
121+
sphere1: Geom,
122+
sphere2: Geom,
217123
worldid: int,
218124
d: Data,
219125
margin: float,
220126
geom_indices: wp.vec2i,
221127
):
222128
_sphere_sphere(
223129
sphere1.pos,
224-
sphere1.radius,
130+
sphere1.size[0],
225131
sphere2.pos,
226-
sphere2.radius,
132+
sphere2.size[0],
227133
worldid,
228134
d,
229135
margin,
@@ -233,17 +139,17 @@ def sphere_sphere(
233139

234140
@wp.func
235141
def capsule_capsule(
236-
cap1: GeomCapsule,
237-
cap2: GeomCapsule,
142+
cap1: Geom,
143+
cap2: Geom,
238144
worldid: int,
239145
d: Data,
240146
margin: float,
241147
geom_indices: wp.vec2i,
242148
):
243149
axis1 = wp.vec3(cap1.rot[0, 2], cap1.rot[1, 2], cap1.rot[2, 2])
244150
axis2 = wp.vec3(cap2.rot[0, 2], cap2.rot[1, 2], cap2.rot[2, 2])
245-
length1 = cap1.halfsize
246-
length2 = cap2.halfsize
151+
length1 = cap1.size[1]
152+
length2 = cap2.size[1]
247153
seg1 = axis1 * length1
248154
seg2 = axis2 * length2
249155

@@ -254,13 +160,13 @@ def capsule_capsule(
254160
cap2.pos + seg2,
255161
)
256162

257-
_sphere_sphere(pt1, cap1.radius, pt2, cap2.radius, worldid, d, margin, geom_indices)
163+
_sphere_sphere(pt1, cap1.size[0], pt2, cap2.size[0], worldid, d, margin, geom_indices)
258164

259165

260166
@wp.func
261167
def plane_capsule(
262-
plane: GeomPlane,
263-
cap: GeomCapsule,
168+
plane: Geom,
169+
cap: Geom,
264170
worldid: int,
265171
d: Data,
266172
margin: float,
@@ -280,19 +186,19 @@ def plane_capsule(
280186

281187
c = wp.cross(n, b)
282188
frame = wp.mat33(n[0], n[1], n[2], b[0], b[1], b[2], c[0], c[1], c[2])
283-
segment = axis * cap.halfsize
189+
segment = axis * cap.size[1]
284190

285-
dist1, pos1 = _plane_sphere(n, plane.pos, cap.pos + segment, cap.radius)
191+
dist1, pos1 = _plane_sphere(n, plane.pos, cap.pos + segment, cap.size[0])
286192
write_contact(d, dist1, pos1, frame, margin, geom_indices, worldid)
287193

288-
dist2, pos2 = _plane_sphere(n, plane.pos, cap.pos - segment, cap.radius)
194+
dist2, pos2 = _plane_sphere(n, plane.pos, cap.pos - segment, cap.size[0])
289195
write_contact(d, dist2, pos2, frame, margin, geom_indices, worldid)
290196

291197

292198
@wp.func
293199
def plane_box(
294-
plane: GeomPlane,
295-
box: GeomBox,
200+
plane: Geom,
201+
box: Geom,
296202
worldid: int,
297203
d: Data,
298204
margin: float,
@@ -326,72 +232,43 @@ def plane_box(
326232
break
327233

328234

329-
_collision_functions = {
330-
(GeomType.PLANE.value, GeomType.SPHERE.value): plane_sphere,
331-
(GeomType.SPHERE.value, GeomType.SPHERE.value): sphere_sphere,
332-
(GeomType.PLANE.value, GeomType.CAPSULE.value): plane_capsule,
333-
(GeomType.PLANE.value, GeomType.BOX.value): plane_box,
334-
(GeomType.CAPSULE.value, GeomType.CAPSULE.value): capsule_capsule,
335-
}
336-
337-
338-
def create_collision_function_kernel(type1, type2):
339-
key = group_key(type1, type2)
340-
341-
@wp.kernel
342-
def _collision_function_kernel(
343-
m: Model,
344-
d: Data,
345-
):
346-
tid = wp.tid()
347-
348-
if tid >= d.ncollision[0] or d.collision_type[tid] != key:
349-
return
350-
351-
geoms = d.collision_pair[tid]
352-
worldid = d.collision_worldid[tid]
353-
354-
# TODO(team): per-world maximum number of collisions?
355-
356-
g1 = geoms[0]
357-
g2 = geoms[1]
235+
@wp.kernel
236+
def _narrowphase(
237+
m: Model,
238+
d: Data,
239+
):
240+
tid = wp.tid()
358241

359-
geom1 = wp.static(get_info(type1))(
360-
g1,
361-
m,
362-
d.geom_xpos[worldid],
363-
d.geom_xmat[worldid],
364-
)
365-
geom2 = wp.static(get_info(type2))(
366-
g2,
367-
m,
368-
d.geom_xpos[worldid],
369-
d.geom_xmat[worldid],
370-
)
242+
if tid >= d.ncollision[0]:
243+
return
371244

372-
margin = wp.max(m.geom_margin[g1], m.geom_margin[g2])
245+
geoms = d.collision_pair[tid]
246+
worldid = d.collision_worldid[tid]
373247

374-
wp.static(_collision_functions[(type1, type2)])(
375-
geom1, geom2, worldid, d, margin, geoms
376-
)
248+
g1 = geoms[0]
249+
g2 = geoms[1]
250+
type1 = m.geom_type[g1]
251+
type2 = m.geom_type[g2]
377252

378-
return _collision_function_kernel
253+
geom1 = _geom(g1, m, d.geom_xpos[worldid], d.geom_xmat[worldid])
254+
geom2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid])
379255

256+
margin = wp.max(m.geom_margin[g1], m.geom_margin[g2])
380257

381-
_collision_kernels = {}
258+
# TODO(team): static loop unrolling to remove unnecessary branching
259+
if type1 == int(GeomType.PLANE.value) and type2 == int(GeomType.SPHERE.value):
260+
plane_sphere(geom1, geom2, worldid, d, margin, geoms)
261+
elif type1 == int(GeomType.SPHERE.value) and type2 == int(GeomType.SPHERE.value):
262+
sphere_sphere(geom1, geom2, worldid, d, margin, geoms)
263+
elif type1 == int(GeomType.PLANE.value) and type2 == int(GeomType.CAPSULE.value):
264+
plane_capsule(geom1, geom2, worldid, d, margin, geoms)
265+
elif type1 == int(GeomType.PLANE.value) and type2 == int(GeomType.BOX.value):
266+
plane_box(geom1, geom2, worldid, d, margin, geoms)
267+
elif type1 == int(GeomType.CAPSULE.value) and type2 == int(GeomType.CAPSULE.value):
268+
capsule_capsule(geom1, geom2, worldid, d, margin, geoms)
382269

383270

384271
def narrowphase(m: Model, d: Data):
385272
# we need to figure out how to keep the overhead of this small - not launching anything
386273
# for pair types without collisions, as well as updating the launch dimensions.
387-
388-
# TODO(team): investigate a single kernel launch for all collision functions
389-
# TODO only generate collision kernels we actually need
390-
if len(_collision_kernels) == 0:
391-
for type1, type2 in _collision_functions.keys():
392-
_collision_kernels[(type1, type2)] = create_collision_function_kernel(
393-
type1, type2
394-
)
395-
396-
for collision_kernel in _collision_kernels.values():
397-
wp.launch(collision_kernel, dim=d.nconmax, inputs=[m, d])
274+
wp.launch(_narrowphase, dim=d.nconmax, inputs=[m, d])

mujoco_warp/_src/io.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,6 @@ def make_data(
475475

476476
# collision driver
477477
d.collision_pair = wp.empty(nconmax, dtype=wp.vec2i, ndim=1)
478-
d.collision_type = wp.empty(nconmax, dtype=wp.int32, ndim=1)
479478
d.collision_worldid = wp.empty(nconmax, dtype=wp.int32, ndim=1)
480479
d.ncollision = wp.zeros(1, dtype=wp.int32, ndim=1)
481480

@@ -699,7 +698,6 @@ def tile(x):
699698

700699
# collision driver
701700
d.collision_pair = wp.empty(nconmax, dtype=wp.vec2i, ndim=1)
702-
d.collision_type = wp.empty(nconmax, dtype=wp.int32, ndim=1)
703701
d.collision_worldid = wp.empty(nconmax, dtype=wp.int32, ndim=1)
704702
d.ncollision = wp.zeros(1, dtype=wp.int32, ndim=1)
705703

0 commit comments

Comments
 (0)