Skip to content

Commit d66bef2

Browse files
committed
Add crossunder tests
1 parent 2d87068 commit d66bef2

File tree

5 files changed

+201
-66
lines changed

5 files changed

+201
-66
lines changed

pyindicators/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from .indicators import sma, rsi, crossunder, ema, wilders_rsi, \
2-
crossover, is_crossover, wma, macd, willr
2+
crossover, is_crossover, wma, macd, willr, is_crossunder, crossunder
33

44
__all__ = [
55
'sma',
66
'wma',
77
'is_crossover',
88
'crossunder',
9+
'is_crossunder',
910
'crossover',
1011
'is_crossover',
1112
'ema',

pyindicators/indicators/crossover.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,12 @@ def is_crossover(
108108
second_column=second_column,
109109
result_column=crossover_column,
110110
number_of_data_points=number_of_data_points,
111-
strics=strict
111+
strict=strict
112112
)
113113

114+
if number_of_data_points is None:
115+
number_of_data_points = len(data)
116+
114117
# If crossunder_column is set, check for a value of 1
115118
# in the last data points
116119
if isinstance(data, PdDataFrame):

pyindicators/indicators/crossunder.py

Lines changed: 13 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -14,58 +14,31 @@ def crossunder(
1414
number_of_data_points: int = None,
1515
strict: bool = True,
1616
) -> Union[PdDataFrame, PlDataFrame]:
17-
"""
18-
Identifies crossunder points where `first_column` crosses below
19-
or below `second_column`.
20-
21-
Args:
22-
data: Pandas or Polars DataFrame
23-
first_column: Name of the first column
24-
second_column: Name of the second column
25-
result_column (optional): Name of the column to
26-
store the crossunder points
27-
number_of_data_points (optional):
28-
Number of recent rows to consider
29-
strict (optional): If True, requires exact crossunders; otherwise,
30-
detects when one surpasses the other.
31-
32-
Returns:
33-
A DataFrame with crossunder points marked.
34-
"""
35-
36-
# Restrict data to the last `data_points` rows if specified
17+
3718
if number_of_data_points is not None:
38-
data = data.tail(number_of_data_points) \
39-
if isinstance(data, PdDataFrame) \
40-
else data.slice(-number_of_data_points)
19+
data = data.tail(number_of_data_points).copy() if isinstance(data, PdDataFrame) else data.slice(-number_of_data_points)
4120

42-
# Pandas Implementation
4321
if isinstance(data, PdDataFrame):
4422
col1, col2 = data[first_column], data[second_column]
4523
prev_col1, prev_col2 = col1.shift(1), col2.shift(1)
4624

4725
if strict:
48-
crossunder_mask = (
49-
(prev_col1 > prev_col2) & (col1 < col2)
50-
)
26+
crossunder_mask = (prev_col1 > prev_col2) & (col1 < col2)
5127
else:
52-
crossunder_mask = (col1 > col2) & (prev_col1 <= prev_col2)
28+
crossunder_mask = (col1 > col2) & (prev_col1 <= prev_col2) | (col1 >= col2) & (prev_col1 < prev_col2)
5329

54-
data[result_column] = crossunder_mask.astype(int)
30+
data.loc[:, result_column] = crossunder_mask.astype(int)
5531

56-
# Polars Implementation
5732
elif isinstance(data, PlDataFrame):
5833
col1, col2 = data[first_column], data[second_column]
5934
prev_col1, prev_col2 = col1.shift(1), col2.shift(1)
6035

6136
if strict:
62-
crossunder_mask = ((prev_col1 > prev_col2) & (col1 < col2))
37+
crossunder_mask = (prev_col1 > prev_col2) & (col1 < col2)
6338
else:
64-
crossunder_mask = (col1 > col2) & (prev_col1 <= prev_col2)
39+
crossunder_mask = (col1 > col2) & (prev_col1 <= prev_col2) | (col1 >= col2) & (prev_col1 < prev_col2)
6540

66-
# Convert boolean mask to 1s and 0s
67-
data = data.with_columns(pl.when(crossunder_mask).then(1)
68-
.otherwise(0).alias(result_column))
41+
data = data.with_columns(pl.when(crossunder_mask).then(1).otherwise(0).alias(result_column))
6942

7043
return data
7144

@@ -78,24 +51,6 @@ def is_crossunder(
7851
number_of_data_points: int = None,
7952
strict: bool = True,
8053
) -> bool:
81-
"""
82-
Returns a boolean when the first series crosses below the second series
83-
at any point or within the last n data points.
84-
85-
Args:
86-
data (Union[pd.DataFrame, pl.DataFrame]): The input data.
87-
first_column (str): The name of the first series.
88-
second_column (str): The name of the second series.
89-
crossunder_column (str) (optional):
90-
The name of the column to store the crossunder points.
91-
number_of_data_points (int) (optional):
92-
The number of data points to consider. Defaults to None.
93-
strict (bool) (optional): If True, requires a strict
94-
crossunder. Defaults to True.
95-
96-
Returns:
97-
bool: True if a crossunder occurs, False otherwise.
98-
"""
9954

10055
if len(data) < 2:
10156
return False
@@ -108,16 +63,16 @@ def is_crossunder(
10863
second_column=second_column,
10964
result_column=crossunder_column,
11065
number_of_data_points=number_of_data_points,
111-
strics=strict
66+
strict=strict
11267
)
11368

114-
# If crossunder_column is set, check for a value of 1
115-
# in the last data points
69+
if number_of_data_points is None:
70+
number_of_data_points = len(data)
71+
11672
if isinstance(data, PdDataFrame):
11773
return data[crossunder_column].tail(number_of_data_points).eq(1).any()
11874
elif isinstance(data, pl.DataFrame):
119-
return data[crossunder_column][-number_of_data_points:]\
120-
.to_list().count(1) > 0
75+
return data[crossunder_column][-number_of_data_points:].to_list().count(1) > 0
12176

12277
raise PyIndicatorException(
12378
"Data type not supported. Please provide a Pandas or Polars DataFrame."

tests/indicators/test_crossover.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_crossover_pandas(self):
1818
df.set_index("DateTime", inplace=True)
1919

2020
self.assertTrue(is_crossover(df, first_column="EMA_50", second_column="EMA_200"))
21-
self.assertTrue(is_crossover(df, first_column="EMA_50", second_column="EMA_200", data_points=3))
21+
self.assertTrue(is_crossover(df, first_column="EMA_50", second_column="EMA_200", number_of_data_points=3))
2222

2323
df = pd.DataFrame({
2424
"EMA_50": [200, 201, 202, 203, 204, 205, 206, 208, 210, 210],
@@ -37,7 +37,7 @@ def test_crossover_pandas(self):
3737
df,
3838
first_column="EMA_50",
3939
second_column="EMA_200",
40-
data_points=3
40+
number_of_data_points=3
4141
)
4242
)
4343

@@ -57,7 +57,7 @@ def test_crossover_pandas(self):
5757
df,
5858
first_column="EMA_50",
5959
second_column="EMA_200",
60-
data_points=3
60+
number_of_data_points=3
6161
)
6262
)
6363

