Skip to content

Commit a27eeda

Browse files
feat: numba implemented, benchmarking provided #143
closes #92
1 parent 2c67c8f commit a27eeda

File tree

6 files changed

+577
-0
lines changed

6 files changed

+577
-0
lines changed

.github/workflows/benchmark.yml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
name: Performance Benchmarks
2+
3+
on:
4+
workflow_dispatch:
5+
push:
6+
branches:
7+
- main
8+
paths:
9+
- 'src/quant_research_starter/backtest/**'
10+
- 'src/quant_research_starter/benchmarks/**'
11+
- '.github/workflows/benchmark.yml'
12+
13+
jobs:
14+
benchmark:
15+
runs-on: ubuntu-latest
16+
steps:
17+
- uses: actions/checkout@v3
18+
19+
- name: Set up Python
20+
uses: actions/setup-python@v4
21+
with:
22+
python-version: '3.10'
23+
24+
- name: Install dependencies
25+
run: |
26+
python -m pip install --upgrade pip
27+
pip install numpy pandas numba
28+
pip install -e .
29+
30+
- name: Run benchmarks
31+
run: |
32+
cd src/quant_research_starter/benchmarks
33+
python bench_opt.py > benchmark_results.txt 2>&1 || true
34+
35+
- name: Upload benchmark results
36+
uses: actions/upload-artifact@v3
37+
if: always()
38+
with:
39+
name: benchmark-results
40+
path: src/quant_research_starter/benchmarks/benchmark_results.txt
41+
retention-days: 30
42+
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""Cython-optimized backtest operations (skeleton)."""
2+
3+
cimport cython
4+
import numpy as np
5+
cimport numpy as np
6+
7+
DTYPE = np.float64
8+
ctypedef np.float64_t DTYPE_t
9+
10+
11+
@cython.boundscheck(False)
12+
@cython.wraparound(False)
13+
def compute_strategy_returns_cython(
14+
np.ndarray[DTYPE_t, ndim=2] weights_prev,
15+
np.ndarray[DTYPE_t, ndim=2] returns,
16+
np.ndarray[DTYPE_t, ndim=1] turnover,
17+
DTYPE_t transaction_cost
18+
):
19+
"""Compute strategy returns with transaction costs (Cython version)."""
20+
cdef int n_days = weights_prev.shape[0]
21+
cdef int n_assets = weights_prev.shape[1]
22+
cdef np.ndarray[DTYPE_t, ndim=1] strat_ret = np.zeros(n_days, dtype=DTYPE)
23+
cdef int i, j
24+
cdef DTYPE_t ret_sum
25+
26+
for i in range(n_days):
27+
ret_sum = 0.0
28+
for j in range(n_assets):
29+
ret_sum += weights_prev[i, j] * returns[i, j]
30+
strat_ret[i] = ret_sum - (turnover[i] * transaction_cost)
31+
32+
return strat_ret
33+
34+
35+
@cython.boundscheck(False)
36+
@cython.wraparound(False)
37+
def compute_turnover_cython(
38+
np.ndarray[DTYPE_t, ndim=2] weights,
39+
np.ndarray[DTYPE_t, ndim=2] weights_prev
40+
):
41+
"""Compute turnover (L1 change / 2) (Cython version)."""
42+
cdef int n_days = weights.shape[0]
43+
cdef int n_assets = weights.shape[1]
44+
cdef np.ndarray[DTYPE_t, ndim=1] turnover = np.zeros(n_days, dtype=DTYPE)
45+
cdef int i, j
46+
cdef DTYPE_t total_change
47+
48+
for i in range(n_days):
49+
total_change = 0.0
50+
for j in range(n_assets):
51+
total_change += abs(weights[i, j] - weights_prev[i, j])
52+
turnover[i] = total_change * 0.5
53+
54+
return turnover
55+
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
"""Numba-accelerated backtest operations."""
2+
3+
import numpy as np
4+
5+
try:
6+
from numba import jit, prange
7+
8+
NUMBA_AVAILABLE = True
9+
except ImportError:
10+
NUMBA_AVAILABLE = False
11+
12+
def jit(*args, **kwargs):
13+
def decorator(func):
14+
return func
15+
16+
return decorator
17+
18+
prange = range
19+
20+
21+
@jit(nopython=True, cache=True)
22+
def compute_strategy_returns(
23+
weights_prev: np.ndarray,
24+
returns: np.ndarray,
25+
turnover: np.ndarray,
26+
transaction_cost: float,
27+
) -> np.ndarray:
28+
"""Compute strategy returns with transaction costs."""
29+
n_days, n_assets = returns.shape
30+
strat_ret = np.zeros(n_days)
31+
32+
for i in prange(n_days):
33+
ret_sum = 0.0
34+
for j in prange(n_assets):
35+
ret_sum += weights_prev[i, j] * returns[i, j]
36+
strat_ret[i] = ret_sum - (turnover[i] * transaction_cost)
37+
38+
return strat_ret
39+
40+
41+
@jit(nopython=True, cache=True)
42+
def compute_turnover(weights: np.ndarray, weights_prev: np.ndarray) -> np.ndarray:
43+
"""Compute turnover (L1 change / 2)."""
44+
n_days, n_assets = weights.shape
45+
turnover = np.zeros(n_days)
46+
47+
for i in prange(n_days):
48+
total_change = 0.0
49+
for j in prange(n_assets):
50+
total_change += abs(weights[i, j] - weights_prev[i, j])
51+
turnover[i] = total_change * 0.5
52+
53+
return turnover
54+
55+
56+
@jit(nopython=True, cache=True)
57+
def compute_portfolio_value(
58+
strategy_returns: np.ndarray, initial_capital: float
59+
) -> np.ndarray:
60+
"""Compute cumulative portfolio value."""
61+
n_days = len(strategy_returns)
62+
portfolio_value = np.zeros(n_days + 1)
63+
portfolio_value[0] = initial_capital
64+
65+
for i in prange(n_days):
66+
portfolio_value[i + 1] = portfolio_value[i] * (1.0 + strategy_returns[i])
67+
68+
return portfolio_value[1:]
69+
70+
71+
@jit(nopython=True, cache=True)
72+
def compute_returns_from_prices(prices: np.ndarray) -> np.ndarray:
73+
"""Compute percentage returns from prices."""
74+
n_days, n_assets = prices.shape
75+
returns = np.zeros((n_days - 1, n_assets))
76+
77+
for i in prange(n_days - 1):
78+
for j in prange(n_assets):
79+
if prices[i, j] > 0:
80+
returns[i, j] = (prices[i + 1, j] - prices[i, j]) / prices[i, j]
81+
82+
return returns
83+
84+
85+
@jit(nopython=True, cache=True)
86+
def rank_based_weights(
87+
signals: np.ndarray, max_leverage: float, long_pct: float, short_pct: float
88+
) -> np.ndarray:
89+
"""Compute rank-based portfolio weights."""
90+
n_assets = len(signals)
91+
weights = np.zeros(n_assets)
92+
93+
valid_mask = np.zeros(n_assets, dtype=np.bool_)
94+
n_valid = 0
95+
for i in range(n_assets):
96+
if not np.isnan(signals[i]):
97+
valid_mask[i] = True
98+
n_valid += 1
99+
100+
if n_valid == 0:
101+
return weights
102+
103+
valid_values = np.zeros(n_valid)
104+
valid_indices = np.zeros(n_valid, dtype=np.int64)
105+
idx = 0
106+
for i in range(n_assets):
107+
if valid_mask[i]:
108+
valid_values[idx] = signals[i]
109+
valid_indices[idx] = i
110+
idx += 1
111+
112+
sorted_idx = np.argsort(valid_values)
113+
ranks = np.zeros(n_valid)
114+
for i in range(n_valid):
115+
ranks[sorted_idx[i]] = i + 1.0
116+
117+
sorted_ranks = np.sort(ranks)
118+
long_idx = int(n_valid * long_pct)
119+
short_idx = int(n_valid * short_pct)
120+
long_threshold = sorted_ranks[long_idx] if long_idx < n_valid else sorted_ranks[-1]
121+
short_threshold = sorted_ranks[short_idx] if short_idx >= 0 else sorted_ranks[0]
122+
123+
long_count = 0
124+
short_count = 0
125+
126+
for idx in range(n_valid):
127+
i = valid_indices[idx]
128+
rank_val = ranks[idx]
129+
if rank_val >= long_threshold:
130+
weights[i] = 1.0
131+
long_count += 1
132+
elif rank_val <= short_threshold:
133+
weights[i] = -1.0
134+
short_count += 1
135+
136+
if long_count > 0:
137+
long_weight = 1.0 / long_count
138+
for i in range(n_assets):
139+
if weights[i] > 0:
140+
weights[i] = long_weight
141+
if short_count > 0:
142+
short_weight = -1.0 / short_count
143+
for i in range(n_assets):
144+
if weights[i] < 0:
145+
weights[i] = short_weight
146+
147+
total_leverage = 0.0
148+
for i in range(n_assets):
149+
total_leverage += abs(weights[i])
150+
151+
if total_leverage > max_leverage and total_leverage > 0:
152+
scale = max_leverage / total_leverage
153+
for i in range(n_assets):
154+
weights[i] *= scale
155+
156+
return weights
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""Simple profiler to identify hotspots in backtest."""
2+
3+
import cProfile
4+
import pstats
5+
import sys
6+
from io import StringIO
7+
from pathlib import Path
8+
9+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
10+
11+
from quant_research_starter.backtest.vectorized import VectorizedBacktest
12+
from quant_research_starter.data import SampleDataLoader
13+
14+
15+
def profile_backtest():
16+
"""Profile the backtest to identify hotspots."""
17+
loader = SampleDataLoader()
18+
prices = loader.load_sample_prices()
19+
20+
signals = prices.pct_change(20).fillna(0)
21+
22+
profiler = cProfile.Profile()
23+
profiler.enable()
24+
25+
backtest = VectorizedBacktest(
26+
prices=prices,
27+
signals=signals,
28+
initial_capital=1_000_000,
29+
transaction_cost=0.001,
30+
)
31+
backtest.run(weight_scheme="rank")
32+
33+
profiler.disable()
34+
35+
s = StringIO()
36+
stats = pstats.Stats(profiler, stream=s)
37+
stats.sort_stats("cumulative")
38+
stats.print_stats(20)
39+
40+
print("Top 20 functions by cumulative time:")
41+
print(s.getvalue())
42+
43+
stats.sort_stats("tottime")
44+
stats.print_stats(20)
45+
46+
print("\nTop 20 functions by total time:")
47+
s2 = StringIO()
48+
stats = pstats.Stats(profiler, stream=s2)
49+
stats.sort_stats("tottime")
50+
stats.print_stats(20)
51+
print(s2.getvalue())
52+
53+
54+
if __name__ == "__main__":
55+
profile_backtest()
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""Setup script for Cython extensions."""
2+
3+
import numpy
4+
from Cython.Build import cythonize
5+
from setuptools import Extension, setup
6+
7+
extensions = [
8+
Extension(
9+
"cython_opt",
10+
["cython_opt.pyx"],
11+
include_dirs=[numpy.get_include()],
12+
extra_compile_args=["-O3"],
13+
)
14+
]
15+
16+
setup(
17+
ext_modules=cythonize(extensions, compiler_directives={"language_level": "3"}),
18+
)

0 commit comments

Comments
 (0)