Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 28 additions & 25 deletions docs/api/data.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,25 @@ aapl_data = data['AAPL']
print(f"AAPL data shape: {aapl_data.shape}")
```

### `validate_data(data_dict: Dict[str, pd.DataFrame]) -> bool`

Validates the fetched data meets requirements.

**Parameters:**
- `data_dict` (Dict[str, pd.DataFrame]): Dictionary of DataFrames with OHLCV data

**Returns:**
- `bool`: True if data is valid, False otherwise

**Example:**
```python
# Validate fetched data
is_valid = loader.validate_data(data_dict)
if not is_valid:
print("Data validation failed")
```
### `validate_data(data_dict: Dict[str, pd.DataFrame]) -> tuple[bool, dict]`

Validates the fetched data meets requirements and returns a per-symbol report with
missing columns, date span, NaN ratios, and pass/fail flags.

**Parameters:**
- `data_dict` (Dict[str, pd.DataFrame]): Dictionary of DataFrames with OHLCV data

**Returns:**
- `Tuple[bool, dict]`: Overall validity flag and a detailed report

**Example:**
```python
# Validate fetched data
is_valid, report = loader.validate_data(data_dict)
if not is_valid:
print("Data validation failed")
print(report)
```

### `save_data(data_dict: Dict[str, pd.DataFrame], path: Optional[str] = None)`

Expand Down Expand Up @@ -282,14 +284,15 @@ df = df.fillna(method='ffill')

### Validation Errors
```python
try:
# Validate data
is_valid = loader.validate_data(data_dict)
if not is_valid:
print("Data validation failed")
except Exception as e:
print(f"Validation error: {e}")
```
try:
# Validate data
is_valid, report = loader.validate_data(data_dict)
if not is_valid:
print("Data validation failed")
print(report)
except Exception as e:
print(f"Validation error: {e}")
```

## Related Documentation

Expand Down
18 changes: 12 additions & 6 deletions docs/quick-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ poetry run quanttradeai --help
poetry run quanttradeai fetch-data
poetry run quanttradeai fetch-data --refresh

# Train models
poetry run quanttradeai train
# Train models
poetry run quanttradeai train
poetry run quanttradeai train --skip-validation # bypass data-quality gate

# Evaluate model
poetry run quanttradeai evaluate -m models/trained/AAPL
Expand All @@ -25,7 +26,8 @@ poetry run quanttradeai backtest --cost-bps 5 --slippage-bps 10
# Backtest a saved model (end-to-end)
poetry run quanttradeai backtest-model -m models/experiments/<timestamp>/<SYMBOL> \
-c config/model_config.yaml -b config/backtest_config.yaml \
--cost-bps 5 --slippage-fixed 0.01 --liquidity-max-participation 0.25
--cost-bps 5 --slippage-fixed 0.01 --liquidity-max-participation 0.25 \
--skip-validation # optional
```

## 📊 Python API Patterns
Expand All @@ -38,9 +40,13 @@ from quanttradeai import DataLoader
loader = DataLoader("config/model_config.yaml")
data = loader.fetch_data(symbols=['AAPL', 'TSLA'], refresh=True)

# Validate data
is_valid = loader.validate_data(data)
```
# Validate data and inspect per-symbol report
is_valid, report = loader.validate_data(data)
if not is_valid:
print(report)
# CLI commands write this report to models/experiments/<timestamp>/validation.json
# or reports/backtests/<timestamp>/validation.json
```

### Feature Engineering
```python
Expand Down
22 changes: 19 additions & 3 deletions quanttradeai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,16 @@ def cmd_fetch_data(
def cmd_train(
config: str = typer.Option(
"config/model_config.yaml", "-c", "--config", help="Path to config file"
)
),
skip_validation: bool = typer.Option(
False,
"--skip-validation",
help="Skip data-quality validation before training",
),
):
"""Run full training pipeline."""

run_pipeline(config)
run_pipeline(config, skip_validation=skip_validation)


@app.command("evaluate")
Expand All @@ -57,10 +62,15 @@ def cmd_evaluate(
config: str = typer.Option(
"config/model_config.yaml", "-c", "--config", help="Path to config file"
),
skip_validation: bool = typer.Option(
False,
"--skip-validation",
help="Skip data-quality validation before evaluation",
),
):
"""Evaluate a saved model on current dataset."""

