11import pytest
22import torch
33
4- from torchts .utils .data import lagmat
4+ from torchts .utils .data import sliding_window
55
66
77@pytest .fixture
@@ -13,7 +13,7 @@ def tensor():
1313@pytest .mark .parametrize ("lag" , [2 , 5 , [1 , 2 , 3 ], {1 , 2 , 3 }, [1 , 3 , 5 ]])
1414@pytest .mark .parametrize ("horizon" , [1 , 2 ])
1515def test_shape (tensor , lag , horizon ):
16- x , y = lagmat (tensor , lag , horizon = horizon )
16+ x , y = sliding_window (tensor , lag , horizon = horizon )
1717
1818 if isinstance (lag , int ):
1919 rows = len (tensor ) - lag - horizon + 1
@@ -33,7 +33,7 @@ def test_shape(tensor, lag, horizon):
3333@pytest .mark .parametrize ("lag" , [2 , 5 , [1 , 2 , 3 ], {1 , 2 , 3 }, [1 , 3 , 5 ]])
3434@pytest .mark .parametrize ("horizon" , [1 , 2 ])
3535def test_value (tensor , lag , horizon ):
36- x , y = lagmat (tensor , lag , horizon = horizon )
36+ x , y = sliding_window (tensor , lag , horizon = horizon )
3737
3838 if isinstance (lag , int ):
3939 for i in range (x .shape [0 ]):
@@ -49,10 +49,10 @@ def test_value(tensor, lag, horizon):
4949@pytest .mark .parametrize ("lag" , ["1" , 1.0 , ["1" ], [1 , "2" , 3 ], {1 , 2.0 , 3 }])
5050def test_non_int (tensor , lag ):
5151 with pytest .raises (TypeError ):
52- lagmat (tensor , lag )
52+ sliding_window (tensor , lag )
5353
5454
5555@pytest .mark .parametrize ("lag" , [- 1 , 0 , [0 , 1 , 2 ], {0 , 1 , 2 }, [- 1 , 1 , 2 ]])
5656def test_non_positive (tensor , lag ):
5757 with pytest .raises (ValueError ):
58- lagmat (tensor , lag )
58+ sliding_window (tensor , lag )
0 commit comments