1818from .math import closest_segment_to_segment_points
1919from .math import make_frame
2020from .math import normalize_with_norm
21- from .support import group_key
2221from .types import Data
2322from .types import GeomType
2423from .types import Model
2524
2625
2726@wp .struct
28- class GeomPlane :
27+ class Geom :
2928 pos : wp .vec3
3029 rot : wp .mat33
3130 normal : wp .vec3
32-
33-
34- @wp .struct
35- class GeomSphere :
36- pos : wp .vec3
37- rot : wp .mat33
38- radius : float
39-
40-
41- @wp .struct
42- class GeomCapsule :
43- pos : wp .vec3
44- rot : wp .mat33
45- radius : float
46- halfsize : float
47-
48-
49- @wp .struct
50- class GeomEllipsoid :
51- pos : wp .vec3
52- rot : wp .mat33
5331 size : wp .vec3
32+ # TODO(team): mesh fields: vertadr, vertnum
5433
5534
56- @wp .struct
57- class GeomCylinder :
58- pos : wp .vec3
59- rot : wp .mat33
60- radius : float
61- halfsize : float
62-
63-
64- @wp .struct
65- class GeomBox :
66- pos : wp .vec3
67- rot : wp .mat33
68- size : wp .vec3
69-
70-
71- @wp .struct
72- class GeomMesh :
73- pos : wp .vec3
74- rot : wp .mat33
75- vertadr : int
76- vertnum : int
77-
78-
79- def get_info (t ):
80- @wp .func
81- def _get_info (
82- gid : int ,
83- m : Model ,
84- geom_xpos : wp .array (dtype = wp .vec3 ),
85- geom_xmat : wp .array (dtype = wp .mat33 ),
86- ):
87- pos = geom_xpos [gid ]
88- rot = geom_xmat [gid ]
89- size = m .geom_size [gid ]
90- if wp .static (t == GeomType .SPHERE .value ):
91- sphere = GeomSphere ()
92- sphere .pos = pos
93- sphere .rot = rot
94- sphere .radius = size [0 ]
95- return sphere
96- elif wp .static (t == GeomType .BOX .value ):
97- box = GeomBox ()
98- box .pos = pos
99- box .rot = rot
100- box .size = size
101- return box
102- elif wp .static (t == GeomType .PLANE .value ):
103- plane = GeomPlane ()
104- plane .pos = pos
105- plane .rot = rot
106- plane .normal = wp .vec3 (rot [0 , 2 ], rot [1 , 2 ], rot [2 , 2 ])
107- return plane
108- elif wp .static (t == GeomType .CAPSULE .value ):
109- capsule = GeomCapsule ()
110- capsule .pos = pos
111- capsule .rot = rot
112- capsule .radius = size [0 ]
113- capsule .halfsize = size [1 ]
114- return capsule
115- elif wp .static (t == GeomType .ELLIPSOID .value ):
116- ellipsoid = GeomEllipsoid ()
117- ellipsoid .pos = pos
118- ellipsoid .rot = rot
119- ellipsoid .size = size
120- return ellipsoid
121- elif wp .static (t == GeomType .CYLINDER .value ):
122- cylinder = GeomCylinder ()
123- cylinder .pos = pos
124- cylinder .rot = rot
125- cylinder .radius = size [0 ]
126- cylinder .halfsize = size [1 ]
127- return cylinder
128- elif wp .static (t == GeomType .MESH .value ):
129- mesh = GeomMesh ()
130- mesh .pos = pos
131- mesh .rot = rot
132- dataid = m .geom_dataid [gid ]
133- if dataid >= 0 :
134- mesh .vertadr = m .mesh_vertadr [dataid ]
135- mesh .vertnum = m .mesh_vertnum [dataid ]
136- else :
137- mesh .vertadr = 0
138- mesh .vertnum = 0
139- return mesh
140- else :
141- wp .static (RuntimeError ("Unsupported type" , t ))
142-
143- return _get_info
35+ @wp .func
36+ def _geom (
37+ gid : int ,
38+ m : Model ,
39+ geom_xpos : wp .array (dtype = wp .vec3 ),
40+ geom_xmat : wp .array (dtype = wp .mat33 ),
41+ ) -> Geom :
42+ geom = Geom ()
43+ geom .pos = geom_xpos [gid ]
44+ rot = geom_xmat [gid ]
45+ geom .rot = rot
46+ geom .size = m .geom_size [gid ]
47+ geom .normal = wp .vec3 (rot [0 , 2 ], rot [1 , 2 ], rot [2 , 2 ]) # plane
48+
49+ return geom
14450
14551
14652@wp .func
@@ -175,14 +81,14 @@ def _plane_sphere(
17581
17682@wp .func
17783def plane_sphere (
178- plane : GeomPlane ,
179- sphere : GeomSphere ,
84+ plane : Geom ,
85+ sphere : Geom ,
18086 worldid : int ,
18187 d : Data ,
18288 margin : float ,
18389 geom_indices : wp .vec2i ,
18490):
185- dist , pos = _plane_sphere (plane .normal , plane .pos , sphere .pos , sphere .radius )
91+ dist , pos = _plane_sphere (plane .normal , plane .pos , sphere .pos , sphere .size [ 0 ] )
18692
18793 write_contact (d , dist , pos , make_frame (plane .normal ), margin , geom_indices , worldid )
18894
@@ -212,18 +118,18 @@ def _sphere_sphere(
212118
213119@wp .func
214120def sphere_sphere (
215- sphere1 : GeomSphere ,
216- sphere2 : GeomSphere ,
121+ sphere1 : Geom ,
122+ sphere2 : Geom ,
217123 worldid : int ,
218124 d : Data ,
219125 margin : float ,
220126 geom_indices : wp .vec2i ,
221127):
222128 _sphere_sphere (
223129 sphere1 .pos ,
224- sphere1 .radius ,
130+ sphere1 .size [ 0 ] ,
225131 sphere2 .pos ,
226- sphere2 .radius ,
132+ sphere2 .size [ 0 ] ,
227133 worldid ,
228134 d ,
229135 margin ,
@@ -233,17 +139,17 @@ def sphere_sphere(
233139
234140@wp .func
235141def capsule_capsule (
236- cap1 : GeomCapsule ,
237- cap2 : GeomCapsule ,
142+ cap1 : Geom ,
143+ cap2 : Geom ,
238144 worldid : int ,
239145 d : Data ,
240146 margin : float ,
241147 geom_indices : wp .vec2i ,
242148):
243149 axis1 = wp .vec3 (cap1 .rot [0 , 2 ], cap1 .rot [1 , 2 ], cap1 .rot [2 , 2 ])
244150 axis2 = wp .vec3 (cap2 .rot [0 , 2 ], cap2 .rot [1 , 2 ], cap2 .rot [2 , 2 ])
245- length1 = cap1 .halfsize
246- length2 = cap2 .halfsize
151+ length1 = cap1 .size [ 1 ]
152+ length2 = cap2 .size [ 1 ]
247153 seg1 = axis1 * length1
248154 seg2 = axis2 * length2
249155
@@ -254,13 +160,13 @@ def capsule_capsule(
254160 cap2 .pos + seg2 ,
255161 )
256162
257- _sphere_sphere (pt1 , cap1 .radius , pt2 , cap2 .radius , worldid , d , margin , geom_indices )
163+ _sphere_sphere (pt1 , cap1 .size [ 0 ] , pt2 , cap2 .size [ 0 ] , worldid , d , margin , geom_indices )
258164
259165
260166@wp .func
261167def plane_capsule (
262- plane : GeomPlane ,
263- cap : GeomCapsule ,
168+ plane : Geom ,
169+ cap : Geom ,
264170 worldid : int ,
265171 d : Data ,
266172 margin : float ,
@@ -280,19 +186,19 @@ def plane_capsule(
280186
281187 c = wp .cross (n , b )
282188 frame = wp .mat33 (n [0 ], n [1 ], n [2 ], b [0 ], b [1 ], b [2 ], c [0 ], c [1 ], c [2 ])
283- segment = axis * cap .halfsize
189+ segment = axis * cap .size [ 1 ]
284190
285- dist1 , pos1 = _plane_sphere (n , plane .pos , cap .pos + segment , cap .radius )
191+ dist1 , pos1 = _plane_sphere (n , plane .pos , cap .pos + segment , cap .size [ 0 ] )
286192 write_contact (d , dist1 , pos1 , frame , margin , geom_indices , worldid )
287193
288- dist2 , pos2 = _plane_sphere (n , plane .pos , cap .pos - segment , cap .radius )
194+ dist2 , pos2 = _plane_sphere (n , plane .pos , cap .pos - segment , cap .size [ 0 ] )
289195 write_contact (d , dist2 , pos2 , frame , margin , geom_indices , worldid )
290196
291197
292198@wp .func
293199def plane_box (
294- plane : GeomPlane ,
295- box : GeomBox ,
200+ plane : Geom ,
201+ box : Geom ,
296202 worldid : int ,
297203 d : Data ,
298204 margin : float ,
@@ -326,72 +232,43 @@ def plane_box(
326232 break
327233
328234
329- _collision_functions = {
330- (GeomType .PLANE .value , GeomType .SPHERE .value ): plane_sphere ,
331- (GeomType .SPHERE .value , GeomType .SPHERE .value ): sphere_sphere ,
332- (GeomType .PLANE .value , GeomType .CAPSULE .value ): plane_capsule ,
333- (GeomType .PLANE .value , GeomType .BOX .value ): plane_box ,
334- (GeomType .CAPSULE .value , GeomType .CAPSULE .value ): capsule_capsule ,
335- }
336-
337-
338- def create_collision_function_kernel (type1 , type2 ):
339- key = group_key (type1 , type2 )
340-
341- @wp .kernel
342- def _collision_function_kernel (
343- m : Model ,
344- d : Data ,
345- ):
346- tid = wp .tid ()
347-
348- if tid >= d .ncollision [0 ] or d .collision_type [tid ] != key :
349- return
350-
351- geoms = d .collision_pair [tid ]
352- worldid = d .collision_worldid [tid ]
353-
354- # TODO(team): per-world maximum number of collisions?
355-
356- g1 = geoms [0 ]
357- g2 = geoms [1 ]
235+ @wp .kernel
236+ def _narrowphase (
237+ m : Model ,
238+ d : Data ,
239+ ):
240+ tid = wp .tid ()
358241
359- geom1 = wp .static (get_info (type1 ))(
360- g1 ,
361- m ,
362- d .geom_xpos [worldid ],
363- d .geom_xmat [worldid ],
364- )
365- geom2 = wp .static (get_info (type2 ))(
366- g2 ,
367- m ,
368- d .geom_xpos [worldid ],
369- d .geom_xmat [worldid ],
370- )
242+ if tid >= d .ncollision [0 ]:
243+ return
371244
372- margin = wp .max (m .geom_margin [g1 ], m .geom_margin [g2 ])
245+ geoms = d .collision_pair [tid ]
246+ worldid = d .collision_worldid [tid ]
373247
374- wp .static (_collision_functions [(type1 , type2 )])(
375- geom1 , geom2 , worldid , d , margin , geoms
376- )
248+ g1 = geoms [0 ]
249+ g2 = geoms [1 ]
250+ type1 = m .geom_type [g1 ]
251+ type2 = m .geom_type [g2 ]
377252
378- return _collision_function_kernel
253+ geom1 = _geom (g1 , m , d .geom_xpos [worldid ], d .geom_xmat [worldid ])
254+ geom2 = _geom (g2 , m , d .geom_xpos [worldid ], d .geom_xmat [worldid ])
379255
256+ margin = wp .max (m .geom_margin [g1 ], m .geom_margin [g2 ])
380257
381- _collision_kernels = {}
258+ # TODO(team): static loop unrolling to remove unnecessary branching
259+ if type1 == int (GeomType .PLANE .value ) and type2 == int (GeomType .SPHERE .value ):
260+ plane_sphere (geom1 , geom2 , worldid , d , margin , geoms )
261+ elif type1 == int (GeomType .SPHERE .value ) and type2 == int (GeomType .SPHERE .value ):
262+ sphere_sphere (geom1 , geom2 , worldid , d , margin , geoms )
263+ elif type1 == int (GeomType .PLANE .value ) and type2 == int (GeomType .CAPSULE .value ):
264+ plane_capsule (geom1 , geom2 , worldid , d , margin , geoms )
265+ elif type1 == int (GeomType .PLANE .value ) and type2 == int (GeomType .BOX .value ):
266+ plane_box (geom1 , geom2 , worldid , d , margin , geoms )
267+ elif type1 == int (GeomType .CAPSULE .value ) and type2 == int (GeomType .CAPSULE .value ):
268+ capsule_capsule (geom1 , geom2 , worldid , d , margin , geoms )
382269
383270
384271def narrowphase (m : Model , d : Data ):
385272 # we need to figure out how to keep the overhead of this small - not launching anything
386273 # for pair types without collisions, as well as updating the launch dimensions.
387-
388- # TODO(team): investigate a single kernel launch for all collision functions
389- # TODO only generate collision kernels we actually need
390- if len (_collision_kernels ) == 0 :
391- for type1 , type2 in _collision_functions .keys ():
392- _collision_kernels [(type1 , type2 )] = create_collision_function_kernel (
393- type1 , type2
394- )
395-
396- for collision_kernel in _collision_kernels .values ():
397- wp .launch (collision_kernel , dim = d .nconmax , inputs = [m , d ])
274+ wp .launch (_narrowphase , dim = d .nconmax , inputs = [m , d ])
0 commit comments