11from typing import Union
2- from pandas import DataFrame as PdDataFrame
3- from polars import DataFrame as PlDataFrame
4-
5- from pyindicators .exceptions import PyIndicatorException
2+ import pandas as pd
3+ import polars as pl
64
75
86def rsi (
9- data : Union [PdDataFrame , PlDataFrame ],
7+ data : Union [pd . DataFrame , pl . DataFrame ],
108 source_column : str ,
119 period : int ,
1210 result_column : str = None ,
13- ) -> Union [PdDataFrame , PlDataFrame ]:
11+ ) -> Union [pd . DataFrame , pl . DataFrame ]:
1412 """
15- Function to calculate the RSI of a series.
13+ Function to calculate the RSI (Relative Strength Index) of a series.
1614
1715 Args:
18- data (Union[PdDataFrame, PlDataFrame ]): The input data.
16+ data (Union[pd.DataFrame, pl.DataFrame ]): The input data.
1917 source_column (str): The name of the series.
20- period (int): The period for the exponential moving average .
21- result_column (str, optional): The name of the column to store the
22- exponential moving average. Defaults to None.
18+ period (int): The period for the RSI calculation .
19+ result_column (str, optional): The name of the column to store the RSI values.
20+ Defaults to None, which means it will be named "RSI_{period}" .
2321
2422 Returns:
25- Union[PdDataFrame, PlDataFrame]: Returns a DataFrame with
26- the RSI of the series.
23+ Union[pd.DataFrame, pl.DataFrame]: The DataFrame with the RSI column added.
2724 """
2825
2926 if result_column is None :
3027 result_column = f"RSI_{ period } "
3128
32- if source_column not in data .columns :
33- raise PyIndicatorException (
34- f"The column { source_column } does not exist in the DataFrame."
35- )
36-
37- if isinstance (data , PdDataFrame ):
29+ if isinstance (data , pd .DataFrame ):
3830 # Compute price changes
3931 delta = data [source_column ].diff ()
4032
@@ -43,14 +35,20 @@ def rsi(
4335 loss = - delta .where (delta < 0 , 0 )
4436
4537 # Compute the rolling average of gains and losses
46- avg_gain = gain .rolling (window = period , min_periods = 1 ).mean ()
47- avg_loss = loss .rolling (window = period , min_periods = 1 ).mean ()
38+ avg_gain = gain .rolling (window = period , min_periods = period ).mean ()
39+ avg_loss = loss .rolling (window = period , min_periods = period ).mean ()
4840
4941 # Compute RSI
5042 rs = avg_gain / avg_loss
51- data [ result_column ] = 100 - (100 / (1 + rs ))
43+ rsi_values = 100 - (100 / (1 + rs ))
5244
53- elif isinstance (data , PlDataFrame ):
45+ # Ensure first `period` rows are NaN
46+ rsi_values [:period ] = pd .NA
47+
48+ # Assign to DataFrame
49+ data [result_column ] = rsi_values
50+
51+ elif isinstance (data , pl .DataFrame ):
5452 # Compute price changes
5553 delta = data [source_column ].diff ().fill_null (0 )
5654
@@ -59,13 +57,91 @@ def rsi(
5957 loss = (- delta ).clip_min (0 )
6058
6159 # Compute rolling averages of gains and losses
62- avg_gain = gain .rolling_mean (window_size = period )
63- avg_loss = loss .rolling_mean (window_size = period )
60+ avg_gain = gain .rolling_mean (window_size = period , min_periods = period )
61+ avg_loss = loss .rolling_mean (window_size = period , min_periods = period )
6462
6563 # Compute RSI
6664 rs = avg_gain / avg_loss
6765 rsi_values = 100 - (100 / (1 + rs ))
6866
67+ # Replace first `period` values with nulls (polars uses `None`)
68+ rsi_values = rsi_values .set_at_idx (list (range (period )), None )
69+
70+ # Add column to DataFrame
71+ data = data .with_columns (rsi_values .alias (result_column ))
72+
73+ else :
74+ raise TypeError ("Input data must be a pandas or polars DataFrame." )
75+
76+ return data
77+
78+
79+ def wilders_rsi (
80+ data : Union [pd .DataFrame , pl .DataFrame ],
81+ source_column : str ,
82+ period : int ,
83+ result_column : str = None ,
84+ ) -> Union [pd .DataFrame , pl .DataFrame ]:
85+ """
86+ Compute RSI using wilders method (Wilder’s Smoothing).
87+
88+ Args:
89+ data (Union[pd.DataFrame, pl.DataFrame]): Input DataFrame.
90+ source_column (str): Name of the column with price data.
91+ period (int): RSI period (e.g., 14).
92+ result_column (str, optional): Name for the output column.
93+
94+ Returns:
95+ Union[pd.DataFrame, pl.DataFrame]: DataFrame with RSI values.
96+ """
97+
98+ if result_column is None :
99+ result_column = f"RSI_{ period } "
100+
101+ if isinstance (data , pd .DataFrame ):
102+ delta = data [source_column ].diff ()
103+
104+ gain = delta .where (delta > 0 , 0 )
105+ loss = - delta .where (delta < 0 , 0 )
106+
107+ # Compute the initial SMA (first `period` rows)
108+ avg_gain = gain .rolling (window = period , min_periods = period ).mean ()
109+ avg_loss = loss .rolling (window = period , min_periods = period ).mean ()
110+
111+ # Apply Wilder's Smoothing for the remaining values
112+ for i in range (period , len (data )):
113+ avg_gain .iloc [i ] = (avg_gain .iloc [i - 1 ] * (period - 1 ) + gain .iloc [i ]) / period
114+ avg_loss .iloc [i ] = (avg_loss .iloc [i - 1 ] * (period - 1 ) + loss .iloc [i ]) / period
115+
116+ rs = avg_gain / avg_loss
117+ data [result_column ] = 100 - (100 / (1 + rs ))
118+
119+ # Ensure first `period` rows are NaN
120+ data .iloc [:period , data .columns .get_loc (result_column )] = pd .NA
121+
122+ elif isinstance (data , pl .DataFrame ):
123+ delta = data [source_column ].diff ().fill_null (0 )
124+ gain = delta .clip_min (0 )
125+ loss = (- delta ).clip_min (0 )
126+
127+ # Compute initial SMA (first `period` rows)
128+ avg_gain = gain .rolling_mean (window_size = period , min_periods = period )
129+ avg_loss = loss .rolling_mean (window_size = period , min_periods = period )
130+
131+ # Apply Wilder's Smoothing
132+ smoothed_gain = [None ] * period
133+ smoothed_loss = [None ] * period
134+ for i in range (period , len (data )):
135+ smoothed_gain .append ((smoothed_gain [- 1 ] * (period - 1 ) + gain [i ]) / period )
136+ smoothed_loss .append ((smoothed_loss [- 1 ] * (period - 1 ) + loss [i ]) / period )
137+
138+ # Compute RSI
139+ rs = pl .Series (smoothed_gain ) / pl .Series (smoothed_loss )
140+ rsi_values = 100 - (100 / (1 + rs ))
141+
142+ # Replace first `period` values with None
143+ rsi_values = rsi_values .set_at_idx (list (range (period )), None )
144+
69145 # Add column to DataFrame
70146 data = data .with_columns (rsi_values .alias (result_column ))
71147
0 commit comments