Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ async def get_static_evaluation() -> dict[str, Any]:
dict[str, Any]
Static evaluation metrics with RMSE/MAE by horizon
"""
logger = get_logger()
logger.info("=" * 80)
logger.info("GET /evaluation/static endpoint called")

if not hasattr(app.state, "eval_storage"):
raise HTTPException(
status_code=503,
Expand All @@ -370,11 +374,20 @@ async def get_static_evaluation() -> dict[str, Any]:
start_date = datetime(2024, 2, 6, 12, 0, 0)
end_date = datetime(2024, 7, 19, 17, 0, 0)

logger.info(f"Static evaluation period: {start_date} to {end_date}")

# Compute metrics using fast SQL aggregation in BigQuery (run in thread)
logger.info("Calling compute_metrics_for_period...")
raw_metrics = await asyncio.to_thread(
app.state.eval_storage.compute_metrics_for_period, start_date, end_date
)

logger.info("Raw metrics from compute_metrics_for_period:")
logger.info(f" overall_rmse: {raw_metrics['overall_rmse']:.4f}°C")
logger.info(f" overall_mae: {raw_metrics['overall_mae']:.4f}°C")
logger.info(f" total_samples: {raw_metrics['total_samples']:,}")
logger.info(f" by_horizon keys: {list(raw_metrics['by_horizon'].keys())}")

# Restructure metrics to match frontend expectations
metrics = {
"overall": {
Expand All @@ -385,7 +398,12 @@ async def get_static_evaluation() -> dict[str, Any]:
"by_horizon": raw_metrics["by_horizon"],
}

logger.info("Restructured metrics for frontend:")
logger.info(f" metrics['overall']: {metrics['overall']}")
logger.info(f" metrics['by_horizon']: {metrics['by_horizon']}")

if raw_metrics["total_samples"] == 0:
logger.warning("No samples found for static evaluation period")
return {
"message": "No predictions found for static evaluation period",
"evaluation_period": {
Expand All @@ -396,14 +414,15 @@ async def get_static_evaluation() -> dict[str, Any]:
}

# Store computed metrics for caching (run in thread)
logger.info("Storing evaluation metrics to BigQuery...")
await asyncio.to_thread(
app.state.eval_storage.store_evaluation_metrics,
evaluation_date=datetime.now(),
metrics=raw_metrics, # Store raw format internally
eval_type="static",
)

return {
response = {
"evaluation_period": {
"start": start_date.isoformat(),
"end": end_date.isoformat(),
Expand All @@ -412,7 +431,13 @@ async def get_static_evaluation() -> dict[str, Any]:
"computed_at": datetime.now().isoformat(),
}

logger.info("Static evaluation endpoint completed successfully")
logger.info("=" * 80)
return response

except Exception as e:
logger.error(f"Failed to compute static evaluation: {str(e)}")
logger.info("=" * 80)
raise HTTPException(
status_code=500,
detail=f"Failed to compute static evaluation: {str(e)}",
Expand Down Expand Up @@ -480,6 +505,10 @@ async def get_dynamic_evaluation() -> dict[str, Any]:
dict[str, Any]
Dynamic evaluation metrics for last 30 days
"""
logger = get_logger()
logger.info("=" * 80)
logger.info("GET /evaluation/dynamic endpoint called")

