|
| 1 | +import csv |
| 2 | +import os |
| 3 | +from typing import TYPE_CHECKING, NamedTuple |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from typing_extensions import override |
| 7 | + |
| 8 | +import genesis as gs |
| 9 | +import genesis.utils.geom as gu |
| 10 | +import genesis.vis.keybindings as kb |
| 11 | +from genesis.utils.misc import tensor_to_array |
| 12 | +from genesis.vis.viewer_plugins import EVENT_HANDLE_STATE, EVENT_HANDLED, RaycasterViewerPlugin |
| 13 | + |
| 14 | +if TYPE_CHECKING: |
| 15 | + from genesis.engine.entities.rigid_entity import RigidLink |
| 16 | + from genesis.engine.scene import Scene |
| 17 | + from genesis.ext.pyrender.node import Node |
| 18 | + |
| 19 | + |
| 20 | +class SelectedPoint(NamedTuple): |
| 21 | + """ |
| 22 | + Represents a selected point on a rigid mesh surface. |
| 23 | +
|
| 24 | + Attributes |
| 25 | + ---------- |
| 26 | + link : RigidLink |
| 27 | + The rigid link that the point belongs to. |
| 28 | + local_position : np.ndarray, shape (3,) |
| 29 | + The position of the point in the link's local coordinate frame. |
| 30 | + local_normal : np.ndarray, shape (3,) |
| 31 | + The surface normal at the point in the link's local coordinate frame. |
| 32 | + """ |
| 33 | + |
| 34 | + link: "RigidLink" |
| 35 | + local_position: np.ndarray # shape (3,) |
| 36 | + local_normal: np.ndarray # shape (3,) |
| 37 | + |
| 38 | + |
| 39 | +class MeshPointSelectorPlugin(RaycasterViewerPlugin): |
| 40 | + """ |
| 41 | + Interactive viewer plugin that enables using mouse clicks to select points on rigid meshes. |
| 42 | + Selected points are stored in local coordinates relative to their link's frame. |
| 43 | + """ |
| 44 | + |
| 45 | + def __init__( |
| 46 | + self, |
| 47 | + sphere_radius: float = 0.005, |
| 48 | + sphere_color: tuple = (0.1, 0.3, 1.0, 1.0), |
| 49 | + hover_color: tuple = (0.3, 0.5, 1.0, 1.0), |
| 50 | + grid_snap: tuple[float, float, float] = (-1.0, -1.0, -1.0), |
| 51 | + output_file: str = "selected_points.csv", |
| 52 | + ) -> None: |
| 53 | + super().__init__() |
| 54 | + self.sphere_radius = sphere_radius |
| 55 | + self.sphere_color = sphere_color |
| 56 | + self.hover_color = hover_color |
| 57 | + self.grid_snap = grid_snap |
| 58 | + self.output_file = output_file |
| 59 | + |
| 60 | + self.selected_points: dict[int, SelectedPoint] = {} |
| 61 | + self._prev_mouse_pos: tuple[int, int] = (0, 0) |
| 62 | + |
| 63 | + def build(self, viewer, camera: "Node", scene: "Scene"): |
| 64 | + super().build(viewer, camera, scene) |
| 65 | + self._prev_mouse_pos: tuple[int, int] = (self.viewer._viewport_size[0] // 2, self.viewer._viewport_size[1] // 2) |
| 66 | + |
| 67 | + def _get_pos_hash(self, pos: np.ndarray) -> int: |
| 68 | + """ |
| 69 | + Generate a hash for a given position to use as a unique identifier. |
| 70 | +
|
| 71 | + Parameters |
| 72 | + ---------- |
| 73 | + pos : np.ndarray, shape (3,) |
| 74 | + The position to hash. |
| 75 | +
|
| 76 | + Returns |
| 77 | + ------- |
| 78 | + int |
| 79 | + The hash of the position. |
| 80 | + """ |
| 81 | + return hash((round(pos[0], 6), round(pos[1], 6), round(pos[2], 6))) |
| 82 | + |
| 83 | + def _snap_to_grid(self, point: np.ndarray) -> np.ndarray: |
| 84 | + """ |
| 85 | + Snap a point to the grid based on grid_snap settings. |
| 86 | +
|
| 87 | + Parameters |
| 88 | + ---------- |
| 89 | + point : np.ndarray, shape (3,) |
| 90 | + The point to snap. |
| 91 | +
|
| 92 | + Returns |
| 93 | + ------- |
| 94 | + np.ndarray, shape (3,) |
| 95 | + The point snapped to the grid. |
| 96 | + """ |
| 97 | + grid_snap = np.array(self.grid_snap) |
| 98 | + # Snap each axis if the snap value is non-negative |
| 99 | + return np.where(grid_snap >= 0, np.round(point / grid_snap) * grid_snap, point) |
| 100 | + |
| 101 | + @override |
| 102 | + def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: |
| 103 | + self._prev_mouse_pos = (x, y) |
| 104 | + |
| 105 | + @override |
| 106 | + def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: |
| 107 | + if button == 1: # left click |
| 108 | + ray = self._screen_position_to_ray(x, y) |
| 109 | + ray_hit = self._raycaster.cast(*ray) |
| 110 | + |
| 111 | + if ray_hit is not None and ray_hit.geom: |
| 112 | + link = ray_hit.geom.link |
| 113 | + world_pos = ray_hit.position |
| 114 | + world_normal = ray_hit.normal |
| 115 | + |
| 116 | + # Get link pose |
| 117 | + link_pos = tensor_to_array(link.get_pos()) |
| 118 | + link_quat = tensor_to_array(link.get_quat()) |
| 119 | + |
| 120 | + local_pos = gu.inv_transform_by_trans_quat(world_pos, link_pos, link_quat) |
| 121 | + local_normal = gu.inv_transform_by_quat(world_normal, link_quat) |
| 122 | + |
| 123 | + # Apply grid snapping to local position |
| 124 | + local_pos = self._snap_to_grid(local_pos) |
| 125 | + |
| 126 | + pos_hash = self._get_pos_hash(local_pos) |
| 127 | + if pos_hash in self.selected_points: |
| 128 | + # Deselect point if already selected |
| 129 | + del self.selected_points[pos_hash] |
| 130 | + else: |
| 131 | + selected_point = SelectedPoint(link, local_pos, local_normal) |
| 132 | + self.selected_points[pos_hash] = selected_point |
| 133 | + |
| 134 | + return EVENT_HANDLED |
| 135 | + return None |
| 136 | + |
| 137 | + @override |
| 138 | + def on_draw(self) -> None: |
| 139 | + super().on_draw() |
| 140 | + if self.scene._visualizer is not None and self.scene._visualizer.is_built: |
| 141 | + self.scene.clear_debug_objects() |
| 142 | + mouse_ray = self._screen_position_to_ray(*self._prev_mouse_pos) |
| 143 | + |
| 144 | + closest_hit = self._raycaster.cast(*mouse_ray) |
| 145 | + if closest_hit is not None: |
| 146 | + snap_pos = self._snap_to_grid(closest_hit.position) |
| 147 | + |
| 148 | + # Draw hover preview |
| 149 | + self.scene.draw_debug_sphere( |
| 150 | + snap_pos, |
| 151 | + self.sphere_radius, |
| 152 | + self.hover_color, |
| 153 | + ) |
| 154 | + self.scene.draw_debug_arrow( |
| 155 | + snap_pos, |
| 156 | + tuple(n * 0.05 for n in closest_hit.normal), |
| 157 | + self.sphere_radius / 2, |
| 158 | + self.hover_color, |
| 159 | + ) |
| 160 | + |
| 161 | + if self.selected_points: |
| 162 | + world_positions = [] |
| 163 | + for point in self.selected_points.values(): |
| 164 | + link_pos = tensor_to_array(point.link.get_pos()) |
| 165 | + link_quat = tensor_to_array(point.link.get_quat()) |
| 166 | + local_pos_arr = np.array(point.local_position, dtype=np.float32) |
| 167 | + current_world_pos = gu.transform_by_trans_quat(local_pos_arr, link_pos, link_quat) |
| 168 | + world_positions.append(current_world_pos) |
| 169 | + |
| 170 | + if len(world_positions) == 1: |
| 171 | + self.scene.draw_debug_sphere( |
| 172 | + world_positions[0], |
| 173 | + self.sphere_radius, |
| 174 | + self.sphere_color, |
| 175 | + ) |
| 176 | + else: |
| 177 | + positions_array = np.array(world_positions) |
| 178 | + self.scene.draw_debug_spheres(positions_array, self.sphere_radius, self.sphere_color) |
| 179 | + |
| 180 | + @override |
| 181 | + def on_close(self) -> None: |
| 182 | + super().on_close() |
| 183 | + |
| 184 | + if not self.selected_points: |
| 185 | + print("[MeshPointSelectorPlugin] No points selected.") |
| 186 | + return |
| 187 | + |
| 188 | + output_file = self.output_file |
| 189 | + try: |
| 190 | + with open(output_file, "w", newline="") as csvfile: |
| 191 | + writer = csv.writer(csvfile) |
| 192 | + |
| 193 | + writer.writerow( |
| 194 | + [ |
| 195 | + "point_idx", |
| 196 | + "link_idx", |
| 197 | + "local_pos_x", |
| 198 | + "local_pos_y", |
| 199 | + "local_pos_z", |
| 200 | + "local_normal_x", |
| 201 | + "local_normal_y", |
| 202 | + "local_normal_z", |
| 203 | + ] |
| 204 | + ) |
| 205 | + |
| 206 | + for i, point in enumerate(self.selected_points.values(), 1): |
| 207 | + writer.writerow( |
| 208 | + [ |
| 209 | + i, |
| 210 | + point.link.idx, |
| 211 | + point.local_position[0], |
| 212 | + point.local_position[1], |
| 213 | + point.local_position[2], |
| 214 | + point.local_normal[0], |
| 215 | + point.local_normal[1], |
| 216 | + point.local_normal[2], |
| 217 | + ] |
| 218 | + ) |
| 219 | + |
| 220 | + gs.logger.info( |
| 221 | + f"[MeshPointSelectorPlugin] Wrote {len(self.selected_points)} selected points to '{output_file}'" |
| 222 | + ) |
| 223 | + |
| 224 | + except Exception as e: |
| 225 | + gs.logger.error(f"[MeshPointSelectorPlugin] Error writing to '{output_file}': {e}") |
| 226 | + |
| 227 | + |
| 228 | +if __name__ == "__main__": |
| 229 | + gs.init(backend=gs.gpu) |
| 230 | + |
| 231 | + scene = gs.Scene( |
| 232 | + sim_options=gs.options.SimOptions( |
| 233 | + gravity=(0.0, 0.0, 0.0), |
| 234 | + ), |
| 235 | + viewer_options=gs.options.ViewerOptions( |
| 236 | + camera_pos=(0.6, 0.6, 0.6), |
| 237 | + camera_lookat=(0.0, 0.0, 0.2), |
| 238 | + camera_fov=40, |
| 239 | + ), |
| 240 | + vis_options=gs.options.VisOptions( |
| 241 | + show_world_frame=True, |
| 242 | + ), |
| 243 | + profiling_options=gs.options.ProfilingOptions( |
| 244 | + show_FPS=False, |
| 245 | + ), |
| 246 | + show_viewer=True, |
| 247 | + ) |
| 248 | + |
| 249 | + hand = scene.add_entity( |
| 250 | + morph=gs.morphs.URDF( |
| 251 | + file="urdf/shadow_hand/shadow_hand.urdf", |
| 252 | + collision=True, |
| 253 | + pos=(0.0, 0.0, 0.0), |
| 254 | + euler=(0.0, 0.0, 180.0), |
| 255 | + fixed=True, |
| 256 | + merge_fixed_links=False, |
| 257 | + ), |
| 258 | + ) |
| 259 | + |
| 260 | + scene.viewer.add_plugin( |
| 261 | + MeshPointSelectorPlugin( |
| 262 | + sphere_radius=0.004, |
| 263 | + grid_snap=(-1.0, 0.01, 0.01), |
| 264 | + output_file="selected_points.csv", |
| 265 | + ) |
| 266 | + ) |
| 267 | + |
| 268 | + scene.build() |
| 269 | + |
| 270 | + is_running = True |
| 271 | + |
| 272 | + def stop(): |
| 273 | + global is_running |
| 274 | + is_running = False |
| 275 | + |
| 276 | + scene.viewer.register_keybinds( |
| 277 | + kb.Keybind("quit", kb.Key.ESCAPE, kb.KeyAction.PRESS, callback=stop), |
| 278 | + ) |
| 279 | + |
| 280 | + try: |
| 281 | + while is_running: |
| 282 | + scene.step() |
| 283 | + |
| 284 | + if "PYTEST_VERSION" in os.environ: |
| 285 | + break |
| 286 | + except KeyboardInterrupt: |
| 287 | + gs.logger.info("Simulation interrupted, exiting.") |
| 288 | + finally: |
| 289 | + gs.logger.info("Simulation finished.") |
0 commit comments