Skip to content

Commit 58e4461

Browse files
authored
sort with skiplist (#67)
1 parent 997c4b7 commit 58e4461

File tree

13 files changed

+766
-100
lines changed

13 files changed

+766
-100
lines changed

KunQuant/Driver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def compileit(f: Function, module_name: str, partition_factor = 3, dtype = "floa
106106
blocking_len = suggested_len[dtype]
107107
if element_size[dtype] * blocking_len not in simd_len:
108108
raise RuntimeError(f"Blocking length {blocking_len} is not supported for {dtype} on {_cpu_arch}")
109+
options['blocking_len'] = blocking_len
109110
if output_layout not in ["STs", "TS", "STREAM"]:
110111
raise RuntimeError("Bad output_layout name " + output_layout)
111112
if input_layout not in ["STs", "TS", "STREAM"]:

KunQuant/Op.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def traverse_replace_map(op: 'OpBase', replace_map: Dict['OpBase', 'OpBase']) ->
7474
return traverse_replace_map(found, replace_map)
7575

7676
class AcceptSingleValueInputTrait(ABC):
77+
'''
78+
The ops that accept a time_len=1 array as input
79+
'''
7780
@abstractmethod
7881
def get_single_value_input_id() -> int:
7982
pass
@@ -479,6 +482,18 @@ class GloablStatefulOpTrait(StatefulOpTrait):
479482
'''
480483
pass
481484

485+
class GlobalStatefulProducerTrait(GloablStatefulOpTrait):
486+
'''
487+
The ops that have an internal state, and the state is carried between different time steps, and the state must be consumed by a StateConsumerTrait
488+
'''
489+
pass
490+
491+
class StateConsumerTrait:
492+
'''
493+
The ops that consume a state from a GlobalStatefulProducerTrait
494+
'''
495+
pass
496+
482497
class ReductionOp(OpBase, StatefulOpTrait):
483498
'''
484499
Base class of all reduction ops. A reduction op takes inputs that is originated from a IterValue. The input must be in a loop (v.get_parent() is a loop). The data produced

KunQuant/ops/CompOp.py

Lines changed: 96 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
from .ReduceOp import ReduceAdd, ReduceMul, ReduceArgMax, ReduceRank, ReduceMin, ReduceMax, ReduceDecayLinear, ReduceArgMin
22
from KunQuant.Op import ConstantOp, OpBase, CompositiveOp, WindowedTrait, ForeachBackWindow, WindowedTempOutput, Builder, IterValue, WindowLoopIndex
33
from .ElewiseOp import And, DivConst, GreaterThan, LessThan, Or, Select, SetInfOrNanToValue, Sub, Mul, Sqrt, SubConst, Div, CmpOp, Exp, Log, Min, Max, Equals, Abs
4-
from .MiscOp import Accumulator, BackRef, SetAccumulator, WindowedLinearRegression, WindowedLinearRegressionResiImpl, WindowedLinearRegressionRSqaureImpl, WindowedLinearRegressionSlopeImpl, ReturnFirstValue
4+
from .MiscOp import Accumulator, BackRef, SetAccumulator, WindowedLinearRegression, WindowedLinearRegressionResiImpl,\
5+
WindowedLinearRegressionRSqaureImpl, WindowedLinearRegressionSlopeImpl, ReturnFirstValue, SkipListState, SkipListQuantile, SkipListRank, SkipListMin, SkipListMax,\
6+
SkipListArgMin
57
from collections import OrderedDict
68
from typing import Union, List, Tuple, Dict
79
import math
810

911
def _is_fast_stat(opt: dict, attrs: dict) -> bool:
1012
return not opt.get("no_fast_stat", True) and not attrs.get("no_fast_stat", False)
1113

14+
def _decide_use_skip_list(window: int, blocking_len: int) -> bool:
15+
naive_cost = window
16+
skip_list_cost = math.log2(window) * blocking_len * 5
17+
return skip_list_cost < naive_cost
18+
1219
class WindowedCompositiveOp(CompositiveOp, WindowedTrait):
1320
def __init__(self, v: OpBase, window: int, v2 = None) -> None:
1421
inputs = [v]
@@ -17,7 +24,7 @@ def __init__(self, v: OpBase, window: int, v2 = None) -> None:
1724
super().__init__(inputs, [("window", window)])
1825

1926
class WindowedReduce(WindowedCompositiveOp):
20-
def make_reduce(self, v: OpBase) -> OpBase:
27+
def make_reduce(self, v: OpBase, newest: OpBase) -> OpBase:
2128
raise RuntimeError("Not implemented")
2229

2330
def decompose(self, options: dict) -> List[OpBase]:
@@ -26,7 +33,7 @@ def decompose(self, options: dict) -> List[OpBase]:
2633
v0 = WindowedTempOutput(self.inputs[0], self.attrs["window"])
2734
v1 = ForeachBackWindow(v0, self.attrs["window"])
2835
itr = IterValue(v1, v0)
29-
v2 = self.make_reduce(itr)
36+
v2 = self.make_reduce(itr, self.inputs[0])
3037
return b.ops
3138

3239
class WindowedSum(WindowedReduce):
@@ -35,7 +42,7 @@ class WindowedSum(WindowedReduce):
3542
For indices < window-1, the output will be NaN
3643
similar to pandas.DataFrame.rolling(n).sum()
3744
'''
38-
def make_reduce(self, v: OpBase) -> OpBase:
45+
def make_reduce(self, v: OpBase, newest: OpBase) -> OpBase:
3946
return ReduceAdd(v)
4047

4148
class WindowedProduct(WindowedReduce):
@@ -44,26 +51,55 @@ class WindowedProduct(WindowedReduce):
4451
For indices < window-1, the output will be NaN
4552
similar to pandas.DataFrame.rolling(n).product()
4653
'''
47-
def make_reduce(self, v: OpBase) -> OpBase:
54+
def make_reduce(self, v: OpBase, newest: OpBase) -> OpBase:
4855
return ReduceMul(v)
4956

50-
class WindowedMin(WindowedReduce):
57+
class _WindowedMinMaxBase(WindowedReduce):
58+
'''
59+
Base class for windowed min/max ops that can use skip list. If the window is small enough, use naive linear scan. Otherwise, use skip list
60+
with Log(window) complexity.
61+
'''
62+
def on_skip_list(self, skplist: SkipListState, cur: OpBase) -> OpBase:
63+
raise RuntimeError("Not implemented")
64+
65+
def decompose(self, options: dict) -> List[OpBase]:
66+
window = self.attrs["window"]
67+
blocking_len = options["blocking_len"]
68+
if _decide_use_skip_list(window, blocking_len):
69+
b = Builder(self.get_parent())
70+
with b:
71+
newv = self.inputs[0]
72+
oldv = BackRef(newv, window)
73+
v2 = SkipListState(oldv, newv, window)
74+
self.on_skip_list(v2, newv)
75+
return b.ops
76+
else:
77+
return super().decompose(options)
78+
79+
class WindowedMin(_WindowedMinMaxBase):
5180
'''
5281
Min of a rolling look back window, including the current newest data.
5382
For indices < window-1, the output will be NaN
5483
similar to pandas.DataFrame.rolling(n).min()
5584
'''
56-
def make_reduce(self, v: OpBase) -> OpBase:
85+
def make_reduce(self, v: OpBase, newest: OpBase) -> OpBase:
5786
return ReduceMin(v)
5887

59-
class WindowedMax(WindowedReduce):
88+
def on_skip_list(self, skplist: SkipListState, cur: OpBase) -> OpBase:
89+
return SkipListMin(skplist)
90+
91+
92+
class WindowedMax(_WindowedMinMaxBase):
6093
'''
6194
Max of a rolling look back window, including the current newest data.
6295
For indices < window-1, the output will be NaN
6396
similar to pandas.DataFrame.rolling(n).max()
6497
'''
65-
def make_reduce(self, v: OpBase) -> OpBase:
98+
def make_reduce(self, v: OpBase, newest: OpBase) -> OpBase:
6699
return ReduceMax(v)
100+
101+
def on_skip_list(self, skplist: SkipListState, cur: OpBase) -> OpBase:
102+
return SkipListMax(skplist)
67103

68104
# small - sum
69105
def _kahan_sub(mask: OpBase, sum: OpBase, small: OpBase, compensation: OpBase) -> Union[OpBase, OpBase]:
@@ -424,51 +460,54 @@ def decompose(self, options: dict) -> List[OpBase]:
424460
return b.ops
425461

426462

427-
class TsArgMax(WindowedCompositiveOp):
463+
class TsArgMax(WindowedReduce):
428464
'''
429465
ArgMax in a rolling look back window, including the current newest data.
430466
The result should be the index of the max element in the rolling window. The index of the oldest element of the rolling window is 1.
431467
Similar to df.rolling(window).apply(np.argmax) + 1
432468
'''
433469
def decompose(self, options: dict) -> List[OpBase]:
434-
b = Builder(self.get_parent())
435-
with b:
436-
v0 = WindowedTempOutput(self.inputs[0], self.attrs["window"])
437-
v1 = ForeachBackWindow(v0, self.attrs["window"])
438-
v2 = ReduceArgMax(IterValue(v1, v0))
439-
v3 = SubConst(v2, self.attrs["window"], True)
440-
return b.ops
470+
window = self.attrs["window"]
471+
blocking_len = options["blocking_len"]
472+
if _decide_use_skip_list(window, blocking_len):
473+
b = Builder(self.get_parent())
474+
with b:
475+
TsArgMin(0-self.inputs[0], window)
476+
return b.ops
477+
else:
478+
return super().decompose(options)
479+
480+
def make_reduce(self, v: OpBase, newest: OpBase) -> OpBase:
481+
v2 = ReduceArgMax(v)
482+
return self.attrs["window"] - v2
441483

442-
class TsArgMin(WindowedCompositiveOp):
484+
class TsArgMin(_WindowedMinMaxBase):
443485
'''
444486
ArgMin in a rolling look back window, including the current newest data.
445487
The result should be the index of the min element in the rolling window. The index of the oldest element of the rolling window is 1.
446488
Similar to df.rolling(window).apply(np.argmin) + 1
447489
'''
448-
def decompose(self, options: dict) -> List[OpBase]:
449-
b = Builder(self.get_parent())
450-
with b:
451-
v0 = WindowedTempOutput(self.inputs[0], self.attrs["window"])
452-
v1 = ForeachBackWindow(v0, self.attrs["window"])
453-
v2 = ReduceArgMin(IterValue(v1, v0))
454-
v3 = SubConst(v2, self.attrs["window"], True)
455-
return b.ops
490+
def on_skip_list(self, skplist: SkipListState, cur: OpBase) -> OpBase:
491+
return SkipListArgMin([skplist], [])
492+
493+
def make_reduce(self, v: OpBase, newest: OpBase) -> OpBase:
494+
v2 = ReduceArgMin(v)
495+
return self.attrs["window"] - v2
496+
456497

457-
class TsRank(WindowedCompositiveOp):
498+
class TsRank(_WindowedMinMaxBase):
458499
'''
459500
Time series rank of the newest data in a rolling look back window, including the current newest data.
460501
Let num_values_less = the number of values in rolling window that is less than the current newest data.
461502
Let num_values_eq = the number of values in rolling window that is equal to the current newest data.
462503
rank = num_values_less + (num_values_eq + 1) / 2
463504
Similar to df.rolling(window).rank()
464505
'''
465-
def decompose(self, options: dict) -> List[OpBase]:
466-
b = Builder(self.get_parent())
467-
with b:
468-
v0 = WindowedTempOutput(self.inputs[0], self.attrs["window"])
469-
v1 = ForeachBackWindow(v0, self.attrs["window"])
470-
v2 = ReduceRank(IterValue(v1, v0), self.inputs[0])
471-
return b.ops
506+
def on_skip_list(self, skplist: SkipListState, cur: OpBase) -> OpBase:
507+
return SkipListRank(skplist, cur)
508+
509+
def make_reduce(self, v: OpBase, newest: OpBase) -> OpBase:
510+
return ReduceRank(v, newest)
472511

473512
class Clip(CompositiveOp):
474513
'''
@@ -636,4 +675,27 @@ def decompose(self, options: dict) -> List[OpBase]:
636675
filtered = Select(index >= max_bar_index, IterValue(each, v), inf)
637676
trough = ReduceMin(filtered)
638677
(peak - trough) / peak
678+
return b.ops
679+
680+
681+
class WindowedQuantile(CompositiveOp, WindowedTrait):
682+
'''
683+
Quantile in `window` rows ago.
684+
Similar to pd.rolling(window).quantile(q, interpolation='linear')
685+
'''
686+
def __init__(self, v: OpBase, window: int, q: float) -> None:
687+
super().__init__([v], [("window", window), ("q", q)])
688+
689+
def required_input_window(self) -> int:
690+
return self.attrs["window"] + 1
691+
692+
def decompose(self, options: dict) -> List[OpBase]:
693+
b = Builder(self.get_parent())
694+
window = self.attrs["window"]
695+
v = self.inputs[0]
696+
q = self.attrs["q"]
697+
with b:
698+
old = BackRef(v, window)
699+
v2 = SkipListState(old, v, window)
700+
v3 = SkipListQuantile(v2, q)
639701
return b.ops

KunQuant/ops/MiscOp.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import KunQuant
2-
from KunQuant.Op import AcceptSingleValueInputTrait, Input, OpBase, WindowedTrait, SinkOpTrait, CrossSectionalOp, GloablStatefulOpTrait, UnaryElementwiseOp, BinaryElementwiseOp
2+
from KunQuant.Op import AcceptSingleValueInputTrait, Input, OpBase, WindowedTrait, SinkOpTrait, CrossSectionalOp, GlobalStatefulProducerTrait, GloablStatefulOpTrait, StateConsumerTrait, UnaryElementwiseOp, BinaryElementwiseOp
33
from typing import List, Union
44

55
class BackRef(OpBase, WindowedTrait):
@@ -12,16 +12,6 @@ def __init__(self, v: OpBase, window: int) -> None:
1212
def required_input_window(self) -> int:
1313
return self.attrs["window"] + 1
1414

15-
class WindowedQuantile(OpBase, WindowedTrait):
16-
'''
17-
Quantile in `window` rows ago.
18-
Similar to pd.rolling(window).quantile(q, interpolation='linear')
19-
'''
20-
def __init__(self, v: OpBase, window: int, q: float) -> None:
21-
super().__init__([v], [("window", window), ("q", q)])
22-
23-
def required_input_window(self) -> int:
24-
return self.attrs["window"] + 1
2515

2616
class FastWindowedSum(OpBase, WindowedTrait, GloablStatefulOpTrait):
2717
'''
@@ -39,7 +29,7 @@ def get_state_variable_name_prefix(self) -> str:
3929
def generate_step_code(self, idx: str, time_idx: str, inputs: List[str], buf_name: str) -> str:
4030
return f"auto v{idx} = sum_{idx}.step({buf_name}, {inputs[0]}, {time_idx});"
4131

42-
class Accumulator(OpBase, GloablStatefulOpTrait):
32+
class Accumulator(OpBase, GlobalStatefulProducerTrait):
4333
'''
4434
Accumulator is a stateful op that accumulates the input value over time.
4535
It can be used to compute running totals, moving averages, etc.'''
@@ -64,7 +54,7 @@ def verify(self, func) -> None:
6454
raise RuntimeError(f"Accumulator {self.attrs['name']} is not used with any SetAccumulator")
6555
return super().verify(func)
6656

67-
class SetAccumulator(OpBase):
57+
class SetAccumulator(OpBase, StateConsumerTrait):
6858
'''
6959
Set the value of an Accumulator to a value, if mask is set. Otherwise, it does nothing.
7060
'''
@@ -120,7 +110,7 @@ def generate_init_code(self, idx: str, elem_type: str, simd_lanes: int, inputs:
120110
def generate_step_code(self, idx: str, time_idx: str, inputs: List[str]) -> str:
121111
return f"auto v{idx} = ema_{idx}.step({inputs[0]}, {time_idx});"
122112

123-
class WindowedLinearRegression(OpBase, WindowedTrait, GloablStatefulOpTrait):
113+
class WindowedLinearRegression(OpBase, WindowedTrait, GlobalStatefulProducerTrait):
124114
'''
125115
Compute states of Windowed Linear Regression
126116
'''
@@ -135,37 +125,74 @@ def get_state_variable_name_prefix(self) -> str:
135125

136126
def generate_step_code(self, idx: str, time_idx: str, inputs: List[str], buf_name: str) -> str:
137127
return f"const auto& v{idx} = linear_{idx}.step({buf_name}, {inputs[0]}, {time_idx});"
138-
139-
class WindowedLinearRegressionImplBase(OpBase):
140-
def __init__(self, v: OpBase) -> None:
141-
super().__init__([v])
142-
128+
129+
130+
class WindowedLinearRegressionConsumerTrait(StateConsumerTrait):
143131
def verify(self, func: 'KunQuant.Stage.Function') -> None:
144132
if len(self.inputs) < 1 or not isinstance(self.inputs[0], WindowedLinearRegression):
145133
raise RuntimeError("WindowedLinearRegressionImpl expects WindowedLinearRegression Op as input")
146-
return super().verify(func)
134+
return OpBase.verify(self, func)
135+
136+
class WindowedLinearRegressionImplUnaryBase(WindowedLinearRegressionConsumerTrait, UnaryElementwiseOp):
137+
pass
147138

148-
class WindowedLinearRegressionConsumerTrait:
139+
class WindowedLinearRegressionImplBinaryBase(WindowedLinearRegressionConsumerTrait, BinaryElementwiseOp):
149140
pass
150141

151-
class WindowedLinearRegressionRSqaureImpl(UnaryElementwiseOp, WindowedLinearRegressionConsumerTrait):
142+
class WindowedLinearRegressionRSqaureImpl(WindowedLinearRegressionImplUnaryBase):
152143
'''
153144
Compute RSqaure of Windowed Linear Regression
154145
'''
155146
pass
156147

157-
class WindowedLinearRegressionSlopeImpl(UnaryElementwiseOp, WindowedLinearRegressionConsumerTrait):
148+
class WindowedLinearRegressionSlopeImpl(WindowedLinearRegressionImplUnaryBase):
158149
'''
159150
Compute RSqaure of Windowed Linear Regression
160151
'''
161152
pass
162153

163-
class WindowedLinearRegressionResiImpl(BinaryElementwiseOp, WindowedLinearRegressionConsumerTrait):
154+
class WindowedLinearRegressionResiImpl(WindowedLinearRegressionImplBinaryBase):
164155
'''
165156
Compute RSqaure of Windowed Linear Regression
166157
'''
167158
pass
168159

160+
class SkipListState(OpBase, GlobalStatefulProducerTrait):
161+
'''
162+
SkipListState is a stateful op that maintains a skip list of the input values.
163+
'''
164+
def __init__(self, oldvalue: OpBase, value: OpBase, window: int) -> None:
165+
super().__init__([oldvalue, value], [("window", window)])
166+
167+
def get_state_variable_name_prefix(self) -> str:
168+
return "skip_list_"
169+
170+
def generate_step_code(self, idx: str, time_idx: str, inputs: List[str]) -> str:
171+
return f"auto& v{idx} = skip_list_{idx}.step({inputs[0]}, {inputs[1]}, {time_idx});"
172+
173+
class SkipListConsumerOp(StateConsumerTrait):
174+
def verify(self, func: 'KunQuant.Stage.Function') -> None:
175+
if len(self.inputs) < 1 or not isinstance(self.inputs[0], SkipListState):
176+
raise RuntimeError("SkipListConsumerOp expects SkipListState Op as input")
177+
return super().verify(func)
178+
179+
class SkipListQuantile(SkipListConsumerOp, OpBase):
180+
def __init__(self, v: OpBase, q: float) -> None:
181+
super().__init__([v], [("q", q)])
182+
183+
class SkipListRank(SkipListConsumerOp, BinaryElementwiseOp):
184+
pass
185+
186+
class SkipListMin(SkipListConsumerOp, UnaryElementwiseOp):
187+
pass
188+
189+
class SkipListMax(SkipListConsumerOp, UnaryElementwiseOp):
190+
pass
191+
192+
class SkipListArgMin(SkipListConsumerOp, OpBase):
193+
pass
194+
195+
169196
class GenericCrossSectionalOp(CrossSectionalOp):
170197
'''
171198
Cross sectional op with customized C++ implementation.

0 commit comments

Comments
 (0)