diff --git a/.edf/README.md b/.edf/README.md new file mode 100644 index 0000000..2316e95 --- /dev/null +++ b/.edf/README.md @@ -0,0 +1,31 @@ +run: +``` +export EDF_PATH=`pwd`/.edf +``` +This adds the repository path to the EDF search path. + +run: +``` +srun -A a-a122 --environment=ubuntu2 cat /etc/os-release +``` + +# local development +srun --environment $PWD/.edf/hirad-ci.toml -A a-a122 -p debug --pty bash + + + +# list current images +podman images + +# build according to the dockerfile into an image with tag tmpv1, from current directory. +podman build -f ci/docker/Dockerfile -t tmpv1 . + +# +podman run -it localhost/tmpv1 + +mkdir /capstor/scratch/cscs/mmcgloho/images + +# export the image into a sqsh file so it is availabe outside the interactive shell +enroot import -x mount -o /capstor/scratch/cscs/mmcgloho/images/hirad-pytorch-25.01-py3.sqsh podman://localhost/tmpv1 + +ls /capstor/scratch/cscs/mmcgloho/images \ No newline at end of file diff --git a/.edf/gemma-pytorch.toml b/.edf/gemma-pytorch.toml new file mode 100644 index 0000000..3c4723a --- /dev/null +++ b/.edf/gemma-pytorch.toml @@ -0,0 +1,14 @@ +image = "/iopsstor/scratch/cscs/${USER}/pytorch-24.01-py3-venv/pytorch-24.01-py3-venv.sqsh" + +mounts = ["/capstor", "/users","/iopsstor/scratch/cscs/mmcgloho"] + +writable = true + +[annotations] +com.hooks.aws_ofi_nccl.enabled = "true" +com.hooks.aws_ofi_nccl.variant = "cuda12" + +[env] +FI_CXI_DISABLE_HOST_REGISTER = "1" +FI_MR_CACHE_MONITOR = "userfaultfd" +NCCL_DEBUG = "INFO" diff --git a/.edf/hirad-ci.toml b/.edf/hirad-ci.toml new file mode 100644 index 0000000..2ff1c6b --- /dev/null +++ b/.edf/hirad-ci.toml @@ -0,0 +1,14 @@ +image = "/capstor/scratch/cscs/${USER}/images/hirad-pytorch-25.01-py3.sqsh" + +mounts = ["/capstor","/iopsstor"] + +writable = true + +[annotations] +com.hooks.aws_ofi_nccl.enabled = "true" +com.hooks.aws_ofi_nccl.variant = "cuda12" + +[env] +FI_CXI_DISABLE_HOST_REGISTER = "1" +FI_MR_CACHE_MONITOR = "userfaultfd" +NCCL_DEBUG = "INFO" diff --git a/.edf/ngc-pytorch.toml b/.edf/ngc-pytorch.toml new file mode 100644 index 0000000..4f790ca --- /dev/null +++ b/.edf/ngc-pytorch.toml @@ -0,0 +1,6 @@ +# https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch +image = "nvcr.io#nvidia/pytorch:22.01-py3" +mounts = ["/capstor/scratch/cscs/${USER}:/capstor/scratch/cscs/${USER}"] +workdir = "/capstor/scratch/cscs/${USER}" + +# Maybe above should be iopsstor \ No newline at end of file diff --git a/.edf/ubuntu.toml b/.edf/ubuntu.toml new file mode 100644 index 0000000..4f14ee9 --- /dev/null +++ b/.edf/ubuntu.toml @@ -0,0 +1,3 @@ +image = "library/ubuntu:24.04" +mounts = ["/capstor/scratch/cscs/mmcgloho:/capstor/scratch/cscs/mmcgloho"] +workdir = "/capstor/scratch/cscs/mmcgloho" diff --git a/.edf/ubuntu2.toml b/.edf/ubuntu2.toml new file mode 100644 index 0000000..22dba65 --- /dev/null +++ b/.edf/ubuntu2.toml @@ -0,0 +1,5 @@ + + +image = "library/ubuntu:24.04" +mounts = ["/capstor/scratch/cscs/${USER}:/capstor/scratch/cscs/${USER}"] +workdir = "/capstor/scratch/cscs/${USER}" diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c7d62e7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,188 @@ +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +# output files +*.out +*.torch +plots/* +*.npz + +# conda +.conda/* + +# temp +temp.* + +# local script +interpolate.sh +core_clariden-ln002_241188 + diff --git a/README.md b/README.md new file mode 100644 index 0000000..b0dbd2e --- /dev/null +++ b/README.md @@ -0,0 +1,167 @@ +# HiRAD-Gen + +HiRAD-Gen is short for high-resolution atmospheric downscaling using generative models. This repository contains the code and configuration required to train and use the model. + +## Installation (Alps) + +To set up the environment for **HiRAD-Gen** on Alps supercomputer, follow these steps: + +1. **Start the PyTorch user environment**: + ```bash + uenv start pytorch/v2.6.0:v1 --view=default + ``` + +2. **Create a Python virtual environment** (replace `{env_name}` with your desired environment name): + ```bash + python -m venv ./{env_name} + ``` + +3. **Activate the virtual environment**: + ```bash + source ./{env_name}/bin/activate + ``` + +4. **Install project dependencies**: + ```bash + pip install -e . + ``` + +This will set up the necessary environment to run HiRAD-Gen within the Alps infrastructure. + +## Training + +### Run regression model training (Alps) + +1. Script for running the training of regression model is in `src/hirad/train_regression.sh`. +Inside this script set the following: +```bash +### OUTPUT ### +#SBATCH --output=your_path_to_output_log +#SBATCH --error=your_path_to_output_error +``` +```bash +#SBATCH -A your_compute_group +``` +```bash +srun bash -c " + . ./{your_env_name}/bin/activate + python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml +" +``` + +2. Set up the following config files in `src/hirad/conf`: + +- In `training_era_cosmo_regression.yaml` set: +``` +hydra: + run: + dir: your_path_to_save_training_output +``` +- In `training/era_cosmo_regression.yaml` set: +``` +hp: + training_duration: number of samples to train for (set to 4 for debugging, 512 fits into 30 minutes on 1 gpu with total_batch_size: 4) +``` +- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default. + +3. Submit the job with: +```bash +sbatch src/hirad/train_regression.sh +``` + +### Run diffusion model training (Alps) +Before training diffusion model, checkpoint for regression model has to exist. + +1. Script for running the training of diffusion model is in `src/hirad/train_diffusion.sh`. +Inside this script set the following: +```bash +### OUTPUT ### +#SBATCH --output=your_path_to_output_log +#SBATCH --error=your_path_to_output_error +``` +```bash +#SBATCH -A your_compute_group +``` +```bash +srun bash -c " + . ./{your_env_name}/bin/activate + python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml +" +``` + +2. Set up the following config files in `src/hirad/conf`: + +- In `training_era_cosmo_diffusion.yaml` set: +``` +hydra: + run: + dir: your_path_to_save_training_output +``` +- In `training/era_cosmo_regression.yaml` set: +``` +hp: + training_duration: number of samples to train for (set to 4 for debugging, 512 fits into 30 minutes on 1 gpu with total_batch_size: 4) +io: + regression_checkpoint_path: path_to_directory_containing_regression_training_model_checkpoints +``` +- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default. + +3. Submit the job with: +```bash +sbatch src/hirad/train_diffusion.sh +``` + +## Inference + +### Running inference on Alps + +1. Script for running the inference is in `src/hirad/generate.sh`. +Inside this script set the following: +```bash +### OUTPUT ### +#SBATCH --output=your_path_to_output_log +#SBATCH --error=your_path_to_output_error +``` +```bash +#SBATCH -A your_compute_group +``` +```bash +srun bash -c " + . ./{your_env_name}/bin/activate + python src/hirad/inference/generate.py --config-name=generate_era_cosmo.yaml +" +``` + +2. Set up the following config files in `src/hirad/conf`: + +- In `generate_era_cosmo.yaml` set: +``` +hydra: + run: + dir: your_path_to_save_inference_output +``` +- In `generation/era_cosmo.yaml`: +Choose the inference mode: +``` +inference_mode: all/regression/diffusion +``` +by default `all` does both regression and diffusion. Depending on mode, regression and/or diffusion model pretrained weights should be provided: +``` +io: + res_ckpt_path: path_to_directory_containing_diffusion_training_model_checkpoints + reg_ckpt_path: path_to_directory_containing_regression_training_model_checkpoints +``` +Finally, from the dataset, subset of time steps can be chosen to do inference for. + +One way is to list steps under `times:` in format `%Y%m%d-%H%M` for era5_cosmo dataset. + +The other way is to specify `times_range:` with three items: first time step (`%Y%m%d-%H%M`), last time step (`%Y%m%d-%H%M`), hour shift (int). Hour shift specifies distance in hours between closest time steps for specific dataset (6 for era_cosmo). + +By default, inference is done for one time step `20160101-0000` + +- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default. + +3. Submit the job with: +```bash +sbatch src/hirad/generate.sh +``` \ No newline at end of file diff --git a/ci/cscs.yml b/ci/cscs.yml index fc92645..85579a0 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -12,14 +12,24 @@ build_job: stage: build extends: .container-builder-cscs-gh200 variables: - DOCKERFILE: ci/docker/Dockerfile + DOCKERFILE: ci/docker/Dockerfile.ci + +test_job: + stage: test + extends: .container-runner-clariden-gh200 + image: $PERSIST_IMAGE_NAME + script: + - env + - cd /src + - python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml + #- python src/hirad/eval/run_scoring.py + #- MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) MASTER_PORT=29500 RANK=${SLURM_PROCID} LOCAL_RANK=${SLURM_LOCALID} WORLD_SIZE=${SLURM_NPROCS} python -c "import os, torch; import torch.distributed as dist; local_rank = int(os.environ['LOCAL_RANK']); torch.cuda.set_device(local_rank); dist.init_process_group('nccl', init_method='env://'); rank = dist.get_rank(); print(f'Hello from rank {rank}'); t = torch.tensor([rank]).to('cuda'); dist.all_reduce(t); print(f'The sum of ranks is {t}.'); dist.destroy_process_group()" + variables: + SLURM_JOB_NUM_NODES: 2 + SLURM_NTASKS: 4 + USE_NCCL: cuda12 + FI_CXI_DISABLE_HOST_REGISTER: 1 + FI_MR_CACHE_MONITOR: userfaultfd + NCCL_DEBUG: INFO + -#test_job: -# stage: test -# extends: .container-runner-clariden-gh200 -# image: $PERSIST_IMAGE_NAME -# script: -# - /opt/helloworld/bin/hello -# variables: -# SLURM_JOB_NUM_NODES: 2 -# SLURM_NTASKS: 2 diff --git a/ci/docker/Dockerfile b/ci/docker/Dockerfile deleted file mode 100644 index 4772d76..0000000 --- a/ci/docker/Dockerfile +++ /dev/null @@ -1,41 +0,0 @@ -# Following some suggestions in https://meteoswiss.atlassian.net/wiki/spaces/APN/pages/719684202/Clariden+Alps+environment+setup - -#FROM ubuntu:22.04 as builder -FROM nvcr.io/nvidia/pytorch:25.01-py3 - -COPY . /src - -# setup -RUN apt-get update && apt-get install python3-pip python3-venv -y -RUN pip install --upgrade \ - pip - #ninja - #wheel - #packaging - #setuptools - -# update flash-attn -RUN MAX_JOBS=16 pip install --upgrade --no-build-isolation \ - flash-attn==2.7.4.post1 -v - -# install the rest of dependencies -# TODO: Factor pydeps into a separate file(s) -# TODO: Add versions for things -RUN pip install \ - anemoi-datasets \ - cartopy \ - matplotlib \ - numpy \ - pandas \ - scipy \ - torch - - -# replace pynvml with nvidia-ml-py -RUN pip uninstall -y pynvml && pip install nvidia-ml-py - -#CMD ["python3.11" "src/input_data/interpolate_basic_test.py"] - - - - diff --git a/ci/docker/Dockerfile.ci b/ci/docker/Dockerfile.ci new file mode 100644 index 0000000..3378585 --- /dev/null +++ b/ci/docker/Dockerfile.ci @@ -0,0 +1,49 @@ +# Following some suggestions in https://meteoswiss.atlassian.net/wiki/spaces/APN/pages/719684202/Clariden+Alps+environment+setup + +#FROM ubuntu:22.04 as builder +#FROM nvcr.io/nvidia/pytorch:25.01-py3 +FROM nvcr.io/nvidia/physicsnemo/physicsnemo:25.03 + +# setup +#RUN apt-get update && apt-get install python3-pip python3-venv -y +RUN pip install --upgrade \ + pip + #ninja + #wheel + #packaging + #setuptools + +# update flash-attn +#RUN MAX_JOBS=16 pip install --upgrade --no-build-isolation \ +# flash-attn==2.7.4.post1 -v + +# install the rest of dependencies +# TODO: Factor pydeps into a separate file(s) +# TODO: Add versions for things +RUN pip install \ + anemoi-datasets \ + cartopy + #matplotlib \ + #numpy \ + #pandas \ + #scipy \ + #torch + + +# replace pynvml with nvidia-ml-py +#RUN pip uninstall -y pynvml && pip install nvidia-ml-py + +#CMD ["python3.11" "src/input_data/interpolate_basic_test.py"] + +COPY . /src +WORKDIR /src + +ENV NCCL_TESTS_VERSION=2.15.0 + +RUN wget -O nccl-tests-${NCCL_TESTS_VERSION}.tar.gz https://github.com/NVIDIA/nccl-tests/archive/refs/tags/v${NCCL_TESTS_VERSION}.tar.gz \ + && tar xf nccl-tests-${NCCL_TESTS_VERSION}.tar.gz \ + && cd nccl-tests-${NCCL_TESTS_VERSION} \ + && MPI=1 MPI_HOME=/opt/hpcx/ompi make -j$(nproc) \ + && cd .. \ + && rm -rf nccl-tests-${NCCL_TESTS_VERSION}.tar.gz + diff --git a/ci/docker/Dockerfile.dev b/ci/docker/Dockerfile.dev new file mode 100644 index 0000000..a37e84d --- /dev/null +++ b/ci/docker/Dockerfile.dev @@ -0,0 +1,22 @@ +FROM nvcr.io/nvidia/pytorch:25.01-py3 + +# setup +#RUN apt-get update && apt-get install python3-pip python3-venv -y +RUN pip install --upgrade \ + pip + +# install the rest of dependencies +# TODO: Factor pydeps into a separate file(s) +# TODO: Add versions for things +RUN pip install \ + anemoi-datasets \ + cartopy + + +COPY . /src +WORKDIR /src + +# Useful utilities for performance monitoring/analysis/debugging +RUN apt-get update \ + && apt-get install -yqq --no-install-recommends strace valgrind htop nvtop atop ioping fio \ + && rm -rf /var/lib/apt/lists/* \ No newline at end of file diff --git a/config/containers/storage.conf b/config/containers/storage.conf new file mode 100644 index 0000000..740925d --- /dev/null +++ b/config/containers/storage.conf @@ -0,0 +1,17 @@ +# https://confluence.cscs.ch/spaces/KB/pages/868834153/Building+container+images+on+Alps +# TOML format + +[storage] +driver = "overlay" +runroot = "/dev/shm/$USER/runroot" +graphroot = "/dev/shm/$USER/root" + +[storage.options.overlay] +mount_program = "/usr/bin/fuse-overlayfs-1.13" + +# In the above configuration, /dev/shm is used to store the container images. +# /dev/shm is the mount point of a tmpfs filesystem and is compatible with the +# user namespaces used by Podman. The limitation of this approach is that +# container images created during a job allocation are deleted when the job +# ends. Therefore, the image # needs to either be pushed to a container registry +# or imported by the Container Engine before the job allocation finishes. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..1477899 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,32 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "hirad-gen" +version = "0.1.0" +description = "High resolution atmospheric downscaling using generative machine learning" +authors = [ + { name="Petar Stamenkovic", email="petar.stamenkovic@meteoswiss.ch" } +] +readme = "README.md" +requires-python = ">=3.12" +license = {file = "LICENSE"} + +dependencies = [ + "cartopy>=0.24.1", + "cftime>=1.6.4", + "hydra-core>=1.3.2", + "matplotlib>=3.10.1", + "omegaconf>=2.3.0", + "tensorboard>=2.19.0", + "termcolor>=3.1.0", + "torchinfo>=1.8.0", + "treelib>=1.7.1" +] + +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] \ No newline at end of file diff --git a/src/hirad/conf/dataset/era_cosmo.yaml b/src/hirad/conf/dataset/era_cosmo.yaml new file mode 100644 index 0000000..63d7361 --- /dev/null +++ b/src/hirad/conf/dataset/era_cosmo.yaml @@ -0,0 +1,2 @@ +type: era5_cosmo +dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/trim_19_overfit \ No newline at end of file diff --git a/src/hirad/conf/generate_era_cosmo.yaml b/src/hirad/conf/generate_era_cosmo.yaml new file mode 100644 index 0000000..5d7649d --- /dev/null +++ b/src/hirad/conf/generate_era_cosmo.yaml @@ -0,0 +1,20 @@ +hydra: + job: + chdir: true + name: generation_full + run: + dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + # Dataset + - dataset/era_cosmo + + # Sampler + - sampler/stochastic + #- sampler/deterministic + + # Generation + - generation/era_cosmo + #- generation/patched_based \ No newline at end of file diff --git a/src/hirad/conf/generation/era_cosmo.yaml b/src/hirad/conf/generation/era_cosmo.yaml new file mode 100644 index 0000000..be4219d --- /dev/null +++ b/src/hirad/conf/generation/era_cosmo.yaml @@ -0,0 +1,44 @@ +num_ensembles: 8 + # Number of ensembles to generate per input +seed_batch_size: 4 + # Size of the batched inference +inference_mode: all + # Choose between "all" (regression + diffusion), "regression" or "diffusion" + # Patch size. Patch-based sampling will be utilized if these dimensions differ from + # img_shape_x and img_shape_y +# overlap_pixels: 0 + # Number of overlapping pixels between adjacent patches +# boundary_pixels: 0 + # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary + # artifact. +patching: False +hr_mean_conditioning: True +# sample_res: full + # Sampling resolution +times_range: null +times: + - 20160101-0000 + # - 20160101-0600 + # - 20160101-1200 +has_laed_time: False + +perf: + force_fp16: False + # Whether to force fp16 precision for the model. If false, it'll use the precision + # specified upon training. + use_torch_compile: False + # whether to use torch.compile on the diffusion model + # this will make the first time stamp generation very slow due to compilation overheads + # but will significantly speed up subsequent inference runs + num_writer_workers: 1 + # number of workers to use for writing file + # To support multiple workers a threadsafe version of the netCDF library must be used + +io: + res_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/diffusion_refactoring/checkpoints_diffusion + # res_ckpt_path: null + # Checkpoint filename for the diffusion model + reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_refactoring/checkpoints_regression + # reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_test/checkpoints_regression + # Checkpoint filename for the mean predictor model + output_path: ./images \ No newline at end of file diff --git a/src/hirad/conf/model/era_cosmo_diffusion.yaml b/src/hirad/conf/model/era_cosmo_diffusion.yaml new file mode 100644 index 0000000..441239e --- /dev/null +++ b/src/hirad/conf/model/era_cosmo_diffusion.yaml @@ -0,0 +1,15 @@ +name: diffusion + # Name of the preconditioner +hr_mean_conditioning: True + # High-res mean (regression's output) as additional condition + +# Standard model parameters. +model_args: + gridtype: "sinusoidal" + # Type of positional grid to use: 'sinusoidal', 'learnable', 'linear'. + # Controls how positional information is encoded. + N_grid_channels: 4 + # Number of channels for positional grid embeddings + embedding_type: "zero" + # Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++, + # 'zero' for none \ No newline at end of file diff --git a/src/hirad/conf/model/era_cosmo_regression.yaml b/src/hirad/conf/model/era_cosmo_regression.yaml new file mode 100644 index 0000000..29b43e8 --- /dev/null +++ b/src/hirad/conf/model/era_cosmo_regression.yaml @@ -0,0 +1,10 @@ +name: regression +hr_mean_conditioning: False + +# Default regression model parameters. Do not modify. +model_args: + "N_grid_channels": 4 + # Number of channels for positional grid embeddings + "embedding_type": "zero" + # Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++, + # 'zero' for none \ No newline at end of file diff --git a/src/hirad/conf/model_size/mini.yaml b/src/hirad/conf/model_size/mini.yaml new file mode 100644 index 0000000..2eb8f8a --- /dev/null +++ b/src/hirad/conf/model_size/mini.yaml @@ -0,0 +1,26 @@ +# @package _global_.model + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +model_args: + # Base multiplier for the number of channels across the network. + model_channels: 64 + # Per-resolution multipliers for the number of channels. + channel_mult: [1, 2, 2] + # Resolutions at which self-attention layers are applied. + attn_resolutions: [16] \ No newline at end of file diff --git a/src/hirad/conf/model_size/normal.yaml b/src/hirad/conf/model_size/normal.yaml new file mode 100644 index 0000000..b81fe15 --- /dev/null +++ b/src/hirad/conf/model_size/normal.yaml @@ -0,0 +1,26 @@ +# @package _global_.model + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +model_args: + # Base multiplier for the number of channels across the network. + model_channels: 128 + # Per-resolution multipliers for the number of channels. + channel_mult: [1, 2, 2, 2, 2] + # Resolutions at which self-attention layers are applied. + attn_resolutions: [28] \ No newline at end of file diff --git a/src/hirad/conf/sampler/deterministic.yaml b/src/hirad/conf/sampler/deterministic.yaml new file mode 100644 index 0000000..35bc0f6 --- /dev/null +++ b/src/hirad/conf/sampler/deterministic.yaml @@ -0,0 +1,4 @@ +type: deterministic +num_steps: 9 + # Number of denoising steps +solver: euler \ No newline at end of file diff --git a/src/hirad/conf/sampler/stochastic.yaml b/src/hirad/conf/sampler/stochastic.yaml new file mode 100644 index 0000000..2481cd3 --- /dev/null +++ b/src/hirad/conf/sampler/stochastic.yaml @@ -0,0 +1,3 @@ +type: stochastic +# boundary_pix: 2 +# overlap_pix: 4 \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_diffusion.yaml b/src/hirad/conf/training/era_cosmo_diffusion.yaml new file mode 100644 index 0000000..f8d19e6 --- /dev/null +++ b/src/hirad/conf/training/era_cosmo_diffusion.yaml @@ -0,0 +1,41 @@ +# Hyperparameters +hp: + training_duration: 16 + # Training duration based on the number of processed samples + total_batch_size: 4 + # Total batch size + batch_size_per_gpu: "auto" + # Batch size per GPU + lr: 0.0002 + # Learning rate + grad_clip_threshold: null + # no gradient clipping for defualt non-patch-based training + lr_decay: 1 + # LR decay rate + lr_rampup: 0 + # Rampup for learning rate, in number of samples + +# Performance +perf: + fp_optimizations: amp-bf16 + # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] + # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} + dataloader_workers: 8 + # DataLoader worker processes + songunet_checkpoint_level: 0 # 0 means no checkpointing + # Gradient checkpointing level, value is number of layers to checkpoint + +# I/O +io: + regression_checkpoint_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_overfit/checkpoints_regression + # Where to load the regression checkpoint + print_progress_freq: 128 + # How often to print progress + save_checkpoint_freq: 5000 + # How often to save the checkpoints, measured in number of processed samples + validation_freq: 5000 + # how often to record the validation loss, measured in number of processed samples + validation_steps: 10 + # how many loss evaluations are used to compute the validation loss per checkpoint + # how many loss evaluations are used to compute the validation loss per checkpoint + checkpoint_dir: . \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_regression.yaml b/src/hirad/conf/training/era_cosmo_regression.yaml new file mode 100644 index 0000000..76bdc4e --- /dev/null +++ b/src/hirad/conf/training/era_cosmo_regression.yaml @@ -0,0 +1,39 @@ +# Hyperparameters +hp: + training_duration: 8 + # Training duration based on the number of processed samples + total_batch_size: 4 + # Total batch size + batch_size_per_gpu: "auto" + # Batch size per GPU + lr: 0.001 + #0.0002 + # Learning rate + grad_clip_threshold: null + # no gradient clipping for defualt non-patch-based training + lr_decay: 1 + # LR decay rate + lr_rampup: 0 + # Rampup for learning rate, in number of samples + +# Performance +perf: + fp_optimizations: amp-bf16 + # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] + # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} + dataloader_workers: 4 + # DataLoader worker processes + songunet_checkpoint_level: 0 # 0 means no checkpointing + # Gradient checkpointing level, value is number of layers to checkpoint + +# I/O +io: + print_progress_freq: 128 + # How often to print progress + save_checkpoint_freq: 5000 + # How often to save the checkpoints, measured in number of processed samples + validation_freq: 5000 + # how often to record the validation loss, measured in number of processed samples + validation_steps: 10 + # how many loss evaluations are used to compute the validation loss per checkpoint + checkpoint_dir: . \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_diffusion.yaml b/src/hirad/conf/training_era_cosmo_diffusion.yaml new file mode 100644 index 0000000..0a069e9 --- /dev/null +++ b/src/hirad/conf/training_era_cosmo_diffusion.yaml @@ -0,0 +1,21 @@ +hydra: + job: + chdir: true + name: diffusion + run: + dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/era_cosmo + + # Model + - model/era_cosmo_diffusion + + - model_size/normal + + # Training + - training/era_cosmo_diffusion \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_regression.yaml b/src/hirad/conf/training_era_cosmo_regression.yaml new file mode 100644 index 0000000..1de83d9 --- /dev/null +++ b/src/hirad/conf/training_era_cosmo_regression.yaml @@ -0,0 +1,21 @@ +hydra: + job: + chdir: true + name: regression + run: + dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/era_cosmo + + # Model + - model/era_cosmo_regression + + - model_size/normal + + # Training + - training/era_cosmo_regression \ No newline at end of file diff --git a/src/hirad/datasets/__init__.py b/src/hirad/datasets/__init__.py new file mode 100644 index 0000000..706284e --- /dev/null +++ b/src/hirad/datasets/__init__.py @@ -0,0 +1,3 @@ +from .dataset import init_train_valid_datasets_from_config, init_dataset_from_config +from .era5_cosmo import ERA5_COSMO +from .base import DownscalingDataset \ No newline at end of file diff --git a/src/hirad/datasets/base.py b/src/hirad/datasets/base.py new file mode 100644 index 0000000..22b00d2 --- /dev/null +++ b/src/hirad/datasets/base.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Tuple + +import numpy as np +import torch + + +@dataclass +class ChannelMetadata: + """Metadata describing a data channel.""" + + name: str + level: str = "" + auxiliary: bool = False + + +class DownscalingDataset(torch.utils.data.Dataset, ABC): + """An abstract class that defines the interface for downscaling datasets.""" + + @abstractmethod + def longitude(self) -> np.ndarray: + """Get longitude values from the dataset.""" + pass + + @abstractmethod + def latitude(self) -> np.ndarray: + """Get latitude values from the dataset.""" + pass + + @abstractmethod + def input_channels(self) -> List[ChannelMetadata]: + """Metadata for the input channels. A list of ChannelMetadata, one for each channel""" + pass + + @abstractmethod + def output_channels(self) -> List[ChannelMetadata]: + """Metadata for the output channels. A list of ChannelMetadata, one for each channel""" + pass + + @abstractmethod + def time(self) -> List: + """Get time values from the dataset.""" + pass + + @abstractmethod + def image_shape(self) -> Tuple[int, int]: + """Get the (height, width) of the data (same for input and output).""" + pass + + def normalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from physical units to normalized data.""" + return x + + def denormalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from normalized data to physical units.""" + return x + + def normalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from physical units to normalized data.""" + return x + + def denormalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from normalized data to physical units.""" + return x + + def info(self) -> dict: + """Get information about the dataset.""" + return {} diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py new file mode 100644 index 0000000..7ba8833 --- /dev/null +++ b/src/hirad/datasets/dataset.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable, Tuple, Union +import copy +import torch + +from hirad.utils.function_utils import InfiniteSampler +from hirad.distributed import DistributedManager + +from .era5_cosmo import ERA5_COSMO +from .base import DownscalingDataset + + +# this maps all known dataset types to the corresponding init function +known_datasets = { + "era5_cosmo": ERA5_COSMO, +} + + +def init_train_valid_datasets_from_config( + dataset_cfg: dict, + dataloader_cfg: Union[dict, None] = None, + batch_size: int = 1, + seed: int = 0, + train_test_split: bool = True, +) -> Tuple[ + DownscalingDataset, + Iterable, + Union[DownscalingDataset, None], + Union[Iterable, None], +]: + """ + A wrapper function for managing the train-test split for the CWB dataset. + + Parameters: + - dataset_cfg (dict): Configuration for the dataset. + - dataloader_cfg (dict, optional): Configuration for the dataloader. Defaults to None. + - batch_size (int): The number of samples in each batch of data. Defaults to 1. + - seed (int): The random seed for dataset shuffling. Defaults to 0. + - train_test_split (bool): A flag to determine whether to create a validation dataset. Defaults to True. + + Returns: + - Tuple[base.DownscalingDataset, Iterable, Optional[base.DownscalingDataset], Optional[Iterable]]: A tuple containing the training dataset and iterator, and optionally the validation dataset and iterator if train_test_split is True. + """ + + config = copy.deepcopy(dataset_cfg) + if 'validation_path': + del config['validation_path'] + (dataset, dataset_iter) = init_dataset_from_config( + config, dataloader_cfg, batch_size=batch_size, seed=seed + ) + if train_test_split: + valid_dataset_cfg = copy.deepcopy(dataset_cfg) + valid_dataset_cfg["dataset_path"] = valid_dataset_cfg["validation_path"] + del valid_dataset_cfg['validation_path'] + (valid_dataset, valid_dataset_iter) = init_dataset_from_config( + valid_dataset_cfg, dataloader_cfg, batch_size=batch_size, seed=seed + ) + else: + valid_dataset = valid_dataset_iter = None + + return dataset, dataset_iter, valid_dataset, valid_dataset_iter + + +def init_dataset_from_config( + dataset_cfg: dict, + dataloader_cfg: Union[dict, None] = None, + batch_size: int = 1, + seed: int = 0, +) -> Tuple[DownscalingDataset, Iterable]: + dataset_cfg = copy.deepcopy(dataset_cfg) + dataset_type = dataset_cfg.pop("type", "era5_cosmo") + if "validation_path" in dataset_cfg: + del dataset_cfg['validation_path'] + if "train_test_split" in dataset_cfg: + # handled by init_train_valid_datasets_from_config + del dataset_cfg["train_test_split"] + dataset_init_func = known_datasets[dataset_type] + + dataset_obj = dataset_init_func(**dataset_cfg) + if dataloader_cfg is None: + dataloader_cfg = {} + + dist = DistributedManager() + dataset_sampler = InfiniteSampler( + dataset=dataset_obj, rank=dist.rank, num_replicas=dist.world_size, seed=seed + ) + + dataset_iterator = iter( + torch.utils.data.DataLoader( + dataset=dataset_obj, + sampler=dataset_sampler, + batch_size=batch_size, + worker_init_fn=None, + **dataloader_cfg, + ) + ) + + return (dataset_obj, dataset_iterator) diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py new file mode 100644 index 0000000..f97dbc6 --- /dev/null +++ b/src/hirad/datasets/era5_cosmo.py @@ -0,0 +1,130 @@ +from .base import DownscalingDataset, ChannelMetadata +import os +import numpy as np +import torch +from typing import List, Tuple +import yaml +import torch.nn.functional as F + +class ERA5_COSMO(DownscalingDataset): + def __init__(self, dataset_path: str): + super().__init__() + + #TODO switch hanbdling paths to Path rather than pure strings + self._dataset_path = dataset_path + self._era5_path = os.path.join(dataset_path, 'era-interpolated') + self._cosmo_path = os.path.join(dataset_path, 'cosmo') + self._info_path = os.path.join(dataset_path, 'info') + + # load file list (each file is one date-time state) + self._file_list = os.listdir(self._cosmo_path) + + # Load cosmo info and channel names + with open(os.path.join(self._info_path,'cosmo.yaml'), 'r') as file: + self._cosmo_info = yaml.safe_load(file) + self._cosmo_channels = [ChannelMetadata(name) for name in self._cosmo_info['select']] + + # Load era5 info and channel names + with open(os.path.join(self._info_path,'era.yaml'), 'r') as file: + self._era_info = yaml.safe_load(file) + self._era_channels = [ChannelMetadata(name) if len(name.split('_'))==1 + else ChannelMetadata(name.split('_')[0],name.split('_')[1]) + for name in self._era_info['select']] + + # Load stats for normalizing channels of input and output + + cosmo_stats = torch.load(os.path.join(self._info_path,'cosmo-stats'), weights_only=False) + self.output_mean = cosmo_stats['mean'] + self.output_std = cosmo_stats['stdev'] + + era_stats = torch.load(os.path.join(self._info_path,'era-stats'), weights_only=False) + self.input_mean = era_stats['mean'] + self.input_std = era_stats['stdev'] + + + def __getitem__(self, idx): + """Get cosmo and era5 interpolated to cosmo grid""" + # get era5 data point + # squeeze the ensemble dimesnsion + # reshape to image_shape + # flip so that it starts in top-left corner (by default it is bottom left) + # orig_shape = [350,542] #TODO currently padding to be divisible by 16 + orig_shape = self.image_shape() + era5_data = np.flip(torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)\ + .squeeze() \ + .reshape(-1,*orig_shape), + 1) + era5_data = self.normalize_input(era5_data) + # get cosmo data point + cosmo_data = np.flip(torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)\ + .squeeze() \ + .reshape(-1,*orig_shape), + 1) + cosmo_data = self.normalize_output(cosmo_data) + # return samples + return torch.tensor(cosmo_data),\ + torch.tensor(era5_data), + # return F.pad(torch.tensor(cosmo_data), pad=(1,1,1,1), mode='constant', value=0), \ + # F.pad(torch.tensor(era5_data), pad=(1,1,1,1), mode='constant', value=0), \ + # 0 + + def __len__(self): + return len(self._file_list) + + + def longitude(self) -> np.ndarray: + """Get longitude values from the dataset.""" + lat_lon = torch.load(os.path.join(self._info_path,'cosmo-lat-lon'), weights_only=False) + return lat_lon[:,1] + + + def latitude(self) -> np.ndarray: + """Get latitude values from the dataset.""" + lat_lon = torch.load(os.path.join(self._info_path,'cosmo-lat-lon'), weights_only=False) + return lat_lon[:,0] + + + def input_channels(self) -> List[ChannelMetadata]: + """Metadata for the input channels. A list of ChannelMetadata, one for each channel""" + return self._era_channels + + + def output_channels(self) -> List[ChannelMetadata]: + """Metadata for the output channels. A list of ChannelMetadata, one for each channel""" + return self._cosmo_channels + + + def time(self) -> List: + """Get time values from the dataset.""" + #TODO Choose the time format and convert to that, currently it's a string from a filename + return [file.split('.')[0] for file in self._file_list] + + + def image_shape(self) -> Tuple[int, int]: + """Get the (height, width) of the data (same for input and output).""" + #TODO load from info, I hardcode it for now (cosmo from anemoi-datasets minus trim-edge=20) + return 352,544 #TODO 350,542 is orig size, UNet requires dimenions divisible by 16, for now, I just add zeros to orig images + + + def normalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from physical units to normalized data.""" + return (x - self.input_mean.reshape((self.input_mean.shape[0],1,1))) \ + / self.input_std.reshape((self.input_std.shape[0],1,1)) + + + def denormalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from normalized data to physical units.""" + return x * self.input_std.reshape((self.input_std.shape[0],1,1)) \ + + self.input_mean.reshape((self.input_mean.shape[0],1,1)) + + + def normalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from physical units to normalized data.""" + return (x - self.output_mean.reshape((self.output_mean.shape[0],1,1))) \ + / self.output_std.reshape((self.output_std.shape[0],1,1)) + + + def denormalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from normalized data to physical units.""" + return x * self.output_std.reshape((self.output_std.shape[0],1,1)) \ + + self.output_mean.reshape((self.output_mean.shape[0],1,1)) \ No newline at end of file diff --git a/src/hirad/distributed/__init__.py b/src/hirad/distributed/__init__.py new file mode 100644 index 0000000..0da01f3 --- /dev/null +++ b/src/hirad/distributed/__init__.py @@ -0,0 +1 @@ +from .manager import DistributedManager \ No newline at end of file diff --git a/src/hirad/distributed/config.py b/src/hirad/distributed/config.py new file mode 100644 index 0000000..2808d92 --- /dev/null +++ b/src/hirad/distributed/config.py @@ -0,0 +1,247 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Union + +from treelib import Tree + + +class ProcessGroupNode: + """ + Class to store the attributes of a distributed process group + + Attributes + ---------- + name : str + Name of the process group + size : Optional[int] + Optional, number of processes in the process group + """ + + def __init__( + self, + name: str, + size: Optional[int] = None, + ): + """ + Constructor for the ProcessGroupNode class + + Parameters + ---------- + name : str + Name of the process group + size : Optional[int] + Optional, size of the process group + """ + self.name = name + self.size = size + + def __str__(self): + """ + String representation of the process group node + + Returns + ------- + str + String representation of the process group node + """ + return "ProcessGroupNode(" f"name={self.name}, " f"size={self.size}, " + + def __repr__(self): + """ + String representation of the process group node + + Returns + ------- + str + String representation of the process group node + """ + return self.__str__() + + +class ProcessGroupConfig: + """ + Class to define the configuration of a model's parallel process group structure as a + tree. Each node of the tree is of type `ProcessGroupNode`. + + Once the process group config structure (i.e, the tree structure) is set, it is + sufficient to set only the sizes for each leaf process group. Then, the size of + every parent group can be automatically computed as the product reduction of the + sub-tree of that parent group node. + + Examples + -------- + >>> from hirad.distributed import ProcessGroupNode, ProcessGroupConfig + >>> + >>> # Create world group that contains all processes that are part of this job + >>> world = ProcessGroupNode("world") + >>> + >>> # Create the process group config with the highest level process group + >>> config = ProcessGroupConfig(world) + >>> + >>> # Create model and data parallel sub-groups + >>> # Sub-groups of a single node are guaranteed to be orthogonal by construction + >>> # Nodes can be added with either the name of the node or the node itself + >>> config.add_node(ProcessGroupNode("model_parallel"), parent=world) + >>> config.add_node(ProcessGroupNode("data_parallel"), parent="world") + >>> + >>> # Create spatial and channel parallel sub-groups + >>> config.add_node(ProcessGroupNode("spatial_parallel"), parent="model_parallel") + >>> config.add_node(ProcessGroupNode("channel_parallel"), parent="model_parallel") + >>> + >>> config.leaf_groups() + ['data_parallel', 'spatial_parallel', 'channel_parallel'] + >>> + >>> # Set leaf group sizes + >>> # Note: product of all leaf-node sizes should be the world size + >>> group_sizes = {"channel_parallel": 3, "spatial_parallel": 2, "data_parallel": 4} + >>> config.set_leaf_group_sizes(group_sizes) # Update all parent group sizes too + >>> config.get_node("model_parallel").size + 6 + """ + + def __init__(self, node: ProcessGroupNode): + """ + Constructor to the ProcessGroupConfig class + + Parameters + ---------- + node : ProcessGroupNode + Root node of the tree, typically would be 'world' + Note, it is generally recommended to set the child groups for 'world' + to 'model_parallel' and 'data_parallel' to aid with distributed + data parallel training unless there is a specific reason to choose a + different structure + """ + self.root = node + self.root_id = node.name + self.tree = Tree() + self.tree.create_node(node.name, node.name, data=node) + + def add_node(self, node: ProcessGroupNode, parent=Union[str, ProcessGroupNode]): + """ + Add a node to the process group config + + Parameters + ---------- + node : ProcessGroupNode + The new node to be added to the config + parent : Union[str, ProcessGroupNode] + Parent node of the node to be added. Should already be in the config. + If str, it is the name of the parent node. Otherwise, the parent + ProcessGroupNode itself. + """ + if isinstance(parent, ProcessGroupNode): + parent = parent.name + self.tree.create_node(node.name, node.name, data=node, parent=parent) + + def get_node(self, name: str) -> ProcessGroupNode: + """ + Method to get the node given the name of the node + + Parameters + ---------- + name : str + Name of the node to retrieve + + Returns + ------- + ProcessGroupNode + Node with the given name from the config + """ + return self.tree.get_node(name).data + + def update_parent_sizes(self, verbose: bool = False) -> int: + """ + Method to update parent node sizes after setting the sizes for each leaf node + + Parameters + ---------- + verbose : bool + If True, print a message each time a parent node size was updated + + Returns + ------- + int + Size of the root node + """ + return _tree_product_reduction(self.tree, self.root_id, verbose=verbose) + + def leaf_groups(self) -> List[str]: + """ + Get a list of all leaf group names + + Returns + ------- + List[str] + List of all leaf node names + """ + return [n.identifier for n in self.tree.leaves()] + + def set_leaf_group_sizes( + self, group_sizes: Dict[str, int], update_parent_sizes: bool = True + ): + """ + Set process group sizes for all leaf groups + + Parameters + ---------- + group_sizes : Dict[str, int] + Dictionary with a mapping of each leaf group name to its size + update_parent_sizes : bool + Update all parent group sizes based on the leaf group if True + If False, only set the leaf group sizes. + """ + for id, size in group_sizes.items(): + if not self.tree.contains(id): + raise AssertionError( + f"Process group {id} is not in this process group config" + ) + node = self.tree.get_node(id) + if not node.is_leaf(): + raise AssertionError(f"Process group {id} is not a leaf group") + node.data.size = size + + if update_parent_sizes: + self.update_parent_sizes() + + +def _tree_product_reduction(tree, node_id, verbose=False): + """ + Function to traverse a tree and compute the product reduction of + the sub-tree for each node starting from `node_id` + """ + children = tree.children(node_id) + node = tree.get_node(node_id) + if not children: + if node.data.size is None: + raise AssertionError("Leaf nodes should have a valid size set") + return node.data.size + + product = 1 + + for child in children: + product *= _tree_product_reduction(tree, child.identifier) + + if node.data.size != product: + if verbose: + print( + "Updating size of node " + f"{node.data.name} from {node.data.size} to {product}" + ) + node.data.size = product + + return product diff --git a/src/hirad/distributed/manager.py b/src/hirad/distributed/manager.py new file mode 100644 index 0000000..eca46c6 --- /dev/null +++ b/src/hirad/distributed/manager.py @@ -0,0 +1,779 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import os +import queue +import warnings +from typing import Optional, Tuple +from warnings import warn + +import numpy as np +import torch +import torch.distributed as dist + +from .config import ProcessGroupConfig, ProcessGroupNode + +warnings.simplefilter("default", DeprecationWarning) + + +class UndefinedGroupError(Exception): + """Exception for querying an undefined process group using the PhysicsNeMo DistributedManager""" + + def __init__(self, name: str): + """ + + Parameters + ---------- + name : str + Name of the process group being queried. + + """ + message = ( + f"Cannot query process group '{name}' before it is explicitly created." + ) + super().__init__(message) + + +class UninitializedDistributedManagerWarning(Warning): + """Warning to indicate usage of an uninitialized DistributedManager""" + + def __init__(self): + message = ( + "A DistributedManager object is being instantiated before " + + "this singleton class has been initialized. Instantiating a manager before " + + "initialization can lead to unexpected results where processes fail " + + "to communicate. Initialize the distributed manager via " + + "DistributedManager.initialize() before instantiating." + ) + super().__init__(message) + + +class DistributedManager(object): + """Distributed Manager for setting up distributed training environment. + + This is a singleton that creates a persistance class instance for storing parallel + environment information through out the life time of the program. This should be + used to help set up Distributed Data Parallel and parallel datapipes. + + Note + ---- + One should call `DistributedManager.initialize()` prior to constructing a manager + object + + Example + ------- + >>> DistributedManager.initialize() + >>> manager = DistributedManager() + >>> manager.rank + 0 + >>> manager.world_size + 1 + """ + + _shared_state = {} + + def __new__(cls): + obj = super(DistributedManager, cls).__new__(cls) + obj.__dict__ = cls._shared_state + + # Set the defaults + if not hasattr(obj, "_rank"): + obj._rank = 0 + if not hasattr(obj, "_world_size"): + obj._world_size = 1 + if not hasattr(obj, "_local_rank"): + obj._local_rank = 0 + if not hasattr(obj, "_distributed"): + obj._distributed = False + if not hasattr(obj, "_device"): + obj._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + if not hasattr(obj, "_cuda"): + obj._cuda = torch.cuda.is_available() + if not hasattr(obj, "_broadcast_buffers"): + obj._broadcast_buffers = False + if not hasattr(obj, "_find_unused_parameters"): + obj._find_unused_parameters = False + if not hasattr(obj, "_initialization_method"): + obj._initialization_method = "None" + if not hasattr(obj, "_groups"): + obj._groups = {} + if not hasattr(obj, "_group_ranks"): + obj._group_ranks = {} + if not hasattr(obj, "_group_names"): + obj._group_names = {} + if not hasattr(obj, "_is_initialized"): + obj._is_initialized = False + if not hasattr(obj, "_global_mesh"): + obj._global_mesh = None # Lazy initialized right when it's first needed + if not hasattr(obj, "_mesh_dims"): + obj._mesh_dims = {} # Dictionary mapping axis names to sizes + + return obj + + def __init__(self): + if not self._is_initialized: + raise UninitializedDistributedManagerWarning() + super().__init__() + + @property + def rank(self): + """Process rank""" + return self._rank + + @property + def local_rank(self): + """Process rank on local machine""" + return self._local_rank + + @property + def world_size(self): + """Number of processes in distributed environment""" + return self._world_size + + @property + def device(self): + """Process device""" + return self._device + + @property + def distributed(self): + """Distributed environment""" + return self._distributed + + @property + def cuda(self): + """If cuda is available""" + return self._cuda + + @property + def mesh_dims(self): + """Mesh Dimensions as dictionary (axis name : size)""" + return self._mesh_dims + + @property + def group_names(self): + """ + Returns a list of all named process groups created + """ + return self._groups.keys() + + @property + def global_mesh(self): + """ + Returns the global mesh. If it's not initialized, it will be created when this is called. + """ + if self._global_mesh is None: + # Fully flat mesh (1D) by default: + self.initialize_mesh(mesh_shape=(-1,), mesh_dim_names=("world",)) + + return self._global_mesh + + def mesh_names(self): + """ + Return mesh axis names + """ + return self._mesh_dims.keys() + + def mesh_sizes(self): + """ + Return mesh axis sizes + """ + return self._mesh_dims.values() + + def group(self, name=None): + """ + Returns a process group with the given name + If name is None, group is also None indicating the default process group + If named group does not exist, UndefinedGroupError exception is raised + """ + if name in self._groups.keys(): + return self._groups[name] + elif name is None: + return None + else: + raise UndefinedGroupError(name) + + def mesh(self, name=None): + """ + Return a device_mesh with the given name. + Does not initialize. If the mesh is not created + already, will raise and error + + Parameters + ---------- + name : str, optional + Name of desired mesh, by default None + """ + + if name in self._global_mesh.axis_names: + return self._global_mesh[name] + elif name is None: + return self._global_mesh + else: + raise UndefinedGroupError(f"Mesh axis {name} not defined") + + def group_size(self, name=None): + """ + Returns the size of named process group + """ + if name is None: + return self._world_size + group = self.group(name) + return dist.get_world_size(group=group) + + def group_rank(self, name=None): + """ + Returns the rank in named process group + """ + if name is None: + return self._rank + group = self.group(name) + return dist.get_rank(group=group) + + def group_name(self, group=None): + """ + Returns the name of process group + """ + if group is None: + return None + return self._group_names[group] + + @property + def broadcast_buffers(self): + """broadcast_buffers in PyTorch DDP""" + return self._broadcast_buffers + + @broadcast_buffers.setter + def broadcast_buffers(self, broadcast: bool): + """Setter for broadcast_buffers""" + self._broadcast_buffers = broadcast + + @property + def find_unused_parameters(self): + """find_unused_parameters in PyTorch DDP""" + return self._find_unused_parameters + + @find_unused_parameters.setter + def find_unused_parameters(self, find_params: bool): + """Setter for find_unused_parameters""" + if find_params: + warn( + "Setting `find_unused_parameters` in DDP to true, " + "use only if necessary." + ) + self._find_unused_parameters = find_params + + def __str__(self): + output = ( + f"Initialized process {self.rank} of {self.world_size} using " + f"method '{self._initialization_method}'. Device set to {str(self.device)}" + ) + return output + + @classmethod + def is_initialized(cls) -> bool: + """If manager singleton has been initialized""" + return cls._shared_state.get("_is_initialized", False) + + @staticmethod + def get_available_backend(): + """Get communication backend""" + if torch.cuda.is_available() and torch.distributed.is_nccl_available(): + return "nccl" + else: + return "gloo" + + @staticmethod + def initialize_env(): + """Setup method using generic initialization""" + rank = int(os.environ.get("RANK")) + world_size = int(os.environ.get("WORLD_SIZE")) + if "LOCAL_RANK" in os.environ: + local_rank = os.environ.get("LOCAL_RANK") + if local_rank is not None: + local_rank = int(local_rank) + else: + local_rank = rank % torch.cuda.device_count() + + else: + local_rank = rank % torch.cuda.device_count() + + # Read env variables + addr = os.environ.get("MASTER_ADDR") + port = os.environ.get("MASTER_PORT") + + DistributedManager.setup( + rank=rank, + world_size=world_size, + local_rank=local_rank, + addr=addr, + port=port, + backend=DistributedManager.get_available_backend(), + ) + + @staticmethod + def initialize_open_mpi(addr, port): + """Setup method using OpenMPI initialization""" + rank = int(os.environ.get("OMPI_COMM_WORLD_RANK")) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE")) + local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")) + + DistributedManager.setup( + rank=rank, + world_size=world_size, + local_rank=local_rank, + addr=addr, + port=port, + backend=DistributedManager.get_available_backend(), + method="openmpi", + ) + + @staticmethod + def initialize_slurm(port): + """Setup method using SLURM initialization""" + rank = int(os.environ.get("SLURM_PROCID")) + world_size = int(os.environ.get("SLURM_NPROCS")) + local_rank = int(os.environ.get("SLURM_LOCALID")) + addr = os.environ.get("MASTER_ADDR") + + DistributedManager.setup( + rank=rank, + world_size=world_size, + local_rank=local_rank, + addr=addr, + port=port, + backend=DistributedManager.get_available_backend(), + method="slurm", + ) + + @staticmethod + def initialize(): + """ + Initialize distributed manager + + Current supported initialization methods are: + `ENV`: PyTorch environment variable initialization + https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization + `SLURM`: Initialization on SLURM systems. + Uses `SLURM_PROCID`, `SLURM_NPROCS`, `SLURM_LOCALID` and + `SLURM_LAUNCH_NODE_IPADDR` environment variables. + `OPENMPI`: Initialization for OpenMPI launchers. + Uses `OMPI_COMM_WORLD_RANK`, `OMPI_COMM_WORLD_SIZE` and + `OMPI_COMM_WORLD_LOCAL_RANK` environment variables. + + Initialization by default is done using the first valid method in the order + listed above. Initialization method can also be explicitly controlled using the + `PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD` environment variable and setting it + to one of the options above. + """ + if DistributedManager.is_initialized(): + warn("Distributed manager is already intialized") + return + + addr = os.getenv("MASTER_ADDR", "localhost") + port = os.getenv("MASTER_PORT", "12355") + # https://pytorch.org/docs/master/notes/cuda.html#id5 + # was changed in version 2.2 + #TODO why is setting this important? + if torch.__version__ < (2, 2): + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + else: + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0" + initialization_method = os.getenv( + "DISTRIBUTED_INITIALIZATION_METHOD" + ) + if initialization_method is None: + try: + DistributedManager.initialize_env() + except TypeError: + if "SLURM_PROCID" in os.environ: + DistributedManager.initialize_slurm(port) + elif "OMPI_COMM_WORLD_RANK" in os.environ: + DistributedManager.initialize_open_mpi(addr, port) + else: + warn( + "Could not initialize using ENV, SLURM or OPENMPI methods. Assuming this is a single process job" + ) + DistributedManager._shared_state["_is_initialized"] = True + elif initialization_method == "ENV": + DistributedManager.initialize_env() + elif initialization_method == "SLURM": + DistributedManager.initialize_slurm(port) + elif initialization_method == "OPENMPI": + DistributedManager.initialize_open_mpi(addr, port) + else: + raise RuntimeError( + "Unknown initialization method " + f"{initialization_method}. " + "Supported values for " + "DISTRIBUTED_INITIALIZATION_METHOD are " + "ENV, SLURM and OPENMPI" + ) + + # Set per rank numpy random seed for data sampling + np.random.seed(seed=DistributedManager().rank) + + def initialize_mesh( + self, mesh_shape: Tuple[int, ...], mesh_dim_names: Tuple[str, ...] + ) -> dist.DeviceMesh: + """ + Initialize a global device mesh over the entire distributed job. + + Creates a multi-dimensional mesh of processes that can be used for distributed + operations. The mesh shape must multiply to equal the total world size, with + one dimension optionally being flexible (-1). + + Parameters + ---------- + mesh_shape : Tuple[int, ...] + Tuple of ints describing the size of each mesh dimension. Product must equal + world_size. One dimension can be -1 to be automatically calculated. + + mesh_dim_names : Tuple[str, ...] + Names for each mesh dimension. Must match length of mesh_shape. + + Returns + ------- + torch.distributed.DeviceMesh + The initialized device mesh + + Raises + ------ + RuntimeError + If mesh dimensions are invalid or don't match world size + AssertionError + If distributed environment is not available + """ + + manager = DistributedManager() + if not manager.distributed: + raise AssertionError( + "torch.distributed is unavailable. " + "Check pytorch build to ensure the distributed package is available. " + "If building PyTorch from source, set `USE_DISTRIBUTED=1` " + "to enable the distributed package" + ) + + # Assert basic properties: + if len(mesh_shape) == 0: + raise RuntimeError( + "Device Mesh requires at least one mesh dimension in `mesh_shape`" + ) + if len(mesh_shape) != len(mesh_dim_names): + raise RuntimeError( + "mesh_shape and mesh_dim_names must have the same length, but found " + f"{len(mesh_shape)} and {len(mesh_dim_names)} respectively." + ) + if len(set(mesh_dim_names)) != len(mesh_dim_names): + raise RuntimeError("Mesh dimension names must be unique") + + # Check against the total mesh shape vs. world size: + total_mesh_shape = np.prod(mesh_shape) + + # Allow one shape to be -1 + if -1 in mesh_shape: + residual_shape = int(self.world_size / (-1 * total_mesh_shape)) + + # Replace -1 with the computed size: + mesh_shape = [residual_shape if m == -1 else m for m in mesh_shape] + # Recompute total shape: + total_mesh_shape = np.prod(mesh_shape) + + if total_mesh_shape != self.world_size: + raise RuntimeError( + "Device Mesh num elements must equal world size of " + f"{total_mesh_shape} but was configured by user with " + f"global size of {self.world_size}." + ) + + # Actually create the mesh: + self._global_mesh = dist.init_device_mesh( + "cuda" if self.cuda else "cpu", + mesh_shape, + mesh_dim_names=mesh_dim_names, + ) + + # Finally, upon success, cache the mesh dimensions: + self._mesh_dims = {key: val for key, val in zip(mesh_dim_names, mesh_shape)} + + return self._global_mesh + + @staticmethod + def setup( + rank=0, + world_size=1, + local_rank=None, + addr="localhost", + port="12355", + backend="nccl", + method="env", + ): + """Set up PyTorch distributed process group and update manager attributes""" + os.environ["MASTER_ADDR"] = addr + os.environ["MASTER_PORT"] = str(port) + + DistributedManager._shared_state["_is_initialized"] = True + manager = DistributedManager() + + manager._distributed = torch.distributed.is_available() + if manager._distributed: + # Update rank and world_size if using distributed + manager._rank = rank + manager._world_size = world_size + if local_rank is None: + manager._local_rank = rank % torch.cuda.device_count() + else: + manager._local_rank = local_rank + + manager._device = torch.device( + f"cuda:{manager.local_rank}" if torch.cuda.is_available() else "cpu" + ) + + #TODO device_id makes the init hang, couldn't figure out why + if manager._distributed: + # Setup distributed process group + # try: + dist.init_process_group( + backend, + rank=manager.rank, + world_size=manager.world_size, + ) + # rank=manager.rank, + # world_size=manager.world_size, + # device_id=manager.device, + # except TypeError: + # # device_id only introduced in PyTorch 2.3 + # dist.init_process_group( + # backend, + # rank=manager.rank, + # world_size=manager.world_size, + # ) + + if torch.cuda.is_available(): + # Set device for this process and empty cache to optimize memory usage + torch.cuda.set_device(manager.device) + torch.cuda.device(manager.device) + torch.cuda.empty_cache() + + manager._initialization_method = method + + @staticmethod + def create_process_subgroup( + name: str, size: int, group_name: Optional[str] = None, verbose: bool = False + ): # pragma: no cover + """ + Create a process subgroup of a parent process group. This must be a collective + call by all processes participating in this application. + + Parameters + ---------- + name : str + Name of the process subgroup to be created. + + size : int + Size of the process subgroup to be created. This must be an integer factor of + the parent group's size. + + group_name : Optional[str] + Name of the parent process group, optional. If None, the default process group + will be used. Default None. + + verbose : bool + Print out ranks of each created process group, default False. + + """ + manager = DistributedManager() + if not manager.distributed: + raise AssertionError( + "torch.distributed is unavailable. " + "Check pytorch build to ensure the distributed package is available. " + "If building PyTorch from source, set `USE_DISTRIBUTED=1` " + "to enable the distributed package" + ) + + if name in manager._groups: + raise AssertionError(f"Group with name {name} already exists") + + # Get parent group's params + group = manager._groups[group_name] if group_name else None + group_size = dist.get_world_size(group=group) + num_groups = manager.world_size // group_size + + # Get number of sub-groups per parent group + if group_size % size != 0: + raise AssertionError( + f"Cannot divide group size {group_size} evenly into subgroups of" + f" size {size}" + ) + num_subgroups = group_size // size + + # Create all the sub-groups + # Note: all ranks in the job need to create all sub-groups in + # the same order even if a rank is not part of a sub-group + manager._group_ranks[name] = [] + for g in range(num_groups): + for i in range(num_subgroups): + # Get global ranks that are part of this sub-group + start = i * size + end = start + size + if group_name: + ranks = manager._group_ranks[group_name][g][start:end] + else: + ranks = list(range(start, end)) + # Create sub-group and keep track of ranks + tmp_group = dist.new_group(ranks=ranks) + manager._group_ranks[name].append(ranks) + if manager.rank in ranks: + # Set group in manager only if this rank is part of the group + manager._groups[name] = tmp_group + manager._group_names[tmp_group] = name + + if verbose and manager.rank == 0: + print(f"Process group '{name}':") + for grp in manager._group_ranks[name]: + print(" ", grp) + + @staticmethod + def create_orthogonal_process_group( + orthogonal_group_name: str, group_name: str, verbose: bool = False + ): # pragma: no cover + """ + Create a process group that is orthogonal to the specified process group. + + Parameters + ---------- + orthogonal_group_name : str + Name of the orthogonal process group to be created. + + group_name : str + Name of the existing process group. + + verbose : bool + Print out ranks of each created process group, default False. + + """ + manager = DistributedManager() + if not manager.distributed: + raise AssertionError( + "torch.distributed is unavailable. " + "Check pytorch build to ensure the distributed package is available. " + "If building PyTorch from source, set `USE_DISTRIBUTED=1` " + "to enable the distributed package" + ) + + if group_name not in manager._groups: + raise ValueError(f"Group with name {group_name} does not exist") + if orthogonal_group_name in manager._groups: + raise ValueError(f"Group with name {orthogonal_group_name} already exists") + + group_ranks = manager._group_ranks[group_name] + orthogonal_ranks = [list(i) for i in zip(*group_ranks)] + + for ranks in orthogonal_ranks: + tmp_group = dist.new_group(ranks=ranks) + if manager.rank in ranks: + # Set group in manager only if this rank is part of the group + manager._groups[orthogonal_group_name] = tmp_group + manager._group_names[tmp_group] = orthogonal_group_name + + manager._group_ranks[orthogonal_group_name] = orthogonal_ranks + + if verbose and manager.rank == 0: + print(f"Process group '{orthogonal_group_name}':") + for grp in manager._group_ranks[orthogonal_group_name]: + print(" ", grp) + + @staticmethod + def create_group_from_node( + node: ProcessGroupNode, + parent: Optional[str] = None, + verbose: bool = False, + ): # pragma: no cover + if node.size is None: + raise AssertionError( + "Cannot create groups from a ProcessGroupNode that is not fully" + " populated. Ensure that config.set_leaf_group_sizes is called first" + " with `update_parent_sizes = True`" + ) + + DistributedManager.create_process_subgroup( + node.name, node.size, group_name=parent, verbose=verbose + ) + # Create orthogonal process group + orthogonal_group = f"__orthogonal_to_{node.name}" + DistributedManager.create_orthogonal_process_group( + orthogonal_group, node.name, verbose=verbose + ) + return orthogonal_group + + @staticmethod + def create_groups_from_config( + config: ProcessGroupConfig, verbose: bool = False + ): # pragma: no cover + + warnings.warn( + "DistributedManager.create_groups_from_config is no longer the most simple " + "way to organize process groups. Please switch to DeviceMesh, " + "and DistributedManager.initialize_mesh", + category=DeprecationWarning, + stacklevel=2, + ) + + # Traverse process group tree in breadth first order + # to create nested process groups + q = queue.Queue() + q.put(config.root_id) + DistributedManager.create_group_from_node(config.root) + + while not q.empty(): + node_id = q.get() + if verbose: + print(f"Node ID: {node_id}") + + children = config.tree.children(node_id) + if verbose: + print(f" Children: {children}") + + parent_group = node_id + for child in children: + # Create child group and replace parent group by orthogonal group so + # that each child forms an independent block of processes + parent_group = DistributedManager.create_group_from_node( + child.data, + parent=parent_group, + ) + + # Add child ids to the queue + q.put(child.identifier) + + @atexit.register + @staticmethod + def cleanup(): + """Clean up distributed group and singleton""" + # Destroying group.WORLD is enough for all process groups to get destroyed + if ( + "_is_initialized" in DistributedManager._shared_state + and DistributedManager._shared_state["_is_initialized"] + and "_distributed" in DistributedManager._shared_state + and DistributedManager._shared_state["_distributed"] + ): + if torch.cuda.is_available(): + dist.barrier(device_ids=[DistributedManager().local_rank]) + else: + dist.barrier() + dist.destroy_process_group() + DistributedManager._shared_state = {} diff --git a/src/hirad/eval/metrics.py b/src/hirad/eval/metrics.py new file mode 100644 index 0000000..1170ea1 --- /dev/null +++ b/src/hirad/eval/metrics.py @@ -0,0 +1,57 @@ +import logging + +import numpy as np +import torch + +from scipy.signal import periodogram + + +# set up MAE calculation to be run for each channel for a given date/time (for target COSMO, prediction, and ERA interpolated) + +# input will be a 2D tensor of values with the COSMO lat/lon. + +# Extracted from physicsnemo/examples/weather/regen/paper_figures/score_inference.py + +def absolute_error(pred, target) -> tuple[float, np.ndarray]: + return np.abs(pred-target) + +def compute_mae(pred, target): + # Exclude any target NaNs (not expected, but precautionary) + # TODO: Fix the deprecated warning (index with dtype torch.bool instead of torch.uint8) + mask = ~np.isnan(target) + pred = pred[mask] + target = target[mask] + + ae = absolute_error(pred, target) + + # TODO, consider adding axis=-1 to choose what axis to average + return np.mean(absolute_error(pred, target)), ae + +def average_power_spectrum(data: np.ndarray, d=2.0): # d=2km by default + """ + Compute the average power spectrum of a data array. + + This function calculates the power spectrum for each row of the input data and + then averages them to obtain the overall power spectrum, repeating until + dimensionality is reduced to 1D. + The power spectrum represents the distribution of signal power as a function of frequency. + + Parameters: + data (numpy.ndarray): Input data array. + d (float): Sampling interval (time between data points). + + Returns: + tuple: A tuple containing the frequency values and the average power spectrum. + - freqs (numpy.ndarray): Frequency values corresponding to the power spectrum. + - power_spectra (numpy.ndarray): Average power spectrum of the input data. + """ + # Compute the power spectrum along the highest dimension for each row + freqs, power_spectra = periodogram(data, fs=1 / d, axis=-1) + logging.info(f'freqs.shape={freqs.shape}, power_spectra.shape={power_spectra.shape}') + + # Average along the first dimension + while power_spectra.ndim > 1: + power_spectra = power_spectra.mean(axis=0) + logging.info(f'power spectra shape={power_spectra.shape}') + + return freqs, power_spectra diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py new file mode 100644 index 0000000..1ca11c2 --- /dev/null +++ b/src/hirad/eval/plotting.py @@ -0,0 +1,30 @@ +import logging + +import cartopy.crs as ccrs +import matplotlib.pyplot as plt +import numpy as np + +def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np.array, filename: str): + fig = plt.figure() + fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) + logging.info(f'plotting values to {filename}') + p = ax.scatter(x=longitudes, y=latitudes, c=values) + ax.coastlines() + ax.gridlines(draw_labels=True) + plt.colorbar(p, label="absolute error", orientation="horizontal") + plt.savefig(filename) + plt.close('all') + +def plot_power_spectra(freqs: dict, spec: dict, channel_name, filename): + fig = plt.figure() + for k in freqs.keys(): + plt.loglog(freqs[k], spec[k], label=k) + plt.title(channel_name) + plt.legend() + plt.xlabel("Frequency (1/km)") + plt.ylabel("Power Spectrum") + plt.ylim(bottom=1e-1) + #plt.psd(x) + logging.info(f'plotting values to {filename}') + plt.savefig(filename) + plt.close('all') \ No newline at end of file diff --git a/src/hirad/eval/run_scoring.py b/src/hirad/eval/run_scoring.py new file mode 100644 index 0000000..4f2fcd8 --- /dev/null +++ b/src/hirad/eval/run_scoring.py @@ -0,0 +1,129 @@ +import os +import sys + +import metrics +import numpy as np +import plotting +import torch +import yaml + +X = 352 # length of grid from N-S +Y = 544 # length of grid from E-W + +def main(): + # TODO: Better arg parsing. + if len(sys.argv) < 3: + raise ValueError('Expected call run_scoring.py [input data directory] [predictions directory] [output plot directory]') + + input_directory = sys.argv[1] + predictions_directory = sys.argv[2] + output_directory = sys.argv[3] + + with open(os.path.join(input_directory, 'info', 'cosmo.yaml')) as cosmo_file: + cosmo_config = yaml.safe_load(cosmo_file) + target_channels = cosmo_config['select'] + + with open(os.path.join(input_directory, 'info', 'era.yaml')) as era_file: + era_config = yaml.safe_load(era_file) + input_channels = era_config['select'] + + lat_lon = torch.load(os.path.join(input_directory, 'info', 'cosmo-lat-lon'), weights_only=False) + latitudes = lat_lon[:,0] + longitudes = lat_lon[:,1] + + # Iterate over all files in the ground truth directory + files = os.listdir(os.path.join(input_directory, 'cosmo')) + files = sorted(files) + + + # Plot power spectra + # TODO: Handle ensembles + prediction_tensor = np.ndarray([len(files), len(target_channels), X, Y]) + baseline_tensor = np.ndarray([len(files), len(input_channels), X, Y]) + target_tensor = np.ndarray([len(files), len(target_channels), X, Y]) + + for i in range(len(files)): + datetime = files[i] + target = torch.load(os.path.join(input_directory, 'cosmo', datetime), weights_only=False) + baseline = torch.load(os.path.join(input_directory, 'era-interpolated', datetime), weights_only=False) + prediction = torch.load(os.path.join(predictions_directory, datetime), weights_only=False) + + # TODO: Handle ensembles + prediction_1d = prediction.reshape(prediction.shape[0], X*Y) + prediction_2d = prediction.reshape(prediction.shape[0], X, Y) + + baseline_1d = baseline.reshape(baseline.shape[0], X*Y) + baseline_2d = baseline.reshape(baseline.shape[0], X, Y) + + target_1d = target.reshape(target.shape[0], X*Y) + target_2d = target.reshape(target.shape[0], X, Y) + + baseline_tensor[i, :] = baseline_2d + prediction_tensor[i, :] = prediction_2d + target_tensor[i,:] = target_2d + + + # Calc spectra + for t_c in range(len(target_channels)): + b_c = input_channels.index(target_channels[t_c]) + freqs = {} + power = {} + if b_c > -1: + b_freq, b_power = metrics.average_power_spectrum(baseline_tensor[:,b_c,:,:].squeeze(), 2.0) + freqs['baseline'] = b_freq + power['baseline'] = b_power + #plotting.plot_power_spectrum(b_freq, b_power, target_channels[t_c], os.path.join('plots/spectra/baseline2dt', target_channels[t_c] + '-all_dates')) + t_freq, t_power = metrics.average_power_spectrum(target_tensor[:,t_c,:,:].squeeze(), 2.0) + freqs['target'] = t_freq + power['target'] = t_power + p_freq, p_power = metrics.average_power_spectrum(prediction_tensor[:,t_c,:,:].squeeze(), 2.0) + # TODO: Uncomment when we have predictions + #freqs['prediction'] = p_freq + #power['prediction'] = p_power + plotting.plot_power_spectra(freqs, power, target_channels[t_c], os.path.join(output_directory, 'spectra', target_channels[t_c] + '-alldates')) + + # store MAE as tensor of date:channel:ensembles:points + # TODO: Handle ensembles + baseline_absolute_error = np.ndarray([len(files),len(target_channels),1,X*Y]) + prediction_absolute_error = np.ndarray([len(files),len(target_channels),1,X*Y]) + + for i in range(len(files)): + datetime = files[i] + target = torch.load(os.path.join(input_directory, 'cosmo', datetime), weights_only=False) + baseline = torch.load(os.path.join(input_directory, 'era-interpolated', datetime), weights_only=False) + prediction = torch.load(os.path.join(predictions_directory, datetime), weights_only=False) + + + prediction_1d = prediction.reshape(prediction.shape[0], 1, X*Y) + prediction_2d = prediction.reshape(prediction.shape[0], 1, X, Y) + + # Get MAE + for t_c in range(len(target_channels)): + b_c = input_channels.index(target_channels[t_c]) + if b_c > -1: + _, baseline_errors = metrics.compute_mae(baseline[b_c,:,:], target[t_c,:,:]) + baseline_absolute_error[i, t_c, :, :] = baseline_errors + _, prediction_errors = metrics.compute_mae(prediction_1d[t_c,:,:], target[t_c,:,:]) + prediction_absolute_error[i, t_c, :, :] = prediction_errors + + + print(f'baseline_absolute_error.shape={baseline_absolute_error.shape}, prediction_absolute_error.shape={prediction_absolute_error.shape}') + # Average errors over ensembles + baseline_mae = np.mean(baseline_absolute_error, axis=2) + prediction_mae = np.mean(prediction_absolute_error, axis=2) + + # Average errors over time + baseline_mae = np.mean(baseline_mae, axis=0) + prediction_mae = np.mean(prediction_mae, axis = 0) + + print(f'baseline mean error = {np.mean(baseline_mae, axis=-1)}') + print(f'prediction mean error = {np.mean(prediction_mae, axis=-1)}') + + # Plot the mean error onto the grid. + for t_c in range(len(target_channels)): + plotting.plot_error_projection(baseline_mae[t_c,:], latitudes, longitudes, os.path.join(output_directory, 'baseline-error' + target_channels[t_c] + '-' + 'average_over_time')) + plotting.plot_error_projection(prediction_mae[t_c,:], latitudes, longitudes, os.path.join(output_directory, 'prediction-error' + target_channels[t_c] + '-' + 'average_over_time')) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/hirad/generate.sh b/src/hirad/generate.sh new file mode 100644 index 0000000..87c8979 --- /dev/null +++ b/src/hirad/generate.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +#SBATCH --job-name="testrun" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=72 +#SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/full_generation.log +#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/full_generation.err + +### ENVIRONMENT #### +#SBATCH --uenv=pytorch/v2.6.0:/user-environment +#SBATCH --view=default +#SBATCH -A a-a122 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" + +# Get number of physical cores using Python +PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# Use SLURM_NTASKS (number of processes to be launched by torchrun) +LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# Compute threads per process +OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +export OMP_NUM_THREADS=$OMP_THREADS +echo "Physical cores: $PHYSICAL_CORES" +echo "Local processes: $LOCAL_PROCS" +echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" + +# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml +srun bash -c " + . ./train_env/bin/activate + python src/hirad/inference/generate.py --config-name=generate_era_cosmo.yaml +" \ No newline at end of file diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py new file mode 100644 index 0000000..ec385dc --- /dev/null +++ b/src/hirad/inference/generate.py @@ -0,0 +1,477 @@ +import hydra +import os +import json +from omegaconf import OmegaConf, DictConfig +import torch +import torch._dynamo +import nvtx +import numpy as np +import contextlib + +from hirad.distributed import DistributedManager +from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper +from concurrent.futures import ThreadPoolExecutor +from functools import partial + +import cartopy.crs as ccrs +from matplotlib import pyplot as plt +from einops import rearrange +from torch.distributed import gather + + +from hydra.utils import to_absolute_path +from hirad.models import EDMPrecondSuperResolution, UNet +from hirad.utils.patching import GridPatching2D +from hirad.utils.stochastic_sampler import stochastic_sampler +from hirad.utils.deterministic_sampler import deterministic_sampler +from hirad.utils.inference_utils import ( + get_time_from_range, + regression_step, + diffusion_step, +) +from hirad.utils.checkpoint import load_checkpoint + + +from hirad.utils.generate_utils import ( + get_dataset_and_sampler +) + +from hirad.utils.train_helpers import set_patch_shape + +from hirad.eval import compute_mae, average_power_spectrum, plot_error_projection, plot_power_spectra + +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") +def main(cfg: DictConfig) -> None: + """Generate random dowscaled atmospheric states using the techniques described in the paper + "Elucidating the Design Space of Diffusion-Based Generative Models". + """ + torch.backends.cudnn.enabled = False + # Initialize distributed manager + DistributedManager.initialize() + dist = DistributedManager() + device = dist.device + + # Initialize logger + logger = PythonLogger("generate") # General python logger + logger0 = RankZeroLoggingWrapper(logger, dist) + + # Handle the batch size + seeds = list(np.arange(cfg.generation.num_ensembles)) + num_batches = ( + (len(seeds) - 1) // (cfg.generation.seed_batch_size * dist.world_size) + 1 + ) * dist.world_size + all_batches = torch.as_tensor(seeds).tensor_split(num_batches) + rank_batches = all_batches[dist.rank :: dist.world_size] + + # Synchronize + if dist.world_size > 1: + torch.distributed.barrier() + + # Parse the inference input times + if cfg.generation.times_range and cfg.generation.times: + raise ValueError("Either times_range or times must be provided, but not both") + if cfg.generation.times_range: + times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M") #TODO check what time formats we are using and adapt + else: + times = cfg.generation.times + + # Create dataset object + dataset_cfg = OmegaConf.to_container(cfg.dataset) + if "has_lead_time" in cfg.generation: + has_lead_time = cfg.generation["has_lead_time"] + else: + has_lead_time = False + dataset, sampler = get_dataset_and_sampler( + dataset_cfg=dataset_cfg, times=times, has_lead_time=has_lead_time + ) + img_shape = dataset.image_shape() + img_out_channels = len(dataset.output_channels()) + + # Parse the patch shape + if cfg.generation.patching: + patch_shape_x = cfg.generation.patch_shape_x + patch_shape_y = cfg.generation.patch_shape_y + else: + patch_shape_x, patch_shape_y = None, None + patch_shape = (patch_shape_y, patch_shape_x) + use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) + if use_patching: + patching = GridPatching2D( + img_shape=img_shape, + patch_shape=patch_shape, + boundary_pix=cfg.generation.boundary_pix, + overlap_pix=cfg.generation.overlap_pix, + ) + logger0.info("Patch-based training enabled") + else: + patching = None + logger0.info("Patch-based training disabled") + + # Parse the inference mode + if cfg.generation.inference_mode == "regression": + load_net_reg, load_net_res = True, False + elif cfg.generation.inference_mode == "diffusion": + load_net_reg, load_net_res = False, True + elif cfg.generation.inference_mode == "all": + load_net_reg, load_net_res = True, True + else: + raise ValueError(f"Invalid inference mode {cfg.generation.inference_mode}") + + # Load diffusion network, move to device, change precision + if load_net_res: + res_ckpt_path = cfg.generation.io.res_ckpt_path + logger0.info(f'Loading correction network from "{res_ckpt_path}"...') + + diffusion_model_args_path = os.path.join(res_ckpt_path, 'model_args.json') + if not os.path.isfile(diffusion_model_args_path): + raise FileNotFoundError(f"Missing config file at '{diffusion_model_args_path}'.") + with open(diffusion_model_args_path, 'r') as f: + diffusion_model_args = json.load(f) + + net_res = EDMPrecondSuperResolution(**diffusion_model_args) + + _ = load_checkpoint( + path=res_ckpt_path, + model=net_res, + device=dist.device + ) + + #TODO fix to use channels_last which is optimal for H100 + net_res = net_res.eval().to(device).to(memory_format=torch.channels_last) + if cfg.generation.perf.force_fp16: + net_res.use_fp16 = True + + # Disable AMP for inference (even if model is trained with AMP) + if hasattr(net_res, "amp_mode"): + net_res.amp_mode = False + else: + net_res = None + + # load regression network, move to device, change precision + if load_net_reg: + reg_ckpt_path = cfg.generation.io.reg_ckpt_path + logger0.info(f'Loading regression network from "{reg_ckpt_path}"...') + + + regression_model_args_path = os.path.join(reg_ckpt_path, 'model_args.json') + if not os.path.isfile(regression_model_args_path): + raise FileNotFoundError(f"Missing config file at '{regression_model_args_path}'.") + with open(regression_model_args_path, 'r') as f: + regression_model_args = json.load(f) + + net_reg = UNet(**regression_model_args) + + _ = load_checkpoint( + path=reg_ckpt_path, + model=net_reg, + device=dist.device + ) + + net_reg = net_reg.eval().to(device).to(memory_format=torch.channels_last) + if cfg.generation.perf.force_fp16: + net_reg.use_fp16 = True + + # Disable AMP for inference (even if model is trained with AMP) + if hasattr(net_reg, "amp_mode"): + net_reg.amp_mode = False + else: + net_reg = None + + # Reset since we are using a different mode. + if cfg.generation.perf.use_torch_compile: + torch._dynamo.reset() + # Only compile residual network + # Overhead of compiling regression network outweights any benefits + if net_res: + net_res = torch.compile(net_res, mode="reduce-overhead") + + # Partially instantiate the sampler based on the configs + if cfg.sampler.type == "deterministic": + if cfg.generation.hr_mean_conditioning: + raise NotImplementedError( + "High-res mean conditioning is not yet implemented for the deterministic sampler" + ) + sampler_fn = partial( + deterministic_sampler, + num_steps=cfg.sampler.num_steps, + # num_ensembles=cfg.generation.num_ensembles, + solver=cfg.sampler.solver, + ) + elif cfg.sampler.type == "stochastic": + sampler_fn = partial(stochastic_sampler, patching=patching) + else: + raise ValueError(f"Unknown sampling method {cfg.sampling.type}") + + + # Main generation definition + def generate_fn(image_lr, lead_time_label): + with nvtx.annotate("generate_fn", color="green"): + # (1, C, H, W) + image_lr = image_lr.to(memory_format=torch.channels_last) + + if net_reg: + with nvtx.annotate("regression_model", color="yellow"): + image_reg = regression_step( + net=net_reg, + img_lr=image_lr, + latents_shape=( + cfg.generation.seed_batch_size, + img_out_channels, + img_shape[0], + img_shape[1], + ), # (batch_size, C, H, W) + lead_time_label=lead_time_label, + ) + if net_res: + if cfg.generation.hr_mean_conditioning: + mean_hr = image_reg[0:1] + else: + mean_hr = None + with nvtx.annotate("diffusion model", color="purple"): + image_res = diffusion_step( + net=net_res, + sampler_fn=sampler_fn, + img_shape=img_shape, + img_out_channels=img_out_channels, + rank_batches=rank_batches, + img_lr=image_lr.expand( + cfg.generation.seed_batch_size, -1, -1, -1 + ), #.to(memory_format=torch.channels_last), + rank=dist.rank, + device=device, + mean_hr=mean_hr, + lead_time_label=lead_time_label, + ) + if cfg.generation.inference_mode == "regression": + image_out = image_reg + elif cfg.generation.inference_mode == "diffusion": + image_out = image_res + else: + image_out = image_reg[0:1,::] + image_res + + # Gather tensors on rank 0 + if dist.world_size > 1: + if dist.rank == 0: + gathered_tensors = [ + torch.zeros_like( + image_out, dtype=image_out.dtype, device=image_out.device + ) + for _ in range(dist.world_size) + ] + else: + gathered_tensors = None + + torch.distributed.barrier() + gather( + image_out, + gather_list=gathered_tensors if dist.rank == 0 else None, + dst=0, + ) + + if dist.rank == 0: + if cfg.generation.inference_mode != "regression": + return torch.cat(gathered_tensors), image_reg[0:1,::] + return torch.cat(gathered_tensors) + else: + return None, None + else: + #TODO do this for multi-gpu setting above too + if cfg.generation.inference_mode != "regression": + return image_out, image_reg + return image_out, None + + # generate images + output_path = getattr(cfg.generation.io, "output_path", "./outputs") + logger0.info(f"Generating images, saving results to {output_path}...") + batch_size = 1 + warmup_steps = min(len(times) - 1, 2) + # Generates model predictions from the input data using the specified + # `generate_fn`, and save the predictions to the provided NetCDF file. It iterates + # through the dataset using a data loader, computes predictions, and saves them along + # with associated metadata. + + torch_cuda_profiler = ( + torch.cuda.profiler.profile() + if torch.cuda.is_available() + else contextlib.nullcontext() + ) + torch_nvtx_profiler = ( + torch.autograd.profiler.emit_nvtx() + if torch.cuda.is_available() + else contextlib.nullcontext() + ) + with torch_cuda_profiler: + with torch_nvtx_profiler: + + data_loader = torch.utils.data.DataLoader( + dataset=dataset, sampler=sampler, batch_size=1, pin_memory=True + ) + time_index = -1 + if dist.rank == 0: + writer_executor = ThreadPoolExecutor( + max_workers=cfg.generation.perf.num_writer_workers + ) + writer_threads = [] + + # Create timer objects only if CUDA is available + use_cuda_timing = torch.cuda.is_available() + if use_cuda_timing: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + else: + # Dummy no-op functions for CPU case + class DummyEvent: + def record(self): + pass + + def synchronize(self): + pass + + def elapsed_time(self, _): + return 0 + + start = end = DummyEvent() + + times = dataset.time() + for index, (image_tar, image_lr, *lead_time_label) in enumerate( + iter(data_loader) + ): + time_index += 1 + if dist.rank == 0: + logger0.info(f"starting index: {time_index}") + + if time_index == warmup_steps: + start.record() + + # continue + if lead_time_label: + lead_time_label = lead_time_label[0].to(dist.device).contiguous() + else: + lead_time_label = None + image_lr = ( + image_lr.to(device=device) + .to(torch.float32) + .to(memory_format=torch.channels_last) + ) + image_tar = image_tar.to(device=device).to(torch.float32) + image_out, image_reg = generate_fn(image_lr,lead_time_label) + if dist.rank == 0: + batch_size = image_out.shape[0] + # write out data in a seperate thread so we don't hold up inferencing + writer_threads.append( + writer_executor.submit( + save_images, + output_path, + times[sampler[time_index]], + dataset, + image_out.cpu().numpy(), + image_tar.cpu().numpy(), + image_lr.cpu().numpy(), + image_reg.cpu().numpy() if image_reg is not None else None, + ) + ) + end.record() + end.synchronize() + elapsed_time = ( + start.elapsed_time(end) / 1000.0 if use_cuda_timing else 0 + ) # Convert ms to s + timed_steps = time_index + 1 - warmup_steps + if dist.rank == 0 and use_cuda_timing: + average_time_per_batch_element = elapsed_time / timed_steps / batch_size + logger.info( + f"Total time to run {timed_steps} steps and {batch_size} members = {elapsed_time} s" + ) + logger.info( + f"Average time per batch element = {average_time_per_batch_element} s" + ) + + # make sure all the workers are done writing + if dist.rank == 0: + for thread in list(writer_threads): + thread.result() + writer_threads.remove(thread) + writer_executor.shutdown() + + if dist.rank == 0: + f.close() + logger0.info("Generation Completed.") + +def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): + + os.makedirs(output_path, exist_ok=True) + + longitudes = dataset.longitude() + latitudes = dataset.latitude() + input_channels = dataset.input_channels() + output_channels = dataset.output_channels() + + target = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze()),1) #.reshape(len(output_channels),-1) + prediction = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1) #.reshape(len(output_channels),-1) + baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1)# .reshape(len(input_channels),-1) + + freqs = {} + power = {} + for idx, channel in enumerate(output_channels): + input_channel_idx = input_channels.index(channel) + _, baseline_errors = compute_mae(baseline[input_channel_idx,:,:], target[idx,:,:]) + _, prediction_errors = compute_mae(prediction[idx,:,:], target[idx,:,:]) + + plot_error_projection(baseline_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-baseline-error.jpg')) + plot_error_projection(prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-prediction-error.jpg')) + + b_freq, b_power = average_power_spectrum(baseline[input_channel_idx,:,:].squeeze(), 2.0) + freqs['baseline'] = b_freq + power['baseline'] = b_power + #plotting.plot_power_spectrum(b_freq, b_power, target_channels[t_c], os.path.join('plots/spectra/baseline2dt', target_channels[t_c] + '-all_dates')) + t_freq, t_power = average_power_spectrum(target[idx,:,:].squeeze(), 2.0) + freqs['target'] = t_freq + power['target'] = t_power + p_freq, p_power = average_power_spectrum(prediction[idx,:,:].squeeze(), 2.0) + freqs['prediction'] = p_freq + power['prediction'] = p_power + plot_power_spectra(freqs, power, channel.name, os.path.join(output_path, f'{time_step}-{channel.name}-spectra.jpg')) + +# def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): +# longitudes = dataset.longitude() +# latitudes = dataset.latitude() +# input_channels = dataset.input_channels() +# output_channels = dataset.output_channels() +# image_pred = image_pred.numpy() +# image_pred_final = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1).reshape(len(output_channels),-1) +# if image_pred.shape[0]>1: +# image_pred_mean = np.flip(dataset.denormalize_output(image_pred.mean(axis=0)),1).reshape(len(output_channels),-1) +# image_pred_first_step = np.flip(dataset.denormalize_output(image_pred[0,::].squeeze()),1).reshape(len(output_channels),-1) +# image_pred_mid_step = np.flip(dataset.denormalize_output(image_pred[image_pred.shape[0]//2,::].squeeze()),1).reshape(len(output_channels),-1) +# image_hr = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) +# image_lr = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze().numpy()),1).reshape(len(input_channels),-1) +# if mean_pred is not None: +# mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) +# os.makedirs(output_path, exist_ok=True) +# for idx, channel in enumerate(output_channels): +# input_channel_idx = input_channels.index(channel) +# _plot_projection(longitudes,latitudes,image_lr[input_channel_idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-lr.jpg')) +# _plot_projection(longitudes,latitudes,image_hr[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr.jpg')) +# _plot_projection(longitudes,latitudes,image_pred_final[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred.jpg')) +# if image_pred.shape[0]>1: +# _plot_projection(longitudes,latitudes,image_pred_mean[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mean.jpg')) +# _plot_projection(longitudes,latitudes,image_pred_first_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-0.jpg')) +# _plot_projection(longitudes,latitudes,image_pred_mid_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mid.jpg')) +# if mean_pred is not None: +# _plot_projection(longitudes,latitudes,mean_pred[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-mean-pred.jpg')) + +# def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): + +# """Plot observed or interpolated data in a scatter plot.""" +# # TODO: Refactor this somehow, it's not really generalizing well across variables. +# fig = plt.figure() +# fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) +# p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) +# ax.coastlines() +# ax.gridlines(draw_labels=True) +# plt.colorbar(p, label="K", orientation="horizontal") +# plt.savefig(filename) +# plt.close('all') + +if __name__ == "__main__": + main() + \ No newline at end of file diff --git a/src/hirad/losses/__init__.py b/src/hirad/losses/__init__.py new file mode 100644 index 0000000..868ffdf --- /dev/null +++ b/src/hirad/losses/__init__.py @@ -0,0 +1 @@ +from .loss import ResidualLoss, RegressionLoss, RegressionLossCE \ No newline at end of file diff --git a/src/hirad/losses/loss.py b/src/hirad/losses/loss.py new file mode 100644 index 0000000..fb65960 --- /dev/null +++ b/src/hirad/losses/loss.py @@ -0,0 +1,1032 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Loss functions used in the paper +"Elucidating the Design Space of Diffusion-Based Generative Models".""" + +from typing import Callable, Optional, Tuple, Union + +import numpy as np +import torch + +from hirad.utils.patching import RandomPatching2D + +class VPLoss: + """ + Loss function corresponding to the variance preserving (VP) formulation. + + Parameters + ---------- + beta_d: float, optional + Coefficient for the diffusion process, by default 19.9. + beta_min: float, optional + Minimum bound, by defaults 0.1. + epsilon_t: float, optional + Small positive value, by default 1e-5. + + Note: + ----- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + + """ + + def __init__( + self, beta_d: float = 19.9, beta_min: float = 0.1, epsilon_t: float = 1e-5 + ): + self.beta_d = beta_d + self.beta_min = beta_min + self.epsilon_t = epsilon_t + + def __call__( + self, + net: torch.nn.Module, + images: torch.Tensor, + labels: torch.Tensor, + augment_pipe: Optional[Callable] = None, + ): + """ + Calculate and return the loss corresponding to the variance preserving (VP) + formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the input images. + The noise level is determined by 'sigma', which is computed as a function of + 'epsilon_t' and random values. The calculated loss is weighted based on the + inverse of 'sigma^2'. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input images to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) + sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1)) + weight = 1 / sigma**2 + y, augment_labels = ( + augment_pipe(images) if augment_pipe is not None else (images, None) + ) + n = torch.randn_like(y) * sigma + D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + + def sigma( + self, t: Union[float, torch.Tensor] + ): # NOTE: also exists in preconditioning + """ + Compute the sigma(t) value for a given t based on the VP formulation. + + The function calculates the noise level schedule for the diffusion process based + on the given parameters `beta_d` and `beta_min`. + + Parameters + ---------- + t : Union[float, torch.Tensor] + The timestep or set of timesteps for which to compute sigma(t). + + Returns + ------- + torch.Tensor + The computed sigma(t) value(s). + """ + t = torch.as_tensor(t) + return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt() + + +class VELoss: + """ + Loss function corresponding to the variance exploding (VE) formulation. + + Parameters + ---------- + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + + Note: + ----- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + """ + + def __init__(self, sigma_min: float = 0.02, sigma_max: float = 100.0): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + + def __call__(self, net, images, labels, augment_pipe=None): + """ + Calculate and return the loss corresponding to the variance exploding (VE) + formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the input images. + The noise level is determined by 'sigma', which is computed as a function of + 'sigma_min' and 'sigma_max' and random values. The calculated loss is weighted + based on the inverse of 'sigma^2'. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input images to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) + sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform) + weight = 1 / sigma**2 + y, augment_labels = ( + augment_pipe(images) if augment_pipe is not None else (images, None) + ) + n = torch.randn_like(y) * sigma + D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + + +class EDMLoss: + """ + Loss function proposed in the EDM paper. + + Parameters + ---------- + P_mean: float, optional + Mean value for `sigma` computation, by default -1.2. + P_std: float, optional: + Standard deviation for `sigma` computation, by default 1.2. + sigma_data: float, optional + Standard deviation for data, by default 0.5. + + Note + ---- + Reference: Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + """ + + def __init__( + self, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5 + ): + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + + def __call__(self, net, images, condition=None, labels=None, augment_pipe=None): + """ + Calculate and return the loss corresponding to the EDM formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the input images. + The noise level is determined by 'sigma', which is computed as a function of + 'P_mean' and 'P_std' random values. The calculated loss is weighted as a + function of 'sigma' and 'sigma_data'. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input images to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + y, augment_labels = ( + augment_pipe(images) if augment_pipe is not None else (images, None) + ) + n = torch.randn_like(y) * sigma + if condition is not None: + D_yn = net( + y + n, + sigma, + condition=condition, + class_labels=labels, + augment_labels=augment_labels, + ) + else: + D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + + +class EDMLossSR: + """ + Variation of the loss function proposed in the EDM paper for Super-Resolution. + + Parameters + ---------- + P_mean: float, optional + Mean value for `sigma` computation, by default -1.2. + P_std: float, optional: + Standard deviation for `sigma` computation, by default 1.2. + sigma_data: float, optional + Standard deviation for data, by default 0.5. + + Note + ---- + Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5 + ): + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + + def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): + """ + Calculate and return the loss corresponding to the EDM formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the input images. + The noise level is determined by 'sigma', which is computed as a function of + 'P_mean' and 'P_std' random values. The calculated loss is weighted as a + function of 'sigma' and 'sigma_data'. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input images to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + # augment for conditional generation + img_tot = torch.cat((img_clean, img_lr), dim=1) + y_tot, augment_labels = ( + augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) + ) + y = y_tot[:, : img_clean.shape[1], :, :] + y_lr = y_tot[:, img_clean.shape[1] :, :, :] + + n = torch.randn_like(y) * sigma + D_yn = net(y + n, y_lr, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + + +class RegressionLoss: + """ + Regression loss function for the deterministic predictions. + Note: this loss does not apply any reduction. + + Attributes + ---------- + sigma_data: float + Standard deviation for data. Deprecated and ignored. + + Note + ---- + Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__(self): + """ + Arguments + ---------- + """ + return + + def __call__( + self, + net: torch.nn.Module, + img_clean: torch.Tensor, + img_lr: torch.Tensor, + augment_pipe: Optional[ + Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]] + ] = None, + ) -> torch.Tensor: + """ + Calculate and return the regression loss for + deterministic predictions. + + Parameters + ---------- + net : torch.nn.Module + The neural network model that will make predictions. + Expected signature: `net(x, img_lr, + augment_labels=augment_labels, force_fp32=False)`, where: + x (torch.Tensor): Tensor of shape (B, C_hr, H, W). Is zero-filled. + img_lr (torch.Tensor): Low-resolution input of shape (B, C_lr, H, W) + augment_labels (torch.Tensor, optional): Optional augmentation + labels, returned by `augment_pipe`. + force_fp32 (bool, optional): Whether to force the model to use + fp32, by default False. + Returns: + torch.Tensor: Predictions of shape (B, C_hr, H, W) + + img_clean : torch.Tensor + High-resolution input images of shape (B, C_hr, H, W). + Used as ground truth and for data augmentation if 'augment_pipe' is provided. + + img_lr : torch.Tensor + Low-resolution input images of shape (B, C_lr, H, W). + Used as input to the neural network. + + augment_pipe : callable, optional + An optional data augmentation function. + Expected signature: + img_tot (torch.Tensor): Concatenated high and low resolution + images of shape (B, C_hr+C_lr, H, W) + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - Augmented images of shape (B, C_hr+C_lr, H, W) + - Optional augmentation labels + + Returns + ------- + torch.Tensor + A tensor representing the per-sample element-wise squared + difference between the network's predictions and the high + resolution images `img_clean` (possibly data-augmented by + `augment_pipe`). + Shape: (B, C_hr, H, W), same as `img_clean`. + """ + weight = ( + 1.0 # (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 + ) + + img_tot = torch.cat((img_clean, img_lr), dim=1) + y_tot, augment_labels = ( + augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) + ) + y = y_tot[:, : img_clean.shape[1], :, :] + y_lr = y_tot[:, img_clean.shape[1] :, :, :] + + zero_input = torch.zeros_like(y, device=img_clean.device) + D_yn = net(zero_input, y_lr, force_fp32=False, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + + return loss + + +class ResidualLoss: + """ + Mixture loss function for denoising score matching. + + This class implements a loss function that combines deterministic + regression with denoising score matching. It uses a pre-trained regression + network to compute residuals before applying the diffusion process. + + Attributes + ---------- + regression_net : torch.nn.Module + The regression network used for computing residuals. + P_mean : float + Mean value for noise level computation. + P_std : float + Standard deviation for noise level computation. + sigma_data : float + Standard deviation for data weighting. + hr_mean_conditioning : bool + Flag indicating whether to use high-resolution mean for conditioning. + + Note + ---- + Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C., Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric + Downscaling. arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + regression_net: torch.nn.Module, + P_mean: float = 0.0, + P_std: float = 1.2, + sigma_data: float = 0.5, + hr_mean_conditioning: bool = False, + ): + """ + Arguments + ---------- + regression_net : torch.nn.Module + Pre-trained regression network used to compute residuals. + Expected signature: `net(zero_input, y_lr, + lead_time_label=lead_time_label, augment_labels=augment_labels)` or + `net(zero_input, y_lr, augment_labels=augment_labels)`, where: + zero_input (torch.Tensor): Zero tensor of shape (B, C_hr, H, W) + y_lr (torch.Tensor): Low-resolution input of shape (B, C_lr, H, W) + lead_time_label (torch.Tensor, optional): Optional lead time labels + augment_labels (torch.Tensor, optional): Optional augmentation labels + Returns: + torch.Tensor: Predictions of shape (B, C_hr, H, W) + + P_mean : float, optional + Mean value for noise level computation, by default 0.0. + + P_std : float, optional + Standard deviation for noise level computation, by default 1.2. + + sigma_data : float, optional + Standard deviation for data weighting, by default 0.5. + + hr_mean_conditioning : bool, optional + Whether to use high-resolution mean for conditioning predicted, by default False. + When True, the mean prediction from `regression_net` is channel-wise + concatenated with `img_lr` for conditioning. + """ + self.regression_net = regression_net + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + self.hr_mean_conditioning = hr_mean_conditioning + self.y_mean = None + + def __call__( + self, + net: torch.nn.Module, + img_clean: torch.Tensor, + img_lr: torch.Tensor, + patching: Optional[RandomPatching2D] = None, + lead_time_label: Optional[torch.Tensor] = None, + augment_pipe: Optional[ + Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]] + ] = None, + use_patch_grad_acc: bool = False, + ) -> torch.Tensor: + """ + Calculate and return the loss for denoising score matching. + + This method computes a mixture loss that combines deterministic + regression with denoising score matching. It first computes residuals + using the regression network, then applies the diffusion process to + these residuals. + + In addition to the standard denoising score matching loss, this method + also supports optional patching for multi-diffusion. In this case, the spatial + dimensions of the input are decomposed into `P` smaller patches of shape + (H_patch, W_patch), that are grouped along the batch dimension, and the + model is applied to each patch individually. In the following, if `patching` + is not provided, then the input is not patched and `P=1` and `(H_patch, + W_patch) = (H, W)`. When patching is used, the original non-patched conditioning is + interpolated onto a spatial grid of shape `(H_patch, W_patch)` and channel-wise + concatenated to the patched conditioning. This ensures that each patch + maintains global information from the entire domain. + + The diffusion model `net` is expected to be conditioned on an input with + `C_cond` channels, which should be: + - `C_cond = C_lr` if `hr_mean_conditioning` is `False` and + `patching` is None. + - `C_cond = C_hr + C_lr` if `hr_mean_conditioning` is `True` and + `patching` is None. + - `C_cond = C_hr + 2*C_lr` if `hr_mean_conditioning` is `True` and + `patching` is not None. + - `C_cond = 2*C_lr` if `hr_mean_conditioning` is `False` and + `patching` is not None. + Additionally, `C_cond` should also include any embedding channels, + such as positional embeddings or time embeddings. + + Note: this loss function does not apply any reduction. + + Parameters + ---------- + net : torch.nn.Module + The neural network model for the diffusion process. + Expected signature: `net(latent, y_lr, sigma, + embedding_selector=embedding_selector, lead_time_label=lead_time_label, + augment_labels=augment_labels)`, where: + latent (torch.Tensor): Noisy input of shape (B[*P], C_hr, H_patch, W_patch) + y_lr (torch.Tensor): Conditioning of shape (B[*P], C_cond, H_patch, W_patch) + sigma (torch.Tensor): Noise level of shape (B[*P], 1, 1, 1) + embedding_selector (callable, optional): Function to select + positional embeddings. Only used if `patching` is provided. + lead_time_label (torch.Tensor, optional): Lead time labels. + augment_labels (torch.Tensor, optional): Augmentation labels + Returns: + torch.Tensor: Predictions of shape (B[*P], C_hr, H_patch, W_patch) + + img_clean : torch.Tensor + High-resolution input images of shape (B, C_hr, H, W). + Used as ground truth and for data augmentation if 'augment_pipe' is provided. + + img_lr : torch.Tensor + Low-resolution input images of shape (B, C_lr, H, W). + Used as input to the regression network and conditioning for the + diffusion process. + + patching : Optional[RandomPatching2D], optional + Patching strategy for processing large images, by default None. See + :class:`physicsnemo.utils.patching.RandomPatching2D` for details. + When provided, the patching strategy is used for both image patches + and positional embeddings selection in the diffusion model `net`. + Transforms tensors from shape (B, C, H, W) to (B*P, C, H_patch, + W_patch). + + lead_time_label : Optional[torch.Tensor], optional + Labels for lead-time aware predictions, by default None. + Shape can vary based on model requirements, typically (B,) or scalar. + + augment_pipe : Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]]] + Data augmentation function. + Expected signature: + img_tot (torch.Tensor): Concatenated high and low resolution images + of shape (B, C_hr+C_lr, H, W) + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - Augmented images of shape (B, C_hr+C_lr, H, W) + - Optional augmentation labels + use_patch_grad_acc: bool, optional + A boolean flag indicating whether to enable multi-iterations of patching accumulations + for amortizing regression cost. Default False. + + Returns + ------- + torch.Tensor + If patching is not used: + A tensor of shape (B, C_hr, H, W) representing the per-sample loss. + If patching is used: + A tensor of shape (B*P, C_hr, H_patch, W_patch) representing + the per-patch loss. + + Raises + ------ + ValueError + If patching is provided but is not an instance of RandomPatching2D. + If shapes of img_clean and img_lr are incompatible. + """ + + # Safety check: enforce patching object + if patching and not isinstance(patching, RandomPatching2D): + raise ValueError("patching must be a 'RandomPatching2D' object.") + # Safety check: enforce shapes + if ( + img_clean.shape[0] != img_lr.shape[0] + or img_clean.shape[2:] != img_lr.shape[2:] + ): + raise ValueError( + f"Shape mismatch between img_clean {img_clean.shape} and " + f"img_lr {img_lr.shape}. " + f"Batch size, height and width must match." + ) + + # augment for conditional generation + img_tot = torch.cat((img_clean, img_lr), dim=1) + y_tot, augment_labels = ( + augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) + ) + y = y_tot[:, : img_clean.shape[1], :, :] + y_lr = y_tot[:, img_clean.shape[1] :, :, :] + y_lr_res = y_lr + batch_size = y.shape[0] + + # if using multi-iterations of patching, switch to optimized version + if use_patch_grad_acc: + # form residual + if self.y_mean is None: + if lead_time_label is not None: + y_mean = self.regression_net( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + y_mean = self.regression_net( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + augment_labels=augment_labels, + ) + self.y_mean = y_mean + + # if on full domain, or if using patching without multi-iterations + else: + # form residual + if lead_time_label is not None: + y_mean = self.regression_net( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + y_mean = self.regression_net( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + augment_labels=augment_labels, + ) + + self.y_mean = y_mean + + y = y - self.y_mean + + if self.hr_mean_conditioning: + y_lr = torch.cat((self.y_mean, y_lr), dim=1) + + # patchified training + # conditioning: cat(y_mean, y_lr, input_interp, pos_embd), 4+12+100+4 + # removed patch_embedding_selector due to compilation issue with dynamo. + if patching: + # Patched residual + # (batch_size * patch_num, c_out, patch_shape_y, patch_shape_x) + y_patched = patching.apply(input=y) + # Patched conditioning on y_lr and interp(img_lr) + # (batch_size * patch_num, 2*c_in, patch_shape_y, patch_shape_x) + y_lr_patched = patching.apply(input=y_lr, additional_input=img_lr) + + y = y_patched + y_lr = y_lr_patched + + # Noise + rnd_normal = torch.randn([y.shape[0], 1, 1, 1], device=img_clean.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + # Input + noise + latent = y + torch.randn_like(y) * sigma + + if lead_time_label is not None: + D_yn = net( + latent, + y_lr, + sigma, + embedding_selector=None, + global_index=patching.global_index(batch_size, img_clean.device) + if patching is not None + else None, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + D_yn = net( + latent, + y_lr, + sigma, + embedding_selector=None, + global_index=patching.global_index(batch_size, img_clean.device) + if patching is not None + else None, + augment_labels=augment_labels, + ) + loss = weight * ((D_yn - y) ** 2) + + return loss + + + +class VELoss_dfsr: + """ + Loss function for dfsr model, modified from class VELoss. + + Parameters + ---------- + beta_start : float + Noise level at the initial step of the forward diffusion process, by default 0.0001. + beta_end : float + Noise level at the Final step of the forward diffusion process, by default 0.02. + num_diffusion_timesteps : int + Total number of forward/backward diffusion steps, by default 1000. + + + Note: + ----- + Reference: Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models. + Advances in neural information processing systems. 2020;33:6840-51. + """ + + def __init__( + self, + beta_start: float = 0.0001, + beta_end: float = 0.02, + num_diffusion_timesteps: int = 1000, + ): + # scheduler for diffusion: + self.beta_schedule = "linear" + self.beta_start = beta_start + self.beta_end = beta_end + self.num_diffusion_timesteps = num_diffusion_timesteps + betas = self.get_beta_schedule( + beta_schedule=self.beta_schedule, + beta_start=self.beta_start, + beta_end=self.beta_end, + num_diffusion_timesteps=self.num_diffusion_timesteps, + ) + self.betas = torch.from_numpy(betas).float() + self.num_timesteps = betas.shape[0] + + def get_beta_schedule( + self, beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps + ): + """ + Compute the variance scheduling parameters {beta(0), ..., beta(t), ..., beta(T)} + based on the VP formulation. + + beta_schedule: str + Method to construct the sequence of beta(t)'s. + beta_start: float + Noise level at the initial step of the forward diffusion process, e.g., beta(0) + beta_end: float + Noise level at the final step of the forward diffusion process, e.g., beta(T) + num_diffusion_timesteps: int + Total number of forward/backward diffusion steps + """ + + def sigmoid(x): + return 1 / (np.exp(-x) + 1) + + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start**0.5, + beta_end**0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + elif beta_schedule == "sigmoid": + betas = np.linspace(-6, 6, num_diffusion_timesteps) + betas = sigmoid(betas) * (beta_end - beta_start) + beta_start + else: + raise NotImplementedError(beta_schedule) + if betas.shape != (num_diffusion_timesteps,): + raise ValueError( + f"Expected betas to have shape ({num_diffusion_timesteps},), " + f"but got {betas.shape}" + ) + return betas + + def __call__(self, net, images, labels, augment_pipe=None): + """ + Calculate and return the loss corresponding to the variance preserving + formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the noise samples added + to the t-th step of the diffusion process. + The noise level is determined by 'beta_t' based on the given parameters 'beta_start', + 'beta_end' and the current diffusion timestep t. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input fluid flow data samples to the neural network. + + labels: torch.Tensor + Ground truth labels for the input fluid flow data samples. Not required for dfsr. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + t = torch.randint( + low=0, high=self.num_timesteps, size=(images.size(0) // 2 + 1,) + ).to(images.device) + t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[: images.size(0)] + e = torch.randn_like(images) + b = self.betas.to(images.device) + a = (1 - b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1) + x = images * a.sqrt() + e * (1.0 - a).sqrt() + + output = net(x, t, labels) + loss = (e - output).square() + + return loss + + +class RegressionLossCE: + """ + A regression loss function for deterministic predictions with probability + channels and lead time labels. Adapted from + :class:`physicsnemo.metrics.diffusion.loss.RegressionLoss`. In this version, + probability channels are evaluated using CrossEntropyLoss instead of + squared error. + Note: this loss does not apply any reduction. + + Attributes + ---------- + entropy : torch.nn.CrossEntropyLoss + Cross entropy loss function used for probability channels. + prob_channels : list[int] + List of channel indices to be treated as probability channels. + + Note + ---- + Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + prob_channels: list[int] = [4, 5, 6, 7, 8], + ): + """ + Arguments + ---------- + prob_channels: list[int], optional + List of channel indices from the target tensor to be treated as + probability channels. Cross entropy loss is computed over these + channels, while the remaining channels are treated as scalar + channels and the squared error loss is computed over them. By + default, [4, 5, 6, 7, 8]. + """ + self.entropy = torch.nn.CrossEntropyLoss(reduction="none") + self.prob_channels = prob_channels + + def __call__( + self, + net: torch.nn.Module, + img_clean: torch.Tensor, + img_lr: torch.Tensor, + lead_time_label: Optional[torch.Tensor] = None, + augment_pipe: Optional[ + Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]] + ] = None, + ) -> torch.Tensor: + """ + Calculate and return the loss for deterministic + predictions, treating specific channels as probability distributions. + + Parameters + ---------- + net : torch.nn.Module + The neural network model that will make predictions. + Expected signature: `net(input, img_lr, lead_time_label=lead_time_label, augment_labels=augment_labels)`, + where: + input (torch.Tensor): Tensor of shape (B, C_hr, H, W). Zero-filled. + y_lr (torch.Tensor): Low-resolution input of shape (B, C_lr, H, W) + lead_time_label (torch.Tensor, optional): Optional lead time + labels. If provided, should be of shape (B,). + augment_labels (torch.Tensor, optional): Optional augmentation + labels, returned by `augment_pipe`. + Returns: + torch.Tensor: Predictions of shape (B, C_hr, H, W) + + img_clean : torch.Tensor + High-resolution input images of shape (B, C_hr, H, W). + Used as ground truth and for data augmentation if `augment_pipe` is provided. + + img_lr : torch.Tensor + Low-resolution input images of shape (B, C_lr, H, W). + Used as input to the neural network. + + lead_time_label : Optional[torch.Tensor], optional + Lead time labels for temporal predictions, by default None. + Shape can vary based on model requirements, typically (B,) or scalar. + + augment_pipe : Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]]] + Data augmentation function. + Expected signature: + img_tot (torch.Tensor): Concatenated high and low resolution + images of shape (B, C_hr+C_lr, H, W). + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - Augmented images of shape (B, C_hr+C_lr, H, W) + - Optional augmentation labels + + Returns + ------- + torch.Tensor + A tensor of shape (B, C_loss, H, W) representing the pixel-wise + loss., where `C_loss = C_hr - len(prob_channels) + 1`. More + specifically, the last channel of the output tensor corresponds to + the cross-entropy loss computed over the channels specified in + `prob_channels`, while the first `C_hr - len(prob_channels)` + channels of the output tensor correspond to the squared error loss. + """ + all_channels = list(range(img_clean.shape[1])) # [0, 1, 2, ..., 10] + scalar_channels = [ + item for item in all_channels if item not in self.prob_channels + ] + weight = ( + 1.0 # (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 + ) + + img_tot = torch.cat((img_clean, img_lr), dim=1) + y_tot, augment_labels = ( + augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) + ) + y = y_tot[:, : img_clean.shape[1], :, :] + y_lr = y_tot[:, img_clean.shape[1] :, :, :] + + input = torch.zeros_like(y, device=img_clean.device) + + if lead_time_label is not None: + D_yn = net( + input, + y_lr, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + D_yn = net( + input, + y_lr, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + loss1 = weight * (D_yn[:, scalar_channels] - y[:, scalar_channels]) ** 2 + loss2 = ( + weight + * self.entropy(D_yn[:, self.prob_channels], y[:, self.prob_channels])[ + :, None + ] + ) + loss = torch.cat((loss1, loss2), dim=1) + return loss \ No newline at end of file diff --git a/src/hirad/models/__init__.py b/src/hirad/models/__init__.py new file mode 100644 index 0000000..b00a477 --- /dev/null +++ b/src/hirad/models/__init__.py @@ -0,0 +1,14 @@ +from .layers import ( + Linear, + Conv2d, + GroupNorm, + AttentionOp, + UNetBlock, + PositionalEmbedding, + FourierEmbedding +) +from .meta import ModelMetaData +from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd +from .dhariwal_unet import DhariwalUNet +from .unet import UNet +from .preconditioning import EDMPrecondSuperResolution, EDMPrecondSR, EDMPrecond diff --git a/src/hirad/models/dhariwal_unet.py b/src/hirad/models/dhariwal_unet.py new file mode 100644 index 0000000..3880cd0 --- /dev/null +++ b/src/hirad/models/dhariwal_unet.py @@ -0,0 +1,259 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model architectures used in the paper "Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch +from torch.nn.functional import silu +import torch.nn as nn + +from .layers import ( + Conv2d, + GroupNorm, + Linear, + PositionalEmbedding, + UNetBlock, +) +from .meta import ModelMetaData + + +@dataclass +class MetaData(ModelMetaData): + name: str = "DhariwalUNet" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = True + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class DhariwalUNet(nn.Module): + """ + Reimplementation of the ADM architecture, a U-Net variant, with optional + self-attention. + + This model supports conditional and unconditional setups, as well as several + options for various internal architectural choices such as encoder and decoder + type, embedding type, etc., making it flexible and adaptable to different tasks + and configurations. + + Parameters + ----------- + img_resolution : int + The resolution of the input/output image. + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels in the output image. + label_dim : int, optional + Number of class labels; 0 indicates an unconditional model. By default 0. + augment_dim : int, optional + Dimensionality of augmentation labels; 0 means no augmentation. By default 0. + model_channels : int, optional + Base multiplier for the number of channels across the network, by default 192. + channel_mult : List[int], optional + Per-resolution multipliers for the number of channels. By default [1,2,3,4]. + channel_mult_emb : int, optional + Multiplier for the dimensionality of the embedding vector. By default 4. + num_blocks : int, optional + Number of residual blocks per resolution. By default 3. + attn_resolutions : List[int], optional + Resolutions at which self-attention layers are applied. By default [32, 16, 8]. + dropout : float, optional + Dropout probability applied to intermediate activations. By default 0.10. + label_dropout : float, optional + Dropout probability of class labels for classifier-free guidance. By default 0.0. + + Reference + ---------- + Reference: Dhariwal, P. and Nichol, A., 2021. Diffusion models beat gans on image + synthesis. Advances in neural information processing systems, 34, pp.8780-8794. + + Note + ----- + Equivalent to the original implementation by Dhariwal and Nichol, available at + https://github.com/openai/guided-diffusion + + Example + -------- + >>> model = DhariwalUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> noise_labels = torch.randn([1]) + >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> input_image = torch.ones([1, 2, 16, 16]) + >>> output_image = model(input_image, noise_labels, class_labels) + """ + + def __init__( + self, + img_resolution: int, + in_channels: int, + out_channels: int, + label_dim: int = 0, + augment_dim: int = 0, + model_channels: int = 192, + channel_mult: List[int] = [1, 2, 3, 4], + channel_mult_emb: int = 4, + num_blocks: int = 3, + attn_resolutions: List[int] = [32, 16, 8], + dropout: float = 0.10, + label_dropout: float = 0.0, + ): + super().__init__(meta=MetaData()) + self.label_dropout = label_dropout + emb_channels = model_channels * channel_mult_emb + init = dict( + init_mode="kaiming_uniform", + init_weight=np.sqrt(1 / 3), + init_bias=np.sqrt(1 / 3), + ) + init_zero = dict(init_mode="kaiming_uniform", init_weight=0, init_bias=0) + block_kwargs = dict( + emb_channels=emb_channels, + channels_per_head=64, + dropout=dropout, + init=init, + init_zero=init_zero, + ) + + # Mapping. + self.map_noise = PositionalEmbedding(num_channels=model_channels) + self.map_augment = ( + Linear( + in_features=augment_dim, + out_features=model_channels, + bias=False, + **init_zero, + ) + if augment_dim + else None + ) + self.map_layer0 = Linear( + in_features=model_channels, out_features=emb_channels, **init + ) + self.map_layer1 = Linear( + in_features=emb_channels, out_features=emb_channels, **init + ) + self.map_label = ( + Linear( + in_features=label_dim, + out_features=emb_channels, + bias=False, + init_mode="kaiming_normal", + init_weight=np.sqrt(label_dim), + ) + if label_dim + else None + ) + + # Encoder. + self.enc = torch.nn.ModuleDict() + cout = in_channels + for level, mult in enumerate(channel_mult): + res = img_resolution >> level + if level == 0: + cin = cout + cout = model_channels * mult + self.enc[f"{res}x{res}_conv"] = Conv2d( + in_channels=cin, out_channels=cout, kernel=3, **init + ) + else: + self.enc[f"{res}x{res}_down"] = UNetBlock( + in_channels=cout, out_channels=cout, down=True, **block_kwargs + ) + for idx in range(num_blocks): + cin = cout + cout = model_channels * mult + self.enc[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, + out_channels=cout, + attention=(res in attn_resolutions), + **block_kwargs, + ) + skips = [block.out_channels for block in self.enc.values()] + + # Decoder. + self.dec = torch.nn.ModuleDict() + for level, mult in reversed(list(enumerate(channel_mult))): + res = img_resolution >> level + if level == len(channel_mult) - 1: + self.dec[f"{res}x{res}_in0"] = UNetBlock( + in_channels=cout, out_channels=cout, attention=True, **block_kwargs + ) + self.dec[f"{res}x{res}_in1"] = UNetBlock( + in_channels=cout, out_channels=cout, **block_kwargs + ) + else: + self.dec[f"{res}x{res}_up"] = UNetBlock( + in_channels=cout, out_channels=cout, up=True, **block_kwargs + ) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = model_channels * mult + self.dec[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, + out_channels=cout, + attention=(res in attn_resolutions), + **block_kwargs, + ) + self.out_norm = GroupNorm(num_channels=cout) + self.out_conv = Conv2d( + in_channels=cout, out_channels=out_channels, kernel=3, **init_zero + ) + + def forward(self, x, noise_labels, class_labels, augment_labels=None): + # Mapping. + emb = self.map_noise(noise_labels) + if self.map_augment is not None and augment_labels is not None: + emb = emb + self.map_augment(augment_labels) + emb = silu(self.map_layer0(emb)) + emb = self.map_layer1(emb) + if self.map_label is not None: + tmp = class_labels + if self.training and self.label_dropout: + tmp = tmp * ( + torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout + ).to(tmp.dtype) + emb = emb + self.map_label(tmp) + emb = silu(emb) + + # Encoder. + skips = [] + for block in self.enc.values(): + x = block(x, emb) if isinstance(block, UNetBlock) else block(x) + skips.append(x) + + # Decoder. + for block in self.dec.values(): + if x.shape[1] != block.in_channels: + x = torch.cat([x, skips.pop()], dim=1) + x = block(x, emb) + x = self.out_conv(silu(self.out_norm(x))) + return x diff --git a/src/hirad/models/layers.py b/src/hirad/models/layers.py new file mode 100644 index 0000000..d7e63d7 --- /dev/null +++ b/src/hirad/models/layers.py @@ -0,0 +1,790 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model architecture layers used in the paper "Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +import contextlib +import importlib +from typing import Any, Dict, List + +import numpy as np +import nvtx +import torch +import torch.cuda.amp as amp +from einops import rearrange +from torch.nn.functional import elu, gelu, leaky_relu, relu, sigmoid, silu, tanh + +from hirad.utils.model_utils import weight_init + +_is_apex_available = False +if torch.cuda.is_available(): + try: + apex_gn_module = importlib.import_module("apex.contrib.group_norm") + ApexGroupNorm = getattr(apex_gn_module, "GroupNorm") + _is_apex_available = True + except ImportError: + pass + +class Linear(torch.nn.Module): + """ + A fully connected (dense) layer implementation. The layer's weights and biases can + be initialized using custom initialization strategies like "kaiming_normal", + and can be further scaled by factors `init_weight` and `init_bias`. + + Parameters + ---------- + in_features : int + Size of each input sample. + out_features : int + Size of each output sample. + bias : bool, optional + The biases of the layer. If set to `None`, the layer will not learn an additive + bias. By default True. + init_mode : str, optional (default="kaiming_normal") + The mode/type of initialization to use for weights and biases. Supported modes + are: + - "xavier_uniform": Xavier (Glorot) uniform initialization. + - "xavier_normal": Xavier (Glorot) normal initialization. + - "kaiming_uniform": Kaiming (He) uniform initialization. + - "kaiming_normal": Kaiming (He) normal initialization. + By default "kaiming_normal". + init_weight : float, optional + A scaling factor to multiply with the initialized weights. By default 1. + init_bias : float, optional + A scaling factor to multiply with the initialized biases. By default 0. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + init_mode: str = "kaiming_normal", + init_weight: int = 1, + init_bias: int = 0, + amp_mode: bool = False, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.amp_mode = amp_mode + init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) + self.weight = torch.nn.Parameter( + weight_init([out_features, in_features], **init_kwargs) * init_weight + ) + self.bias = ( + torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) + if bias + else None + ) + + def forward(self, x): + weight, bias = self.weight, self.bias + # pdb.set_trace() + if not self.amp_mode: + if self.weight is not None and self.weight.dtype != x.dtype: + weight = self.weight.to(x.dtype) + if self.bias is not None and self.bias.dtype != x.dtype: + bias = self.bias.to(x.dtype) + x = x @ weight.t() + if self.bias is not None: + x = x.add_(bias) + return x + + +class Conv2d(torch.nn.Module): + """ + A custom 2D convolutional layer implementation with support for up-sampling, + down-sampling, and custom weight and bias initializations. The layer's weights + and biases canbe initialized using custom initialization strategies like + "kaiming_normal", and can be further scaled by factors `init_weight` and + `init_bias`. + + Parameters + ---------- + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels produced by the convolution. + kernel : int + Size of the convolving kernel. + bias : bool, optional + The biases of the layer. If set to `None`, the layer will not learn an + additive bias. By default True. + up : bool, optional + Whether to perform up-sampling. By default False. + down : bool, optional + Whether to perform down-sampling. By default False. + resample_filter : List[int], optional + Filter to be used for resampling. By default [1, 1]. + fused_resample : bool, optional + If True, performs fused up-sampling and convolution or fused down-sampling + and convolution. By default False. + init_mode : str, optional (default="kaiming_normal") + init_mode : str, optional (default="kaiming_normal") + The mode/type of initialization to use for weights and biases. Supported modes + are: + - "xavier_uniform": Xavier (Glorot) uniform initialization. + - "xavier_normal": Xavier (Glorot) normal initialization. + - "kaiming_uniform": Kaiming (He) uniform initialization. + - "kaiming_normal": Kaiming (He) normal initialization. + By default "kaiming_normal". + init_weight : float, optional + A scaling factor to multiply with the initialized weights. By default 1.0. + init_bias : float, optional + A scaling factor to multiply with the initialized biases. By default 0.0. + fused_conv_bias: bool, optional + A boolean flag indicating whether bias will be passed as a parameter of conv2d. By default False. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel: int, + bias: bool = True, + up: bool = False, + down: bool = False, + resample_filter: List[int] = [1, 1], + fused_resample: bool = False, + init_mode: str = "kaiming_normal", + init_weight: float = 1.0, + init_bias: float = 0.0, + fused_conv_bias: bool = False, + amp_mode: bool = False, + ): + if up and down: + raise ValueError("Both 'up' and 'down' cannot be true at the same time.") + if not kernel and fused_conv_bias: + print( + "Warning: Kernel is required when fused_conv_bias is enabled. Setting fused_conv_bias to False." + ) + fused_conv_bias = False + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.up = up + self.down = down + self.fused_resample = fused_resample + self.fused_conv_bias = fused_conv_bias + self.amp_mode = amp_mode + init_kwargs = dict( + mode=init_mode, + fan_in=in_channels * kernel * kernel, + fan_out=out_channels * kernel * kernel, + ) + self.weight = ( + torch.nn.Parameter( + weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs) + * init_weight + ) + if kernel + else None + ) + self.bias = ( + torch.nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias) + if kernel and bias + else None + ) + f = torch.as_tensor(resample_filter, dtype=torch.float32) + f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square() + self.register_buffer("resample_filter", f if up or down else None) + + def forward(self, x): + weight, bias, resample_filter = self.weight, self.bias, self.resample_filter + if not self.amp_mode: + if self.weight is not None and self.weight.dtype != x.dtype: + weight = self.weight.to(x.dtype) + if self.bias is not None and self.bias.dtype != x.dtype: + bias = self.bias.to(x.dtype) + if ( + self.resample_filter is not None + and self.resample_filter.dtype != x.dtype + ): + resample_filter = self.resample_filter.to(x.dtype) + + w = weight if weight is not None else None + b = bias if bias is not None else None + f = resample_filter if resample_filter is not None else None + w_pad = w.shape[-1] // 2 if w is not None else 0 + f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0 + + if self.fused_resample and self.up and w is not None: + x = torch.nn.functional.conv_transpose2d( + x, + f.mul(4).tile([self.in_channels, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=max(f_pad - w_pad, 0), + ) + if self.fused_conv_bias: + x = torch.nn.functional.conv2d( + x, w, padding=max(w_pad - f_pad, 0), bias=b + ) + else: + x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0)) + elif self.fused_resample and self.down and w is not None: + x = torch.nn.functional.conv2d(x, w, padding=w_pad + f_pad) + if self.fused_conv_bias: + x = torch.nn.functional.conv2d( + x, + f.tile([self.out_channels, 1, 1, 1]), + groups=self.out_channels, + stride=2, + bias=b, + ) + else: + x = torch.nn.functional.conv2d( + x, + f.tile([self.out_channels, 1, 1, 1]), + groups=self.out_channels, + stride=2, + ) + else: + if self.up: + x = torch.nn.functional.conv_transpose2d( + x, + f.mul(4).tile([self.in_channels, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=f_pad, + ) + if self.down: + x = torch.nn.functional.conv2d( + x, + f.tile([self.in_channels, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=f_pad, + ) + + #TODO during inference, model breaks here for some reason + # current fix is to disable torch.backends.cudnn.enabled = False + if w is not None: # ask in corrdiff channel whether w will ever be none + if self.fused_conv_bias: + x = torch.nn.functional.conv2d(x, w, padding=w_pad, bias=b) + else: + x = torch.nn.functional.conv2d(x, w, padding=w_pad) + if b is not None and not self.fused_conv_bias: + x = x.add_(b.reshape(1, -1, 1, 1)) + return x + + +class GroupNorm(torch.nn.Module): + """ + A custom Group Normalization layer implementation. + + Group Normalization (GN) divides the channels of the input tensor into groups and + normalizes the features within each group independently. It does not require the + batch size as in Batch Normalization, making itsuitable for batch sizes of any size + or even for batch-free scenarios. + + Parameters + ---------- + num_channels : int + Number of channels in the input tensor. + num_groups : int, optional + Desired number of groups to divide the input channels, by default 32. + This might be adjusted based on the `min_channels_per_group`. + min_channels_per_group : int, optional + Minimum channels required per group. This ensures that no group has fewer + channels than this number. By default 4. + eps : float, optional + A small number added to the variance to prevent division by zero, by default + 1e-5. + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + fused_act : bool, optional + Whether to fuse the activation function with GroupNorm. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. + Notes + ----- + If `num_channels` is not divisible by `num_groups`, the actual number of groups + might be adjusted to satisfy the `min_channels_per_group` condition. + """ + + def __init__( + self, + num_channels: int, + num_groups: int = 32, + min_channels_per_group: int = 4, + eps: float = 1e-5, + use_apex_gn: bool = False, + fused_act: bool = False, + act: str = None, + amp_mode: bool = False, + ): + if fused_act and act is None: + raise ValueError("'act' must be specified when 'fused_act' is set to True.") + + super().__init__() + self.num_groups = min(num_groups, num_channels // min_channels_per_group) + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(num_channels)) + self.bias = torch.nn.Parameter(torch.zeros(num_channels)) + if use_apex_gn and not _is_apex_available: + raise ValueError("'apex' is not installed, set `use_apex_gn=False`") + self.use_apex_gn = use_apex_gn + self.fused_act = fused_act + self.act = act.lower() if act else act + self.act_fn = None + self.amp_mode = amp_mode + if self.use_apex_gn: + if self.act: + self.gn = ApexGroupNorm( + num_groups=self.num_groups, + num_channels=num_channels, + eps=self.eps, + affine=True, + act=self.act, + ) + + else: + self.gn = ApexGroupNorm( + num_groups=self.num_groups, + num_channels=num_channels, + eps=self.eps, + affine=True, + ) + if self.fused_act: + self.act_fn = self.get_activation_function() + + def forward(self, x): + weight, bias = self.weight, self.bias + if not self.amp_mode: + if not self.use_apex_gn: + if weight.dtype != x.dtype: + weight = self.weight.to(x.dtype) + if bias.dtype != x.dtype: + bias = self.bias.to(x.dtype) + if self.use_apex_gn: + x = self.gn(x) + elif self.training: + # Use default torch implementation of GroupNorm for training + # This does not support channels last memory format + x = torch.nn.functional.group_norm( + x, + num_groups=self.num_groups, + weight=weight, + bias=bias, + eps=self.eps, + ) + if self.fused_act: + x = self.act_fn(x) + else: + # Use custom GroupNorm implementation that supports channels last + # memory layout for inference + x = x.float() + x = rearrange(x, "b (g c) h w -> b g c h w", g=self.num_groups) + + mean = x.mean(dim=[2, 3, 4], keepdim=True) + var = x.var(dim=[2, 3, 4], keepdim=True) + + x = (x - mean) * (var + self.eps).rsqrt() + x = rearrange(x, "b g c h w -> b (g c) h w") + + weight = rearrange(weight, "c -> 1 c 1 1") + bias = rearrange(bias, "c -> 1 c 1 1") + x = x * weight + bias + + if self.fused_act: + x = self.act_fn(x) + return x + + def get_activation_function(self): + """ + Get activation function given string input + """ + + activation_map = { + "silu": silu, + "relu": relu, + "leaky_relu": leaky_relu, + "sigmoid": sigmoid, + "tanh": tanh, + "gelu": gelu, + "elu": elu, + } + + act_fn = activation_map.get(self.act, None) + if act_fn is None: + raise ValueError(f"Unknown activation function: {self.act}") + return act_fn + + +class AttentionOp(torch.autograd.Function): + """ + Attention weight computation, i.e., softmax(Q^T * K). + Performs all computation using FP32, but uses the original datatype for + inputs/outputs/gradients to conserve memory. + """ + + @staticmethod + def forward(ctx, q, k): + w = ( + torch.einsum( + "ncq,nck->nqk", + q.to(torch.float32), + (k / torch.sqrt(torch.tensor(k.shape[1]))).to(torch.float32), + ) + .softmax(dim=2) + .to(q.dtype) + ) + ctx.save_for_backward(q, k, w) + return w + + @staticmethod + def backward(ctx, dw): + q, k, w = ctx.saved_tensors + db = torch._softmax_backward_data( + grad_output=dw.to(torch.float32), + output=w.to(torch.float32), + dim=2, + input_dtype=torch.float32, + ) + + dq = torch.einsum("nck,nqk->ncq", k.to(torch.float32), db).to( + q.dtype + ) / np.sqrt(k.shape[1]) + dk = torch.einsum("ncq,nqk->nck", q.to(torch.float32), db).to( + k.dtype + ) / np.sqrt(k.shape[1]) + return dq, dk + + +class UNetBlock(torch.nn.Module): + """ + Unified U-Net block with optional up/downsampling and self-attention. Represents + the union of all features employed by the DDPM++, NCSN++, and ADM architectures. + + Parameters: + ----------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + emb_channels : int + Number of embedding channels. + up : bool, optional + If True, applies upsampling in the forward pass. By default False. + down : bool, optional + If True, applies downsampling in the forward pass. By default False. + attention : bool, optional + If True, enables the self-attention mechanism in the block. By default False. + num_heads : int, optional + Number of attention heads. If None, defaults to `out_channels // 64`. + channels_per_head : int, optional + Number of channels per attention head. By default 64. + dropout : float, optional + Dropout probability. By default 0.0. + skip_scale : float, optional + Scale factor applied to skip connections. By default 1.0. + eps : float, optional + Epsilon value used for normalization layers. By default 1e-5. + resample_filter : List[int], optional + Filter for resampling layers. By default [1, 1]. + resample_proj : bool, optional + If True, resampling projection is enabled. By default False. + adaptive_scale : bool, optional + If True, uses adaptive scaling in the forward pass. By default True. + init : dict, optional + Initialization parameters for convolutional and linear layers. + init_zero : dict, optional + Initialization parameters with zero weights for certain layers. By default + {'init_weight': 0}. + init_attn : dict, optional + Initialization parameters specific to attention mechanism layers. + Defaults to 'init' if not provided. + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + fused_conv_bias: bool, optional + A boolean flag indicating whether bias will be passed as a parameter of conv2d. By default False. + profile_mode: + A boolean flag indicating whether to enable all nvtx annotations during profiling. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + emb_channels: int, + up: bool = False, + down: bool = False, + attention: bool = False, + num_heads: int = None, + channels_per_head: int = 64, + dropout: float = 0.0, + skip_scale: float = 1.0, + eps: float = 1e-5, + resample_filter: List[int] = [1, 1], + resample_proj: bool = False, + adaptive_scale: bool = True, + init: Dict[str, Any] = dict(), + init_zero: Dict[str, Any] = dict(init_weight=0), + init_attn: Any = None, + use_apex_gn: bool = False, + act: str = "silu", + fused_conv_bias: bool = False, + profile_mode: bool = False, + amp_mode: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.emb_channels = emb_channels + self.num_heads = ( + 0 + if not attention + else num_heads + if num_heads is not None + else out_channels // channels_per_head + ) + self.dropout = dropout + self.skip_scale = skip_scale + self.adaptive_scale = adaptive_scale + + self.profile_mode = profile_mode + self.amp_mode = amp_mode + self.norm0 = GroupNorm( + num_channels=in_channels, + eps=eps, + use_apex_gn=use_apex_gn, + fused_act=True, + act=act, + amp_mode=amp_mode, + ) + self.conv0 = Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel=3, + up=up, + down=down, + resample_filter=resample_filter, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, + **init, + ) + self.affine = Linear( + in_features=emb_channels, + out_features=out_channels * (2 if adaptive_scale else 1), + amp_mode=amp_mode, + **init, + ) + if self.adaptive_scale: + self.norm1 = GroupNorm( + num_channels=out_channels, + eps=eps, + use_apex_gn=use_apex_gn, + amp_mode=amp_mode, + ) + else: + self.norm1 = GroupNorm( + num_channels=out_channels, + eps=eps, + use_apex_gn=use_apex_gn, + act=act, + fused_act=True, + amp_mode=amp_mode, + ) + self.conv1 = Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel=3, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, + **init_zero, + ) + + self.skip = None + if out_channels != in_channels or up or down: + kernel = 1 if resample_proj or out_channels != in_channels else 0 + fused_conv_bias = fused_conv_bias if kernel != 0 else False + self.skip = Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel=kernel, + up=up, + down=down, + resample_filter=resample_filter, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, + **init, + ) + + if self.num_heads: + self.norm2 = GroupNorm( + num_channels=out_channels, + eps=eps, + use_apex_gn=use_apex_gn, + amp_mode=amp_mode, + ) + self.qkv = Conv2d( + in_channels=out_channels, + out_channels=out_channels * 3, + kernel=1, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, + **(init_attn if init_attn is not None else init), + ) + self.proj = Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel=1, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, + **init_zero, + ) + + def forward(self, x, emb): + with nvtx.annotate( + message="UNetBlock", color="purple" + ) if self.profile_mode else contextlib.nullcontext(): + orig = x + x = self.conv0(self.norm0(x)) + params = self.affine(emb).unsqueeze(2).unsqueeze(3) + if not self.amp_mode: + if params.dtype != x.dtype: + params = params.to(x.dtype) + + if self.adaptive_scale: + scale, shift = params.chunk(chunks=2, dim=1) + x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) + else: + x = self.norm1(x.add_(params)) + + x = self.conv1( + torch.nn.functional.dropout(x, p=self.dropout, training=self.training) + ) + x = x.add_(self.skip(orig) if self.skip is not None else orig) + x = x * self.skip_scale + + if self.num_heads: + q, k, v = ( + self.qkv(self.norm2(x)) + .reshape( + x.shape[0], self.num_heads, x.shape[1] // self.num_heads, 3, -1 + ) + .unbind(3) + ) + # w = AttentionOp.apply(q, k) + # a = torch.einsum("nqk,nck->ncq", w, v) + # Compute attention in one step + with amp.autocast(enabled=self.amp_mode): + attn = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = self.proj(attn.reshape(*x.shape)).add_(x) + x = x * self.skip_scale + + return x + + +class PositionalEmbedding(torch.nn.Module): + """ + A module for generating positional embeddings based on timesteps. + This embedding technique is employed in the DDPM++ and ADM architectures. + + Parameters: + ----------- + num_channels : int + Number of channels for the embedding. + max_positions : int, optional + Maximum number of positions for the embeddings, by default 10000. + endpoint : bool, optional + If True, the embedding considers the endpoint. By default False. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. + + """ + + def __init__( + self, + num_channels: int, + max_positions: int = 10000, + endpoint: bool = False, + amp_mode: bool = False, + ): + super().__init__() + self.num_channels = num_channels + self.max_positions = max_positions + self.endpoint = endpoint + self.amp_mode = amp_mode + + def forward(self, x): + freqs = torch.arange( + start=0, end=self.num_channels // 2, dtype=torch.float32, device=x.device + ) + freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) + freqs = (1 / self.max_positions) ** freqs + if not self.amp_mode: + if freqs.dtype != x.dtype: + freqs = freqs.to(x.dtype) + x = x.ger(freqs) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x + + +class FourierEmbedding(torch.nn.Module): + """ + Generates Fourier embeddings for timesteps, primarily used in the NCSN++ + architecture. + + This class generates embeddings by first multiplying input tensor `x` and + internally stored random frequencies, and then concatenating the cosine and sine of + the resultant. + + Parameters: + ----------- + num_channels : int + The number of channels in the embedding. The final embedding size will be + 2 * num_channels because of concatenation of cosine and sine results. + scale : int, optional + A scale factor applied to the random frequencies, controlling their range + and thereby the frequency of oscillations in the embedding space. By default 16. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. + """ + + def __init__(self, num_channels: int, scale: int = 16, amp_mode: bool = False): + super().__init__() + self.register_buffer("freqs", torch.randn(num_channels // 2) * scale) + self.amp_mode = amp_mode + + def forward(self, x): + freqs = self.freqs + if not self.amp_mode: + if x.dtype != self.freqs.dtype: + freqs = self.freqs.to(x.dtype) + + x = x.ger((2 * np.pi * freqs)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x diff --git a/src/hirad/models/meta.py b/src/hirad/models/meta.py new file mode 100644 index 0000000..aab8e45 --- /dev/null +++ b/src/hirad/models/meta.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + + +@dataclass +class ModelMetaData: + """Data class for storing essential meta data needed for all Hirad Models""" + + # Model info + name: str = "HiradModule" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp: bool = False + amp_cpu: bool = None + amp_gpu: bool = None + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + onnx_gpu: bool = None + onnx_cpu: bool = None + onnx_runtime: bool = False + trt: bool = False + # Physics informed + var_dim: int = -1 + func_torch: bool = False + auto_grad: bool = False + + def __post_init__(self): + self.amp_cpu = self.amp if self.amp_cpu is None else self.amp_cpu + self.amp_gpu = self.amp if self.amp_gpu is None else self.amp_gpu + self.onnx_cpu = self.onnx if self.onnx_cpu is None else self.onnx_cpu + self.onnx_gpu = self.onnx if self.onnx_gpu is None else self.onnx_gpu diff --git a/src/hirad/models/preconditioning.py b/src/hirad/models/preconditioning.py new file mode 100644 index 0000000..74496a5 --- /dev/null +++ b/src/hirad/models/preconditioning.py @@ -0,0 +1,1388 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Preconditioning schemes used in the paper"Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +import importlib +import warnings +from dataclasses import dataclass +from typing import List, Literal, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from .meta import ModelMetaData + +network_module = importlib.import_module("hirad.models") + + +@dataclass +class VPPrecondMetaData(ModelMetaData): + """VPPrecond meta data""" + + name: str = "VPPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class VPPrecond(nn.Module): + """ + Preconditioning corresponding to the variance preserving (VP) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + beta_d : float + Extent of the noise level schedule, by default 19.9. + beta_min : float + Initial slope of the noise level schedule, by default 0.1. + M : int + Original number of timesteps in the DDPM formulation, by default 1000. + epsilon_t : float + Minimum t-value used during training, by default 1e-5. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + beta_d: float = 19.9, + beta_min: float = 0.1, + M: int = 1000, + epsilon_t: float = 1e-5, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__() #meta=VPPrecondMetaData + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.beta_d = beta_d + self.beta_min = beta_min + self.M = M + self.epsilon_t = epsilon_t + self.sigma_min = float(self.sigma(epsilon_t)) + self.sigma_max = float(self.sigma(1)) + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma**2 + 1).sqrt() + c_noise = (self.M - 1) * self.sigma_inv(sigma) + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def sigma(self, t: Union[float, torch.Tensor]): + """ + Compute the sigma(t) value for a given t based on the VP formulation. + + The function calculates the noise level schedule for the diffusion process based + on the given parameters `beta_d` and `beta_min`. + + Parameters + ---------- + t : Union[float, torch.Tensor] + The timestep or set of timesteps for which to compute sigma(t). + + Returns + ------- + torch.Tensor + The computed sigma(t) value(s). + """ + t = torch.as_tensor(t) + return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt() + + def sigma_inv(self, sigma: Union[float, torch.Tensor]): + """ + Compute the inverse of the sigma function for a given sigma. + + This function effectively calculates t from a given sigma(t) based on the + parameters `beta_d` and `beta_min`. + + Parameters + ---------- + sigma : Union[float, torch.Tensor] + The sigma(t) value or set of sigma(t) values for which to compute the + inverse. + + Returns + ------- + torch.Tensor + The computed t value(s) corresponding to the provided sigma(t). + """ + sigma = torch.as_tensor(sigma) + return ( + (self.beta_min**2 + 2 * self.beta_d * (1 + sigma**2).log()).sqrt() + - self.beta_min + ) / self.beta_d + + def round_sigma(self, sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +@dataclass +class VEPrecondMetaData(ModelMetaData): + """VEPrecond meta data""" + + name: str = "VEPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class VEPrecond(nn.Module): + """ + Preconditioning corresponding to the variance exploding (VE) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + sigma_min: float = 0.02, + sigma_max: float = 100.0, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__() #meta=VEPrecondMetaData + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = sigma + c_in = 1 + c_noise = (0.5 * sigma).log() + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def round_sigma(self, sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +@dataclass +class iDDPMPrecondMetaData(ModelMetaData): + """iDDPMPrecond meta data""" + + name: str = "iDDPMPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class iDDPMPrecond(nn.Module): + """ + Preconditioning corresponding to the improved DDPM (iDDPM) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + C_1 : float + Timestep adjustment at low noise levels., by default 0.001. + C_2 : float + Timestep adjustment at high noise levels., by default 0.008. + M: int + Original number of timesteps in the DDPM formulation, by default 1000. + model_type :str + Class name of the underlying model, by default "DhariwalUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Nichol, A.Q. and Dhariwal, P., 2021, July. Improved denoising diffusion + probabilistic models. In International Conference on Machine Learning + (pp. 8162-8171). PMLR. + """ + + def __init__( + self, + img_resolution, + img_channels, + label_dim=0, + use_fp16=False, + C_1=0.001, + C_2=0.008, + M=1000, + model_type="DhariwalUNet", + **model_kwargs, + ): + super().__init__() #meta=iDDPMPrecondMetaData + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.C_1 = C_1 + self.C_2 = C_2 + self.M = M + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels * 2, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + u = torch.zeros(M + 1) + for j in range(M, 0, -1): # M, ..., 1 + u[j - 1] = ( + (u[j] ** 2 + 1) + / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) + - 1 + ).sqrt() + self.register_buffer("u", u) + self.sigma_min = float(u[M - 1]) + self.sigma_max = float(u[0]) + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma**2 + 1).sqrt() + c_noise = ( + self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32) + ) + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x[:, : self.img_channels].to(torch.float32) + return D_x + + def alpha_bar(self, j): + """ + Compute the alpha_bar(j) value for a given j based on the iDDPM formulation. + + Parameters + ---------- + j : Union[int, torch.Tensor] + The timestep or set of timesteps for which to compute alpha_bar(j). + + Returns + ------- + torch.Tensor + The computed alpha_bar(j) value(s). + """ + j = torch.as_tensor(j) + return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2 + + def round_sigma(self, sigma, return_index=False): + """ + Round the provided sigma value(s) to the nearest value(s) in a + pre-defined set `u`. + + Parameters + ---------- + sigma : Union[float, list, torch.Tensor] + The sigma value(s) to round. + return_index : bool, optional + Whether to return the index/indices of the rounded value(s) in `u` instead + of the rounded value(s) themselves, by default False. + + Returns + ------- + torch.Tensor + The rounded sigma value(s) or their index/indices in `u`, depending on the + value of `return_index`. + """ + sigma = torch.as_tensor(sigma) + index = torch.cdist( + sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), + self.u.reshape(1, -1, 1), + ).argmin(2) + result = index if return_index else self.u[index.flatten()].to(sigma.dtype) + return result.reshape(sigma.shape).to(sigma.device) + + +@dataclass +class EDMPrecondMetaData(ModelMetaData): + """EDMPrecond meta data""" + + name: str = "EDMPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class EDMPrecond(nn.Module): + """ + Improved preconditioning proposed in the paper "Elucidating the Design Space of + Diffusion-Based Generative Models" (EDM) + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels (for both input and output). If your model + requires a different number of input or output chanels, + override this by passing either of the optional + img_in_channels or img_out_channels args + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.0. + sigma_max : float + Maximum supported noise level, by default inf. + sigma_data : float + Expected standard deviation of the training data, by default 0.5. + model_type :str + Class name of the underlying model, by default "DhariwalUNet". + img_in_channels: int + Optional setting for when number of input channels =/= number of output + channels. If set, will override img_channels for the input + This is useful in the case of additional (conditional) channels + img_out_channels: int + Optional setting for when number of input channels =/= number of output + channels. If set, will override img_channels for the output + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + """ + + def __init__( + self, + img_resolution, + img_channels, + label_dim=0, + use_fp16=False, + sigma_min=0.0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="DhariwalUNet", + img_in_channels=None, + img_out_channels=None, + **model_kwargs, + ): + super().__init__() #meta=EDMPrecondMetaData + self.img_resolution = img_resolution + if img_in_channels is not None: + img_in_channels = img_in_channels + else: + img_in_channels = img_channels + if img_out_channels is not None: + img_out_channels = img_out_channels + else: + img_out_channels = img_channels + + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_in_channels, + out_channels=img_out_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward( + self, + x, + sigma, + condition=None, + class_labels=None, + force_fp32=False, + **model_kwargs, + ): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 + + arg = c_in * x + + if condition is not None: + arg = torch.cat([arg, condition], dim=1) + + F_x = self.model( + arg.to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + @staticmethod + def round_sigma(sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + +@dataclass +class EDMPrecondSuperResolutionMetaData(ModelMetaData): + """EDMPrecondSR meta data""" + + name: str = "EDMPrecondSuperResolution" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class EDMPrecondSuperResolution(nn.Module): + """ + Improved preconditioning proposed in the paper "Elucidating the Design Space of + Diffusion-Based Generative Models" (EDM). + + This is a variant of `EDMPrecond` that is specifically designed for super-resolution + tasks. It wraps a neural network that predicts the denoised high-resolution image + given a noisy high-resolution image, and additional conditioning that includes a + low-resolution image, and a noise level. + + Parameters + ---------- + img_resolution : Union[int, Tuple[int, int]] + Spatial resolution `(H, W)` of the image. If a single int is provided, + the image is assumed to be square. + img_in_channels : int + Number of input channels in the low-resolution input image. + img_out_channels : int + Number of output channels in the high-resolution output image. + use_fp16 : bool, optional + Whether to use half-precision floating point (FP16) for model execution, + by default False. + model_type : str, optional + Class name of the underlying model. Must be one of the following: + 'SongUNet', 'SongUNetPosEmbd', 'SongUNetPosLtEmbd', 'DhariwalUNet'. + Defaults to 'SongUNetPosEmbd'. + sigma_data : float, optional + Expected standard deviation of the training data, by default 0.5. + sigma_min : float, optional + Minimum supported noise level, by default 0.0. + sigma_max : float, optional + Maximum supported noise level, by default inf. + **model_kwargs : dict + Keyword arguments passed to the underlying model `__init__` method. + + Note + ---- + References: + - Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + - Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + img_resolution: Union[int, Tuple[int, int]], + img_in_channels: int, + img_out_channels: int, + use_fp16: bool = False, + model_type: Literal[ + "SongUNetPosEmbd", "SongUNetPosLtEmbd", "SongUNet", "DhariwalUNet" + ] = "SongUNetPosEmbd", + sigma_data: float = 0.5, + sigma_min=0.0, + sigma_max=float("inf"), + **model_kwargs: dict, + ): + super().__init__() #meta=EDMPrecondSRMetaData + self.img_resolution = img_resolution + self.img_in_channels = img_in_channels + self.img_out_channels = img_out_channels + self.use_fp16 = use_fp16 + self.sigma_data = sigma_data + self.sigma_min = sigma_min + self.sigma_max = sigma_max + + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_in_channels + img_out_channels, + out_channels=img_out_channels, + **model_kwargs, + ) # TODO needs better handling + self.scaling_fn = self._scaling_fn + + @staticmethod + def _scaling_fn( + x: torch.Tensor, img_lr: torch.Tensor, c_in: torch.Tensor + ) -> torch.Tensor: + """ + Scale input tensors by first scaling the high-resolution tensor and then + concatenating with the low-resolution tensor. + + Parameters + ---------- + x : torch.Tensor + Noisy high-resolution image of shape (B, C_hr, H, W). + img_lr : torch.Tensor + Low-resolution image of shape (B, C_lr, H, W). + c_in : torch.Tensor + Scaling factor of shape (B, 1, 1, 1). + + Returns + ------- + torch.Tensor + Scaled and concatenated tensor of shape (B, C_in+C_out, H, W). + """ + return torch.cat([c_in * x, img_lr.to(x.dtype)], dim=1) + + def forward( + self, + x: torch.Tensor, + img_lr: torch.Tensor, + sigma: torch.Tensor, + force_fp32: bool = False, + **model_kwargs: dict, + ) -> torch.Tensor: + """ + Forward pass of the EDMPrecondSuperResolution model wrapper. + + This method applies the EDM preconditioning to compute the denoised image + from a noisy high-resolution image and low-resolution conditioning image. + + Parameters + ---------- + x : torch.Tensor + Noisy high-resolution image of shape (B, C_hr, H, W). The number of + channels `C_hr` should be equal to `img_out_channels`. + img_lr : torch.Tensor + Low-resolution conditioning image of shape (B, C_lr, H, W). The number + of channels `C_lr` should be equal to `img_in_channels`. + sigma : torch.Tensor + Noise level of shape (B) or (B, 1) or (B, 1, 1, 1). + force_fp32 : bool, optional + Whether to force FP32 precision regardless of the `use_fp16` attribute, + by default False. + **model_kwargs : dict + Additional keyword arguments to pass to the underlying model + `self.model` forward method. + + Returns + ------- + torch.Tensor + Denoised high-resolution image of shape (B, C_hr, H, W). + + Raises + ------ + ValueError + If the model output dtype doesn't match the expected dtype. + """ + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 + + if img_lr is None: + arg = c_in * x + else: + arg = self.scaling_fn(x, img_lr, c_in) + arg = arg.to(dtype) + + F_x = self.model( + arg, + c_noise.flatten(), + class_labels=None, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + @staticmethod + def round_sigma(sigma: Union[float, List, torch.Tensor]) -> torch.Tensor: + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float, List, torch.Tensor] + Sigma value(s) to convert. + + Returns + ------- + torch.Tensor + Tensor representation of sigma values. + + See Also + -------- + EDMPrecond.round_sigma + """ + return EDMPrecond.round_sigma(sigma) + + @property + def amp_mode(self): + """ + Return the *amp_mode* flag of the wrapped model or *None*. + """ + return getattr(self.model, "amp_mode", None) + + @amp_mode.setter + def amp_mode(self, value: bool): + """ + Propagate *amp_mode* to the model and all its sub-modules. + """ + + if not isinstance(value, bool): + raise TypeError("amp_mode must be a boolean value.") + + if hasattr(self.model, "amp_mode"): + self.model.amp_mode = value + + for sub_module in self.model.modules(): + if hasattr(sub_module, "amp_mode"): + sub_module.amp_mode = value + +# NOTE: This is a deprecated version of the EDMPrecondSuperResolution model. +# This was used to maintain backwards compatibility and allow loading old models. +@dataclass +class EDMPrecondSRMetaData(ModelMetaData): + """EDMPrecondSR meta data""" + + name: str = "EDMPrecondSR" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class EDMPrecondSR(EDMPrecondSuperResolution): + """ + Improved preconditioning proposed in the paper "Elucidating the Design Space of + Diffusion-Based Generative Models" (EDM) for super-resolution tasks + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + img_in_channels : int + Number of input color channels. + img_out_channels : int + Number of output color channels. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.0. + sigma_max : float + Maximum supported noise level, by default inf. + sigma_data : float + Expected standard deviation of the training data, by default 0.5. + model_type :str + Class name of the underlying model, by default "SongUNetPosEmbd". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + References: + - Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + - Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + img_resolution, + img_channels, #deprecated + img_in_channels, + img_out_channels, + use_fp16=False, + sigma_min=0.0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="SongUNetPosEmbd", + scale_cond_input=True, #deprecated + **model_kwargs, + ): + warnings.warn( + "EDMPrecondSR is deprecated and will be removed in a future version. " + "Please use EDMPrecondSuperResolution instead.", + DeprecationWarning, + stacklevel=2, + ) + + if scale_cond_input: + warnings.warn( + "scale_cond_input=True does not properly scale the conditional input. " + "(see https://github.com/NVIDIA/modulus/issues/229). " + "This setup will be deprecated. " + "Please set scale_cond_input=False.", + DeprecationWarning, + ) + + super().__init__( + img_resolution=img_resolution, + img_in_channels=img_in_channels, + img_out_channels=img_out_channels, + use_fp16=use_fp16, + sigma_min=sigma_min, + sigma_max=sigma_max, + sigma_data=sigma_data, + model_type=model_type, + **model_kwargs, + ) + + # Store deprecated parameters for backward compatibility + self.img_channels = img_channels + self.scale_cond_input = scale_cond_input + + def forward( + self, + x, + img_lr, + sigma, + force_fp32=False, + **model_kwargs, + ): + """ + Forward pass of the EDMPrecondSR model wrapper. + + Parameters + ---------- + x : torch.Tensor + Noisy high-resolution image of shape (B, C_hr, H, W). + img_lr : torch.Tensor + Low-resolution conditioning image of shape (B, C_lr, H, W). + sigma : torch.Tensor + Noise level of shape (B) or (B, 1) or (B, 1, 1, 1). + force_fp32 : bool, optional + Whether to force FP32 precision regardless of the `use_fp16` attribute, + by default False. + **model_kwargs : dict + Additional keyword arguments to pass to the underlying model. + + Returns + ------- + torch.Tensor + Denoised high-resolution image of shape (B, C_hr, H, W). + """ + return super().forward( + x=x, img_lr=img_lr, sigma=sigma, force_fp32=force_fp32, **model_kwargs + ) + +class VEPrecond_dfsr(nn.Module): + """ + Preconditioning for dfsr model, modified from class VEPrecond, where the input + argument 'sigma' in forward propagation function is used to receive the timestep + of the backward diffusion process. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models. + Advances in neural information processing systems. 2020;33:6840-51. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + sigma_min: float = 0.02, + sigma_max: float = 100.0, + dataset_mean: float = 5.85e-05, + dataset_scale: float = 4.79, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=self.img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + # print("sigma: ", sigma) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_in = 1 + c_noise = sigma # Change the definitation of c_noise to avoid -inf values for zero sigma + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if F_x.dtype != dtype: + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + return F_x + + +class VEPrecond_dfsr_cond(nn.Module): + """ + Preconditioning for dfsr model with physics-informed conditioning input, modified + from class VEPrecond, where the input argument 'sigma' in forward propagation function + is used to receive the timestep of the backward diffusion process. The gradient of PDE + residual with respect to the vorticity in the governing Navier-Stokes equation is computed + as the physics-informed conditioning variable and is combined with the backward diffusion + timestep before being sent to the underlying model for noise prediction. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: + [1] Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + [2] Shu D, Li Z, Farimani AB. A physics-informed diffusion model for high-fidelity + flow field reconstruction. Journal of Computational Physics. 2023 Apr 1;478:111972. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + sigma_min: float = 0.02, + sigma_max: float = 100.0, + dataset_mean: float = 5.85e-05, + dataset_scale: float = 4.79, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=model_kwargs["model_channels"] * 2, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + # modules to embed residual loss + self.conv_in = torch.nn.Conv2d( + img_channels, + model_kwargs["model_channels"], + kernel_size=3, + stride=1, + padding=1, + padding_mode="circular", + ) + self.emb_conv = torch.nn.Sequential( + torch.nn.Conv2d( + img_channels, + model_kwargs["model_channels"], + kernel_size=1, + stride=1, + padding=0, + ), + torch.nn.GELU(), + torch.nn.Conv2d( + model_kwargs["model_channels"], + model_kwargs["model_channels"], + kernel_size=3, + stride=1, + padding=1, + padding_mode="circular", + ), + ) + self.dataset_mean = dataset_mean + self.dataset_scale = dataset_scale + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_in = 1 + c_noise = sigma + + # Compute physics-informed conditioning information using vorticity residual + dx = ( + self.voriticity_residual((x * self.dataset_scale + self.dataset_mean)) + / self.dataset_scale + ) + x = self.conv_in(x) + cond_emb = self.emb_conv(dx) + x = torch.cat((x, cond_emb), dim=1) + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if F_x.dtype != dtype: + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + return F_x + + def voriticity_residual(self, w, re=1000.0, dt=1 / 32): + """ + Compute the gradient of PDE residual with respect to a given vorticity w using the + spectrum method. + + Parameters + ---------- + w: torch.Tensor + The fluid flow data sample (vorticity). + re: float + The value of Reynolds number used in the governing Navier-Stokes equation. + dt: float + Time step used to compute the time-derivative of vorticity included in the governing + Navier-Stokes equation. + + Returns + ------- + torch.Tensor + The computed vorticity gradient. + """ + + # w [b t h w] + w = w.clone() + w.requires_grad_(True) + nx = w.size(2) + device = w.device + + w_h = torch.fft.fft2(w[:, 1:-1], dim=[2, 3]) + # Wavenumbers in y-direction + k_max = nx // 2 + N = nx + k_x = ( + torch.cat( + ( + torch.arange(start=0, end=k_max, step=1, device=device), + torch.arange(start=-k_max, end=0, step=1, device=device), + ), + 0, + ) + .reshape(N, 1) + .repeat(1, N) + .reshape(1, 1, N, N) + ) + k_y = ( + torch.cat( + ( + torch.arange(start=0, end=k_max, step=1, device=device), + torch.arange(start=-k_max, end=0, step=1, device=device), + ), + 0, + ) + .reshape(1, N) + .repeat(N, 1) + .reshape(1, 1, N, N) + ) + # Negative Laplacian in Fourier space + lap = k_x**2 + k_y**2 + lap[..., 0, 0] = 1.0 + psi_h = w_h / lap + + u_h = 1j * k_y * psi_h + v_h = -1j * k_x * psi_h + wx_h = 1j * k_x * w_h + wy_h = 1j * k_y * w_h + wlap_h = -lap * w_h + + u = torch.fft.irfft2(u_h[..., :, : k_max + 1], dim=[2, 3]) + v = torch.fft.irfft2(v_h[..., :, : k_max + 1], dim=[2, 3]) + wx = torch.fft.irfft2(wx_h[..., :, : k_max + 1], dim=[2, 3]) + wy = torch.fft.irfft2(wy_h[..., :, : k_max + 1], dim=[2, 3]) + wlap = torch.fft.irfft2(wlap_h[..., :, : k_max + 1], dim=[2, 3]) + advection = u * wx + v * wy + + wt = (w[:, 2:, :, :] - w[:, :-2, :, :]) / (2 * dt) + + # establish forcing term + x = torch.linspace(0, 2 * np.pi, nx + 1, device=device) + x = x[0:-1] + X, Y = torch.meshgrid(x, x) + f = -4 * torch.cos(4 * Y) + + residual = wt + (advection - (1.0 / re) * wlap + 0.1 * w[:, 1:-1]) - f + residual_loss = (residual**2).mean() + dw = torch.autograd.grad(residual_loss, w)[0] + + return dw diff --git a/src/hirad/models/song_unet.py b/src/hirad/models/song_unet.py new file mode 100644 index 0000000..a56f861 --- /dev/null +++ b/src/hirad/models/song_unet.py @@ -0,0 +1,1250 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model architectures used in the paper "Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +import contextlib +from dataclasses import dataclass +from typing import Callable, List, Optional, Union + +import numpy as np +import nvtx +import torch +from torch.nn.functional import silu +from torch.utils.checkpoint import checkpoint +import torch.nn as nn + +from .layers import ( + Conv2d, + FourierEmbedding, + GroupNorm, + Linear, + PositionalEmbedding, + UNetBlock, +) +from .meta import ModelMetaData + + +@dataclass +class MetaData(ModelMetaData): + name: str = "SongUNet" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = True + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class SongUNet(nn.Module): + """ + Reimplementation of the DDPM++ and NCSN++ architectures, U-Net variants with + optional self-attention, embeddings, and encoder-decoder components. + + This model supports conditional and unconditional setups, as well as several + options for various internal architectural choices such as encoder and decoder + type, embedding type, etc., making it flexible and adaptable to different tasks + and configurations. + + Parameters + ----------- + img_resolution : Union[List[int], int] + The resolution of the input/output image. Can be a single int for square images + or a list [height, width] for rectangular images. + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels in the output image. + label_dim : int, optional + Number of class labels; 0 indicates an unconditional model. By default 0. + augment_dim : int, optional + Dimensionality of augmentation labels; 0 means no augmentation. By default 0. + model_channels : int, optional + Base multiplier for the number of channels across the network. By default 128. + channel_mult : List[int], optional + Per-resolution multipliers for the number of channels. By default [1,2,2,2]. + channel_mult_emb : int, optional + Multiplier for the dimensionality of the embedding vector. By default 4. + num_blocks : int, optional + Number of residual blocks per resolution. By default 4. + attn_resolutions : List[int], optional + Resolutions at which self-attention layers are applied. By default [16]. + dropout : float, optional + Dropout probability applied to intermediate activations. By default 0.10. + label_dropout : float, optional + Dropout probability of class labels for classifier-free guidance. By default 0.0. + embedding_type : str, optional + Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++, 'zero' for none. + By default 'positional'. + channel_mult_noise : int, optional + Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. + encoder_type : str, optional + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. , 'skip' for skip connections. + By default 'standard'. + decoder_type : str, optional + Decoder architecture: 'standard' or 'skip' for skip connections. By default 'standard'. + resample_filter : List[int], optional + Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. By default [1,1]. + checkpoint_level : int, optional + Number of layers that should use gradient checkpointing (0 disables checkpointing). + Higher values trade memory for computation. By default 0. + additive_pos_embed : bool, optional + If True, adds a learned positional embedding after the first convolution layer. + Used in StormCast model. By default False. + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + profile_mode: + A boolean flag indicating whether to enable all nvtx annotations during profiling. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. + + + Reference + ---------- + Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + + Note + ----- + Equivalent to the original implementation by Song et al., available at + https://github.com/yang-song/score_sde_pytorch + + Example + -------- + >>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> noise_labels = torch.randn([1]) + >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> input_image = torch.ones([1, 2, 16, 16]) + >>> output_image = model(input_image, noise_labels, class_labels) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + """ + + def __init__( + self, + img_resolution: Union[List[int], int], + in_channels: int, + out_channels: int, + label_dim: int = 0, + augment_dim: int = 0, + model_channels: int = 128, + channel_mult: List[int] = [1, 2, 2, 2], + channel_mult_emb: int = 4, + num_blocks: int = 4, + attn_resolutions: List[int] = [16], + dropout: float = 0.10, + label_dropout: float = 0.0, + embedding_type: str = "positional", + channel_mult_noise: int = 1, + encoder_type: str = "standard", + decoder_type: str = "standard", + resample_filter: List[int] = [1, 1], + checkpoint_level: int = 0, + additive_pos_embed: bool = False, + use_apex_gn: bool = False, + act: str = "silu", + profile_mode: bool = False, + amp_mode: bool = False, + ): + valid_embedding_types = ["fourier", "positional", "zero"] + if embedding_type not in valid_embedding_types: + raise ValueError( + f"Invalid embedding_type: {embedding_type}. Must be one of {valid_embedding_types}." + ) + + valid_encoder_types = ["standard", "skip", "residual"] + if encoder_type not in valid_encoder_types: + raise ValueError( + f"Invalid encoder_type: {encoder_type}. Must be one of {valid_encoder_types}." + ) + + valid_decoder_types = ["standard", "skip"] + if decoder_type not in valid_decoder_types: + raise ValueError( + f"Invalid decoder_type: {decoder_type}. Must be one of {valid_decoder_types}." + ) + + super().__init__() #meta=MetaData() + self.label_dropout = label_dropout + self.embedding_type = embedding_type + emb_channels = model_channels * channel_mult_emb + self.emb_channels = emb_channels + noise_channels = model_channels * channel_mult_noise + init = dict(init_mode="xavier_uniform") + init_zero = dict(init_mode="xavier_uniform", init_weight=1e-5) + init_attn = dict(init_mode="xavier_uniform", init_weight=np.sqrt(0.2)) + block_kwargs = dict( + emb_channels=emb_channels, + num_heads=1, + dropout=dropout, + skip_scale=np.sqrt(0.5), + eps=1e-6, + resample_filter=resample_filter, + resample_proj=True, + adaptive_scale=False, + init=init, + init_zero=init_zero, + init_attn=init_attn, + use_apex_gn=use_apex_gn, + act=act, + fused_conv_bias=True, + profile_mode=profile_mode, + amp_mode=amp_mode, + ) + self.profile_mode = profile_mode + self.amp_mode = amp_mode + + # for compatibility with older versions that took only 1 dimension + self.img_resolution = img_resolution + if isinstance(img_resolution, int): + self.img_shape_y = self.img_shape_x = img_resolution + else: + self.img_shape_y = img_resolution[0] + self.img_shape_x = img_resolution[1] + + # set the threshold for checkpointing based on image resolution + self.checkpoint_threshold = (self.img_shape_y >> checkpoint_level) + 1 + + # Optional additive learned positition embed after the first conv + self.additive_pos_embed = additive_pos_embed + if self.additive_pos_embed: + self.spatial_emb = torch.nn.Parameter( + torch.randn(1, model_channels, self.img_shape_y, self.img_shape_x) + ) + torch.nn.init.trunc_normal_(self.spatial_emb, std=0.02) + + # Mapping. + if self.embedding_type != "zero": + self.map_noise = ( + PositionalEmbedding( + num_channels=noise_channels, endpoint=True, amp_mode=amp_mode + ) + if embedding_type == "positional" + else FourierEmbedding(num_channels=noise_channels, amp_mode=amp_mode) + ) + self.map_label = ( + Linear( + in_features=label_dim, + out_features=noise_channels, + amp_mode=amp_mode, + **init, + ) + if label_dim + else None + ) + self.map_augment = ( + Linear( + in_features=augment_dim, + out_features=noise_channels, + bias=False, + amp_mode=amp_mode, + **init, + ) + if augment_dim + else None + ) + self.map_layer0 = Linear( + in_features=noise_channels, + out_features=emb_channels, + amp_mode=amp_mode, + **init, + ) + self.map_layer1 = Linear( + in_features=emb_channels, + out_features=emb_channels, + amp_mode=amp_mode, + **init, + ) + + # Encoder. + self.enc = torch.nn.ModuleDict() + cout = in_channels + caux = in_channels + for level, mult in enumerate(channel_mult): + res = self.img_shape_y >> level + if level == 0: + cin = cout + cout = model_channels + self.enc[f"{res}x{res}_conv"] = Conv2d( + in_channels=cin, + out_channels=cout, + kernel=3, + fused_conv_bias=True, + amp_mode=amp_mode, + **init, + ) + else: + self.enc[f"{res}x{res}_down"] = UNetBlock( + in_channels=cout, out_channels=cout, down=True, **block_kwargs + ) + if encoder_type == "skip": + self.enc[f"{res}x{res}_aux_down"] = Conv2d( + in_channels=caux, + out_channels=caux, + kernel=0, + down=True, + resample_filter=resample_filter, + amp_mode=amp_mode, + ) + self.enc[f"{res}x{res}_aux_skip"] = Conv2d( + in_channels=caux, + out_channels=cout, + kernel=1, + fused_conv_bias=True, + amp_mode=amp_mode, + **init, + ) + if encoder_type == "residual": + self.enc[f"{res}x{res}_aux_residual"] = Conv2d( + in_channels=caux, + out_channels=cout, + kernel=3, + down=True, + resample_filter=resample_filter, + fused_resample=True, + fused_conv_bias=True, + amp_mode=amp_mode, + **init, + ) + caux = cout + for idx in range(num_blocks): + cin = cout + cout = model_channels * mult + attn = res in attn_resolutions + self.enc[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, out_channels=cout, attention=attn, **block_kwargs + ) + skips = [ + block.out_channels for name, block in self.enc.items() if "aux" not in name + ] + + # Decoder. + self.dec = torch.nn.ModuleDict() + for level, mult in reversed(list(enumerate(channel_mult))): + res = self.img_shape_y >> level + if level == len(channel_mult) - 1: + self.dec[f"{res}x{res}_in0"] = UNetBlock( + in_channels=cout, out_channels=cout, attention=True, **block_kwargs + ) + self.dec[f"{res}x{res}_in1"] = UNetBlock( + in_channels=cout, out_channels=cout, **block_kwargs + ) + else: + self.dec[f"{res}x{res}_up"] = UNetBlock( + in_channels=cout, out_channels=cout, up=True, **block_kwargs + ) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = model_channels * mult + attn = idx == num_blocks and res in attn_resolutions + self.dec[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, out_channels=cout, attention=attn, **block_kwargs + ) + if decoder_type == "skip" or level == 0: + if decoder_type == "skip" and level < len(channel_mult) - 1: + self.dec[f"{res}x{res}_aux_up"] = Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel=0, + up=True, + resample_filter=resample_filter, + amp_mode=amp_mode, + ) + self.dec[f"{res}x{res}_aux_norm"] = GroupNorm( + num_channels=cout, + eps=1e-6, + use_apex_gn=use_apex_gn, + amp_mode=amp_mode, + ) + self.dec[f"{res}x{res}_aux_conv"] = Conv2d( + in_channels=cout, + out_channels=out_channels, + kernel=3, + fused_conv_bias=True, + amp_mode=amp_mode, + **init_zero, + ) + + def forward(self, x, noise_labels, class_labels, augment_labels=None): + with nvtx.annotate( + message="SongUNet", color="blue" + ) if self.profile_mode else contextlib.nullcontext(): + if self.embedding_type != "zero": + # Mapping. + emb = self.map_noise(noise_labels) + emb = ( + emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) + ) # swap sin/cos + if self.map_label is not None: + tmp = class_labels + if self.training and self.label_dropout: + tmp = tmp * ( + torch.rand([x.shape[0], 1], device=x.device) + >= self.label_dropout + ).to(tmp.dtype) + emb = emb + self.map_label( + tmp * np.sqrt(self.map_label.in_features) + ) + if self.map_augment is not None and augment_labels is not None: + emb = emb + self.map_augment(augment_labels) + emb = silu(self.map_layer0(emb)) + emb = silu(self.map_layer1(emb)) + else: + emb = torch.zeros( + (noise_labels.shape[0], self.emb_channels), device=x.device + ) + + # Encoder. + skips = [] + aux = x + for name, block in self.enc.items(): + with nvtx.annotate( + f"SongUNet encoder: {name}", color="blue" + ) if self.profile_mode else contextlib.nullcontext(): + if "aux_down" in name: + aux = block(aux) + elif "aux_skip" in name: + x = skips[-1] = x + block(aux) + elif "aux_residual" in name: + x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) + elif "_conv" in name: + x = block(x) + if self.additive_pos_embed: + x = x + self.spatial_emb.to(dtype=x.dtype) + skips.append(x) + else: + # For UNetBlocks check if we should use gradient checkpointing + if isinstance(block, UNetBlock): + if x.shape[-1] > self.checkpoint_threshold: + # self.checkpoint = checkpoint? + # else: self.checkpoint = lambda(block,x,emb:block(x,emb)) + x = checkpoint(block, x, emb, use_reentrant=False) + else: + # AssertionError: Only support NHWC layout. + x = block(x, emb) + else: + x = block(x) + skips.append(x) + + # Decoder. + aux = None + tmp = None + for name, block in self.dec.items(): + with nvtx.annotate( + f"SongUNet decoder: {name}", color="blue" + ) if self.profile_mode else contextlib.nullcontext(): + if "aux_up" in name: + aux = block(aux) + elif "aux_norm" in name: + tmp = block(x) + elif "aux_conv" in name: + tmp = block(silu(tmp)) + aux = tmp if aux is None else tmp + aux + else: + if x.shape[1] != block.in_channels: + x = torch.cat([x, skips.pop()], dim=1) + # check for checkpointing on decoder blocks and up sampling blocks + if ( + x.shape[-1] > self.checkpoint_threshold and "_block" in name + ) or ( + x.shape[-1] > (self.checkpoint_threshold / 2) + and "_up" in name + ): + x = checkpoint(block, x, emb, use_reentrant=False) + else: + x = block(x, emb) + return aux + + +class SongUNetPosEmbd(SongUNet): + """Extends SongUNet with positional embeddings. + + This model supports conditional and unconditional setups, as well as several + options for various internal architectural choices such as encoder and decoder + type, embedding type, etc., making it flexible and adaptable to different tasks + and configurations. + + This model adds positional embeddings to the base SongUNet architecture. The embeddings + can be selected using either a selector function or global indices, with the selector + approach being more computationally efficient. + + The model provides two methods for selecting positional embeddings: + + 1. Using a selector function (preferred method). See + :meth:`positional_embedding_selector` for details. + 2. Using global indices. See :meth:`positional_embedding_indexing` for + details. + + Parameters + ----------- + img_resolution : Union[List[int], int] + The resolution of the input/output image. Can be a single int for square images + or a list [height, width] for rectangular images. + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels in the output image. + label_dim : int, optional + Number of class labels; 0 indicates an unconditional model. By default 0. + augment_dim : int, optional + Dimensionality of augmentation labels; 0 means no augmentation. By default 0. + model_channels : int, optional + Base multiplier for the number of channels across the network. By default 128. + channel_mult : List[int], optional + Per-resolution multipliers for the number of channels. By default [1,2,2,2,2]. + channel_mult_emb : int, optional + Multiplier for the dimensionality of the embedding vector. By default 4. + num_blocks : int, optional + Number of residual blocks per resolution. By default 4. + attn_resolutions : List[int], optional + Resolutions at which self-attention layers are applied. By default [28]. + dropout : float, optional + Dropout probability applied to intermediate activations. By default 0.13. + label_dropout : float, optional + Dropout probability of class labels for classifier-free guidance. By default 0.0. + embedding_type : str, optional + Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. + By default 'positional'. + channel_mult_noise : int, optional + Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. + encoder_type : str, optional + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. , 'skip' for skip connections. + By default'standard'. + decoder_type : str, optional + Decoder architecture: 'standard' or 'skip' for skip connections. By default 'standard'. + resample_filter : List[int], optional + Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. By default [1,1]. + gridtype : str, optional + Type of positional grid to use: 'sinusoidal', 'learnable', 'linear', or 'test'. + Controls how positional information is encoded. By default 'sinusoidal'. + N_grid_channels : int, optional + Number of channels in the positional embedding grid. For 'sinusoidal' must be 4 or + multiple of 4. For 'linear' must be 2. By default 4. + checkpoint_level : int, optional + Number of layers that should use gradient checkpointing (0 disables checkpointing). + Higher values trade memory for computation. By default 0. + additive_pos_embed : bool, optional + If True, adds a learned positional embedding after the first convolution layer. + Used in StormCast model. By default False. + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + profile_mode: + A boolean flag indicating whether to enable all nvtx annotations during profiling. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. + lead_time_mode : bool, optional + A boolean flag indicating whether we are running SongUNet with lead time embedding. Defaults to False. + lead_time_channels : int, optional + Number of channels in the lead time embedding. These are learned embeddings that + encode temporal forecast information. By default None. + lead_time_steps : int, optional + Number of discrete lead time steps to support. Each step gets its own learned + embedding vector. By default 9. + prob_channels : List[int], optional + Indices of probability output channels that should use softmax activation. + Used for classification outputs. By default empty list. + + Note + ----- + Equivalent to the original implementation by Song et al., available at + https://github.com/yang-song/score_sde_pytorch + + Example + -------- + >>> import torch + >>> from physicsnemo.models.diffusion.song_unet import SongUNetPosEmbd + >>> from physicsnemo.utils.patching import GridPatching2D + >>> + >>> # Model initialization - in_channels must include both original input channels (2) + >>> # and the positional embedding channels (N_grid_channels=4 by default) + >>> model = SongUNetPosEmbd(img_resolution=16, in_channels=2+4, out_channels=2) + >>> noise_labels = torch.randn([1]) + >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> # The input has only the original 2 channels - positional embeddings are + >>> # added automatically inside the forward method + >>> input_image = torch.ones([1, 2, 16, 16]) + >>> output_image = model(input_image, noise_labels, class_labels) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + >>> + >>> # Using a global index to select all positional embeddings + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16)) + >>> global_index = patching.global_index(batch_size=1) + >>> output_image = model( + ... input_image, noise_labels, class_labels, + ... global_index=global_index + ... ) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + >>> + >>> # Using a custom embedding selector to select all positional embeddings + >>> def patch_embedding_selector(emb): + ... return patching.apply(emb[None].expand(1, -1, -1, -1)) + >>> output_image = model( + ... input_image, noise_labels, class_labels, + ... embedding_selector=patch_embedding_selector + ... ) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + """ + + def __init__( + self, + img_resolution: Union[List[int], int], + in_channels: int, + out_channels: int, + label_dim: int = 0, + augment_dim: int = 0, + model_channels: int = 128, + channel_mult: List[int] = [1, 2, 2, 2, 2], + channel_mult_emb: int = 4, + num_blocks: int = 4, + attn_resolutions: List[int] = [28], + dropout: float = 0.13, + label_dropout: float = 0.0, + embedding_type: str = "positional", + channel_mult_noise: int = 1, + encoder_type: str = "standard", + decoder_type: str = "standard", + resample_filter: List[int] = [1, 1], + gridtype: str = "sinusoidal", + N_grid_channels: int = 4, + checkpoint_level: int = 0, + additive_pos_embed: bool = False, + use_apex_gn: bool = False, + act: str = "silu", + profile_mode: bool = False, + amp_mode: bool = False, + lead_time_mode: bool = False, + lead_time_channels: int = None, + lead_time_steps: int = 9, + prob_channels: List[int] = [], + ): + super().__init__( + img_resolution, + in_channels, + out_channels, + label_dim, + augment_dim, + model_channels, + channel_mult, + channel_mult_emb, + num_blocks, + attn_resolutions, + dropout, + label_dropout, + embedding_type, + channel_mult_noise, + encoder_type, + decoder_type, + resample_filter, + checkpoint_level, + additive_pos_embed, + use_apex_gn, + act, + profile_mode, + amp_mode, + ) + + self.gridtype = gridtype + self.N_grid_channels = N_grid_channels + if self.gridtype == "learnable": + self.pos_embd = self._get_positional_embedding() + else: + self.register_buffer("pos_embd", self._get_positional_embedding().float()) + self.lead_time_mode = lead_time_mode + if self.lead_time_mode: + self.lead_time_channels = lead_time_channels + self.lead_time_steps = lead_time_steps + self.lt_embd = self._get_lead_time_embedding() + self.prob_channels = prob_channels + if self.prob_channels: + self.scalar = torch.nn.Parameter( + torch.ones((1, len(self.prob_channels), 1, 1)) + ) + + def forward( + self, + x, + noise_labels, + class_labels, + global_index: Optional[torch.Tensor] = None, + embedding_selector: Optional[Callable] = None, + augment_labels=None, + lead_time_label=None, + ): + with nvtx.annotate( + message="SongUNetPosEmbd", color="blue" + ) if self.profile_mode else contextlib.nullcontext(): + if embedding_selector is not None and global_index is not None: + raise ValueError( + "Cannot provide both embedding_selector and global_index. " + "embedding_selector is the preferred approach for better efficiency." + ) + + if x.dtype != self.pos_embd.dtype: + self.pos_embd = self.pos_embd.to(x.dtype) + + # Append positional embedding to input conditioning + if self.pos_embd is not None: + # Select positional embeddings with a selector function + if embedding_selector is not None: + selected_pos_embd = self.positional_embedding_selector( + x, embedding_selector + ) + # Select positional embeddings using global indices (selects all + # embeddings if global_index is None) + else: + selected_pos_embd = self.positional_embedding_indexing( + x, global_index=global_index, lead_time_label=lead_time_label + ) + x = torch.cat((x, selected_pos_embd), dim=1) + + out = super().forward(x, noise_labels, class_labels, augment_labels) + + if self.lead_time_mode: + # if training mode, let crossEntropyLoss do softmax. The model outputs logits. + # if eval mode, the model outputs probability + all_channels = list(range(out.shape[1])) # [0, 1, 2, ..., 10] + scalar_channels = [ + item for item in all_channels if item not in self.prob_channels + ] + if self.prob_channels and (not self.training): + out_final = torch.cat( + ( + out[:, scalar_channels], + (out[:, self.prob_channels] * self.scalar).softmax(dim=1), + ), + dim=1, + ) + elif self.prob_channels and self.training: + out_final = torch.cat( + ( + out[:, scalar_channels], + (out[:, self.prob_channels] * self.scalar), + ), + dim=1, + ) + else: + out_final = out + return out_final + + return out + + def positional_embedding_indexing( + self, + x: torch.Tensor, + global_index: Optional[torch.Tensor] = None, + lead_time_label=None, + ) -> torch.Tensor: + """Select positional embeddings using global indices. + + This method either uses global indices to select specific embeddings or expands + the embeddings for the full input when no indices are provided. + + Typically used in patch-based training, where the batch dimension + contains multiple patches extracted from a larger image. + + Arguments + --------- + x : torch.Tensor + Input tensor of shape (B, C, H, W), used to determine batch size + and device. + global_index : Optional[torch.Tensor] + Optional tensor of indices for selecting embeddings. These should + correspond to the spatial indices of the batch elements in the + input tensor x. When provided, should have shape (P, 2, H, W) where + the second dimension contains y,x coordinates (indices of the + positional embedding grid). + + Returns + ------- + torch.Tensor + Selected positional embeddings with shape: + - If global_index provided: (B, N_pe, H, W) + - If global_index is None: (B, N_pe, H_pe, W_pe) + where N_pe is the number of positional embedding channels, and H_pe + and W_pe are the height and width of the positional embedding grid. + + Example + ------- + >>> # Create global indices using patching utility: + >>> from physicsnemo.utils.patching import GridPatching2D + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8)) + >>> global_index = patching.global_index(batch_size=3) + >>> print(global_index.shape) + torch.Size([4, 2, 8, 8]) + + See Also + -------- + :meth:`physicsnemo.utils.patching.RandomPatching2D.global_index` + For generating random patch indices. + :meth:`physicsnemo.utils.patching.GridPatching2D.global_index` + For generating deterministic grid-based patch indices. + See these methods for possible ways to generate the global_index parameter. + """ + # If no global indices are provided, select all embeddings and expand + # to match the batch size of the input + if x.dtype != self.pos_embd.dtype: + self.pos_embd = self.pos_embd.to(x.dtype) + + if global_index is None: + if self.lead_time_mode: + selected_pos_embd = [] + if self.pos_embd is not None: + selected_pos_embd.append( + self.pos_embd[None].expand((x.shape[0], -1, -1, -1)) + ) + if self.lt_embd is not None: + selected_pos_embd.append( + torch.reshape( + self.lt_embd[lead_time_label.int()], + ( + x.shape[0], + self.lead_time_channels, + self.img_shape_y, + self.img_shape_x, + ), + ) + ) + if len(selected_pos_embd) > 0: + selected_pos_embd = torch.cat(selected_pos_embd, dim=1) + else: + selected_pos_embd = self.pos_embd[None].expand( + (x.shape[0], -1, -1, -1) + ) # (B, N_pe, H, W) + + else: + P = global_index.shape[0] + B = x.shape[0] // P + H = global_index.shape[2] + W = global_index.shape[3] + + global_index = torch.reshape( + torch.permute(global_index, (1, 0, 2, 3)), (2, -1) + ) # (P, 2, X, Y) to (2, P*X*Y) + selected_pos_embd = self.pos_embd[ + :, global_index[0], global_index[1] + ] # (N_pe, P*X*Y) + selected_pos_embd = torch.permute( + torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)), + (1, 0, 2, 3), + ) # (P, N_pe, X, Y) + + selected_pos_embd = selected_pos_embd.repeat( + B, 1, 1, 1 + ) # (B*P, N_pe, X, Y) + + # Append positional and lead time embeddings to input conditioning + if self.lead_time_mode: + embeds = [] + if self.pos_embd is not None: + embeds.append(selected_pos_embd) # reuse code below + if self.lt_embd is not None: + lt_embds = self.lt_embd[ + lead_time_label.int() + ] # (B, self.lead_time_channels, self.img_shape_y, self.img_shape_x), + + selected_lt_pos_embd = lt_embds[ + :, :, global_index[0], global_index[1] + ] # (B, N_lt, P*X*Y) + selected_lt_pos_embd = torch.reshape( + torch.permute( + torch.reshape( + selected_lt_pos_embd, + (B, self.lead_time_channels, P, H, W), + ), + (0, 2, 1, 3, 4), + ).contiguous(), + (B * P, self.lead_time_channels, H, W), + ) # (B*P, N_pe, X, Y) + embeds.append(selected_lt_pos_embd) + + if len(embeds) > 0: + selected_pos_embd = torch.cat(embeds, dim=1) + + return selected_pos_embd + + def positional_embedding_selector( + self, + x: torch.Tensor, + embedding_selector: Callable[[torch.Tensor], torch.Tensor], + ) -> torch.Tensor: + """Select positional embeddings using a selector function. + + Similar to positional_embedding_indexing, but uses a selector function + to select the embeddings. This method provides a more efficient way to + select embeddings for batches of data. + Typically used with patch-based processing, where the batch dimension + contains multiple patches extracted from a larger image. + + Arguments + --------- + x : torch.Tensor + Input tensor of shape (B, C, H, W) only used to determine dtype and + device. + embedding_selector : Callable + Function that takes as input an embedding tensor of shape (N_pe, + H_pe, W_pe) and returns selected embeddings with shape (batch_size, N_pe, H, W). + Each selected embedding should correspond to the positional + information of each batch element in x. + For patch-based processing, typically this should be based on + :meth:`physicsnemo.utils.patching.BasePatching2D.apply` method to + maintain consistency with patch extraction. + embeds : Optional[torch.Tensor] + Optional tensor for combined positional and lead time embeddings tensor + + Returns + ------- + torch.Tensor + Selected positional embeddings with shape (B, N_pe, H, W) + where N_pe is the number of positional embedding channels. + + Example + ------- + >>> # Define a selector function with a patching utility: + >>> from physicsnemo.utils.patching import GridPatching2D + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8)) + >>> batch_size = 4 + >>> def embedding_selector(emb): + ... return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) + >>> + + See Also + -------- + :meth:`physicsnemo.utils.patching.BasePatching2D.apply` + For the base patching method typically used in embedding_selector. + """ + if x.dtype != self.pos_embd.dtype: + self.pos_embd = self.pos_embd.to(x.dtype) + + return embedding_selector(self.pos_embd) # (B, N_pe, H, W) + + def _get_positional_embedding(self): + if self.N_grid_channels == 0: + return None + elif self.gridtype == "learnable": + grid = torch.nn.Parameter( + torch.randn(self.N_grid_channels, self.img_shape_y, self.img_shape_x) + ) # (N_grid_channels, img_shape_y, img_shape_x) + elif self.gridtype == "linear": + if self.N_grid_channels != 2: + raise ValueError("N_grid_channels must be set to 2 for gridtype linear") + x = np.meshgrid(np.linspace(-1, 1, self.img_shape_y)) + y = np.meshgrid(np.linspace(-1, 1, self.img_shape_x)) + grid_x, grid_y = np.meshgrid(y, x) + grid = torch.from_numpy( + np.stack((grid_x, grid_y), axis=0) + ) # (2, img_shape_y, img_shape_x) + grid.requires_grad = False + elif self.gridtype == "sinusoidal" and self.N_grid_channels == 4: + # print('sinusuidal grid added ......') + x1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_y))) + x2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_y))) + y1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_x))) + y2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_x))) + grid_x1, grid_y1 = np.meshgrid(y1, x1) + grid_x2, grid_y2 = np.meshgrid(y2, x2) + grid = torch.squeeze( + torch.from_numpy( + np.expand_dims( + np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis=0), axis=0 + ) + ) + ) # (4, img_shape_y, img_shape_x) + grid.requires_grad = False + elif self.gridtype == "sinusoidal" and self.N_grid_channels != 4: + if self.N_grid_channels % 4 != 0: + raise ValueError("N_grid_channels must be a factor of 4") + num_freq = self.N_grid_channels // 4 + freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq) + grid_list = [] + grid_x, grid_y = np.meshgrid( + np.linspace(0, 2 * np.pi, self.img_shape_x), + np.linspace(0, 2 * np.pi, self.img_shape_y), + ) + for freq in freq_bands: + for p_fn in [np.sin, np.cos]: + grid_list.append(p_fn(grid_x * freq)) + grid_list.append(p_fn(grid_y * freq)) + grid = torch.from_numpy( + np.stack(grid_list, axis=0) + ) # (N_grid_channels, img_shape_y, img_shape_x) + grid.requires_grad = False + elif self.gridtype == "test" and self.N_grid_channels == 2: + idx_x = torch.arange(self.img_shape_y) + idx_y = torch.arange(self.img_shape_x) + mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y) + grid = torch.stack((mesh_x, mesh_y), dim=0) # (2, img_shape_y, img_shape_x) + else: + raise ValueError("Gridtype not supported.") + return grid + + def _get_lead_time_embedding(self): + if (self.lead_time_steps is None) or (self.lead_time_channels is None): + return None + grid = torch.nn.Parameter( + torch.randn( + self.lead_time_steps, + self.lead_time_channels, + self.img_shape_y, + self.img_shape_x, + ) + ) # (lead_time_steps, lead_time_channels, img_shape_y, img_shape_x) + return grid + + +class SongUNetPosLtEmbd(SongUNetPosEmbd): + """ + This model is adapted from SongUNetPosEmbd, with the incorporation of lead-time aware + embeddings. The lead-time embedding is activated by setting the + `lead_time_channels` and `lead_time_steps` parameters. + + Like SongUNetPosEmbd, this model provides two methods for selecting positional embeddings: + 1. Using a selector function (preferred method). See + :meth:`positional_embedding_selector` for details. + 2. Using global indices. See :meth:`positional_embedding_indexing` for + details. + + Parameters + ----------- + img_resolution : Union[List[int], int] + The resolution of the input/output image. Can be a single int for square images + or a list [height, width] for rectangular images. + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels in the output image. + label_dim : int, optional + Number of class labels; 0 indicates an unconditional model. By default 0. + augment_dim : int, optional + Dimensionality of augmentation labels; 0 means no augmentation. By default 0. + model_channels : int, optional + Base multiplier for the number of channels across the network. By default 128. + channel_mult : List[int], optional + Per-resolution multipliers for the number of channels. By default [1,2,2,2,2]. + channel_mult_emb : int, optional + Multiplier for the dimensionality of the embedding vector. By default 4. + num_blocks : int, optional + Number of residual blocks per resolution. By default 4. + attn_resolutions : List[int], optional + Resolutions at which self-attention layers are applied. By default [28]. + dropout : float, optional + Dropout probability applied to intermediate activations. By default 0.13. + label_dropout : float, optional + Dropout probability of class labels for classifier-free guidance. By default 0.0. + embedding_type : str, optional + Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. + By default 'positional'. + channel_mult_noise : int, optional + Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. + encoder_type : str, optional + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++, 'skip' for skip connections. + By default 'standard'. + decoder_type : str, optional + Decoder architecture: 'standard' or 'skip' for skip connections. By default 'standard'. + resample_filter : List[int], optional + Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. By default [1,1]. + gridtype : str, optional + Type of positional grid to use: 'sinusoidal', 'learnable', 'linear', or 'test'. + Controls how positional information is encoded. By default 'sinusoidal'. + N_grid_channels : int, optional + Number of channels in the positional embedding grid. For 'sinusoidal' must be 4 or + multiple of 4. For 'linear' must be 2. By default 4. + lead_time_channels : int, optional + Number of channels in the lead time embedding. These are learned embeddings that + encode temporal forecast information. By default None. + lead_time_steps : int, optional + Number of discrete lead time steps to support. Each step gets its own learned + embedding vector. By default 9. + prob_channels : List[int], optional + Indices of probability output channels that should use softmax activation. + Used for classification outputs. By default empty list. + checkpoint_level : int, optional + Number of layers that should use gradient checkpointing (0 disables checkpointing). + Higher values trade memory for computation. By default 0. + additive_pos_embed : bool, optional + If True, adds a learned positional embedding after the first convolution layer. + Used in StormCast model. By default False. + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + profile_mode: + A boolean flag indicating whether to enable all nvtx annotations during profiling. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. + + + Note + ----- + Equivalent to the original implementation by Song et al., available at + https://github.com/yang-song/score_sde_pytorch + + Example + -------- + >>> import torch + >>> from physicsnemo.models.diffusion.song_unet import SongUNetPosLtEmbd + >>> from physicsnemo.utils.patching import GridPatching2D + >>> + >>> # Model initialization - in_channels must include original input channels (2), + >>> # positional embedding channels (N_grid_channels=4 by default) and + >>> # lead time embedding channels (4) + >>> model = SongUNetPosLtEmbd( + ... img_resolution=16, in_channels=2+4+4, out_channels=2, + ... lead_time_channels=4, lead_time_steps=9 + ... ) + >>> noise_labels = torch.randn([1]) + >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> # The input has only the original 2 channels - positional embeddings and + >>> # lead time embeddings are added automatically inside the forward method + >>> input_image = torch.ones([1, 2, 16, 16]) + >>> lead_time_label = torch.tensor([3]) + >>> output_image = model( + ... input_image, noise_labels, class_labels, + ... lead_time_label=lead_time_label + ... ) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + >>> + >>> # Using global_index to select all the positional and lead time embeddings + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16)) + >>> global_index = patching.global_index(batch_size=1) + >>> output_image = model( + ... input_image, noise_labels, class_labels, + ... lead_time_label=lead_time_label, + ... global_index=global_index + ... ) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + + # NOTE: commented out doctest for embedding_selector due to compatibility issue + # >>> + # >>> # Using custom embedding selector to select all the positional and lead time embeddings + # >>> def patch_embedding_selector(emb): + # ... return patching.apply(emb[None].expand(1, -1, -1, -1)) + # >>> output_image = model( + # ... input_image, noise_labels, class_labels, + # ... lead_time_label=lead_time_label, + # ... embedding_selector=patch_embedding_selector + # ... ) + # >>> output_image.shape + # torch.Size([1, 2, 16, 16]) + + """ + + def __init__( + self, + img_resolution: Union[List[int], int], + in_channels: int, + out_channels: int, + label_dim: int = 0, + augment_dim: int = 0, + model_channels: int = 128, + channel_mult: List[int] = [1, 2, 2, 2, 2], + channel_mult_emb: int = 4, + num_blocks: int = 4, + attn_resolutions: List[int] = [28], + dropout: float = 0.13, + label_dropout: float = 0.0, + embedding_type: str = "positional", + channel_mult_noise: int = 1, + encoder_type: str = "standard", + decoder_type: str = "standard", + resample_filter: List[int] = [1, 1], + gridtype: str = "sinusoidal", + N_grid_channels: int = 4, + lead_time_channels: int = None, + lead_time_steps: int = 9, + prob_channels: List[int] = [], + checkpoint_level: int = 0, + additive_pos_embed: bool = False, + use_apex_gn: bool = False, + act: str = "silu", + profile_mode: bool = False, + amp_mode: bool = False, + ): + super().__init__( + img_resolution, + in_channels, + out_channels, + label_dim, + augment_dim, + model_channels, + channel_mult, + channel_mult_emb, + num_blocks, + attn_resolutions, + dropout, + label_dropout, + embedding_type, + channel_mult_noise, + encoder_type, + decoder_type, + resample_filter, + gridtype, + N_grid_channels, + checkpoint_level, + additive_pos_embed, + use_apex_gn, + act, + profile_mode, + amp_mode, + True, # Note: lead_time_mode=True is enforced here + lead_time_channels, + lead_time_steps, + prob_channels, + ) + + def forward( + self, + x, + noise_labels, + class_labels, + lead_time_label=None, + global_index: Optional[torch.Tensor] = None, + embedding_selector: Optional[Callable] = None, + augment_labels=None, + ): + return super().forward( + x=x, + noise_labels=noise_labels, + class_labels=class_labels, + global_index=global_index, + embedding_selector=embedding_selector, + augment_labels=augment_labels, + lead_time_label=lead_time_label, + ) + + # Nothing else is re-implemented, because everything is already in the parent SongUNetPosEmb \ No newline at end of file diff --git a/src/hirad/models/unet.py b/src/hirad/models/unet.py new file mode 100644 index 0000000..e0a447a --- /dev/null +++ b/src/hirad/models/unet.py @@ -0,0 +1,356 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Tuple, Union + +import torch +import torch.nn as nn + +from .meta import ModelMetaData + +network_module = importlib.import_module("hirad.models") + + +@dataclass +class MetaData(ModelMetaData): + name: str = "UNet" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = True + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class UNet(nn.Module): # TODO a lot of redundancy, need to clean up + """ + U-Net Wrapper for CorrDiff deterministic regression model. + + Parameters + ----------- + img_resolution : Union[int, Tuple[int, int]] + The resolution of the input/output image. If a single int is provided, + then the image is assumed to be square. + img_in_channels : int + Number of channels in the input image. + img_out_channels : int + Number of channels in the output image. + use_fp16: bool, optional + Execute the underlying model at FP16 precision, by default False. + model_type: str, optional + Class name of the underlying model. Must be one of the following: + 'SongUNet', 'SongUNetPosEmbd', 'SongUNetPosLtEmbd', 'DhariwalUNet'. + Defaults to 'SongUNetPosEmbd'. + **model_kwargs : dict + Keyword arguments passed to the underlying model `__init__` method. + + See Also + -------- + For information on model types and their usage: + :class:`~physicsnemo.models.diffusion.SongUNet`: Basic U-Net for diffusion models + :class:`~physicsnemo.models.diffusion.SongUNetPosEmbd`: U-Net with positional embeddings + :class:`~physicsnemo.models.diffusion.SongUNetPosLtEmbd`: U-Net with positional and lead-time embeddings + + Please refer to the documentation of these classes for details on how to call + and use these models directly. + + References + ---------- + Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + @classmethod + def _backward_compat_arg_mapper( + cls, version: str, args: Dict[str, Any] + ) -> Dict[str, Any]: + """Map arguments from older versions to current version format. + + Parameters + ---------- + version : str + Version of the checkpoint being loaded + args : Dict[str, Any] + Arguments dictionary from the checkpoint + + Returns + ------- + Dict[str, Any] + Updated arguments dictionary compatible with current version + """ + # Call parent class method first + args = super()._backward_compat_arg_mapper(version, args) + + if version == "0.1.0": + # In version 0.1.0, img_channels was unused + if "img_channels" in args: + _ = args.pop("img_channels") + + # Sigma parameters are also unused + if "sigma_min" in args: + _ = args.pop("sigma_min") + if "sigma_max" in args: + _ = args.pop("sigma_max") + if "sigma_data" in args: + _ = args.pop("sigma_data") + + return args + + def __init__( + self, + img_resolution: Union[int, Tuple[int, int]], + img_in_channels: int, + img_out_channels: int, + use_fp16: bool = False, + model_type: Literal[ + "SongUNetPosEmbd", "SongUNetPosLtEmbd", "SongUNet", "DhariwalUNet" + ] = "SongUNetPosEmbd", + **model_kwargs: dict, + ): + super().__init__() #meta=MetaData + + # for compatibility with older versions that took only 1 dimension + if isinstance(img_resolution, int): + self.img_shape_x = self.img_shape_y = img_resolution + else: + self.img_shape_y = img_resolution[0] + self.img_shape_x = img_resolution[1] + + self.img_in_channels = img_in_channels + self.img_out_channels = img_out_channels + + self.use_fp16 = use_fp16 + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_in_channels + img_out_channels, + out_channels=img_out_channels, + **model_kwargs, + ) + + def forward( + self, + x: torch.Tensor, + img_lr: torch.Tensor, + force_fp32: bool = False, + **model_kwargs: dict, + ) -> torch.Tensor: + """ + Forward pass of the UNet wrapper model. + + This method concatenates the input tensor with the low-resolution conditioning tensor + and passes the result through the underlying model. + + Parameters + ---------- + x : torch.Tensor + The input tensor, typically zero-filled, of shape (B, C_hr, H, W). + img_lr : torch.Tensor + Low-resolution conditioning image of shape (B, C_lr, H, W). + force_fp32 : bool, optional + Whether to force FP32 precision regardless of the `use_fp16` attribute, + by default False. + **model_kwargs : dict + Additional keyword arguments to pass to the underlying model + `self.model` forward method. + + Returns + ------- + torch.Tensor + Output tensor (prediction) of shape (B, C_hr, H, W). + + Raises + ------ + ValueError + If the model output dtype doesn't match the expected dtype. + """ + + # SR: concatenate input channels + if img_lr is not None: + x = torch.cat((x, img_lr), dim=1) + + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + F_x = self.model( + x.to(dtype), # (c_in * x).to(dtype), + torch.zeros(x.shape[0], dtype=dtype, device=x.device), # c_noise.flatten() + class_labels=None, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, " f"but got {F_x.dtype} instead." + ) + + # skip connection + D_x = F_x.to(torch.float32) + return D_x + + def round_sigma(self, sigma: Union[float, List, torch.Tensor]) -> torch.Tensor: + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float, List, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + @property + def amp_mode(self): + """ + Return the *amp_mode* flag of the underlying model if present. + """ + return getattr(self.model, "amp_mode", None) + + @amp_mode.setter + def amp_mode(self, value: bool): + """ + Update *amp_mode* on the wrapped model and its sub-modules. + """ + if not isinstance(value, bool): + raise TypeError("amp_mode must be a boolean value.") + + if hasattr(self.model, "amp_mode"): + self.model.amp_mode = value + + # Recursively update sub-modules that define *amp_mode*. + for sub_module in self.model.modules(): + if hasattr(sub_module, "amp_mode"): + sub_module.amp_mode = value + +# TODO: implement amp_mode property for StormCastUNet (same as UNet) +class StormCastUNet(nn.Module): + """ + U-Net wrapper for StormCast; used so the same Song U-Net network can be re-used for this model. + + Parameters + ----------- + img_resolution : int or List[int] + The resolution of the input/output image. + img_channels : int + Number of color channels. + img_in_channels : int + Number of input color channels. + img_out_channels : int + Number of output color channels. + use_fp16: bool, optional + Execute the underlying model at FP16 precision?, by default False. + sigma_min: float, optional + Minimum supported noise level, by default 0. + sigma_max: float, optional + Maximum supported noise level, by default float('inf'). + sigma_data: float, optional + Expected standard deviation of the training data, by default 0.5. + model_type: str, optional + Class name of the underlying model, by default 'SongUNet'. + **model_kwargs : dict + Keyword arguments for the underlying model. + + """ + + def __init__( + self, + img_resolution, + img_in_channels, + img_out_channels, + use_fp16=False, + sigma_min=0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="SongUNet", + **model_kwargs, + ): + super().__init__() #meta=MetaData("StormCastUNet") + + if isinstance(img_resolution, int): + self.img_shape_x = self.img_shape_y = img_resolution + else: + self.img_shape_x = img_resolution[0] + self.img_shape_y = img_resolution[1] + + self.img_in_channels = img_in_channels + self.img_out_channels = img_out_channels + + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_in_channels, + out_channels=img_out_channels, + **model_kwargs, + ) + + def forward(self, x, force_fp32=False, **model_kwargs): + """Run a forward pass of the StormCast regression U-Net. + + Args: + x (torch.Tensor): input to the U-Net + force_fp32 (bool, optional): force casting to fp_32 if True. Defaults to False. + + Raises: + ValueError: If input data type is a mismatch with provided options + + Returns: + D_x (torch.Tensor): Output (prediction) of the U-Net + """ + + x = x.to(torch.float32) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + F_x = self.model( + x.to(dtype), + torch.zeros(x.shape[0], dtype=x.dtype, device=x.device), + class_labels=None, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = F_x.to(torch.float32) + return D_x diff --git a/src/hirad/models/utils.py b/src/hirad/models/utils.py new file mode 100644 index 0000000..e1cde9d --- /dev/null +++ b/src/hirad/models/utils.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import torch + + +def weight_init(shape: tuple, mode: str, fan_in: int, fan_out: int): + """ + Unified routine for initializing weights and biases. + This function provides a unified interface for various weight initialization + strategies like Xavier (Glorot) and Kaiming (He) initializations. + + Parameters + ---------- + shape : tuple + The shape of the tensor to initialize. It could represent weights or biases + of a layer in a neural network. + mode : str + The mode/type of initialization to use. Supported values are: + - "xavier_uniform": Xavier (Glorot) uniform initialization. + - "xavier_normal": Xavier (Glorot) normal initialization. + - "kaiming_uniform": Kaiming (He) uniform initialization. + - "kaiming_normal": Kaiming (He) normal initialization. + fan_in : int + The number of input units in the weight tensor. For convolutional layers, + this typically represents the number of input channels times the kernel height + times the kernel width. + fan_out : int + The number of output units in the weight tensor. For convolutional layers, + this typically represents the number of output channels times the kernel height + times the kernel width. + + Returns + ------- + torch.Tensor + The initialized tensor based on the specified mode. + + Raises + ------ + ValueError + If the provided `mode` is not one of the supported initialization modes. + """ + if mode == "xavier_uniform": + return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) + if mode == "xavier_normal": + return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) + if mode == "kaiming_uniform": + return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) + if mode == "kaiming_normal": + return np.sqrt(1 / fan_in) * torch.randn(*shape) + raise ValueError(f'Invalid init mode "{mode}"') diff --git a/src/hirad/train_diffusion.sh b/src/hirad/train_diffusion.sh new file mode 100644 index 0000000..cf2f88f --- /dev/null +++ b/src/hirad/train_diffusion.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +#SBATCH --job-name="testrun" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=72 +#SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/diffusion.log +#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/diffusion.err + +### ENVIRONMENT #### +#SBATCH --uenv=pytorch/v2.6.0:/user-environment +#SBATCH --view=default +#SBATCH -A a-a122 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +# Get master node. +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +export MASTER_ADDR +export MASTER_PORT=29500 + +# Get number of physical cores using Python +PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# Compute cores per process +OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +export OMP_NUM_THREADS=$OMP_THREADS + +# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml +srun bash -c " + . ./train_env/bin/activate + python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml +" \ No newline at end of file diff --git a/src/hirad/train_regression.sh b/src/hirad/train_regression.sh new file mode 100644 index 0000000..c065477 --- /dev/null +++ b/src/hirad/train_regression.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +#SBATCH --job-name="testrun" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=72 +#SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/regression.log +#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/regression.err + +### ENVIRONMENT #### +#SBATCH --uenv=pytorch/v2.6.0:/user-environment +#SBATCH --view=default +#SBATCH -A a-a122 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +# Get master node. +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +export MASTER_ADDR +export MASTER_PORT=29500 + +# Get number of physical cores using Python +PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# Compute cores per process +OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +export OMP_NUM_THREADS=$OMP_THREADS + +# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml +srun bash -c " + . ./train_env/bin/activate + python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml +" \ No newline at end of file diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py new file mode 100755 index 0000000..3a2fe2e --- /dev/null +++ b/src/hirad/training/train.py @@ -0,0 +1,716 @@ +import os +import time + +import psutil +import hydra +from omegaconf import DictConfig, OmegaConf +import json +from contextlib import nullcontext +import nvtx +import torch +from hydra.utils import to_absolute_path +from torch.utils.tensorboard import SummaryWriter +from torch.nn.parallel import DistributedDataParallel +from torchinfo import summary + +from hirad.distributed import DistributedManager +from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper +from hirad.utils.train_helpers import set_seed, configure_cuda_for_consistent_precision, \ + set_patch_shape, compute_num_accumulation_rounds, \ + is_time_for_periodic_task, handle_and_clip_gradients +from hirad.utils.checkpoint import load_checkpoint, save_checkpoint +from hirad.utils.patching import RandomPatching2D +from hirad.models import UNet, EDMPrecondSuperResolution, EDMPrecondSR +from hirad.losses import ResidualLoss, RegressionLoss, RegressionLossCE +from hirad.datasets import init_train_valid_datasets_from_config + +from matplotlib import pyplot as plt + +torch._dynamo.reset() +# Increase the cache size limit +torch._dynamo.config.cache_size_limit = 264 # Set to a higher value +torch._dynamo.config.verbose = True # Enable verbose logging +torch._dynamo.config.suppress_errors = False # Forces the error to show all details +torch._logging.set_logs(recompiles=True, graph_breaks=True) + +# Define safe CUDA profiler tools that fallback to no-ops when CUDA is not available +def cuda_profiler(): + if torch.cuda.is_available(): + return torch.cuda.profiler.profile() + else: + return nullcontext() + + +def cuda_profiler_start(): + if torch.cuda.is_available(): + torch.cuda.profiler.start() + + +def cuda_profiler_stop(): + if torch.cuda.is_available(): + torch.cuda.profiler.stop() + + +def profiler_emit_nvtx(): + if torch.cuda.is_available(): + return torch.autograd.profiler.emit_nvtx() + else: + return nullcontext() + +@hydra.main(version_base=None, config_path="../conf", config_name="training") +def main(cfg: DictConfig) -> None: + # Initialize distributed environment for training + DistributedManager.initialize() + dist = DistributedManager() + + if dist.rank==0: + writer = SummaryWriter(log_dir='tensorboard') + logger = PythonLogger("main") # general logger + logger0 = RankZeroLoggingWrapper(logger, dist) # rank 0 logger + + OmegaConf.resolve(cfg) + dataset_cfg = OmegaConf.to_container(cfg.dataset) + if hasattr(cfg.dataset, "validation_path"): + train_test_split = True + else: + train_test_split = False + fp_optimizations = cfg.training.perf.fp_optimizations + songunet_checkpoint_level = cfg.training.perf.songunet_checkpoint_level + fp16 = fp_optimizations == "fp16" + enable_amp = fp_optimizations.startswith("amp") + amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16 + logger0.info(f"Saving the outputs in {os.getcwd()}") + checkpoint_dir = os.path.join( + cfg.training.io.get("checkpoint_dir", "."), f"checkpoints_{cfg.model.name}" + ) + if dist.rank==0 and not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) # added creating checkpoint dir + if cfg.training.hp.batch_size_per_gpu == "auto": + cfg.training.hp.batch_size_per_gpu = ( + cfg.training.hp.total_batch_size // dist.world_size + ) + + set_seed(dist.rank) + configure_cuda_for_consistent_precision() + + # Instantiate the dataset + data_loader_kwargs = { + "pin_memory": True, + "num_workers": cfg.training.perf.dataloader_workers, + "prefetch_factor": 2 if cfg.training.perf.dataloader_workers > 0 else None, + } + ( + dataset, + dataset_iterator, + validation_dataset, + validation_dataset_iterator, + ) = init_train_valid_datasets_from_config( + dataset_cfg, + data_loader_kwargs, + batch_size=cfg.training.hp.batch_size_per_gpu, + seed=0, + train_test_split=train_test_split, + ) + logger0.info(f"Training on dataset with size {len(dataset)}") + + # Parse image configuration & update model args + dataset_channels = len(dataset.input_channels()) + img_in_channels = dataset_channels + img_shape = dataset.image_shape() + img_out_channels = len(dataset.output_channels()) + if cfg.model.hr_mean_conditioning: + img_in_channels += img_out_channels + + + if cfg.model.name == "lt_aware_ce_regression": + prob_channels = dataset.get_prob_channel_index() #TODO figure out what prob_channel are and update dataloader + else: + prob_channels = None + + # Parse the patch shape + #TODO figure out patched diffusion and how to use it + if ( + cfg.model.name == "patched_diffusion" + or cfg.model.name == "lt_aware_patched_diffusion" + ): + patch_shape_x = cfg.training.hp.patch_shape_x + patch_shape_y = cfg.training.hp.patch_shape_y + else: + patch_shape_x = None + patch_shape_y = None + if ( + patch_shape_x + and patch_shape_y + and patch_shape_y >= img_shape[0] + and patch_shape_x >= img_shape[1] + ): + logger0.warning( + f"Patch shape {patch_shape_y}x{patch_shape_x} is larger than \ + the image shape {img_shape[0]}x{img_shape[1]}. Patching will not be used." + ) + patch_shape = (patch_shape_y, patch_shape_x) + use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) + if use_patching: + # Utility to perform patches extraction and batching + patching = RandomPatching2D( + img_shape=img_shape, + patch_shape=patch_shape, + patch_num=getattr(cfg.training.hp, "patch_num", 1), + ) + logger0.info("Patch-based training enabled") + else: + patching = None + logger0.info("Patch-based training disabled") + # interpolate global channel if patch-based model is used + if use_patching: + img_in_channels += dataset_channels + + # Instantiate the model and move to device. + model_args = { # default parameters for all networks + "img_out_channels": img_out_channels, + "img_resolution": list(img_shape), + "use_fp16": fp16, + "checkpoint_level": songunet_checkpoint_level, + } + if cfg.model.name == "lt_aware_ce_regression": + model_args["prob_channels"] = prob_channels + + if hasattr(cfg.model, "model_args"): # override defaults from config file + model_args.update(OmegaConf.to_container(cfg.model.model_args)) + + use_torch_compile = False + use_apex_gn = False + profile_mode = False + + if hasattr(cfg.training.perf, "torch_compile"): + use_torch_compile = cfg.training.perf.torch_compile + if hasattr(cfg.training.perf, "use_apex_gn"): + use_apex_gn = cfg.training.perf.use_apex_gn + model_args["use_apex_gn"] = use_apex_gn + + if hasattr(cfg.training.perf, "profile_mode"): + profile_mode = cfg.training.perf.profile_mode + model_args["profile_mode"] = profile_mode + + if enable_amp: + model_args["amp_mode"] = enable_amp + + + if cfg.model.name == "regression": + model = UNet( + img_in_channels=img_in_channels + model_args["N_grid_channels"], + **model_args, + ) + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + elif cfg.model.name == "lt_aware_ce_regression": + model = UNet( + img_in_channels=img_in_channels + + model_args["N_grid_channels"] + + model_args["lead_time_channels"], + **model_args, + ) + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] + elif cfg.model.name == "lt_aware_patched_diffusion": + model = EDMPrecondSuperResolution( + img_in_channels=img_in_channels + + model_args["N_grid_channels"] + + model_args["lead_time_channels"], + **model_args, + ) + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] + else: # diffusion or patched diffusion + model = EDMPrecondSuperResolution( + img_in_channels=img_in_channels + model_args["N_grid_channels"], + **model_args, + ) + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + + model.train().requires_grad_(True).to(dist.device) + + if dist.rank==0 and not os.path.exists(os.path.join(checkpoint_dir, 'model_args.json')): + with open(os.path.join(checkpoint_dir, f'model_args.json'), 'w') as f: + json.dump(model_args, f) + + if use_apex_gn: + model.to(memory_format=torch.channels_last) + + # Check if regression model is used with patching + if ( + cfg.model.name in ["regression", "lt_aware_ce_regression"] + and patching is not None + ): + raise ValueError( + f"Regression model ({cfg.model.name}) cannot be used with patch-based training. " + ) + + # Enable distributed data parallel if applicable + if dist.world_size > 1: + model = DistributedDataParallel( + model, + device_ids=[dist.local_rank], + broadcast_buffers=True, + output_device=dist.device, + find_unused_parameters=True, # dist.find_unused_parameters, + bucket_cap_mb=35, + gradient_as_bucket_view=True, + ) + + # Load the regression checkpoint if applicable #TODO test when training correction + if hasattr(cfg.training.io, "regression_checkpoint_path"): + regression_checkpoint_path = to_absolute_path( + cfg.training.io.regression_checkpoint_path + ) + if not os.path.isdir(regression_checkpoint_path): + raise FileNotFoundError( + f"Expected this regression checkpoint but not found: {regression_checkpoint_path}" + ) + #regression_net = torch.nn.Module() #TODO Module.from_checkpoint(regression_checkpoint_path) figure out how to save and load models, also, some basic functions like num_params, device + #TODO make regression model loading more robust (model type is both in rergession_checkpoint_path and regression_name) + #TODO add the option to choose epoch to load from / regression_checkpoint_path is now a folder + regression_model_args_path = os.path.join(regression_checkpoint_path, 'model_args.json') + if not os.path.isfile(regression_model_args_path): + raise FileNotFoundError(f"Missing config file at '{regression_model_args_path}'.") + + with open(regression_model_args_path, 'r') as f: + regression_model_args = json.load(f) + + regression_model_args.update({ + "use_apex_gn": use_apex_gn, + "profile_mode": profile_mode, + "amp_mode": enable_amp, + }) + + regression_net = UNet(**regression_model_args) + + _ = load_checkpoint( + path=regression_checkpoint_path, + model=regression_net, + device=dist.device + ) + regression_net.eval().requires_grad_(False).to(dist.device) + if use_apex_gn: + regression_net.to(memory_format=torch.channels_last) + logger0.success("Loaded the pre-trained regression model") + else: + regression_net = None + + # Compile the model and regression net if applicable + if use_torch_compile: + model = torch.compile(model) + if regression_net: + regression_net = torch.compile(regression_net) + + + # Compute the number of required gradient accumulation rounds + # It is automatically used if batch_size_per_gpu * dist.world_size < total_batch_size + batch_gpu_total, num_accumulation_rounds = compute_num_accumulation_rounds( + cfg.training.hp.total_batch_size, + cfg.training.hp.batch_size_per_gpu, + dist.world_size, + ) + batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu + logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds") + + patch_num = getattr(cfg.training.hp, "patch_num", 1) + max_patch_per_gpu = getattr(cfg.training.hp, "max_patch_per_gpu", 1) + + # calculate patch per iter + if hasattr(cfg.training.hp, "max_patch_per_gpu") and max_patch_per_gpu > 1: + max_patch_num_per_iter = min( + patch_num, (max_patch_per_gpu // batch_size_per_gpu) + ) # Ensure at least 1 patch per iter + patch_iterations = ( + patch_num + max_patch_num_per_iter - 1 + ) // max_patch_num_per_iter + patch_nums_iter = [ + min(max_patch_num_per_iter, patch_num - i * max_patch_num_per_iter) + for i in range(patch_iterations) + ] + print( + f"max_patch_num_per_iter is {max_patch_num_per_iter}, patch_iterations is {patch_iterations}, patch_nums_iter is {patch_nums_iter}" + ) + else: + patch_nums_iter = [patch_num] + + # Set patch gradient accumulation only for patched diffusion models + if cfg.model.name in { + "patched_diffusion", + "lt_aware_patched_diffusion", + }: + if len(patch_nums_iter) > 1: + if not patching: + logger0.info( + "Patching is not enabled: patch gradient accumulation automatically disabled." + ) + use_patch_grad_acc = False + else: + use_patch_grad_acc = True + else: + use_patch_grad_acc = False + # Automatically disable patch gradient accumulation for non-patched models + else: + logger0.info( + "Training a non-patched model: patch gradient accumulation automatically disabled." + ) + use_patch_grad_acc = None + + + # Instantiate the loss function + if cfg.model.name in ( + "diffusion", + "patched_diffusion", + "lt_aware_patched_diffusion", + ): + loss_fn = ResidualLoss( + regression_net=regression_net, + hr_mean_conditioning=cfg.model.hr_mean_conditioning, + ) + elif cfg.model.name == "regression": + loss_fn = RegressionLoss() + elif cfg.model.name == "lt_aware_ce_regression": + loss_fn = RegressionLossCE(prob_channels=prob_channels) + + # Instantiate the optimizer + optimizer = torch.optim.Adam( + params=model.parameters(), + lr=cfg.training.hp.lr, + betas=[0.9, 0.999], + eps=1e-8, + fused=True, + ) + + # Record the current time to measure the duration of subsequent operations. + start_time = time.time() + + # Load optimizer checkpoint if it exists + if dist.world_size > 1: + torch.distributed.barrier() + try: + cur_nimg = load_checkpoint( + path=checkpoint_dir, + model=model, + optimizer=optimizer, + device=dist.device, + ) + except: + cur_nimg = 0 + + ############################################################################ + # MAIN TRAINING LOOP # + ############################################################################ + + logger0.info(f"Training for {cfg.training.hp.training_duration} images...") + done = False + + # init variables to monitor running mean of average loss since last periodic + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 + start_nimg = cur_nimg + input_dtype = torch.float32 + if enable_amp: + input_dtype = torch.float32 + elif fp16: + input_dtype = torch.float16 + + # enable profiler: + with cuda_profiler(): + with profiler_emit_nvtx(): + while not done: + tick_start_nimg = cur_nimg + tick_start_time = time.time() + + if cur_nimg - start_nimg == 24 * cfg.training.hp.total_batch_size: + logger0.info(f"Starting Profiler at {cur_nimg}") + cuda_profiler_start() + + if cur_nimg - start_nimg == 25 * cfg.training.hp.total_batch_size: + logger0.info(f"Stopping Profiler at {cur_nimg}") + cuda_profiler_stop() + + with nvtx.annotate("Training iteration", color="green"): + # Compute & accumulate gradients + optimizer.zero_grad(set_to_none=True) + loss_accum = 0 + for n_i in range(num_accumulation_rounds): + with nvtx.annotate( + f"accumulation round {n_i}", color="Magenta" + ): + with nvtx.annotate("loading data", color="green"): + img_clean, img_lr, *lead_time_label = next( + dataset_iterator + ) + if use_apex_gn: + img_clean = img_clean.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + img_lr = img_lr.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + else: + img_clean = ( + img_clean.to(dist.device) + .to(input_dtype) + .contiguous() + ) + img_lr = ( + img_lr.to(dist.device) + .to(input_dtype) + .contiguous() + ) + loss_fn_kwargs = { + "net": model, + "img_clean": img_clean, + "img_lr": img_lr, + "augment_pipe": None, + } + if use_patch_grad_acc is not None: + loss_fn_kwargs[ + "use_patch_grad_acc" + ] = use_patch_grad_acc + + if lead_time_label: + lead_time_label = ( + lead_time_label[0].to(dist.device).contiguous() + ) + loss_fn_kwargs.update( + {"lead_time_label": lead_time_label} + ) + else: + lead_time_label = None + if use_patch_grad_acc: + loss_fn.y_mean = None + + for patch_num_per_iter in patch_nums_iter: + if patching is not None: + patching.set_patch_num(patch_num_per_iter) + loss_fn_kwargs.update({"patching": patching}) + with nvtx.annotate(f"loss forward", color="green"): + with torch.autocast( + "cuda", dtype=amp_dtype, enabled=enable_amp + ): + loss = loss_fn(**loss_fn_kwargs) + + loss = loss.sum() / batch_size_per_gpu + loss_accum += loss / num_accumulation_rounds + with nvtx.annotate(f"loss backward", color="yellow"): + loss.backward() + + + with nvtx.annotate(f"loss aggregate", color="green"): + loss_sum = torch.tensor([loss_accum], device=dist.device) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce( + loss_sum, op=torch.distributed.ReduceOp.SUM + ) + average_loss = (loss_sum / dist.world_size).cpu().item() + + # update running mean of average loss since last periodic task + average_loss_running_mean += ( + average_loss - average_loss_running_mean + ) / n_average_loss_running_mean + n_average_loss_running_mean += 1 + + if dist.rank == 0: + writer.add_scalar("training_loss", average_loss, cur_nimg) + writer.add_scalar( + "training_loss_running_mean", + average_loss_running_mean, + cur_nimg, + ) + + ptt = is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ) + if ptt: + # reset running mean of average loss + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 + + # Update weights. + with nvtx.annotate("update weights", color="blue"): + + lr_rampup = cfg.training.hp.lr_rampup # ramp up the learning rate + for g in optimizer.param_groups: + if lr_rampup > 0: + g["lr"] = cfg.training.hp.lr * min(cur_nimg / lr_rampup, 1) + if cur_nimg >= lr_rampup: + g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // 5e6) + current_lr = g["lr"] + if dist.rank == 0: + writer.add_scalar("learning_rate", current_lr, cur_nimg) + handle_and_clip_gradients( + model, grad_clip_threshold=cfg.training.hp.grad_clip_threshold + ) + with nvtx.annotate("optimizer step", color="blue"): + optimizer.step() + + cur_nimg += cfg.training.hp.total_batch_size + done = cur_nimg >= cfg.training.hp.training_duration + + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + # Print stats if we crossed the printing threshold with this batch + tick_end_time = time.time() + fields = [] + fields += [f"samples {cur_nimg:<9.1f}"] + fields += [f"training_loss {average_loss:<7.2f}"] + fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] + fields += [f"learning_rate {current_lr:<7.8f}"] + fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] + fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] + fields += [ + f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" + ] + fields += [ + f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" + ] + if torch.cuda.is_available(): + fields += [ + f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" + ] + torch.cuda.reset_peak_memory_stats() + logger0.info(" ".join(fields)) + + with nvtx.annotate("validation", color="red"): + # Validation + if validation_dataset_iterator is not None: + valid_loss_accum = 0 + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.validation_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + ): + with torch.no_grad(): + for _ in range(cfg.training.io.validation_steps): + ( + img_clean_valid, + img_lr_valid, + *lead_time_label_valid, + ) = next(validation_dataset_iterator) + + if use_apex_gn: + img_clean_valid = img_clean_valid.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + img_lr_valid = img_lr_valid.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + + else: + img_clean_valid = ( + img_clean_valid.to(dist.device) + .to(input_dtype) + .contiguous() + ) + img_lr_valid = ( + img_lr_valid.to(dist.device) + .to(input_dtype) + .contiguous() + ) + + loss_valid_kwargs = { + "net": model, + "img_clean": img_clean_valid, + "img_lr": img_lr_valid, + "augment_pipe": None, + } + if use_patch_grad_acc is not None: + loss_valid_kwargs[ + "use_patch_grad_acc" + ] = use_patch_grad_acc + if lead_time_label_valid: + lead_time_label_valid = ( + lead_time_label_valid[0] + .to(dist.device) + .contiguous() + ) + loss_valid_kwargs.update( + {"lead_time_label": lead_time_label_valid} + ) + if use_patch_grad_acc: + loss_fn.y_mean = None + + for patch_num_per_iter in patch_nums_iter: + if patching is not None: + patching.set_patch_num(patch_num_per_iter) + loss_fn_kwargs.update( + {"patching": patching} + ) + with torch.autocast( + "cuda", dtype=amp_dtype, enabled=enable_amp + ): + loss_valid = loss_fn(**loss_valid_kwargs) + + loss_valid = ( + (loss_valid.sum() / batch_size_per_gpu) + .cpu() + .item() + ) + valid_loss_accum += ( + loss_valid + / cfg.training.io.validation_steps + ) + valid_loss_sum = torch.tensor( + [valid_loss_accum], device=dist.device + ) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce( + valid_loss_sum, op=torch.distributed.ReduceOp.SUM + ) + average_valid_loss = valid_loss_sum / dist.world_size + if dist.rank == 0: + writer.add_scalar( + "validation_loss", average_valid_loss, cur_nimg + ) + + + # Save checkpoints + if dist.world_size > 1: + torch.distributed.barrier() + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.save_checkpoint_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + save_checkpoint( + path=checkpoint_dir, + model=model, + optimizer=optimizer, + epoch=cur_nimg, + ) + + # Done. + logger0.info("Training Completed.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/hirad/utils/__init__.py b/src/hirad/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/hirad/utils/capture.py b/src/hirad/utils/capture.py new file mode 100644 index 0000000..9c38d5a --- /dev/null +++ b/src/hirad/utils/capture.py @@ -0,0 +1,513 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import logging +import os +import time +from contextlib import nullcontext +from logging import Logger +from typing import Any, Callable, Dict, NewType, Optional, Union + +import torch + +from hirad.distributed import DistributedManager + +float16 = NewType("float16", torch.float16) +bfloat16 = NewType("bfloat16", torch.bfloat16) +optim = NewType("optim", torch.optim) + + +class _StaticCapture(object): + """Base class for StaticCapture decorator. + + This class should not be used, rather StaticCaptureTraining and StaticCaptureEvaluate + should be used instead for training and evaluation functions. + """ + + # Grad scaler and checkpoint class variables use for checkpoint saving and loading + # Since an instance of Static capture does not exist for checkpoint functions + # one must use class functions to access state dicts + _amp_scalers = {} + _amp_scaler_checkpoints = {} + _logger = logging.getLogger("capture") + + def __new__(cls, *args, **kwargs): + obj = super(_StaticCapture, cls).__new__(cls) + obj.amp_scalers = cls._amp_scalers + obj.amp_scaler_checkpoints = cls._amp_scaler_checkpoints + obj.logger = cls._logger + return obj + + def __init__( + self, + model: "physicsnemo.Module", + optim: Optional[optim] = None, + logger: Optional[Logger] = None, + use_graphs: bool = True, + use_autocast: bool = True, + use_gradscaler: bool = True, + compile: bool = False, + cuda_graph_warmup: int = 11, + amp_type: Union[float16, bfloat16] = torch.float16, + gradient_clip_norm: Optional[float] = None, + label: Optional[str] = None, + ): + self.logger = logger if logger else self.logger + # Checkpoint label (used for gradscaler) + self.label = label if label else f"scaler_{len(self.amp_scalers.keys())}" + + # DDP fix + if not isinstance(model, physicsnemo.models.Module) and hasattr( + model, "module" + ): + model = model.module + + if not isinstance(model, physicsnemo.models.Module): + self.logger.error("Model not a PhysicsNeMo Module!") + raise ValueError("Model not a PhysicsNeMo Module!") + if compile: + model = torch.compile(model) + + self.model = model + + self.optim = optim + self.eval = False + self.no_grad = False + self.gradient_clip_norm = gradient_clip_norm + + # Set up toggles for optimizations + if not (amp_type == torch.float16 or amp_type == torch.bfloat16): + raise ValueError("AMP type must be torch.float16 or torch.bfloat16") + # CUDA device + if "cuda" in str(self.model.device): + # CUDA graphs + if use_graphs and not self.model.meta.cuda_graphs: + self.logger.warning( + f"Model {model.meta.name} does not support CUDA graphs, turning off" + ) + use_graphs = False + self.cuda_graphs_enabled = use_graphs + + # AMP GPU + if not self.model.meta.amp_gpu: + self.logger.warning( + f"Model {model.meta.name} does not support AMP on GPUs, turning off" + ) + use_autocast = False + use_gradscaler = False + self.use_gradscaler = use_gradscaler + self.use_autocast = use_autocast + + self.amp_device = "cuda" + # Check if bfloat16 is suppored on the GPU + if amp_type == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + self.logger.warning( + "Current CUDA device does not support bfloat16, falling back to float16" + ) + amp_type = torch.float16 + self.amp_dtype = amp_type + # Gradient Scaler + scaler_enabled = self.use_gradscaler and amp_type == torch.float16 + self.scaler = self._init_amp_scaler(scaler_enabled, self.logger) + + self.replay_stream = torch.cuda.Stream(self.model.device) + # CPU device + else: + self.cuda_graphs_enabled = False + # AMP CPU + if use_autocast and not self.model.meta.amp_cpu: + self.logger.warning( + f"Model {model.meta.name} does not support AMP on CPUs, turning off" + ) + use_autocast = False + + self.use_autocast = use_autocast + self.amp_device = "cpu" + # Only float16 is supported on CPUs + # https://pytorch.org/docs/stable/amp.html#cpu-op-specific-behavior + if amp_type == torch.float16 and use_autocast: + self.logger.warning( + "torch.float16 not supported for CPU AMP, switching to torch.bfloat16" + ) + amp_type = torch.bfloat16 + self.amp_dtype = torch.bfloat16 + # Gradient Scaler (not enabled) + self.scaler = self._init_amp_scaler(False, self.logger) + self.replay_stream = None + + if self.cuda_graphs_enabled: + self.graph = torch.cuda.CUDAGraph() + + self.output = None + self.iteration = 0 + self.cuda_graph_warmup = cuda_graph_warmup # Default for DDP = 11 + + def __call__(self, fn: Callable) -> Callable: + self.function = fn + + @functools.wraps(fn) + def decorated(*args: Any, **kwds: Any) -> Any: + """Training step decorator function""" + + with torch.no_grad() if self.no_grad else nullcontext(): + if self.cuda_graphs_enabled: + self._cuda_graph_forward(*args, **kwds) + else: + self._zero_grads() + self.output = self._amp_forward(*args, **kwds) + + if not self.eval: + # Update model parameters + self.scaler.step(self.optim) + self.scaler.update() + + return self.output + + return decorated + + def _cuda_graph_forward(self, *args: Any, **kwargs: Any) -> Any: + """Forward training step with CUDA graphs + + Returns + ------- + Any + Output of neural network forward + """ + # Graph warm up + if self.iteration < self.cuda_graph_warmup: + self.replay_stream.wait_stream(torch.cuda.current_stream()) + self._zero_grads() + with torch.cuda.stream(self.replay_stream): + output = self._amp_forward(*args, **kwargs) + self.output = output.detach() + torch.cuda.current_stream().wait_stream(self.replay_stream) + # CUDA Graphs + else: + # Graph record + if self.iteration == self.cuda_graph_warmup: + self.logger.warning(f"Recording graph of '{self.function.__name__}'") + self._zero_grads() + torch.cuda.synchronize() + if DistributedManager().distributed: + torch.distributed.barrier() + # TODO: temporary workaround till this issue is fixed: + # https://github.com/pytorch/pytorch/pull/104487#issuecomment-1638665876 + delay = os.environ.get("PHYSICSNEMO_CUDA_GRAPH_CAPTURE_DELAY", "10") + time.sleep(int(delay)) + with torch.cuda.graph(self.graph): + output = self._amp_forward(*args, **kwargs) + self.output = output.detach() + # Graph replay + self.graph.replay() + + self.iteration += 1 + return self.output + + def _zero_grads(self): + """Zero gradients + + Default to `set_to_none` since this will in general have lower memory + footprint, and can modestly improve performance. + + Note + ---- + Zeroing gradients can potentially cause an invalid CUDA memory access in another + graph. However if your graph involves gradients, you much set your gradients to none. + If there is already a graph recorded that includes these gradients, this will error. + Use the `NoGrad` version of capture to avoid this issue for inferencers / validators. + """ + # Skip zeroing if no grad is being used + if self.no_grad: + return + + try: + self.optim.zero_grad(set_to_none=True) + except Exception: + if self.optim: + self.optim.zero_grad() + # For apex optim support and eval mode (need to reset model grads) + self.model.zero_grad(set_to_none=True) + + def _amp_forward(self, *args, **kwargs) -> Any: + """Compute loss and gradients (if training) with AMP + + Returns + ------- + Any + Output of neural network forward + """ + with torch.autocast( + self.amp_device, enabled=self.use_autocast, dtype=self.amp_dtype + ): + output = self.function(*args, **kwargs) + + if not self.eval: + # In training mode output should be the loss + self.scaler.scale(output).backward() + if self.gradient_clip_norm is not None: + self.scaler.unscale_(self.optim) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.gradient_clip_norm + ) + + return output + + def _init_amp_scaler( + self, scaler_enabled: bool, logger: Logger + ) -> torch.cuda.amp.GradScaler: + # Create gradient scaler + scaler = torch.cuda.amp.GradScaler(enabled=scaler_enabled) + # Store scaler in class variable + self.amp_scalers[self.label] = scaler + logging.debug(f"Created gradient scaler {self.label}") + + # If our checkpoint dictionary has weights for this scaler lets load + if self.label in self.amp_scaler_checkpoints: + try: + scaler.load_state_dict(self.amp_scaler_checkpoints[self.label]) + del self.amp_scaler_checkpoints[self.label] + self.logger.info(f"Loaded grad scaler state dictionary {self.label}.") + except Exception as e: + self.logger.error( + f"Failed to load grad scaler {self.label} state dict from saved " + + "checkpoints. Did you switch the ordering of declared static captures?" + ) + raise ValueError(e) + return scaler + + @classmethod + def state_dict(cls) -> Dict[str, Any]: + """Class method for accsessing the StaticCapture state dictionary. + Use this in a training checkpoint function. + + Returns + ------- + Dict[str, Any] + Dictionary of states to save for file + """ + scaler_states = {} + for key, value in cls._amp_scalers.items(): + scaler_states[key] = value.state_dict() + + return scaler_states + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any]) -> None: + """Class method for loading a StaticCapture state dictionary. + Use this in a training checkpoint function. + + Returns + ------- + Dict[str, Any] + Dictionary of states to save for file + """ + for key, value in state_dict.items(): + # If scaler has been created already load the weights + if key in cls._amp_scalers: + try: + cls._amp_scalers[key].load_state_dict(value) + cls._logger.info(f"Loaded grad scaler state dictionary {key}.") + except Exception as e: + cls._logger.error( + f"Failed to load grad scaler state dict with id {key}." + + " Something went wrong!" + ) + raise ValueError(e) + # Otherwise store in checkpoints for later use + else: + cls._amp_scaler_checkpoints[key] = value + + @classmethod + def reset_state(cls): + cls._amp_scalers = {} + cls._amp_scaler_checkpoints = {} + + +class StaticCaptureTraining(_StaticCapture): + """A performance optimization decorator for PyTorch training functions. + + This class should be initialized as a decorator on a function that computes the + forward pass of the neural network and loss function. The user should only call the + defind training step function. This will apply optimizations including: AMP and + Cuda Graphs. + + Parameters + ---------- + model : physicsnemo.models.Module + PhysicsNeMo Model + optim : torch.optim + Optimizer + logger : Optional[Logger], optional + PhysicsNeMo Launch Logger, by default None + use_graphs : bool, optional + Toggle CUDA graphs if supported by model, by default True + use_amp : bool, optional + Toggle AMP if supported by mode, by default True + cuda_graph_warmup : int, optional + Number of warmup steps for cuda graphs, by default 11 + amp_type : Union[float16, bfloat16], optional + Auto casting type for AMP, by default torch.float16 + gradient_clip_norm : Optional[float], optional + Threshold for gradient clipping + label : Optional[str], optional + Static capture checkpoint label, by default None + + Raises + ------ + ValueError + If the model provided is not a physicsnemo.models.Module. I.e. has no meta data. + + Example + ------- + >>> # Create model + >>> model = physicsnemo.models.mlp.FullyConnected(2, 64, 2) + >>> input = torch.rand(8, 2) + >>> output = torch.rand(8, 2) + >>> # Create optimizer + >>> optim = torch.optim.Adam(model.parameters(), lr=0.001) + >>> # Create training step function with optimization wrapper + >>> @StaticCaptureTraining(model=model, optim=optim) + ... def training_step(model, invar, outvar): + ... predvar = model(invar) + ... loss = torch.sum(torch.pow(predvar - outvar, 2)) + ... return loss + ... + >>> # Sample training loop + >>> for i in range(3): + ... loss = training_step(model, input, output) + ... + + Note + ---- + Static captures must be checkpointed when training using the `state_dict()` if AMP + is being used with gradient scaler. By default, this requires static captures to be + instantiated in the same order as when they were checkpointed. The label parameter + can be used to relax/circumvent this ordering requirement. + + Note + ---- + Capturing multiple cuda graphs in a single program can lead to potential invalid CUDA + memory access errors on some systems. Prioritize capturing training graphs when this + occurs. + """ + + def __init__( + self, + model: "physicsnemo.Module", + optim: torch.optim, + logger: Optional[Logger] = None, + use_graphs: bool = True, + use_amp: bool = True, + compile: bool = False, + cuda_graph_warmup: int = 11, + amp_type: Union[float16, bfloat16] = torch.float16, + gradient_clip_norm: Optional[float] = None, + label: Optional[str] = None, + ): + super().__init__( + model, + optim, + logger, + use_graphs, + use_amp, + use_amp, + compile, + cuda_graph_warmup, + amp_type, + gradient_clip_norm, + label, + ) + + +class StaticCaptureEvaluateNoGrad(_StaticCapture): + + """An performance optimization decorator for PyTorch no grad evaluation. + + This class should be initialized as a decorator on a function that computes run the + forward pass of the model that does not require gradient calculations. This is the + recommended method to use for inference and validation methods. + + Parameters + ---------- + model : physicsnemo.models.Module + PhysicsNeMo Model + logger : Optional[Logger], optional + PhysicsNeMo Launch Logger, by default None + use_graphs : bool, optional + Toggle CUDA graphs if supported by model, by default True + use_amp : bool, optional + Toggle AMP if supported by mode, by default True + cuda_graph_warmup : int, optional + Number of warmup steps for cuda graphs, by default 11 + amp_type : Union[float16, bfloat16], optional + Auto casting type for AMP, by default torch.float16 + label : Optional[str], optional + Static capture checkpoint label, by default None + + Raises + ------ + ValueError + If the model provided is not a physicsnemo.models.Module. I.e. has no meta data. + + Example + ------- + >>> # Create model + >>> model = physicsnemo.models.mlp.FullyConnected(2, 64, 2) + >>> input = torch.rand(8, 2) + >>> # Create evaluate function with optimization wrapper + >>> @StaticCaptureEvaluateNoGrad(model=model) + ... def eval_step(model, invar): + ... predvar = model(invar) + ... return predvar + ... + >>> output = eval_step(model, input) + >>> output.size() + torch.Size([8, 2]) + + Note + ---- + Capturing multiple cuda graphs in a single program can lead to potential invalid CUDA + memory access errors on some systems. Prioritize capturing training graphs when this + occurs. + """ + + def __init__( + self, + model: "physicsnemo.Module", + logger: Optional[Logger] = None, + use_graphs: bool = True, + use_amp: bool = True, + compile: bool = False, + cuda_graph_warmup: int = 11, + amp_type: Union[float16, bfloat16] = torch.float16, + label: Optional[str] = None, + ): + super().__init__( + model, + None, + logger, + use_graphs, + use_amp, + compile, + False, + cuda_graph_warmup, + amp_type, + None, + label, + ) + self.eval = True # No optimizer/scaler calls + self.no_grad = True # No grad context and no grad zeroing diff --git a/src/hirad/utils/checkpoint.py b/src/hirad/utils/checkpoint.py new file mode 100644 index 0000000..a346b16 --- /dev/null +++ b/src/hirad/utils/checkpoint.py @@ -0,0 +1,330 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import re +from pathlib import Path +from typing import Any, Dict, List, NewType, Optional, Union, Tuple + +import torch +from torch.cuda.amp import GradScaler +from torch.optim.lr_scheduler import _LRScheduler + +from hirad.distributed import DistributedManager +from .console import PythonLogger + +optimizer = NewType("optimizer", torch.optim.Optimizer) +scheduler = NewType("scheduler", _LRScheduler) +scaler = NewType("scaler", GradScaler) + +checkpoint_logging = PythonLogger("checkpoint") + + +def _get_checkpoint_filename( + path: str, + base_name: str = "checkpoint", + index: Union[int, None] = None, + saving: bool = False, + model_type: str = "pt", +) -> str: + """Gets the file name /path of checkpoint + + This function has three different ways of providing a checkout filename: + - If supplied an index this will return the checkpoint name using that index. + - If index is None and saving is false, this will get the checkpoint with the + largest index (latest save). + - If index is None and saving is true, it will return the next valid index file name + which is calculated by indexing the largest checkpoint index found by one. + + Parameters + ---------- + path : str + Path to checkpoints + base_name: str, optional + Base file name, by default checkpoint + index : Union[int, None], optional + Checkpoint index, by default None + saving : bool, optional + Get filename for saving a new checkpoint, by default False + model_type : str + Model type, by default "mdlus" for PhysicsNeMo models and "pt" for PyTorch models + + + Returns + ------- + str + Checkpoint file name + """ + # Get model parallel rank so all processes in the first model parallel group + # can save their checkpoint. In the case without model parallelism, + # model_parallel_rank should be the same as the process rank itself and + # only rank 0 saves + if not DistributedManager.is_initialized(): + checkpoint_logging.warning( + "`DistributedManager` not initialized already. Initializing now, but this might lead to unexpected errors" + ) + DistributedManager.initialize() + manager = DistributedManager() + model_parallel_rank = ( + manager.group_rank("model_parallel") + if "model_parallel" in manager.group_names + else 0 + ) + + # Input file name + checkpoint_filename = str( + Path(path).resolve() / f"{base_name}.{model_parallel_rank}" + ) + + # File extension for PhysicsNeMo models or PyTorch models + file_extension = "."+model_type + + # If epoch is provided load that file + if index is not None: + checkpoint_filename = checkpoint_filename + f".{index}" + checkpoint_filename += file_extension + # Otherwise try loading the latest epoch or rolling checkpoint + else: + file_names = [ + Path(fname).name + for fname in glob.glob( + checkpoint_filename + "*" + file_extension, recursive=False + ) + ] + + if len(file_names) > 0: + # If checkpoint from a null index save exists load that + # This is the most likely line to error since it will fail with + # invalid checkpoint names + file_idx = [ + int( + re.sub( + f"^{base_name}.{model_parallel_rank}.|" + file_extension, + "", + fname, + ) + ) + for fname in file_names + ] + file_idx.sort() + # If we are saving index by 1 to get the next free file name + if saving: + checkpoint_filename = checkpoint_filename + f".{file_idx[-1]+1}" + else: + checkpoint_filename = checkpoint_filename + f".{file_idx[-1]}" + checkpoint_filename += file_extension + else: + checkpoint_filename += ".0" + file_extension + + return checkpoint_filename + + +def save_checkpoint( + path: str, + model: Union[torch.nn.Module, None] = None, + optimizer: Union[optimizer, None] = None, + scheduler: Union[scheduler, None] = None, + scaler: Union[scaler, None] = None, + epoch: Union[int, None] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> None: + """Training checkpoint saving utility + + This will save a training checkpoint in the provided path following the file naming + convention "checkpoint.{model parallel id}.{epoch/index}.pt". The load checkpoint + method can then be used to read this file. + + Parameters + ---------- + path : str + Path to save the training checkpoint + models : Union[torch.nn.Module, List[torch.nn.Module], None], optional + A single or list of PyTorch models, by default None + optimizer : Union[optimizer, None], optional + Optimizer, by default None + scheduler : Union[scheduler, None], optional + Learning rate scheduler, by default None + scaler : Union[scaler, None], optional + AMP grad scaler. Will attempt to save on in static capture if none provided, by + default None + epoch : Union[int, None], optional + Epoch checkpoint to load. If none this will save the checkpoint in the next + valid index, by default None + metadata : Optional[Dict[str, Any]], optional + Additional metadata to save, by default None + """ + # Create checkpoint directory if it does not exist + if not Path(path).is_dir(): + checkpoint_logging.warning( + f"Output directory {path} does not exist, will " "attempt to create" + ) + Path(path).mkdir(parents=True, exist_ok=True) + + # == Saving model checkpoint == + if model: + if hasattr(model, "module"): + # Strip out DDP layer + model = model.module + # Base name of model is meta.name unless pytorch model + name = model.__class__.__name__ + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, saving=True, model_type="pt" + ) + + # Save state dictionary + torch.save(model.state_dict(), file_name) + checkpoint_logging.success(f"Saved model state dictionary: {file_name}") + + # == Saving training checkpoint == + checkpoint_dict = {} + # Optimizer state dict + if optimizer: + checkpoint_dict["optimizer_state_dict"] = optimizer.state_dict() + + # Scheduler state dict + if scheduler: + checkpoint_dict["scheduler_state_dict"] = scheduler.state_dict() + + # Scaler state dict + if scaler: + checkpoint_dict["scaler_state_dict"] = scaler.state_dict() + + # Output file name + output_filename = _get_checkpoint_filename( + path, index=epoch, saving=True, model_type="pt" + ) + if epoch: + checkpoint_dict["epoch"] = epoch + if metadata: + checkpoint_dict["metadata"] = metadata + + # Save checkpoint to memory + if bool(checkpoint_dict): + torch.save( + checkpoint_dict, + output_filename, + ) + checkpoint_logging.success(f"Saved training checkpoint: {output_filename}") + + +def load_checkpoint( + path: str, + model: torch.nn.Module, + optimizer: Union[optimizer, None] = None, + scheduler: Union[scheduler, None] = None, + scaler: Union[scaler, None] = None, + epoch: Union[int, None] = None, + metadata_dict: Optional[Dict[str, Any]] = {}, + device: Union[str, torch.device] = "cpu", +) -> int: + """Checkpoint loading utility + + This loader is designed to be used with the save checkpoint utility. Given a path, this method will try to find a checkpoint and load state + dictionaries into the provided training objects. + + Parameters + ---------- + path : str + Path to training checkpoint + models : Union[torch.nn.Module, List[torch.nn.Module], None], optional + A single or list of PyTorch models, by default None + optimizer : Union[optimizer, None], optional + Optimizer, by default None + scheduler : Union[scheduler, None], optional + Learning rate scheduler, by default None + scaler : Union[scaler, None], optional + AMP grad scaler, by default None + epoch : Union[int, None], optional + Epoch checkpoint to load. If none is provided this will attempt to load the + checkpoint with the largest index, by default None + metadata_dict: Optional[Dict[str, Any]], optional + Dictionary to store metadata from the checkpoint, by default None + device : Union[str, torch.device], optional + Target device, by default "cpu" + + Returns + ------- + int + Loaded epoch + """ + # Check if checkpoint directory exists + if not Path(path).is_dir(): + checkpoint_logging.warning( + f"Provided checkpoint directory {path} does not exist, skipping load" + ) + return 0 + + # == Loading model checkpoint == + if hasattr(model, "module"): + # Strip out DDP layer + model = model.module + # Base name of model is meta.name unless pytorch model + name = model.__class__.__name__ + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, + ) + if not Path(file_name).exists(): + checkpoint_logging.error( + f"Could not find valid model file {file_name}, skipping load" + ) + else: + # Load state dictionary + model.load_state_dict(torch.load(file_name, map_location=device)) + + checkpoint_logging.success( + f"Loaded model state dictionary {file_name} to device {device}" + ) + + # == Loading training checkpoint == + checkpoint_filename = _get_checkpoint_filename(path, index=epoch, model_type="pt") + if not Path(checkpoint_filename).is_file(): + checkpoint_logging.warning( + f"Could not find valid checkpoint file {checkpoint_filename} skipping load" + ) + return 0 + + checkpoint_dict = torch.load(checkpoint_filename, map_location=device) + checkpoint_logging.success( + f"Loaded checkpoint file {checkpoint_filename} to device {device}" + ) + + # Optimizer state dict + if optimizer and "optimizer_state_dict" in checkpoint_dict: + optimizer.load_state_dict(checkpoint_dict["optimizer_state_dict"]) + checkpoint_logging.success("Loaded optimizer state dictionary") + + # Scheduler state dict + if scheduler and "scheduler_state_dict" in checkpoint_dict: + scheduler.load_state_dict(checkpoint_dict["scheduler_state_dict"]) + checkpoint_logging.success("Loaded scheduler state dictionary") + + # Scaler state dict + if scaler and "scaler_state_dict" in checkpoint_dict: + scaler.load_state_dict(checkpoint_dict["scaler_state_dict"]) + checkpoint_logging.success("Loaded grad scaler state dictionary") + + epoch = 0 + if "epoch" in checkpoint_dict: + epoch = checkpoint_dict["epoch"] + + # Update metadata if exists and the dictionary object is provided + metadata = checkpoint_dict.get("metadata", {}) + for key, value in metadata.items(): + metadata_dict[key] = value + + return epoch diff --git a/src/hirad/utils/console.py b/src/hirad/utils/console.py new file mode 100644 index 0000000..4231576 --- /dev/null +++ b/src/hirad/utils/console.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +from termcolor import colored + + +class PythonLogger: + """Simple console logger for DL training + This is a WIP + """ + + def __init__(self, name: str = "launch"): + self.logger = logging.getLogger(name) + + def file_logging(self, file_name: str = "launch.log"): + """Log to file""" + if os.path.exists(file_name): + try: + os.remove(file_name) + except FileNotFoundError: + # ignore if already removed (can happen with multiple processes) + pass + formatter = logging.Formatter( + "[%(asctime)s - %(name)s - %(levelname)s] %(message)s", + datefmt="%H:%M:%S", + ) + filehandler = logging.FileHandler(file_name) + filehandler.setFormatter(formatter) + filehandler.setLevel(logging.DEBUG) + self.logger.addHandler(filehandler) + + def log(self, message: str): + """Log message""" + self.logger.info(message) + + def info(self, message: str): + """Log info""" + self.logger.info(colored(message, "light_blue")) + + def success(self, message: str): + """Log success""" + self.logger.info(colored(message, "light_green")) + + def warning(self, message: str): + """Log warning""" + self.logger.warning(colored(message, "light_yellow")) + + def error(self, message: str): + """Log error""" + self.logger.error(colored(message, "light_red")) + + +class RankZeroLoggingWrapper: + """Wrapper class to only log from rank 0 process in distributed training.""" + + def __init__(self, obj, dist): + self.obj = obj + self.dist = dist + + def __getattr__(self, name): + attr = getattr(self.obj, name) + if callable(attr): + + def wrapper(*args, **kwargs): + if self.dist.rank == 0: + return attr(*args, **kwargs) + else: + return None + + return wrapper + else: + return attr diff --git a/src/hirad/utils/deterministic_sampler.py b/src/hirad/utils/deterministic_sampler.py new file mode 100644 index 0000000..e502875 --- /dev/null +++ b/src/hirad/utils/deterministic_sampler.py @@ -0,0 +1,341 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Literal, Optional + +import numpy as np +import nvtx +import torch + +from hirad.models import EDMPrecond + +# ruff: noqa: E731 + + +@nvtx.annotate(message="deterministic_sampler", color="red") +def deterministic_sampler( + net: torch.nn.Module, + latents: torch.Tensor, + img_lr: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + randn_like: Callable = torch.randn_like, + num_steps: int = 18, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + rho: float = 7.0, + solver: Literal["heun", "euler"] = "heun", + discretization: Literal["vp", "ve", "iddpm", "edm"] = "edm", + schedule: Literal["vp", "ve", "linear"] = "linear", + scaling: Literal["vp", "none"] = "none", + epsilon_s: float = 1e-3, + C_1: float = 0.001, + C_2: float = 0.008, + M: int = 1000, + alpha: float = 1.0, + S_churn: int = 0, + S_min: float = 0.0, + S_max: float = float("inf"), + S_noise: float = 1.0, +) -> torch.Tensor: + """ + Generalized sampler, representing the superset of all sampling methods + discussed in the paper "Elucidating the Design Space of Diffusion-Based + Generative Models" (EDM). + - https://arxiv.org/abs/2206.00364 + + This function integrates an ODE (probability flow) or SDE over multiple + time-steps to generate samples from the diffusion model provided by the + argument 'net'. It can be used to combine multiple choices to + design a custom sampler, including multiple integration solver, + discretization method, noise schedule, and so on. + + Parameters: + ----------- + net : torch.nn.Module + The diffusion model to use in the sampling process. + latents : torch.Tensor + The latent random noise used as the initial condition for the + stochastic ODE. + img_lr : torch.Tensor + Low-resolution input image for conditioning the diffusion process. + Passed as a keywork argument to the model 'net'. + class_labels : Optional[torch.Tensor] + Labels of the classes used as input to a class-conditionned + diffusion model. Passed as a keyword argument to the model 'net'. + If provided, it must be a tensor containing integer values. + Defaults to None, in which case it is ignored. + randn_like: Callable + Random Number Generator to generate random noise that is added + during the stochastic sampling. Must have the same signature as + torch.randn_like and return torch.Tensor. Defaults to + torch.randn_like. + num_steps : Optional[int] + Number of time-steps for the stochastic ODE integration. Defaults + to 18. + sigma_min : Optional[float] + Minimum noise level for the diffusion process. 'sigma_min', + 'sigma_max', and 'rho' are used to compute the time-step + discretization, based on the choice of discretization. For the + default choice ("discretization='heun'"), the noise level schedule + is computed as: + :math:`\sigma_i = (\sigma_{max}^{1/\rho} + i / (num_steps - 1) * (\sigma_{min}^{1/\rho} - \sigma_{max}^{1/\rho}))^{rho}`. + For other choices of 'discretization', see details in the EDM + paper. Defaults to None, in which case defaults values depending + of the specified discretization are used. + sigma_max : Optional[float] + Maximum noise level for the diffusion process. See sigma_min for + details. Defaults to None, in which case defaults values depending + of the specified discretization are used. + rho : float, optional + Exponent used in the noise schedule. See sigma_min for details. + Only used when 'discretization' is 'heun'. Values in the range [5, + 10] produce better images. Lower values lead to truncation errors + equalized over all time steps. Defaults to 7. + solver : Literal["heun", "euler"] + The numerical method used to integrate the stochastic ODE. "euler" + is 1st order solver, which is faster but produces lower-quality + images. "heun" is 2nd order, more expensive, but produces + higher-quality images. Defaults to "heun". + discretization : Literal["vp", "ve", "iddpm", "edm"] + The method to discretize time-steps :math:`t_i` in the + diffusion process. See the EDM papper for details. Defaults to + "edm". + schedule : Literal["vp", "ve", "linear"] + The type of noise level schedule. Defaults to "linear". If + schedule='ve', then :math:`\sigma(t) = \sqrt{t}`. If + schedule='linear', then :math:`\sigma(t) = t`. If schedule='vp', + see EDM paper for details. Defaults to "linear". + scaling : Literal["vp", "none"] + The type of time-dependent signal scaling :math:`s(t)`, such that + :math:`x = s(t) \hat{x}`. See EDM paper for details on the 'vp' + scaling. Defaults to 'none', in which case :math:`s(t)=1`. + epsilon_s : float, optional + Parameter to compute both the noise level schedule and the + time-step discetization. Only used when discretization='vp' or + schedule='vp'. Ignored in other cases. Defaults to 1e-3. + C_1 : float, optional + Parameters to compute the time-step discetization. Only used when + discretization='iddpm'. Defaults to 0.001. + C_2 : float, optional + Same as for C_1. Only used when discretization='iddpm'. Defaults to + 0.008. + M : int, optional + Same as for C_1 and C_2. Only used when discretization='iddpm'. + Defaults to 1000. + alpha : float, optional + Controls (i.e. multiplies) the step size :math:`t_{i+1} - + \hat{t}_i` in the stochastic sampler, where :math:`\hat{t}_i` is + the temporarily increased noise level. Defaults to 1.0, which is + the recommended value. + S_churn : int, optional + Controls the amount of stochasticty injected in the SDE in the + stochatsic sampler. Larger values of S_churn lead to larger values + of :math:`\hat{t}_i`, which in turn lead to injecting more + stochasticity in the SDE by Defaults to 0, which means no + stochasticity is injected. + S_min : float, optional + S_min and S_max control the time-step range obver which + stochasticty is injected in the SDE. Stochasticity is injected + through `\hat{t}_i` for time-steps :math:`t_i` such that + :math:`S_{min} \leq t_i \leq S_{max}`. Defaults to 0.0. + S_max : float, optional + See S_min. Defaults to float("inf"). + S_noise : float, optional + Controls the amount of stochasticty injected in the SDE in the + stochatsic sampler. Added signal noise is proportinal to + :math:`\epsilon_i` where `\epsilon_i ~ N(0, S_{noise}^2)`. Defaults + to 1.0. + + Returns + ------- + torch.Tensor: + Generated batch of samples. Same shape as the input 'latents'. + """ + + # conditioning + x_lr = img_lr + + if solver not in ["euler", "heun"]: + raise ValueError(f"Unknown solver {solver}") + if discretization not in ["vp", "ve", "iddpm", "edm"]: + raise ValueError(f"Unknown discretization {discretization}") + if schedule not in ["vp", "ve", "linear"]: + raise ValueError(f"Unknown schedule {schedule}") + if scaling not in ["vp", "none"]: + raise ValueError(f"Unknown scaling {scaling}") + + # Helper functions for VP & VE noise level schedules. + vp_sigma = ( + lambda beta_d, beta_min: lambda t: ( + np.e ** (0.5 * beta_d * (t**2) + beta_min * t) - 1 + ) + ** 0.5 + ) + vp_sigma_deriv = ( + lambda beta_d, beta_min: lambda t: 0.5 + * (beta_min + beta_d * t) + * (sigma(t) + 1 / sigma(t)) + ) + vp_sigma_inv = ( + lambda beta_d, beta_min: lambda sigma: ( + (beta_min**2 + 2 * beta_d * (sigma**2 + 1).log()).sqrt() - beta_min + ) + / beta_d + ) + ve_sigma = lambda t: t.sqrt() + ve_sigma_deriv = lambda t: 0.5 / t.sqrt() + ve_sigma_inv = lambda sigma: sigma**2 + + # Select default noise level range based on the specified time step discretization. + if sigma_min is None: + vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s) + sigma_min = {"vp": vp_def, "ve": 0.02, "iddpm": 0.002, "edm": 0.002}[ + discretization + ] + if sigma_max is None: + vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1) + sigma_max = {"vp": vp_def, "ve": 100, "iddpm": 81, "edm": 80}[discretization] + + # Adjust noise levels based on what's supported by the network. + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + + # Compute corresponding betas for VP. + vp_beta_d = ( + 2 + * (np.log(sigma_min**2 + 1) / epsilon_s - np.log(sigma_max**2 + 1)) + / (epsilon_s - 1) + ) + vp_beta_min = np.log(sigma_max**2 + 1) - 0.5 * vp_beta_d + + # Define time steps in terms of noise level. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + if discretization == "vp": + orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) + sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) + elif discretization == "ve": + orig_t_steps = (sigma_max**2) * ( + (sigma_min**2 / sigma_max**2) ** (step_indices / (num_steps - 1)) + ) + sigma_steps = ve_sigma(orig_t_steps) + elif discretization == "iddpm": + u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device) + alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 + for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1 + u[j - 1] = ( + (u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1 + ).sqrt() + u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] + sigma_steps = u_filtered[ + ((len(u_filtered) - 1) / (num_steps - 1) * step_indices) + .round() + .to(torch.int64) + ] + else: + sigma_steps = ( + sigma_max ** (1 / rho) + + step_indices + / (num_steps - 1) + * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) + ) ** rho + + # Define noise level schedule. + if schedule == "vp": + sigma = vp_sigma(vp_beta_d, vp_beta_min) + sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) + sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) + elif schedule == "ve": + sigma = ve_sigma + sigma_deriv = ve_sigma_deriv + sigma_inv = ve_sigma_inv + else: + sigma = lambda t: t + sigma_deriv = lambda t: 1 + sigma_inv = lambda sigma: sigma + + # Define scaling schedule. + if scaling == "vp": + s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() + s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) + else: + s = lambda t: 1 + s_deriv = lambda t: 0 + + # Compute final time steps based on the corresponding noise levels. + t_steps = sigma_inv(net.round_sigma(sigma_steps)) + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + t_next = t_steps[0] + x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = ( + min(S_churn / num_steps, np.sqrt(2) - 1) + if S_min <= sigma(t_cur) <= S_max + else 0 + ) + t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) + x_hat = s(t_hat) / s(t_cur) * x_cur + ( + sigma(t_hat) ** 2 - sigma(t_cur) ** 2 + ).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur) + + # Euler step. + h = t_next - t_hat + if isinstance(net, EDMPrecond): + # Conditioning info is passed as keyword arg + denoised = net( + x_hat / s(t_hat), + sigma(t_hat), + condition=x_lr, + class_labels=class_labels, + ).to(torch.float64) + else: + denoised = net(x_hat / s(t_hat), x_lr, sigma(t_hat), class_labels).to( + torch.float64 + ) + d_cur = ( + sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat) + ) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised + x_prime = x_hat + alpha * h * d_cur + t_prime = t_hat + alpha * h + + # Apply 2nd order correction. + if solver == "euler" or i == num_steps - 1: + x_next = x_hat + h * d_cur + else: + if isinstance(net, EDMPrecond): + # Conditioning info is passed as keyword arg + denoised = net( + x_prime / s(t_prime), + sigma(t_prime), + condition=x_lr, + class_labels=class_labels, + ).to(torch.float64) + else: + denoised = net( + x_prime / s(t_prime), x_lr, sigma(t_prime), class_labels + ).to(torch.float64) + d_prime = ( + sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime) + ) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised + x_next = x_hat + h * ( + (1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime + ) + + return x_next diff --git a/src/hirad/utils/function_utils.py b/src/hirad/utils/function_utils.py new file mode 100644 index 0000000..347457c --- /dev/null +++ b/src/hirad/utils/function_utils.py @@ -0,0 +1,798 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Miscellaneous utility classes and functions.""" + +import contextlib +import ctypes +import datetime +import fnmatch +import importlib +import inspect +import os +import re +import shutil +import sys +import types +import warnings +from typing import Any, Iterator, List, Tuple, Union + +import cftime +import numpy as np +import torch + +# ruff: noqa: E722 PERF203 S110 E713 S324 + + +class EasyDict(dict): # pragma: no cover + """ + Convenience class that behaves like a dict but allows access with the attribute + syntax. + """ + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class StackedRandomGenerator: # pragma: no cover + """ + Wrapper for torch.Generator that allows specifying a different random seed + for each sample in a minibatch. + """ + + def __init__(self, device, seeds): + super().__init__() + self.generators = [ + torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds + ] + + def randn(self, size, **kwargs): + if size[0] != len(self.generators): + raise ValueError( + f"Expected first dimension of size {len(self.generators)}, got {size[0]}" + ) + return torch.stack( + [torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators] + ) + + def randn_like(self, input): + return self.randn( + input.shape, dtype=input.dtype, layout=input.layout, device=input.device + ) + + def randint(self, *args, size, **kwargs): + if size[0] != len(self.generators): + raise ValueError( + f"Expected first dimension of size {len(self.generators)}, got {size[0]}" + ) + return torch.stack( + [ + torch.randint(*args, size=size[1:], generator=gen, **kwargs) + for gen in self.generators + ] + ) + + +def parse_int_list(s): # pragma: no cover + """ + Parse a comma separated list of numbers or ranges and return a list of ints. + Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] + """ + if isinstance(s, list): + return s + ranges = [] + range_re = re.compile(r"^(\d+)-(\d+)$") + for p in s.split(","): + m = range_re.match(p) + if m: + ranges.extend(range(int(m.group(1)), int(m.group(2)) + 1)) + else: + ranges.append(int(p)) + return ranges + + +# Small util functions +# ------------------------------------------------------------------------------------- +def convert_datetime_to_cftime( + time: datetime.datetime, cls=cftime.DatetimeGregorian +) -> cftime.DatetimeGregorian: + """Convert a Python datetime object to a cftime DatetimeGregorian object.""" + return cls(time.year, time.month, time.day, time.hour, time.minute, time.second) + + +def time_range( + start_time: datetime.datetime, + end_time: datetime.datetime, + step: datetime.timedelta, + inclusive: bool = False, +): + """Like the Python `range` iterator, but with datetimes.""" + t = start_time + while (t <= end_time) if inclusive else (t < end_time): + yield t + t += step + + +def format_time(seconds: Union[int, float]) -> str: # pragma: no cover + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format( + s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60 + ) + + +def format_time_brief(seconds: Union[int, float]) -> str: # pragma: no cover + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) + else: + return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) + + +def tuple_product(t: Tuple) -> Any: # pragma: no cover + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double, +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: # pragma: no cover + """ + Given a type name string (or an object having a __name__ attribute), return + matching Numpy and ctypes types that have the same size in bytes. + """ + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + if type_str not in _str_to_ctype.keys(): + raise ValueError("Unknown type name: " + type_str) + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + if my_dtype.itemsize != ctypes.sizeof(my_ctype): + raise ValueError( + "Numpy and ctypes types for '{}' have different sizes!".format(type_str) + ) + + return my_dtype, my_ctype + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------- + + +def get_module_from_obj_name( + obj_name: str, +) -> Tuple[types.ModuleType, str]: # pragma: no cover + """ + Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed). + """ + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [ + (".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1) + ] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith( + "No module named '" + module_name + "'" + ): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module( + module: types.ModuleType, obj_name: str +) -> Any: # pragma: no cover + """ + Traverses the object name and returns the last (rightmost) python object. + """ + if obj_name == "": + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: # pragma: no cover + """ + Finds the python object with the given name. + """ + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name( + *args, func_name: str = None, **kwargs +) -> Any: # pragma: no cover + """ + Finds the python object with the given name and calls it as a function. + """ + if func_name is None: + raise ValueError("func_name must be specified") + func_obj = get_obj_by_name(func_name) + if not callable(func_obj): + raise ValueError(func_name + " is not callable") + return func_obj(*args, **kwargs) + + +def construct_class_by_name( + *args, class_name: str = None, **kwargs +) -> Any: # pragma: no cover + """ + Finds the python class with the given name and constructs it with the given + arguments. + """ + return call_func_by_name(*args, func_name=class_name, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: # pragma: no cover + """ + Get the directory path of the module containing the given object name. + """ + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: # pragma: no cover + """ + Determine whether the given object is a top-level function, i.e., defined at module + scope using 'def'. + """ + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: # pragma: no cover + """ + Return the fully-qualified name of a top-level function. + """ + if not is_top_level_function(obj): + raise ValueError("Object is not a top-level function") + module = obj.__module__ + if module == "__main__": + module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] + return module + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + + +def list_dir_recursively_with_ignore( + dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False +) -> List[Tuple[str, str]]: # pragma: no cover + """ + List all files recursively in a given directory while ignoring given file and + directory names. Returns list of tuples containing both absolute and relative paths. + """ + if not os.path.isdir(dir_path): + raise RuntimeError(f"Directory does not exist: {dir_path}") + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + if len(absolute_paths) != len(relative_paths): + raise ValueError("Number of absolute and relative paths do not match") + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs( + files: List[Tuple[str, str]] +) -> None: # pragma: no cover + """ + Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories. + """ + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# ---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + + +def constant( + value, shape=None, dtype=None, device=None, memory_format=None +): # pragma: no cover + """Cached construction of constant tensors""" + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device("cpu") + if memory_format is None: + memory_format = torch.contiguous_format + + key = ( + value.shape, + value.dtype, + value.tobytes(), + shape, + dtype, + device, + memory_format, + ) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + + +# ---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + + def nan_to_num( + input, nan=0.0, posinf=None, neginf=None, *, out=None + ): # pylint: disable=redefined-builtin # pragma: no cover + """Replace NaN/Inf with specified numerical values""" + if not isinstance(input, torch.Tensor): + raise TypeError("input should be a Tensor") + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + if nan != 0: + raise ValueError("nan_to_num only supports nan=0") + return torch.clamp( + input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out + ) + + +# ---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +# ---------------------------------------------------------------------------- +# Context manager to temporarily suppress known warnings in torch.jit.trace(). +# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + + +@contextlib.contextmanager +def suppress_tracer_warnings(): # pragma: no cover + """ + Context manager to temporarily suppress known warnings in torch.jit.trace(). + Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + """ + flt = ("ignore", None, torch.jit.TracerWarning, None, 0) + warnings.filters.insert(0, flt) + yield + warnings.filters.remove(flt) + + +# ---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + + +def assert_shape(tensor, ref_shape): # pragma: no cover + """ + Assert that the shape of a tensor matches the given list of integers. + None indicates that the size of a dimension is allowed to vary. + Performs symbolic assertion when used in torch.jit.trace(). + """ + if tensor.ndim != len(ref_shape): + raise AssertionError( + f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}" + ) + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert( + torch.equal(torch.as_tensor(size), ref_size), + f"Wrong size for dimension {idx}", + ) + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert( + torch.equal(size, torch.as_tensor(ref_size)), + f"Wrong size for dimension {idx}: expected {ref_size}", + ) + elif size != ref_size: + raise AssertionError( + f"Wrong size for dimension {idx}: got {size}, expected {ref_size}" + ) + + +# ---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + + +def profiled_function(fn): # pragma: no cover + """Function decorator that calls torch.autograd.profiler.record_function().""" + + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + + decorator.__name__ = fn.__name__ + return decorator + + +# ---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + + +class InfiniteSampler(torch.utils.data.Sampler[int]): # pragma: no cover + """Sampler for torch.utils.data.DataLoader that loops over the dataset indefinitely. + + This sampler yields indices indefinitely, optionally shuffling items as it goes. + It can also perform distributed sampling when rank and num_replicas are specified. + + Parameters + ---------- + dataset : torch.utils.data.Dataset + The dataset to sample from + rank : int, default=0 + The rank of the current process within num_replicas processes + num_replicas : int, default=1 + The number of processes participating in distributed sampling + shuffle : bool, default=True + Whether to shuffle the indices + seed : int, default=0 + Random seed for reproducibility when shuffling + window_size : float, default=0.5 + Fraction of dataset to use as window for shuffling. Must be between 0 and 1. + A larger window means more thorough shuffling but slower iteration. + """ + + def __init__( + self, + dataset: torch.utils.data.Dataset, + rank: int = 0, + num_replicas: int = 1, + shuffle: bool = True, + seed: int = 0, + window_size: float = 0.5, + ): + if not len(dataset) > 0: + raise ValueError("Dataset must contain at least one item") + if not num_replicas > 0: + raise ValueError("num_replicas must be positive") + if not 0 <= rank < num_replicas: + raise ValueError("rank must be non-negative and less than num_replicas") + if not 0 <= window_size <= 1: + raise ValueError("window_size must be between 0 and 1") + super().__init__() + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self) -> Iterator[int]: + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + + +# ---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + + +def params_and_buffers(module): # pragma: no cover + """Get parameters and buffers of a nn.Module""" + if not isinstance(module, torch.nn.Module): + raise TypeError("module must be a torch.nn.Module instance") + return list(module.parameters()) + list(module.buffers()) + + +def named_params_and_buffers(module): # pragma: no cover + """Get named parameters and buffers of a nn.Module""" + if not isinstance(module, torch.nn.Module): + raise TypeError("module must be a torch.nn.Module instance") + return list(module.named_parameters()) + list(module.named_buffers()) + + +@torch.no_grad() +def copy_params_and_buffers( + src_module, dst_module, require_all=False +): # pragma: no cover + """Copy parameters and buffers from a source module to target module""" + if not isinstance(src_module, torch.nn.Module): + raise TypeError("src_module must be a torch.nn.Module instance") + if not isinstance(dst_module, torch.nn.Module): + raise TypeError("dst_module must be a torch.nn.Module instance") + src_tensors = dict(named_params_and_buffers(src_module)) + for name, tensor in named_params_and_buffers(dst_module): + if not ((name in src_tensors) or (not require_all)): + raise ValueError(f"Missing source tensor for {name}") + if name in src_tensors: + tensor.copy_(src_tensors[name]) + + +# ---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + + +@contextlib.contextmanager +def ddp_sync(module, sync): # pragma: no cover + """ + Context manager for easily enabling/disabling DistributedDataParallel + synchronization. + """ + if not isinstance(module, torch.nn.Module): + raise TypeError("module must be a torch.nn.Module instance") + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + + +# ---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + + +def check_ddp_consistency(module, ignore_regex=None): # pragma: no cover + """Check DistributedDataParallel consistency across processes.""" + if not isinstance(module, torch.nn.Module): + raise TypeError("module must be a torch.nn.Module instance") + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + "." + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + if tensor.is_floating_point(): + tensor = nan_to_num(tensor) + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + if not (tensor == other).all(): + raise RuntimeError(f"DDP consistency check failed for {fullname}") + + +# ---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + + +def print_module_summary( + module, inputs, max_nesting=3, skip_redundant=True +): # pragma: no cover + """Print summary table of module hierarchy.""" + if not isinstance(module, torch.nn.Module): + raise TypeError("module must be a torch.nn.Module instance") + if isinstance(module, torch.jit.ScriptModule): + raise TypeError("module must not be a torch.jit.ScriptModule instance") + if not isinstance(inputs, (tuple, list)): + raise TypeError("inputs must be a tuple or list") + + # Register hooks. + entries = [] + nesting = [0] + + def pre_hook(_mod, _inputs): + nesting[0] += 1 + + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(EasyDict(mod=mod, outputs=outputs)) + + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= { + id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs + } + + # Filter out redundant entries. + if skip_redundant: + entries = [ + e + for e in entries + if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs) + ] + + # Construct table. + rows = [ + [type(module).__name__, "Parameters", "Buffers", "Output shape", "Datatype"] + ] + rows += [["---"] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = "" if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(t.shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split(".")[-1] for t in e.outputs] + rows += [ + [ + name + (":0" if len(e.outputs) >= 2 else ""), + str(param_size) if param_size else "-", + str(buffer_size) if buffer_size else "-", + (output_shapes + ["-"])[0], + (output_dtypes + ["-"])[0], + ] + ] + for idx in range(1, len(e.outputs)): + rows += [ + [name + f":{idx}", "-", "-", output_shapes[idx], output_dtypes[idx]] + ] + param_total += param_size + buffer_total += buffer_size + rows += [["---"] * len(rows[0])] + rows += [["Total", str(param_total), str(buffer_total), "-", "-"]] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + for row in rows: + print( + " ".join( + cell + " " * (width - len(cell)) for cell, width in zip(row, widths) + ) + ) + return outputs + + +# ---------------------------------------------------------------------------- diff --git a/src/hirad/utils/generate_utils.py b/src/hirad/utils/generate_utils.py new file mode 100644 index 0000000..43f83b6 --- /dev/null +++ b/src/hirad/utils/generate_utils.py @@ -0,0 +1,22 @@ +import datetime +from hirad.datasets import init_dataset_from_config +from .function_utils import convert_datetime_to_cftime + + +def get_dataset_and_sampler(dataset_cfg, times, has_lead_time=False): + """ + Get a dataset and sampler for generation. + """ + (dataset, _) = init_dataset_from_config(dataset_cfg, batch_size=1) + # if has_lead_time: + # plot_times = times + # else: + # plot_times = [ + # datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%S") + # for time in times + # ] + all_times = dataset.time() + time_indices = [all_times.index(t) for t in times] + sampler = time_indices + + return dataset, sampler \ No newline at end of file diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py new file mode 100644 index 0000000..8665536 --- /dev/null +++ b/src/hirad/utils/inference_utils.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +from typing import Optional + +import cftime +import nvtx +import torch +import tqdm + +from .function_utils import StackedRandomGenerator, time_range + +from .stochastic_sampler import stochastic_sampler +from .deterministic_sampler import deterministic_sampler + +############################################################################ +# CorrDiff Generation Utilities # +############################################################################ + + +def regression_step( + net: torch.nn.Module, + img_lr: torch.Tensor, + latents_shape: torch.Size, + lead_time_label: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Perform a regression step to produce ensemble mean prediction. + + This function takes a low-resolution input and performs a regression step to produce + an ensemble mean prediction. It processes a single instance and then replicates + the results across the batch dimension if needed. + + Parameters + ---------- + net : torch.nn.Module + U-Net model for regression. + img_lr : torch.Tensor + Low-resolution input to the network with shape (1, channels, height, width). + Must have a batch dimension of 1. + latents_shape : torch.Size + Shape of the latent representation with format + (batch_size, out_channels, image_shape_y, image_shape_x). + lead_time_label : Optional[torch.Tensor], optional + Lead time label tensor for lead time conditioning, + with shape (1, lead_time_dims). Default is None. + + Returns + ------- + torch.Tensor + Predicted ensemble mean at the next time step with shape matching latents_shape. + + Raises + ------ + ValueError + If img_lr has a batch size greater than 1. + """ + # Create a tensor of zeros with the given shape and move it to the appropriate device + x_hat = torch.zeros(latents_shape, dtype=img_lr.dtype, device=img_lr.device) + + # Safety check: avoid silently ignoring batch elements in img_lr + if img_lr.shape[0] > 1: + raise ValueError( + f"Expected img_lr to have a batch size of 1, " + f"but found {img_lr.shape[0]}." + ) + + # Perform regression on a single batch element + with torch.inference_mode(): + if lead_time_label is not None: + x = net(x=x_hat[0:1], img_lr=img_lr, lead_time_label=lead_time_label) + else: + x = net(x=x_hat[0:1], img_lr=img_lr) + + # If the batch size is greater than 1, repeat the prediction + if x_hat.shape[0] > 1: + x = x.repeat([d if i == 0 else 1 for i, d in enumerate(x_hat.shape)]) + + return x + + +def diffusion_step( + net: torch.nn.Module, + sampler_fn: callable, + img_shape: tuple, + img_out_channels: int, + rank_batches: list, + img_lr: torch.Tensor, + rank: int, + device: torch.device, + mean_hr: torch.Tensor = None, + lead_time_label: torch.Tensor = None, +) -> torch.Tensor: + + """ + Generate images using diffusion techniques as described in the relevant paper. + + This function applies a diffusion model to generate high-resolution images based on + low-resolution inputs. It supports optional conditioning on high-resolution mean + predictions and lead time labels. + + For each low-resolution sample in `img_lr`, the function generates multiple + high-resolution samples, with different random seeds, specified in `rank_batches`. + The function then concatenates these high-resolution samples across the batch dimension. + + Parameters + ---------- + net : torch.nn.Module + The diffusion model network. + sampler_fn : callable + Function used to sample images from the diffusion model. + img_shape : tuple + Shape of the images, (height, width). + img_out_channels : int + Number of output channels for the image. + rank_batches : list + List of batches of seeds to process. + img_lr : torch.Tensor + Low-resolution input image with shape (seed_batch_size, channels_lr, height, width). + rank : int, optional + Rank of the current process for distributed processing. + device : torch.device, optional + Device to perform computations. + mean_hr : torch.Tensor, optional + High-resolution mean tensor to be used as an additional input, + with shape (1, channels_hr, height, width). Default is None. + lead_time_label : torch.Tensor, optional + Lead time label tensor for temporal conditioning, + with shape (batch_size, lead_time_dims). Default is None. + + Returns + ------- + torch.Tensor + Generated images concatenated across batches with shape + (seed_batch_size * len(rank_batches), out_channels, height, width). + """ + + # Check img_lr dimensions match expected shape + if img_lr.shape[2:] != img_shape: + raise ValueError( + f"img_lr shape {img_lr.shape[2:]} does not match expected shape img_shape {img_shape}" + ) + + # Check mean_hr dimensions if provided + if mean_hr is not None: + if mean_hr.shape[2:] != img_shape: + raise ValueError( + f"mean_hr shape {mean_hr.shape[2:]} does not match expected shape img_shape {img_shape}" + ) + if mean_hr.shape[0] != 1: + raise ValueError(f"mean_hr must have batch size 1, got {mean_hr.shape[0]}") + + img_lr = img_lr.to(memory_format=torch.channels_last) + + # Handling of the high-res mean + additional_args = {} + if mean_hr is not None: + additional_args["mean_hr"] = mean_hr + if lead_time_label is not None: + additional_args["lead_time_label"] = lead_time_label + + # Loop over batches + all_images = [] + for batch_seeds in tqdm.tqdm(rank_batches, unit="batch", disable=(rank != 0)): + with nvtx.annotate(f"generate {len(all_images)}", color="rapids"): + batch_size = len(batch_seeds) + if batch_size == 0: + continue + + # Initialize random generator, and generate latents + rnd = StackedRandomGenerator(device, batch_seeds) + latents = rnd.randn( + [ + img_lr.shape[0], + img_out_channels, + img_shape[0], + img_shape[1], + ], + device=device, + )#.to(memory_format=torch.channels_last) + + with torch.inference_mode(): + images = sampler_fn( + net, latents, img_lr, randn_like=rnd.randn_like, **additional_args + ) + all_images.append(images) + return torch.cat(all_images) + + +def generate(): + pass + +############################################################################ +# CorrDiff writer utilities # +############################################################################ + + +class NetCDFWriter: + """NetCDF Writer""" + + def __init__( + self, f, lat, lon, input_channels, output_channels, has_lead_time=False + ): + self._f = f + self.has_lead_time = has_lead_time + # create unlimited dimensions + f.createDimension("time") + f.createDimension("ensemble") + + if lat.shape != lon.shape: + raise ValueError("lat and lon must have the same shape") + ny, nx = lat.shape + + # create lat/lon grid + f.createDimension("x", nx) + f.createDimension("y", ny) + + v = f.createVariable("lat", "f", dimensions=("y", "x")) + # NOTE rethink this for datasets whose samples don't have constant lat-lon. + v[:] = lat + v.standard_name = "latitude" + v.units = "degrees_north" + + v = f.createVariable("lon", "f", dimensions=("y", "x")) + v[:] = lon + v.standard_name = "longitude" + v.units = "degrees_east" + + # create time dimension + if has_lead_time: + v = f.createVariable("time", "str", ("time")) + else: + v = f.createVariable("time", "i8", ("time")) + v.calendar = "standard" + v.units = "hours since 1990-01-01 00:00:00" + + self.truth_group = f.createGroup("truth") + self.prediction_group = f.createGroup("prediction") + self.input_group = f.createGroup("input") + + for variable in output_channels: + name = variable.name + variable.level + self.truth_group.createVariable(name, "f", dimensions=("time", "y", "x")) + self.prediction_group.createVariable( + name, "f", dimensions=("ensemble", "time", "y", "x") + ) + + # setup input data in netCDF + + for variable in input_channels: + name = variable.name + variable.level + self.input_group.createVariable(name, "f", dimensions=("time", "y", "x")) + + def write_input(self, channel_name, time_index, val): + """Write input data to NetCDF file.""" + self.input_group[channel_name][time_index] = val + + def write_truth(self, channel_name, time_index, val): + """Write ground truth data to NetCDF file.""" + self.truth_group[channel_name][time_index] = val + + def write_prediction(self, channel_name, time_index, ensemble_index, val): + """Write prediction data to NetCDF file.""" + self.prediction_group[channel_name][ensemble_index, time_index] = val + + def write_time(self, time_index, time): + """Write time information to NetCDF file.""" + if self.has_lead_time: + self._f["time"][time_index] = time + else: + time_v = self._f["time"] + self._f["time"][time_index] = cftime.date2num( + time, time_v.units, time_v.calendar + ) + + +############################################################################ +# CorrDiff time utilities # +############################################################################ + + +def get_time_from_range(times_range, time_format="%Y-%m-%dT%H:%M:%S"): + """Generates a list of times within a given range. + + Args: + times_range: A list containing start time, end time, and optional interval (hours). + time_format: The format of the input times (default: "%Y-%m-%dT%H:%M:%S"). + + Returns: + A list of times within the specified range. + """ + + start_time = datetime.datetime.strptime(times_range[0], time_format) + end_time = datetime.datetime.strptime(times_range[1], time_format) + interval = ( + datetime.timedelta(hours=times_range[2]) + if len(times_range) > 2 + else datetime.timedelta(hours=1) + ) + + times = [ + t.strftime(time_format) + for t in time_range(start_time, end_time, interval, inclusive=True) + ] + return times diff --git a/src/hirad/utils/model_utils.py b/src/hirad/utils/model_utils.py new file mode 100644 index 0000000..e1cde9d --- /dev/null +++ b/src/hirad/utils/model_utils.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import torch + + +def weight_init(shape: tuple, mode: str, fan_in: int, fan_out: int): + """ + Unified routine for initializing weights and biases. + This function provides a unified interface for various weight initialization + strategies like Xavier (Glorot) and Kaiming (He) initializations. + + Parameters + ---------- + shape : tuple + The shape of the tensor to initialize. It could represent weights or biases + of a layer in a neural network. + mode : str + The mode/type of initialization to use. Supported values are: + - "xavier_uniform": Xavier (Glorot) uniform initialization. + - "xavier_normal": Xavier (Glorot) normal initialization. + - "kaiming_uniform": Kaiming (He) uniform initialization. + - "kaiming_normal": Kaiming (He) normal initialization. + fan_in : int + The number of input units in the weight tensor. For convolutional layers, + this typically represents the number of input channels times the kernel height + times the kernel width. + fan_out : int + The number of output units in the weight tensor. For convolutional layers, + this typically represents the number of output channels times the kernel height + times the kernel width. + + Returns + ------- + torch.Tensor + The initialized tensor based on the specified mode. + + Raises + ------ + ValueError + If the provided `mode` is not one of the supported initialization modes. + """ + if mode == "xavier_uniform": + return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) + if mode == "xavier_normal": + return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) + if mode == "kaiming_uniform": + return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) + if mode == "kaiming_normal": + return np.sqrt(1 / fan_in) * torch.randn(*shape) + raise ValueError(f'Invalid init mode "{mode}"') diff --git a/src/hirad/utils/patching.py b/src/hirad/utils/patching.py new file mode 100644 index 0000000..6f4bc4d --- /dev/null +++ b/src/hirad/utils/patching.py @@ -0,0 +1,767 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +import random +import warnings +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple, Union + +import torch +from einops import rearrange +from torch import Tensor + +""" +This module defines utilities, including classes and functions, for domain +decomposition. +""" + + +class BasePatching2D(ABC): + """ + Abstract base class for 2D image patching operations. + + This class provides a foundation for implementing various image patching + strategies. + It handles basic validation and provides abstract methods that must be + implemented by subclasses. + + Parameters + ---------- + img_shape : Tuple[int, int] + The height and width of the input images (img_shape_y, img_shape_x). + patch_shape : Tuple[int, int] + The height and width of the patches (patch_shape_y, patch_shape_x) to + extract. + """ + + def __init__( + self, img_shape: Tuple[int, int], patch_shape: Tuple[int, int] + ) -> None: + # Check that img_shape and patch_shape are 2D + if len(img_shape) != 2: + raise ValueError(f"img_shape must be 2D, got {len(img_shape)}D") + if len(patch_shape) != 2: + raise ValueError(f"patch_shape must be 2D, got {len(patch_shape)}D") + + # Make sure patches fit within the image + if any(p > i for p, i in zip(patch_shape, img_shape)): + warnings.warn( + f"Patch shape {patch_shape} is larger than " + f"image shape {img_shape}. " + f"Patches will be cropped to fit within the image." + ) + self.img_shape = img_shape + self.patch_shape = tuple(min(p, i) for p, i in zip(patch_shape, img_shape)) + + @abstractmethod + def apply(self, input: Tensor, **kwargs) -> Tensor: + """ + Apply the patching operation to the input tensor. + + Parameters + ---------- + input : Tensor + Input tensor of shape (batch_size, channels, img_shape_y, + img_shape_x). + **kwargs : dict + Additional keyword arguments specific to the patching + implementation. + + Returns + ------- + Tensor + Patched tensor, shape depends on specific implementation. + """ + pass + + def fuse(self, input: Tensor, **kwargs) -> Tensor: + """ + Fuse patches back into a complete image. + + Parameters + ---------- + input : Tensor + Input tensor containing patches. + **kwargs : dict + Additional keyword arguments specific to the fusion implementation. + + Returns + ------- + Tensor + Fused tensor, shape depends on specific implementation. + + Raises + ------ + NotImplementedError + If the subclass does not implement this method. + """ + raise NotImplementedError("'fuse' method must be implemented in subclasses.") + + def global_index( + self, batch_size: int, device: Union[torch.device, str] = "cpu" + ) -> Tensor: + """ + Returns a tensor containing the global indices for each patch. + + Global indices correspond to (y, x) global grid coordinates of each + element within the original image (before patching). It is typically + used to keep track of the original position of each patch in the + original image. + + Parameters + ---------- + batch_size : int + The size of the batch of images to patch. + device : Union[torch.device, str] + Proper device to initialize global_index on. Default to `cpu` + + Returns + ------- + Tensor + A tensor of shape (self.patch_num, 2, patch_shape_y, + patch_shape_x). `global_index[:, 0, :, :]` contains the + y-coordinate (height), and `global_index[:, 1, :, :]` contains the + x-coordinate (width). + """ + Ny = torch.arange(self.img_shape[0], device=device).int() + Nx = torch.arange(self.img_shape[1], device=device).int() + grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0).unsqueeze(0) + global_index = self.apply(grid).long() + return global_index + + +class RandomPatching2D(BasePatching2D): + """ + Class for randomly extracting patches from 2D images. + + This class provides utilities to randomly extract patches from images + represented as 4D tensors. It maintains a list of random patch indices + that can be reset as needed. + + Parameters + ---------- + img_shape : Tuple[int, int] + The height and width of the input images (img_shape_y, img_shape_x). + patch_shape : Tuple[int, int] + The height and width of the patches (patch_shape_y, patch_shape_x) to + extract. + patch_num : int + The number of patches to extract. + + Attributes + ---------- + patch_indices : List[Tuple[int, int]] + The indices of the patches to extract from the images. These indices + correspond to the (y, x) coordinates of the lower left corner of each + patch. + + See Also + -------- + :class:`physicsnemo.utils.patching.BasePatching2D` + The base class providing the patching interface. + :class:`physicsnemo.utils.patching.GridPatching2D` + Alternative patching strategy using deterministic patch locations. + """ + + def __init__( + self, img_shape: Tuple[int, int], patch_shape: Tuple[int, int], patch_num: int + ) -> None: + """ + Initialize the RandomPatching2D object with the provided image shape, + patch shape, and number of patches to extract. + + Parameters + ---------- + img_shape : Tuple[int, int] + The height and width of the input images (img_shape_y, + img_shape_x). + patch_shape : Tuple[int, int] + The height and width of the patches (patch_shape_y, patch_shape_x) + to extract. + patch_num : int + The number of patches to extract. + + Returns + ------- + None + """ + super().__init__(img_shape, patch_shape) + self._patch_num = patch_num + # Generate the indices of the patches to extract + self.reset_patch_indices() + + @property + def patch_num(self) -> int: + """ + Get the number of patches to extract. + + Returns + ------- + int + The number of patches to extract. + """ + return self._patch_num + + def set_patch_num(self, value: int) -> None: + """ + Set the number of patches to extract and reset patch indices. + This is the only way to modify the patch_num value. + + Parameters + ---------- + value : int + The new number of patches to extract. + """ + self._patch_num = value + self.reset_patch_indices() + + def reset_patch_indices(self) -> None: + """ + Generate new random indices for the patches to extract. These are the + starting indices of the patches to extract (upper left corner). + + Returns + ------- + None + """ + self.patch_indices = [ + ( + random.randint(0, self.img_shape[0] - self.patch_shape[0]), + random.randint(0, self.img_shape[1] - self.patch_shape[1]), + ) + for _ in range(self.patch_num) + ] + return + + def get_patch_indices(self) -> List[Tuple[int, int]]: + """ + Get the current list of patch starting indices. + + These are the upper-left coordinates of each extracted patch + from the full image. + + Returns + ------- + List[Tuple[int, int]] + A list of (row, column) tuples representing patch starting positions. + """ + return self.patch_indices + + def apply( + self, + input: Tensor, + additional_input: Optional[Tensor] = None, + ) -> Tensor: + """ + Applies the patching operation by extracting patches specified by + `self.patch_indices` from the `input` Tensor. Extracted patches are + batched along the first dimension of the output. The layout of the + output assumes that for any i, `out[B * i: B * (i + 1)]` + corresponds to the same patch exacted from each batch element of + `input`. + + Arguments + --------- + input : Tensor + The input tensor representing the full image with shape + (batch_size, channels_in, img_shape_y, img_shape_x). + additional_input : Optional[Tensor], optional + If provided, it is concatenated to each patch along `dim=1`. + Must have same batch size as `input`. Bilinear interpolation + is used to interpolate `additional_input` onto a 2D grid of shape + (patch_shape_y, patch_shape_x). + + Returns + ------- + Tensor + A tensor of shape (batch_size * self.patch_num, channels [+ + additional_channels], patch_shape_y, patch_shape_x). If + `additional_input` is provided, its channels are concatenated + along the channel dimension. + """ + B = input.shape[0] + out = torch.zeros( + B * self.patch_num, + ( + input.shape[1] + + (additional_input.shape[1] if additional_input is not None else 0) + ), + self.patch_shape[0], + self.patch_shape[1], + device=input.device, + ) + out = out.to( + memory_format=torch.channels_last + if input.is_contiguous(memory_format=torch.channels_last) + else torch.contiguous_format + ) + if additional_input is not None: + add_input_interp = torch.nn.functional.interpolate( + input=additional_input, size=self.patch_shape, mode="bilinear" + ) + + for i, (py, px) in enumerate(self.patch_indices): + if additional_input is not None: + out[B * i : B * (i + 1),] = torch.cat( + ( + input[ + :, + :, + py : py + self.patch_shape[0], + px : px + self.patch_shape[1], + ], + add_input_interp, + ), + dim=1, + ) + else: + out[B * i : B * (i + 1),] = input[ + :, + :, + py : py + self.patch_shape[0], + px : px + self.patch_shape[1], + ] + return out + + +class GridPatching2D(BasePatching2D): + """ + Class for deterministically extracting patches from 2D images in a grid pattern. + + This class provides utilities to extract patches from images in a + deterministic manner, with configurable overlap and boundary pixels. + The patches are extracted in a grid-like pattern covering the entire image. + + Parameters + ---------- + img_shape : Tuple[int, int] + The height and width of the input images (img_shape_y, img_shape_x). + patch_shape : Tuple[int, int] + The height and width of the patches (patch_shape_y, patch_shape_x) to + extract. + overlap_pix : int, optional + Number of pixels to overlap between adjacent patches, by default 0. + boundary_pix : int, optional + Number of pixels to crop as boundary from each patch, by default 0. + + Attributes + ---------- + patch_num : int + Total number of patches that will be extracted from the image, + calculated as patch_num_x * patch_num_y. + + See Also + -------- + :class:`physicsnemo.utils.patching.BasePatching2D` + The base class providing the patching interface. + :class:`physicsnemo.utils.patching.RandomPatching2D` + Alternative patching strategy using random patch locations. + """ + + def __init__( + self, + img_shape: Tuple[int, int], + patch_shape: Tuple[int, int], + overlap_pix: int = 0, + boundary_pix: int = 0, + ): + super().__init__(img_shape, patch_shape) + self.overlap_pix = overlap_pix + self.boundary_pix = boundary_pix + patch_num_x = math.ceil( + img_shape[1] / (patch_shape[1] - overlap_pix - boundary_pix) + ) + patch_num_y = math.ceil( + img_shape[0] / (patch_shape[0] - overlap_pix - boundary_pix) + ) + self.patch_num = patch_num_x * patch_num_y + + def apply( + self, + input: Tensor, + additional_input: Optional[Tensor] = None, + ) -> Tensor: + """ + Apply deterministic patching to the input tensor. + + Splits the input tensor into patches in a grid-like pattern. Can + optionally concatenate additional interpolated data to each patch. + Extracted patches are batched along the first dimension of the output. + The layout of the output assumes that for any i, `out[B * i: B * (i + 1)]` + corresponds to the same patch exacted from each batch element of + `input`. The patches can be reconstructed back into the original image + using the fuse method. + + Parameters + ---------- + input : Tensor + Input tensor of shape (batch_size, channels, img_shape_y, + img_shape_x). + additional_input : Optional[Tensor], optional + Additional data to concatenate to each patch. Will be interpolated + to match patch dimensions. Shape must be (batch_size, + additional_channels, H, W), by default None. + + Returns + ------- + Tensor + Tensor containing patches with shape (batch_size * patch_num, + channels [+ additional_channels], patch_shape_y, patch_shape_x). + If additional_input is provided, its channels are concatenated + along the channel dimension. + + See Also + -------- + :func:`physicsnemo.utils.patching.image_batching` + The underlying function used to perform the patching operation. + """ + if additional_input is not None: + add_input_interp = torch.nn.functional.interpolate( + input=additional_input, size=self.patch_shape, mode="bilinear" + ) + else: + add_input_interp = None + out = image_batching( + input=input, + patch_shape_y=self.patch_shape[0], + patch_shape_x=self.patch_shape[1], + overlap_pix=self.overlap_pix, + boundary_pix=self.boundary_pix, + input_interp=add_input_interp, + ) + return out + + def fuse(self, input: Tensor, batch_size: int) -> Tensor: + """ + Fuse patches back into a complete image. + + Reconstructs the original image by stitching together patches, + accounting for overlapping regions and boundary pixels. In overlapping + regions, values are averaged. + + Parameters + ---------- + input : Tensor + Input tensor containing patches with shape (batch_size * patch_num, + channels, patch_shape_y, patch_shape_x). + batch_size : int + The original batch size before patching. + + Returns + ------- + Tensor + Reconstructed image tensor with shape (batch_size, channels, + img_shape_y, img_shape_x). + + See Also + -------- + :func:`physicsnemo.utils.patching.image_fuse` + The underlying function used to perform the fusion operation. + """ + out = image_fuse( + input=input, + img_shape_y=self.img_shape[0], + img_shape_x=self.img_shape[1], + batch_size=batch_size, + overlap_pix=self.overlap_pix, + boundary_pix=self.boundary_pix, + ) + return out + + +def image_batching( + input: Tensor, + patch_shape_y: int, + patch_shape_x: int, + overlap_pix: int, + boundary_pix: int, + input_interp: Optional[Tensor] = None, +) -> Tensor: + """ + Splits a full image into a batch of patched images. + + This function takes a full image and splits it into patches, adding padding + where necessary. It can also concatenate additional interpolated data to + each patch if provided. + + Parameters + ---------- + input : Tensor + The input tensor representing the full image with shape (batch_size, + channels, img_shape_y, img_shape_x). + patch_shape_y : int + The height (y-dimension) of each image patch. + patch_shape_x : int + The width (x-dimension) of each image patch. + overlap_pix : int + The number of overlapping pixels between adjacent patches. + boundary_pix : int + The number of pixels to crop as a boundary from each patch. + input_interp : Optional[Tensor], optional + Optional additional data to concatenate to each patch with shape + (batch_size, interp_channels, patch_shape_y, patch_shape_x). + By default None. + + Returns + ------- + Tensor + A tensor containing the image patches, with shape (total_patches * + batch_size, channels [+ interp_channels], patch_shape_x, + patch_shape_y). + """ + # Infer sizes from input image + batch_size, _, img_shape_y, img_shape_x = input.shape + + # Safety check: make sure patch_shapes are large enough to accommodate + # overlaps and boundaries pixels + if (patch_shape_x - overlap_pix - boundary_pix) < 1: + raise ValueError( + f"patch_shape_x must verify patch_shape_x ({patch_shape_x}) >= " + f"1 + overlap_pix ({overlap_pix}) + boundary_pix ({boundary_pix})" + ) + if (patch_shape_y - overlap_pix - boundary_pix) < 1: + raise ValueError( + f"patch_shape_y must verify patch_shape_y ({patch_shape_y}) >= " + f"1 + overlap_pix ({overlap_pix}) + boundary_pix ({boundary_pix})" + ) + # Safety check: validate input_interp dimensions if provided + if input_interp is not None: + if input_interp.shape[0] != batch_size: + raise ValueError( + f"input_interp batch size ({input_interp.shape[0]}) must match " + f"input batch size ({batch_size})" + ) + if (input_interp.shape[2] != patch_shape_y) or ( + input_interp.shape[3] != patch_shape_x + ): + raise ValueError( + f"input_interp patch shape ({input_interp.shape[2]}, {input_interp.shape[3]}) " + f"must match specified patch shape ({patch_shape_y}, {patch_shape_x})" + ) + + # Safety check: make sure patch_shape is large enough in comparison to + # overlap_pix and boundary_pix. Otherwise, number of patches extracted by + # unfold differs from the expected number of patches. + if patch_shape_x <= overlap_pix + 2 * boundary_pix: + raise ValueError( + f"patch_shape_x ({patch_shape_x}) must verify " + f"patch_shape_x ({patch_shape_x}) > " + f"overlap_pix ({overlap_pix}) + 2 * boundary_pix ({boundary_pix})" + ) + if patch_shape_y <= overlap_pix + 2 * boundary_pix: + raise ValueError( + f"patch_shape_y ({patch_shape_y}) must verify " + f"patch_shape_y ({patch_shape_y}) > " + f"overlap_pix ({overlap_pix}) + 2 * boundary_pix ({boundary_pix})" + ) + + patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) + patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) + padded_shape_x = ( + (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) + + patch_shape_x + + boundary_pix + ) + padded_shape_y = ( + (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) + + patch_shape_y + + boundary_pix + ) + pad_x_right = padded_shape_x - img_shape_x - boundary_pix + pad_y_right = padded_shape_y - img_shape_y - boundary_pix + image_padding = torch.nn.ReflectionPad2d( + (boundary_pix, pad_x_right, boundary_pix, pad_y_right) + ).to( + input.device + ) # (padding_left,padding_right,padding_top,padding_bottom) + input_padded = image_padding(input) + patch_num = patch_num_x * patch_num_y + x_unfold = torch.nn.functional.unfold( + input=input_padded.view(_cast_type(input_padded)), # Cast to float + kernel_size=(patch_shape_y, patch_shape_x), + stride=( + patch_shape_y - overlap_pix - boundary_pix, + patch_shape_x - overlap_pix - boundary_pix, + ), + ).to(input_padded.dtype) + x_unfold = rearrange( + x_unfold, + "b (c p_h p_w) (nb_p_h nb_p_w) -> (nb_p_w nb_p_h b) c p_h p_w", + p_h=patch_shape_y, + p_w=patch_shape_x, + nb_p_h=patch_num_y, + nb_p_w=patch_num_x, + ) + if input_interp is not None: + input_interp_repeated = rearrange( + torch.repeat_interleave( + input=input_interp, + repeats=patch_num, + dim=0, + output_size=x_unfold.shape[0], + ), + "(b p) c h w -> (p b) c h w", + p=patch_num, + ) + return torch.cat((x_unfold, input_interp_repeated), dim=1) + else: + return x_unfold + + +def image_fuse( + input: Tensor, + img_shape_y: int, + img_shape_x: int, + batch_size: int, + overlap_pix: int, + boundary_pix: int, +) -> Tensor: + """ + Reconstructs a full image from a batch of patched images. Reverts the patching + operation performed by image_batching(). + + This function takes a batch of image patches and reconstructs the full + image by stitching the patches together. The function accounts for + overlapping and boundary pixels, ensuring that overlapping areas are + averaged. + + Parameters + ---------- + input : Tensor + The input tensor containing the image patches with shape (patch_num * batch_size, channels, patch_shape_y, patch_shape_x). + img_shape_y : int + The height (y-dimension) of the original full image. + img_shape_x : int + The width (x-dimension) of the original full image. + batch_size : int + The original batch size before patching. + overlap_pix : int + The number of overlapping pixels between adjacent patches. + boundary_pix : int + The number of pixels to crop as a boundary from each patch. + + Returns + ------- + Tensor + The reconstructed full image tensor with shape (batch_size, channels, + img_shape_y, img_shape_x). + + See Also + -------- + :func:`physicsnemo.utils.patching.image_batching` + The function this reverses, which splits images into patches. + """ + + # Infer sizes from input image shape + patch_shape_y, patch_shape_x = input.shape[2], input.shape[3] + + # Calculate the number of patches in each dimension + patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) + patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) + + # Calculate the shape of the input after padding + padded_shape_x = ( + (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) + + patch_shape_x + + boundary_pix + ) + padded_shape_y = ( + (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) + + patch_shape_y + + boundary_pix + ) + # Calculate the shape of the padding to add to input + pad_x_right = padded_shape_x - img_shape_x - boundary_pix + pad_y_right = padded_shape_y - img_shape_y - boundary_pix + pad = (boundary_pix, pad_x_right, boundary_pix, pad_y_right) + + # Count local overlaps between patches + input_ones = torch.ones( + (batch_size, input.shape[1], padded_shape_y, padded_shape_x), + device=input.device, + ) + overlap_count = torch.nn.functional.unfold( + input=input_ones, + kernel_size=(patch_shape_y, patch_shape_x), + stride=( + patch_shape_y - overlap_pix - boundary_pix, + patch_shape_x - overlap_pix - boundary_pix, + ), + ) + overlap_count = torch.nn.functional.fold( + input=overlap_count, + output_size=(padded_shape_y, padded_shape_x), + kernel_size=(patch_shape_y, patch_shape_x), + stride=( + patch_shape_y - overlap_pix - boundary_pix, + patch_shape_x - overlap_pix - boundary_pix, + ), + ) + + # Reshape input to make it 3D to apply fold + x = rearrange( + input, + "(nb_p_w nb_p_h b) c p_h p_w -> b (c p_h p_w) (nb_p_h nb_p_w)", + p_h=patch_shape_y, + p_w=patch_shape_x, + nb_p_h=patch_num_y, + nb_p_w=patch_num_x, + ) + # Stitch patches together (by summing over overlapping patches) + x_folded = torch.nn.functional.fold( + input=x, + output_size=(padded_shape_y, padded_shape_x), + kernel_size=(patch_shape_y, patch_shape_x), + stride=( + patch_shape_y - overlap_pix - boundary_pix, + patch_shape_x - overlap_pix - boundary_pix, + ), + ) + + # Remove padding + x_no_padding = x_folded[ + ..., pad[2] : pad[2] + img_shape_y, pad[0] : pad[0] + img_shape_x + ] + overlap_count_no_padding = overlap_count[ + ..., pad[2] : pad[2] + img_shape_y, pad[0] : pad[0] + img_shape_x + ] + + # Normalize by overlap count + return x_no_padding / overlap_count_no_padding + + +def _cast_type(input: Tensor) -> torch.dtype: + """Return float type based on input tensor type. + + Parameters + ---------- + input : Tensor + Input tensor to determine float type from + + Returns + ------- + torch.dtype + Float type corresponding to input tensor type for int32/64, + otherwise returns original dtype + """ + if input.dtype == torch.int32: + return torch.float32 + elif input.dtype == torch.int64: + return torch.float64 + else: + return input.dtype diff --git a/src/hirad/utils/stochastic_sampler.py b/src/hirad/utils/stochastic_sampler.py new file mode 100644 index 0000000..198fde4 --- /dev/null +++ b/src/hirad/utils/stochastic_sampler.py @@ -0,0 +1,277 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Optional + +import torch +from torch import Tensor + +from hirad.utils.patching import GridPatching2D + + +def stochastic_sampler( + net: torch.nn.Module, + latents: torch.Tensor, + img_lr: torch.Tensor, + class_labels: Optional[Tensor] = None, + randn_like: Callable[[Tensor], Tensor] = torch.randn_like, + patching: Optional[GridPatching2D] = None, + mean_hr: Optional[torch.Tensor] = None, + lead_time_label: Optional[torch.Tensor] = None, + num_steps: int = 18, + sigma_min: float = 0.002, + sigma_max: float = 800, + rho: float = 7, + S_churn: float = 0, + S_min: float = 0, + S_max: float = float("inf"), + S_noise: float = 1, +) -> torch.Tensor: + """ + Proposed EDM sampler (Algorithm 2) with minor changes to enable + super-resolution and patch-based diffusion. + + Parameters + ---------- + net : torch.nn.Module + The neural network model that generates denoised images from noisy + inputs. + Expected signature: `net(x, x_lr, t_hat, class_labels, + lead_time_label=lead_time_label, embedding_selector=embedding_selector)`, + where: + x (torch.Tensor): Noisy input of shape (batch_size, C_out, H, W) + x_lr (torch.Tensor): Conditioning input of shape (batch_size, C_cond, H, W) + t_hat (torch.Tensor): Noise level of shape (batch_size, 1, 1, 1) or scalar + class_labels (torch.Tensor, optional): Optional class labels + lead_time_label (torch.Tensor, optional): Optional lead time labels + embedding_selector (callable, optional): Function to select + positional embeddings. Used for patch-based diffusion. + Returns: + torch.Tensor: Denoised prediction of shape (batch_size, C_out, H, W) + + Required attributes: + sigma_min (float): Minimum supported noise level for the model + sigma_max (float): Maximum supported noise level for the model + round_sigma (callable): Method to convert sigma values to tensor representation + latents : Tensor + The latent variables (e.g., noise) used as the initial input for the + sampler. Has shape (batch_size, C_out, img_shape_y, img_shape_x). + img_lr : Tensor + Low-resolution input image for conditioning the super-resolution + process. Must have shape (batch_size, C_lr, img_lr_ shape_y, + img_lr_shape_x). + class_labels : Optional[Tensor], optional + Class labels for conditional generation, if required by the model. By + default None. + randn_like : Callable[[Tensor], Tensor] + Function to generate random noise with the same shape as the input + tensor. + By default torch.randn_like. + patching : Optional[GridPatching2D], optional + A patching utility for patch-based diffusion. Implements methods to + extract patches from an image and batch the patches along `dim=0`. + Should also implement a `fuse` method to reconstruct the original image + from a batch of patches. See + :class:`physicsnemo.utils.patching.GridPatching2D` for details. By + default None, in which case non-patched diffusion is used. + mean_hr : Optional[Tensor], optional + Optional tensor containing mean high-resolution images for + conditioning. Must have same height and width as `img_lr`, with shape + (B_hr, C_hr, img_lr_shape_y, img_lr_shape_x) where the batch dimension + B_hr can be either 1, either equal to batch_size, or can be omitted. If + B_hr = 1 or is omitted, `mean_hr` will be expanded to match the shape + of `img_lr`. By default None. + lead_time_label : Optional[Tensor], optional + Optional lead time labels. By default None. + num_steps : int + Number of time steps for the sampler. By default 18. + sigma_min : float + Minimum noise level. By default 0.002. + sigma_max : float + Maximum noise level. By default 800. + rho : float + Exponent used in the time step discretization. By default 7. + S_churn : float + Churn parameter controlling the level of noise added in each step. By + default 0. + S_min : float + Minimum time step for applying churn. By default 0. + S_max : float + Maximum time step for applying churn. By default float("inf"). + S_noise : float + Noise scaling factor applied during the churn step. By default 1. + + Returns + ------- + Tensor + The final denoised image produced by the sampler. Same shape as + `latents`: (batch_size, C_out, img_shape_y, img_shape_x). + + See Also + -------- + :class:`physicsnemo.models.diffusion.EDMPrecondSuperResolution`: A model + wrapper that provides preconditioning for super-resolution diffusion + models and implements the required interface for this sampler. + """ + + # Adjust noise levels based on what's supported by the network. + # Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution. + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + + if patching is not None and not isinstance(patching, GridPatching2D): + raise ValueError("patching must be an instance of GridPatching2D.") + + # Safety check: if patching is used then img_lr and latents must have same + # height and width, otherwise there is mismatch in the number + # of patches extracted to form the final batch_size. + if patching: + if img_lr.shape[-2:] != latents.shape[-2:]: + raise ValueError( + f"img_lr and latents must have the same height and width, " + f"but found {img_lr.shape[-2:]} vs {latents.shape[-2:]}. " + ) + # img_lr and latents must also have the same batch_size, otherwise mismatch + # when processed by the network + if img_lr.shape[0] != latents.shape[0]: + raise ValueError( + f"img_lr and latents must have the same batch size, but found " + f"{img_lr.shape[0]} vs {latents.shape[0]}." + ) + + # Time step discretization. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + t_steps = ( + sigma_max ** (1 / rho) + + step_indices + / (num_steps - 1) + * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) + ) ** rho + t_steps = torch.cat( + [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] + ) # t_N = 0 + + batch_size = img_lr.shape[0] + + # conditioning = [mean_hr, img_lr, global_lr, pos_embd] + x_lr = img_lr + if mean_hr is not None: + if mean_hr.shape[-2:] != img_lr.shape[-2:]: + raise ValueError( + f"mean_hr and img_lr must have the same height and width, " + f"but found {mean_hr.shape[-2:]} vs {img_lr.shape[-2:]}." + ) + x_lr = torch.cat((mean_hr.expand(x_lr.shape[0], -1, -1, -1), x_lr), dim=1) + + # input and position padding + patching + if patching: + # Patched conditioning [x_lr, mean_hr] + # (batch_size * patch_num, C_in + C_out, patch_shape_y, patch_shape_x) + x_lr = patching.apply(input=x_lr, additional_input=img_lr) + + # Function to select the correct positional embedding for each patch + def patch_embedding_selector(emb): + # emb: (N_pe, image_shape_y, image_shape_x) + # return: (batch_size * patch_num, N_pe, patch_shape_y, patch_shape_x) + return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) + + else: + patch_embedding_selector = None + + # Main sampling loop. + x_next = latents.to(torch.float64) * t_steps[0] + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + # Increase noise temporarily. + gamma = S_churn / num_steps if S_min <= t_cur <= S_max else 0 + t_hat = net.round_sigma(t_cur + gamma * t_cur) + + x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) + + # Euler step. Perform patching operation on score tensor if patch-based + # generation is used denoised = net(x_hat, t_hat, + # class_labels,lead_time_label=lead_time_label).to(torch.float64) + + x_hat_batch = (patching.apply(input=x_hat) if patching else x_hat).to( + latents.device + ) + x_lr = x_lr.to(latents.device) + + if lead_time_label is not None: + denoised = net( + x_hat_batch, + x_lr, + t_hat, + class_labels, + lead_time_label=lead_time_label, + embedding_selector=patch_embedding_selector, + ).to(torch.float64) + else: + # print("Sizes") + # print(x_hat_batch.shape) + # print(x_lr.shape) + # print(t_hat) + # print(class_labels) + # print(global_index) + denoised = net( + x_hat_batch, + x_lr, + t_hat, + class_labels, + embedding_selector=patch_embedding_selector, + ).to(torch.float64) + if patching: + # Un-patch the denoised image + # (batch_size, C_out, img_shape_y, img_shape_x) + denoised = patching.fuse(input=denoised, batch_size=batch_size) + + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < num_steps - 1: + # Patched input + # (batch_size * patch_num, C_out, patch_shape_y, patch_shape_x) + x_next_batch = (patching.apply(input=x_next) if patching else x_next).to( + latents.device + ) + + if lead_time_label is not None: + denoised = net( + x_next_batch, + x_lr, + t_next, + class_labels, + lead_time_label=lead_time_label, + embedding_selector=patch_embedding_selector, + ).to(torch.float64) + else: + denoised = net( + x_next_batch, + x_lr, + t_next, + class_labels, + embedding_selector=patch_embedding_selector, + ).to(torch.float64) + if patching: + # Un-patch the denoised image + # (batch_size, C_out, img_shape_y, img_shape_x) + denoised = patching.fuse(input=denoised, batch_size=batch_size) + + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + return x_next diff --git a/src/hirad/utils/train_helpers.py b/src/hirad/utils/train_helpers.py new file mode 100644 index 0000000..218d6f1 --- /dev/null +++ b/src/hirad/utils/train_helpers.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np +from omegaconf import ListConfig +import warnings + + +def set_patch_shape(img_shape, patch_shape): + img_shape_y, img_shape_x = img_shape + patch_shape_y, patch_shape_x = patch_shape + if (patch_shape_x is None) or (patch_shape_x > img_shape_x): + patch_shape_x = img_shape_x + if (patch_shape_y is None) or (patch_shape_y > img_shape_y): + patch_shape_y = img_shape_y + if patch_shape_x == img_shape_x and patch_shape_y == img_shape_y: + use_patching = False + else: + use_patching = True + if use_patching: + if patch_shape_x != patch_shape_y: + warnings.warn( + f"You are using rectangular patches " + f"of shape {(patch_shape_y, patch_shape_x)}, " + f"which are an experimental feature." + ) + raise NotImplementedError("Rectangular patch not supported yet") + if patch_shape_x % 32 != 0 or patch_shape_y % 32 != 0: + raise ValueError("Patch shape needs to be a multiple of 32") + return use_patching, (img_shape_y, img_shape_x), (patch_shape_y, patch_shape_x) + + +def set_seed(rank): + """ + Set seeds for NumPy and PyTorch to ensure reproducibility in distributed settings + """ + np.random.seed(rank % (1 << 31)) + torch.manual_seed(np.random.randint(1 << 31)) + + +def configure_cuda_for_consistent_precision(): + """ + Configures CUDA and cuDNN settings to ensure consistent precision by + disabling TensorFloat-32 (TF32) and reduced precision settings. + """ + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + + +def compute_num_accumulation_rounds(total_batch_size, batch_size_per_gpu, world_size): + """ + Calculate the total batch size per GPU in a distributed setting, log the batch size per GPU, ensure it's within valid limits, + determine the number of accumulation rounds, and validate that the global batch size matches the expected value. + """ + batch_gpu_total = total_batch_size // world_size + batch_size_per_gpu = batch_size_per_gpu + if batch_size_per_gpu is None or batch_size_per_gpu > batch_gpu_total: + batch_size_per_gpu = batch_gpu_total + num_accumulation_rounds = batch_gpu_total // batch_size_per_gpu + if total_batch_size != batch_size_per_gpu * num_accumulation_rounds * world_size: + raise ValueError( + "total_batch_size must be equal to batch_size_per_gpu * num_accumulation_rounds * world_size" + ) + return batch_gpu_total, num_accumulation_rounds + + +def handle_and_clip_gradients(model, grad_clip_threshold=None): + """ + Handles NaNs and infinities in the gradients and optionally clips the gradients. + + Parameters: + - model (torch.nn.Module): The model whose gradients need to be processed. + - grad_clip_threshold (float, optional): The threshold for gradient clipping. If None, no clipping is performed. + """ + # Replace NaNs and infinities in gradients + for param in model.parameters(): + if param.grad is not None: + torch.nan_to_num( + param.grad, nan=0.0, posinf=1e5, neginf=-1e5, out=param.grad + ) + + # Clip gradients if a threshold is provided + if grad_clip_threshold is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_threshold) + + +def parse_model_args(args): + """Convert ListConfig values in args to tuples.""" + return {k: tuple(v) if isinstance(v, ListConfig) else v for k, v in args.items()} + + +def is_time_for_periodic_task( + cur_nimg, freq, done, batch_size, rank, rank_0_only=False +): + """Should we perform a task that is done every `freq` samples?""" + if rank_0_only and rank != 0: + return False + elif done: # Run periodic tasks also at the end of training + return True + else: + return cur_nimg % freq < batch_size