11from typing import List , Any
22
3- from talipp .indicator_util import has_valid_values
43from talipp .indicators .Indicator import Indicator , InputModifierType
4+ from talipp .indicators .TrueRange import TrueRange
55from talipp .input import SamplingPeriodType
66from talipp .ohlcv import OHLCV
7+ from talipp .ma import MAType , MAFactory
78
89
910class ATR (Indicator ):
@@ -19,40 +20,24 @@ class ATR(Indicator):
1920 input_indicator: Input indicator.
2021 input_modifier: Input modifier.
2122 input_sampling: Input sampling type.
22- """
23-
23+ """
2424 def __init__ (self , period : int ,
2525 input_values : List [OHLCV ] = None ,
2626 input_indicator : Indicator = None ,
2727 input_modifier : InputModifierType = None ,
28+ ma_type : MAType = MAType .WilderMA ,
2829 input_sampling : SamplingPeriodType = None ):
2930 super (ATR , self ).__init__ (input_modifier = input_modifier ,
3031 input_sampling = input_sampling )
3132
3233 self .period = period
33- self .tr = []
3434
35- self .add_managed_sequence (self .tr )
35+ self ._tr = TrueRange ()
36+ self .add_sub_indicator (self ._tr )
37+
38+ self ._ma_tr = MAFactory .get_ma (ma_type , period , input_indicator = self ._tr )
3639
3740 self .initialize (input_values , input_indicator )
3841
3942 def _calculate_new_value (self ) -> Any :
40- high = self .input_values [- 1 ].high
41- low = self .input_values [- 1 ].low
42-
43- if has_valid_values (self .input_values , 1 , exact = True ):
44- self .tr .append (high - low )
45- else :
46- close2 = self .input_values [- 2 ].close
47- self .tr .append (max (
48- high - low ,
49- abs (high - close2 ),
50- abs (low - close2 ),
51- ))
52-
53- if len (self .input_values ) < self .period :
54- return None
55- elif len (self .input_values ) == self .period :
56- return sum (self .tr ) / self .period
57- else :
58- return (self .output_values [- 1 ] * (self .period - 1 ) + self .tr [- 1 ]) / self .period
43+ return self ._ma_tr .output_values [- 1 ]
0 commit comments