-
Notifications
You must be signed in to change notification settings - Fork 498
Add EddyFormer #1237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mrlazy1708
wants to merge
8
commits into
NVIDIA:main
Choose a base branch
from
mrlazy1708:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add EddyFormer #1237
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
36ad0bd
implement EddyFormer
9cb780a
fix format issue
a4d7a65
verify rope dimension
6bf5617
fix device and docstring
5ca494c
fix import and remove comments
b839de8
use ddp; change to rel l2 loss; add checkpointing
ff2947c
switch to physicsnemo.Module; add use_scale; separate EddyFormerConfi…
9ba5381
Merge branch 'main' into main
mrlazy1708 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| # EddyFormer for 3D Isotropic Turbulence | ||
|
|
||
| This example demonstrates how to use the EddyFormer model for simulating | ||
| a three-dimensional isotropic turbulence. This example runs on a single GPU. | ||
|
|
||
| ## Problem Overview | ||
|
|
||
| This example focuses on **three-dimensional homogeneous isotropic turbulence (HIT)** sustained by large-scale forcing. The flow is governed by the incompressible Navier–Stokes equations with an external forcing term: | ||
|
|
||
| \[ | ||
| \frac{\partial \mathbf{u}}{\partial t} + \mathbf{u} \cdot \nabla \mathbf{u} | ||
| = \nu \nabla^2 \mathbf{u} + \mathbf{f}(\mathbf{x}) | ||
| \] | ||
|
|
||
| where: | ||
|
|
||
| - **\(\mathbf{u}(\mathbf{x}, t)\)** — velocity field in a 3D periodic domain | ||
| - **\(\nu = 0.01\)** — kinematic viscosity | ||
| - **\(\mathbf{f}(\mathbf{x})\)** — isotropic forcing applied at the largest scales | ||
|
|
||
| ### Forcing Mechanism | ||
|
|
||
| To maintain statistically steady turbulence, a **constant-power forcing** is applied to the lowest Fourier modes (\(|\mathbf{k}| \le 1\)). The forcing injects a prescribed amount of energy \(P_{\text{in}} = 1.0\) into the system: | ||
|
|
||
| \[ | ||
| \mathbf{f}(\mathbf{x}) = | ||
| \frac{P_{\text{in}}}{E_1} | ||
| \sum_{\substack{|\mathbf{k}| \le 1 \\ \mathbf{k} \neq 0}} | ||
| \hat{\mathbf{u}}_{\mathbf{k}} e^{i \mathbf{k} \cdot \mathbf{x}} | ||
| \] | ||
|
|
||
| where: | ||
|
|
||
| \[ | ||
| E_1 = \frac{1}{2} | ||
| \sum_{|\mathbf{k}| \le 1} | ||
| \hat{\mathbf{u}}_{\mathbf{k}} \cdot \hat{\mathbf{u}}_{\mathbf{k}}^{*} | ||
| \] | ||
|
|
||
| is the kinetic energy contained in the forced low-wavenumber modes. | ||
|
|
||
| Under this forcing, the flow reaches a **statistically steady state** with a Taylor-scale Reynolds number of: | ||
|
|
||
| **\(\mathrm{Re}_\lambda \approx 94\)** | ||
|
|
||
| ### Task Description | ||
|
|
||
| The objective of this example is to **predict the future velocity field** of the turbulent flow. Given \(\mathbf{u}(\mathbf{x}, t)\), the task is: | ||
|
|
||
| > **Predict the velocity field \(\mathbf{u}(\mathbf{x}, t + \Delta t)\) with \(\Delta t = 0.5\).** | ||
|
|
||
| This requires modeling nonlinear, chaotic, multi-scale turbulent dynamics, including: | ||
|
|
||
| - energy injection at large scales | ||
| - nonlinear transfer across the inertial range | ||
| - dissipation at the smallest scales | ||
|
|
||
| ### Dataset Summary | ||
|
|
||
| - **DNS resolution:** \(384^3\) (used to generate the dataset) | ||
| - **Stored dataset resolution:** \(96^3\) | ||
| - **Kolmogorov scale resolution:** ~0.5 η | ||
| - **Forcing:** applied to modes with \(|\mathbf{k}| \le 1\) | ||
| - **Viscosity:** \(\nu = 0.01\) | ||
| - **Input power:** \(P_{\text{in}} = 1.0\) | ||
| - **Flow regime:** statistically steady HIT at \(\mathrm{Re}_\lambda \approx 94\) | ||
|
|
||
| ## Prerequisites | ||
|
|
||
| Install the required dependencies by running below: | ||
|
|
||
| ```bash | ||
| pip install -r requirements.txt | ||
| ``` | ||
|
|
||
| ## Download the Dataset | ||
|
|
||
| The dataset is publicly available at [Huggingface](https://huggingface.co/datasets/ydu11/re94). | ||
| To download the dataset, run (you might need to install the Huggingface CLI): | ||
|
|
||
| ```bash | ||
| bash download_dataset.sh | ||
| ``` | ||
|
|
||
| ## Getting Started | ||
|
|
||
| To train the model, run | ||
|
|
||
| ```bash | ||
| python train_ef_isotropic.py | ||
| ``` | ||
|
|
||
| ## References | ||
|
|
||
| - [EddyFormer: Accelerated Neural Simulations of Three-Dimensional Turbulence at Scale](https://arxiv.org/abs/2510.24173) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| model: | ||
| idim: 3 | ||
| odim: 3 | ||
| hdim: 32 | ||
| num_layers: 4 | ||
| use_scale: true | ||
| layer_config: | ||
| basis: legendre | ||
| mesh: [8, 8, 8] | ||
| mode: [10, 10, 10] | ||
| mode_les: [5, 5, 5] | ||
| kernel_size: [2, 2, 2] | ||
| kernel_size_les: [2, 2, 2] | ||
| ffn_dim: 128 | ||
| activation: GELU | ||
| num_heads: 4 | ||
| heads_dim: 32 | ||
|
|
||
| training: | ||
| dataset: data/ns3d-re94 | ||
| result_dir: outputs/ef-re94 | ||
| t: 0.5 | ||
| batch_size: 4 | ||
| num_epochs: 1 | ||
| learning_rate: 1e-3 | ||
| ckpt_every: 1000 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| hf download --repo-type dataset ydu11/re94 --local-dir ${1:-data/ns3d-re94} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| hydra-core>=1.2.0 | ||
| termcolor>=2.1.1 |
139 changes: 139 additions & 0 deletions
139
examples/cfd/isotropic_eddyformer/train_ef_isotropic.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,139 @@ | ||
| import hydra | ||
| from typing import Tuple | ||
| from torch import Tensor | ||
| from omegaconf import DictConfig | ||
|
|
||
| import os | ||
| import numpy as np | ||
|
|
||
| import torch | ||
| from torch.optim import Adam | ||
| from torch.utils.data import Dataset, DataLoader | ||
| from torch.nn.parallel import DistributedDataParallel | ||
|
|
||
| from physicsnemo.models.eddyformer import EddyFormer, EddyFormerConfig | ||
| from physicsnemo.distributed import DistributedManager | ||
| from physicsnemo.utils import StaticCaptureTraining | ||
| from physicsnemo.launch.utils import save_checkpoint | ||
| from physicsnemo.launch.logging import PythonLogger, LaunchLogger | ||
|
|
||
|
|
||
| class Re94(Dataset): | ||
|
|
||
| root: str | ||
| t: float | ||
|
|
||
| n: int = 50 | ||
| dt: float = 0.1 | ||
|
|
||
| def __init__(self, root: str, split: str, *, t: float = 0.5) -> None: | ||
| """ | ||
| """ | ||
| super().__init__() | ||
| self.root = root | ||
| self.t = t | ||
|
|
||
| self.file = [] | ||
| for fname in sorted(os.listdir(root)): | ||
| if fname.startswith(split): | ||
| self.file.append(fname) | ||
|
|
||
| @property | ||
| def stride(self) -> int: | ||
| k = int(self.t / self.dt) | ||
| assert self.dt * k == self.t | ||
| return k | ||
|
|
||
| @property | ||
| def samples_per_file(self) -> int: | ||
| return self.n - self.stride + 1 | ||
|
|
||
| def __len__(self) -> int: | ||
| return len(self.file) * self.samples_per_file | ||
|
|
||
| def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]: | ||
| file_idx, time_idx = divmod(idx, self.samples_per_file) | ||
|
|
||
| data = np.load(f"{self.root}/{self.file[file_idx]}", allow_pickle=True).item() | ||
| return torch.from_numpy(data["u"][time_idx]), torch.from_numpy(data["u"][time_idx + self.stride]) | ||
|
|
||
| @hydra.main(version_base="1.3", config_path=".", config_name="config.yaml") | ||
| def isotropic_trainer(cfg: DictConfig) -> None: | ||
| """ | ||
| """ | ||
| DistributedManager.initialize() # Only call this once in the entire script! | ||
| dist = DistributedManager() # call if required elsewhere | ||
|
|
||
| # initialize monitoring | ||
| log = PythonLogger(name="re94_ef") | ||
| log.file_logging(f"{cfg.training.result_dir}/log.txt") | ||
| LaunchLogger.initialize() # PhysicsNeMo launch logger | ||
|
|
||
| # define model and optimizer | ||
| model = EddyFormer( | ||
| idim=cfg.model.idim, | ||
| odim=cfg.model.odim, | ||
| hdim=cfg.model.hdim, | ||
| num_layers=cfg.model.num_layers, | ||
| use_scale=cfg.model.use_scale, | ||
| cfg=EddyFormerConfig(**cfg.model.layer_config), | ||
| ).to(dist.device) | ||
|
|
||
| if dist.distributed: | ||
| ddps = torch.cuda.Stream() | ||
| with torch.cuda.stream(ddps): | ||
| model = DistributedDataParallel( | ||
| model, | ||
| device_ids=[dist.local_rank], | ||
| output_device=dist.device, | ||
| broadcast_buffers=dist.broadcast_buffers, | ||
| find_unused_parameters=dist.find_unused_parameters, | ||
| ) | ||
| torch.cuda.current_stream().wait_stream(ddps) | ||
| log.success("Initialized DDP training") | ||
|
|
||
| optimizer = Adam(model.parameters(), lr=cfg.training.learning_rate) | ||
|
|
||
| # define dataset and dataloader | ||
| dataset = Re94(root=cfg.training.dataset, split="train", t=cfg.training.t) | ||
| dataloader = DataLoader(dataset, cfg.training.batch_size, shuffle=True) | ||
|
|
||
| # define relative l2 error as the loss function | ||
| def loss_fun(pred: Tensor, target: Tensor) -> Tensor: | ||
| return torch.linalg.norm(pred - target) / torch.linalg.norm(target) | ||
|
|
||
| # define training step | ||
| @StaticCaptureTraining( | ||
| model=model, | ||
| optim=optimizer, | ||
| logger=log, | ||
| use_amp=False, | ||
| use_graphs=False | ||
| ) | ||
| def training_step(input: Tensor, target: Tensor) -> Tensor: | ||
| pred = torch.vmap(model)(input) | ||
| loss = torch.vmap(loss_fun)(pred, target) | ||
| return torch.mean(loss) | ||
|
|
||
| it = 0 | ||
| log.info("Training started") | ||
|
|
||
| for epoch in range(cfg.training.num_epochs): | ||
| for it, (input, target) in enumerate(dataloader, it): | ||
|
|
||
| input = input.to(dist.device) | ||
| target = target.to(dist.device) | ||
| loss = training_step(input, target) | ||
|
|
||
| with LaunchLogger("train", epoch=epoch) as logger: | ||
| logger.log_minibatch({"Training loss": loss.item()}) | ||
|
|
||
| if it and it % cfg.training.ckpt_every == 0 and dist.rank == 0: | ||
| save_checkpoint(f"{cfg.training.result_dir}/ckpt.pt", model, optimizer, epoch=it) | ||
|
|
||
| log.success("Training completed") | ||
| save_checkpoint(f"{cfg.training.result_dir}/ckpt.pt", model, optimizer) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| isotropic_trainer() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from ._basis import Legendre | ||
| from ._datatype import SEM | ||
| from .eddyformer import EddyFormer, EddyFormerConfig |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| from typing import Protocol | ||
| from torch import Tensor | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| import numpy as np | ||
| import functools | ||
|
|
||
| class Basis(Protocol): | ||
|
|
||
| grid: Tensor | ||
| quad: Tensor | ||
|
|
||
| m: int | ||
| f: Tensor | ||
|
|
||
| def fn(self, xs: Tensor) -> Tensor: | ||
| """ | ||
| Evaluate basis functions at given points. | ||
| """ | ||
|
|
||
| def at(self, coef: Tensor, xs: Tensor) -> Tensor: | ||
| """ | ||
| Evaluate basis expansion at given points. | ||
| """ | ||
| return torch.tensordot(self.fn(xs), coef, dims=1) | ||
|
|
||
| def modal(self, vals: Tensor) -> Tensor: | ||
| """ | ||
| Convert nodal values to modal coefficients. | ||
| """ | ||
|
|
||
| def nodal(self, coef: Tensor) -> Tensor: | ||
| """ | ||
| Convert modal coefficients to nodal values. | ||
| """ | ||
|
|
||
| class Element(Basis): | ||
|
|
||
| def __init__(self, base: Basis): | ||
| """ | ||
| """ | ||
|
|
||
| # ---------------------------------------------------------------------------- # | ||
| # LEGENDRE # | ||
| # ---------------------------------------------------------------------------- # | ||
|
|
||
| from numpy.polynomial import legendre | ||
|
|
||
| @functools.cache | ||
| class Legendre(nn.Module, Basis): | ||
|
|
||
| """ | ||
| Shifted Legendre polynomials: | ||
| - `(1 - x^2) Pn''(x) - 2 x Pn(x) + n (n + 1) Pn(x) = 0` | ||
| - `Pn^~(x) = Pn(2 x - 1)` | ||
| """ | ||
|
|
||
| def extra_repr(self) -> str: | ||
| return f"m={self.m}" | ||
|
|
||
| def __init__(self, m: int, endpoint: bool = False): | ||
| """ | ||
| """ | ||
| super().__init__() | ||
| self.m = m | ||
|
|
||
| if endpoint: m -= 1 | ||
| c = (0, ) * m + (1, ) | ||
| dc = legendre.legder(c) | ||
|
|
||
| x = legendre.legroots(dc if endpoint else c) | ||
| y = legendre.legval(x, c if endpoint else dc) | ||
|
|
||
| if endpoint: | ||
| x = np.concatenate([[-1], x, [1]]) | ||
| y = np.concatenate([[1], y, [1]]) | ||
|
|
||
| w = 1 / y ** 2 | ||
| if endpoint: w /= m * (m + 1) | ||
| else: w /= 1 - x ** 2 | ||
|
|
||
| self.register_buffer("grid", torch.tensor((1 + x) / 2, dtype=torch.float)) | ||
| self.register_buffer("quad", torch.tensor(w, dtype=torch.float)) | ||
|
|
||
| self.register_buffer("f", self.fn(self.grid)) | ||
|
|
||
| def fn(self, xs: Tensor) -> Tensor: | ||
| """ | ||
| """ | ||
| P = torch.ones_like(xs), 2 * xs - 1 | ||
|
|
||
| for i in range(2, self.m): | ||
| a, b = (i * 2 - 1) / i, (i - 1) / i | ||
| P += a * P[-1] * P[1] - b * P[-2], | ||
mrlazy1708 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| return torch.stack(P, dim=-1) | ||
|
|
||
| # --------------------------------- TRANSFORM -------------------------------- # | ||
|
|
||
| def modal(self, vals: Tensor) -> Tensor: | ||
| """ | ||
| """ | ||
| norm = 2 * torch.arange(self.m, device=vals.device) + 1 | ||
| coef = self.f * norm * self.quad[:, None] | ||
| return torch.tensordot(coef.T, vals, dims=1) | ||
|
|
||
| def nodal(self, coef: Tensor) -> Tensor: | ||
| """ | ||
| """ | ||
| return self.at(coef, self.grid) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.