Skip to content

Commit 8875575

Browse files
committed
Add MACD and WMA indicators
1 parent 26c12e9 commit 8875575

File tree

10 files changed

+2585
-3
lines changed

10 files changed

+2585
-3
lines changed

README.md

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ PyIndicators is a powerful and user-friendly Python library for technical analys
88
<picture>
99
<source media="(prefers-color-scheme: dark)" srcset="static/sponsors/finterion-dark.png">
1010
<source media="(prefers-color-scheme: light)" srcset="static/sponsors/finterion-light.png">
11-
<img src="static/sponsors/finterion-light.svg" alt="Finterion Logo" style="height: 55px;">
11+
<img src="static/sponsors/finterion-light.svg" alt="Finterion Logo" style="height: 40px;">
1212
</picture>
1313
</a>
1414

@@ -26,8 +26,10 @@ pip install pyindicators
2626
* Dataframe first approach, with support for both pandas dataframes and polars dataframes
2727
* Supports python version 3.9 and above.
2828
* [Trend indicators](#trend-indicators)
29+
* [Weighted Moving Average (WMA)](#weighted-moving-average-wma)
2930
* [Simple Moving Average (SMA)](#simple-moving-average-sma)
3031
* [Exponential Moving Average (EMA)](#exponential-moving-average-ema)
32+
* [Moving Average Convergence Divergence (MACD)](#moving-average-convergence-divergence-macd)
3133
* [Momentum indicators](#momentum-indicators)
3234
* [Relative Strength Index (RSI)](#relative-strength-index-rsi)
3335
* [Relative Strength Index Wilders method (Wilders RSI)](#wilders-relative-strength-index-wilders-rsi)
@@ -39,6 +41,31 @@ pip install pyindicators
3941

4042
### Trend Indicators
4143

44+
#### Weighted Moving Average (WMA)
45+
46+
```python
47+
from investing_algorithm_framework import CSVOHLCVMarketDataSource
48+
49+
from pyindicators import wma
50+
51+
# For this example the investing algorithm framework is used for dataframe creation,
52+
csv_path = "./tests/test_data/OHLCV_BTC-EUR_BINANCE_15m_2023-12-01:00:00_2023-12-25:00:00.csv"
53+
data_source = CSVOHLCVMarketDataSource(csv_file_path=csv_path)
54+
55+
pl_df = data_source.get_data()
56+
pd_df = data_source.get_data(pandas=True)
57+
58+
# Calculate SMA for Polars DataFrame
59+
pl_df = wma(pl_df, source_column="Close", period=200, result_column="SMA_200")
60+
pl_df.show(10)
61+
62+
# Calculate SMA for Pandas DataFrame
63+
pd_df = wma(pd_df, source_column="Close", period=200, result_column="SMA_200")
64+
pd_df.tail(10)
65+
```
66+
67+
![WMA](https://github.com/coding-kitties/PyIndicators/blob/main/static/images/indicators/wma.png)
68+
4269
#### Simple Moving Average (SMA)
4370

4471
Smooth out price data to identify trend direction.
@@ -93,6 +120,33 @@ pd_df.tail(10)
93120

94121
![EMA](https://github.com/coding-kitties/PyIndicators/blob/main/static/images/indicators/ema.png)
95122

123+
#### Moving Average Convergence Divergence (MACD)
124+
125+
```python
126+
from investing_algorithm_framework import CSVOHLCVMarketDataSource
127+
128+
from pyindicators import macd
129+
130+
# For this example the investing algorithm framework is used for dataframe creation,
131+
csv_path = "./tests/test_data/OHLCV_BTC-EUR_BINANCE_15m_2023-12-01:00:00_2023-12-25:00:00.csv"
132+
data_source = CSVOHLCVMarketDataSource(csv_file_path=csv_path)
133+
134+
pl_df = data_source.get_data()
135+
pd_df = data_source.get_data(pandas=True)
136+
137+
# Calculate MACD for Polars DataFrame
138+
pl_df = macd(pl_df, source_column="Close", short_period=12, long_period=26, signal_period=9)
139+
140+
# Calculate MACD for Pandas DataFrame
141+
pd_df = macd(pd_df, source_column="Close", short_period=12, long_period=26, signal_period=9)
142+
143+
pl_df.show(10)
144+
pd_df.tail(10)
145+
```
146+
147+
![EMA](https://github.com/coding-kitties/PyIndicators/blob/main/static/images/indicators/macd.png)
148+
149+
96150
### Momentum Indicators
97151

98152
#### Relative Strength Index (RSI)

pyindicators/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from .indicators import sma, rsi, is_crossover, crossunder, ema, wilders_rsi, \
2-
crossover, is_crossover
2+
crossover, is_crossover, wma, macd
33

44
__all__ = [
55
'sma',
6+
'wma',
67
'is_crossover',
78
'crossunder',
89
'crossover',
910
'is_crossover',
1011
'ema',
1112
'rsi',
12-
"wilders_rsi"
13+
"wilders_rsi",
14+
'macd'
1315
]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
from .simple_moving_average import sma
2+
from .weighted_moving_average import wma
23
from .crossover import is_crossover, crossover
34
from .crossunder import crossunder
45
from .exponential_moving_average import ema
56
from .rsi import rsi, wilders_rsi
7+
from .macd import macd
68

79
__all__ = [
810
'sma',
11+
"wma",
912
'is_crossover',
1013
"crossover",
1114
'crossunder',
1215
'ema',
1316
'rsi',
1417
'wilders_rsi',
18+
'macd'
1519
]

pyindicators/indicators/macd.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from typing import Union
2+
3+
import numpy as np
4+
from pandas import DataFrame as PdDataFrame
5+
from polars import DataFrame as PlDataFrame
6+
import pandas as pd
7+
import polars as pl
8+
9+
from pyindicators.exceptions import PyIndicatorException
10+
from pyindicators.indicators import ema
11+
12+
13+
def macd(
14+
data: Union[PdDataFrame, PlDataFrame],
15+
source_column: str,
16+
short_period: int = 12,
17+
long_period: int = 26,
18+
signal_period: int = 9,
19+
macd_column: str = "macd",
20+
signal_column: str = "macd_signal",
21+
histogram_column: str = "macd_histogram"
22+
) -> Union[PdDataFrame, PlDataFrame]:
23+
"""
24+
Calculate the MACD (Moving Average Convergence Divergence) for a given DataFrame.
25+
26+
Args:
27+
data (Union[pd.DataFrame, pl.DataFrame]): Input data containing the price series.
28+
source_column (str): Column name for the price series.
29+
short_period (int, optional): Period for the short-term EMA (default: 12).
30+
long_period (int, optional): Period for the long-term EMA (default: 26).
31+
signal_period (int, optional): Period for the Signal Line EMA (default: 9).
32+
macd_column (str, optional): Column name to store the MACD line.
33+
signal_column (str, optional): Column name to store the Signal line.
34+
histogram_column (str, optional): Column name to store the MACD histogram.
35+
36+
Returns:
37+
Union[pd.DataFrame, pl.DataFrame]: DataFrame with MACD, Signal Line, and Histogram.
38+
"""
39+
if source_column not in data.columns:
40+
raise PyIndicatorException(
41+
f"Column '{source_column}' not found in DataFrame"
42+
)
43+
44+
if isinstance(data, PdDataFrame):
45+
# Calculate the short-term and long-term EMAs
46+
data = ema(data, source_column, short_period, f"EMA_{short_period}")
47+
data = ema(data, source_column, long_period, f"EMA_{long_period}")
48+
49+
# Calculate the MACD line
50+
data[macd_column] = \
51+
data[f"EMA_{short_period}"] - data[f"EMA_{long_period}"]
52+
53+
# Calculate the Signal Line
54+
data = ema(data, macd_column, signal_period, signal_column)
55+
56+
# Calculate the MACD Histogram
57+
data[histogram_column] = data[macd_column] - data[signal_column]
58+
return data
59+
elif isinstance(data, pl.DataFrame):
60+
return None
61+
else:
62+
raise PyIndicatorException(
63+
"Unsupported DataFrame type. Use Pandas or Polars."
64+
)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from typing import Union
2+
3+
import numpy as np
4+
from pandas import DataFrame as PdDataFrame
5+
from polars import DataFrame as PlDataFrame
6+
import pandas as pd
7+
import polars as pl
8+
9+
from pyindicators.exceptions import PyIndicatorException
10+
11+
12+
def wma(
13+
data: Union[PdDataFrame, PlDataFrame],
14+
source_column: str,
15+
period: int,
16+
result_column: str = None,
17+
) -> Union[PdDataFrame, PlDataFrame]:
18+
"""
19+
Function to calculate the weighted moving average of a series.
20+
21+
Args:
22+
data (Union[PdDataFrame, PlDataFrame]): The input data.
23+
source_column (str): The name of the series.
24+
period (int): The period for the simple moving average.
25+
result_column (str, optional): The name of the column to store the
26+
simple moving average. Defaults to None.
27+
28+
Returns:
29+
Union[PdDataFrame, PlDataFrame]: Returns a DataFrame
30+
with the weighted moving average of the series.
31+
"""
32+
if len(data) < period:
33+
raise PyIndicatorException(
34+
"The data must be larger than the period " +
35+
f"{period} to calculate the WMA. The data " +
36+
f"only contains {len(data)} data points."
37+
)
38+
if result_column is None:
39+
result_column = f"WMA_{period}"
40+
41+
weights = np.arange(1, period + 1)
42+
43+
if isinstance(data, pd.DataFrame):
44+
if source_column not in data.columns:
45+
raise PyIndicatorException(
46+
f"Column '{source_column}' not found in DataFrame"
47+
)
48+
49+
data[result_column] = (
50+
data[source_column]
51+
.rolling(window=period)
52+
.apply(lambda x: np.dot(x, weights) / weights.sum(), raw=True)
53+
)
54+
return data
55+
56+
elif isinstance(data, pl.DataFrame):
57+
if source_column not in data.columns:
58+
raise PyIndicatorException(
59+
f"Column '{source_column}' not found in DataFrame"
60+
)
61+
62+
wma_values = (
63+
data[source_column]
64+
.rolling_mean(window_size=period, weights=weights.tolist())
65+
)
66+
67+
data = data.with_columns(pl.Series(result_column, wma_values))
68+
return data
69+
70+
else:
71+
raise PyIndicatorException(
72+
"Unsupported DataFrame type. Use Pandas or Polars."
73+
)

static/images/indicators/macd.png

91.9 KB
Loading

static/images/indicators/wma.png

73.8 KB
Loading

tests/indicators/test_macd.py

Whitespace-only changes.

tests/indicators/test_wma.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import pandas as pd
2+
import polars as pl
3+
import pandas.testing as pdt
4+
from polars.testing import assert_frame_equal
5+
6+
from tests.resources import TestBaseline
7+
from pyindicators import wma
8+
9+
10+
class Test(TestBaseline):
11+
correct_output_csv_filename = \
12+
"WMA_200_BTC-EUR_BINANCE_15m_2023-12-01-00-00_2023-12-25-00-00.csv"
13+
14+
def generate_pandas_df(self, polars_source_df):
15+
polars_source_df = wma(
16+
data=polars_source_df,
17+
period=200,
18+
result_column="WMA_200",
19+
source_column="Close"
20+
)
21+
return polars_source_df
22+
23+
def generate_polars_df(self, pandas_source_df):
24+
pandas_source_df = wma(
25+
data=pandas_source_df,
26+
period=200,
27+
result_column="WMA_200",
28+
source_column="Close"
29+
)
30+
return pandas_source_df
31+
32+
def test_comparison_pandas(self):
33+
34+
# Load the correct output in a pandas dataframe
35+
correct_output_pd = pd.read_csv(self.get_correct_output_csv_path())
36+
37+
# Load the source in a pandas dataframe
38+
source = pd.read_csv(self.get_source_csv_path())
39+
40+
# Generate the pandas dataframe
41+
output = self.generate_pandas_df(source)
42+
output = output[correct_output_pd.columns]
43+
output["Datetime"] = \
44+
pd.to_datetime(output["Datetime"]).dt.tz_localize(None)
45+
correct_output_pd["Datetime"] = \
46+
pd.to_datetime(correct_output_pd["Datetime"]).dt.tz_localize(None)
47+
48+
pdt.assert_frame_equal(correct_output_pd, output)
49+
50+
def test_comparison_polars(self):
51+
52+
# Load the correct output in a polars dataframe
53+
correct_output_pl = pl.read_csv(
54+
self.get_correct_output_csv_path(),
55+
schema_overrides={"WMA_200": pl.Float64}
56+
)
57+
58+
# Load the source in a polars dataframe
59+
source = pl.read_csv(self.get_source_csv_path())
60+
61+
# Generate the polars dataframe
62+
output = self.generate_polars_df(source)
63+
64+
# Convert the datetime columns to datetime
65+
# Convert the 'Datetime' column in both DataFrames to datetime
66+
output = output.with_columns(
67+
pl.col("Datetime").str.strptime(pl.Datetime).alias("Datetime")
68+
)
69+
70+
correct_output_pl = correct_output_pl.with_columns(
71+
pl.col("Datetime").str.strptime(pl.Datetime).alias("Datetime")
72+
)
73+
output = output[correct_output_pl.columns]
74+
output = self.make_polars_column_datetime_naive(output, "Datetime")
75+
correct_output_pl = self.make_polars_column_datetime_naive(
76+
correct_output_pl, "Datetime"
77+
)
78+
79+
assert_frame_equal(correct_output_pl, output)

0 commit comments

Comments
 (0)