diff --git a/README.md b/README.md index 8bc5e42..710c478 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,6 @@ Update NEMO ocean model restart files with machine learning predictions. ## Installation - Start by installing a virtual environment and then: ```bash @@ -13,11 +12,22 @@ pip install . ## Usage ```bash -nemo-spinup-restart --restart_path /path/to/restarts \ - --radical RESTART_NAME \ - --mask_file /path/to/mask.nc \ - --prediction_path /path/to/predictions \ - --ocean_terms ocean_terms.yaml +nemo-restart + --restart_path /path/to/restarts \ + --radical RESTART_NAME \ + --mask_file /path/to/mask.nc \ + --prediction_path /path/to/predictions \ + --ocean_terms ocean_terms.yaml + +nemo-upscale upscale \\ + --predictions-dir ./predictions \ + --coarse-template ./1deg/restart_template.nc \ + --coarse-mask ./1deg/mesh_mask.nc \ + --coarse-namelist ./1deg/namelist_cfg \ + --fine-template ./025deg/restart_template.nc \ + --fine-mask ./025deg/mesh_mask.nc \ + --output-dir ./generated \ + --name C2 ``` ### Required Arguments diff --git a/algorithm.md b/algorithm.md new file mode 100644 index 0000000..6f1492f --- /dev/null +++ b/algorithm.md @@ -0,0 +1,68 @@ + +Key insight: Pre-existing template lat and lon values do not matter. Mesh mask is the authoritative grid definition and we overwrite the coordinates with mask.glamt and mask.gphit values. + +--- + +## Regrid.py Algorithm + +### `upscale_predictions()` +1. Load numpy predictions (toce, soce, ssh) at specified time index +2. Create coarse restart: `create_restart_from_predictions()` +3. Regrid to fine resolution: `regrid_restart()` +4. Return path to fine restart file + +### `create_restart_from_predictions()` +1. Load template restart and mesh mask +2. Extract depth from mask (`gdept_0`) +3. Compute potential density using NEMO equation of state +4. Populate restart fields: tb/tn, sb/sn, sshb/sshn, rhop +5. Update metadata and save to NetCDF + +### `regrid_restart()` + +**Phase 1: Load & Prepare** +1. Load coarse/fine restarts and masks +2. Assign lon/lat coords to restarts from masks (glamt/gphit) +3. Save intermediate: `restart_lr.nc`, `restart_hr_template.nc` + +**Phase 2: Extrapolate Coarse** +4. Call `extrapolate_to_land(restart_lr, mask_lr)` +5. Save intermediate: `restart_lr_extrap.nc` + +**Phase 3: Regrid** +6. Create xESMF regridder (bilinear, nearest_s2d extrapolation) +7. Apply regridding: `restart_hr = regridder(restart_lr_extrap)` +8. Save intermediate: unmasked fine restart + +**Phase 4: Cleanup Coordinates** +9. Rename lat→nav_lat, lon→nav_lon +10. Drop x, y coordinate variables + +**Phase 5: Apply Fine Mask** +11. Prepare fine mask (align nav_lev, drop x/y/time) +12. Mask all variables (set land to 0.0) + +**Phase 6: Finalize** +13. Zero all velocities (ub/un/vb/vn) +14. Copy time metadata from coarse (kt, ndastp, adatrj, ntime) +15. Copy timestep from fine template (rdt) +16. Reorder variables to match template +17. Update file metadata and save + +### `extrapolate_to_land()` +1. Prepare mask (squeeze, drop x/y/time_counter, align nav_lev) +2. Create 2D surface mask +3. Apply mask to set land to NaN (3D and 4D variables) +4. Extrapolate along x-dimension (nearest neighbor) +5. Extrapolate along y-dimension (nearest neighbor) +6. Return extrapolated restart + +**Note:** No vertical extrapolation performed (~12k NaNs remain in deep ocean) + +### Key Design Choices +- **Coordinates from mask:** Always overwrite restart coords with glamt/gphit +- **xESMF requirements:** Needs lon/lat named coordinates for geographic interpolation +- **Dimension-based masking:** `.where()` broadcasts on dimension names (y, x), not coord values. +- **Conservative velocities:** Zero out for NEMO to recompute from density +- **Two-stage extrapolation:** Fill land before regridding to avoid NaN propagation + diff --git a/examples/diffusion.md b/examples/diffusion.md new file mode 100644 index 0000000..2401fee --- /dev/null +++ b/examples/diffusion.md @@ -0,0 +1,42 @@ +# Diffusion state restart file generation and evaluation + +## Data setup + +TODO: Host data on spiritx + +## Regrid and upscale + +We need 1-degree, 0.25-degree restart file and mesh mask references. + +```bash + +REFERENCE=./data/reference + +nemo-upscale upscale \ + --predictions-dir $REFERENCE/diffusion_states/chamon_C2_clean/ \ + --coarse-template $REFERENCE/1deg_restart/DINO_00000002_restart.nc \ + --coarse-mask $REFERENCE/1deg_restart/mesh_mask.nc \ + --coarse-namelist $REFERENCE/1deg_restart/namelist_cfg \ + --fine-template $REFERENCE/025deg_restart/restart25_arch/DINO_10800000_restart.nc \ + --fine-mask $REFERENCE/025deg_restart/restart25_arch/mesh_mask.nc \ + --name C2 \ + --time-index 4 \ + --output-dir ./generated +``` + +## Evaluate using spinup-evaluation + +Install spinup-evaluation as normal and create a gen-setup.yaml that looks as so: + +```gen-setup.yaml + +mesh_mask: $REFERENCE/025deg_restart/restart25_arch/mesh_mask.nc +restart_files: 'generated_restart_C2' +``` + +```bash +spinup-eval \ + --sim-path ./generated/ \ + --config ./gen-setup.yaml \ + --mode restart +``` diff --git a/examples/gen-setup.yaml b/examples/gen-setup.yaml new file mode 100644 index 0000000..1f30af5 --- /dev/null +++ b/examples/gen-setup.yaml @@ -0,0 +1,5 @@ +# DINO-setup.yaml + +mesh_mask: /Users/matt/work/nemo/spinup-data/Gorce-data/Dinonline/restart0/mesh_mask_new.nc + +restart_files: 'generated_restart_C2_fine' diff --git a/main.py b/main.py deleted file mode 100644 index ec0fd1f..0000000 --- a/main.py +++ /dev/null @@ -1,112 +0,0 @@ -# Adapted from code by Maud Tissot (Spinup-NEMO) -# Original source: https://github.com/maudtst/Spinup-NEMO -# Licensed under the MIT License -# -# Modifications in this version by ICCS, 2025 -import sys -import argparse - -from src.nemo_spinup_restart.restart import * -import xarray as xr - - -def update_restart_slice(restart_file, restart_name, mask_file): - # restart file "/thredds/idris/work/ues27zx/Restarts/" mask file '/thredds/idris/work/ues27zx/eORCA1.4.2_mesh_mask_modJD.nc' - """ - Update a restart file with new predictions and related variables. - - Parameters: - ----------- - restart_file (str) : Path to the existing restart file. - restart_name (str) : Name of the restart file. - file_mask (str) : Path to the mask file. - - Returns: - None - """ - restart_array = xr.open_dataset( - restart_file + restart_name, decode_times=False - ) # load restart file - mask_array = xr.open_dataset(mask_file, decode_times=False) # load mask file - zos_new, so_new, thetao_new = ( - restart.load_predictions() - ) # load ssh, so and thetao predictions - restart.update_pred( - restart_array, zos_new, so_new, thetao_new - ) # update restart with ssh, so and thetao predictions - e3t_new = restart.update_e3tm(restart_array, mask_array) # update e3tm and gete e3t - deptht_new = restart.get_deptht( - restart_array, mask_array - ) # get new deptht for density - restart.update_rhop(restart_array, mask_array, deptht_new) # update density - restart.update_v_velocity( - restart_array, mask_array, e3t_new[0] - ) # update meridional velocity - restart.update_u_velocity( - restart_array, mask_array, e3t_new[0] - ) # update zonal velocity - array = array.rename_vars( - {"xx": "x", "yy": "y"} - ) # inverse transformation of x and y vars - # Restart.to_netcdf(restart_file+restart_name) # save file - - -# PAs EU LE TEMPS D'ESSAYER -def update_Restarts(restarts_file, mask_file, jobs=10): - """ - Update multiple restart files in parallel. - - Parameters: - restarts_file (str) : Path to the directory containing restart files. - mask_file (str) : Path to the mask file. - jobs (int, optional) : Number of parallel jobs to run. default 10. - - Returns: - None - """ - restart_names = restart.get_restart_files( - restarts_file - ) # SUPER LONG PEUT ETRE LE FAIR EN BASH OU ERREUR - Parallel(jobs)( - delayed(update_restart_slice)(restarts_file, file, mask_file) - for file in restart_names - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Update of restart files") - parser.add_argument( - "--restart_path", type=str, help="path of restart file directory" - ) - parser.add_argument("--radical", type=str, help="radical of restart filename") - parser.add_argument("--mask_file", type=str, help="adress of mask file") - parser.add_argument( - "--prediction_path", type=str, help="path of prediction directory" - ) - parser.add_argument( - "--ocean_terms", type=str, default="ocean_terms.yaml", - help="path to ocean_terms.yaml file (default: ocean_terms.yaml)" - ) - args = parser.parse_args() - - restart = xr.open_dataset( - get_restart_files(args.restart_path, args.radical), decode_times=False - ) - mask = get_mask_file(args.mask_file, restart) - restart = load_predictions(restart, dirpath=args.prediction_path, ocean_terms_file=args.ocean_terms) - restart = propagate_pred(restart, mask) - record_full_restart(args.restart_path, args.radical, restart) - record_pieced_restart(args.restart_path, args.radical, restart) - - print("""All done. Now you just need to : - - Back transform the coordinates of the pieced restart files using ncks to the original version (see bash script xarray_to_CMIP.sh) - - Rename/Overwrite the "NEW_" restart files to their old version if you’re happy with them (see other bash script rewrite.sh) - - Point to the restart directory in your simulation config.card (if all your non-NEMO restart files are also in the restart_path directory, of course). - You might need to reorganize them in a ./OCE/Restart/CM....nc structure instead of ./OCE_CM...nc (there’s the rename.sh bash script for that) but normally it should work without. - You can see the example script Jumper.sh for how to do most of that. See you soon. :) """) - - # update_Restarts(restarts_file=args.restarts_file,mask_file=args.mask_file) - - # update_restart_files - - # python SpinUp/jumper/main/main_restart.py --restart_files '/thredds/idris/work/ues27zx/eORCA1.4.2_mesh_mask_modJD.nc' --mask_file '/thredds/idris/work/ues27zx/eORCA1.4.2_mesh_mask_modJD.nc' diff --git a/pyproject.toml b/pyproject.toml index 6236377..59a4f27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,8 @@ dependencies = [ "matplotlib", "netcdf4", "pyyaml", + "xesmf", + "f90nml", ] [project.optional-dependencies] @@ -41,7 +43,8 @@ dev = [ ] [project.scripts] -nemo-spinup-restart = "nemo_spinup_restart.cli:main" +nemo-restart = "nemo_spinup_restart.cli:main" +nemo-upscale = "nemo_spinup_restart.regrid_cli:main" [tool.setuptools.packages.find] where = ["src"] diff --git a/src/nemo_spinup_restart/regrid.py b/src/nemo_spinup_restart/regrid.py new file mode 100644 index 0000000..9746388 --- /dev/null +++ b/src/nemo_spinup_restart/regrid.py @@ -0,0 +1,422 @@ +""" +Upscale NEMO restart files from coarse to fine resolution. + +Uses standard open-source libraries: +- xarray for NetCDF handling +- xnemogcm for NEMO-specific file operations (MIT) +- xesmf for regridding (MIT) +- f90nml for namelist reading (LGPL) +""" + +import numpy as np +import xarray as xr +import xesmf as xe +import f90nml +from datetime import datetime +from pathlib import Path + + +def compute_nemo_density(temperature, salinity, depth, namelist_path): + """ + Compute potential density using NEMO's simplified linear equation of state. + + Parameters + ---------- + temperature : np.ndarray + Ocean temperature (°C), shape (nav_lev, y, x) + salinity : np.ndarray + Ocean salinity (PSU), shape (nav_lev, y, x) + depth : np.ndarray + Depth levels (m), shape (nav_lev, y, x) or (nav_lev,) + namelist_path : str + Path to NEMO namelist file containing equation of state parameters + + Returns + ------- + rhop : np.ndarray + Potential density (kg/m³), shape (nav_lev, y, x) + """ + nml = f90nml.read(namelist_path) + eos = nml["nameos"] + + rhop = ( + -eos["rn_a0"] + * (1.0 + 0.5 * eos["rn_lambda1"] * (temperature - 10.0) + eos["rn_mu1"] * depth) + * (temperature - 10.0) + + eos["rn_b0"] + * (1.0 - 0.5 * eos["rn_lambda2"] * (salinity - 35.0) - eos["rn_mu2"] * depth) + * (salinity - 35.0) + - eos["rn_nu"] * (temperature - 10.0) * (salinity - 35.0) + ) + 1026 + + return rhop + + +def create_restart_from_predictions( + restart_template_path, + mesh_mask_path, + namelist_path, + temperature, + salinity, + ssh, + output_path, +): + """ + Create a NEMO restart file from ML predictions at template resolution. + + Parameters + ---------- + restart_template_path : str + Path to template restart file (defines structure and grid) + mesh_mask_path : str + Path to mesh_mask file (for depth levels) + namelist_path : str + Path to NEMO namelist (for equation of state parameters) + temperature : np.ndarray + Temperature predictions (°C), shape (nav_lev, y, x) + salinity : np.ndarray + Salinity predictions (PSU), shape (nav_lev, y, x) + ssh : np.ndarray + Sea surface height predictions (m), shape (y, x) + output_path : str + Where to save the generated restart file + + Returns + ------- + restart : xr.Dataset + The created restart file + """ + # Load template + restart = xr.open_dataset(restart_template_path).load() + + # Load mask for depth information + mask = xr.open_dataset(mesh_mask_path) + depth = mask.gdept_0.squeeze().data # Remove extra dimensions + + # Compute density + rhop = compute_nemo_density(temperature, salinity, depth, namelist_path) + + # Populate restart file + restart["tb"] = ( + ("time_counter", "nav_lev", "y", "x"), + np.expand_dims(temperature, axis=0), + ) + restart["tn"] = ( + ("time_counter", "nav_lev", "y", "x"), + np.expand_dims(temperature, axis=0), + ) + + restart["sb"] = ( + ("time_counter", "nav_lev", "y", "x"), + np.expand_dims(salinity, axis=0), + ) + restart["sn"] = ( + ("time_counter", "nav_lev", "y", "x"), + np.expand_dims(salinity, axis=0), + ) + + restart["sshb"] = (("time_counter", "y", "x"), ssh) + restart["sshn"] = (("time_counter", "y", "x"), ssh) + + restart["rhop"] = ( + ("time_counter", "nav_lev", "y", "x"), + np.expand_dims(rhop, axis=0), + ) + + # Update metadata + restart.attrs["file_name"] = Path(output_path).name + restart.attrs["TimeStamp"] = ( + datetime.now().astimezone().strftime("%d/%m/%Y %H:%M:%S %z") + ) + + # Save + Path(output_path).unlink(missing_ok=True) + restart.to_netcdf(output_path, unlimited_dims="time_counter") + print(f"Saved {output_path}") + + return restart + + +def extrapolate_to_land(restart, mask): + """ + Extrapolate ocean values onto land points for better regridding. + + Parameters + ---------- + restart : xr.Dataset + Restart file to extrapolate + mask : xr.Dataset + NEMO mask file (mesh_mask.nc) + + Returns + ------- + restart_extrapolated : xr.Dataset + Restart with land points filled + """ + # Create 3D mask matching restart coordinates + # Squeeze removes time_counter dimension if present + # tmask_3d = mask.tmask.squeeze() + # Assign nav_lev to match restart, keep x/y dimensions aligned + # tmask_3d = tmask_3d.assign_coords({"nav_lev": restart.nav_lev}) + # Drop coordinate variables that might conflict (keep only dimension indices) + # tmask_3d = tmask_3d.drop_vars(["x", "y", "time_counter"], errors="ignore") + print(mask) + # Mask needs to match the coordinates of the restart file + + tmask_3d = mask.tmask.squeeze() + + # breakpoint() + # tmask_3d = tmask_3d.drop_vars(["x", "y", "time_counter"], errors="ignore") + + # Only align vertical coordinate - xarray will broadcast on dimension names + if "nav_lev" in restart.dims: + tmask_3d = tmask_3d.assign_coords({"nav_lev": restart.nav_lev}) + + # 2D mask for surface variables + tmask_2d = tmask_3d.isel(nav_lev=0) + + # Apply mask (set land to NaN) + # Apply mask (set land to NaN) - modify in place like grid_manipulation.py + for var_name, var_data in restart.items(): + if var_data.ndim == 3: # 2D variables (+ time) + print("Masking variable:", var_name) + restart[var_name] = var_data.where(tmask_2d == 1.0) + elif var_data.ndim == 4: # 3D variables (+ time) + print("Masking variable:", var_name) + restart[var_name] = var_data.where(tmask_3d == 1.0) + + + # Extrapolate NaN values + restart_filled_x = restart.interpolate_na( + dim="x", method="nearest", fill_value="extrapolate" + ) + restart = restart_filled_x.interpolate_na( + dim="y", method="nearest", fill_value="extrapolate" + ) + + # Extrapolate NaN values - apply to each variable individually + # for var_name, var_data in restart.items(): + # if var_data.ndim in [3, 4]: # Only extrapolate spatial variables + # print(f"Extrapolating {var_name}") + # # Fill along x + # filled_x = var_data.interpolate_na( + # dim="x", method="nearest", fill_value="extrapolate" + # ) + # # Fill along y + # filled_xy = filled_x.interpolate_na( + # dim="y", method="nearest", fill_value="extrapolate" + # ) + # restart[var_name] = filled_xy + + return restart + + +def regrid_restart( + restart_coarse_path, + restart_fine_template_path, + mesh_mask_coarse_path, + mesh_mask_fine_path, + output_path, +): + """ + Regrid a restart file from coarse to fine resolution. + + Parameters + ---------- + restart_coarse_path : str + Path to coarse resolution restart file (e.g., 1°) + restart_fine_template_path : str + Path to fine resolution restart template (e.g., 1/4°) + mesh_mask_coarse_path : str + Path to coarse resolution mesh_mask + mesh_mask_fine_path : str + Path to fine resolution mesh_mask + output_path : str + Where to save the regridded restart file + + Returns + ------- + restart_regridded : xr.Dataset + The regridded restart file at fine resolution + """ + # Load files + restart_lr = xr.open_dataset(restart_coarse_path).load() + restart_hr_template = xr.open_dataset(restart_fine_template_path).load() + mask_lr = xr.open_dataset(mesh_mask_coarse_path) + mask_hr = xr.open_dataset(mesh_mask_fine_path) + + # Extract timestep from template + timestep_fine = float(restart_hr_template["rdt"].values) + print(f"Using timestep from template: {timestep_fine}s") + + # Add lon/lat coordinates for xESMF (it needs CF-compliant coordinates) + # Get lon/lat from mesh_mask files (glamt, gphit are T-point coordinates) + restart_lr = restart_lr.assign_coords( + { + "lon": (["y", "x"], mask_lr.glamt.squeeze().values), + "lat": (["y", "x"], mask_lr.gphit.squeeze().values), + } + ) + restart_hr_template = restart_hr_template.assign_coords( + { + "lon": (["y", "x"], mask_hr.glamt.squeeze().values), + "lat": (["y", "x"], mask_hr.gphit.squeeze().values), + } + ) + + # restart_lr.to_netcdf("out_stages/restart_lr.nc", unlimited_dims="time_counter") + + # restart_hr_template.to_netcdf( + # "out_stages/restart_hr_template.nc", unlimited_dims="time_counter" + # ) + + # Extrapolate coarse restart onto land + restart_lr_extrap = extrapolate_to_land(restart_lr, mask_lr) + + print(restart_lr_extrap) + + # restart_lr_extrap.to_netcdf( + # "out_stages/restart_lr_extrap.nc", unlimited_dims="time_counter" + # ) + + # Create regridder (bilinear interpolation) + # breakpoint() + regridder = xe.Regridder( + restart_lr_extrap, + restart_hr_template, + "bilinear", + extrap_method="nearest_s2d", + ignore_degenerate=True, + ) + + # Apply regridding + restart_hr = regridder(restart_lr_extrap) + # breakpoint() + # restart_hr.to_netcdf( + # "generated_mine/generated_restart_C2_fine_unmasked-1.nc", + # unlimited_dims="time_counter", + # ) + # breakpoint() + + # Clean up coordinates: rename lat/lon (from xESMF) to nav_lat/nav_lon, drop x/y + # The lat/lon from xESMF are the true 2D grid coordinates, so we keep them as nav_lat/nav_lon + restart_hr = restart_hr.rename({"lat": "nav_lat", "lon": "nav_lon"}) + restart_hr = restart_hr.drop_vars(["x", "y"], errors="ignore") + + # Apply fine resolution mask + # Rename mask dimensions to match restart coordinates + tmask_3d = mask_hr.tmask.squeeze() # Remove time dimension if singleton + tmask_3d = tmask_3d.assign_coords({"nav_lev": restart_hr.nav_lev}) + tmask_3d = tmask_3d.drop_vars(["x", "y", "time_counter"], errors="ignore") + tmask_3d = tmask_3d.compute() + + tmask_2d = tmask_3d.isel(nav_lev=0).compute() + + # Mask regridded data (keep ocean values where mask==1, set land to 0) + for var_name, var_data in restart_hr.items(): + if var_data.ndim == 3: # 2D variables (+ time) + restart_hr[var_name] = var_data.where(tmask_2d == 1.0, 0.0) + elif var_data.ndim == 4: # 3D variables (+ time) + restart_hr[var_name] = var_data.where(tmask_3d == 1.0, 0.0) + + # Zero out velocities (will be recomputed by NEMO) + restart_hr["ub"].values[:] = 0.0 + restart_hr["un"].values[:] = 0.0 + restart_hr["vb"].values[:] = 0.0 + restart_hr["vn"].values[:] = 0.0 + + # Copy metadata from coarse restart and fine template + restart_hr["kt"] = restart_lr.kt + restart_hr["ndastp"] = restart_lr.ndastp + restart_hr["adatrj"] = restart_lr.adatrj + restart_hr["ntime"] = restart_lr.ntime + restart_hr["rdt"] = restart_hr_template["rdt"] # Use timestep from fine resolution + + # Match variable order from template + restart_hr = restart_hr[list(restart_hr_template.keys())] + + # Update metadata + restart_hr.attrs["file_name"] = Path(output_path).name + restart_hr.attrs["TimeStamp"] = ( + datetime.now().astimezone().strftime("%d/%m/%Y %H:%M:%S %z") + ) + # Save + Path(output_path).unlink(missing_ok=True) + restart_hr.to_netcdf(output_path, unlimited_dims="time_counter") + print(f"Saved {output_path}") + + return restart_hr + + +def upscale_predictions( + predictions_dir, + coarse_restart_template, + coarse_mesh_mask, + coarse_namelist, + fine_restart_template, + fine_mesh_mask, + output_dir, + generation_name, + time_index=-1, +): + """ + Complete workflow: numpy predictions → coarse restart → regrid → fine restart. + + Parameters + ---------- + predictions_dir : str + Directory containing pred_thetao.npy, pred_so.npy, pred_zos.npy + coarse_restart_template : str + Template restart file at coarse resolution + coarse_mesh_mask : str + Mesh mask at coarse resolution + coarse_namelist : str + Namelist at coarse resolution (for density calculation) + fine_restart_template : str + Template restart file at fine resolution + fine_mesh_mask : str + Mesh mask at fine resolution + output_dir : str + Directory to save generated restart files + generation_name : str + Identifier for this generation (e.g., 'C2') + time_index : int, optional + Time index to extract from prediction arrays (default: -1, last timestep) + + Returns + ------- + fine_restart_path : str + Path to the final upscaled restart file + """ + # Step 1: Load predictions (select specified time index) + temp = np.load(Path(predictions_dir) / "toce.npy")[time_index, :, :, :] + sal = np.load(Path(predictions_dir) / "soce.npy")[time_index, :, :, :] + ssh = np.load(Path(predictions_dir) / "ssh.npy")[time_index, :, :] + + print( + f"Loaded predictions at time index {time_index}: temp {temp.shape}, sal {sal.shape}, ssh {ssh.shape}" + ) + + # Step 2: Create coarse restart + coarse_output = Path(output_dir) / f"generated_restart_{generation_name}_coarse.nc" + create_restart_from_predictions( + coarse_restart_template, + coarse_mesh_mask, + coarse_namelist, + temp, + sal, + ssh, + str(coarse_output), + ) + + # Step 3: Regrid to fine resolution + fine_output = Path(output_dir) / f"generated_restart_{generation_name}_fine.nc" + regrid_restart( + str(coarse_output), + fine_restart_template, + coarse_mesh_mask, + fine_mesh_mask, + str(fine_output), + ) + + return str(fine_output) diff --git a/src/nemo_spinup_restart/regrid_cli.py b/src/nemo_spinup_restart/regrid_cli.py new file mode 100644 index 0000000..333311b --- /dev/null +++ b/src/nemo_spinup_restart/regrid_cli.py @@ -0,0 +1,190 @@ +""" +CLI for upscaling NEMO restart files from coarse to fine resolution. +""" + +import argparse +from pathlib import Path + +from nemo_spinup_restart.regrid import ( + regrid_restart, + upscale_predictions, +) + + +def main(): + """Main CLI entry point for nemo-upscale.""" + parser = argparse.ArgumentParser( + description="Upscale NEMO restart files from coarse to fine resolution", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" + Examples: + # Complete upscaling workflow (predictions -> coarse -> fine) + nemo-upscale upscale \\ + --predictions-dir ./predictions \\ + --coarse-template ./1deg/restart_template.nc \\ + --coarse-mask ./1deg/mesh_mask.nc \\ + --coarse-namelist ./1deg/namelist_cfg \\ + --fine-template ./025deg/restart_template.nc \\ + --fine-mask ./025deg/mesh_mask.nc \\ + --output-dir ./generated \\ + --name C2 + + # Regrid existing coarse restart to fine resolution + nemo-upscale regrid \\ + --coarse-restart ./generated/restart_coarse.nc \\ + --fine-template ./025deg/restart_template.nc \\ + --coarse-mask ./1deg/mesh_mask.nc \\ + --fine-mask ./025deg/mesh_mask.nc \\ + --output ./generated/restart_fine.nc + """, + ) + + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + # Subcommand: upscale (complete workflow) + upscale_parser = subparsers.add_parser( + "upscale", + help="Complete workflow: numpy predictions -> coarse restart -> fine restart", + ) + upscale_parser.add_argument( + "--predictions-dir", + type=str, + required=True, + help="Directory containing pred_thetao.npy, pred_so.npy, pred_zos.npy", + ) + upscale_parser.add_argument( + "--coarse-template", + type=str, + required=True, + help="Template restart file at coarse resolution", + ) + upscale_parser.add_argument( + "--coarse-mask", + type=str, + required=True, + help="Mesh mask file at coarse resolution", + ) + upscale_parser.add_argument( + "--coarse-namelist", + type=str, + required=True, + help="NEMO namelist at coarse resolution (for density calculation)", + ) + upscale_parser.add_argument( + "--fine-template", + type=str, + required=True, + help="Template restart file at fine resolution", + ) + upscale_parser.add_argument( + "--fine-mask", type=str, required=True, help="Mesh mask file at fine resolution" + ) + upscale_parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Directory to save generated restart files", + ) + upscale_parser.add_argument( + "--name", + type=str, + required=True, + help="Identifier for this generation (e.g., 'C2')", + ) + upscale_parser.add_argument( + "--time-index", + type=int, + default=-1, + help="Time index to extract from prediction arrays (default: -1, last timestep)", + ) + + # Subcommand: regrid (coarse -> fine only) + regrid_parser = subparsers.add_parser( + "regrid", help="Regrid existing restart file to fine resolution" + ) + regrid_parser.add_argument( + "--coarse-restart", + type=str, + required=True, + help="Coarse resolution restart file to regrid", + ) + regrid_parser.add_argument( + "--fine-template", + type=str, + required=True, + help="Template restart file at fine resolution", + ) + regrid_parser.add_argument( + "--coarse-mask", + type=str, + required=True, + help="Mesh mask file at coarse resolution", + ) + regrid_parser.add_argument( + "--fine-mask", type=str, required=True, help="Mesh mask file at fine resolution" + ) + regrid_parser.add_argument( + "--output", type=str, required=True, help="Path to save regridded restart file" + ) + + args = parser.parse_args() + + if args.command is None: + parser.print_help() + return + + # Execute the appropriate command + if args.command == "upscale": + print("Starting complete upscaling workflow...") + print(f"Predictions: {args.predictions_dir}") + print(f"Coarse: {args.coarse_template}") + print(f"Fine: {args.fine_template}") + print(f"Output: {args.output_dir}") + + import numpy as np + + # Verify predictions exist + pred_dir = Path(args.predictions_dir) + if not (pred_dir / "toce.npy").exists(): + raise FileNotFoundError( + f"toce.npy not found in {args.predictions_dir}" + ) + if not (pred_dir / "soce.npy").exists(): + raise FileNotFoundError(f"soce.npy not found in {args.predictions_dir}") + if not (pred_dir / "ssh.npy").exists(): + raise FileNotFoundError(f"ssh.npy not found in {args.predictions_dir}") + + output_file = upscale_predictions( + args.predictions_dir, + args.coarse_template, + args.coarse_mask, + args.coarse_namelist, + args.fine_template, + args.fine_mask, + args.output_dir, + args.name, + args.time_index, + ) + + print(f"\n✓ Upscaling complete!") + print(f" Fine resolution restart: {output_file}") + + elif args.command == "regrid": + print("Regridding restart to fine resolution...") + print(f"Coarse: {args.coarse_restart}") + print(f"Fine template: {args.fine_template}") + print(f"Output: {args.output}") + + regrid_restart( + args.coarse_restart, + args.fine_template, + args.coarse_mask, + args.fine_mask, + args.output, + ) + + print(f"\n✓ Regridding complete: {args.output}") + + +if __name__ == "__main__": + main()