Skip to content

Commit 06a3106

Browse files
committed
ATR=TrueRange with WilderMA
1 parent ed5e0e1 commit 06a3106

File tree

8 files changed

+147
-34
lines changed

8 files changed

+147
-34
lines changed

talipp/indicators/ATR.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from typing import List, Any
22

3-
from talipp.indicator_util import has_valid_values
43
from talipp.indicators.Indicator import Indicator, InputModifierType
4+
from talipp.indicators.TrueRange import TrueRange
55
from talipp.input import SamplingPeriodType
66
from talipp.ohlcv import OHLCV
7+
from talipp.ma import MAType, MAFactory
78

89

910
class ATR(Indicator):
@@ -19,40 +20,24 @@ class ATR(Indicator):
1920
input_indicator: Input indicator.
2021
input_modifier: Input modifier.
2122
input_sampling: Input sampling type.
22-
"""
23-
23+
"""
2424
def __init__(self, period: int,
2525
input_values: List[OHLCV] = None,
2626
input_indicator: Indicator = None,
2727
input_modifier: InputModifierType = None,
28+
ma_type: MAType = MAType.WilderMA,
2829
input_sampling: SamplingPeriodType = None):
2930
super(ATR, self).__init__(input_modifier=input_modifier,
3031
input_sampling=input_sampling)
3132

3233
self.period = period
33-
self.tr = []
3434

35-
self.add_managed_sequence(self.tr)
35+
self._tr = TrueRange()
36+
self.add_sub_indicator(self._tr)
37+
38+
self._ma_tr = MAFactory.get_ma(ma_type, period, input_indicator=self._tr)
3639

3740
self.initialize(input_values, input_indicator)
3841

3942
def _calculate_new_value(self) -> Any:
40-
high = self.input_values[-1].high
41-
low = self.input_values[-1].low
42-
43-
if has_valid_values(self.input_values, 1, exact=True):
44-
self.tr.append(high - low)
45-
else:
46-
close2 = self.input_values[-2].close
47-
self.tr.append(max(
48-
high - low,
49-
abs(high - close2),
50-
abs(low - close2),
51-
))
52-
53-
if len(self.input_values) < self.period:
54-
return None
55-
elif len(self.input_values) == self.period:
56-
return sum(self.tr) / self.period
57-
else:
58-
return (self.output_values[-1] * (self.period - 1) + self.tr[-1]) / self.period
43+
return self._ma_tr.output_values[-1]

talipp/indicators/CHOP.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import List, Any
33

