@@ -13,7 +13,6 @@ def __init__(
13
13
end : torch .Tensor ,
14
14
num_nodes : int = 5 ,
15
15
requires_grad : bool = True ,
16
- device = None ,
17
16
* args ,
18
17
** kwargs ,
19
18
) -> None :
@@ -201,7 +200,7 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
201
200
(torch .floor (tt * num_edges ).clamp (min = 0 , max = num_edges - 1 ).long ()) # Bx|t|
202
201
.unsqueeze (2 )
203
202
.repeat (1 , 1 , D )
204
- ) # Bx|t|xD, this assumes that nodes are equi-distant
203
+ ). to ( self . device ) # Bx|t|xD, this assumes that nodes are equi-distant
205
204
result = torch .gather (a , 1 , idx ) * tt .unsqueeze (2 ) + torch .gather (b , 1 , idx ) # Bx|t|xD
206
205
if B == 1 :
207
206
result = result .squeeze (0 ) # |t|xD
@@ -221,7 +220,7 @@ def __getitem__(self, indices: int) -> "DiscreteCurve":
221
220
return C
222
221
223
222
def __setitem__ (self , indices , curves ) -> None :
224
- self .params [indices ] = curves .params .squeeze ()
223
+ self .params [indices ]. data = curves .params .squeeze ()
225
224
226
225
def constant_speed (
227
226
self , metric = None , t : Optional [torch .Tensor ] = None
@@ -399,7 +398,7 @@ def __getitem__(self, indices: int) -> "CubicSpline":
399
398
return C
400
399
401
400
def __setitem__ (self , indices , curves ) -> None :
402
- self .params [indices ] = curves .params
401
+ self .params [indices ]. data = curves .params
403
402
404
403
def deriv (self , t : Optional [torch .Tensor ] = None ) -> torch .Tensor :
405
404
"""
0 commit comments