Skip to content

Commit 72946ea

Browse files
committed
fix tests
1 parent 9c6cb00 commit 72946ea

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

stochman/curves.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ def __init__(
1313
end: torch.Tensor,
1414
num_nodes: int = 5,
1515
requires_grad: bool = True,
16-
device=None,
1716
*args,
1817
**kwargs,
1918
) -> None:
@@ -201,7 +200,7 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
201200
(torch.floor(tt * num_edges).clamp(min=0, max=num_edges - 1).long()) # Bx|t|
202201
.unsqueeze(2)
203202
.repeat(1, 1, D)
204-
) # Bx|t|xD, this assumes that nodes are equi-distant
203+
).to(self.device) # Bx|t|xD, this assumes that nodes are equi-distant
205204
result = torch.gather(a, 1, idx) * tt.unsqueeze(2) + torch.gather(b, 1, idx) # Bx|t|xD
206205
if B == 1:
207206
result = result.squeeze(0) # |t|xD
@@ -221,7 +220,7 @@ def __getitem__(self, indices: int) -> "DiscreteCurve":
221220
return C
222221

223222
def __setitem__(self, indices, curves) -> None:
224-
self.params[indices] = curves.params.squeeze()
223+
self.params[indices].data = curves.params.squeeze()
225224

226225
def constant_speed(
227226
self, metric=None, t: Optional[torch.Tensor] = None
@@ -399,7 +398,7 @@ def __getitem__(self, indices: int) -> "CubicSpline":
399398
return C
400399

401400
def __setitem__(self, indices, curves) -> None:
402-
self.params[indices] = curves.params
401+
self.params[indices].data = curves.params
403402

404403
def deriv(self, t: Optional[torch.Tensor] = None) -> torch.Tensor:
405404
"""

tests/test_curves.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
class TestCurves:
99
@pytest.mark.parametrize("requires_grad", [True, False])
1010
@pytest.mark.parametrize("batch_dim", [1, 5])
11-
@pytest.mark.parametrize("device", ["cpu", "cuda"])
11+
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
1212
def test_curve_evaluation(self, curve_class, requires_grad, batch_dim, device):
1313
if not torch.cuda.is_available() and device == "cuda":
1414
pytest.skip("test requires cuda")
@@ -27,7 +27,7 @@ def test_curve_evaluation(self, curve_class, requires_grad, batch_dim, device):
2727
assert c.params.device == torch.device(device)
2828

2929
eval_nodes = 10
30-
t = torch.linspace(0, 1, eval_nodes)
30+
t = torch.linspace(0, 1, eval_nodes).to(device)
3131
out = c(t)
3232
assert isinstance(out, torch.Tensor)
3333
assert out.device == torch.device(device)

0 commit comments

Comments
 (0)