Skip to content

Commit d110000

Browse files
authored
Merge pull request #66 from AKKI0511/analyze-quanttradeai-for-end-to-end-feature
feat: add test-window coverage reporting
2 parents 2852274 + dd3c66a commit d110000

File tree

5 files changed

+189
-13
lines changed

5 files changed

+189
-13
lines changed

quanttradeai/cli.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,20 @@ def cmd_train(
5353
),
5454
):
5555
"""Run full training pipeline."""
56+
pipeline_result = run_pipeline(
57+
config, skip_validation=skip_validation, include_coverage=True
58+
)
59+
coverage_info = None
60+
if isinstance(pipeline_result, tuple) and len(pipeline_result) == 2:
61+
_, coverage_info = pipeline_result
5662

57-
run_pipeline(config, skip_validation=skip_validation)
63+
if coverage_info:
64+
fallback = coverage_info.get("fallback_symbols") or []
65+
path = coverage_info.get("path")
66+
summary = f"Test-window coverage report saved to {path}."
67+
if fallback:
68+
summary += " Fallback chronological split used for: " + ", ".join(fallback)
69+
typer.echo(summary, err=True)
5870

5971

6072
@app.command("evaluate")
@@ -164,6 +176,16 @@ def cmd_backtest_model(
164176
liquidity_max_participation=liquidity_max_participation,
165177
skip_validation=skip_validation,
166178
)
179+
coverage_info = (
180+
summary.get("coverage_report") if isinstance(summary, dict) else None
181+
)
182+
if coverage_info:
183+
fallback = coverage_info.get("fallback_symbols") or []
184+
path = coverage_info.get("path")
185+
message = f"Test-window coverage report saved to {path}."
186+
if fallback:
187+
message += " Fallback chronological split used for: " + ", ".join(fallback)
188+
typer.echo(message, err=True)
167189
typer.echo(json.dumps(summary, indent=2))
168190

169191

quanttradeai/main.py

Lines changed: 97 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,23 @@ def _write_validation_report(report_path: Path, report: dict) -> None:
7575
logger.warning("Failed to write CSV validation report: %s", exc)
7676

7777

