Skip to content

Commit a0cc176

Browse files
committed
Add wilders_rsi indicator
1 parent 2388836 commit a0cc176

9 files changed

+9407
-70
lines changed

pyindicators/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from .indicators import sma, rsi, is_crossover, crossunder, ema
1+
from .indicators import sma, rsi, is_crossover, crossunder, ema, wilders_rsi
22

33
__all__ = [
44
'sma',
55
'is_crossover',
66
'crossunder',
77
'ema',
88
'rsi',
9+
"wilders_rsi"
910
]

pyindicators/indicators/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from .crossover import is_crossover
33
from .crossunder import crossunder
44
from .exponential_moving_average import ema
5-
from .rsi import rsi
5+
from .rsi import rsi, wilders_rsi
66

77
__all__ = [
88
'sma',
99
'is_crossover',
1010
'crossunder',
1111
'ema',
1212
'rsi',
13+
'wilders_rsi',
1314
]

pyindicators/indicators/rsi.py

Lines changed: 101 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,32 @@
11
from typing import Union
2-
from pandas import DataFrame as PdDataFrame
3-
from polars import DataFrame as PlDataFrame
4-
5-
from pyindicators.exceptions import PyIndicatorException
2+
import pandas as pd
3+
import polars as pl
64

75

86
def rsi(
9-
data: Union[PdDataFrame, PlDataFrame],
7+
data: Union[pd.DataFrame, pl.DataFrame],
108
source_column: str,
119
period: int,
1210
result_column: str = None,
13-
) -> Union[PdDataFrame, PlDataFrame]:
11+
) -> Union[pd.DataFrame, pl.DataFrame]:
1412
"""
15-
Function to calculate the RSI of a series.
13+
Function to calculate the RSI (Relative Strength Index) of a series.
1614
1715
Args:
18-
data (Union[PdDataFrame, PlDataFrame]): The input data.
16+
data (Union[pd.DataFrame, pl.DataFrame]): The input data.
1917
source_column (str): The name of the series.
20-
period (int): The period for the exponential moving average.
21-
result_column (str, optional): The name of the column to store the
22-
exponential moving average. Defaults to None.
18+
period (int): The period for the RSI calculation.
19+
result_column (str, optional): The name of the column to store the RSI values.
20+
Defaults to None, which means it will be named "RSI_{period}".
2321
2422
Returns:
25-
Union[PdDataFrame, PlDataFrame]: Returns a DataFrame with
26-
the RSI of the series.
23+
Union[pd.DataFrame, pl.DataFrame]: The DataFrame with the RSI column added.
2724
"""
2825

2926
if result_column is None:
3027
result_column = f"RSI_{period}"
3128

32-
if source_column not in data.columns:
33-
raise PyIndicatorException(
34-
f"The column {source_column} does not exist in the DataFrame."
35-
)
36-
37-
if isinstance(data, PdDataFrame):
29+
if isinstance(data, pd.DataFrame):
3830
# Compute price changes
3931
delta = data[source_column].diff()
4032

