Skip to content

Commit 279a67c

Browse files
committed
Update constant_speed test to include the new output argument
1 parent aac69ed commit 279a67c

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tests/test_curves.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,13 @@ def test_constant_speed(self, curve_class):
107107
begin = torch.randn(batch_size, dim)
108108
end = torch.randn(batch_size, dim)
109109
c = curve_class(begin, end, 20)
110-
new_t, Ct = c.constant_speed(t=torch.linspace(0, 1, timesteps))
110+
new_t, Ct, curve_length = c.constant_speed(t=torch.linspace(0, 1, timesteps))
111111
assert isinstance(new_t, torch.Tensor)
112112
assert isinstance(Ct, torch.Tensor)
113+
assert isinstance(curve_length, torch.Tensor)
113114
assert new_t.shape == (batch_size, timesteps)
114115
assert Ct.shape == (batch_size, timesteps, dim)
116+
assert curve_length.shape == (batch_size, )
115117

116118
def test_plotting_in_axis(self, curve_class):
117119
batch_size = 5

0 commit comments

Comments
 (0)