Skip to content

Commit cd25491

Browse files
authored
Rename lagmat to sliding_window (#115)
1 parent 28cd562 commit cd25491

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

tests/test_lagmat.py renamed to tests/test_sliding_window.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import torch
33

4-
from torchts.utils.data import lagmat
4+
from torchts.utils.data import sliding_window
55

66

77
@pytest.fixture
@@ -13,7 +13,7 @@ def tensor():
1313
@pytest.mark.parametrize("lag", [2, 5, [1, 2, 3], {1, 2, 3}, [1, 3, 5]])
1414
@pytest.mark.parametrize("horizon", [1, 2])
1515
def test_shape(tensor, lag, horizon):
16-
x, y = lagmat(tensor, lag, horizon=horizon)
16+
x, y = sliding_window(tensor, lag, horizon=horizon)
1717

1818
if isinstance(lag, int):
1919
rows = len(tensor) - lag - horizon + 1
@@ -33,7 +33,7 @@ def test_shape(tensor, lag, horizon):
3333
@pytest.mark.parametrize("lag", [2, 5, [1, 2, 3], {1, 2, 3}, [1, 3, 5]])
3434
@pytest.mark.parametrize("horizon", [1, 2])
3535
def test_value(tensor, lag, horizon):
36-
x, y = lagmat(tensor, lag, horizon=horizon)
36+
x, y = sliding_window(tensor, lag, horizon=horizon)
3737

3838
if isinstance(lag, int):
3939
for i in range(x.shape[0]):
@@ -49,10 +49,10 @@ def test_value(tensor, lag, horizon):
4949
@pytest.mark.parametrize("lag", ["1", 1.0, ["1"], [1, "2", 3], {1, 2.0, 3}])
5050
def test_non_int(tensor, lag):
5151
with pytest.raises(TypeError):
52-
lagmat(tensor, lag)
52+
sliding_window(tensor, lag)
5353

5454

5555
@pytest.mark.parametrize("lag", [-1, 0, [0, 1, 2], {0, 1, 2}, [-1, 1, 2]])
5656
def test_non_positive(tensor, lag):
5757
with pytest.raises(ValueError):
58-
lagmat(tensor, lag)
58+
sliding_window(tensor, lag)

torchts/utils/data.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,6 @@ def concat(a, b):
3737
return torch.cat([a, b.unsqueeze(0)], dim=0)
3838

3939

40-
def lagmat(tensor, lags, horizon=1, dim=0, step=1):
41-
is_int = isinstance(lags, int)
42-
is_iter = isinstance(lags, Iterable) and all(isinstance(lag, int) for lag in lags)
43-
44-
if not is_int and not is_iter:
45-
raise TypeError("lags must be of type int or Iterable[int]")
46-
47-
if (is_int and lags < 1) or (is_iter and any(lag < 1 for lag in lags)):
48-
raise ValueError(f"lags must be positive but found {lags}")
49-
50-
if is_int:
51-
data = tensor.unfold(dim, lags + horizon, step)
52-
x, y = data[:, :lags], data[:, -1]
53-
else:
54-
data = tensor.unfold(dim, max(lags) + horizon, step)
55-
x, y = data[:, [lag - 1 for lag in lags]], data[:, -1]
56-
57-
return x, y
58-
59-
6040
def load_dataset(dataset_dir, batch_size, val_batch_size=None, test_batch_size=None):
6141
if val_batch_size is None:
6242
val_batch_size = batch_size
@@ -108,3 +88,23 @@ def load_pickle(pickle_file):
10888
raise e
10989

11090
return pickle_data
91+
92+
93+
def sliding_window(tensor, lags, horizon=1, dim=0, step=1):
94+
is_int = isinstance(lags, int)
95+
is_iter = isinstance(lags, Iterable) and all(isinstance(lag, int) for lag in lags)
96+
97+
if not is_int and not is_iter:
98+
raise TypeError("lags must be of type int or Iterable[int]")
99+
100+
if (is_int and lags < 1) or (is_iter and any(lag < 1 for lag in lags)):
101+
raise ValueError(f"lags must be positive but found {lags}")
102+
103+
if is_int:
104+
data = tensor.unfold(dim, lags + horizon, step)
105+
x, y = data[:, :lags], data[:, -1]
106+
else:
107+
data = tensor.unfold(dim, max(lags) + horizon, step)
108+
x, y = data[:, [lag - 1 for lag in lags]], data[:, -1]
109+
110+
return x, y

0 commit comments

Comments
 (0)