78+
def _write_coverage_report(report_path: Path, coverage: dict) -> None:
79+
"""Persist test-window coverage results to JSON and CSV formats."""
80+
81+
report_path.parent.mkdir(parents=True, exist_ok=True)
82+
with open(report_path, "w") as fh:
83+
json.dump(coverage, fh, indent=2)
84+
85+
try:
86+
rows = []
87+
for symbol, details in coverage.items():
88+
row = {"symbol": symbol, **details}
89+
rows.append(row)
90+
pd.DataFrame(rows).to_csv(report_path.with_suffix(".csv"), index=False)
91+
except Exception as exc: # pragma: no cover - defensive
92+
logger.warning("Failed to write CSV coverage report: %s", exc)
93+
94+
7895
def _ensure_datetime_index(df: pd.DataFrame) -> pd.DataFrame:
7996
"""Ensure the DataFrame index is a DatetimeIndex.
8097
@@ -103,7 +120,7 @@ def _ensure_datetime_index(df: pd.DataFrame) -> pd.DataFrame:
103120
def time_aware_split(
104121
df_labeled: pd.DataFrame,
105122
cfg: dict,
106-
) -> Tuple[pd.DataFrame, pd.DataFrame]:
123+
) -> Tuple[pd.DataFrame, pd.DataFrame, dict]:
107124
"""Return chronological train/test splits using config windows.
108125
109126
Rules
@@ -150,6 +167,10 @@ def _window_has_full_coverage(
150167
return True
151168

152169
fallback_used = False
170+
coverage_ok: bool | None = True if test_start else None
171+
data_start = df.index.min()
172+
data_end = df.index.max()
173+
split_strategy = "window" if test_start else "fraction"
153174
if test_start:
154175
start_dt = pd.to_datetime(test_start)
155176
end_dt = pd.to_datetime(test_end) if test_end else None
@@ -184,7 +205,22 @@ def _window_has_full_coverage(
184205
raise ValueError(
185206
f"Invalid train/test window produced empty split when using {window_msg}. Adjust data.test_* or training.test_size."
186207
)
187-
return train_df, test_df
208+
coverage: dict = {
209+
"data_start": data_start.isoformat(),
210+
"data_end": data_end.isoformat(),
211+
"test_start": pd.to_datetime(test_start).isoformat() if test_start else None,
212+
"test_end": pd.to_datetime(test_end).isoformat() if test_end else None,
213+
"train_start": train_df.index.min().isoformat(),
214+
"train_end": train_df.index.max().isoformat(),
215+
"test_start_actual": test_df.index.min().isoformat(),
216+
"test_end_actual": test_df.index.max().isoformat(),
217+
"train_size": len(train_df),
218+
"test_size": len(test_df),
219+
"coverage_ok": coverage_ok,
220+
"fallback_used": fallback_used,
221+
"split_strategy": split_strategy if not fallback_used else "fraction_fallback",
222+
}
223+
return train_df, test_df, coverage
188224

189225

190226
def _validate_or_raise(
@@ -225,7 +261,10 @@ def _validate_or_raise(
225261

226262

227263
def run_pipeline(
228-
config_path: str = "config/model_config.yaml", *, skip_validation: bool = False
264+
config_path: str = "config/model_config.yaml",
265+
*,
266+
skip_validation: bool = False,
267+
include_coverage: bool = False,
229268
):
230269
"""Run the end-to-end training pipeline.
231270
@@ -238,6 +277,12 @@ def run_pipeline(
238277
>>> results = run_pipeline("config/model_config.yaml")
239278
>>> sorted(results.keys()) # doctest: +ELLIPSIS
240279
...
280+
281+
Set ``include_coverage=True`` to also receive coverage metadata:
282+
283+
>>> results, coverage = run_pipeline("config/model_config.yaml", include_coverage=True)
284+
>>> sorted(coverage.keys())
285+
['fallback_symbols', 'path']
241286
"""
242287

243288
# Load configuration
@@ -274,6 +319,7 @@ def run_pipeline(
274319

275320
# Process each stock
276321
results = {}
322+
coverage_report: dict[str, dict] = {}
277323
for symbol, df in data_dict.items():
278324
logger.info(f"\nProcessing {symbol}...")
279325

@@ -284,7 +330,8 @@ def run_pipeline(
284330
df_labeled = data_processor.generate_labels(df_processed)
285331

286332
# 4. Time-aware Split
287-
train_df, test_df = time_aware_split(df_labeled, config)
333+
train_df, test_df, coverage = time_aware_split(df_labeled, config)
334+
coverage_report[symbol] = coverage
288335
X_train, y_train = model.prepare_data(train_df)
289336
X_test, y_test = model.prepare_data(test_df)
290337
# Log split summary
@@ -327,11 +374,35 @@ def run_pipeline(
327374
logger.info(f"Train Metrics: {train_metrics}")
328375
logger.info(f"Test Metrics: {test_metrics}")
329376

377+
coverage_path = Path(experiment_dir) / "test_window_coverage.json"
378+
_write_coverage_report(coverage_path, coverage_report)
379+
fallback_symbols = [
380+
symbol
381+
for symbol, details in coverage_report.items()
382+
if details.get("fallback_used")
383+
]
384+
if fallback_symbols:
385+
logger.warning(
386+
"Fallback chronological split applied for symbols: %s. Coverage report: %s",
387+
", ".join(fallback_symbols),
388+
coverage_path,
389+
)
390+
logger.info("Coverage report saved to %s", coverage_path)
391+
330392
# Save experiment results
331393
with open(f"{experiment_dir}/results.json", "w") as f:
332394
json.dump(results, f, indent=4)
333395

334396
logger.info("\nPipeline completed successfully!")
397+
if include_coverage:
398+
return (
399+
results,
400+
{
401+
"path": coverage_path.as_posix(),
402+
"fallback_symbols": fallback_symbols,
403+
},
404+
)
405+
335406
return results
336407

337408
except Exception as e:
@@ -539,6 +610,7 @@ def run_model_backtest(
539610
summary: dict = {}
540611
prepared_data: dict[str, pd.DataFrame] = {}
541612
artifact_dirs: dict[str, Path] = {}
613+
coverage_report: dict[str, dict] = {}
542614

543615
trading_cfg = (cfg or {}).get("trading", {})
544616
stop_loss = trading_cfg.get("stop_loss")
@@ -595,7 +667,8 @@ def _execution_for(symbol: str) -> dict:
595667
try:
596668
df_proc = processor.process_data(df)
597669
df_lbl = processor.generate_labels(df_proc)
598-
train_df, test_df = time_aware_split(df_lbl, cfg)
670+
train_df, test_df, coverage = time_aware_split(df_lbl, cfg)
671+
coverage_report[symbol] = coverage
599672
# Build features from saved order
600673
missing = [
601674
c for c in (clf.feature_columns or []) if c not in test_df.columns
@@ -676,6 +749,25 @@ def _execution_for(symbol: str) -> dict:
676749
"output_dir": out_dir.as_posix(),
677750
}
678751

752+
coverage_path = base_dir / "test_window_coverage.json"
753+
_write_coverage_report(coverage_path, coverage_report)
754+
fallback_symbols = [
755+
symbol
756+
for symbol, details in coverage_report.items()
757+
if details.get("fallback_used")
758+
]
759+
if fallback_symbols:
760+
logger.warning(
761+
"Fallback chronological split applied for symbols: %s. Coverage report: %s",
762+
", ".join(fallback_symbols),
763+
coverage_path,
764+
)
765+
logger.info("Coverage report saved to %s", coverage_path)
766+
summary["coverage_report"] = {
767+
"path": coverage_path.as_posix(),
768+
"fallback_symbols": fallback_symbols,
769+
}
770+
679771
return summary
680772

681773

tests/data/test_validation_gate.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,11 @@ def test_run_pipeline_skip_validation_allows_progression(
8181
)
8282
monkeypatch.setattr(MomentumClassifier, "save_model", lambda self, path: None)
8383

84-
results = run_pipeline(sample_config_path, skip_validation=True)
84+
results, coverage_info = run_pipeline(
85+
sample_config_path, skip_validation=True, include_coverage=True
86+
)
8587
assert "AAPL" in results
88+
assert coverage_info["path"].endswith("test_window_coverage.json")
8689

8790

8891
def test_run_model_backtest_validates_before_execution(

tests/integration/test_pipeline_time_splits.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,36 @@ def test_time_aware_split_with_window():
1414
idx = pd.date_range("2024-01-01", periods=10, freq="D")
1515
df = pd.DataFrame({"Close": range(10)}, index=idx)
1616
cfg = {"data": {"test_start": "2024-01-06", "test_end": "2024-01-08"}}
17-
train, test = time_aware_split(df, cfg)
17+
train, test, coverage = time_aware_split(df, cfg)
1818
assert train.index.max() < pd.to_datetime("2024-01-06")
1919
assert test.index.min() == pd.to_datetime("2024-01-06")
2020
assert test.index.max() == pd.to_datetime("2024-01-08")
2121
assert len(train) == 5 and len(test) == 3
22+
assert coverage["coverage_ok"] is True
23+
assert coverage["fallback_used"] is False
2224

2325

2426
def test_time_aware_split_with_start_only():
2527
idx = pd.date_range("2024-01-01", periods=10, freq="D")
2628
df = pd.DataFrame({"Close": range(10)}, index=idx)
2729
cfg = {"data": {"test_start": "2024-01-06"}}
28-
train, test = time_aware_split(df, cfg)
30+
train, test, coverage = time_aware_split(df, cfg)
2931
assert train.index.max() < pd.to_datetime("2024-01-06")
3032
assert test.index.min() == pd.to_datetime("2024-01-06")
3133
assert len(train) == 5 and len(test) == 5
34+
assert coverage["coverage_ok"] is True
35+
assert coverage["fallback_used"] is False
3236

3337

3438
def test_time_aware_split_fallback_fraction():
3539
idx = pd.date_range("2024-01-01", periods=10, freq="D")
3640
df = pd.DataFrame({"Close": range(10)}, index=idx)
3741
cfg = {"training": {"test_size": 0.2}}
38-
train, test = time_aware_split(df, cfg)
42+
train, test, coverage = time_aware_split(df, cfg)
3943
assert len(train) == 8 and len(test) == 2
4044
assert train.index.max() < test.index.min()
45+
assert coverage["test_start"] is None
46+
assert coverage["coverage_ok"] is None
4147

4248

4349
def test_time_aware_split_warns_and_falls_back(caplog):
@@ -49,10 +55,12 @@ def test_time_aware_split_warns_and_falls_back(caplog):
4955
}
5056

5157
with caplog.at_level(logging.WARNING):
52-
train, test = time_aware_split(df, cfg)
58+
train, test, coverage = time_aware_split(df, cfg)
5359

5460
assert len(train) == 3 and len(test) == 2
5561
assert "falling back to chronological split" in caplog.text
62+
assert coverage["coverage_ok"] is False
63+
assert coverage["fallback_used"] is True
5664

5765

5866
def test_time_aware_split_warns_on_partial_window(caplog):
@@ -64,11 +72,32 @@ def test_time_aware_split_warns_on_partial_window(caplog):
6472
}
6573

6674
with caplog.at_level(logging.WARNING):
67-
train, test = time_aware_split(df, cfg)
75+
train, test, coverage = time_aware_split(df, cfg)
6876

6977
assert len(train) == 6 and len(test) == 2
7078
assert train.index.max() < test.index.min()
7179
assert "not fully present in data; falling back" in caplog.text
80+
assert coverage["coverage_ok"] is False
81+
assert coverage["fallback_used"] is True
82+
83+
84+
def test_time_aware_split_reports_coverage_fields():
85+
idx = pd.date_range("2024-01-01", periods=6, freq="D")
86+
df = pd.DataFrame({"Close": range(6)}, index=idx)
87+
cfg = {
88+
"data": {"test_start": "2024-01-05", "test_end": "2024-01-10"},
89+
"training": {"test_size": 0.5},
90+
}
91+
92+
train, test, coverage = time_aware_split(df, cfg)
93+
94+
assert coverage["data_start"].startswith("2024-01-01")
95+
assert coverage["data_end"].startswith("2024-01-06")
96+
assert coverage["test_start"].startswith("2024-01-05")
97+
assert coverage["test_end"].startswith("2024-01-10")
98+
assert coverage["train_size"] == len(train)
99+
assert coverage["test_size"] == len(test)
100+
assert coverage["split_strategy"] == "fraction_fallback"
72101

73102

74103
def test_model_config_rejects_out_of_range_test_window():
@@ -159,9 +188,12 @@ def prepare_data(df):
159188
model_instance.evaluate.return_value = {"accuracy": 1.0}
160189
model_instance.save_model.return_value = None
161190

162-
results = run_pipeline(str(config_path))
191+
results, coverage_info = run_pipeline(
192+
str(config_path), include_coverage=True
193+
)
163194

164195
mock_loader.assert_called_once_with(str(config_path))
165196
assert "AAA" in results
166197
assert "hyperparameters" in results["AAA"]
198+
assert coverage_info["path"].endswith("test_window_coverage.json")
167199

tests/streaming/test_gateway.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import tempfile
44
import time
5+
import warnings
56
from typing import Awaitable, Callable, Dict, List, Optional
67
from unittest.mock import AsyncMock, patch
78

@@ -12,6 +13,32 @@
1213
import pytest
1314
from quanttradeai.streaming import StreamingGateway
1415

16+
pytestmark = pytest.mark.filterwarnings(
17+
"ignore:.*websockets.*:DeprecationWarning",
18+
"ignore:.*WebSocketServerProtocol is deprecated:DeprecationWarning",
19+
)
20+
21+
warnings.filterwarnings(
22+
"ignore",
23+
message=r"websockets\.legacy is deprecated",
24+
category=DeprecationWarning,
25+
)
26+
warnings.filterwarnings(
27+
"ignore",
28+
message=r"websockets\.server\.WebSocketServerProtocol is deprecated",
29+
category=DeprecationWarning,
30+
)
31+
warnings.filterwarnings(
32+
"ignore",
33+
category=DeprecationWarning,
34+
module=r"websockets.*",
35+
)
36+
warnings.filterwarnings(
37+
"ignore",
38+
category=DeprecationWarning,
39+
module=r"uvicorn\.protocols\.websockets.*",
40+
)
41+
1542

1643
class StubProviderMonitor:
1744
def __init__(

0 commit comments

Comments
 (0)