@@ -34,7 +34,7 @@ def __init__(
3434 self ,
3535 meshes : Dict [str , Mesh ],
3636 batch_size : int = 1 ,
37- device : torch .device = torch . device ( "cuda" ) ,
37+ device : Union [ torch .device , str ] = "cuda" ,
3838 ) -> None :
3939 self ._names = []
4040 self ._vertices = []
@@ -47,7 +47,7 @@ def __init__(
4747 self ._upper_face_index_lookup = {}
4848
4949 # populate this container
50- self ._populate_container (meshes , device )
50+ self ._populate_container (meshes , device = device )
5151
5252 # add batch dim
5353 self ._batch_size = batch_size
@@ -56,13 +56,13 @@ def __init__(
5656 # populate index lookups
5757 self ._populate_index_lookups ()
5858
59- self ._device = device
59+ self ._device = torch . device ( device ) if isinstance ( device , str ) else device
6060
6161 @abc .abstractmethod
6262 def _populate_container (
6363 self ,
6464 meshes : Dict [str , Mesh ],
65- device : torch .device ( "cuda" ) ,
65+ device : Union [ torch .device , str ] = "cuda" ,
6666 ) -> None :
6767 offset = 0
6868 for name , mesh in meshes .items ():
@@ -161,10 +161,10 @@ def device(self) -> torch.device:
161161 def batch_size (self ) -> int :
162162 return self ._batch_size
163163
164- def to (self , device : torch .device ) -> None :
164+ def to (self , device : Union [ torch .device , str ] ) -> None :
165165 self ._vertices = self ._vertices .to (device = device )
166166 self ._faces = self ._faces .to (device = device )
167- self ._device = device
167+ self ._device = torch . device ( device ) if isinstance ( device , str ) else device
168168
169169
170170class Camera :
@@ -184,7 +184,7 @@ def __init__(
184184 resolution : Tuple [int , int ],
185185 intrinsics : Optional [Union [torch .FloatTensor , np .ndarray ]] = None ,
186186 extrinsics : Optional [Union [torch .FloatTensor , np .ndarray ]] = None ,
187- device : torch .device = torch . device ( "cuda" ) ,
187+ device : Union [ torch .device , str ] = "cuda" ,
188188 name : str = "camera" ,
189189 ) -> None :
190190 if intrinsics is None :
@@ -211,7 +211,7 @@ def __init__(
211211 self ._intrinsics = intrinsics
212212 self ._extrinsics = extrinsics
213213 self ._resolution = resolution
214- self ._device = device
214+ self ._device = torch . device ( device ) if isinstance ( device , str ) else device
215215 ht_optical_shape = (
216216 (1 ,) + extrinsics .shape [- 2 :]
217217 if extrinsics .dim () == 3
@@ -229,11 +229,11 @@ def __init__(
229229 self ._name = name
230230
231231 @abc .abstractmethod
232- def to (self , device : torch .device ) -> None :
232+ def to (self , device : Union [ torch .device , str ] ) -> None :
233233 self ._intrinsics = self ._intrinsics .to (device = device )
234234 self ._extrinsics = self ._extrinsics .to (device = device )
235235 self ._ht_optical = self ._ht_optical .to (device = device )
236- self ._device = device
236+ self ._device = torch . device ( device ) if isinstance ( device , str ) else device
237237
238238 @property
239239 def intrinsics (self ) -> torch .FloatTensor :
@@ -312,7 +312,7 @@ def __init__(
312312 extrinsics : Optional [Union [torch .FloatTensor , np .ndarray ]] = None ,
313313 zmin : float = 0.1 ,
314314 zmax : float = 100.0 ,
315- device : torch .device = torch . device ( "cuda" ) ,
315+ device : Union [ torch .device , str ] = "cuda" ,
316316 ) -> None :
317317 super ().__init__ (resolution , intrinsics , extrinsics , device )
318318
@@ -347,7 +347,7 @@ def __init__(
347347 self ._perspective_projection [..., 2 , 3 ] = 2.0 * zmax * zmin / (zmin - zmax )
348348 self ._perspective_projection [..., 3 , 2 ] = 1.0
349349
350- def to (self , device : torch .device ) -> None :
350+ def to (self , device : Union [ torch .device , str ] ) -> None :
351351 self ._perspective_projection = self ._perspective_projection .to (device = device )
352352 super ().to (device = device )
353353
0 commit comments