Skip to content

Commit ca38712

Browse files
Milotrinceduburcqa
andauthored
[FEATURE] Add example viewer plugins. (#2357)
Co-authored-by: Alexis DUBURCQ <alexis.duburcq@gmail.com>
1 parent 0dcfa35 commit ca38712

File tree

15 files changed

+1455
-200
lines changed

15 files changed

+1455
-200
lines changed
Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
import csv
2+
import os
3+
from typing import TYPE_CHECKING, NamedTuple
4+
5+
import numpy as np
6+
from typing_extensions import override
7+
8+
import genesis as gs
9+
import genesis.utils.geom as gu
10+
import genesis.vis.keybindings as kb
11+
from genesis.utils.misc import tensor_to_array
12+
from genesis.vis.viewer_plugins import EVENT_HANDLE_STATE, EVENT_HANDLED, RaycasterViewerPlugin
13+
14+
if TYPE_CHECKING:
15+
from genesis.engine.entities.rigid_entity import RigidLink
16+
from genesis.engine.scene import Scene
17+
from genesis.ext.pyrender.node import Node
18+
19+
20+
class SelectedPoint(NamedTuple):
21+
"""
22+
Represents a selected point on a rigid mesh surface.
23+
24+
Attributes
25+
----------
26+
link : RigidLink
27+
The rigid link that the point belongs to.
28+
local_position : np.ndarray, shape (3,)
29+
The position of the point in the link's local coordinate frame.
30+
local_normal : np.ndarray, shape (3,)
31+
The surface normal at the point in the link's local coordinate frame.
32+
"""
33+
34+
link: "RigidLink"
35+
local_position: np.ndarray # shape (3,)
36+
local_normal: np.ndarray # shape (3,)
37+
38+
39+
class MeshPointSelectorPlugin(RaycasterViewerPlugin):
40+
"""
41+
Interactive viewer plugin that enables using mouse clicks to select points on rigid meshes.
42+
Selected points are stored in local coordinates relative to their link's frame.
43+
"""
44+
45+
def __init__(
46+
self,
47+
sphere_radius: float = 0.005,
48+
sphere_color: tuple = (0.1, 0.3, 1.0, 1.0),
49+
hover_color: tuple = (0.3, 0.5, 1.0, 1.0),
50+
grid_snap: tuple[float, float, float] = (-1.0, -1.0, -1.0),
51+
output_file: str = "selected_points.csv",
52+
) -> None:
53+
super().__init__()
54+
self.sphere_radius = sphere_radius
55+
self.sphere_color = sphere_color
56+
self.hover_color = hover_color
57+
self.grid_snap = grid_snap
58+
self.output_file = output_file
59+
60+
self.selected_points: dict[int, SelectedPoint] = {}
61+
self._prev_mouse_pos: tuple[int, int] = (0, 0)
62+
63+
def build(self, viewer, camera: "Node", scene: "Scene"):
64+
super().build(viewer, camera, scene)
65+
self._prev_mouse_pos: tuple[int, int] = (self.viewer._viewport_size[0] // 2, self.viewer._viewport_size[1] // 2)
66+
67+
def _get_pos_hash(self, pos: np.ndarray) -> int:
68+
"""
69+
Generate a hash for a given position to use as a unique identifier.
70+
71+
Parameters
72+
----------
73+
pos : np.ndarray, shape (3,)
74+
The position to hash.
75+
76+
Returns
77+
-------
78+
int
79+
The hash of the position.
80+
"""
81+
return hash((round(pos[0], 6), round(pos[1], 6), round(pos[2], 6)))
82+
83+
def _snap_to_grid(self, point: np.ndarray) -> np.ndarray:
84+
"""
85+
Snap a point to the grid based on grid_snap settings.
86+
87+
Parameters
88+
----------
89+
point : np.ndarray, shape (3,)
90+
The point to snap.
91+
92+
Returns
93+
-------
94+
np.ndarray, shape (3,)
95+
The point snapped to the grid.
96+
"""
97+
grid_snap = np.array(self.grid_snap)
98+
# Snap each axis if the snap value is non-negative
99+
return np.where(grid_snap >= 0, np.round(point / grid_snap) * grid_snap, point)
100+
101+
@override
102+
def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE:
103+
self._prev_mouse_pos = (x, y)
104+
105+
@override
106+
def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE:
107+
if button == 1: # left click
108+
ray = self._screen_position_to_ray(x, y)
109+
ray_hit = self._raycaster.cast(*ray)
110+
111+
if ray_hit is not None and ray_hit.geom:
112+
link = ray_hit.geom.link
113+
world_pos = ray_hit.position
114+
world_normal = ray_hit.normal
115+
116+
# Get link pose
117+
link_pos = tensor_to_array(link.get_pos())
118+
link_quat = tensor_to_array(link.get_quat())
119+
120+
local_pos = gu.inv_transform_by_trans_quat(world_pos, link_pos, link_quat)
121+
local_normal = gu.inv_transform_by_quat(world_normal, link_quat)
122+
123+
# Apply grid snapping to local position
124+
local_pos = self._snap_to_grid(local_pos)
125+
126+
pos_hash = self._get_pos_hash(local_pos)
127+
if pos_hash in self.selected_points:
128+
# Deselect point if already selected
129+
del self.selected_points[pos_hash]
130+
else:
131+
selected_point = SelectedPoint(link, local_pos, local_normal)
132+
self.selected_points[pos_hash] = selected_point
133+
134+
return EVENT_HANDLED
135+
return None
136+
137+
@override
138+
def on_draw(self) -> None:
139+
super().on_draw()
140+
if self.scene._visualizer is not None and self.scene._visualizer.is_built:
141+
self.scene.clear_debug_objects()
142+
mouse_ray = self._screen_position_to_ray(*self._prev_mouse_pos)
143+
144+
closest_hit = self._raycaster.cast(*mouse_ray)
145+
if closest_hit is not None:
146+
snap_pos = self._snap_to_grid(closest_hit.position)
147+
148+
# Draw hover preview
149+
self.scene.draw_debug_sphere(
150+
snap_pos,
151+
self.sphere_radius,
152+
self.hover_color,
153+
)
154+
self.scene.draw_debug_arrow(
155+
snap_pos,
156+
tuple(n * 0.05 for n in closest_hit.normal),
157+
self.sphere_radius / 2,
158+
self.hover_color,
159+
)
160+
161+
if self.selected_points:
162+
world_positions = []
163+
for point in self.selected_points.values():
164+
link_pos = tensor_to_array(point.link.get_pos())
165+
link_quat = tensor_to_array(point.link.get_quat())
166+
local_pos_arr = np.array(point.local_position, dtype=np.float32)
167+
current_world_pos = gu.transform_by_trans_quat(local_pos_arr, link_pos, link_quat)
168+
world_positions.append(current_world_pos)
169+
170+
if len(world_positions) == 1:
171+
self.scene.draw_debug_sphere(
172+
world_positions[0],
173+
self.sphere_radius,
174+
self.sphere_color,
175+
)
176+
else:
177+
positions_array = np.array(world_positions)
178+
self.scene.draw_debug_spheres(positions_array, self.sphere_radius, self.sphere_color)
179+
180+
@override
181+
def on_close(self) -> None:
182+
super().on_close()
183+
184+
if not self.selected_points:
185+
print("[MeshPointSelectorPlugin] No points selected.")
186+
return
187+
188+
output_file = self.output_file
189+
try:
190+
with open(output_file, "w", newline="") as csvfile:
191+
writer = csv.writer(csvfile)
192+
193+
writer.writerow(
194+
[
195+
"point_idx",
196+
"link_idx",
197+
"local_pos_x",
198+
"local_pos_y",
199+
"local_pos_z",
200+
"local_normal_x",
201+
"local_normal_y",
202+
"local_normal_z",
203+
]
204+
)
205+
206+
for i, point in enumerate(self.selected_points.values(), 1):
207+
writer.writerow(
208+
[
209+
i,
210+
point.link.idx,
211+
point.local_position[0],
212+
point.local_position[1],
213+
point.local_position[2],
214+
point.local_normal[0],
215+
point.local_normal[1],
216+
point.local_normal[2],
217+
]
218+
)
219+
220+
gs.logger.info(
221+
f"[MeshPointSelectorPlugin] Wrote {len(self.selected_points)} selected points to '{output_file}'"
222+
)
223+
224+
except Exception as e:
225+
gs.logger.error(f"[MeshPointSelectorPlugin] Error writing to '{output_file}': {e}")
226+
227+
228+
if __name__ == "__main__":
229+
gs.init(backend=gs.gpu)
230+
231+
scene = gs.Scene(
232+
sim_options=gs.options.SimOptions(
233+
gravity=(0.0, 0.0, 0.0),
234+
),
235+
viewer_options=gs.options.ViewerOptions(
236+
camera_pos=(0.6, 0.6, 0.6),
237+
camera_lookat=(0.0, 0.0, 0.2),
238+
camera_fov=40,
239+
),
240+
vis_options=gs.options.VisOptions(
241+
show_world_frame=True,
242+
),
243+
profiling_options=gs.options.ProfilingOptions(
244+
show_FPS=False,
245+
),
246+
show_viewer=True,
247+
)
248+
249+
hand = scene.add_entity(
250+
morph=gs.morphs.URDF(
251+
file="urdf/shadow_hand/shadow_hand.urdf",
252+
collision=True,
253+
pos=(0.0, 0.0, 0.0),
254+
euler=(0.0, 0.0, 180.0),
255+
fixed=True,
256+
merge_fixed_links=False,
257+
),
258+
)
259+
260+
scene.viewer.add_plugin(
261+
MeshPointSelectorPlugin(
262+
sphere_radius=0.004,
263+
grid_snap=(-1.0, 0.01, 0.01),
264+
output_file="selected_points.csv",
265+
)
266+
)
267+
268+
scene.build()
269+
270+
is_running = True
271+
272+
def stop():
273+
global is_running
274+
is_running = False
275+
276+
scene.viewer.register_keybinds(
277+
kb.Keybind("quit", kb.Key.ESCAPE, kb.KeyAction.PRESS, callback=stop),
278+
)
279+
280+
try:
281+
while is_running:
282+
scene.step()
283+
284+
if "PYTEST_VERSION" in os.environ:
285+
break
286+
except KeyboardInterrupt:
287+
gs.logger.info("Simulation interrupted, exiting.")
288+
finally:
289+
gs.logger.info("Simulation finished.")
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import argparse
2+
import math
3+
import os
4+
5+
import genesis as gs
6+
import genesis.vis.keybindings as kb
7+
8+
if __name__ == "__main__":
9+
parser = argparse.ArgumentParser(description="Mouse interaction viewer plugin example.")
10+
parser.add_argument("--use_force", action="store_true", help="Apply spring forces instead of setting position")
11+
args = parser.parse_args()
12+
13+
gs.init(backend=gs.gpu)
14+
15+
scene = gs.Scene(
16+
viewer_options=gs.options.ViewerOptions(
17+
camera_pos=(3.5, 0.0, 2.5),
18+
camera_lookat=(0.0, 0.0, 0.5),
19+
camera_fov=40,
20+
),
21+
profiling_options=gs.options.ProfilingOptions(
22+
show_FPS=False,
23+
),
24+
show_viewer=True,
25+
)
26+
27+
scene.add_entity(gs.morphs.Plane())
28+
29+
sphere = scene.add_entity(
30+
morph=gs.morphs.Sphere(
31+
pos=(-0.3, -0.3, 0),
32+
radius=0.1,
33+
),
34+
)
35+
for i in range(6):
36+
angle = i * (2 * math.pi / 6)
37+
radius = 0.5 + i * 0.1
38+
cube = scene.add_entity(
39+
morph=gs.morphs.Box(
40+
pos=(radius * math.cos(angle), radius * math.sin(angle), 0.1 + i * 0.1),
41+
size=(0.2, 0.2, 0.2),
42+
),
43+
)
44+
45+
scene.viewer.add_plugin(
46+
gs.vis.viewer_plugins.MouseInteractionPlugin(
47+
use_force=args.use_force,
48+
color=(0.1, 0.6, 0.8, 0.6),
49+
)
50+
)
51+
52+
scene.build()
53+
54+
is_running = True
55+
56+
def stop():
57+
global is_running
58+
is_running = False
59+
60+
scene.viewer.register_keybinds(
61+
kb.Keybind("quit", kb.Key.ESCAPE, kb.KeyAction.PRESS, callback=stop),
62+
)
63+
64+
try:
65+
while is_running:
66+
scene.step()
67+
68+
if "PYTEST_VERSION" in os.environ:
69+
break
70+
except KeyboardInterrupt:
71+
gs.logger.info("Simulation interrupted, exiting.")
72+
finally:
73+
gs.logger.info("Simulation finished.")

0 commit comments

Comments
 (0)