if not hasattr(app.state, "eval_storage"):
raise HTTPException(
status_code=503,
Expand All @@ -491,11 +520,20 @@ async def get_dynamic_evaluation() -> dict[str, Any]:
end_date = datetime.now()
start_date = end_date - timedelta(days=30)

logger.info(f"Evaluation period: {start_date} to {end_date} (30 days)")

# Compute metrics using fast SQL aggregation in BigQuery (run in thread)
logger.info("Calling compute_metrics_for_period...")
raw_metrics = await asyncio.to_thread(
app.state.eval_storage.compute_metrics_for_period, start_date, end_date
)

logger.info("Raw metrics from compute_metrics_for_period:")
logger.info(f" overall_rmse: {raw_metrics['overall_rmse']:.4f}°C")
logger.info(f" overall_mae: {raw_metrics['overall_mae']:.4f}°C")
logger.info(f" total_samples: {raw_metrics['total_samples']:,}")
logger.info(f" by_horizon keys: {list(raw_metrics['by_horizon'].keys())}")

# Restructure metrics to match frontend expectations
metrics = {
"overall": {
Expand All @@ -506,7 +544,12 @@ async def get_dynamic_evaluation() -> dict[str, Any]:
"by_horizon": raw_metrics["by_horizon"],
}

logger.info("Restructured metrics for frontend:")
logger.info(f" metrics['overall']: {metrics['overall']}")
logger.info(f" metrics['by_horizon']: {metrics['by_horizon']}")

if raw_metrics["total_samples"] == 0:
logger.warning("No samples found for dynamic evaluation window")
return {
"message": "No predictions found for dynamic evaluation window",
"evaluation_window": {
Expand All @@ -518,14 +561,15 @@ async def get_dynamic_evaluation() -> dict[str, Any]:
}

# Store computed metrics (run in thread)
logger.info("Storing evaluation metrics to BigQuery...")
await asyncio.to_thread(
app.state.eval_storage.store_evaluation_metrics,
evaluation_date=datetime.now(),
metrics=raw_metrics, # Store raw format internally
eval_type="dynamic",
)

return {
response = {
"evaluation_window": {
"start": start_date.isoformat(),
"end": end_date.isoformat(),
Expand All @@ -535,7 +579,13 @@ async def get_dynamic_evaluation() -> dict[str, Any]:
"computed_at": datetime.now().isoformat(),
}

logger.info("Evaluation endpoint completed successfully")
logger.info("=" * 80)
return response

except Exception as e:
logger.error(f"Failed to compute dynamic evaluation: {str(e)}")
logger.info("=" * 80)
raise HTTPException(
status_code=500,
detail=f"Failed to compute dynamic evaluation: {str(e)}",
Expand Down
95 changes: 95 additions & 0 deletions scripts/clear_static_predictions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Clear static evaluation predictions from BigQuery.

Deletes all predictions for the static evaluation period (Feb 6, 2024 - July 19, 2024)
to allow for regeneration with proper hourly intervals.

Usage:
GCP_PROJECT_ID=coderd python scripts/clear_static_predictions.py
"""

import os
from datetime import datetime

from google.cloud import bigquery
from rich.console import Console


console = Console()


def main() -> None:
"""Clear static evaluation predictions from BigQuery."""
project_id = os.getenv("GCP_PROJECT_ID")
if not project_id:
console.print("[red]✗ GCP_PROJECT_ID environment variable not set[/red]")
return

console.print(
"\n[bold cyan]Clearing Static Evaluation Predictions from BigQuery[/bold cyan]\n"
)

# Static evaluation period
start_date = datetime(2024, 2, 6, 12, 0, 0)
end_date = datetime(2024, 7, 19, 17, 0, 0)

client = bigquery.Client(project=project_id)
table_id = f"{project_id}.gaca_evaluation.predictions"

console.print("[yellow]Target period:[/yellow]")
console.print(f" Start: {start_date}")
console.print(f" End: {end_date}")
console.print()

# Count rows to be deleted
count_query = f"""
SELECT COUNT(*) as count
FROM `{table_id}`
WHERE run_timestamp BETWEEN @start_date AND @end_date
"""

job_config = bigquery.QueryJobConfig(
query_parameters=[
bigquery.ScalarQueryParameter("start_date", "TIMESTAMP", start_date),
bigquery.ScalarQueryParameter("end_date", "TIMESTAMP", end_date),
]
)

console.print("[cyan]Counting rows in static period...[/cyan]")
count_job = client.query(count_query, job_config=job_config)
result = list(count_job.result())
row_count = result[0]["count"] if result else 0

console.print(f"[yellow]Found {row_count:,} rows to delete[/yellow]\n")

if row_count == 0:
console.print("[green]✓ No rows to delete[/green]\n")
return

# Confirm deletion
console.print(
"[bold red]WARNING:[/bold red] This will permanently delete "
f"{row_count:,} prediction rows!"
)
response = input("Continue? (yes/no): ").strip().lower()

if response != "yes":
console.print("\n[yellow]⚠ Cancelled[/yellow]\n")
return

# Delete rows
console.print(f"\n[cyan]Deleting rows from {table_id}...[/cyan]")
delete_query = f"""
DELETE FROM `{table_id}`
WHERE run_timestamp BETWEEN @start_date AND @end_date
"""

delete_job = client.query(delete_query, job_config=job_config)
delete_job.result() # Wait for completion

console.print(
f"\n[bold green]✓ Deleted {row_count:,} predictions from static period![/bold green]\n"
)


if __name__ == "__main__":
main()
25 changes: 19 additions & 6 deletions scripts/generate_historical_predictions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
"""Generate predictions for historical validation period and store to Firestore.
"""Generate predictions for historical validation period and store to BigQuery.

This script wraps the batch-predict CLI command and stores results to Firestore
This script wraps the batch-predict CLI command and stores results to BigQuery
for evaluation purposes.

IMPORTANT: For accurate evaluation, use --interval 1 (hourly) to ensure unbiased
temporal coverage across all times of day. Larger intervals (e.g., 24h) will cause
each forecast horizon to be evaluated at only specific times, introducing diurnal
bias and making metrics unreliable for cross-horizon comparison.

Usage:
python scripts/generate_historical_predictions.py --help
python scripts/generate_historical_predictions.py --start-date "2024-02-06 12:00"
# Recommended: Hourly predictions for unbiased evaluation
python scripts/generate_historical_predictions.py \
--start-date "2024-02-06 12:00" \
--end-date "2024-07-19 17:00" \
--interval 1

# Faster but biased: Daily predictions (not recommended)
python scripts/generate_historical_predictions.py \
--start-date "2024-02-06 12:00" \
--interval 24
"""

import argparse
Expand Down Expand Up @@ -169,8 +182,8 @@ def main() -> None:
parser.add_argument(
"--interval",
type=int,
default=24,
help="Interval between predictions in hours",
default=1,
help="Interval between predictions in hours (default: 1 for continuous evaluation)",
)
parser.add_argument(
"--output",
Expand Down
13 changes: 9 additions & 4 deletions src/gaca_ews/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,9 @@ def batch_predict( # noqa: PLR0912, PLR0915
int,
typer.Option(
"--interval",
help="Interval between predictions in hours",
help="Interval between predictions in hours (use 1 for continuous evaluation)",
),
] = 24,
] = 1,
output: Annotated[
Path | None,
typer.Option(
Expand Down Expand Up @@ -326,10 +326,15 @@ def batch_predict( # noqa: PLR0912, PLR0915
Generates predictions for multiple timestamps within a date range,
useful for validation, evaluation, and backtesting.

IMPORTANT: For accurate evaluation, use --interval 1 (hourly predictions)
to ensure unbiased temporal coverage. Larger intervals (e.g., 24h) will
cause each forecast horizon to be evaluated at only specific times of day,
introducing diurnal bias and making cross-horizon comparisons invalid.

Example:
gaca-ews batch-predict --start-date "2024-02-06 12:00" \\
--end-date "2024-02-10 12:00"
gaca-ews batch-predict --start-date "2024-02-06 12:00" --interval 12
--end-date "2024-02-10 12:00" --interval 1
gaca-ews batch-predict --start-date "2024-02-06 12:00" --interval 6
"""
try:
# Set logger level
Expand Down
31 changes: 31 additions & 0 deletions src/gaca_ews/evaluation/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,19 @@ def compute_metrics_for_period(
console.print(
f"[cyan]Computing metrics for {start_date} to {end_date}...[/cyan]"
)
console.print("[yellow]DEBUG: Query parameters:[/yellow]")
console.print(f" start_date: {start_date}")
console.print(f" end_date: {end_date}")
console.print(f" project: {self.project_id}")
console.print(f" dataset: {self.dataset_id}")

query_job = self.client.query(query, job_config=job_config)
results = query_job.result()

console.print(
"[yellow]DEBUG: BigQuery job completed, parsing results...[/yellow]"
)

# Parse results
by_horizon = {}
total_samples = 0
Expand All @@ -278,6 +288,11 @@ def compute_metrics_for_period(
mae = float(row["mae"])
count = int(row["sample_count"])

console.print(
f"[yellow]DEBUG: Horizon {horizon}h - "
f"RMSE={rmse:.4f}°C, MAE={mae:.4f}°C, samples={count:,}[/yellow]"
)

by_horizon[horizon] = {
"rmse": rmse,
"mae": mae,
Expand All @@ -289,7 +304,15 @@ def compute_metrics_for_period(
sum_squared_errors += (rmse**2) * count
sum_abs_errors += mae * count

console.print(
f"[yellow]DEBUG: Accumulating - "
f"total_samples={total_samples:,}, "
f"sum_squared_errors={sum_squared_errors:.4f}, "
f"sum_abs_errors={sum_abs_errors:.4f}[/yellow]"
)

if total_samples == 0:
console.print("[red]DEBUG: No samples found, returning zeros[/red]")
return {
"overall_rmse": 0.0,
"overall_mae": 0.0,
Expand All @@ -300,6 +323,14 @@ def compute_metrics_for_period(
overall_rmse = (sum_squared_errors / total_samples) ** 0.5
overall_mae = sum_abs_errors / total_samples

console.print("[yellow]DEBUG: Overall computation:[/yellow]")
console.print(
f" Overall RMSE = sqrt({sum_squared_errors:.4f} / {total_samples:,}) = {overall_rmse:.4f}°C"
)
console.print(
f" Overall MAE = {sum_abs_errors:.4f} / {total_samples:,} = {overall_mae:.4f}°C"
)

console.print(
f"[green]✓[/green] Computed metrics over {total_samples:,} samples "
f"(RMSE: {overall_rmse:.3f}°C, MAE: {overall_mae:.3f}°C)"
Expand Down