diff --git a/conf/datapaths/datapaths_cannon.yaml b/conf/datapaths/datapaths_cannon.yaml index cbc2419..b7dd1aa 100644 --- a/conf/datapaths/datapaths_cannon.yaml +++ b/conf/datapaths/datapaths_cannon.yaml @@ -3,4 +3,5 @@ name: null dirs: input: lego: /n/dominici_lab/lab/lego - output: /n/dominici_lab/lab/lego_loader_x/output \ No newline at end of file + covars: /n/dominici_lab/lab/lego_loader_x/output + health: /n/dominici_lab/lab/lego_loader_x/synthetic_health \ No newline at end of file diff --git a/conf/synthetic/config.yaml b/conf/synthetic/config.yaml new file mode 100644 index 0000000..6301948 --- /dev/null +++ b/conf/synthetic/config.yaml @@ -0,0 +1,54 @@ +year: 2010 +horizons: [30, 90, 180] # Horizons in days (including daily) + +# conf +var_group: health +vg_name: synthetic_health + +var: diabetes + +spatial_res: zcta +temporal_res: daily + +input_dir: data/input/ +output_dir: data/health/ + +#var_group +min_year: 2000 +max_year: 2015 +min_spatial_res: zcta +min_temporal_res: daily +lego_nm: synthetic_sparse_counts +lego_dir: lego/synthetic/medpar_outcomes/ccw/zcta_daily + +# Debug options +debug_days: 3 # Set to null or remove for full year processing + +# Synthetic data parameters +synthetic: + # Paths for ZCTA data + zcta_unique_path: data/input/lego/geoboundaries/us_geoboundaries__census/us_uniqueid__census/zcta_yearly + zcta_shapefile_path: data/input/lego/geoboundaries/us_geoboundaries__census/us_shapefile__census/zcta_yearly + population_path: data/input/lego/social/demographics__census/raw/core/zcta__dec__population.parquet + + # Poisson distribution parameters for synthetic data generation + poisson_params: + base_rate: 0.11 # Base rate for Poisson distribution (target: ~85% zeros) + seasonal_amplitude: 0.02 # Seasonal variation amplitude (reduced to maintain sparsity) + spatial_variance: 0.03 # Spatial variance across ZCTAs (reduced to maintain sparsity) + latitude_effect: 0.2 # Effect of latitude on incidence + longitude_effect: 0.1 # Effect of longitude on incidence + population_effect: 0.0001 # Population scaling factor (per capita effect) + random_seed: 42 # For reproducibility + + # Geographic constraints + mainland_only: true # Filter for continental US only + + # Date range for synthetic data + date_range: + start_year: 2000 + end_year: 2015 + +hydra: + run: + dir: logs/synthetic/${now:%Y-%m-%d}/${now:%H-%M-%S} \ No newline at end of file diff --git a/conf/synthetic/snakemake.yaml b/conf/synthetic/snakemake.yaml new file mode 100644 index 0000000..504330a --- /dev/null +++ b/conf/synthetic/snakemake.yaml @@ -0,0 +1,112 @@ +years: [2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, + 2010, 2011, 2012, 2013, 2014] + +# All health outcomes (same as real health data) +vars: ['anemia', 'ami', 'alzh', 'alzhdmta', 'atrialfb', 'cataract', + 'chrnkidn', 'copd', 'chf', 'diabetes', 'stroke', 'breastCancer', + 'colorectalCancer', 'prostateCancer', 'lungCancer', + 'endometrialCancer', 'hyperp', 'glaucoma', 'hipfrac', 'ischmcht', + 'depressn', 'osteoprs', 'ra_oa', 'asthma', 'hyperl', 'hypert', + 'hypoth'] + +# Default parameters for all diseases +default_params: + base_rate: 0.11 + seasonal_amplitude: 0.02 + spatial_variance: 0.03 + latitude_effect: 0.2 + longitude_effect: 0.1 + population_effect: 0.0001 + random_seed: 42 + +# Disease-specific parameter variations +disease_params: + # More common diseases + diabetes: + base_rate: 0.15 + seasonal_amplitude: 0.015 + + hypert: + base_rate: 0.18 + seasonal_amplitude: 0.01 + + hyperl: + base_rate: 0.16 + seasonal_amplitude: 0.01 + + chf: + base_rate: 0.12 + seasonal_amplitude: 0.03 + + # Seasonal diseases + asthma: + base_rate: 0.10 + seasonal_amplitude: 0.04 + + copd: + base_rate: 0.09 + seasonal_amplitude: 0.035 + + stroke: + base_rate: 0.08 + seasonal_amplitude: 0.025 + + # Cancers (less seasonal, more spatial variation) + breastCancer: + base_rate: 0.05 + seasonal_amplitude: 0.005 + spatial_variance: 0.05 + + lungCancer: + base_rate: 0.06 + seasonal_amplitude: 0.008 + spatial_variance: 0.06 + + colorectalCancer: + base_rate: 0.04 + seasonal_amplitude: 0.005 + spatial_variance: 0.04 + + prostateCancer: + base_rate: 0.07 + seasonal_amplitude: 0.005 + spatial_variance: 0.05 + + endometrialCancer: + base_rate: 0.03 + seasonal_amplitude: 0.003 + spatial_variance: 0.04 + + # Age-related diseases (higher population effect) + alzh: + base_rate: 0.06 + population_effect: 0.0002 + spatial_variance: 0.04 + + alzhdmta: + base_rate: 0.05 + population_effect: 0.0002 + spatial_variance: 0.04 + + osteoprs: + base_rate: 0.08 + population_effect: 0.00015 + seasonal_amplitude: 0.015 + + cataract: + base_rate: 0.10 + population_effect: 0.00015 + seasonal_amplitude: 0.01 + + # Other diseases + ami: + base_rate: 0.07 + seasonal_amplitude: 0.025 + + hipfrac: + base_rate: 0.04 + seasonal_amplitude: 0.02 + + anemia: + base_rate: 0.09 + seasonal_amplitude: 0.02 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b0ce7d1..e7a9928 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,11 @@ numpy torch -pandas==2.2.2 -pyarrow==11.0.0 +pandas>2.2.2 +pyarrow duckdb==0.9.2 hydra-core==1.3.2 snakemake==8.16 tqdm ipykernel +geopandas +scipy diff --git a/snakefile_synthetic_health.smk b/snakefile_synthetic_health.smk new file mode 100644 index 0000000..65c82d4 --- /dev/null +++ b/snakefile_synthetic_health.smk @@ -0,0 +1,52 @@ +# Snakemake file for synthetic health data generation + +# Load config +configfile: "conf/synthetic/snakemake.yaml" + +# Get config values +years = config["years"] +vars = config["vars"] + +def get_disease_params(var): + """Get disease-specific parameters for synthetic data generation""" + params = config["default_params"].copy() + + # Apply disease-specific variations if they exist + if var in config["disease_params"]: + params.update(config["disease_params"][var]) + + # Add disease-specific random seed for reproducibility but variation + params["random_seed"] = hash(var) % 1000 + 42 + + return params + +# Rule: final output is one sentinel file per var/year (Dec 31) +rule all: + input: + expand( + "data/health/synthetic_health/{var}/{var}__{year}1231.parquet", + var=vars, + year=years + ) + +# Rule: preprocess synthetic health data for given var and year +rule preprocess_synthetic_health: + output: + "data/health/synthetic_health/{var}/{var}__{year}1231.parquet" + params: + disease_params = lambda wildcards: get_disease_params(wildcards.var) + shell: + """ + python src/preprocessing_synth_health.py \ + hydra.run.dir=. \ + var={wildcards.var} \ + year={wildcards.year} \ + debug_days=null \ + synthetic.poisson_params.base_rate={params.disease_params[base_rate]} \ + synthetic.poisson_params.seasonal_amplitude={params.disease_params[seasonal_amplitude]} \ + synthetic.poisson_params.spatial_variance={params.disease_params[spatial_variance]} \ + synthetic.poisson_params.latitude_effect={params.disease_params[latitude_effect]} \ + synthetic.poisson_params.longitude_effect={params.disease_params[longitude_effect]} \ + synthetic.poisson_params.population_effect={params.disease_params[population_effect]} \ + synthetic.poisson_params.random_seed={params.disease_params[random_seed]} + """ \ No newline at end of file diff --git a/src/preprocessing_synth_health.py b/src/preprocessing_synth_health.py new file mode 100644 index 0000000..4dd31c1 --- /dev/null +++ b/src/preprocessing_synth_health.py @@ -0,0 +1,314 @@ +from datetime import date, timedelta +import duckdb +import hydra +import os +import logging +import calendar +import numpy as np +import pandas as pd +import geopandas as gpd +import math +from tqdm import tqdm + + +# Configure logging +LOGGER = logging.getLogger(__name__) +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +def get_zcta_data_with_geo_pop(unique_fpath, shapefile_fpath, population_fpath, year, mainland_only=True): + """ + Extract ZCTA IDs with geographic coordinates and population data + Returns a comprehensive dataset for mainland US ZCTAs + """ + LOGGER.info(f"Loading ZCTA data for year {year}") + + # Read unique ID file for the given year + unique_file = f"{unique_fpath}/us_uniqueid__census__zcta_yearly__{year}.parquet" + df_unique = pd.read_parquet(unique_file) + + # Filter for mainland US if requested + if mainland_only and 'continental_us' in df_unique.columns: + df_unique = df_unique[df_unique.continental_us] + LOGGER.info(f"Filtered to {len(df_unique)} mainland US ZCTAs") + + # Read shapefile for geographic data + shapefile_dir = f"{shapefile_fpath}/us_shapefile__census__zcta_yearly__{year}" + shapefile_path = f"{shapefile_dir}/us_shapefile__census__zcta_yearly__{year}.shp" + + if os.path.exists(shapefile_path): + gdf = gpd.read_file(shapefile_path) + + # Calculate centroids for lat/lon (project to appropriate CRS first) + gdf_projected = gdf.to_crs('EPSG:3857') # Web Mercator for accurate centroids + centroids = gdf_projected.geometry.centroid.to_crs(gdf.crs) # Back to original CRS + + gdf['longitude'] = centroids.x + gdf['latitude'] = centroids.y + + # Keep only necessary columns + geo_data = gdf[['zcta', 'longitude', 'latitude']].copy() + + LOGGER.info(f"Loaded geographic data for {len(geo_data)} ZCTAs") + else: + raise FileNotFoundError( + f"Shapefile not found: {shapefile_path}. " + f"Cannot proceed without geographic data for year {year}." + ) + + # Read population data + df_pop_full = pd.read_parquet(population_fpath).reset_index() + + # Map year to nearest census year (population data only available for 2000, 2010, 2020) + available_years = sorted(df_pop_full['year'].unique()) + + # Find the closest census year + if year <= 2005: + census_year = 2000 + elif year <= 2015: + census_year = 2010 + else: + census_year = 2020 + + # Use population data from the mapped census year + df_pop = df_pop_full[df_pop_full['year'] == census_year] + + if len(df_pop) == 0: + raise ValueError( + f"No population data available for census year {census_year} " + f"(mapped from year {year}). Available years: {available_years}" + ) + + if census_year != year: + LOGGER.info(f"Using population data from census year {census_year} for year {year}") + + # Merge all data together + zcta_data = df_unique.merge(geo_data, on='zcta', how='left') + zcta_data = zcta_data.merge(df_pop[['zcta', 'population']], on='zcta', how='left') + + # Check for missing population data - this indicates a data consistency problem + missing_pop = zcta_data['population'].isna() + if missing_pop.any(): + missing_count = missing_pop.sum() + missing_zctas = zcta_data[missing_pop]['zcta'].head(10).tolist() + raise ValueError( + f"Missing population data for {missing_count} ZCTAs. " + f"Example missing ZCTAs: {missing_zctas}. " + f"This indicates a mismatch between ZCTA lists and population data." + ) + + # Check for missing coordinates - this indicates a real problem + missing_lon = zcta_data['longitude'].isna() + missing_lat = zcta_data['latitude'].isna() + + if missing_lon.any() or missing_lat.any(): + missing_count = missing_lon.sum() + missing_lat.sum() + LOGGER.error(f"Found {missing_count} ZCTAs with missing coordinates!") + LOGGER.error("This indicates a mismatch between unique IDs and shapefile data") + + # Show some examples for debugging + missing_zctas = zcta_data[missing_lon | missing_lat]['zcta'].head(10).tolist() + LOGGER.error(f"Example missing ZCTAs: {missing_zctas}") + + raise ValueError(f"Missing coordinates for {missing_count} ZCTAs - check data consistency") + + LOGGER.info(f"Final dataset: {len(zcta_data)} ZCTAs with geo and population data") + + return zcta_data + +def generate_all_synthetic_data_vectorized(zcta_data, date_list, var_name, poisson_params): + """ + Generate synthetic health data for ALL dates and ZCTAs at once using vectorized operations + Much faster than generating one date at a time + """ + LOGGER.info(f"Generating synthetic data for {len(date_list)} dates and {len(zcta_data)} ZCTAs") + + # Set random seed for reproducibility + np.random.seed(poisson_params['random_seed']) + + # Pre-calculate all spatial effects (these don't change by date) + zcta_data = zcta_data.copy() + # Create pseudo-random spatial effects based on ZCTA IDs + # Convert ZCTA strings to numeric values for spatial variation + zcta_numeric = pd.to_numeric(zcta_data['zcta'], errors='coerce').fillna(0) + zcta_normalized = (zcta_numeric % 10000) / 10000.0 # Normalize to 0-1 range + zcta_data['base_spatial_effect'] = poisson_params['spatial_variance'] * np.sin(2 * np.pi * zcta_normalized) + + # Geographic effects + lat_normalized = (zcta_data['latitude'] - 35) / 15 + lon_normalized = (zcta_data['longitude'] + 95) / 30 + zcta_data['latitude_effect'] = poisson_params['latitude_effect'] * np.sin(lat_normalized * np.pi) + zcta_data['longitude_effect'] = poisson_params['longitude_effect'] * np.cos(lon_normalized * np.pi) + + # Population effects + pop_log = np.log(np.maximum(1, zcta_data['population'])) + zcta_data['population_effect'] = poisson_params['population_effect'] * pop_log + + # Create all data at once + all_synthetic_data = [] + + for target_date in date_list: + # Calculate seasonal effect for this date + day_of_year = target_date.timetuple().tm_yday + seasonal_effect = poisson_params['seasonal_amplitude'] * np.sin(2 * np.pi * day_of_year / 365.25) + + # Calculate lambda parameters for all ZCTAs at once + lambda_params = np.maximum(0.1, + poisson_params['base_rate'] + + seasonal_effect + + zcta_data['base_spatial_effect'] + + zcta_data['latitude_effect'] + + zcta_data['longitude_effect'] + + zcta_data['population_effect'] + ) + + # Generate all counts at once using vectorized Poisson + counts = np.random.poisson(lambda_params) + + # Create records for this date + date_data = { + 'zcta': zcta_data['zcta'].values, + 'var': var_name, + 'date': target_date, + 'n': counts + } + all_synthetic_data.append(pd.DataFrame(date_data)) + + # Concatenate all data + return pd.concat(all_synthetic_data, ignore_index=True) + +@hydra.main(config_path="../conf/synthetic", config_name="config", version_base=None) +def main(cfg): + """ + Preprocess synthetic health data for data loader. + This creates synthetic data that mimics the structure of the LEGO dataset. + Only zcta daily data is supported (with hardcoded vars) + """ + + conn = duckdb.connect() + + LOGGER.info(f"Processing synthetic data for {cfg.var}") + + output_folder = f"{cfg.output_dir}/{cfg.vg_name}/{cfg.var}/" + os.makedirs(output_folder, exist_ok=True) + + year = cfg.year + horizons = cfg.horizons + LOGGER.info(f"Processing year {year}") + + # Get ZCTA data with geographic coordinates and population information + LOGGER.info("Loading ZCTA data with geographic and population information") + zcta_data = get_zcta_data_with_geo_pop( + cfg.synthetic.zcta_unique_path, + cfg.synthetic.zcta_shapefile_path, + cfg.synthetic.population_path, + year, + cfg.synthetic.mainland_only + ) + + LOGGER.info(f"Found {len(zcta_data)} ZCTAs for year {year} with complete data") + + # get days list for a given year with calendar days + days_list = [(year, month, day) for month in range(1, 13) for day in range(1, calendar.monthrange(year, month)[1] + 1)] + + # Debug option: limit to first few days for testing + if hasattr(cfg, 'debug_days') and cfg.debug_days: + days_list = days_list[:cfg.debug_days] + LOGGER.info(f"Debug mode: processing only first {len(days_list)} days") + + # Generate ALL synthetic data needed at once (much faster than per-day) + LOGGER.info("Pre-generating all synthetic data needed...") + + # Calculate all unique dates needed for all days and horizons + all_needed_dates = set() + max_horizon = max(horizons) if horizons else 0 + + for day in days_list: + t = date(day[0], day[1], day[2]) + for i in range(max_horizon + 1): + all_needed_dates.add(t + timedelta(days=i)) + + all_needed_dates = sorted(list(all_needed_dates)) + LOGGER.info(f"Need synthetic data for {len(all_needed_dates)} unique dates") + + # Generate all synthetic data at once using vectorized operations + all_synthetic_df = generate_all_synthetic_data_vectorized( + zcta_data, all_needed_dates, cfg.var, cfg.synthetic.poisson_params + ) + + # Create one large parquet file with all synthetic data + synthetic_input_dir = f"{cfg.input_dir}/{cfg.lego_dir}" + os.makedirs(synthetic_input_dir, exist_ok=True) + bulk_synthetic_file = f"{synthetic_input_dir}/{cfg.lego_nm}_bulk.parquet" + all_synthetic_df.to_parquet(bulk_synthetic_file, index=False) + LOGGER.info(f"Saved all synthetic data to {bulk_synthetic_file}") + + for day in tqdm(days_list, desc="Processing days"): + date_str = f"{day[0]}{day[1]:02d}{day[2]:02d}" + output_fname = f"{output_folder}/{cfg.var}__{date_str}.parquet" + + t = date(day[0], day[1], day[2]) + + # Use the bulk synthetic data file in queries + input_files = bulk_synthetic_file + + # Build queries for each horizon, starting with same-day (horizon = 0) + queries = [] + + # Same-day count (horizon = 0) + queries.append(f""" + SELECT + zcta, + 0 AS horizon, + n + FROM '{input_files}' + WHERE + var = '{cfg.var}' AND + date = DATE '{t}' + """) + + # Future horizons + for horizon in horizons: + t_end = t + timedelta(days=horizon) + queries.append(f""" + SELECT + zcta, + {horizon} AS horizon, + SUM(n) AS n + FROM '{input_files}' + WHERE + var = '{cfg.var}' AND + date >= DATE '{t}' AND + date <= DATE '{t_end}' + GROUP BY zcta + """) + + # Combine all queries into one + full_query = " UNION ALL ".join(queries) + + # Execute and save + conn.execute(f""" + CREATE OR REPLACE TABLE output AS + {full_query} + """) + + conn.execute(f""" + COPY (SELECT * FROM output ORDER BY zcta, horizon) + TO '{output_fname}' + """) + + # Clean up bulk synthetic data file + LOGGER.info("Cleaning up bulk synthetic data file") + try: + if os.path.exists(bulk_synthetic_file): + os.remove(bulk_synthetic_file) + except Exception as e: + LOGGER.warning(f"Could not remove bulk synthetic file: {e}") + + conn.close() + LOGGER.info("Synthetic health data preprocessing completed") + +if __name__ == "__main__": + main() \ No newline at end of file