@@ -69,7 +69,7 @@ def test_crossover_polars(self):
6969
})
7070

7171
self.assertTrue(is_crossover(df, first_column="EMA_50", second_column="EMA_200"))
72-
self.assertTrue(is_crossover(df, first_column="EMA_50", second_column="EMA_200", data_points=3))
72+
self.assertTrue(is_crossover(df, first_column="EMA_50", second_column="EMA_200", number_of_data_points=3))
7373

7474
df = pl.DataFrame({
7575
"EMA_50": [200, 201, 202, 203, 204, 205, 206, 208, 210, 210],
@@ -85,7 +85,7 @@ def test_crossover_polars(self):
8585
df,
8686
first_column="EMA_50",
8787
second_column="EMA_200",
88-
data_points=3
88+
number_of_data_points=3
8989
)
9090
)
9191

@@ -103,6 +103,6 @@ def test_crossover_polars(self):
103103
df,
104104
first_column="EMA_50",
105105
second_column="EMA_200",
106-
data_points=3
106+
number_of_data_points=3
107107
)
108108
)
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import pandas as pd
2+
import polars as pl
3+
from unittest import TestCase
4+
5+
from pyindicators import is_crossunder
6+
7+
8+
class TestCrossover(TestCase):
9+
10+
def test_crossunder_pandas(self):
11+
df = pd.DataFrame({
12+
"EMA_50": [200, 201, 202, 203, 204, 205, 207, 205, 204, 205],
13+
"EMA_200": [200, 201, 202, 203, 204, 205, 206, 207, 209, 209],
14+
"DateTime": pd.date_range("2021-01-01", periods=10, freq="D")
15+
})
16+
17+
# Set index to DateTime
18+
df.set_index("DateTime", inplace=True)
19+
20+
self.assertTrue(
21+
is_crossunder(df, first_column="EMA_50", second_column="EMA_200")
22+
)
23+
self.assertFalse(
24+
is_crossunder(
25+
df,
26+
first_column="EMA_50",
27+
second_column="EMA_200",
28+
number_of_data_points=3
29+
)
30+
)
31+
self.assertTrue(
32+
is_crossunder(
33+
df,
34+
first_column="EMA_50",
35+
second_column="EMA_200",
36+
number_of_data_points=5
37+
)
38+
)
39+
40+
df = pd.DataFrame({
41+
"EMA_50": [200, 201, 202, 203, 204, 205, 207, 205, 209, 209],
42+
"EMA_200": [200, 201, 202, 203, 204, 205, 206, 209, 209, 209],
43+
"DateTime": pd.date_range("2021-01-01", periods=10, freq="D")
44+
})
45+
46+
# Set index to DateTime
47+
df.set_index("DateTime", inplace=True)
48+
49+
self.assertTrue(
50+
is_crossunder(df, first_column="EMA_50", second_column="EMA_200")
51+
)
52+
self.assertFalse(
53+
is_crossunder(
54+
df,
55+
first_column="EMA_50",
56+
second_column="EMA_200",
57+
number_of_data_points=3
58+
)
59+
)
60+
self.assertFalse(
61+
is_crossunder(
62+
df,
63+
first_column="EMA_50",
64+
second_column="EMA_200",
65+
number_of_data_points=3
66+
)
67+
)
68+
69+
# Check how strict works
70+
df = pd.DataFrame({
71+
"EMA_50": [200, 201, 202, 203, 204, 205, 207, 206, 205, 209],
72+
"EMA_200": [200, 201, 202, 203, 204, 205, 206, 206, 206, 209],
73+
"DateTime": pd.date_range("2021-01-01", periods=10, freq="D")
74+
})
75+
76+
# Set index to DateTime
77+
df.set_index("DateTime", inplace=True)
78+
79+
self.assertFalse(
80+
is_crossunder(
81+
df,
82+
first_column="EMA_50",
83+
second_column="EMA_200",
84+
number_of_data_points=4,
85+
strict=True
86+
)
87+
)
88+
self.assertTrue(
89+
is_crossunder(
90+
df,
91+
first_column="EMA_50",
92+
second_column="EMA_200",
93+
number_of_data_points=4,
94+
strict=False
95+
)
96+
)
97+
98+
99+
def test_crossunder_polars(self):
100+
df = pl.DataFrame({
101+
"EMA_50": [200, 201, 202, 203, 204, 205, 207, 205, 204, 205],
102+
"EMA_200": [200, 201, 202, 203, 204, 205, 206, 207, 209, 209],
103+
"DateTime": pd.date_range("2021-01-01", periods=10, freq="D")
104+
})
105+
106+
self.assertTrue(
107+
is_crossunder(df, first_column="EMA_50", second_column="EMA_200")
108+
)
109+
self.assertFalse(
110+
is_crossunder(
111+
df,
112+
first_column="EMA_50",
113+
second_column="EMA_200",
114+
number_of_data_points=3
115+
)
116+
)
117+
self.assertTrue(
118+
is_crossunder(
119+
df,
120+
first_column="EMA_50",
121+
second_column="EMA_200",
122+
number_of_data_points=5
123+
)
124+
)
125+
126+
df = pl.DataFrame({
127+
"EMA_50": [200, 201, 202, 203, 204, 205, 207, 205, 209, 209],
128+
"EMA_200": [200, 201, 202, 203, 204, 205, 206, 209, 209, 209],
129+
"DateTime": pd.date_range("2021-01-01", periods=10, freq="D")
130+
})
131+
132+
self.assertTrue(
133+
is_crossunder(df, first_column="EMA_50", second_column="EMA_200")
134+
)
135+
self.assertFalse(
136+
is_crossunder(
137+
df,
138+
first_column="EMA_50",
139+
second_column="EMA_200",
140+
number_of_data_points=3
141+
)
142+
)
143+
self.assertFalse(
144+
is_crossunder(
145+
df,
146+
first_column="EMA_50",
147+
second_column="EMA_200",
148+
number_of_data_points=3
149+
)
150+
)
151+
152+
# Check how strict works
153+
df = pl.DataFrame({
154+
"EMA_50": [200, 201, 202, 203, 204, 205, 207, 206, 205, 209],
155+
"EMA_200": [200, 201, 202, 203, 204, 205, 206, 206, 206, 209],
156+
"DateTime": pd.date_range("2021-01-01", periods=10, freq="D")
157+
})
158+
159+
self.assertFalse(
160+
is_crossunder(
161+
df,
162+
first_column="EMA_50",
163+
second_column="EMA_200",
164+
number_of_data_points=4,
165+
strict=True
166+
)
167+
)
168+
self.assertTrue(
169+
is_crossunder(
170+
df,
171+
first_column="EMA_50",
172+
second_column="EMA_200",
173+
number_of_data_points=4,
174+
strict=False
175+
)
176+
)

0 commit comments

Comments
 (0)