@@ -56,6 +56,26 @@ def setup_directories(cache_dir: str = "data/raw"):
5656 Path (dir_path ).mkdir (parents = True , exist_ok = True )
5757
5858
59+ def _write_validation_report (report_path : Path , report : dict ) -> None :
60+ """Persist validation results to JSON and CSV formats."""
61+
62+ report_path .parent .mkdir (parents = True , exist_ok = True )
63+ with open (report_path , "w" ) as f :
64+ json .dump (report , f , indent = 2 , default = float )
65+
66+ try :
67+ import pandas as pd
68+
69+ rows = []
70+ for symbol , details in report .items ():
71+ row = {"symbol" : symbol , ** details }
72+ rows .append (row )
73+ df = pd .DataFrame (rows )
74+ df .to_csv (report_path .with_suffix (".csv" ), index = False )
75+ except Exception as exc : # pragma: no cover - defensive
76+ logger .warning ("Failed to write CSV validation report: %s" , exc )
77+
78+
5979def _ensure_datetime_index (df : pd .DataFrame ) -> pd .DataFrame :
6080 """Ensure the DataFrame index is a DatetimeIndex.
6181
@@ -168,7 +188,46 @@ def _window_has_full_coverage(
168188 return train_df , test_df
169189
170190
171- def run_pipeline (config_path : str = "config/model_config.yaml" ):
191+ def _validate_or_raise (
192+ * ,
193+ loader : DataLoader ,
194+ data : dict ,
195+ report_path : Path ,
196+ skip_validation : bool ,
197+ ) -> dict :
198+ """Validate fetched data and persist a report unless skipped."""
199+
200+ if skip_validation :
201+ logger .warning (
202+ "Skipping data validation as requested; downstream results may be unreliable."
203+ )
204+ return {}
205+
206+ validation_result = loader .validate_data (data )
207+ if isinstance (validation_result , tuple ) and len (validation_result ) == 2 :
208+ is_valid , report = validation_result
209+ else :
210+ is_valid = bool (validation_result )
211+ report = {}
212+ _write_validation_report (report_path , report )
213+
214+ if not is_valid :
215+ failed = [symbol for symbol , res in report .items () if not res .get ("passed" )]
216+ raise ValueError (
217+ f"Data validation failed for symbols: { ', ' .join (failed ) if failed else 'unknown' } "
218+ )
219+
220+ logger .info (
221+ "Data validation passed for %d symbols. Report saved to %s" ,
222+ len (report ),
223+ report_path ,
224+ )
225+ return report
226+
227+
228+ def run_pipeline (
229+ config_path : str = "config/model_config.yaml" , * , skip_validation : bool = False
230+ ):
172231 """Run the end-to-end training pipeline.
173232
174233 Loads data, generates features and labels, tunes hyperparameters,
@@ -205,6 +264,15 @@ def run_pipeline(config_path: str = "config/model_config.yaml"):
205264 refresh = config .get ("data" , {}).get ("refresh" , False )
206265 data_dict = data_loader .fetch_data (refresh = refresh )
207266
267+ # Validate data quality
268+ validation_path = Path (experiment_dir ) / "validation.json"
269+ _validate_or_raise (
270+ loader = data_loader ,
271+ data = data_dict ,
272+ report_path = validation_path ,
273+ skip_validation = skip_validation ,
274+ )
275+
208276 # Process each stock
209277 results = {}
210278 for symbol , df in data_dict .items ():
@@ -284,7 +352,9 @@ def fetch_data_only(config_path: str, refresh: bool = False) -> None:
284352 data_loader .save_data (data )
285353
286354
287- def evaluate_model (config_path : str , model_path : str ) -> None :
355+ def evaluate_model (
356+ config_path : str , model_path : str , * , skip_validation : bool = False
357+ ) -> None :
288358 """Evaluate a persisted model on current config’s dataset.
289359
290360 Example
@@ -297,6 +367,13 @@ def evaluate_model(config_path: str, model_path: str) -> None:
297367 model .load_model (model_path )
298368
299369 data_dict = data_loader .fetch_data ()
370+ validation_path = Path (model_path ) / "validation.json"
371+ _validate_or_raise (
372+ loader = data_loader ,
373+ data = data_dict ,
374+ report_path = validation_path ,
375+ skip_validation = skip_validation ,
376+ )
300377 results = {}
301378 for symbol , df in data_dict .items ():
302379 df_processed = data_processor .process_data (df )
@@ -382,6 +459,7 @@ def run_model_backtest(
382459 slippage_bps : float | None = None ,
383460 slippage_fixed : float | None = None ,
384461 liquidity_max_participation : float | None = None ,
462+ skip_validation : bool = False ,
385463) -> dict :
386464 """Backtest a saved model’s predictions on the configured test window.
387465
@@ -432,8 +510,15 @@ def run_model_backtest(
432510 risk_path ,
433511 )
434512
435- data_dict = loader .fetch_data ()
436513 timestamp = datetime .now ().strftime ("%Y%m%d_%H%M%S" )
514+ data_dict = loader .fetch_data ()
515+ validation_path = Path (f"reports/backtests/{ timestamp } " ) / "validation.json"
516+ _validate_or_raise (
517+ loader = loader ,
518+ data = data_dict ,
519+ report_path = validation_path ,
520+ skip_validation = skip_validation ,
521+ )
437522 base_dir = Path (f"reports/backtests/{ timestamp } " )
438523 base_dir .mkdir (parents = True , exist_ok = True )
439524
0 commit comments