diff --git a/backend/app/main.py b/backend/app/main.py index 6dd82f9..d2cab38 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -359,6 +359,10 @@ async def get_static_evaluation() -> dict[str, Any]: dict[str, Any] Static evaluation metrics with RMSE/MAE by horizon """ + logger = get_logger() + logger.info("=" * 80) + logger.info("GET /evaluation/static endpoint called") + if not hasattr(app.state, "eval_storage"): raise HTTPException( status_code=503, @@ -370,11 +374,20 @@ async def get_static_evaluation() -> dict[str, Any]: start_date = datetime(2024, 2, 6, 12, 0, 0) end_date = datetime(2024, 7, 19, 17, 0, 0) + logger.info(f"Static evaluation period: {start_date} to {end_date}") + # Compute metrics using fast SQL aggregation in BigQuery (run in thread) + logger.info("Calling compute_metrics_for_period...") raw_metrics = await asyncio.to_thread( app.state.eval_storage.compute_metrics_for_period, start_date, end_date ) + logger.info("Raw metrics from compute_metrics_for_period:") + logger.info(f" overall_rmse: {raw_metrics['overall_rmse']:.4f}°C") + logger.info(f" overall_mae: {raw_metrics['overall_mae']:.4f}°C") + logger.info(f" total_samples: {raw_metrics['total_samples']:,}") + logger.info(f" by_horizon keys: {list(raw_metrics['by_horizon'].keys())}") + # Restructure metrics to match frontend expectations metrics = { "overall": { @@ -385,7 +398,12 @@ async def get_static_evaluation() -> dict[str, Any]: "by_horizon": raw_metrics["by_horizon"], } + logger.info("Restructured metrics for frontend:") + logger.info(f" metrics['overall']: {metrics['overall']}") + logger.info(f" metrics['by_horizon']: {metrics['by_horizon']}") + if raw_metrics["total_samples"] == 0: + logger.warning("No samples found for static evaluation period") return { "message": "No predictions found for static evaluation period", "evaluation_period": { @@ -396,6 +414,7 @@ async def get_static_evaluation() -> dict[str, Any]: } # Store computed metrics for caching (run in thread) + logger.info("Storing evaluation metrics to BigQuery...") await asyncio.to_thread( app.state.eval_storage.store_evaluation_metrics, evaluation_date=datetime.now(), @@ -403,7 +422,7 @@ async def get_static_evaluation() -> dict[str, Any]: eval_type="static", ) - return { + response = { "evaluation_period": { "start": start_date.isoformat(), "end": end_date.isoformat(), @@ -412,7 +431,13 @@ async def get_static_evaluation() -> dict[str, Any]: "computed_at": datetime.now().isoformat(), } + logger.info("Static evaluation endpoint completed successfully") + logger.info("=" * 80) + return response + except Exception as e: + logger.error(f"Failed to compute static evaluation: {str(e)}") + logger.info("=" * 80) raise HTTPException( status_code=500, detail=f"Failed to compute static evaluation: {str(e)}", @@ -480,6 +505,10 @@ async def get_dynamic_evaluation() -> dict[str, Any]: dict[str, Any] Dynamic evaluation metrics for last 30 days """ + logger = get_logger() + logger.info("=" * 80) + logger.info("GET /evaluation/dynamic endpoint called") + if not hasattr(app.state, "eval_storage"): raise HTTPException( status_code=503, @@ -491,11 +520,20 @@ async def get_dynamic_evaluation() -> dict[str, Any]: end_date = datetime.now() start_date = end_date - timedelta(days=30) + logger.info(f"Evaluation period: {start_date} to {end_date} (30 days)") + # Compute metrics using fast SQL aggregation in BigQuery (run in thread) + logger.info("Calling compute_metrics_for_period...") raw_metrics = await asyncio.to_thread( app.state.eval_storage.compute_metrics_for_period, start_date, end_date ) + logger.info("Raw metrics from compute_metrics_for_period:") + logger.info(f" overall_rmse: {raw_metrics['overall_rmse']:.4f}°C") + logger.info(f" overall_mae: {raw_metrics['overall_mae']:.4f}°C") + logger.info(f" total_samples: {raw_metrics['total_samples']:,}") + logger.info(f" by_horizon keys: {list(raw_metrics['by_horizon'].keys())}") + # Restructure metrics to match frontend expectations metrics = { "overall": { @@ -506,7 +544,12 @@ async def get_dynamic_evaluation() -> dict[str, Any]: "by_horizon": raw_metrics["by_horizon"], } + logger.info("Restructured metrics for frontend:") + logger.info(f" metrics['overall']: {metrics['overall']}") + logger.info(f" metrics['by_horizon']: {metrics['by_horizon']}") + if raw_metrics["total_samples"] == 0: + logger.warning("No samples found for dynamic evaluation window") return { "message": "No predictions found for dynamic evaluation window", "evaluation_window": { @@ -518,6 +561,7 @@ async def get_dynamic_evaluation() -> dict[str, Any]: } # Store computed metrics (run in thread) + logger.info("Storing evaluation metrics to BigQuery...") await asyncio.to_thread( app.state.eval_storage.store_evaluation_metrics, evaluation_date=datetime.now(), @@ -525,7 +569,7 @@ async def get_dynamic_evaluation() -> dict[str, Any]: eval_type="dynamic", ) - return { + response = { "evaluation_window": { "start": start_date.isoformat(), "end": end_date.isoformat(), @@ -535,7 +579,13 @@ async def get_dynamic_evaluation() -> dict[str, Any]: "computed_at": datetime.now().isoformat(), } + logger.info("Evaluation endpoint completed successfully") + logger.info("=" * 80) + return response + except Exception as e: + logger.error(f"Failed to compute dynamic evaluation: {str(e)}") + logger.info("=" * 80) raise HTTPException( status_code=500, detail=f"Failed to compute dynamic evaluation: {str(e)}", diff --git a/scripts/clear_static_predictions.py b/scripts/clear_static_predictions.py new file mode 100644 index 0000000..9675be7 --- /dev/null +++ b/scripts/clear_static_predictions.py @@ -0,0 +1,95 @@ +"""Clear static evaluation predictions from BigQuery. + +Deletes all predictions for the static evaluation period (Feb 6, 2024 - July 19, 2024) +to allow for regeneration with proper hourly intervals. + +Usage: + GCP_PROJECT_ID=coderd python scripts/clear_static_predictions.py +""" + +import os +from datetime import datetime + +from google.cloud import bigquery +from rich.console import Console + + +console = Console() + + +def main() -> None: + """Clear static evaluation predictions from BigQuery.""" + project_id = os.getenv("GCP_PROJECT_ID") + if not project_id: + console.print("[red]✗ GCP_PROJECT_ID environment variable not set[/red]") + return + + console.print( + "\n[bold cyan]Clearing Static Evaluation Predictions from BigQuery[/bold cyan]\n" + ) + + # Static evaluation period + start_date = datetime(2024, 2, 6, 12, 0, 0) + end_date = datetime(2024, 7, 19, 17, 0, 0) + + client = bigquery.Client(project=project_id) + table_id = f"{project_id}.gaca_evaluation.predictions" + + console.print("[yellow]Target period:[/yellow]") + console.print(f" Start: {start_date}") + console.print(f" End: {end_date}") + console.print() + + # Count rows to be deleted + count_query = f""" + SELECT COUNT(*) as count + FROM `{table_id}` + WHERE run_timestamp BETWEEN @start_date AND @end_date + """ + + job_config = bigquery.QueryJobConfig( + query_parameters=[ + bigquery.ScalarQueryParameter("start_date", "TIMESTAMP", start_date), + bigquery.ScalarQueryParameter("end_date", "TIMESTAMP", end_date), + ] + ) + + console.print("[cyan]Counting rows in static period...[/cyan]") + count_job = client.query(count_query, job_config=job_config) + result = list(count_job.result()) + row_count = result[0]["count"] if result else 0 + + console.print(f"[yellow]Found {row_count:,} rows to delete[/yellow]\n") + + if row_count == 0: + console.print("[green]✓ No rows to delete[/green]\n") + return + + # Confirm deletion + console.print( + "[bold red]WARNING:[/bold red] This will permanently delete " + f"{row_count:,} prediction rows!" + ) + response = input("Continue? (yes/no): ").strip().lower() + + if response != "yes": + console.print("\n[yellow]⚠ Cancelled[/yellow]\n") + return + + # Delete rows + console.print(f"\n[cyan]Deleting rows from {table_id}...[/cyan]") + delete_query = f""" + DELETE FROM `{table_id}` + WHERE run_timestamp BETWEEN @start_date AND @end_date + """ + + delete_job = client.query(delete_query, job_config=job_config) + delete_job.result() # Wait for completion + + console.print( + f"\n[bold green]✓ Deleted {row_count:,} predictions from static period![/bold green]\n" + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_historical_predictions.py b/scripts/generate_historical_predictions.py index 284af72..8bb1270 100644 --- a/scripts/generate_historical_predictions.py +++ b/scripts/generate_historical_predictions.py @@ -1,11 +1,24 @@ -"""Generate predictions for historical validation period and store to Firestore. +"""Generate predictions for historical validation period and store to BigQuery. -This script wraps the batch-predict CLI command and stores results to Firestore +This script wraps the batch-predict CLI command and stores results to BigQuery for evaluation purposes. +IMPORTANT: For accurate evaluation, use --interval 1 (hourly) to ensure unbiased +temporal coverage across all times of day. Larger intervals (e.g., 24h) will cause +each forecast horizon to be evaluated at only specific times, introducing diurnal +bias and making metrics unreliable for cross-horizon comparison. + Usage: - python scripts/generate_historical_predictions.py --help - python scripts/generate_historical_predictions.py --start-date "2024-02-06 12:00" + # Recommended: Hourly predictions for unbiased evaluation + python scripts/generate_historical_predictions.py \ + --start-date "2024-02-06 12:00" \ + --end-date "2024-07-19 17:00" \ + --interval 1 + + # Faster but biased: Daily predictions (not recommended) + python scripts/generate_historical_predictions.py \ + --start-date "2024-02-06 12:00" \ + --interval 24 """ import argparse @@ -169,8 +182,8 @@ def main() -> None: parser.add_argument( "--interval", type=int, - default=24, - help="Interval between predictions in hours", + default=1, + help="Interval between predictions in hours (default: 1 for continuous evaluation)", ) parser.add_argument( "--output", diff --git a/src/gaca_ews/cli/main.py b/src/gaca_ews/cli/main.py index c7a58c3..612c6d1 100644 --- a/src/gaca_ews/cli/main.py +++ b/src/gaca_ews/cli/main.py @@ -294,9 +294,9 @@ def batch_predict( # noqa: PLR0912, PLR0915 int, typer.Option( "--interval", - help="Interval between predictions in hours", + help="Interval between predictions in hours (use 1 for continuous evaluation)", ), - ] = 24, + ] = 1, output: Annotated[ Path | None, typer.Option( @@ -326,10 +326,15 @@ def batch_predict( # noqa: PLR0912, PLR0915 Generates predictions for multiple timestamps within a date range, useful for validation, evaluation, and backtesting. + IMPORTANT: For accurate evaluation, use --interval 1 (hourly predictions) + to ensure unbiased temporal coverage. Larger intervals (e.g., 24h) will + cause each forecast horizon to be evaluated at only specific times of day, + introducing diurnal bias and making cross-horizon comparisons invalid. + Example: gaca-ews batch-predict --start-date "2024-02-06 12:00" \\ - --end-date "2024-02-10 12:00" - gaca-ews batch-predict --start-date "2024-02-06 12:00" --interval 12 + --end-date "2024-02-10 12:00" --interval 1 + gaca-ews batch-predict --start-date "2024-02-06 12:00" --interval 6 """ try: # Set logger level diff --git a/src/gaca_ews/evaluation/storage.py b/src/gaca_ews/evaluation/storage.py index 01ac2e2..69a05b5 100644 --- a/src/gaca_ews/evaluation/storage.py +++ b/src/gaca_ews/evaluation/storage.py @@ -263,9 +263,19 @@ def compute_metrics_for_period( console.print( f"[cyan]Computing metrics for {start_date} to {end_date}...[/cyan]" ) + console.print("[yellow]DEBUG: Query parameters:[/yellow]") + console.print(f" start_date: {start_date}") + console.print(f" end_date: {end_date}") + console.print(f" project: {self.project_id}") + console.print(f" dataset: {self.dataset_id}") + query_job = self.client.query(query, job_config=job_config) results = query_job.result() + console.print( + "[yellow]DEBUG: BigQuery job completed, parsing results...[/yellow]" + ) + # Parse results by_horizon = {} total_samples = 0 @@ -278,6 +288,11 @@ def compute_metrics_for_period( mae = float(row["mae"]) count = int(row["sample_count"]) + console.print( + f"[yellow]DEBUG: Horizon {horizon}h - " + f"RMSE={rmse:.4f}°C, MAE={mae:.4f}°C, samples={count:,}[/yellow]" + ) + by_horizon[horizon] = { "rmse": rmse, "mae": mae, @@ -289,7 +304,15 @@ def compute_metrics_for_period( sum_squared_errors += (rmse**2) * count sum_abs_errors += mae * count + console.print( + f"[yellow]DEBUG: Accumulating - " + f"total_samples={total_samples:,}, " + f"sum_squared_errors={sum_squared_errors:.4f}, " + f"sum_abs_errors={sum_abs_errors:.4f}[/yellow]" + ) + if total_samples == 0: + console.print("[red]DEBUG: No samples found, returning zeros[/red]") return { "overall_rmse": 0.0, "overall_mae": 0.0, @@ -300,6 +323,14 @@ def compute_metrics_for_period( overall_rmse = (sum_squared_errors / total_samples) ** 0.5 overall_mae = sum_abs_errors / total_samples + console.print("[yellow]DEBUG: Overall computation:[/yellow]") + console.print( + f" Overall RMSE = sqrt({sum_squared_errors:.4f} / {total_samples:,}) = {overall_rmse:.4f}°C" + ) + console.print( + f" Overall MAE = {sum_abs_errors:.4f} / {total_samples:,} = {overall_mae:.4f}°C" + ) + console.print( f"[green]✓[/green] Computed metrics over {total_samples:,} samples " f"(RMSE: {overall_rmse:.3f}°C, MAE: {overall_mae:.3f}°C)"