Skip to content

Commit 4e77323

Browse files
authored
Merge pull request #20 from TexasCoding/cosine/feature/testing-suite-indicators-an8kgy
Create Comprehensive Testing Suite for Indicators Module
2 parents 9daf244 + 65a0d20 commit 4e77323

File tree

4 files changed

+160
-0
lines changed

4 files changed

+160
-0
lines changed

tests/indicators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Marker for indicators test package

tests/indicators/conftest.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
import polars as pl
3+
4+
@pytest.fixture
5+
def sample_ohlcv_df():
6+
"""
7+
Returns a deterministic polars DataFrame with 120 rows and standard OHLCV columns.
8+
Values are monotonically increasing for easy/deterministic indicator output.
9+
"""
10+
n = 120
11+
return pl.DataFrame({
12+
"open": [float(i) for i in range(n)],
13+
"high": [float(i) + 1 for i in range(n)],
14+
"low": [float(i) - 1 for i in range(n)],
15+
"close": [float(i) + 0.5 for i in range(n)],
16+
"volume": [100 + i for i in range(n)],
17+
})
18+
19+
@pytest.fixture
20+
def small_ohlcv_df():
21+
"""
22+
Returns a polars DataFrame with 5 rows to trigger insufficient data paths.
23+
"""
24+
n = 5
25+
return pl.DataFrame({
26+
"open": [float(i) for i in range(n)],
27+
"high": [float(i) + 1 for i in range(n)],
28+
"low": [float(i) - 1 for i in range(n)],
29+
"close": [float(i) + 0.5 for i in range(n)],
30+
"volume": [100 + i for i in range(n)],
31+
})
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import pytest
2+
import polars as pl
3+
import inspect
4+
import importlib
5+
import pkgutil
6+
from project_x_py.indicators.base import BaseIndicator
7+
8+
def _concrete_indicator_classes():
9+
# Recursively discover all non-abstract subclasses of BaseIndicator in project_x_py.indicators.*
10+
import project_x_py.indicators
11+
seen = set()
12+
result = []
13+
14+
def onclass(cls):
15+
if cls in seen:
16+
return
17+
seen.add(cls)
18+
# Must be subclass of BaseIndicator but not the base class itself
19+
if not issubclass(cls, BaseIndicator) or cls is BaseIndicator:
20+
return
21+
# Skip abstract classes (those with any abstractmethods)
22+
if getattr(cls, "__abstractmethods__", None):
23+
return
24+
# Only include classes defined in project_x_py.indicators.*
25+
if not cls.__module__.startswith("project_x_py.indicators."):
26+
return
27+
result.append(cls)
28+
29+
# Walk all modules in project_x_py.indicators package
30+
for finder, name, ispkg in pkgutil.walk_packages(project_x_py.indicators.__path__, project_x_py.indicators.__name__ + "."):
31+
try:
32+
mod = importlib.import_module(name)
33+
except Exception:
34+
continue # If import fails, skip that module
35+
for _, obj in inspect.getmembers(mod, inspect.isclass):
36+
onclass(obj)
37+
# Remove duplicates, sort by class name for determinism
38+
return sorted(set(result), key=lambda cls: cls.__name__)
39+
40+
@pytest.mark.parametrize("indicator_cls", _concrete_indicator_classes(), ids=lambda cls: cls.__name__)
41+
def test_indicator_calculate_adds_new_column(indicator_cls, sample_ohlcv_df):
42+
"""
43+
For every indicator class: instantiate with default ctor, call .calculate() or __call__ on sample data.
44+
- No exception is raised.
45+
- Result is a polars.DataFrame with same row count.
46+
- At least one new column is present.
47+
"""
48+
instance = indicator_cls()
49+
input_cols = set(sample_ohlcv_df.columns)
50+
# Try __call__ first (uses caching), then fallback to .calculate
51+
try:
52+
out_df = instance(sample_ohlcv_df)
53+
except Exception:
54+
out_df = instance.calculate(sample_ohlcv_df)
55+
56+
assert isinstance(out_df, pl.DataFrame), f"{indicator_cls.__name__} output is not a polars.DataFrame"
57+
assert out_df.height == sample_ohlcv_df.height, (
58+
f"{indicator_cls.__name__} output row count {out_df.height} != input {sample_ohlcv_df.height}"
59+
)
60+
new_cols = set(out_df.columns) - input_cols
61+
assert new_cols, f"{indicator_cls.__name__} did not add any new columns"
62+
63+
def _get_new_column_names(indicator_cls, input_cols, df):
64+
return set(df.columns) - set(input_cols)
65+
66+
@pytest.mark.parametrize("indicator_cls", _concrete_indicator_classes(), ids=lambda cls: cls.__name__)
67+
def test_indicator_caching_returns_same_object(indicator_cls, sample_ohlcv_df):
68+
"""
69+
Calling the indicator twice with the same df on the same instance should return the exact same DataFrame object (proves internal cache).
70+
"""
71+
instance = indicator_cls()
72+
# Use __call__ to trigger cache logic
73+
out1 = instance(sample_ohlcv_df)
74+
out2 = instance(sample_ohlcv_df)
75+
assert out1 is out2, (
76+
f"{indicator_cls.__name__} did not return identical object on repeated call (cache broken?)"
77+
)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import pytest
2+
import polars as pl
3+
4+
from project_x_py.indicators.base import (
5+
BaseIndicator,
6+
safe_division,
7+
IndicatorError,
8+
)
9+
from project_x_py.indicators.overlap import calculate_sma, SMA
10+
from project_x_py.indicators.momentum import calculate_rsi
11+
from project_x_py.indicators.volatility import calculate_atr
12+
from project_x_py.indicators.volume import calculate_obv
13+
14+
def test_validate_data_missing_column(sample_ohlcv_df):
15+
df_missing = sample_ohlcv_df.drop("open")
16+
sma = SMA()
17+
with pytest.raises(IndicatorError, match="Missing required columns?"):
18+
sma.validate_data(df_missing, required_cols=["open", "close"])
19+
20+
def test_validate_data_length_too_short(small_ohlcv_df):
21+
sma = SMA()
22+
with pytest.raises(IndicatorError, match="at least"):
23+
sma.validate_data_length(small_ohlcv_df, min_length=10)
24+
25+
def test_validate_period_negative_or_zero():
26+
sma = SMA()
27+
for val in [0, -1, -10]:
28+
with pytest.raises(IndicatorError, match="period"):
29+
sma.validate_period(val)
30+
31+
def test_safe_division_behavior():
32+
df = pl.DataFrame({"numerator": [1, 2], "denominator": [0, 2]})
33+
out = df.with_columns(
34+
result=safe_division(pl.col("numerator"), pl.col("denominator"), default=-1)
35+
)
36+
# Should be Series [-1, 1]
37+
assert out["result"].to_list() == [-1, 1], f"safe_division gave {out['result'].to_list()}"
38+
39+
@pytest.mark.parametrize("func, kwargs, exp_col", [
40+
(calculate_sma, {"period": 5}, "sma_5"),
41+
(calculate_rsi, {"period": 14}, "rsi_14"),
42+
(calculate_atr, {"period": 14}, "atr_14"),
43+
(calculate_obv, {}, "obv"),
44+
])
45+
def test_convenience_functions_expected_column_and_shape(sample_ohlcv_df, func, kwargs, exp_col):
46+
"""
47+
Convenience functions (calculate_sma etc) add expected columns and preserve row count.
48+
"""
49+
df_func = func(sample_ohlcv_df, **kwargs)
50+
assert exp_col in df_func.columns, f"{func.__name__} did not add expected column '{exp_col}'"
51+
assert df_func.height == sample_ohlcv_df.height

0 commit comments

Comments
 (0)