Skip to content

Commit 724cdb0

Browse files
authored
Use shape calculation in _fit to optimize QUANTTransformer (#2727)
1 parent 69f769e commit 724cdb0

File tree

1 file changed

+10
-9
lines changed
  • aeon/transformations/collection/interval_based

1 file changed

+10
-9
lines changed

aeon/transformations/collection/interval_based/_quant.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def __init__(self, interval_depth=6, quantile_divisor=4):
7676

7777
def _fit(self, X, y=None):
7878
import torch
79-
import torch.nn.functional as F
8079

8180
X = torch.tensor(X).float()
8281

@@ -85,17 +84,19 @@ def _fit(self, X, y=None):
8584
if self.interval_depth < 1:
8685
raise ValueError("interval_depth must be >= 1")
8786

87+
in_length = X.shape[-1]
88+
8889
representation_functions = (
89-
lambda X: X,
90-
lambda X: F.avg_pool1d(F.pad(X.diff(), (2, 2), "replicate"), 5, 1),
91-
lambda X: X.diff(n=2),
92-
lambda X: torch.fft.rfft(X).abs(),
90+
in_length, # lambda X: X
91+
in_length
92+
- 1, # lambda X: F.avg_pool1d(F.pad(X.diff(), (2, 2), "replicate"), 5, 1)
93+
in_length - 2, # lambda X: X.diff(n=2)
94+
in_length // 2 + 1, # lambda X: torch.fft.rfft(X).abs()
9395
)
94-
9596
self.intervals_ = []
96-
for function in representation_functions:
97-
Z = function(X)
98-
self.intervals_.append(self._make_intervals(input_length=Z.shape[-1]))
97+
98+
for length in representation_functions:
99+
self.intervals_.append(self._make_intervals(input_length=length))
99100

100101
return self
101102

0 commit comments

Comments
 (0)