Skip to content

Commit 7c3cfe0

Browse files
authored
Merge pull request #59 from adenzler-nvidia/dev/adenzler/check-dist-in-contact-functions
write_contact function that respects margin
2 parents a230c1c + 8f8ab9d commit 7c3cfe0

File tree

1 file changed

+24
-30
lines changed

1 file changed

+24
-30
lines changed

mujoco/mjx/_src/collision_functions.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,17 @@ def _get_info(
142142

143143
return _get_info
144144

145+
@wp.func
146+
def write_contact(d: Data, dist: float, pos: wp.vec3, frame: wp.mat33, margin: float, geoms: wp.vec2i, worldid: int):
147+
active = (dist - margin) < 0
148+
if active:
149+
index = wp.atomic_add(d.ncon, 0, 1)
150+
if index < d.nconmax:
151+
d.contact.dist[index] = dist
152+
d.contact.pos[index] = pos
153+
d.contact.frame[index] = frame
154+
d.contact.geom[index] = geoms
155+
d.contact.worldid[index] = worldid
145156

146157
@wp.func
147158
def _plane_sphere(
@@ -153,18 +164,14 @@ def _plane_sphere(
153164

154165

155166
@wp.func
156-
def plane_sphere(plane: GeomPlane, sphere: GeomSphere, worldid: int, d: Data):
167+
def plane_sphere(plane: GeomPlane, sphere: GeomSphere, worldid: int, d: Data, margin: float, geom_indices: wp.vec2i):
157168
dist, pos = _plane_sphere(plane.normal, plane.pos, sphere.pos, sphere.radius)
158169

159-
index = wp.atomic_add(d.ncon, 0, 1)
160-
d.contact.dist[index] = dist
161-
d.contact.pos[index] = pos
162-
d.contact.frame[index] = make_frame(plane.normal)
163-
return index, 1
170+
write_contact(d, dist, pos, make_frame(plane.normal), margin, geom_indices, worldid)
164171

165172

166173
@wp.func
167-
def sphere_sphere(sphere1: GeomSphere, sphere2: GeomSphere, worldid: int, d: Data):
174+
def sphere_sphere(sphere1: GeomSphere, sphere2: GeomSphere, worldid: int, d: Data, margin: float, geom_indices: wp.vec2i):
168175
dir = sphere1.pos - sphere2.pos
169176
dist = wp.length(dir)
170177
if dist == 0.0:
@@ -174,15 +181,11 @@ def sphere_sphere(sphere1: GeomSphere, sphere2: GeomSphere, worldid: int, d: Dat
174181
dist = dist - (sphere1.radius + sphere2.radius)
175182
pos = sphere1.pos + n * (sphere1.radius + 0.5 * dist)
176183

177-
index = wp.atomic_add(d.ncon, 0, 1)
178-
d.contact.dist[index] = dist
179-
d.contact.pos[index] = pos
180-
d.contact.frame[index] = make_frame(n)
181-
return index, 1
184+
write_contact(d, dist, pos, make_frame(n), margin, geom_indices, worldid)
182185

183186

184187
@wp.func
185-
def plane_capsule(plane: GeomPlane, cap: GeomCapsule, worldid: int, d: Data):
188+
def plane_capsule(plane: GeomPlane, cap: GeomCapsule, worldid: int, d: Data, margin: float, geom_indices: wp.vec2i):
186189
"""Calculates two contacts between a capsule and a plane."""
187190
n = plane.normal
188191
axis = wp.vec3(cap.rot[0, 2], cap.rot[1, 2], cap.rot[2, 2])
@@ -198,19 +201,11 @@ def plane_capsule(plane: GeomPlane, cap: GeomCapsule, worldid: int, d: Data):
198201
frame = mat33_from_cols(n, b, wp.cross(n, b))
199202
segment = axis * cap.halfsize
200203

201-
start_index = wp.atomic_add(d.ncon, 0, 2)
202-
index = start_index
203-
dist, pos = _plane_sphere(n, plane.pos, cap.pos + segment, cap.radius)
204-
d.contact.dist[index] = dist
205-
d.contact.pos[index] = pos
206-
d.contact.frame[index] = frame
207-
index += 1
204+
dist1, pos1 = _plane_sphere(n, plane.pos, cap.pos + segment, cap.radius)
205+
write_contact(d, dist1, pos1, frame, margin, geom_indices, worldid)
208206

209-
dist, pos = _plane_sphere(n, plane.pos, cap.pos - segment, cap.radius)
210-
d.contact.dist[index] = dist
211-
d.contact.pos[index] = pos
212-
d.contact.frame[index] = frame
213-
return start_index, 2
207+
dist2, pos2 = _plane_sphere(n, plane.pos, cap.pos - segment, cap.radius)
208+
write_contact(d, dist2, pos2, frame, margin, geom_indices, worldid)
214209

215210

216211
_collision_functions = {
@@ -252,12 +247,11 @@ def _collision_function_kernel(
252247
d.geom_xmat[worldid],
253248
)
254249

255-
index, ncon = wp.static(_collision_functions[(type1, type2)])(
256-
geom1, geom2, worldid, d
250+
margin = wp.max(m.geom_margin[g1], m.geom_margin[g2])
251+
252+
wp.static(_collision_functions[(type1, type2)])(
253+
geom1, geom2, worldid, d, margin, geoms
257254
)
258-
for i in range(ncon):
259-
d.contact.worldid[index + i] = worldid
260-
d.contact.geom[index + i] = geoms
261255

262256
return _collision_function_kernel
263257

0 commit comments

Comments
 (0)