Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions mjx/mujoco/mjx/_src/collision_convex.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def plane_convex(plane: GeomInfo, convex: ConvexInfo) -> Collision:

frame = jp.stack([math.make_frame(n)] * 4, axis=0)
unique = jp.tril(idx == idx[:, None]).sum(axis=1) == 1
dist = jp.where(unique, -support[idx], 1)
dist = jp.where(unique, -support[idx], jp.abs(support[idx]).max())
pos = pos - 0.5 * dist[:, None] * n
return dist, pos, frame

Expand Down Expand Up @@ -214,7 +214,7 @@ def get_support(faces, normal):
d *= sign

spt = sphere_pos + n * sphere.size[0]
dist = jp.where(has_separating_axis, 1.0, d - sphere.size[0])
dist = d - sphere.size[0]
pos = (pt + spt) * 0.5

# Go back to world frame.
Expand Down Expand Up @@ -280,9 +280,8 @@ def get_support(face, normal):
# Create variables for the face contact.
pos = (cap_pts_clipped + face_pts) * 0.5
contact_normal = -jp.stack([normal] * 2, 0)
face_penetration = jp.where(
mask & has_support, jp.dot(face_pts - cap_pts_clipped, normal), -1
)
face_dist = jp.dot(face_pts - cap_pts_clipped, normal)
face_penetration = jp.where(mask & has_support, face_dist, jp.minimum(face_dist, 0))

# Pick a potential shallow edge contact.
def get_edge_axis(edge):
Expand Down Expand Up @@ -316,7 +315,7 @@ def get_edge_axis(edge):
edge_face_normals = edge_face_normal[e_idx]
edge_voronoi_front = ((edge_face_normals @ edge_axis) < 0).all()
shallow = ~degenerate_edge_dir & edge_voronoi_front
edge_penetration = jp.where(shallow, cap.size[0] - edge_dist, -1)
edge_penetration = cap.size[0] - edge_dist

# Determine edge contact position.
edge_pos = (
Expand All @@ -327,7 +326,8 @@ def get_edge_axis(edge):
) & ~degenerate_edge_dir
min_face_penetration = face_penetration.min()
has_edge_contact = (
(edge_penetration > 0)
shallow
& (edge_penetration > 0)
# prefer edge contact if the edge is smaller than face penetration
& jp.where(
min_face_penetration > 0,
Expand All @@ -351,7 +351,9 @@ def get_edge_axis(edge):
n = n @ convex.mat.T

dist = -jp.where(
has_edge_contact, jp.array([edge_penetration, -1]), face_penetration
has_edge_contact,
jp.array([edge_penetration, -edge_penetration]),
face_penetration,
)
return dist, pos, n

Expand Down Expand Up @@ -577,7 +579,7 @@ def _create_contact_manifold(
penetration_dir = jp.take(poly_incident, best, axis=0) - contact_pts
penetration = penetration_dir.dot(-clipping_norm)

dist = jp.where(mask_pts, -penetration, jp.ones_like(penetration))
dist = jp.where(mask_pts, -penetration, jp.abs(penetration))
pos = contact_pts
normal = -jp.stack([sep_axis] * 4, 0)
return dist, pos, normal
Expand Down Expand Up @@ -677,7 +679,7 @@ def get_support(axis, is_degenerate):
idx = dist.argmin()
dist = jp.where(
is_edge_contact,
jp.array([dist[idx], 1, 1, 1]),
jp.array([dist[idx], jp.abs(dist[idx]), jp.abs(dist[idx]), jp.abs(dist[idx])]),
dist,
)
pos = jp.where(is_edge_contact, jp.tile(pos[idx], (4, 1)), pos)
Expand Down Expand Up @@ -806,7 +808,7 @@ def get_support(axis):
incident_face_norm,
-best_axis,
)
dist = jp.where(is_face_separating, 1.0, dist)
dist = jp.where(is_face_separating, jp.abs(dist), dist)

# Handle edge separating axes by checking all edge pairs.
a_idx = jp.tile(jp.arange(edges_a.shape[0]), reps=edges_b.shape[0])
Expand Down Expand Up @@ -856,7 +858,7 @@ def get_normals(a_dir, a_pt, b_dir):
normal = jp.where(is_edge_contact, edge_axes[best_edge_idx], normal)
dist = jp.where(
is_edge_contact,
jp.array([best_edge_dist, 1, 1, 1]),
jp.array([best_edge_dist, jp.abs(best_edge_dist), jp.abs(best_edge_dist), jp.abs(best_edge_dist)]),
dist,
)
a_closest, b_closest = math.closest_segment_to_segment_points(
Expand Down Expand Up @@ -1060,7 +1062,7 @@ def hfield_sphere(

# zero out non-unique contacts
unique = jp.tril(idx == idx[:, None]).sum(axis=1) == 1
dist = jp.where(unique, dist, 1)
dist = jp.where(unique, dist, jp.max(dist))

# back to world frame, _hfield_collision returns collision in hfield frame
pos = jax.vmap(lambda p: h.mat @ p + h.pos)(pos)
Expand All @@ -1084,7 +1086,7 @@ def hfield_capsule(

# zero out non-unique contacts
unique = jp.tril(idx == idx[:, None]).sum(axis=1) == 1
dist = jp.where(unique, dist, 1)
dist = jp.where(unique, dist, jp.max(dist))

# back to world frame, _hfield_collision returns collision in hfield frame
pos = jax.vmap(lambda p: h.mat @ p + h.pos)(pos)
Expand All @@ -1108,7 +1110,7 @@ def hfield_convex(

# zero out non-unique contacts
unique = jp.tril(idx == idx[:, None]).sum(axis=1) == 1
dist = jp.where(unique, dist, 1)
dist = jp.where(unique, dist, jp.max(dist))

# back to world frame, _hfield_collision returns collision in hfield frame
pos = jax.vmap(lambda p: h.mat @ p + h.pos)(pos)
Expand Down