Skip to content

Commit 793fec7

Browse files
authored
Merge pull request #61 from AKKI0511/conduct-deep-analysis-of-quanttradeai
Add data validation gate and reporting
2 parents 8592251 + 220414c commit 793fec7

File tree

9 files changed

+325
-85
lines changed

9 files changed

+325
-85
lines changed

docs/api/data.md

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,25 @@ aapl_data = data['AAPL']
5151
print(f"AAPL data shape: {aapl_data.shape}")
5252
```
5353

54-
### `validate_data(data_dict: Dict[str, pd.DataFrame]) -> bool`
55-
56-
Validates the fetched data meets requirements.
57-
58-
**Parameters:**
59-
- `data_dict` (Dict[str, pd.DataFrame]): Dictionary of DataFrames with OHLCV data
60-
61-
**Returns:**
62-
- `bool`: True if data is valid, False otherwise
63-
64-
**Example:**
65-
```python
66-
# Validate fetched data
67-
is_valid = loader.validate_data(data_dict)
68-
if not is_valid:
69-
print("Data validation failed")
70-
```
54+
### `validate_data(data_dict: Dict[str, pd.DataFrame]) -> tuple[bool, dict]`
55+
56+
Validates the fetched data meets requirements and returns a per-symbol report with
57+
missing columns, date span, NaN ratios, and pass/fail flags.
58+
59+
**Parameters:**
60+
- `data_dict` (Dict[str, pd.DataFrame]): Dictionary of DataFrames with OHLCV data
61+
62+
**Returns:**
63+
- `Tuple[bool, dict]`: Overall validity flag and a detailed report
64+
65+
**Example:**
66+
```python
67+
# Validate fetched data
68+
is_valid, report = loader.validate_data(data_dict)
69+
if not is_valid:
70+
print("Data validation failed")
71+
print(report)
72+
```
7173

7274
### `save_data(data_dict: Dict[str, pd.DataFrame], path: Optional[str] = None)`
7375

@@ -282,14 +284,15 @@ df = df.fillna(method='ffill')
282284

283285
### Validation Errors
284286
```python
285-
try:
286-
# Validate data
287-
is_valid = loader.validate_data(data_dict)
288-
if not is_valid:
289-
print("Data validation failed")
290-
except Exception as e:
291-
print(f"Validation error: {e}")
292-
```
287+
try:
288+
# Validate data
289+
is_valid, report = loader.validate_data(data_dict)
290+
if not is_valid:
291+
print("Data validation failed")
292+
print(report)
293+
except Exception as e:
294+
print(f"Validation error: {e}")
295+
```
293296

294297
## Related Documentation
295298

docs/quick-reference.md

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ poetry run quanttradeai --help
1212
poetry run quanttradeai fetch-data
1313
poetry run quanttradeai fetch-data --refresh
1414

15-
# Train models
16-
poetry run quanttradeai train
15+
# Train models
16+
poetry run quanttradeai train
17+
poetry run quanttradeai train --skip-validation # bypass data-quality gate
1718

1819
# Evaluate model
1920
poetry run quanttradeai evaluate -m models/trained/AAPL
@@ -25,7 +26,8 @@ poetry run quanttradeai backtest --cost-bps 5 --slippage-bps 10
2526
# Backtest a saved model (end-to-end)
2627
poetry run quanttradeai backtest-model -m models/experiments/<timestamp>/<SYMBOL> \
2728
-c config/model_config.yaml -b config/backtest_config.yaml \
28-
--cost-bps 5 --slippage-fixed 0.01 --liquidity-max-participation 0.25
29+
--cost-bps 5 --slippage-fixed 0.01 --liquidity-max-participation 0.25 \
30+
--skip-validation # optional
2931
```
3032

3133
## 📊 Python API Patterns
@@ -38,9 +40,13 @@ from quanttradeai import DataLoader
3840
loader = DataLoader("config/model_config.yaml")
3941
data = loader.fetch_data(symbols=['AAPL', 'TSLA'], refresh=True)
4042

41-
# Validate data
42-
is_valid = loader.validate_data(data)
43-
```
43+
# Validate data and inspect per-symbol report
44+
is_valid, report = loader.validate_data(data)
45+
if not is_valid:
46+
print(report)
47+
# CLI commands write this report to models/experiments/<timestamp>/validation.json
48+
# or reports/backtests/<timestamp>/validation.json
49+
```
4450

4551
### Feature Engineering
4652
```python

quanttradeai/cli.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,16 @@ def cmd_fetch_data(
4444
def cmd_train(
4545
config: str = typer.Option(
4646
"config/model_config.yaml", "-c", "--config", help="Path to config file"
47-
)
47+
),
48+
skip_validation: bool = typer.Option(
49+
False,
50+
"--skip-validation",
51+
help="Skip data-quality validation before training",
52+
),
4853
):
4954
"""Run full training pipeline."""
5055

51-
run_pipeline(config)
56+
run_pipeline(config, skip_validation=skip_validation)
5257

5358

