Skip to content

Commit a6e1cc9

Browse files
authored
[FEATURE] Add full support of tensor broadcasting in getters. (#2051)
* Remove tensor allocation warnings. * Unify tensor sanitization. * Remove 'unsafe' support now that zero-copy is inherently safe. * Raise exception if invalid markers at passed in unit test expr. * Add fully support of (enhanced) tensor broadcasting in setters (w/ and w/o zero-copy). * Rename 'verts_idx' in 'verts_local_idx' for clarity. * More robust batched camera and sensor debug draw.
1 parent eea9caf commit a6e1cc9

34 files changed

+1087
-1207
lines changed

examples/coupling/fem_cube_linked_with_arm.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,7 @@ def main():
115115

116116
print("cube init pos", cube.init_positions)
117117
pin_idx = [1, 5]
118-
cube.set_vertex_constraints(
119-
verts_idx=pin_idx,
120-
link=end_joint.link,
121-
)
118+
cube.set_vertex_constraints(verts_idx_local=pin_idx, link=end_joint.link)
122119
print("Cube initial positions:", cube.init_positions[pin_idx])
123120
scene.draw_debug_spheres(poss=cube.init_positions[pin_idx], radius=0.02, color=(1.0, 0.0, 1.0, 0.8))
124121

examples/fem_hard_and_soft_constraint.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -106,30 +106,19 @@ def get_next_circle_position():
106106
try:
107107
target_positions = blob.init_positions[pinned_idx]
108108
scene.draw_debug_spheres(poss=target_positions, radius=0.02, color=(1, 0, 1, 0.8))
109-
blob.set_vertex_constraints(
110-
verts_idx=pinned_idx,
111-
target_poss=target_positions,
112-
is_soft_constraint=True,
113-
stiffness=1e4,
114-
)
109+
blob.set_vertex_constraints(pinned_idx, target_positions, is_soft_constraint=True, stiffness=1e4)
115110

116111
target_positions = get_next_circle_position()
117112
debug_circle = scene.draw_debug_spheres(poss=target_positions, radius=0.02, color=(0, 1, 0, 0.8))
118-
cube.set_vertex_constraints(
119-
verts_idx=pinned_idx,
120-
target_poss=target_positions,
121-
)
113+
cube.set_vertex_constraints(pinned_idx, target_positions)
122114

123115
for step in tqdm(range(total_steps), total=total_steps):
124116
if debug_circle is not None:
125117
scene.clear_debug_object(debug_circle)
126118

127119
new_pos = get_next_circle_position()
128120
debug_circle = scene.draw_debug_spheres(poss=new_pos, radius=0.02, color=(0, 1, 0, 0.8))
129-
cube.update_constraint_targets(
130-
verts_idx=pinned_idx,
131-
target_poss=new_pos,
132-
)
121+
cube.update_constraint_targets(pinned_idx, new_pos)
133122

134123
scene.step()
135124

examples/sap_coupling/fem_fixed_constraint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def main():
6363
target_poss = cube.init_positions[verts_idx] + torch.tensor(
6464
(0.15 * (math.cos(0.04 * i) - 1.0), 0.15 * math.sin(0.04 * i), 0.0)
6565
)
66-
cube.set_vertex_constraints(verts_idx=verts_idx, target_poss=target_poss)
66+
cube.set_vertex_constraints(verts_idx, target_poss)
6767
scene.step(update_visualizer=False)
6868
if args.vis:
6969
scene.visualizer.context.draw_debug_sphere(

genesis/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from .logging import Logger
3434
from .version import __version__
3535
from .utils import redirect_libc_stderr, set_random_seed, get_platform, get_device
36-
from .utils.misc import ALLOCATE_TENSOR_WARNING
3736

3837

3938
# Global state
@@ -109,9 +108,6 @@ def init(
109108
logger.info(f"~<│{wave}>~ ~~~~<Genesis>~~~~ ~<{wave}│>~")
110109
logger.info(f"~<╰{'─'*(bar_width)}╯>~")
111110

112-
# FIXME: Disable this warning for now, because it is not useful without printing the entire traceback
113-
logger.addFilter(lambda record: record.msg != ALLOCATE_TENSOR_WARNING)
114-
115111
# Get concrete device and backend
116112
global device
117113
device, device_name, total_mem, backend = get_device(backend)

genesis/engine/entities/drone_entity.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import gstaichi as ti
66

77
import genesis as gs
8-
import genesis.utils.misc as mu
8+
from genesis.utils.misc import get_assets_dir, broadcast_tensor
99

1010
from .rigid_entity import RigidEntity
1111

@@ -16,14 +16,14 @@ def _load_scene(self, morph, surface):
1616
super()._load_scene(morph, surface)
1717

1818
# additional drone specific attributes
19-
properties = ET.parse(os.path.join(mu.get_assets_dir(), morph.file)).getroot()[0].attrib
19+
properties = ET.parse(os.path.join(get_assets_dir(), morph.file)).getroot()[0].attrib
2020
self._KF = float(properties["kf"])
2121
self._KM = float(properties["km"])
2222

2323
self._n_propellers = len(morph.propellers_link_name)
2424

2525
propellers_link = gs.List([self.get_link(name) for name in morph.propellers_link_name])
26-
self._propellers_link_idxs = torch.tensor(
26+
self._propellers_link_idx = torch.tensor(
2727
[link.idx for link in propellers_link], dtype=gs.tc_int, device=gs.device
2828
)
2929
try:
@@ -64,18 +64,22 @@ def set_propellels_rpm(self, propellels_rpm):
6464
gs.raise_exception("`set_propellels_rpm` can only be called once per step.")
6565
self._prev_prop_t = self.sim.cur_step_global
6666

67-
propellels_rpm = self.solver._process_dim(
68-
torch.as_tensor(propellels_rpm, dtype=gs.tc_float, device=gs.device)
69-
).T.contiguous()
70-
if len(propellels_rpm) != len(self._propellers_link_idxs):
71-
gs.raise_exception("Last dimension of `propellels_rpm` does not match `entity.n_propellers`.")
72-
if torch.any(propellels_rpm < 0):
73-
gs.raise_exception("`propellels_rpm` cannot be negative.")
67+
assert propellels_rpm is not None
68+
propellels_rpm, *_ = self._solver._sanitize_io_variables(
69+
propellels_rpm, self._propellers_link_idx, self._n_propellers, "propellers_link_idx"
70+
)
71+
if self._scene.n_envs == 0:
72+
propellels_rpm = propellels_rpm[None]
73+
74+
# FIXME: This check is too expensive
75+
# if (propellels_rpm < 0.0).any():
76+
# gs.raise_exception("`propellels_rpm` cannot be negative.")
77+
7478
self._propellers_revs = (self._propellers_revs + propellels_rpm) % (60 / self.solver.dt)
7579

7680
self.solver.set_drone_rpm(
7781
self._n_propellers,
78-
self._propellers_link_idxs,
82+
self._propellers_link_idx,
7983
propellels_rpm,
8084
self._propellers_spin,
8185
self.KF,
@@ -122,7 +126,7 @@ def COM_link_idx(self):
122126
@property
123127
def propellers_idx(self):
124128
"""The indices of the drone's propeller links."""
125-
return self._propellers_link_idxs
129+
return self._propellers_link_idx
126130

127131
@property
128132
def propellers_spin(self):

genesis/engine/entities/fem_entity.py

Lines changed: 60 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from genesis.engine.couplers import SAPCoupler
1414
from genesis.engine.states.cache import QueriedStates
1515
from genesis.engine.states.entities import FEMEntityState
16-
from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, to_gs_tensor, tensor_to_array
16+
from genesis.utils.misc import to_gs_tensor, tensor_to_array, broadcast_tensor
1717

1818
from .base_entity import Entity
1919

@@ -126,6 +126,34 @@ def __init__(self, scene, solver, material, morph, surface, idx, v_start=0, el_s
126126
# ----------------------------------- basic entity ops -------------------------------
127127
# ------------------------------------------------------------------------------------
128128

129+
def _sanitize_verts_idx_local(self, verts_idx_local=None, envs_idx=None):
130+
if verts_idx_local is None:
131+
verts_idx_local = range(self.n_vertices)
132+
133+
if envs_idx is None:
134+
verts_idx_local_ = broadcast_tensor(verts_idx_local, gs.tc_int, (-1,), ("envs_idx",))
135+
else:
136+
verts_idx_local_ = broadcast_tensor(verts_idx_local, gs.tc_int, (len(envs_idx), -1), ("envs_idx", ""))
137+
138+
# FIXME: This check is too expensive
139+
# if not (0 <= verts_idx_local_ & verts_idx_local_ < self.n_vertices).all():
140+
# gs.raise_exception("Elements of `verts_idx_local' are out-of-range.")
141+
142+
return verts_idx_local_.contiguous()
143+
144+
def _sanitize_verts_tensor(self, tensor, dtype, verts_idx=None, envs_idx=None, element_shape=(), *, batched=True):
145+
n_vertices = verts_idx.shape[-1] if verts_idx is not None else self.n_vertices
146+
if batched:
147+
assert envs_idx is not None
148+
batch_shape = (len(envs_idx), n_vertices)
149+
dim_names = ("envs_idx", "verts_idx", *("" for _ in element_shape))
150+
else:
151+
batch_shape = (n_vertices,)
152+
dim_names = ("verts_idx", *("" for _ in element_shape))
153+
tensor_shape = (*batch_shape, *element_shape)
154+
155+
return broadcast_tensor(tensor, dtype, tensor_shape, dim_names).contiguous()
156+
129157
def set_position(self, pos):
130158
"""
131159
Set the target position(s) for the FEM entity.
@@ -153,14 +181,14 @@ def set_position(self, pos):
153181
if pos.ndim == 1:
154182
if pos.shape == (3,):
155183
pos = self.init_positions_COM_offset + pos
156-
self._tgt["pos"] = pos.unsqueeze(0).tile((self._sim._B, 1, 1))
184+
self._tgt["pos"] = pos[None].tile((self._sim._B, 1, 1))
157185
is_valid = True
158186
elif pos.ndim == 2:
159187
if pos.shape == (self.n_vertices, 3):
160-
self._tgt["pos"] = pos.unsqueeze(0).tile((self._sim._B, 1, 1))
188+
self._tgt["pos"] = pos[None].tile((self._sim._B, 1, 1))
161189
is_valid = True
162190
elif pos.shape == (self._sim._B, 3):
163-
pos = self.init_positions_COM_offset.unsqueeze(0) + pos.unsqueeze(1)
191+
pos = self.init_positions_COM_offset[None] + pos[:, None]
164192
self._tgt["pos"] = pos
165193
is_valid = True
166194
elif pos.ndim == 3:
@@ -200,10 +228,10 @@ def set_velocity(self, vel):
200228
is_valid = True
201229
elif vel.ndim == 2:
202230
if vel.shape == (self.n_vertices, 3):
203-
self._tgt["vel"] = vel.unsqueeze(0).tile((self._sim._B, 1, 1))
231+
self._tgt["vel"] = vel[None].tile((self._sim._B, 1, 1))
204232
is_valid = True
205233
elif vel.shape == (self._sim._B, 3):
206-
self._tgt["vel"] = vel.unsqueeze(1).tile((1, self.n_vertices, 1))
234+
self._tgt["vel"] = vel[:, None].tile((1, self.n_vertices, 1))
207235
is_valid = True
208236
elif vel.ndim == 3:
209237
if vel.shape == (self._sim._B, self.n_vertices, 3):
@@ -241,7 +269,7 @@ def set_actuation(self, actu):
241269
is_valid = True
242270
elif actu.ndim == 1:
243271
if actu.shape == (n_groups,):
244-
self._tgt["actu"] = actu.unsqueeze(0).tile((self._sim._B, 1))
272+
self._tgt["actu"] = actu[None].tile((self._sim._B, 1))
245273
is_valid = True
246274
elif actu.shape == (self.n_elements,):
247275
gs.raise_exception("Cannot set per-element actuation.")
@@ -814,46 +842,16 @@ def set_muscle_direction(self, muscle_direction):
814842
muscle_direction=muscle_direction,
815843
)
816844

817-
def _sanitize_input_tensor(self, tensor, dtype, unbatched_ndim=1):
818-
_tensor = torch.as_tensor(tensor, dtype=dtype, device=gs.device)
819-
820-
if _tensor.ndim < unbatched_ndim + 1:
821-
_tensor = _tensor.repeat((self._sim._B, *((1,) * max(1, _tensor.ndim))))
822-
if self._sim._B > 1:
823-
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
824-
else:
825-
_tensor = _tensor.contiguous()
826-
if _tensor is not tensor:
827-
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
828-
829-
if len(_tensor) != self._sim._B:
830-
gs.raise_exception("Input tensor batch size must match the number of environments.")
831-
832-
if _tensor.ndim != unbatched_ndim + 1:
833-
gs.raise_exception(f"Input tensor ndim is {_tensor.ndim}, should be {unbatched_ndim + 1}.")
834-
835-
return _tensor
836-
837-
def _sanitize_input_verts_idx(self, verts_idx_local):
838-
verts_idx = self._sanitize_input_tensor(verts_idx_local, dtype=gs.tc_int, unbatched_ndim=1) + self._v_start
839-
assert ((verts_idx >= 0) & (verts_idx < self._solver.n_vertices)).all(), "Vertex indices out of bounds."
840-
return verts_idx
841-
842-
def _sanitize_input_poss(self, poss):
843-
poss = self._sanitize_input_tensor(poss, dtype=gs.tc_float, unbatched_ndim=2)
844-
assert poss.ndim == 3 and poss.shape[2] == 3, "Position tensor must have shape (B, num_verts, 3)."
845-
return poss
846-
847845
def set_vertex_constraints(
848-
self, verts_idx, target_poss=None, link=None, is_soft_constraint=False, stiffness=0.0, envs_idx=None
846+
self, verts_idx_local, target_poss=None, link=None, is_soft_constraint=False, stiffness=0.0, envs_idx=None
849847
):
850848
"""
851849
Set vertex constraints for specified vertices.
852850
853851
Parameters
854852
----------
855-
verts_idx : array_like
856-
List of vertex indices to constrain.
853+
verts_idx_local : array_like
854+
List of local vertex indices to constrain.
857855
target_poss : array_like, shape (len(verts_idx), 3), optional
858856
List of target positions [x, y, z] for each vertex. If not provided, the initial positions are used.
859857
link : RigidLink
@@ -874,30 +872,22 @@ def set_vertex_constraints(
874872
if not self._solver._constraints_initialized:
875873
self._solver.init_constraints()
876874

875+
use_current_poss = target_poss is None
877876
envs_idx = self._scene._sanitize_envs_idx(envs_idx)
878-
verts_idx = self._sanitize_input_verts_idx(verts_idx)
877+
verts_idx_local = self._sanitize_verts_idx_local(verts_idx_local, envs_idx)
878+
verts_idx = verts_idx_local + self._v_start
879+
target_poss = self._sanitize_verts_tensor(target_poss, gs.tc_float, verts_idx, envs_idx, (3,))
879880

880-
if target_poss is None:
881-
target_poss = torch.zeros(
882-
(verts_idx.shape[0], verts_idx.shape[1], 3), dtype=gs.tc_float, device=gs.device, requires_grad=False
883-
)
881+
if use_current_poss:
884882
self._kernel_get_verts_pos(self._sim.cur_substep_local, target_poss, verts_idx)
885-
target_poss = self._sanitize_input_poss(target_poss)
886-
887-
assert (
888-
len(envs_idx) == len(target_poss) == len(verts_idx)
889-
), "First dimension should match number of environments."
890-
assert target_poss.shape[1] == verts_idx.shape[1], "Target position should be provided for each vertex."
891883

892884
if link is None:
893885
link_init_pos = torch.zeros((self._sim._B, 3), dtype=gs.tc_float, device=gs.device)
894886
link_init_quat = torch.zeros((self._sim._B, 4), dtype=gs.tc_float, device=gs.device)
895887
link_idx = -1
896888
else:
897889
assert isinstance(link, RigidLink), "Only RigidLink is supported for vertex constraints."
898-
link_init_pos = self._sanitize_input_tensor(link.get_pos(), dtype=gs.tc_float)
899-
link_init_quat = self._sanitize_input_tensor(link.get_quat(), dtype=gs.tc_float)
900-
link_idx = link.idx
890+
link_init_pos, link_init_quat, link_idx = link.get_pos(), link.get_quat(), link.idx
901891

902892
self._solver._kernel_set_vertex_constraints(
903893
self._sim.cur_substep_local,
@@ -911,31 +901,36 @@ def set_vertex_constraints(
911901
envs_idx,
912902
)
913903

914-
def update_constraint_targets(self, verts_idx, target_poss, envs_idx=None):
904+
def update_constraint_targets(self, verts_idx_local, target_poss, envs_idx=None):
915905
"""Update target positions for existing constraints."""
916906
if not self._solver._constraints_initialized:
917907
gs.logger.warning("Ignoring update_constraint_targets; constraints have not been initialized.")
918908
return
919909

910+
assert target_poss is not None
920911
envs_idx = self._scene._sanitize_envs_idx(envs_idx)
921-
verts_idx = self._sanitize_input_verts_idx(verts_idx)
922-
target_poss = self._sanitize_input_poss(target_poss)
923-
assert target_poss.shape[1] == verts_idx.shape[1], "Target position should be provided for each vertex."
912+
verts_idx_local = self._sanitize_verts_idx_local(verts_idx_local, envs_idx)
913+
verts_idx = verts_idx_local + self._v_start
914+
target_poss = self._sanitize_verts_tensor(target_poss, gs.tc_float, verts_idx, envs_idx, (3,))
924915

925916
self._solver._kernel_update_constraint_targets(verts_idx, target_poss, envs_idx)
926917

927-
def remove_vertex_constraints(self, verts_idx=None, envs_idx=None):
918+
def remove_vertex_constraints(self, verts_idx_local=None, envs_idx=None):
928919
"""Remove constraints from specified vertices, or all if None."""
929920
if not self._solver._constraints_initialized:
930921
gs.logger.warning("Ignoring remove_vertex_constraints; constraints have not been initialized.")
931922
return
932923

933-
if verts_idx is None:
924+
# FIXME: GsTaichi 'fill' method is very inefficient. Try using zero-copy if possible.
925+
if verts_idx_local is None:
934926
self._solver.vertex_constraints.is_constrained.fill(0)
935-
else:
936-
verts_idx = self._sanitize_input_verts_idx(verts_idx)
937-
envs_idx = self._scene._sanitize_envs_idx(envs_idx)
938-
self._solver._kernel_remove_specific_constraints(verts_idx, envs_idx)
927+
return
928+
929+
envs_idx = self._scene._sanitize_envs_idx(envs_idx)
930+
verts_idx_local = self._sanitize_verts_idx_local(verts_idx_local, envs_idx)
931+
verts_idx = verts_idx_local + self._v_start
932+
933+
self._solver._kernel_remove_specific_constraints(verts_idx, envs_idx)
939934

940935
@ti.kernel
941936
def _kernel_get_verts_pos(self, f: ti.i32, pos: ti.types.ndarray(), verts_idx: ti.types.ndarray()):
@@ -954,14 +949,8 @@ def get_el2v(self):
954949
el2v : gs.Tensor
955950
Tensor of shape (n_elements, 4) mapping each element to its vertex indices.
956951
"""
957-
958952
el2v = gs.zeros((self.n_elements, 4), dtype=int, requires_grad=False, scene=self.scene)
959-
self._solver._kernel_get_el2v(
960-
element_el_start=self._el_start,
961-
n_elements=self.n_elements,
962-
el2v=el2v,
963-
)
964-
953+
self._solver._kernel_get_el2v(element_el_start=self._el_start, n_elements=self.n_elements, el2v=el2v)
965954
return el2v
966955

967956
@ti.kernel

0 commit comments

Comments
 (0)