11from .ReduceOp import ReduceAdd , ReduceMul , ReduceArgMax , ReduceRank , ReduceMin , ReduceMax , ReduceDecayLinear , ReduceArgMin
22from KunQuant .Op import ConstantOp , OpBase , CompositiveOp , WindowedTrait , ForeachBackWindow , WindowedTempOutput , Builder , IterValue , WindowLoopIndex
33from .ElewiseOp import And , DivConst , GreaterThan , LessThan , Or , Select , SetInfOrNanToValue , Sub , Mul , Sqrt , SubConst , Div , CmpOp , Exp , Log , Min , Max , Equals , Abs
4- from .MiscOp import Accumulator , BackRef , SetAccumulator , WindowedLinearRegression , WindowedLinearRegressionResiImpl , WindowedLinearRegressionRSqaureImpl , WindowedLinearRegressionSlopeImpl , ReturnFirstValue
4+ from .MiscOp import Accumulator , BackRef , SetAccumulator , WindowedLinearRegression , WindowedLinearRegressionResiImpl ,\
5+ WindowedLinearRegressionRSqaureImpl , WindowedLinearRegressionSlopeImpl , ReturnFirstValue , SkipListState , SkipListQuantile , SkipListRank , SkipListMin , SkipListMax ,\
6+ SkipListArgMin
57from collections import OrderedDict
68from typing import Union , List , Tuple , Dict
79import math
810
911def _is_fast_stat (opt : dict , attrs : dict ) -> bool :
1012 return not opt .get ("no_fast_stat" , True ) and not attrs .get ("no_fast_stat" , False )
1113
14+ def _decide_use_skip_list (window : int , blocking_len : int ) -> bool :
15+ naive_cost = window
16+ skip_list_cost = math .log2 (window ) * blocking_len * 5
17+ return skip_list_cost < naive_cost
18+
1219class WindowedCompositiveOp (CompositiveOp , WindowedTrait ):
1320 def __init__ (self , v : OpBase , window : int , v2 = None ) -> None :
1421 inputs = [v ]
@@ -17,7 +24,7 @@ def __init__(self, v: OpBase, window: int, v2 = None) -> None:
1724 super ().__init__ (inputs , [("window" , window )])
1825
1926class WindowedReduce (WindowedCompositiveOp ):
20- def make_reduce (self , v : OpBase ) -> OpBase :
27+ def make_reduce (self , v : OpBase , newest : OpBase ) -> OpBase :
2128 raise RuntimeError ("Not implemented" )
2229
2330 def decompose (self , options : dict ) -> List [OpBase ]:
@@ -26,7 +33,7 @@ def decompose(self, options: dict) -> List[OpBase]:
2633 v0 = WindowedTempOutput (self .inputs [0 ], self .attrs ["window" ])
2734 v1 = ForeachBackWindow (v0 , self .attrs ["window" ])
2835 itr = IterValue (v1 , v0 )
29- v2 = self .make_reduce (itr )
36+ v2 = self .make_reduce (itr , self . inputs [ 0 ] )
3037 return b .ops
3138
3239class WindowedSum (WindowedReduce ):
@@ -35,7 +42,7 @@ class WindowedSum(WindowedReduce):
3542 For indices < window-1, the output will be NaN
3643 similar to pandas.DataFrame.rolling(n).sum()
3744 '''
38- def make_reduce (self , v : OpBase ) -> OpBase :
45+ def make_reduce (self , v : OpBase , newest : OpBase ) -> OpBase :
3946 return ReduceAdd (v )
4047
4148class WindowedProduct (WindowedReduce ):
@@ -44,26 +51,55 @@ class WindowedProduct(WindowedReduce):
4451 For indices < window-1, the output will be NaN
4552 similar to pandas.DataFrame.rolling(n).product()
4653 '''
47- def make_reduce (self , v : OpBase ) -> OpBase :
54+ def make_reduce (self , v : OpBase , newest : OpBase ) -> OpBase :
4855 return ReduceMul (v )
4956
50- class WindowedMin (WindowedReduce ):
57+ class _WindowedMinMaxBase (WindowedReduce ):
58+ '''
59+ Base class for windowed min/max ops that can use skip list. If the window is small enough, use naive linear scan. Otherwise, use skip list
60+ with Log(window) complexity.
61+ '''
62+ def on_skip_list (self , skplist : SkipListState , cur : OpBase ) -> OpBase :
63+ raise RuntimeError ("Not implemented" )
64+
65+ def decompose (self , options : dict ) -> List [OpBase ]:
66+ window = self .attrs ["window" ]
67+ blocking_len = options ["blocking_len" ]
68+ if _decide_use_skip_list (window , blocking_len ):
69+ b = Builder (self .get_parent ())
70+ with b :
71+ newv = self .inputs [0 ]
72+ oldv = BackRef (newv , window )
73+ v2 = SkipListState (oldv , newv , window )
74+ self .on_skip_list (v2 , newv )
75+ return b .ops
76+ else :
77+ return super ().decompose (options )
78+
79+ class WindowedMin (_WindowedMinMaxBase ):
5180 '''
5281 Min of a rolling look back window, including the current newest data.
5382 For indices < window-1, the output will be NaN
5483 similar to pandas.DataFrame.rolling(n).min()
5584 '''
56- def make_reduce (self , v : OpBase ) -> OpBase :
85+ def make_reduce (self , v : OpBase , newest : OpBase ) -> OpBase :
5786 return ReduceMin (v )
5887
59- class WindowedMax (WindowedReduce ):
88+ def on_skip_list (self , skplist : SkipListState , cur : OpBase ) -> OpBase :
89+ return SkipListMin (skplist )
90+
91+
92+ class WindowedMax (_WindowedMinMaxBase ):
6093 '''
6194 Max of a rolling look back window, including the current newest data.
6295 For indices < window-1, the output will be NaN
6396 similar to pandas.DataFrame.rolling(n).max()
6497 '''
65- def make_reduce (self , v : OpBase ) -> OpBase :
98+ def make_reduce (self , v : OpBase , newest : OpBase ) -> OpBase :
6699 return ReduceMax (v )
100+
101+ def on_skip_list (self , skplist : SkipListState , cur : OpBase ) -> OpBase :
102+ return SkipListMax (skplist )
67103
68104# small - sum
69105def _kahan_sub (mask : OpBase , sum : OpBase , small : OpBase , compensation : OpBase ) -> Union [OpBase , OpBase ]:
@@ -424,51 +460,54 @@ def decompose(self, options: dict) -> List[OpBase]:
424460 return b .ops
425461
426462
427- class TsArgMax (WindowedCompositiveOp ):
463+ class TsArgMax (WindowedReduce ):
428464 '''
429465 ArgMax in a rolling look back window, including the current newest data.
430466 The result should be the index of the max element in the rolling window. The index of the oldest element of the rolling window is 1.
431467 Similar to df.rolling(window).apply(np.argmax) + 1
432468 '''
433469 def decompose (self , options : dict ) -> List [OpBase ]:
434- b = Builder (self .get_parent ())
435- with b :
436- v0 = WindowedTempOutput (self .inputs [0 ], self .attrs ["window" ])
437- v1 = ForeachBackWindow (v0 , self .attrs ["window" ])
438- v2 = ReduceArgMax (IterValue (v1 , v0 ))
439- v3 = SubConst (v2 , self .attrs ["window" ], True )
440- return b .ops
470+ window = self .attrs ["window" ]
471+ blocking_len = options ["blocking_len" ]
472+ if _decide_use_skip_list (window , blocking_len ):
473+ b = Builder (self .get_parent ())
474+ with b :
475+ TsArgMin (0 - self .inputs [0 ], window )
476+ return b .ops
477+ else :
478+ return super ().decompose (options )
479+
480+ def make_reduce (self , v : OpBase , newest : OpBase ) -> OpBase :
481+ v2 = ReduceArgMax (v )
482+ return self .attrs ["window" ] - v2
441483
442- class TsArgMin (WindowedCompositiveOp ):
484+ class TsArgMin (_WindowedMinMaxBase ):
443485 '''
444486 ArgMin in a rolling look back window, including the current newest data.
445487 The result should be the index of the min element in the rolling window. The index of the oldest element of the rolling window is 1.
446488 Similar to df.rolling(window).apply(np.argmin) + 1
447489 '''
448- def decompose (self , options : dict ) -> List [OpBase ]:
449- b = Builder (self .get_parent ())
450- with b :
451- v0 = WindowedTempOutput (self .inputs [0 ], self .attrs ["window" ])
452- v1 = ForeachBackWindow (v0 , self .attrs ["window" ])
453- v2 = ReduceArgMin (IterValue (v1 , v0 ))
454- v3 = SubConst (v2 , self .attrs ["window" ], True )
455- return b .ops
490+ def on_skip_list (self , skplist : SkipListState , cur : OpBase ) -> OpBase :
491+ return SkipListArgMin ([skplist ], [])
492+
493+ def make_reduce (self , v : OpBase , newest : OpBase ) -> OpBase :
494+ v2 = ReduceArgMin (v )
495+ return self .attrs ["window" ] - v2
496+
456497
457- class TsRank (WindowedCompositiveOp ):
498+ class TsRank (_WindowedMinMaxBase ):
458499 '''
459500 Time series rank of the newest data in a rolling look back window, including the current newest data.
460501 Let num_values_less = the number of values in rolling window that is less than the current newest data.
461502 Let num_values_eq = the number of values in rolling window that is equal to the current newest data.
462503 rank = num_values_less + (num_values_eq + 1) / 2
463504 Similar to df.rolling(window).rank()
464505 '''
465- def decompose (self , options : dict ) -> List [OpBase ]:
466- b = Builder (self .get_parent ())
467- with b :
468- v0 = WindowedTempOutput (self .inputs [0 ], self .attrs ["window" ])
469- v1 = ForeachBackWindow (v0 , self .attrs ["window" ])
470- v2 = ReduceRank (IterValue (v1 , v0 ), self .inputs [0 ])
471- return b .ops
506+ def on_skip_list (self , skplist : SkipListState , cur : OpBase ) -> OpBase :
507+ return SkipListRank (skplist , cur )
508+
509+ def make_reduce (self , v : OpBase , newest : OpBase ) -> OpBase :
510+ return ReduceRank (v , newest )
472511
473512class Clip (CompositiveOp ):
474513 '''
@@ -636,4 +675,27 @@ def decompose(self, options: dict) -> List[OpBase]:
636675 filtered = Select (index >= max_bar_index , IterValue (each , v ), inf )
637676 trough = ReduceMin (filtered )
638677 (peak - trough ) / peak
678+ return b .ops
679+
680+
681+ class WindowedQuantile (CompositiveOp , WindowedTrait ):
682+ '''
683+ Quantile in `window` rows ago.
684+ Similar to pd.rolling(window).quantile(q, interpolation='linear')
685+ '''
686+ def __init__ (self , v : OpBase , window : int , q : float ) -> None :
687+ super ().__init__ ([v ], [("window" , window ), ("q" , q )])
688+
689+ def required_input_window (self ) -> int :
690+ return self .attrs ["window" ] + 1
691+
692+ def decompose (self , options : dict ) -> List [OpBase ]:
693+ b = Builder (self .get_parent ())
694+ window = self .attrs ["window" ]
695+ v = self .inputs [0 ]
696+ q = self .attrs ["q" ]
697+ with b :
698+ old = BackRef (v , window )
699+ v2 = SkipListState (old , v , window )
700+ v3 = SkipListQuantile (v2 , q )
639701 return b .ops
0 commit comments