evaluate_model(config, model_path)
evaluate_model(config, model_path, skip_validation=skip_validation)


@app.command("backtest")
Expand Down Expand Up @@ -133,6 +143,11 @@ def cmd_backtest_model(
liquidity_max_participation: Optional[float] = typer.Option(
None, help="Liquidity max participation"
),
skip_validation: bool = typer.Option(
False,
"--skip-validation",
help="Skip data-quality validation before backtesting",
),
):
"""Backtest a saved model on the configured test window with execution costs."""

Expand All @@ -146,6 +161,7 @@ def cmd_backtest_model(
slippage_bps=slippage_bps,
slippage_fixed=slippage_fixed,
liquidity_max_participation=liquidity_max_participation,
skip_validation=skip_validation,
)
typer.echo(json.dumps(summary, indent=2))

Expand Down
83 changes: 52 additions & 31 deletions quanttradeai/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,37 +349,58 @@ def _resample_secondary(
resampled = resampled.reindex(target_index)
return resampled

def validate_data(self, data_dict: Dict[str, pd.DataFrame]) -> bool:
"""
Validate the fetched data meets requirements.

Args:
data_dict: Dictionary of DataFrames with OHLCV data.

Returns:
bool: True if data is valid, False otherwise.
"""
required_columns = ["Open", "High", "Low", "Close", "Volume"]

for symbol, df in data_dict.items():
# Check required columns
if not all(col in df.columns for col in required_columns):
logger.error(f"Missing required columns for {symbol}")
return False

# Check data range
date_range = (df.index.max() - df.index.min()).days
if date_range < 365: # At least one year of data
logger.error(f"Insufficient data range for {symbol}")
return False

# Check for excessive missing values
# Check missing value ratio per column
if df.isnull().mean().max() > 0.01: # Max 1% missing values
logger.error(f"Too many missing values for {symbol}")
return False

return True
def validate_data(self, data_dict: Dict[str, pd.DataFrame]) -> tuple[bool, dict]:
"""
Validate the fetched data meets requirements and return a detailed report.

Args:
data_dict: Dictionary of DataFrames with OHLCV data.

Returns:
Tuple[bool, dict]: Overall validity flag and a per-symbol report with
missing column checks, date span, NaN ratios, and pass/fail status.
"""
required_columns = ["Open", "High", "Low", "Close", "Volume"]

overall_valid = True
report: Dict[str, dict] = {}

for symbol, df in data_dict.items():
missing_columns = [col for col in required_columns if col not in df.columns]
date_span_days = int((df.index.max() - df.index.min()).days) if not df.empty else 0
nan_ratio_by_column = {
col: float(df[col].isnull().mean()) for col in df.columns
}
nan_ratio_required_columns = {
col: nan_ratio_by_column[col]
for col in required_columns
if col in nan_ratio_by_column
}
max_nan_ratio = max(nan_ratio_required_columns.values(), default=0.0)

Comment on lines 371 to 380

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Allow sparse optional columns in validation

Validation now computes nan_ratio_by_column across every column and then fails if the maximum exceeds 1%, which includes optional columns such as the news text column added in _attach_news. When news ingestion is enabled, that column is expected to be mostly null (no headline every bar), so max_nan_ratio will routinely exceed the 1% threshold and _validate_or_raise will halt training/backtesting even though the required OHLCV data is present. Validation should ignore optional/sparse fields or limit the NaN check to required price columns to avoid blocking any configuration that enables news ingestion.

Useful? React with 👍 / 👎.

errors = []
if missing_columns:
errors.append(f"Missing required columns: {', '.join(missing_columns)}")
if date_span_days < 365:
errors.append("Insufficient data range (<365 days)")
if max_nan_ratio > 0.01:
errors.append(f"Too many missing values (max ratio={max_nan_ratio:.4f})")

passed = len(errors) == 0
if not passed:
overall_valid = False
logger.error("Data validation failed for %s: %s", symbol, "; ".join(errors))

report[symbol] = {
"missing_columns": missing_columns,
"date_span_days": date_span_days,
"nan_ratio_by_column": nan_ratio_by_column,
"max_nan_ratio": max_nan_ratio,
"passed": passed,
"errors": errors,
}

return overall_valid, report

def save_data(
self, data_dict: Dict[str, pd.DataFrame], path: Optional[str] = None
Expand Down
91 changes: 88 additions & 3 deletions quanttradeai/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ def setup_directories(cache_dir: str = "data/raw"):
Path(dir_path).mkdir(parents=True, exist_ok=True)


def _write_validation_report(report_path: Path, report: dict) -> None:
"""Persist validation results to JSON and CSV formats."""

report_path.parent.mkdir(parents=True, exist_ok=True)
with open(report_path, "w") as f:
json.dump(report, f, indent=2, default=float)

try:
import pandas as pd

rows = []
for symbol, details in report.items():
row = {"symbol": symbol, **details}
rows.append(row)
df = pd.DataFrame(rows)
df.to_csv(report_path.with_suffix(".csv"), index=False)
except Exception as exc: # pragma: no cover - defensive
logger.warning("Failed to write CSV validation report: %s", exc)


def _ensure_datetime_index(df: pd.DataFrame) -> pd.DataFrame:
"""Ensure the DataFrame index is a DatetimeIndex.

Expand Down Expand Up @@ -168,7 +188,46 @@ def _window_has_full_coverage(
return train_df, test_df


def run_pipeline(config_path: str = "config/model_config.yaml"):
def _validate_or_raise(
*,
loader: DataLoader,
data: dict,
report_path: Path,
skip_validation: bool,
) -> dict:
"""Validate fetched data and persist a report unless skipped."""

if skip_validation:
logger.warning(
"Skipping data validation as requested; downstream results may be unreliable."
)
return {}

validation_result = loader.validate_data(data)
if isinstance(validation_result, tuple) and len(validation_result) == 2:
is_valid, report = validation_result
else:
is_valid = bool(validation_result)
report = {}
_write_validation_report(report_path, report)

if not is_valid:
failed = [symbol for symbol, res in report.items() if not res.get("passed")]
raise ValueError(
f"Data validation failed for symbols: {', '.join(failed) if failed else 'unknown'}"
)

logger.info(
"Data validation passed for %d symbols. Report saved to %s",
len(report),
report_path,
)
return report


def run_pipeline(
config_path: str = "config/model_config.yaml", *, skip_validation: bool = False
):
"""Run the end-to-end training pipeline.

Loads data, generates features and labels, tunes hyperparameters,
Expand Down Expand Up @@ -205,6 +264,15 @@ def run_pipeline(config_path: str = "config/model_config.yaml"):
refresh = config.get("data", {}).get("refresh", False)
data_dict = data_loader.fetch_data(refresh=refresh)

# Validate data quality
validation_path = Path(experiment_dir) / "validation.json"
_validate_or_raise(
loader=data_loader,
data=data_dict,
report_path=validation_path,
skip_validation=skip_validation,
)

# Process each stock
results = {}
for symbol, df in data_dict.items():
Expand Down Expand Up @@ -284,7 +352,9 @@ def fetch_data_only(config_path: str, refresh: bool = False) -> None:
data_loader.save_data(data)


def evaluate_model(config_path: str, model_path: str) -> None:
def evaluate_model(
config_path: str, model_path: str, *, skip_validation: bool = False
) -> None:
"""Evaluate a persisted model on current config’s dataset.

Example
Expand All @@ -297,6 +367,13 @@ def evaluate_model(config_path: str, model_path: str) -> None:
model.load_model(model_path)

data_dict = data_loader.fetch_data()
validation_path = Path(model_path) / "validation.json"
_validate_or_raise(
loader=data_loader,
data=data_dict,
report_path=validation_path,
skip_validation=skip_validation,
)
results = {}
for symbol, df in data_dict.items():
df_processed = data_processor.process_data(df)
Expand Down Expand Up @@ -382,6 +459,7 @@ def run_model_backtest(
slippage_bps: float | None = None,
slippage_fixed: float | None = None,
liquidity_max_participation: float | None = None,
skip_validation: bool = False,
) -> dict:
"""Backtest a saved model’s predictions on the configured test window.

Expand Down Expand Up @@ -432,8 +510,15 @@ def run_model_backtest(
risk_path,
)

data_dict = loader.fetch_data()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
data_dict = loader.fetch_data()
validation_path = Path(f"reports/backtests/{timestamp}") / "validation.json"
_validate_or_raise(
loader=loader,
data=data_dict,
report_path=validation_path,
skip_validation=skip_validation,
)
base_dir = Path(f"reports/backtests/{timestamp}")
base_dir.mkdir(parents=True, exist_ok=True)

Expand Down
Loading