@@ -76,7 +76,6 @@ def __init__(self, interval_depth=6, quantile_divisor=4):
7676
7777 def _fit (self , X , y = None ):
7878 import torch
79- import torch .nn .functional as F
8079
8180 X = torch .tensor (X ).float ()
8281
@@ -85,17 +84,19 @@ def _fit(self, X, y=None):
8584 if self .interval_depth < 1 :
8685 raise ValueError ("interval_depth must be >= 1" )
8786
87+ in_length = X .shape [- 1 ]
88+
8889 representation_functions = (
89- lambda X : X ,
90- lambda X : F .avg_pool1d (F .pad (X .diff (), (2 , 2 ), "replicate" ), 5 , 1 ),
91- lambda X : X .diff (n = 2 ),
92- lambda X : torch .fft .rfft (X ).abs (),
90+ in_length , # lambda X: X
91+ in_length
92+ - 1 , # lambda X: F.avg_pool1d(F.pad(X.diff(), (2, 2), "replicate"), 5, 1)
93+ in_length - 2 , # lambda X: X.diff(n=2)
94+ in_length // 2 + 1 , # lambda X: torch.fft.rfft(X).abs()
9395 )
94-
9596 self .intervals_ = []
96- for function in representation_functions :
97- Z = function ( X )
98- self .intervals_ .append (self ._make_intervals (input_length = Z . shape [ - 1 ] ))
97+
98+ for length in representation_functions :
99+ self .intervals_ .append (self ._make_intervals (input_length = length ))
99100
100101 return self
101102
0 commit comments