Skip to content

Commit b0cede5

Browse files
committed
updated device types
1 parent e6db943 commit b0cede5

File tree

8 files changed

+48
-38
lines changed

8 files changed

+48
-38
lines changed

cli/rr_cam_swarm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import os
3+
from typing import Union
34

45
import cv2
56
import numpy as np
@@ -179,7 +180,7 @@ def instantiate_particles(
179180
eye_min_dist: float,
180181
eye_max_dist: float,
181182
angle_interval: float,
182-
device: torch.device = torch.device("cuda"),
183+
device: Union[torch.device, str] = "cuda",
183184
) -> torch.Tensor:
184185
r"""Instantiate the particles for the optimization randomly under field of view constraints.
185186
Particles (camera poses) are represented using eye space coordinates (eye, center, angle).
@@ -193,7 +194,7 @@ def instantiate_particles(
193194
eye_min_dist (float): The minimum distance of the eye from the origin.
194195
eye_max_dist (float): The maximum distance of the eye from the origin.
195196
angle_interval (float): The angle interval in which to sample the rotation angle.
196-
device (torch.device): The device to instantiate the particles on.
197+
device (Union[torch.device, str]): The device to instantiate the particles on.
197198
198199
Returns:
199200
torch.Tensor: The particles of shape (n_particles, 7).

roboreg/differentiable/kinematics.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict
1+
from typing import Dict, Union
22

33
import pytorch_kinematics as pk
44
import torch
@@ -17,7 +17,7 @@ def __init__(
1717
urdf: str,
1818
root_link_name: str,
1919
end_link_name: str,
20-
device: torch.device = torch.device("cuda"),
20+
device: Union[torch.device, str] = "cuda",
2121
) -> None:
2222
self._root_link_name = root_link_name
2323
self._end_link_name = end_link_name
@@ -26,7 +26,7 @@ def __init__(
2626
root_link_name=self._root_link_name,
2727
end_link_name=self._end_link_name,
2828
)
29-
self._device = device
29+
self._device = torch.device(device) if isinstance(device, str) else device
3030
self.to(device=self._device)
3131

3232
def _build_serial_chain_from_urdf(
@@ -36,9 +36,9 @@ def _build_serial_chain_from_urdf(
3636
urdf, end_link_name=end_link_name, root_link_name=root_link_name
3737
)
3838

39-
def to(self, device: torch.device) -> None:
39+
def to(self, device: Union[torch.device, str]) -> None:
4040
self._chain.to(device=device)
41-
self._device = device
41+
self._device = torch.device(device) if isinstance(device, str) else device
4242

4343
def forward_kinematics(self, q: torch.Tensor) -> Dict[str, torch.Tensor]:
4444
ht_lookup = {

roboreg/differentiable/rendering.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple
1+
from typing import List, Tuple, Union
22

33
import nvdiffrast.torch as dr
44
import torch
@@ -13,11 +13,11 @@ class NVDiffRastRenderer:
1313

1414
__slots__ = ["_device", "_ctx"]
1515

16-
def __init__(self, device: torch.device = torch.device("cuda")) -> None:
16+
def __init__(self, device: Union[torch.device, str] = "cuda") -> None:
1717
super().__init__()
1818
if not torch.cuda.is_available():
1919
raise ValueError("CUDA is not available.")
20-
self._device = device
20+
self._device = torch.device(device) if isinstance(device, str) else device
2121
self._ctx = dr.RasterizeCudaContext(device=self._device)
2222

2323
def scale_clip_vertices(

roboreg/differentiable/robot.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Union
2+
13
import torch
24

35
from roboreg.io import URDFParser
@@ -13,12 +15,12 @@ def __init__(
1315
self,
1416
mesh_container: TorchMeshContainer,
1517
kinematics: TorchKinematics,
16-
device: torch.device = torch.device("cuda"),
18+
device: Union[torch.device, str] = "cuda",
1719
) -> None:
1820
self._mesh_container = mesh_container
1921
self._kinematics = kinematics
2022
self._configured_vertices = self.mesh_container.vertices.clone()
21-
self._device = device
23+
self._device = torch.device(device) if isinstance(device, str) else device
2224
self.to(device=self._device)
2325

2426
@classmethod
@@ -29,7 +31,7 @@ def from_urdf_parser(
2931
end_link_name: str,
3032
collision: bool = False,
3133
batch_size: int = 1,
32-
device: torch.device = torch.device("cuda"),
34+
device: Union[torch.device, str] = "cuda",
3335
target_reduction: float = 0.0,
3436
) -> "Robot":
3537
from roboreg.io import apply_mesh_origins, load_meshes, simplify_meshes
@@ -104,11 +106,11 @@ def configure(
104106
ht_root.transpose(-1, -2),
105107
)
106108

107-
def to(self, device: torch.device) -> None:
109+
def to(self, device: Union[torch.device, str]) -> None:
108110
self._mesh_container.to(device=device)
109111
self._kinematics.to(device=device)
110112
self._configured_vertices = self._configured_vertices.to(device=device)
111-
self._device = device
113+
self._device = torch.device(device) if isinstance(device, str) else device
112114

113115
@property
114116
def kinematics(self) -> TorchKinematics:

roboreg/differentiable/structs.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434
self,
3535
meshes: Dict[str, Mesh],
3636
batch_size: int = 1,
37-
device: torch.device = torch.device("cuda"),
37+
device: Union[torch.device, str] = "cuda",
3838
) -> None:
3939
self._names = []
4040
self._vertices = []
@@ -47,7 +47,7 @@ def __init__(
4747
self._upper_face_index_lookup = {}
4848

4949
# populate this container
50-
self._populate_container(meshes, device)
50+
self._populate_container(meshes, device=device)
5151

5252
# add batch dim
5353
self._batch_size = batch_size
@@ -56,13 +56,13 @@ def __init__(
5656
# populate index lookups
5757
self._populate_index_lookups()
5858

59-
self._device = device
59+
self._device = torch.device(device) if isinstance(device, str) else device
6060

6161
@abc.abstractmethod
6262
def _populate_container(
6363
self,
6464
meshes: Dict[str, Mesh],
65-
device: torch.device("cuda"),
65+
device: Union[torch.device, str] = "cuda",
6666
) -> None:
6767
offset = 0
6868
for name, mesh in meshes.items():
@@ -161,10 +161,10 @@ def device(self) -> torch.device:
161161
def batch_size(self) -> int:
162162
return self._batch_size
163163

164-
def to(self, device: torch.device) -> None:
164+
def to(self, device: Union[torch.device, str]) -> None:
165165
self._vertices = self._vertices.to(device=device)
166166
self._faces = self._faces.to(device=device)
167-
self._device = device
167+
self._device = torch.device(device) if isinstance(device, str) else device
168168

169169

170170
class Camera:
@@ -184,7 +184,7 @@ def __init__(
184184
resolution: Tuple[int, int],
185185
intrinsics: Optional[Union[torch.FloatTensor, np.ndarray]] = None,
186186
extrinsics: Optional[Union[torch.FloatTensor, np.ndarray]] = None,
187-
device: torch.device = torch.device("cuda"),
187+
device: Union[torch.device, str] = "cuda",
188188
name: str = "camera",
189189
) -> None:
190190
if intrinsics is None:
@@ -211,7 +211,7 @@ def __init__(
211211
self._intrinsics = intrinsics
212212
self._extrinsics = extrinsics
213213
self._resolution = resolution
214-
self._device = device
214+
self._device = torch.device(device) if isinstance(device, str) else device
215215
ht_optical_shape = (
216216
(1,) + extrinsics.shape[-2:]
217217
if extrinsics.dim() == 3
@@ -229,11 +229,11 @@ def __init__(
229229
self._name = name
230230

231231
@abc.abstractmethod
232-
def to(self, device: torch.device) -> None:
232+
def to(self, device: Union[torch.device, str]) -> None:
233233
self._intrinsics = self._intrinsics.to(device=device)
234234
self._extrinsics = self._extrinsics.to(device=device)
235235
self._ht_optical = self._ht_optical.to(device=device)
236-
self._device = device
236+
self._device = torch.device(device) if isinstance(device, str) else device
237237

238238
@property
239239
def intrinsics(self) -> torch.FloatTensor:
@@ -312,7 +312,7 @@ def __init__(
312312
extrinsics: Optional[Union[torch.FloatTensor, np.ndarray]] = None,
313313
zmin: float = 0.1,
314314
zmax: float = 100.0,
315-
device: torch.device = torch.device("cuda"),
315+
device: Union[torch.device, str] = "cuda",
316316
) -> None:
317317
super().__init__(resolution, intrinsics, extrinsics, device)
318318

@@ -347,7 +347,7 @@ def __init__(
347347
self._perspective_projection[..., 2, 3] = 2.0 * zmax * zmin / (zmin - zmax)
348348
self._perspective_projection[..., 3, 2] = 1.0
349349

350-
def to(self, device: torch.device) -> None:
350+
def to(self, device: Union[torch.device, str]) -> None:
351351
self._perspective_projection = self._perspective_projection.to(device=device)
352352
super().to(device=device)
353353

roboreg/io/parsers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, urdf: str) -> None:
1818

1919
@classmethod
2020
def from_file(cls, path: Union[Path, str]) -> None:
21-
r"""Instantiate URDF parser from URDF string.
21+
r"""Instantiate URDF parser path to URDF file.
2222
2323
Args:
2424
path (Union[Path, str]): Path to URDF file.

roboreg/segmentor.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, Union
22

33
import numpy as np
44
import torch
@@ -8,9 +8,11 @@
88
class Segmentor(object):
99
__slots__ = ["_model", "_pth", "_device"]
1010

11-
def __init__(self, pth: float = 0.5, device: str = "cuda") -> None:
11+
def __init__(
12+
self, pth: float = 0.5, device: Union[torch.device, str] = "cuda"
13+
) -> None:
1214
self._pth = pth
13-
self._device = device
15+
self._device = torch.device(device) if isinstance(device, str) else device
1416

1517
@property
1618
def pth(self) -> float:
@@ -28,7 +30,7 @@ def __init__(
2830
self,
2931
model_id: str = "facebook/sam2-hiera-large",
3032
pth: float = 0.5,
31-
device: str = "cuda",
33+
device: Union[torch.device, str] = "cuda",
3234
) -> None:
3335
super().__init__(pth=pth, device=device)
3436
self._model: SAM2ImagePredictor = SAM2ImagePredictor.from_pretrained(model_id)
@@ -37,7 +39,10 @@ def __call__(
3739
self, img: np.ndarray, input_points: np.ndarray, input_labels: np.ndarray
3840
) -> np.ndarray:
3941
self._model.set_image(img)
40-
with torch.inference_mode(), torch.autocast(self._device, dtype=torch.bfloat16):
42+
with (
43+
torch.inference_mode(),
44+
torch.autocast(device_type=self._device.type, dtype=torch.bfloat16),
45+
):
4146
mask_logits, _, _ = self._model.predict(
4247
point_coords=input_points,
4348
point_labels=input_labels,

test/differentiable/test_structs.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@ def test_torch_mesh_container() -> None:
2121
assert container.vertices.size()[1] == sum(
2222
list(container.per_mesh_vertex_count.values())
2323
), "Expected vertex count to match."
24-
assert container.device == device, f"Expected container on '{device}' device."
25-
assert (
26-
container.vertices[list(paths.keys())[0]].device == device
24+
assert container.device == torch.device(
25+
device
26+
), f"Expected container on '{device}' device."
27+
assert container.vertices.device == torch.device(
28+
device
2729
), f"Expected vertices on '{device}' device."
28-
assert (
29-
container.faces[list(paths.keys())[0]].device == device
30+
assert container.faces.device == torch.device(
31+
device
3032
), f"Expected faces on '{device}' device."
3133

3234

0 commit comments

Comments
 (0)