5459
@app.command("evaluate")
@@ -57,10 +62,15 @@ def cmd_evaluate(
5762
config: str = typer.Option(
5863
"config/model_config.yaml", "-c", "--config", help="Path to config file"
5964
),
65+
skip_validation: bool = typer.Option(
66+
False,
67+
"--skip-validation",
68+
help="Skip data-quality validation before evaluation",
69+
),
6070
):
6171
"""Evaluate a saved model on current dataset."""
6272

63-
evaluate_model(config, model_path)
73+
evaluate_model(config, model_path, skip_validation=skip_validation)
6474

6575

6676
@app.command("backtest")
@@ -133,6 +143,11 @@ def cmd_backtest_model(
133143
liquidity_max_participation: Optional[float] = typer.Option(
134144
None, help="Liquidity max participation"
135145
),
146+
skip_validation: bool = typer.Option(
147+
False,
148+
"--skip-validation",
149+
help="Skip data-quality validation before backtesting",
150+
),
136151
):
137152
"""Backtest a saved model on the configured test window with execution costs."""
138153

@@ -146,6 +161,7 @@ def cmd_backtest_model(
146161
slippage_bps=slippage_bps,
147162
slippage_fixed=slippage_fixed,
148163
liquidity_max_participation=liquidity_max_participation,
164+
skip_validation=skip_validation,
149165
)
150166
typer.echo(json.dumps(summary, indent=2))
151167