44
from talipp.indicator_util import has_valid_values
5-
from talipp.indicators.ATR import ATR
5+
from talipp.indicators.TrueRange import TrueRange
66
from talipp.indicators.Indicator import Indicator, InputModifierType
77
from talipp.input import SamplingPeriodType
88
from talipp.ohlcv import OHLCV
@@ -33,19 +33,19 @@ def __init__(self, period: int,
3333

3434
self.period = period
3535

36-
self.atr = ATR(1)
37-
self.add_sub_indicator(self.atr)
36+
self.tr = TrueRange()
37+
self.add_sub_indicator(self.tr)
3838

3939
self.initialize(input_values, input_indicator)
4040

4141
def _calculate_new_value(self) -> Any:
42-
if not has_valid_values(self.atr, self.period) or not has_valid_values(self.input_values, self.period):
42+
if not has_valid_values(self.tr, self.period) or not has_valid_values(self.input_values, self.period):
4343
return None
4444

4545
max_high = max(self.input_values[-self.period:], key = lambda x: x.high).high
4646
min_low = min(self.input_values[-self.period:], key = lambda x: x.low).low
4747

4848
if max_high != min_low:
49-
return 100.0 * log10(sum(self.atr[-self.period:]) / (max_high - min_low) ) / log10(self.period)
49+
return 100.0 * log10(sum(self.tr[-self.period:]) / (max_high - min_low) ) / log10(self.period)
5050
else:
5151
return None

talipp/indicators/TrueRange.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from typing import List, Any
2+
3+
from talipp.indicator_util import has_valid_values
4+
from talipp.indicators.Indicator import Indicator, InputModifierType
5+
from talipp.input import SamplingPeriodType
6+
from talipp.ohlcv import OHLCV
7+
8+
9+
class TrueRange(Indicator):
10+
"""True Range
11+
12+
Input type: [OHLCV][talipp.ohlcv.OHLCV]
13+
14+
Output type: `float`
15+
16+
Args:
17+
input_values: List of input values.
18+
input_indicator: Input indicator.
19+
input_modifier: Input modifier.
20+
input_sampling: Input sampling type.
21+
"""
22+
23+
def __init__(self,
24+
input_values: List[OHLCV] = None,
25+
input_indicator: Indicator = None,
26+
input_modifier: InputModifierType = None,
27+
input_sampling: SamplingPeriodType = None):
28+
super(TrueRange, self).__init__(input_modifier=input_modifier,
29+
input_sampling=input_sampling)
30+
31+
self.initialize(input_values, input_indicator)
32+
33+
def _calculate_new_value(self) -> Any:
34+
high = self.input_values[-1].high
35+
low = self.input_values[-1].low
36+
37+
if has_valid_values(self.input_values, 1, exact=True):
38+
return high - low
39+
else:
40+
close2 = self.input_values[-2].close
41+
return max(
42+
high - low,
43+
abs(high - close2),
44+
abs(low - close2),
45+
)

talipp/indicators/VTX.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import List, Any
33

44
from talipp.indicator_util import has_valid_values
5-
from talipp.indicators.ATR import ATR
5+
from talipp.indicators.TrueRange import TrueRange
66
from talipp.indicators.Indicator import Indicator, InputModifierType
77
from talipp.input import SamplingPeriodType
88
from talipp.ohlcv import OHLCV
@@ -53,8 +53,8 @@ def __init__(self, period: int,
5353
self.minus_vm = []
5454
self.add_managed_sequence(self.minus_vm)
5555

56-
self.atr = ATR(1)
57-
self.add_sub_indicator(self.atr)
56+
self.tr = TrueRange()
57+
self.add_sub_indicator(self.tr)
5858

5959
self.initialize(input_values, input_indicator)
6060

@@ -68,9 +68,9 @@ def _calculate_new_value(self) -> Any:
6868
self.plus_vm.append(abs(value.high - value2.low))
6969
self.minus_vm.append(abs(value.low - value2.high))
7070

71-
if not has_valid_values(self.atr, self.period) or not has_valid_values(self.plus_vm, self.period) or \
71+
if not has_valid_values(self.tr, self.period) or not has_valid_values(self.plus_vm, self.period) or \
7272
not has_valid_values(self.minus_vm, self.period):
7373
return None
7474

75-
atr_sum = float(sum(self.atr[-self.period:]))
75+
atr_sum = float(sum(self.tr[-self.period:]))
7676
return VTXVal(sum(self.plus_vm[-self.period:]) / atr_sum, sum(self.minus_vm[-self.period:]) / atr_sum)

talipp/indicators/WilderMA.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from typing import List, Any
2+
3+
from talipp.indicator_util import has_valid_values
4+
from talipp.indicators.Indicator import Indicator, InputModifierType
5+
from talipp.input import SamplingPeriodType
6+
7+
8+
class WilderMA(Indicator):
9+
"""Wilder's Moving Average.
10+
11+
Input type: `float`
12+
13+
Output type: `float`
14+
15+
Args:
16+
period: Period.
17+
input_values: List of input values.
18+
input_indicator: Input indicator.
19+
input_modifier: Input modifier.
20+
input_sampling: Input sampling type.
21+
"""
22+
23+
def __init__(self, period: int,
24+
input_values: List[float] = None,
25+
input_indicator: Indicator = None,
26+
input_modifier: InputModifierType = None,
27+
input_sampling: SamplingPeriodType = None):
28+
super().__init__(input_modifier=input_modifier,
29+
input_sampling=input_sampling)
30+
31+
self.period = period
32+
self.k = 1.0 / self.period
33+
34+
self.initialize(input_values, input_indicator)
35+
36+
def _calculate_new_value(self) -> Any:
37+
if len(self.input_values) < self.period:
38+
return None
39+
elif has_valid_values(self.input_values, self.period, exact=True):
40+
return sum(self.input_values[-self.period:]) / self.period
41+
else:
42+
return float(self.k * self.input_values[-1] + (1.0 - self.k) * self.output_values[-1])

talipp/indicators/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,14 @@
4747
from .T3 import T3 as T3
4848
from .TEMA import TEMA as TEMA
4949
from .TRIX import TRIX as TRIX
50+
from .TrueRange import TrueRange as TrueRange
5051
from .TSI import TSI as TSI
5152
from .TTM import TTM as TTM
5253
from .UO import UO as UO
5354
from .VTX import VTX as VTX
5455
from .VWAP import VWAP as VWAP
5556
from .VWMA import VWMA as VWMA
57+
from .WilderMA import WilderMA as WilderMA
5658
from .WMA import WMA as WMA
5759
from .ZigZag import ZigZag as ZigZag
5860
from .ZLEMA import ZLEMA as ZLEMA
@@ -113,6 +115,7 @@
113115
"VTX",
114116
"VWAP",
115117
"VWMA",
118+
"WilderMA",
116119
"WMA",
117120
"ZigZag",
118121
"ZLEMA"

talipp/ma.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from talipp.indicators.TEMA import TEMA
1515
from talipp.indicators.TRIX import TRIX
1616
from talipp.indicators.VWMA import VWMA
17+
from talipp.indicators.WilderMA import WilderMA
1718
from talipp.indicators.WMA import WMA
1819
from talipp.indicators.ZLEMA import ZLEMA
1920

@@ -54,6 +55,9 @@ class MAType(Enum):
5455
VWMA = auto()
5556
"""[Volume Weighted Moving Average][talipp.indicators.VWMA]"""
5657

58+
WilderMA = auto()
59+
"""[Wilder's Moving Average][talipp.indicators.WMA]"""
60+
5761
WMA = auto()
5862
"""[Weighted Moving Average][talipp.indicators.WMA]"""
5963

@@ -101,6 +105,8 @@ def get_ma(ma_type: MAType,
101105
return HMA(period=period, input_values=input_values, input_indicator=input_indicator, input_modifier=input_modifier)
102106
elif ma_type == MAType.VWMA:
103107
return VWMA(period=period, input_values=input_values, input_indicator=input_indicator, input_modifier=input_modifier)
108+
elif ma_type == MAType.WilderMA:
109+
return WilderMA(period=period, input_values=input_values, input_indicator=input_indicator, input_modifier=input_modifier)
104110
elif ma_type == MAType.WMA:
105111
return WMA(period=period, input_values=input_values, input_indicator=input_indicator, input_modifier=input_modifier)
106112
elif ma_type == MAType.T3:

test/test_WilderMA.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import unittest
2+
3+
from talipp.indicators import WilderMA
4+
5+
from TalippTest import TalippTest
6+
7+
8+
class Test(TalippTest):
9+
def setUp(self) -> None:
10+
self.input_values = list(TalippTest.CLOSE_TMPL)
11+
12+
def test_init(self):
13+
ind = WilderMA(5, self.input_values)
14+
15+
print(ind)
16+
17+
self.assertAlmostEqual(ind[-3], 9.699400, places = 5)
18+
self.assertAlmostEqual(ind[-2], 9.805521, places = 5)
19+
self.assertAlmostEqual(ind[-1], 9.844417, places = 5)
20+
21+
def test_update(self):
22+
self.assertIndicatorUpdate(WilderMA(5, self.input_values))
23+
24+
def test_delete(self):
25+
self.assertIndicatorDelete(WilderMA(5, self.input_values))
26+
27+
def test_purge_oldest(self):
28+
self.assertIndicatorPurgeOldest(WilderMA(5, self.input_values))
29+
30+
31+
if __name__ == '__main__':
32+
unittest.main()

0 commit comments

Comments
 (0)