Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/indicators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Marker for indicators test package
31 changes: 31 additions & 0 deletions tests/indicators/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
import polars as pl

@pytest.fixture
def sample_ohlcv_df():
"""
Returns a deterministic polars DataFrame with 120 rows and standard OHLCV columns.
Values are monotonically increasing for easy/deterministic indicator output.
"""
n = 120
return pl.DataFrame({
"open": [float(i) for i in range(n)],
"high": [float(i) + 1 for i in range(n)],
"low": [float(i) - 1 for i in range(n)],
"close": [float(i) + 0.5 for i in range(n)],
"volume": [100 + i for i in range(n)],
})

@pytest.fixture
def small_ohlcv_df():
"""
Returns a polars DataFrame with 5 rows to trigger insufficient data paths.
"""
n = 5
return pl.DataFrame({
"open": [float(i) for i in range(n)],
"high": [float(i) + 1 for i in range(n)],
"low": [float(i) - 1 for i in range(n)],
"close": [float(i) + 0.5 for i in range(n)],
"volume": [100 + i for i in range(n)],
})
77 changes: 77 additions & 0 deletions tests/indicators/test_all_indicators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import pytest
import polars as pl
import inspect
import importlib
import pkgutil
from project_x_py.indicators.base import BaseIndicator

def _concrete_indicator_classes():
# Recursively discover all non-abstract subclasses of BaseIndicator in project_x_py.indicators.*
import project_x_py.indicators
seen = set()
result = []

def onclass(cls):
if cls in seen:
return
seen.add(cls)
# Must be subclass of BaseIndicator but not the base class itself
if not issubclass(cls, BaseIndicator) or cls is BaseIndicator:
return
# Skip abstract classes (those with any abstractmethods)
if getattr(cls, "__abstractmethods__", None):
return
# Only include classes defined in project_x_py.indicators.*
if not cls.__module__.startswith("project_x_py.indicators."):
return
result.append(cls)

# Walk all modules in project_x_py.indicators package
for finder, name, ispkg in pkgutil.walk_packages(project_x_py.indicators.__path__, project_x_py.indicators.__name__ + "."):
try:
mod = importlib.import_module(name)
except Exception:
continue # If import fails, skip that module
for _, obj in inspect.getmembers(mod, inspect.isclass):
onclass(obj)
# Remove duplicates, sort by class name for determinism
return sorted(set(result), key=lambda cls: cls.__name__)

@pytest.mark.parametrize("indicator_cls", _concrete_indicator_classes(), ids=lambda cls: cls.__name__)
def test_indicator_calculate_adds_new_column(indicator_cls, sample_ohlcv_df):
"""
For every indicator class: instantiate with default ctor, call .calculate() or __call__ on sample data.
- No exception is raised.
- Result is a polars.DataFrame with same row count.
- At least one new column is present.
"""
instance = indicator_cls()
input_cols = set(sample_ohlcv_df.columns)
# Try __call__ first (uses caching), then fallback to .calculate
try:
out_df = instance(sample_ohlcv_df)
except Exception:
out_df = instance.calculate(sample_ohlcv_df)

assert isinstance(out_df, pl.DataFrame), f"{indicator_cls.__name__} output is not a polars.DataFrame"
assert out_df.height == sample_ohlcv_df.height, (
f"{indicator_cls.__name__} output row count {out_df.height} != input {sample_ohlcv_df.height}"
)
new_cols = set(out_df.columns) - input_cols
assert new_cols, f"{indicator_cls.__name__} did not add any new columns"

def _get_new_column_names(indicator_cls, input_cols, df):
return set(df.columns) - set(input_cols)

@pytest.mark.parametrize("indicator_cls", _concrete_indicator_classes(), ids=lambda cls: cls.__name__)
def test_indicator_caching_returns_same_object(indicator_cls, sample_ohlcv_df):
"""
Calling the indicator twice with the same df on the same instance should return the exact same DataFrame object (proves internal cache).
"""
instance = indicator_cls()
# Use __call__ to trigger cache logic
out1 = instance(sample_ohlcv_df)
out2 = instance(sample_ohlcv_df)
assert out1 is out2, (
f"{indicator_cls.__name__} did not return identical object on repeated call (cache broken?)"
)
51 changes: 51 additions & 0 deletions tests/indicators/test_base_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
import polars as pl

from project_x_py.indicators.base import (
BaseIndicator,
safe_division,
IndicatorError,
)
from project_x_py.indicators.overlap import calculate_sma, SMA
from project_x_py.indicators.momentum import calculate_rsi
from project_x_py.indicators.volatility import calculate_atr
from project_x_py.indicators.volume import calculate_obv

def test_validate_data_missing_column(sample_ohlcv_df):
df_missing = sample_ohlcv_df.drop("open")
sma = SMA()
with pytest.raises(IndicatorError, match="Missing required columns?"):
sma.validate_data(df_missing, required_cols=["open", "close"])

def test_validate_data_length_too_short(small_ohlcv_df):
sma = SMA()
with pytest.raises(IndicatorError, match="at least"):
sma.validate_data_length(small_ohlcv_df, min_length=10)

def test_validate_period_negative_or_zero():
sma = SMA()
for val in [0, -1, -10]:
with pytest.raises(IndicatorError, match="period"):
sma.validate_period(val)

def test_safe_division_behavior():
df = pl.DataFrame({"numerator": [1, 2], "denominator": [0, 2]})
out = df.with_columns(
result=safe_division(pl.col("numerator"), pl.col("denominator"), default=-1)
)
# Should be Series [-1, 1]
assert out["result"].to_list() == [-1, 1], f"safe_division gave {out['result'].to_list()}"

@pytest.mark.parametrize("func, kwargs, exp_col", [
(calculate_sma, {"period": 5}, "sma_5"),
(calculate_rsi, {"period": 14}, "rsi_14"),
(calculate_atr, {"period": 14}, "atr_14"),
(calculate_obv, {}, "obv"),
])
def test_convenience_functions_expected_column_and_shape(sample_ohlcv_df, func, kwargs, exp_col):
"""
Convenience functions (calculate_sma etc) add expected columns and preserve row count.
"""
df_func = func(sample_ohlcv_df, **kwargs)
assert exp_col in df_func.columns, f"{func.__name__} did not add expected column '{exp_col}'"
assert df_func.height == sample_ohlcv_df.height