quanttradeai/data/loader.py

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -349,37 +349,58 @@ def _resample_secondary(
349349
resampled = resampled.reindex(target_index)
350350
return resampled
351351

352-
def validate_data(self, data_dict: Dict[str, pd.DataFrame]) -> bool:
353-
"""
354-
Validate the fetched data meets requirements.
355-
356-
Args:
357-
data_dict: Dictionary of DataFrames with OHLCV data.
358-
359-
Returns:
360-
bool: True if data is valid, False otherwise.
361-
"""
362-
required_columns = ["Open", "High", "Low", "Close", "Volume"]
363-
364-
for symbol, df in data_dict.items():
365-
# Check required columns
366-
if not all(col in df.columns for col in required_columns):
367-
logger.error(f"Missing required columns for {symbol}")
368-
return False
369-
370-
# Check data range
371-
date_range = (df.index.max() - df.index.min()).days
372-
if date_range < 365: # At least one year of data
373-
logger.error(f"Insufficient data range for {symbol}")
374-
return False
375-
376-
# Check for excessive missing values
377-
# Check missing value ratio per column
378-
if df.isnull().mean().max() > 0.01: # Max 1% missing values
379-
logger.error(f"Too many missing values for {symbol}")
380-
return False
381-
382-
return True
352+
def validate_data(self, data_dict: Dict[str, pd.DataFrame]) -> tuple[bool, dict]:
353+
"""
354+
Validate the fetched data meets requirements and return a detailed report.
355+
356+
Args:
357+
data_dict: Dictionary of DataFrames with OHLCV data.
358+
359+
Returns:
360+
Tuple[bool, dict]: Overall validity flag and a per-symbol report with
361+
missing column checks, date span, NaN ratios, and pass/fail status.
362+
"""
363+
required_columns = ["Open", "High", "Low", "Close", "Volume"]
364+
365+
overall_valid = True
366+
report: Dict[str, dict] = {}
367+
368+
for symbol, df in data_dict.items():
369+
missing_columns = [col for col in required_columns if col not in df.columns]
370+
date_span_days = int((df.index.max() - df.index.min()).days) if not df.empty else 0
371+
nan_ratio_by_column = {
372+
col: float(df[col].isnull().mean()) for col in df.columns
373+
}
374+
nan_ratio_required_columns = {
375+
col: nan_ratio_by_column[col]
376+
for col in required_columns
377+
if col in nan_ratio_by_column
378+
}
379+
max_nan_ratio = max(nan_ratio_required_columns.values(), default=0.0)
380+
381+
errors = []
382+
if missing_columns:
383+
errors.append(f"Missing required columns: {', '.join(missing_columns)}")
384+
if date_span_days < 365:
385+
errors.append("Insufficient data range (<365 days)")
386+
if max_nan_ratio > 0.01:
387+
errors.append(f"Too many missing values (max ratio={max_nan_ratio:.4f})")
388+
389+
passed = len(errors) == 0
390+
if not passed:
391+
overall_valid = False
392+
logger.error("Data validation failed for %s: %s", symbol, "; ".join(errors))
393+
394+
report[symbol] = {
395+
"missing_columns": missing_columns,
396+
"date_span_days": date_span_days,
397+
"nan_ratio_by_column": nan_ratio_by_column,
398+
"max_nan_ratio": max_nan_ratio,
399+
"passed": passed,
400+
"errors": errors,
401+
}
402+
403+
return overall_valid, report
383404

384405
def save_data(
385406
self, data_dict: Dict[str, pd.DataFrame], path: Optional[str] = None

quanttradeai/main.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,26 @@ def setup_directories(cache_dir: str = "data/raw"):
5656
Path(dir_path).mkdir(parents=True, exist_ok=True)
5757

5858

59+
def _write_validation_report(report_path: Path, report: dict) -> None:
60+
"""Persist validation results to JSON and CSV formats."""
61+
62+
report_path.parent.mkdir(parents=True, exist_ok=True)
63+
with open(report_path, "w") as f:
64+
json.dump(report, f, indent=2, default=float)
65+
66+
try:
67+
import pandas as pd
68+
69+
rows = []
70+
for symbol, details in report.items():
71+
row = {"symbol": symbol, **details}
72+
rows.append(row)
73+
df = pd.DataFrame(rows)
74+
df.to_csv(report_path.with_suffix(".csv"), index=False)
75+
except Exception as exc: # pragma: no cover - defensive
76+
logger.warning("Failed to write CSV validation report: %s", exc)
77+
78+
5979
def _ensure_datetime_index(df: pd.DataFrame) -> pd.DataFrame:
6080
"""Ensure the DataFrame index is a DatetimeIndex.
6181
@@ -168,7 +188,46 @@ def _window_has_full_coverage(
168188
return train_df, test_df
169189

170190

171-
def run_pipeline(config_path: str = "config/model_config.yaml"):
191+
def _validate_or_raise(
192+
*,
193+
loader: DataLoader,
194+
data: dict,
195+
report_path: Path,
196+
skip_validation: bool,
197+
) -> dict:
198+
"""Validate fetched data and persist a report unless skipped."""
199+
200+
if skip_validation:
201+
logger.warning(
202+
"Skipping data validation as requested; downstream results may be unreliable."
203+
)
204+
return {}
205+
206+
validation_result = loader.validate_data(data)
207+
if isinstance(validation_result, tuple) and len(validation_result) == 2:
208+
is_valid, report = validation_result
209+
else:
210+
is_valid = bool(validation_result)
211+
report = {}
212+
_write_validation_report(report_path, report)
213+
214+
if not is_valid:
215+
failed = [symbol for symbol, res in report.items() if not res.get("passed")]
216+
raise ValueError(
217+
f"Data validation failed for symbols: {', '.join(failed) if failed else 'unknown'}"
218+
)
219+
220+
logger.info(
221+
"Data validation passed for %d symbols. Report saved to %s",
222+
len(report),
223+
report_path,
224+
)
225+
return report
226+
227+
228+
def run_pipeline(
229+
config_path: str = "config/model_config.yaml", *, skip_validation: bool = False
230+
):
172231
"""Run the end-to-end training pipeline.
173232
174233
Loads data, generates features and labels, tunes hyperparameters,
@@ -205,6 +264,15 @@ def run_pipeline(config_path: str = "config/model_config.yaml"):
205264
refresh = config.get("data", {}).get("refresh", False)
206265
data_dict = data_loader.fetch_data(refresh=refresh)
207266

267+
# Validate data quality
268+
validation_path = Path(experiment_dir) / "validation.json"
269+
_validate_or_raise(
270+
loader=data_loader,
271+
data=data_dict,
272+
report_path=validation_path,
273+
skip_validation=skip_validation,
274+
)
275+
208276
# Process each stock
209277
results = {}
210278
for symbol, df in data_dict.items():
@@ -284,7 +352,9 @@ def fetch_data_only(config_path: str, refresh: bool = False) -> None:
284352
data_loader.save_data(data)
285353

286354

287-
def evaluate_model(config_path: str, model_path: str) -> None:
355+
def evaluate_model(
356+
config_path: str, model_path: str, *, skip_validation: bool = False
357+
) -> None:
288358
"""Evaluate a persisted model on current config’s dataset.
289359
290360
Example
@@ -297,6 +367,13 @@ def evaluate_model(config_path: str, model_path: str) -> None:
297367
model.load_model(model_path)
298368

299369
data_dict = data_loader.fetch_data()
370+
validation_path = Path(model_path) / "validation.json"
371+
_validate_or_raise(
372+
loader=data_loader,
373+
data=data_dict,
374+
report_path=validation_path,
375+
skip_validation=skip_validation,
376+
)
300377
results = {}
301378
for symbol, df in data_dict.items():
302379
df_processed = data_processor.process_data(df)
@@ -382,6 +459,7 @@ def run_model_backtest(
382459
slippage_bps: float | None = None,
383460
slippage_fixed: float | None = None,
384461
liquidity_max_participation: float | None = None,
462+
skip_validation: bool = False,
385463
) -> dict:
386464
"""Backtest a saved model’s predictions on the configured test window.
387465
@@ -432,8 +510,15 @@ def run_model_backtest(
432510
risk_path,
433511
)
434512

435-
data_dict = loader.fetch_data()
436513
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
514+
data_dict = loader.fetch_data()
515+
validation_path = Path(f"reports/backtests/{timestamp}") / "validation.json"
516+
_validate_or_raise(
517+
loader=loader,
518+
data=data_dict,
519+
report_path=validation_path,
520+
skip_validation=skip_validation,
521+
)
437522
base_dir = Path(f"reports/backtests/{timestamp}")
438523
base_dir.mkdir(parents=True, exist_ok=True)
439524

0 commit comments

Comments
 (0)