@@ -43,14 +35,20 @@ def rsi(
4335
loss = -delta.where(delta < 0, 0)
4436

4537
# Compute the rolling average of gains and losses
46-
avg_gain = gain.rolling(window=period, min_periods=1).mean()
47-
avg_loss = loss.rolling(window=period, min_periods=1).mean()
38+
avg_gain = gain.rolling(window=period, min_periods=period).mean()
39+
avg_loss = loss.rolling(window=period, min_periods=period).mean()
4840

4941
# Compute RSI
5042
rs = avg_gain / avg_loss
51-
data[result_column] = 100 - (100 / (1 + rs))
43+
rsi_values = 100 - (100 / (1 + rs))
5244

53-
elif isinstance(data, PlDataFrame):
45+
# Ensure first `period` rows are NaN
46+
rsi_values[:period] = pd.NA
47+
48+
# Assign to DataFrame
49+
data[result_column] = rsi_values
50+
51+
elif isinstance(data, pl.DataFrame):
5452
# Compute price changes
5553
delta = data[source_column].diff().fill_null(0)
5654

@@ -59,13 +57,91 @@ def rsi(
5957
loss = (-delta).clip_min(0)
6058

6159
# Compute rolling averages of gains and losses
62-
avg_gain = gain.rolling_mean(window_size=period)
63-
avg_loss = loss.rolling_mean(window_size=period)
60+
avg_gain = gain.rolling_mean(window_size=period, min_periods=period)
61+
avg_loss = loss.rolling_mean(window_size=period, min_periods=period)
6462

6563
# Compute RSI
6664
rs = avg_gain / avg_loss
6765
rsi_values = 100 - (100 / (1 + rs))
6866

67+
# Replace first `period` values with nulls (polars uses `None`)
68+
rsi_values = rsi_values.set_at_idx(list(range(period)), None)
69+
70+
# Add column to DataFrame
71+
data = data.with_columns(rsi_values.alias(result_column))
72+
73+
else:
74+
raise TypeError("Input data must be a pandas or polars DataFrame.")
75+
76+
return data
77+
78+
79+
def wilders_rsi(
80+
data: Union[pd.DataFrame, pl.DataFrame],
81+
source_column: str,
82+
period: int,
83+
result_column: str = None,
84+
) -> Union[pd.DataFrame, pl.DataFrame]:
85+
"""
86+
Compute RSI using wilders method (Wilder’s Smoothing).
87+
88+
Args:
89+
data (Union[pd.DataFrame, pl.DataFrame]): Input DataFrame.
90+
source_column (str): Name of the column with price data.
91+
period (int): RSI period (e.g., 14).
92+
result_column (str, optional): Name for the output column.
93+
94+
Returns:
95+
Union[pd.DataFrame, pl.DataFrame]: DataFrame with RSI values.
96+
"""
97+
98+
if result_column is None:
99+
result_column = f"RSI_{period}"
100+
101+
if isinstance(data, pd.DataFrame):
102+
delta = data[source_column].diff()
103+
104+
gain = delta.where(delta > 0, 0)
105+
loss = -delta.where(delta < 0, 0)
106+
107+
# Compute the initial SMA (first `period` rows)
108+
avg_gain = gain.rolling(window=period, min_periods=period).mean()
109+
avg_loss = loss.rolling(window=period, min_periods=period).mean()
110+
111+
# Apply Wilder's Smoothing for the remaining values
112+
for i in range(period, len(data)):
113+
avg_gain.iloc[i] = (avg_gain.iloc[i - 1] * (period - 1) + gain.iloc[i]) / period
114+
avg_loss.iloc[i] = (avg_loss.iloc[i - 1] * (period - 1) + loss.iloc[i]) / period
115+
116+
rs = avg_gain / avg_loss
117+
data[result_column] = 100 - (100 / (1 + rs))
118+
119+
# Ensure first `period` rows are NaN
120+
data.iloc[:period, data.columns.get_loc(result_column)] = pd.NA
121+
122+
elif isinstance(data, pl.DataFrame):
123+
delta = data[source_column].diff().fill_null(0)
124+
gain = delta.clip_min(0)
125+
loss = (-delta).clip_min(0)
126+
127+
# Compute initial SMA (first `period` rows)
128+
avg_gain = gain.rolling_mean(window_size=period, min_periods=period)
129+
avg_loss = loss.rolling_mean(window_size=period, min_periods=period)
130+
131+
# Apply Wilder's Smoothing
132+
smoothed_gain = [None] * period
133+
smoothed_loss = [None] * period
134+
for i in range(period, len(data)):
135+
smoothed_gain.append((smoothed_gain[-1] * (period - 1) + gain[i]) / period)
136+
smoothed_loss.append((smoothed_loss[-1] * (period - 1) + loss[i]) / period)
137+
138+
# Compute RSI
139+
rs = pl.Series(smoothed_gain) / pl.Series(smoothed_loss)
140+
rsi_values = 100 - (100 / (1 + rs))
141+
142+
# Replace first `period` values with None
143+
rsi_values = rsi_values.set_at_idx(list(range(period)), None)
144+
69145
# Add column to DataFrame
70146
data = data.with_columns(rsi_values.alias(result_column))
71147

94.3 KB
Loading

tests/indicators/test_rsi.py

Lines changed: 78 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,78 @@
1-
# from datetime import timedelta
2-
# from unittest import TestCase
3-
4-
# import pandas as pd
5-
# import numpy as np
6-
# import tulipy as ti
7-
# from investing_algorithm_framework import CSVOHLCVMarketDataSource
8-
9-
# import pyindicators as pyi
10-
11-
12-
# class Test(TestCase):
13-
14-
# def test(self):
15-
# data_source = CSVOHLCVMarketDataSource(
16-
# csv_file_path="../test_data/OHLCV_BTC-EUR_BINANCE_15m"
17-
# "_2023-12-01:00:00_2023-12-25:00:00.csv",
18-
# )
19-
# data_source.end_date = data_source.start_date \
20-
# + timedelta(days=4, hours=4)
21-
22-
# while not data_source.empty():
23-
# data = data_source.get_data(market_credential_service=None)
24-
# df = pd.DataFrame(
25-
# data,
26-
# columns=['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
27-
# )
28-
# pyi_rsi = pyi.rsi(series=df["Close"], timeperiod=14)
29-
# ta_rsi = ta.RSI(df["Close"], timeperiod=14).astype('float64')
30-
# ti_rsi = pd.Series(ti.rsi(df["Close"].to_numpy(), period=14))
31-
# # # Define a tolerance for comparison
32-
# tolerance = 1e-9
33-
# #
34-
# # # Compare the two Series with tolerance
35-
# nan_mask = ~np.isnan(pyi_rsi) & ~np.isnan(ta_rsi)
36-
# comparison_result = np.abs(
37-
# ta_rsi[nan_mask] - ti_rsi[nan_mask]) <= tolerance
38-
39-
# print(ta_rsi.iloc[-1], ti_rsi.iloc[-1])
40-
# # data_source.start_date = \
41-
# # data_source.start_date + timedelta(minutes=15)
42-
# # data_source.end_date = data_source.end_date + timedelta(minutes=15)
43-
# # self.assertTrue(all(comparison_result))
1+
import pandas as pd
2+
import polars as pl
3+
import pandas.testing as pdt
4+
from polars.testing import assert_frame_equal
5+
6+
from tests.resources import TestBaseline
7+
from pyindicators import rsi
8+
9+
10+
class Test(TestBaseline):
11+
correct_output_csv_filename = \
12+
"RSI_14_BTC-EUR_BINANCE_15m_2023-12-01:00:00_2023-12-25:00:00.csv"
13+
14+
def generate_pandas_df(self, polars_source_df):
15+
polars_source_df = rsi(
16+
data=polars_source_df,
17+
period=14,
18+
result_column="RSI_14",
19+
source_column="Close"
20+
)
21+
return polars_source_df
22+
23+
def generate_polars_df(self, pandas_source_df):
24+
pandas_source_df = rsi(
25+
data=pandas_source_df,
26+
period=14,
27+
result_column="RSI_14",
28+
source_column="Close"
29+
)
30+
return pandas_source_df
31+
32+
def test_comparison_pandas(self):
33+
34+
# Load the correct output in a pandas dataframe
35+
correct_output_pd = pd.read_csv(self.get_correct_output_csv_path())
36+
37+
# Load the source in a pandas dataframe
38+
source = pd.read_csv(self.get_source_csv_path())
39+
40+
# Generate the pandas dataframe
41+
output = self.generate_pandas_df(source)
42+
output = output[correct_output_pd.columns]
43+
output["Datetime"] = \
44+
pd.to_datetime(output["Datetime"]).dt.tz_localize(None)
45+
correct_output_pd["Datetime"] = \
46+
pd.to_datetime(correct_output_pd["Datetime"]).dt.tz_localize(None)
47+
48+
print(correct_output_pd.head(40))
49+
print(output.head(40))
50+
# pdt.assert_frame_equal(correct_output_pd, output)
51+
52+
# def test_comparison_polars(self):
53+
54+
# # Load the correct output in a polars dataframe
55+
# correct_output_pl = pl.read_csv(self.get_correct_output_csv_path())
56+
57+
# # Load the source in a polars dataframe
58+
# source = pl.read_csv(self.get_source_csv_path())
59+
60+
# # Generate the polars dataframe
61+
# output = self.generate_polars_df(source)
62+
63+
# # Convert the datetime columns to datetime
64+
# # Convert the 'Datetime' column in both DataFrames to datetime
65+
# output = output.with_columns(
66+
# pl.col("Datetime").str.strptime(pl.Datetime).alias("Datetime")
67+
# )
68+
69+
# correct_output_pl = correct_output_pl.with_columns(
70+
# pl.col("Datetime").str.strptime(pl.Datetime).alias("Datetime")
71+
# )
72+
# output = output[correct_output_pl.columns]
73+
# output = self.make_polars_column_datetime_naive(output, "Datetime")
74+
# correct_output_pl = self.make_polars_column_datetime_naive(
75+
# correct_output_pl, "Datetime"
76+
# )
77+
78+
# assert_frame_equal(correct_output_pl, output)

0 commit comments

Comments
 (0)