diff --git a/examples/cfd/external_aerodynamics/transolver/README.md b/examples/cfd/external_aerodynamics/transolver/README.md index 35d33ac7a5..054a490cdc 100644 --- a/examples/cfd/external_aerodynamics/transolver/README.md +++ b/examples/cfd/external_aerodynamics/transolver/README.md @@ -1,16 +1,33 @@ -# Transolver for External Aerodynamics on Irregular Meshes + +# `Transolver` and `Typhon` for External Aerodynamics on Irregular Meshes -## Transolver CFD Example: Overview +This example is an end to end training recipe for two models, both of which can +be run on surface or volume data. -This directory contains the essential components for training and evaluating a -Transolver model tailored to external aerodynamics CFD problems. The Transolver model +1. `Transolver` is a high-performance surrogate model for CFD solvers. The Transolver model adapts the Attention mechanism, encouraging the learning of meaningful representations. In each PhysicsAttention layer, input points are projected onto state vectors through learnable transformations and weights. These transformations are then used to compute self-attention among all state vectors, and the same weights are reused to project states back to each input point. -By stacking multiple PhysicsAttention layers, the Transolver model learns to map from +2. `Typhon` is an extension of the PhysicsAttention of Transolver with geometric +or global state enhancements. We call this layer "Geometry-Aware Latent Embeddings", +or `GALE`, and the model - built from sequential `GALE` +layers - is called `typhon` as a reference to the mythological Greek god of storms. + +As `typhon` is an extension to the Transolver PhysicsAttention mechanism, the training recipes +for these two models are integrated into one example. You may train either model +on surface or volume data, as described below. A publication for the `typhon` +will be released soon. + +## External Aerodynamics CFD Example: Overview + +This directory contains the essential components for training and evaluating a +model tailored to external aerodynamics CFD problems. Two models are supported, +because they are so closely related: `Transolver` and `Typhon`. + +By stacking multiple PhysicsAttention layers, the `Transolver` model learns to map from the functional input space to the output space with high fidelity. The PhysicsNeMo implementation closely follows the original Transolver architecture ([https://github.com/thuml/Transolver](https://github.com/thuml/Transolver)), but @@ -19,22 +36,31 @@ TransformerEngine. The training example for Transolver uses the [DrivaerML dataset](https://caemldatasets.org/drivaerml/). -> **Note:** Currently, training transolver in this example supports **surface** data only. -> Volumetric data is still in development. +`Typhon` is an extension to Transolver that particularly focuses on geometrical encodings +and enabling each attention layer to leverage the geometrical and global properties +of the system. As a concrete example, in this example we are training external +aerodynamics surrogate models for automobiles. `Transolver` takes as input +a point cloud on the surface or surrounding the surface, and iteratively processes +it with PhysicsAttention - and produces excellent results. -## Requirements +`Typhon` also takes as inputs the STL mesh of the car, and maps it to the same latent space +that PhysicsAttention layers use. This geometrical state is combined with the +self-attention state via cross-attention, allowing each attention layer to incorporate +self-attention between successive layers as well as attend to the overall geometry +of the problem, at every stage. -Transolver requires TransformerEngine from NVIDIA, as well as Zarr >= 3.0 and `zarrs` -for the data pipeline. Install them with `pip install -r requirements.txt` +## Requirements -> For the Transolver datapipe, zarr > 3.0 is required. If you are using an older -> container, you may need to `unset PIP_CONSTRAINT` to allow zarr 3.0 or higher. +Transolver can use TransformerEngine from NVIDIA, as well as tensorstore (for IO), +zarr, einops and a few other python packages. Install them with `pip install -r requirements.txt` +as well as physicsnemo 25.11 or higher. The `typhon` model is a prerelease model +and not available in 25.11 - please install physicsnemo from source to use it. -## Using Transolver for External Aerodynamics +## Using Transolver and Typhon for External Aerodynamics -1. Prepare the Dataset. Transolver uses the same Zarr outputs as other models with DrivaerML. +1. Prepare the Dataset. Both models uses the same Zarr outputs as other models with DrivaerML. `PhysicsNeMo` has a related project to help with data processing, called [PhysicsNeMo-Curator](https://github.com/NVIDIA/physicsnemo-curator). -Using `PhysicsNeMo-Curator`, the data needed to train Transolver can be setup easily. +Using `PhysicsNeMo-Curator`, the data needed to train can be setup easily. Please refer to [these instructions on getting started](https://github.com/NVIDIA/physicsnemo-curator?tab=readme-ov-file#what-is-physicsnemo-curator) with `PhysicsNeMo-Curator`. For specifics of preparing the dataset for this example, see the [download](https://github.com/NVIDIA/physicsnemo-curator/blob/main/examples/external_aerodynamics/domino/README.md#download-drivaerml-dataset) @@ -42,14 +68,15 @@ and [preprocessing](https://github.com/NVIDIA/physicsnemo-curator/blob/main/exam instructions from `physicsnemo-curator`. Users should apply the preprocessing steps locally to produce `zarr` output files. -2. Train your model. The model and training configuration is set in -`conf/train_surface.yaml`, where you can control both network properties +2. Train your model. The model and training configuration is configured with +`hydra`, and four configurations are available: [`transolver`, `typhon`] x [surface, volume]. +Find configurations in `src/conf`, where you can control both network properties and training properties. See below for an overview and explanation of key parameters that may be of special interest. 3. Use the trained model to perform inference. This example contains two inference examples: one for inference on the validation set, already in -Zarr format, and a second example for inference directly on .vtp files. +Zarr format. The `.vtp` inference pipeline is being updated to accomodate both models. The following sections contain further details on the training and inference recipe. @@ -69,35 +96,24 @@ in the training configuration (defaults to current directory). > this to a different normalization, however take care to update both the > preprocessing as well as inference scripts. Min/Max is another popular strategy. -To configure your training run, use `hydra` and `conf/train_surface.yaml`. The +To configure your training run, use `hydra`. The config contains sections for the model, data, optimizer, and training settings. For details on the model parameters, see the API for `physicsnemo.models.transolver`. -The data is processed with a custom Zarr dataloader, designed to use zarr 3.0 and -`zarrs` rust implementation for an optimized Codec. It also uses python's `threading` -module to open parallel reads of multiple zarr keys. You can control the number -of parallel python threads via `data.max_workers`. - -Additionally, the Zarr dataloader optimizes CPU->GPU transfers by directly -allocating pinned memory on the CPU, reading the Zarr data into that -memory buffer via a 0-copy to numpy, and moving the data to GPU via a separate -stream with non-blocking transfers. In short: you can completely overlap IO -and GPU processing as long as the IO file system can provide the next data example -fast enough. In reality, the IO latency has some variance but is not a bottleneck. - -You can disable memory pinning with `data.pin_memory=False`. Further, to fit -the training into memory, you can apply on-the-fly downsampling to the data + +To fit the training into memory, you can apply on-the-fly downsampling to the data with `data.resolution=N`, where `N` is how many points per GPU to use. This dataloader will yield the full data examples in shapes of `[1, K, f]` where `K` is the resolution of the mesh, and `f` is the feature space (3 for points, normals, etc. 4 for surface fields). Downsampling happens in the preprocessing pipeline. -> The pipeline has the ability to optimally load data from disk into `physicsnemo.ShardTensor` -> for domain parallelism - however the model support is still in development. +During training, the configuration uses a flat learning rate that decays every 100 +epochs, and bfloat16 format by default. The scheduler and learning rate +may be configured. -During training, the configuration uses the OneCycle learning rate (similar to the -original Transolver publication), and float32 format. The scheduler and learning rate -may be configured - note that the scheduler is updated every training step. For -schedulers that update every epoch, modification of the training script may be required. +The Optimizer for this training is the `Muon` optimizer - available only in +`pytorch>=2.9.0`. While not strictly required, we have found the `muon` optimizer +performs substantially better on these architectures than standard `AdamW` and +a oneCycle schedule. ### Training Precision @@ -122,7 +138,7 @@ Several other important configuration settings are available: - `checkpoint_dir` sets the directory for saving model checkpoints (defaults to `output_dir` if not specified), allowing separation of checkpoints from other outputs. -- `training.compile` will use `torch.compile` for optimized performance. It is not +- `compile` will use `torch.compile` for optimized performance. It is not compatible with `transformer_engine` (`model.use_te=True`). If TransformerEngine is not used, and half precision is, `torch.compile` is recommended for improved performance. - `training.num_epochs` controls the total number of epochs used during training. @@ -135,10 +151,10 @@ tools are checkpointed. The training script supports data-parallel training via PyTorch DDP. In a future update, we may enable domain parallelism via FSDP and ShardTensor. -The script can be launched on a single GPU with +The script can be launched on a single GPU with, for example, ```bash -python train.py --config-name train_surface +python train.py --config-name transolver_surface ``` or, for multi-GPU training, use `torchrun` or other distributed job launch tools. @@ -190,26 +206,91 @@ Epoch 47 Validation Average Metrics: ## Dataset Inference -There are two scripts provided as inference examples - it's expected that every user's -inference workloads are different, so these aim to cover common scenarios as examples. + + +The validation dataset in Zarr format can be loaded, processed, and the L2 +metrics summarized in `inference_on_zarr.py`. For surface data, this script will also +compute the drag and lift coefficients and the R^2 correlation of the predictions. + +To ensure correct calculation of drag and lift, and accurate overall metrics, +the inference script will chunk a full-resolution training example into batches, +and stitch the outputs together at the end. Output will appear as a table +with all metrics for that mode, for example: + +``` +| Batch | Loss | L2 Pressure | L2 Shear X | L2 Shear Y | L2 Shear Z | Predicted Drag Coefficient | Pred Lift Coefficient | True Drag Coefficient | True Lift Coefficient | Elapsed (s) | +|---------|--------|---------------|--------------|--------------|--------------|------------------------------|-------------------------|-------------------------|-------------------------|---------------| +| 0 | 0.0284 | 0.0583 | 0.0904 | 0.1013 | 0.1159 | 3.8533 | 3.5871 | 3.9506 | 3.4867 | 11.2388 | +| 1 | 0.023 | 0.0558 | 0.0758 | 0.0985 | 0.1056 | 3.9865 | 2.0918 | 3.9827 | 2.1996 | 10.2538 | +| 2 | 0.0457 | 0.0726 | 0.106 | 0.12 | 0.1801 | 3.9847 | 1.9193 | 3.8824 | 1.8598 | 9.859 | +| 3 | 0.045 | 0.0675 | 0.1098 | 0.1252 | 0.1391 | 6.4476 | 3.03 | 6.3828 | 2.9734 | 11.6881 | +| 4 | 0.0367 | 0.0624 | 0.1068 | 0.1152 | 0.1263 | 4.6706 | 2.2905 | 4.6637 | 2.2301 | 13.5494 | +| 5 | 0.0228 | 0.0499 | 0.0785 | 0.0941 | 0.0981 | 6.1097 | 0.7497 | 6.1664 | 0.7472 | 12.4388 | +| 6 | 0.0285 | 0.0589 | 0.0909 | 0.1059 | 0.1312 | 3.9335 | 0.8309 | 3.9136 | 0.8324 | 8.8262 | +| 7 | 0.0376 | 0.0717 | 0.1095 | 0.1236 | 0.1276 | 4.7873 | 1.9045 | 4.8402 | 2.1894 | 11.8797 | +| 8 | 0.0284 | 0.0548 | 0.0863 | 0.107 | 0.1215 | 4.229 | 1.1434 | 4.4872 | 1.0741 | 10.7116 | +| 9 | 0.0461 | 0.0767 | 0.1125 | 0.1246 | 0.139 | 5.1331 | 1.2558 | 5.0711 | 1.2379 | 12.562 | +| 10 | 0.0536 | 0.0849 | 0.1178 | 0.129 | 0.1548 | 5.0147 | 3.5343 | 5.1289 | 3.4116 | 10.4503 | +| 11 | 0.0333 | 0.0634 | 0.0965 | 0.1104 | 0.1147 | 6.5021 | 2.64 | 6.4 | 2.7209 | 13.1643 | +| 12 | 0.0238 | 0.0537 | 0.0804 | 0.0958 | 0.1032 | 5.8751 | 1.8963 | 5.945 | 1.7916 | 10.5001 | +| 13 | 0.0343 | 0.0651 | 0.1027 | 0.1093 | 0.1278 | 5.562 | 0.994 | 5.5422 | 1.0268 | 9.7037 | +| 14 | 0.0329 | 0.0717 | 0.0938 | 0.1113 | 0.124 | 5.1604 | 2.535 | 5.3534 | 2.6942 | 10.454 | +| 15 | 0.0231 | 0.0529 | 0.0807 | 0.1003 | 0.1121 | 4.646 | 2.0366 | 4.701 | 1.9473 | 10.0764 | +| 16 | 0.0311 | 0.0575 | 0.1021 | 0.1079 | 0.1166 | 5.6578 | 1.6703 | 5.3935 | 1.7436 | 12.5273 | +| 17 | 0.0284 | 0.0629 | 0.0897 | 0.1079 | 0.1172 | 5.1796 | 1.5146 | 5.3012 | 1.4919 | 11.4887 | +| 18 | 0.0302 | 0.0668 | 0.0929 | 0.1106 | 0.1157 | 5.9201 | 1.0449 | 5.9403 | 0.8958 | 11.4593 | +| 19 | 0.0248 | 0.054 | 0.0962 | 0.1046 | 0.1182 | 5.3302 | 2.3861 | 5.3644 | 2.4404 | 11.6885 | +| 20 | 0.0232 | 0.0537 | 0.0834 | 0.0981 | 0.1009 | 5.2209 | 2.1628 | 5.2129 | 2.1078 | 11.1264 | +| 21 | 0.0237 | 0.0609 | 0.0793 | 0.1 | 0.0977 | 5.5532 | 1.7551 | 5.6004 | 1.6219 | 12.1105 | +| 22 | 0.0252 | 0.0568 | 0.0813 | 0.0996 | 0.103 | 4.8141 | 2.8054 | 4.7433 | 2.8088 | 11.5489 | +| 23 | 0.0327 | 0.0627 | 0.0911 | 0.108 | 0.1273 | 5.9474 | -0.3203 | 5.8594 | -0.2171 | 12.2077 | +| 24 | 0.0313 | 0.0611 | 0.0923 | 0.1085 | 0.1096 | 5.565 | 1.667 | 5.4902 | 1.9746 | 14.0795 | +| 25 | 0.0357 | 0.0752 | 0.1021 | 0.1211 | 0.1489 | 4.4083 | 1.6014 | 4.2135 | 1.7296 | 9.5865 | +| 26 | 0.0321 | 0.0703 | 0.0933 | 0.1098 | 0.1208 | 5.6937 | 3.166 | 5.6328 | 3.6129 | 13.1783 | +| 27 | 0.0247 | 0.0575 | 0.0845 | 0.1051 | 0.1129 | 4.1762 | 1.453 | 4.2148 | 1.4453 | 11.8495 | +| 28 | 0.0318 | 0.0609 | 0.0965 | 0.1104 | 0.1173 | 5.2632 | 3.2019 | 5.2519 | 3.1841 | 13.0749 | +| 29 | 0.0368 | 0.061 | 0.0992 | 0.115 | 0.1278 | 6.585 | 1.755 | 6.4859 | 1.6068 | 12.7777 | +| 30 | 0.0289 | 0.0577 | 0.0871 | 0.1062 | 0.1169 | 5.0937 | 3.2484 | 5.3222 | 3.2723 | 11.8586 | +| 31 | 0.0369 | 0.0671 | 0.0994 | 0.1129 | 0.1618 | 5.5144 | 2.3549 | 5.5315 | 2.3749 | 10.1762 | +| 32 | 0.0356 | 0.0753 | 0.0981 | 0.1176 | 0.1373 | 5.5471 | 0.2552 | 5.5173 | 0.4556 | 8.9759 | +| 33 | 0.02 | 0.0478 | 0.0775 | 0.0956 | 0.0944 | 4.6799 | 1.9003 | 4.7444 | 1.9062 | 14.5507 | +| 34 | 0.0226 | 0.0487 | 0.0789 | 0.0922 | 0.0963 | 5.6734 | 3.4928 | 5.7506 | 3.5023 | 13.0373 | +| 35 | 0.0213 | 0.0512 | 0.0804 | 0.0959 | 0.1018 | 6.0567 | 0.9755 | 6.0311 | 0.8335 | 12.5048 | +| 36 | 0.0273 | 0.0548 | 0.0844 | 0.1004 | 0.1263 | 5.1413 | 1.308 | 5.2466 | 1.2221 | 9.8688 | +| 37 | 0.0325 | 0.0621 | 0.0895 | 0.1121 | 0.1271 | 3.2417 | 0.9704 | 3.3774 | 1.0713 | 9.3579 | +| 38 | 0.043 | 0.0661 | 0.1029 | 0.1173 | 0.1312 | 6.1339 | 3.4028 | 6.1527 | 3.423 | 11.6961 | +| 39 | 0.0279 | 0.0573 | 0.0905 | 0.1034 | 0.1118 | 6.7051 | 1.803 | 6.6982 | 1.7571 | 12.1701 | +| 40 | 0.0453 | 0.07 | 0.0986 | 0.1161 | 0.1526 | 6.9221 | 2.9829 | 6.734 | 3.0083 | 12.1203 | +| 41 | 0.0303 | 0.0638 | 0.0931 | 0.1113 | 0.1403 | 4.4597 | 0.6138 | 4.3089 | 0.772 | 9.2097 | +| 42 | 0.0219 | 0.0505 | 0.0802 | 0.0988 | 0.0977 | 6.1306 | 1.9262 | 6.1526 | 1.6924 | 12.1102 | +| 43 | 0.0274 | 0.0604 | 0.0823 | 0.1083 | 0.1113 | 3.8352 | 2.3836 | 3.9082 | 2.5251 | 11.412 | +| 44 | 0.0338 | 0.0694 | 0.0923 | 0.1073 | 0.1273 | 6.3879 | 0.8005 | 6.1468 | 0.8924 | 10.3299 | +| 45 | 0.0271 | 0.0553 | 0.0888 | 0.1016 | 0.1065 | 7.5199 | 1.7131 | 7.4145 | 1.6298 | 12.9432 | +| 46 | 0.0258 | 0.0526 | 0.0834 | 0.0984 | 0.117 | 4.7368 | 0.6839 | 4.7965 | 0.7926 | 9.1506 | +| 47 | 0.0295 | 0.0564 | 0.0896 | 0.1123 | 0.1127 | 5.1667 | 2.7415 | 5.2693 | 2.779 | 12.7713 | +[2025-11-19 07:02:38,387][training][INFO] - R2 score for lift: 0.9807 +[2025-11-19 07:02:38,387][training][INFO] - R2 score for drag: 0.9844 +[2025-11-19 07:02:38,387][training][INFO] - Summary: +| Batch | Loss | L2 Pressure | L2 Shear X | L2 Shear Y | L2 Shear Z | Predicted Drag Coefficient | Pred Lift Coefficient | True Drag Coefficient | True Lift Coefficient | Elapsed (s) | +|---------|--------|---------------|--------------|--------------|--------------|------------------------------|-------------------------|-------------------------|-------------------------|---------------| +| Mean | 0.0311 | 0.0614 | 0.0921 | 0.108 | 0.1214 | 5.2949 | 1.9137 | 5.2962 | 1.9329 | 11.4647 | +``` -First, the validation dataset in Zarr format can be loaded, processed, and the L2 -metrics summarized in `inference_on_zarr.py`. Alternatively, the model can be used + -In `inference_on_zarr.py`, the dataset examples are downsampled and preprocessed + -## Future work +## Transolver++ -The Transolver model is a promising, transformer-based model that produces high -quality predictions for CFD surrogate simulations. In the future, we may update -the example to include domain parallelism and Transolver++ enhancements, -as well as volumetric data examples. If you -have issues, requests, or other items please feel free to open an issue and discuss! +Transolver++ is supported in both models with the `plus` flag to the model. In +our experiments, we did not see gains and have focused on `transolver` and `typhon`. +You are welcome to try it and share your results with us on GitHub! diff --git a/examples/cfd/external_aerodynamics/transolver/inference_on_zarr.py b/examples/cfd/external_aerodynamics/transolver/inference_on_zarr.py deleted file mode 100644 index 9801b7d7b0..0000000000 --- a/examples/cfd/external_aerodynamics/transolver/inference_on_zarr.py +++ /dev/null @@ -1,195 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import torch -import torchinfo - -import hydra -from omegaconf import DictConfig -from physicsnemo.models.transolver.transolver import Transolver -from physicsnemo.launch.utils import load_checkpoint -from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper - -from physicsnemo.distributed import DistributedManager - -import time - -from datapipe import DomainParallelZarrDataset - -from train import forward_pass -from tabulate import tabulate - -import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Format, DelayedScaling -from torch.amp import autocast -from contextlib import nullcontext - -from train import ( - get_autocast_context, - pad_input_for_fp8, - unpad_output_for_fp8, - update_model_params_for_fp8, -) - - -def inference(cfg: DictConfig) -> None: - """ - Run inference on a validation Zarr dataset using a trained Transolver model. - - Args: - cfg (DictConfig): Hydra configuration object containing model, data, and training settings. - - Returns: - None - """ - DistributedManager.initialize() - - dist_manager = DistributedManager() - - logger = RankZeroLoggingWrapper(PythonLogger(name="training"), dist_manager) - - cfg, output_pad_size = update_model_params_for_fp8(cfg, logger) - - # Set up model - model = hydra.utils.instantiate(cfg.model) - logger.info(f"\n{torchinfo.summary(model, verbose=0)}") - model.eval() - model.to(dist_manager.device) - - if cfg.training.compile: - model = torch.compile(model) - - # Validation dataset - - val_dataset = DomainParallelZarrDataset( - data_path=cfg.data.val.data_path, # Assuming validation data path is configured - device_mesh=None, - placements=None, - max_workers=cfg.data.max_workers, - pin_memory=cfg.data.pin_memory, - keys_to_read=cfg.data.data_keys, - large_keys=cfg.data.large_keys, - ) - - ckpt_args = { - "path": f"{cfg.output_dir}/{cfg.run_id}/checkpoints", - "models": model, - } - - # Load the normalization factors: - norm_file = "surface_fields_normalization.npz" - norm_data = np.load(norm_file) - norm_factors = { - "mean": torch.from_numpy(norm_data["mean"]).to(dist_manager.device), - "std": torch.from_numpy(norm_data["std"]).to(dist_manager.device), - } - - loaded_epoch = load_checkpoint(device=dist_manager.device, **ckpt_args) - logger.info(f"loaded epoch: {loaded_epoch}") - - results = [] - start = time.time() - for batch_idx in range(len(val_dataset)): - batch = val_dataset[batch_idx] - with torch.no_grad(): - loss, metrics = forward_pass( - batch, - model, - cfg.training.precision, - output_pad_size, - dist_manager, - cfg, - norm_factors, - ) - - # Extract metric values and convert tensors to floats - l2_pressure = ( - metrics["l2_pressure"].item() - if hasattr(metrics["l2_pressure"], "item") - else metrics["l2_pressure"] - ) - l2_shear_x = ( - metrics["l2_shear_x"].item() - if hasattr(metrics["l2_shear_x"], "item") - else metrics["l2_shear_x"] - ) - l2_shear_y = ( - metrics["l2_shear_y"].item() - if hasattr(metrics["l2_shear_y"], "item") - else metrics["l2_shear_y"] - ) - l2_shear_z = ( - metrics["l2_shear_z"].item() - if hasattr(metrics["l2_shear_z"], "item") - else metrics["l2_shear_z"] - ) - - end = time.time() - elapsed = end - start - logger.info(f"Finished batch {batch_idx} in {elapsed:.4f} seconds") - results.append( - [ - batch_idx, - f"{loss:.4f}", - f"{l2_pressure:.4f}", - f"{l2_shear_x:.4f}", - f"{l2_shear_y:.4f}", - f"{l2_shear_z:.4f}", - f"{elapsed:.4f}", - ] - ) - - start = time.time() - - headers = [ - "Batch", - "Loss", - "L2 Pressure", - "L2 Shear X", - "L2 Shear Y", - "L2 Shear Z", - "Elapsed (s)", - ] - logger.info(f"Results:\n{tabulate(results, headers=headers, tablefmt='github')}") - - # Calculate means for each metric (skip batch index) - if results: - # Convert string values back to float for mean calculation - arr = np.array(results)[:, 1:].astype(float) - means = arr.mean(axis=0) - mean_row = ["Mean"] + [f"{m:.4f}" for m in means] - logger.info( - f"Summary:\n{tabulate([mean_row], headers=headers, tablefmt='github')}" - ) - - -@hydra.main(version_base=None, config_path="conf", config_name="train_surface") -def launch(cfg: DictConfig) -> None: - """ - Launch inference with Hydra configuration. - - Args: - cfg (DictConfig): Hydra configuration object. - - Returns: - None - """ - inference(cfg) - - -if __name__ == "__main__": - launch() diff --git a/examples/cfd/external_aerodynamics/transolver/loss.py b/examples/cfd/external_aerodynamics/transolver/loss.py deleted file mode 100644 index 6eb2873b73..0000000000 --- a/examples/cfd/external_aerodynamics/transolver/loss.py +++ /dev/null @@ -1,88 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from typing import Literal - - -def loss_fn( - pred: torch.Tensor, - target: torch.Tensor, - mode: Literal["surface", "volume"], -) -> torch.Tensor: - """ - Compute the main loss function for the model. - - Args: - pred: Predicted tensor from the model. - target: Ground truth tensor. - others: Dictionary of additional tensors (e.g., surface_areas, surface_normals, stream_velocity). - - Returns: - Loss value as a scalar tensor. - """ - if mode == "surface": - loss = loss_fn_surface(pred, target, "mse") - elif mode == "volume": - loss = loss_fn_volume(pred, target, "mse") - return loss - - -def loss_fn_volume( - pred: torch.Tensor, target: torch.Tensor, mode: Literal["mse", "rmse"] -) -> torch.Tensor: - """ - Compute the main loss function for the model. - """ - - raise NotImplementedError("Volumetric loss not yet implemented.") - - -def loss_fn_surface( - output: torch.Tensor, target: torch.Tensor, loss_type: Literal["mse", "rmse"] -) -> torch.Tensor: - """Calculate loss for surface data by handling scalar and vector components separately. - - Args: - output: Predicted surface values from the model. - target: Ground truth surface values. - loss_type: Type of loss to calculate ("mse" or "rmse"). - - Returns: - Combined scalar and vector loss as a scalar tensor. - """ - # Separate the scalar and vector components: - output_pressure, output_sheer = torch.split(output, [1, 3], dim=2) - target_pressure, target_sheer = torch.split(target, [1, 3], dim=2) - - numerator_pressure = torch.mean((output_pressure - target_pressure) ** 2.0) - numerator_sheer = torch.mean((target_sheer - output_sheer) ** 2.0, (0, 1)) - - eps = 1e-8 - if loss_type == "mse": - loss_pressure = numerator_pressure - loss_wall_sheer = torch.sum(numerator_sheer) - else: - denom = torch.mean((target_pressure) ** 2.0) + eps - loss_pressure = numerator_pressure / denom - - # Compute the mean diff**2 of the vector component, leave the last dimension: - denom_sheer = torch.mean((target_sheer) ** 2.0, (0, 1)) + eps - loss_wall_sheer = torch.sum(numerator_sheer / denom_sheer) - - loss = loss_pressure + loss_wall_sheer - - return loss / 4.0 diff --git a/examples/cfd/external_aerodynamics/transolver/requirements.txt b/examples/cfd/external_aerodynamics/transolver/requirements.txt index db3bcf0fc7..ffc351ec7b 100644 --- a/examples/cfd/external_aerodynamics/transolver/requirements.txt +++ b/examples/cfd/external_aerodynamics/transolver/requirements.txt @@ -5,5 +5,5 @@ termcolor torchinfo einops transformer_engine[pytorch] +tensorstore zarr>=3.0 -zarrs \ No newline at end of file diff --git a/examples/cfd/external_aerodynamics/transolver/src/benchmark_dataloading.py b/examples/cfd/external_aerodynamics/transolver/src/benchmark_dataloading.py new file mode 100644 index 0000000000..fe58c3a240 --- /dev/null +++ b/examples/cfd/external_aerodynamics/transolver/src/benchmark_dataloading.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This is a standalone script for benchmarking and testing the Transolver +datapipe in surface or volume mode. +""" + +from pathlib import Path + +import time +import os +import re +import torch + +import numpy as np + +from typing import Literal, Any + + +import hydra +from omegaconf import DictConfig, OmegaConf + + +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler + + +from physicsnemo.distributed import DistributedManager +from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper + +from physicsnemo.datapipes.cae.transolver_datapipe import ( + create_transolver_dataset, +) + + +from physicsnemo.utils.profiling import profile, Profiler + + +@profile +def main(cfg: DictConfig): + """Main training function + + Args: + cfg: Hydra configuration object + """ + + DistributedManager.initialize() + + # Set up distributed training + dist_manager = DistributedManager() + + # Set up logging + logger = RankZeroLoggingWrapper(PythonLogger(name="training"), dist_manager) + + logger.info(f"Config:\n{OmegaConf.to_yaml(cfg, resolve=True)}") + + # Load the normalization file: + norm_dir = getattr(cfg.data, "normalization_dir", ".") + if cfg.data.mode == "surface": + norm_file = str(Path(norm_dir) / "surface_fields_normalization.npz") + elif cfg.data.mode == "volume": + norm_file = str(Path(norm_dir) / "volume_fields_normalization.npz") + + norm_data = np.load(norm_file) + norm_factors = { + "mean": torch.from_numpy(norm_data["mean"]).to(dist_manager.device), + "std": torch.from_numpy(norm_data["std"]).to(dist_manager.device), + } + # Training dataset + + train_dataloader = create_transolver_dataset( + cfg.data, + phase="train", + scaling_factors=norm_factors, + ) + + # Validation dataset + + val_dataloader = create_transolver_dataset( + cfg.data, + phase="val", + scaling_factors=norm_factors, + ) + + num_replicas = dist_manager.world_size + data_rank = dist_manager.rank + + # Set up distributed samplers + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataloader, + num_replicas=num_replicas, + rank=data_rank, + shuffle=True, + drop_last=True, + ) + + val_sampler = torch.utils.data.distributed.DistributedSampler( + val_dataloader, + num_replicas=num_replicas, + rank=data_rank, + shuffle=False, # No shuffling for validation + drop_last=True, + ) + + # Training loop + logger.info("Starting IO benchmark...") + for epoch in range(1): + # Set the epoch in the samplers + train_sampler.set_epoch(epoch) + val_sampler.set_epoch(epoch) + train_dataloader.dataset.set_indices(list(train_sampler)) + val_dataloader.dataset.set_indices(list(val_sampler)) + + start_time = time.time() + # Training phase + start = time.time() + with Profiler(): + for i_batch, data in enumerate(train_dataloader): + print(f"Train {i_batch} elapsed time: {time.time() - start}") + start = time.time() + + end_time = time.time() + train_duration = end_time - start_time + + # Log epoch results + logger.info( + f"Epoch [{epoch}/{cfg.training.num_epochs}] [duration: {train_duration:.2f}s]" + ) + + logger.info("Benchmark completed!") + + +@hydra.main(version_base=None, config_path="conf", config_name="train_surface") +def launch(cfg: DictConfig): + """Launch training with hydra configuration + + Args: + cfg: Hydra configuration object + """ + + # If you want to use `line_profiler` or PyTorch's profiler, enable them here. + + profiler = Profiler() + if cfg.profile: + profiler.enable("torch") + profiler.initialize() + main(cfg) + profiler.finalize() + + +if __name__ == "__main__": + launch() diff --git a/examples/cfd/external_aerodynamics/transolver/compute_normalizations.py b/examples/cfd/external_aerodynamics/transolver/src/compute_normalizations.py similarity index 80% rename from examples/cfd/external_aerodynamics/transolver/compute_normalizations.py rename to examples/cfd/external_aerodynamics/transolver/src/compute_normalizations.py index 6fa3cae9b3..749a7ab2f7 100644 --- a/examples/cfd/external_aerodynamics/transolver/compute_normalizations.py +++ b/examples/cfd/external_aerodynamics/transolver/src/compute_normalizations.py @@ -26,17 +26,20 @@ """ from pathlib import Path +import time import numpy as np import torch import hydra from omegaconf import DictConfig -from datapipe import DomainParallelZarrDataset +from physicsnemo.datapipes.cae.cae_dataset import CAEDataset def compute_mean_std_min_max( - dataset: DomainParallelZarrDataset, field_key: str + dataset: CAEDataset, + field_key: str, + max_samples: int = 100, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute the mean, standard deviation, minimum, and maximum for a specified field @@ -45,19 +48,22 @@ def compute_mean_std_min_max( Uses a numerically stable online algorithm for mean and variance. Args: - dataset (DomainParallelZarrDataset): The dataset to process. + dataset (CAEDataset): The dataset to process. field_key (str): The key for the field to normalize. Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: mean, std, min, max tensors for the field. """ - N = 0 # Total number of elements processed + N = torch.tensor( + 0, dtype=torch.int64, device="cpu" + ) # Total number of elements processed mean = None M2 = None # Sum of squares of differences from the current mean min_val = None max_val = None + time_start = time.time() for i in range(len(dataset)): print(f"reading file: {i}") data = dataset[i][field_key] @@ -67,17 +73,17 @@ def compute_mean_std_min_max( M2 = torch.zeros(data.shape[-1], device=data.device) min_val = torch.full((data.shape[-1],), float("inf"), device=data.device) max_val = torch.full((data.shape[-1],), float("-inf"), device=data.device) - n = data.shape[1] + n = data.shape[0] N += n # Compute batch statistics - batch_mean = data.mean(axis=(0, 1)) - batch_M2 = ((data - batch_mean) ** 2).sum(axis=(0, 1)) - batch_n = data.shape[1] + batch_mean = data.mean(axis=(0,)) + batch_M2 = ((data - batch_mean) ** 2).sum(axis=(0,)) + batch_n = data.shape[0] # Update min/max - batch_min = data.amin(dim=(0, 1)) - batch_max = data.amax(dim=(0, 1)) + batch_min = data.amin(dim=(0,)) + batch_max = data.amax(dim=(0,)) min_val = torch.minimum(min_val, batch_min) max_val = torch.maximum(max_val, batch_max) @@ -86,6 +92,11 @@ def compute_mean_std_min_max( N += batch_n mean = mean + delta * (batch_n / N) M2 = M2 + batch_M2 + delta**2 * (batch_n * N) / N + time_end = time.time() + print(f"Time taken for file {i}: {time_end - time_start:.2f} seconds") + time_start = time.time() + if i >= max_samples: + break var = M2 / (N - 1) std = torch.sqrt(var) @@ -100,8 +111,10 @@ def main(cfg: DictConfig) -> None: The computed statistics are printed and saved to a .npz file. """ + # Choose which field to normalize (can be overridden via command line) - field_key: str = cfg.get("field_key", "surface_fields") + field_key: str = cfg.data.mode + "_fields" + # Normalization directory can be configured (backward compatible: defaults to current directory) normalization_dir: str = getattr(cfg.data, "normalization_dir", ".") @@ -110,19 +123,21 @@ def main(cfg: DictConfig) -> None: Path(normalization_dir) / f"{field_key}_normalization.npz" ) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + # Create the dataset using configuration parameters - dataset = DomainParallelZarrDataset( - data_path=cfg.data.train.data_path, - device_mesh=None, - placements=None, - max_workers=cfg.data.max_workers, + dataset = CAEDataset( + data_dir=cfg.data.train.data_path, + keys_to_read=[ + field_key, + ], + keys_to_read_if_available={}, + output_device=device, + preload_depth=cfg.data.preload_depth, pin_memory=cfg.data.pin_memory, - keys_to_read=[field_key], - large_keys=[field_key], ) - # Compute normalization statistics - mean, std, min_val, max_val = compute_mean_std_min_max(dataset, field_key) + mean, std, min_val, max_val = compute_mean_std_min_max(dataset, field_key, 100) print(f"Mean for {field_key}: {mean}") print(f"Std for {field_key}: {std}") print(f"Min for {field_key}: {min_val}") diff --git a/examples/cfd/external_aerodynamics/transolver/src/conf/transolver_surface.yaml b/examples/cfd/external_aerodynamics/transolver/src/conf/transolver_surface.yaml new file mode 100644 index 0000000000..db7f5938c9 --- /dev/null +++ b/examples/cfd/external_aerodynamics/transolver/src/conf/transolver_surface.yaml @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +output_dir: "runs" +checkpoint_dir: null # Optional: set custom checkpoint path, defaults to output_dir +run_id: "surface/bfloat16" + +# Performance considerations: +precision: bfloat16 # float32, float16, bfloat16, or float8 +compile: true +profile: false + +# Training configuration +training: + num_epochs: 501 # Add one to save at 250 + save_interval: 25 # Save checkpoint every N epochs + + # StepLR scheduler: Decays the learning rate by gamma every step_size epochs + scheduler: + name: "StepLR" + params: + step_size: 100 # Decay every 200 epochs (set X as desired) + gamma: 0.5 # Decay factor + + # Optimizer configuration + optimizer: + _target_: torch.optim.AdamW + lr: 1.0e-3 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + eps: 1.0e-8 + +# Model configuration +model: + _target_: physicsnemo.models.transolver.Transolver + functional_dim: 2 # Input feature dimension + out_dim: 4 # Output feature dimension + embedding_dim: 6 # Spatial embedding dimension + n_layers: 8 # Number of transformer layers + n_hidden: 256 # Hidden dimension + dropout: 0.0 # Dropout rate + n_head: 8 # Number of attention heads + act: "gelu" # Activation function + mlp_ratio: 2 # MLP ratio in attention blocks + slice_num: 512 # Number of slices in physics attention + unified_pos: false # Whether to use unified positional embeddings + ref: 8 # Reference dimension for unified pos + structured_shape: null + use_te: false # Use transformer engine + time_input: false # Whether to use time embeddings + plus: false + + +# Data configuration +data: + train: + data_path: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/domino/train/ + val: + data_path: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/domino/val/ + max_workers: 8 + normalization_dir: "src/" # Directory for normalization files + preload_depth: 1 + pin_memory: true + resolution: 300_000 + mode: surface + # Preprocessing switches: + # (Changing thes will change the embedding dim) + include_normals: true + include_sdf: false + translational_invariance: true + scale_invariance: true + reference_scale: [12.0, 4.5, 3.25] + data_keys: + - "surface_fields" + - "surface_mesh_centers" + - "surface_normals" + - "surface_areas" + - "air_density" + - "stream_velocity" + - "stl_faces" + - "stl_centers" + - "stl_coordinates" + include_geometry: false + broadcast_global_features: true + return_mesh_features: false + +# Logging configuration +logging: + level: INFO + format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' \ No newline at end of file diff --git a/examples/cfd/external_aerodynamics/transolver/src/conf/transolver_volume.yaml b/examples/cfd/external_aerodynamics/transolver/src/conf/transolver_volume.yaml new file mode 100644 index 0000000000..04a907c1b4 --- /dev/null +++ b/examples/cfd/external_aerodynamics/transolver/src/conf/transolver_volume.yaml @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +output_dir: "runs" +checkpoint_dir: null # Optional: set custom checkpoint path, defaults to output_dir +run_id: "volume/bfloat16" + +# Performance considerations: +precision: bfloat16 # float32, float16, bfloat16, or float8 +compile: true +profile: false + +# Training configuration +training: + num_epochs: 501 # Add one to save at 250 + save_interval: 25 # Save checkpoint every N epochs + + # StepLR scheduler: Decays the learning rate by gamma every step_size epochs + scheduler: + name: "StepLR" + params: + step_size: 100 # Decay every 200 epochs (set X as desired) + gamma: 0.5 # Decay factor + + # Optimizer configuration + optimizer: + _target_: torch.optim.AdamW + lr: 1.0e-3 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + eps: 1.0e-8 + +# Model configuration +model: + _target_: physicsnemo.models.transolver.Transolver + functional_dim: 2 # Input feature dimension + out_dim: 5 # Output feature dimension + embedding_dim: 7 # Spatial embedding dimension + n_layers: 8 # Number of transformer layers + n_hidden: 256 # Hidden dimension + dropout: 0.0 # Dropout rate + n_head: 8 # Number of attention heads + act: "gelu" # Activation function + mlp_ratio: 2 # MLP ratio in attention blocks + slice_num: 512 # Number of slices in physics attention + unified_pos: false # Whether to use unified positional embeddings + ref: 8 # Reference dimension for unified pos + structured_shape: null + use_te: false # Use transformer engine + time_input: false # Whether to use time embeddings + plus: false + + +# Data configuration +data: + train: + data_path: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/domino/train/ + val: + data_path: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/domino/val/ + max_workers: 8 + normalization_dir: "src/" # Directory for normalization files + preload_depth: 1 + volume_sample_from_disk: true # Enable faster IO on pre-shuffled volumetric data + pin_memory: true + resolution: 300_000 + # Preprocessing switches: + # (Changing thes will change the embedding dim) + include_normals: true + include_sdf: true + translational_invariance: true + scale_invariance: true + reference_scale: [12.0, 4.5, 3.25] + mode: volume + data_keys: + - "volume_fields" + - "volume_mesh_centers" + - "air_density" + - "stream_velocity" + - "stl_faces" + - "stl_centers" + - "stl_coordinates" + +# Logging configuration +logging: + level: INFO + format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' diff --git a/examples/cfd/external_aerodynamics/transolver/src/conf/typhon_surface.yaml b/examples/cfd/external_aerodynamics/transolver/src/conf/typhon_surface.yaml new file mode 100644 index 0000000000..3d56806e94 --- /dev/null +++ b/examples/cfd/external_aerodynamics/transolver/src/conf/typhon_surface.yaml @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +output_dir: "runs" +checkpoint_dir: null # Optional: set custom checkpoint path, defaults to output_dir +run_id: "typhon/surface/bfloat16" + +# Performance considerations: +precision: bfloat16 # float32, float16, bfloat16, or float8 +compile: true +profile: false + +# Training configuration +training: + + num_epochs: 501 # Add one to save at 250 + save_interval: 25 # Save checkpoint every N epochs + + # StepLR scheduler: Decays the learning rate by gamma every step_size epochs + scheduler: + name: "StepLR" + params: + step_size: 100 # Decay every 200 epochs (set X as desired) + gamma: 0.5 # Decay factor + + + # Optimizer configuration + optimizer: + _target_: torch.optim.AdamW + lr: 1.0e-3 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + eps: 1.0e-8 + +# Model configuration +model: + _target_: physicsnemo.experimental.models.typhon.Typhon + functional_dim: 6 # Input feature dimension + global_dim: 2 + geometry_dim: 3 + out_dim: 4 # Output feature dimension + n_layers: 8 # Number of transformer layers + n_hidden: 256 # Hidden dimension + dropout: 0.0 # Dropout rate + n_head: 8 # Number of attention heads + act: "gelu" # Activation function + mlp_ratio: 4 # MLP ratio in attention blocks + slice_num: 512 # Number of slices in physics attention + use_te: false # Use transformer engine + plus: false + + +# Data configuration +data: + train: + data_path: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/domino/train/ + val: + data_path: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/domino/val/ + max_workers: 8 + normalization_dir: "src/" # Directory for normalization files + preload_depth: 1 + pin_memory: true + resolution: 300_000 + mode: surface + # Preprocessing switches: + # (Changing thes will change the embedding dim) + data_keys: + - "surface_fields" + - "surface_areas" + - "surface_mesh_centers" + - "surface_normals" + - "air_density" + - "stream_velocity" + - "stl_faces" + - "stl_centers" + - "stl_coordinates" + include_geometry: true + include_normals: true + include_sdf: false + translational_invariance: true + scale_invariance: true + reference_scale: [12.0, 4.5, 3.25] + geometry_sampling: 300_000 + broadcast_global_features: false + return_mesh_features: False + +# Logging configuration +logging: + level: INFO + format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' diff --git a/examples/cfd/external_aerodynamics/transolver/src/conf/typhon_volume.yaml b/examples/cfd/external_aerodynamics/transolver/src/conf/typhon_volume.yaml new file mode 100644 index 0000000000..d3e6dda2cc --- /dev/null +++ b/examples/cfd/external_aerodynamics/transolver/src/conf/typhon_volume.yaml @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +output_dir: "runs" +checkpoint_dir: null # Optional: set custom checkpoint path, defaults to output_dir +run_id: "volumeX/fake-name" + +# Performance considerations: +precision: bfloat16 # float32, float16, bfloat16, or float8 +compile: true +profile: false + +# Training configuration +training: + num_epochs: 501 # Add one to save at 250 + save_interval: 25 # Save checkpoint every N epochs + + # StepLR scheduler: Decays the learning rate by gamma every step_size epochs + scheduler: + name: "StepLR" + params: + step_size: 100 # Decay every 200 epochs (set X as desired) + gamma: 0.5 # Decay factor + + + # Optimizer configuration + optimizer: + _target_: torch.optim.AdamW + lr: 1.0e-3 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + eps: 1.0e-8 + + +# Model configuration +model: + _target_: physicsnemo.experimental.models.typhon.Typhon + functional_dim: 7 # Input feature dimension + global_dim: 2 + geometry_dim: 3 + out_dim: 5 # Output feature dimension + n_layers: 6 # Number of transformer layers + n_hidden: 256 # Hidden dimension + dropout: 0.0 # Dropout rate + n_head: 8 # Number of attention heads + act: "gelu" # Activation function + mlp_ratio: 1 # MLP ratio in attention blocks + slice_num: 256 # Number of slices in physics attention + use_te: false # Use transformer engine + plus: false + + + +# Data configuration +data: + train: + data_path: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/domino/train/ + val: + data_path: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/domino/val/ + max_workers: 8 + normalization_dir: "src/" # Directory for normalization files + preload_depth: 1 + volume_sample_from_disk: true # Enable faster IO on pre-shuffled volumetric data + pin_memory: true + resolution: 100_000 + # Preprocessing switches: + # (Changing these will change the embedding dim) + include_geometry: true + include_normals: true + include_sdf: true + translational_invariance: true + scale_invariance: true + reference_scale: [12.0, 4.5, 3.25] + geometry_sampling: 300_000 + broadcast_global_features: false + mode: volume + data_keys: + - "volume_fields" + - "volume_mesh_centers" + - "air_density" + - "stream_velocity" + - "stl_faces" + - "stl_centers" + - "stl_coordinates" + +# Logging configuration +logging: + level: INFO + format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' diff --git a/examples/cfd/external_aerodynamics/transolver/src/inference_on_zarr.py b/examples/cfd/external_aerodynamics/transolver/src/inference_on_zarr.py new file mode 100644 index 0000000000..bc3115cb37 --- /dev/null +++ b/examples/cfd/external_aerodynamics/transolver/src/inference_on_zarr.py @@ -0,0 +1,449 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import numpy as np +import torch +import torchinfo +import typing +import collections +from typing import Literal + +import hydra +import omegaconf +from omegaconf import DictConfig +from physicsnemo.models.transolver.transolver import Transolver +from physicsnemo.launch.utils import load_checkpoint +from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper + +from sklearn.metrics import r2_score + +from physicsnemo.distributed import DistributedManager + +import time + +from physicsnemo.datapipes.cae.transolver_datapipe import ( + create_transolver_dataset, + TransolverDataPipe, +) +from train import forward_pass +from tabulate import tabulate + +# import transformer_engine.pytorch as te +# from transformer_engine.common.recipe import Format, DelayedScaling +from torch.amp import autocast +from contextlib import nullcontext + +from train import ( + get_autocast_context, + pad_input_for_fp8, + unpad_output_for_fp8, + update_model_params_for_fp8, +) + +torch.serialization.add_safe_globals([omegaconf.listconfig.ListConfig]) +torch.serialization.add_safe_globals([omegaconf.base.ContainerMetadata]) +torch.serialization.add_safe_globals([typing.Any]) +torch.serialization.add_safe_globals([list]) +torch.serialization.add_safe_globals([collections.defaultdict]) +torch.serialization.add_safe_globals([dict]) +torch.serialization.add_safe_globals([int]) +torch.serialization.add_safe_globals([omegaconf.nodes.AnyNode]) +torch.serialization.add_safe_globals([omegaconf.base.Metadata]) + + +@torch.no_grad() +def compute_force_coefficients( + normals: torch.Tensor, + area: torch.Tensor, + coeff: float, + p: torch.Tensor, + wss: torch.Tensor, + force_direction: torch.Tensor = np.array([1, 0, 0]), +): + """ + Computes force coefficients for a given mesh. Output includes the pressure and skin + friction components. Can be used to compute lift and drag. + For drag, use the `force_direction` as the direction of the motion, + e.g. [1, 0, 0] for flow in x direction. + For lift, use the `force_direction` as the direction perpendicular to the motion, + e.g. [0, 1, 0] for flow in x direction and weight in y direction. + + Parameters: + ----------- + normals: torch.Tensor + The surface normals on cells of the mesh + area: torch.Tensor + The surface areas of each cell + coeff: float + Reciprocal of dynamic pressure times the frontal area, i.e. 2/(A * rho * U^2) + p: torch.Tensor + Pressure distribution on the mesh (on each cell) + wss: torch.Tensor + Wall shear stress distribution on the mesh (on each cell) + force_direction: torch.Tensor + Direction to compute the force, default is np.array([1, 0, 0]) + + Returns: + -------- + c_total: float + Computed total force coefficient + c_p: float + Computed pressure force coefficient + c_f: float + Computed skin friction coefficient + """ + + # Compute coefficients + c_p = coeff * torch.sum(torch.sum(normals * force_direction, dim=-1) * area * p) + c_f = -coeff * torch.sum(torch.sum(wss * force_direction, dim=-1) * area) + + # Compute total force coefficients + c_total = c_p + c_f + + return c_total, c_p, c_f + + +def batched_inference_loop( + batch: dict, + model: torch.nn.Module, + precision: str, + data_mode: Literal["surface", "volume"], + batch_resolution: int, + output_pad_size: int | None, + dist_manager: DistributedManager, + datapipe: TransolverDataPipe, +) -> tuple[float, dict, tuple[torch.Tensor, torch.Tensor]]: + N = batch["embeddings"].shape[1] + # This generates a random ordering of the input points, + # Which we'll then slice up into inputs to the model. + indices = torch.randperm(N, device=batch["fx"].device) + + index_blocks = torch.split(indices, batch_resolution) + + global_preds_targets = [] + global_weight = 0.0 + for i, index_block in enumerate(index_blocks): + # We compute the local_batch by slicing from embeddings and fields: + local_embeddings = batch["embeddings"][:, index_block] + local_fields = batch["fields"][:, index_block] + + # fx does not need to be sliced for TransolverX: + if "geometry" not in batch.keys(): + local_fx = batch["fx"][:, index_block] + else: + local_fx = batch["fx"] + + local_batch = { + "fx": local_fx, + "embeddings": local_embeddings, + "fields": local_fields, + } + + if "air_density" in batch.keys() and "stream_velocity" in batch.keys(): + local_batch["air_density"] = batch["air_density"] + local_batch["stream_velocity"] = batch["stream_velocity"] + + if "geometry" in batch.keys(): + local_batch["geometry"] = batch["geometry"] + + # Run the forward inference pass: + local_loss, local_metrics, local_preds_targets = forward_pass( + local_batch, + model, + precision, + output_pad_size, + dist_manager, + data_mode, + datapipe, + ) + + # Accumulate the loss and metrics: + # (Still on the GPU) + weight = index_block.shape[0] / N + global_weight += weight + if i == 0: + metrics = {k: local_metrics[k] * weight for k in local_metrics.keys()} + loss = local_loss * weight + else: + metrics = { + k: metrics[k] + local_metrics[k] * weight for k in metrics.keys() + } + loss += local_loss * weight + + global_preds_targets.append(local_preds_targets) + + # Now, compute the overall loss, metrics, and coefficients: + metrics = {k: v / global_weight for k, v in metrics.items()} + loss = loss / global_weight + + global_predictions = torch.cat([l[0] for l in global_preds_targets], dim=1) + global_targets = torch.cat([l[1] for l in global_preds_targets], dim=1) + + # Now, we have to *unshuffle* the prediction to the original index + inverse_indices = torch.empty_like(indices) + inverse_indices[indices] = torch.arange(indices.size(0), device=indices.device) + # Suppose prediction is of shape [batch, N, ...] + global_predictions = global_predictions[:, inverse_indices] + global_targets = global_targets[:, inverse_indices] + return loss, metrics, (global_predictions, global_targets) + + # # Now, we have to *unshuffle* the prediction to the original index + # inverse_indices = torch.empty_like(indices) + # inverse_indices[indices] = torch.arange( + # indices.size(0), device=indices.device + # ) + # # Suppose prediction is of shape [batch, N, ...] + # prediction = prediction[:, inverse_indices] + + +def inference(cfg: DictConfig) -> None: + """ + Run inference on a validation Zarr dataset using a trained Transolver model. + + Args: + cfg (DictConfig): Hydra configuration object containing model, data, and training settings. + + Returns: + None + """ + DistributedManager.initialize() + + dist_manager = DistributedManager() + + logger = RankZeroLoggingWrapper(PythonLogger(name="training"), dist_manager) + + cfg, output_pad_size = update_model_params_for_fp8(cfg, logger) + + # Set up model + model = hydra.utils.instantiate(cfg.model) + logger.info(f"\n{torchinfo.summary(model, verbose=0)}") + model.eval() + model.to(dist_manager.device) + + num_params = sum(p.numel() for p in model.parameters()) + logger.info(f"Number of parameters: {num_params}") + + # Load the normalization file from configured directory (defaults to current dir) + norm_dir = getattr(cfg.data, "normalization_dir", ".") + if cfg.data.mode == "surface": + norm_file = str(Path(norm_dir) / "surface_fields_normalization.npz") + elif cfg.data.mode == "volume": + norm_file = str(Path(norm_dir) / "volume_fields_normalization.npz") + + norm_data = np.load(norm_file) + norm_factors = { + "mean": torch.from_numpy(norm_data["mean"]).to(dist_manager.device), + "std": torch.from_numpy(norm_data["std"]).to(dist_manager.device), + } + + if cfg.training.compile: + model = torch.compile(model, dynamic=True) + + # For INFERENCE, we deliberately set the resolution in the data pipe to NONE + # so there is not downsampling. We still batch it in the inference script + # for efficiency + + batch_resolution = cfg.data.resolution + cfg.data.resolution = None + + # Validation dataset + val_dataset = create_transolver_dataset( + cfg.data, + phase="val", + scaling_factors=norm_factors, + ) + + ckpt_args = { + "path": f"{cfg.output_dir}/{cfg.run_id}/checkpoints", + "models": model, + } + + loaded_epoch = load_checkpoint(device=dist_manager.device, **ckpt_args) + logger.info(f"loaded epoch: {loaded_epoch}") + + results = [] + start = time.time() + for batch_idx, batch in enumerate(val_dataset): + with torch.no_grad(): + loss, metrics, (global_predictions, global_targets) = ( + batched_inference_loop( + batch, + model, + cfg.training.precision, + cfg.data.mode, + batch_resolution, + output_pad_size, + dist_manager, + val_dataset, + ) + ) + + if cfg.data.mode == "surface": + coeff = 1.0 + + # Compute the drag and loss coefficients: + # (Index on [0] is to remove the 1 batch index) + pred_pressure, pred_shear = torch.split( + global_predictions[0], (1, 3), dim=-1 + ) + # pred_pressure = pred_pressure * ( + # batch["air_density"] * batch["stream_velocity"] ** 2 + # ) + # pred_shear = pred_shear * (batch["air_density"] * batch["stream_velocity"] ** 2) + + pred_pressure = pred_pressure.reshape(-1) + pred_drag_coeff, _, _ = compute_force_coefficients( + batch["surface_normals"][0], + batch["surface_areas"], + coeff, + pred_pressure, + pred_shear, + torch.tensor([[1, 0, 0]], device=dist_manager.device), + ) + + pred_lift_coeff, _, _ = compute_force_coefficients( + batch["surface_normals"][0], + batch["surface_areas"], + coeff, + pred_pressure, + pred_shear, + torch.tensor([[0, 0, 1]], device=dist_manager.device), + ) + + # air_density = batch["air_density"] if "air_density" in batch.keys() else None + # stream_velocity = batch["stream_velocity"] if "stream_velocity" in batch.keys() else None + # true_fields = val_dataset.unscale_model_targets(batch["fields"], air_density=air_density, stream_velocity=stream_velocity) + true_pressure, true_shear = torch.split(global_targets[0], (1, 3), dim=-1) + + true_pressure = true_pressure.reshape(-1) + true_drag_coeff, _, _ = compute_force_coefficients( + batch["surface_normals"][0], + batch["surface_areas"], + coeff, + true_pressure, + true_shear, + torch.tensor([[1, 0, 0]], device=dist_manager.device), + ) + + true_lift_coeff, _, _ = compute_force_coefficients( + batch["surface_normals"][0], + batch["surface_areas"], + coeff, + true_pressure, + true_shear, + torch.tensor([[0, 0, 1]], device=dist_manager.device), + ) + + pred_lift_coeff = pred_lift_coeff.item() + pred_drag_coeff = pred_drag_coeff.item() + + # Extract metric values and convert tensors to floats + l2_pressure = ( + metrics["l2_pressure_surf"].item() + if hasattr(metrics["l2_pressure_surf"], "item") + else metrics["l2_pressure_surf"] + ) + l2_shear_x = ( + metrics["l2_shear_x"].item() + if hasattr(metrics["l2_shear_x"], "item") + else metrics["l2_shear_x"] + ) + l2_shear_y = ( + metrics["l2_shear_y"].item() + if hasattr(metrics["l2_shear_y"], "item") + else metrics["l2_shear_y"] + ) + l2_shear_z = ( + metrics["l2_shear_z"].item() + if hasattr(metrics["l2_shear_z"], "item") + else metrics["l2_shear_z"] + ) + + end = time.time() + elapsed = end - start + logger.info(f"Finished batch {batch_idx} in {elapsed:.4f} seconds") + results.append( + [ + batch_idx, + f"{loss:.4f}", + f"{l2_pressure:.4f}", + f"{l2_shear_x:.4f}", + f"{l2_shear_y:.4f}", + f"{l2_shear_z:.4f}", + f"{pred_drag_coeff:.4f}", + f"{pred_lift_coeff:.4f}", + f"{true_drag_coeff:.4f}", + f"{true_lift_coeff:.4f}", + f"{elapsed:.4f}", + ] + ) + + start = time.time() + + pred_drag_coeffs = [r[6] for r in results] + pred_lift_coeffs = [r[7] for r in results] + true_drag_coeffs = [r[8] for r in results] + true_lift_coeffs = [r[9] for r in results] + + # Compute the R2 scores for lift and drag: + r2_lift = r2_score(true_lift_coeffs, pred_lift_coeffs) + r2_drag = r2_score(true_drag_coeffs, pred_drag_coeffs) + + headers = [ + "Batch", + "Loss", + "L2 Pressure", + "L2 Shear X", + "L2 Shear Y", + "L2 Shear Z", + "Predicted Drag Coefficient", + "Pred Lift Coefficient", + "True Drag Coefficient", + "True Lift Coefficient", + "Elapsed (s)", + ] + logger.info(f"Results:\n{tabulate(results, headers=headers, tablefmt='github')}") + logger.info(f"R2 score for lift: {r2_lift:.4f}") + logger.info(f"R2 score for drag: {r2_drag:.4f}") + # Calculate means for each metric (skip batch index) + if results: + # Convert string values back to float for mean calculation + arr = np.array(results)[:, 1:].astype(float) + means = arr.mean(axis=0) + mean_row = ["Mean"] + [f"{m:.4f}" for m in means] + logger.info( + f"Summary:\n{tabulate([mean_row], headers=headers, tablefmt='github')}" + ) + + +@hydra.main(version_base=None, config_path="conf", config_name="train_surface") +def launch(cfg: DictConfig) -> None: + """ + Launch inference with Hydra configuration. + + Args: + cfg (DictConfig): Hydra configuration object. + + Returns: + None + """ + inference(cfg) + + +if __name__ == "__main__": + launch() diff --git a/examples/cfd/external_aerodynamics/transolver/metrics.py b/examples/cfd/external_aerodynamics/transolver/src/metrics.py similarity index 76% rename from examples/cfd/external_aerodynamics/transolver/metrics.py rename to examples/cfd/external_aerodynamics/transolver/src/metrics.py index 4b03b02139..143e4fa338 100644 --- a/examples/cfd/external_aerodynamics/transolver/metrics.py +++ b/examples/cfd/external_aerodynamics/transolver/src/metrics.py @@ -38,21 +38,22 @@ def all_reduce_dict( if dm.world_size == 1: return metrics - for key, value in metrics.items(): - dist.all_reduce(value) - value = value / dm.world_size - metrics[key] = value + # Pack the metrics together: + merged_metrics = torch.stack(list(metrics.values()), dim=-1) + dist.all_reduce(merged_metrics) + merged_metrics = merged_metrics / dm.world_size + + # Unstack metrics: + metrics = {key: merged_metrics[i] for i, key in enumerate(metrics.keys())} return metrics def metrics_fn( pred: torch.Tensor, target: torch.Tensor, - others: dict[str, torch.Tensor], dm: DistributedManager, mode: str, - norm_factors: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]: """ Computes metrics for either surface or volume data. @@ -69,9 +70,9 @@ def metrics_fn( """ with torch.no_grad(): if mode == "surface": - metrics = metrics_fn_surface(pred, target, others, dm, norm_factors) + metrics = metrics_fn_surface(pred, target, dm) elif mode == "volume": - metrics = metrics_fn_volume(pred, target, others, dm, norm_factors) + metrics = metrics_fn_volume(pred, target, dm) else: raise ValueError(f"Unknown data mode: {mode}") @@ -82,9 +83,7 @@ def metrics_fn( def metrics_fn_volume( pred: torch.Tensor, target: torch.Tensor, - others: dict[str, torch.Tensor], dm: DistributedManager, - norm_factors: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]: """ Placeholder for volume metrics computation. @@ -99,15 +98,31 @@ def metrics_fn_volume( Raises: NotImplementedError: Always, as this function is not yet implemented. """ - raise NotImplementedError("Volume metrics not yet implemented.") + l2_num = (pred - target) ** 2 + l2_num = torch.sum(l2_num, dim=1) + l2_num = torch.sqrt(l2_num) + + l2_denom = target**2 + l2_denom = torch.sum(l2_denom, dim=1) + l2_denom = torch.sqrt(l2_denom) + + l2 = l2_num / l2_denom + + metrics = { + "l2_pressure_vol": torch.mean(l2[:, 3]), + "l2_velocity_x": torch.mean(l2[:, 0]), + "l2_velocity_y": torch.mean(l2[:, 1]), + "l2_velocity_z": torch.mean(l2[:, 2]), + "l2_nut": torch.mean(l2[:, 4]), + } + + return metrics def metrics_fn_surface( pred: torch.Tensor, target: torch.Tensor, - others: dict[str, torch.Tensor], dm: DistributedManager, - norm_factors: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]: """ Computes L2 surface metrics between prediction and target. @@ -123,8 +138,8 @@ def metrics_fn_surface( Dictionary of L2 surface metrics for pressure and shear components. """ # Unnormalize the surface values for L2: - target = target * norm_factors["std"] + norm_factors["mean"] - pred = pred * norm_factors["std"] + norm_factors["mean"] + # target = target * norm_factors["std"] + norm_factors["mean"] + # pred = pred * norm_factors["std"] + norm_factors["mean"] l2_num = (pred - target) ** 2 l2_num = torch.sum(l2_num, dim=1) @@ -137,26 +152,10 @@ def metrics_fn_surface( l2 = l2_num / l2_denom metrics = { - "l2_pressure": torch.mean(l2[:, 0]), + "l2_pressure_surf": torch.mean(l2[:, 0]), "l2_shear_x": torch.mean(l2[:, 1]), "l2_shear_y": torch.mean(l2[:, 2]), "l2_shear_z": torch.mean(l2[:, 3]), } return metrics - - -def metrics_fn_surface_pressure( - pred: torch.Tensor, target: torch.Tensor -) -> torch.Tensor: - """ - Computes mean squared error between predicted and target surface pressure. - - Args: - pred: Predicted surface pressure. - target: Target surface pressure. - - Returns: - Mean squared error as a torch.Tensor. - """ - return torch.mean((pred - target) ** 2.0) diff --git a/examples/cfd/external_aerodynamics/transolver/preprocess.py b/examples/cfd/external_aerodynamics/transolver/src/preprocess.py similarity index 100% rename from examples/cfd/external_aerodynamics/transolver/preprocess.py rename to examples/cfd/external_aerodynamics/transolver/src/preprocess.py diff --git a/examples/cfd/external_aerodynamics/transolver/src/surface_fields_normalization.npz b/examples/cfd/external_aerodynamics/transolver/src/surface_fields_normalization.npz new file mode 100644 index 0000000000..b6809d416c Binary files /dev/null and b/examples/cfd/external_aerodynamics/transolver/src/surface_fields_normalization.npz differ diff --git a/examples/cfd/external_aerodynamics/transolver/train.py b/examples/cfd/external_aerodynamics/transolver/src/train.py similarity index 60% rename from examples/cfd/external_aerodynamics/transolver/train.py rename to examples/cfd/external_aerodynamics/transolver/src/train.py index 7b56aaa62a..133ced5924 100644 --- a/examples/cfd/external_aerodynamics/transolver/train.py +++ b/examples/cfd/external_aerodynamics/transolver/src/train.py @@ -14,39 +14,128 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Core python imports: import os import time from pathlib import Path +from typing import Literal, Any, Callable, Sequence +import collections +from contextlib import nullcontext -import torch +# Configuration: import hydra import omegaconf -from tabulate import tabulate from omegaconf import DictConfig -import torchinfo + +# Pytorch imports: +import torch +from torch.optim import Optimizer +from torch.amp import autocast, GradScaler from torch.utils.tensorboard import SummaryWriter +# For metrics and model printouts: +from tabulate import tabulate +import torchinfo + +# For loading dataset stats: import numpy as np +# Physicsnemo imports ... from physicsnemo.launch.utils import load_checkpoint, save_checkpoint from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper from physicsnemo.distributed import DistributedManager - from physicsnemo.utils.profiling import profile, Profiler +from physicsnemo.datapipes.cae.transolver_datapipe import ( + create_transolver_dataset, + TransolverDataPipe, +) -from datapipe import DomainParallelZarrDataset -from loss import loss_fn +# Local folder imports for this example from metrics import metrics_fn from preprocess import ( preprocess_surface_data, downsample_surface, ) -from contextlib import nullcontext -from torch.amp import autocast, GradScaler +# Special import, if transformer engine is available: +from physicsnemo.utils.version_check import check_min_version + +TE_AVAILABLE = check_min_version("transformer_engine", "0.0.0", hard_fail=False) + +if TE_AVAILABLE: + import transformer_engine.pytorch as te + from transformer_engine.common.recipe import Format, DelayedScaling +else: + te, Format, DelayedScaling = None, None, None + +# This will go away when checkpointing is refined further below: +torch.serialization.add_safe_globals([omegaconf.listconfig.ListConfig]) +torch.serialization.add_safe_globals([omegaconf.base.ContainerMetadata]) +torch.serialization.add_safe_globals([Any]) +torch.serialization.add_safe_globals([list]) +torch.serialization.add_safe_globals([collections.defaultdict]) +torch.serialization.add_safe_globals([dict]) +torch.serialization.add_safe_globals([int]) +torch.serialization.add_safe_globals([omegaconf.nodes.AnyNode]) +torch.serialization.add_safe_globals([omegaconf.base.Metadata]) + -import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Format, DelayedScaling +class CombinedOptimizer(Optimizer): + """Combine multiple PyTorch optimizers into a single Optimizer-like interface. + + The wrapper concatenates the *param_groups* from all contained optimizers so + that learning-rate schedulers (e.g., ReduceLROnPlateau, CosineAnnealingLR) + operate transparently across every parameter. Only a minimal subset of the + *torch.optim.Optimizer* API is implemented—extend as needed. + + Note: + This will get upstreamed to physicsnemo shortly. Don't count on this + class existing here in the future! + """ + + def __init__( + self, + optimizers: Sequence[Optimizer], + torch_compile_kwargs: dict[str, Any] | None = None, + ): + if not optimizers: + raise ValueError("`optimizers` must contain at least one optimizer.") + + self.optimizers = optimizers + + # Collect parameter groups from all optimizers. We pass an empty + # *defaults* dict because hyper-parameters are managed by the inner + # optimizers, not this wrapper. + param_groups = [g for opt in optimizers for g in opt.param_groups] + super().__init__(param_groups, defaults={}) + + if torch_compile_kwargs is None: + self.step_fns: list[Callable] = [opt.step for opt in optimizers] + else: + self.step_fns: list[Callable] = [ + torch.compile(opt.step, **torch_compile_kwargs) for opt in optimizers + ] + + def zero_grad(self, *args, **kwargs) -> None: + """Nullify gradients""" + for opt in self.optimizers: + opt.zero_grad(*args, **kwargs) + + def step(self, closure=None) -> None: + for step_fn in self.step_fns: + if closure is None: + step_fn() + else: + step_fn(closure) + + def state_dict(self): + return {"optimizers": [opt.state_dict() for opt in self.optimizers]} + + def load_state_dict(self, state_dict): + for opt, sd in zip(self.optimizers, state_dict["optimizers"]): + opt.load_state_dict(sd) + + self.param_groups = [g for opt in self.optimizers for g in opt.param_groups] def get_autocast_context(precision: str) -> nullcontext: @@ -63,7 +152,7 @@ def get_autocast_context(precision: str) -> nullcontext: return autocast("cuda", dtype=torch.float16) elif precision == "bfloat16": return autocast("cuda", dtype=torch.bfloat16) - elif precision == "float8": + elif precision == "float8" and TE_AVAILABLE: fp8_format = Format.HYBRID fp8_recipe = DelayedScaling( fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max" @@ -73,29 +162,30 @@ def get_autocast_context(precision: str) -> nullcontext: return nullcontext() -def cast_precisions( - features: torch.Tensor, embeddings: torch.Tensor, precision: str -) -> tuple[torch.Tensor, torch.Tensor]: +def cast_precisions(*tensors: torch.Tensor, precision: str) -> list[torch.Tensor]: + """ + Casts the tensors to the specified precision. """ - Casts the features and embeddings tensors to the specified precision. - Args: - features (torch.Tensor): The input features tensor. - embeddings (torch.Tensor): The input embeddings tensor. - precision (str): The desired precision ("float16", "bfloat16", or other for no cast). + match precision: + case "float16": + dtype = torch.float16 + case "bfloat16": + dtype = torch.bfloat16 + case _: + dtype = None - Returns: - Tuple[torch.Tensor, torch.Tensor]: The features and embeddings tensors cast to the specified precision. - """ - if precision == "float16": - return features.to(torch.float16), embeddings.to(torch.float16) - elif precision == "bfloat16": - return features.to(torch.bfloat16), embeddings.to(torch.bfloat16) - else: - return features, embeddings + if dtype is not None: + tensors = [t.to(dtype) for t in tensors] + + return tensors -def pad_input_for_fp8(features: torch.Tensor, embeddings: torch.Tensor) -> torch.Tensor: +def pad_input_for_fp8( + features: torch.Tensor, + embeddings: torch.Tensor, + geometry: torch.Tensor | None = None, +) -> torch.Tensor: """ Pads the input features tensor so that the concatenated feature and embedding dimension is a multiple of 16, which is required for FP8 operations. Only the features is updated. @@ -113,7 +203,14 @@ def pad_input_for_fp8(features: torch.Tensor, embeddings: torch.Tensor) -> torch features = torch.nn.functional.pad(features, (0, pad_size)) fx_dim = features.shape[-1] + embeddings.shape[-1] - return features + if geometry is not None: + geometry_dim = geometry.shape[-1] if geometry is not None else 0 + if geometry_dim % 16 != 0: + pad_size = 16 - (geometry_dim % 16) + geometry = torch.nn.functional.pad(geometry, (0, pad_size)) + geometry_dim = geometry.shape[-1] + + return features, geometry def unpad_output_for_fp8( @@ -141,53 +238,62 @@ def forward_pass( precision: str, output_pad_size: int | None, dist_manager: DistributedManager, - cfg: DictConfig, - norm_factors: dict[str, torch.Tensor], + data_mode: Literal["surface", "volume"], + datapipe: TransolverDataPipe, ): """ Run the forward pass of the model for one batch, including metrics and loss calculation. """ - if cfg.data.mode == "surface": - features, embeddings, targets, others = preprocess_surface_data( - batch, norm_factors - ) - features, embeddings, targets = downsample_surface( - features, embeddings, targets, cfg.data.resolution - ) + features = batch["fx"] + embeddings = batch["embeddings"] + targets = batch["fields"] - elif cfg.data.mode == "volume": - # This is a feature to implement in the future. - pass - else: - raise ValueError(f"Unknown data mode: {cfg.data.mode}") + # Cast precisions: + features, embeddings = cast_precisions(features, embeddings, precision=precision) - # del batch + if "geometry" in batch.keys(): + (geometry,) = cast_precisions(batch["geometry"], precision=precision) + else: + geometry = None - # Cast precisions: - features, embeddings = cast_precisions(features, embeddings, precision) with get_autocast_context(precision): # For fp8, we may have to pad the inputs: - if precision == "float8": - features = pad_input_for_fp8(features, embeddings) + if precision == "float8" and TE_AVAILABLE: + features, geometry = pad_input_for_fp8(features, embeddings, geometry) - outputs = model(features, embeddings) + if "geometry" in batch.keys(): + outputs = model( + global_embedding=features, local_embedding=embeddings, geometry=geometry + ) + else: + outputs = model(fx=features, embedding=embeddings) outputs = unpad_output_for_fp8(outputs, output_pad_size) - loss = loss_fn(outputs, targets, cfg.data.mode) + loss = torch.nn.functional.mse_loss(outputs, targets) + + air_density = batch["air_density"] if "air_density" in batch.keys() else None + stream_velocity = ( + batch["stream_velocity"] if "stream_velocity" in batch.keys() else None + ) - metrics = metrics_fn( - outputs, targets, others, dist_manager, cfg.data.mode, norm_factors + unscaled_outputs = datapipe.unscale_model_targets( + outputs, air_density=air_density, stream_velocity=stream_velocity ) + unscaled_targets = datapipe.unscale_model_targets( + targets, air_density=air_density, stream_velocity=stream_velocity + ) + + metrics = metrics_fn(unscaled_outputs, unscaled_targets, dist_manager, data_mode) - return loss, metrics + return loss, metrics, (unscaled_outputs, unscaled_targets) @profile def train_epoch( dataloader, - sampler: torch.utils.data.Sampler, + epoch_len: int, model: torch.nn.Module, output_pad_size: int | None, optimizer: torch.optim.Optimizer, @@ -197,7 +303,6 @@ def train_epoch( epoch: int, cfg: DictConfig, dist_manager: DistributedManager, - norm_factors: dict[str, torch.Tensor], scaler: GradScaler | None = None, ) -> float: """ @@ -205,8 +310,8 @@ def train_epoch( Args: dataloader: Training data loader - sampler (torch.utils.data.Sampler): Sampler for distributed or sequential sampling. model (torch.nn.Module): The neural network model to train. + epoch_len (int): Length of the epoch. output_pad_size (int | None): Optional output padding size for lowest precisions (FP8). optimizer (torch.optim.Optimizer): Optimizer for model parameters. scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler. @@ -215,7 +320,6 @@ def train_epoch( epoch (int): Current epoch number. cfg (DictConfig): Hydra configuration object. dist_manager (DistributedManager): Distributed manager from physicsnemo. - norm_factors (dict[str, torch.Tensor]): Normalization factors for the data. scaler (GradScaler | None, optional): Gradient scaler for mixed precision training. Returns: float: The average training loss for the epoch. @@ -224,76 +328,73 @@ def train_epoch( total_loss = 0 total_metrics = {} - epoch_indices = list(sampler) if sampler is not None else range(len(dataloader)) - epoch_len = len(epoch_indices) precision = getattr(cfg.training, "precision", "float32") start_time = time.time() - with Profiler(): - for i, batch_idx in enumerate(epoch_indices): - batch = dataloader[batch_idx] - # preload the next batch, if we're not on the last batch - if i < epoch_len - 1 and sampler is not None: - dataloader.preload(epoch_indices[i + 1]) + for i, batch in enumerate(dataloader): + # TransolverX has a different forward pass: - loss, metrics = forward_pass( - batch, - model, - precision, - output_pad_size, - dist_manager, - cfg, - norm_factors, - ) + loss, metrics, _ = forward_pass( + batch, + model, + precision, + output_pad_size, + dist_manager, + cfg.data.mode, + dataloader, + ) - optimizer.zero_grad() - if precision == "float16" and scaler is not None: - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - else: - loss.backward() - optimizer.step() + optimizer.zero_grad() + if precision == "float16" and scaler is not None: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() - if not isinstance(scheduler, torch.optim.lr_scheduler.StepLR): - scheduler.step() + if not isinstance(scheduler, torch.optim.lr_scheduler.StepLR): + scheduler.step() - end_time = time.time() + end_time = time.time() - # Logging - this_loss = loss.detach().item() - total_loss += this_loss + # Logging + this_loss = loss.detach().item() + total_loss += this_loss - if i == 0: - total_metrics = metrics - else: - total_metrics = { - k: total_metrics[k] + metrics[k].item() for k in metrics.keys() - } + if i == 0: + total_metrics = metrics + else: + total_metrics = { + k: total_metrics[k] + metrics[k].item() for k in metrics.keys() + } - duration = end_time - start_time - start_time = end_time - images_per_second = 1 / duration + duration = end_time - start_time + start_time = end_time + images_per_second = 1 / duration - mem_usage = torch.cuda.memory_reserved() / 1024**3 + mem_usage = torch.cuda.memory_reserved() / 1024**3 - logger.info( - f"Epoch {epoch} [{i}/{epoch_len}] Loss: {this_loss:.6f} Duration: {duration:.2f}s Mem: {mem_usage:.2f}GB" + logger.info( + f"Epoch {epoch} [{i}/{epoch_len}] Loss: {this_loss:.6f} Duration: {duration:.2f}s Mem: {mem_usage:.2f}GB" + ) + if dist_manager.rank == 0: + writer.add_scalar( + "batch/learning_rate", + optimizer.param_groups[0]["lr"], + i + epoch_len * epoch, ) - if dist_manager.rank == 0: - writer.add_scalar( - "batch/learning_rate", - optimizer.param_groups[0]["lr"], - i + epoch_len * epoch, - ) - writer.add_scalar("batch/loss", this_loss, i + epoch_len * epoch) + writer.add_scalar("batch/loss", this_loss, i + epoch_len * epoch) + writer.add_scalar( + "batch/throughpu_per_gpu", images_per_second, i + epoch_len * epoch + ) + for metric_name, metric_value in metrics.items(): writer.add_scalar( - "batch/throughpu_per_gpu", images_per_second, i + epoch_len * epoch + f"batch/{metric_name}", metric_value, i + epoch_len * epoch ) - for metric_name, metric_value in metrics.items(): - writer.add_scalar( - f"batch/{metric_name}", metric_value, i + epoch_len * epoch - ) + + if cfg.profile and i >= 10: + break # Stop profiling after 10 batches avg_loss = total_loss / epoch_len avg_metrics = {k: v / epoch_len for k, v in total_metrics.items()} @@ -314,7 +415,7 @@ def train_epoch( @profile def val_epoch( dataloader, - sampler: torch.utils.data.Sampler | None, + epoch_len: int, model: torch.nn.Module, output_pad_size: int | None, logger: PythonLogger, @@ -322,14 +423,13 @@ def val_epoch( epoch: int, cfg: DictConfig, dist_manager: DistributedManager, - norm_factors: dict[str, torch.Tensor], ) -> float: """ Run validation for one epoch. Args: dataloader: Validation data loader. - sampler (torch.utils.data.Sampler | None): Sampler for distributed or sequential sampling. + epoch_len (int): Length of the epoch. model (torch.nn.Module): The model to evaluate. output_pad_size (int | None): Optional output padding size for lowest precisions (FP8). logger (PythonLogger): Logger for validation progress. @@ -337,7 +437,6 @@ def val_epoch( epoch (int): Current epoch number. cfg (DictConfig): Hydra configuration object. dist_manager (DistributedManager): Distributed manager instance. - norm_factors (dict[str, torch.Tensor]): Normalization factors for the data. Returns: float: The average validation loss for the epoch. """ @@ -346,28 +445,19 @@ def val_epoch( total_loss = 0 total_metrics = {} - epoch_indices = list(sampler) if sampler is not None else range(len(dataloader)) - epoch_len = len(epoch_indices) precision = getattr(cfg.training, "precision", "float32") start_time = time.time() with torch.no_grad(): # Disable gradient computation - for i, batch_idx in enumerate(epoch_indices): - # Get data from batch - batch = dataloader[batch_idx] - - # preload the next batch, if we're not on the last batch - if i < epoch_len - 1 and sampler is not None: - dataloader.preload(epoch_indices[i + 1]) - - loss, metrics = forward_pass( + for i, batch in enumerate(dataloader): + loss, metrics, _ = forward_pass( batch, model, precision, output_pad_size, dist_manager, - cfg, - norm_factors, + cfg.data.mode, + dataloader, ) if i == 0: @@ -390,6 +480,9 @@ def val_epoch( ) # We don't add individual loss measurements to tensorboard in the validation loop. + if cfg.profile and i >= 10: + break # Stop profiling after 10 batches + avg_loss = total_loss / epoch_len avg_metrics = {k: v / epoch_len for k, v in total_metrics.items()} if dist_manager.rank == 0: @@ -428,7 +521,7 @@ def update_model_params_for_fp8(cfg, logger) -> tuple | None: # if (cfg.model.embedding_dim + cfg.model.functional_dim) % 16 != 0: output_pad_size = None - if cfg.training.precision == "float8": + if cfg.precision == "float8": if cfg.model.out_dim % 16 != 0: # pad the output: output_pad_size = 16 - (cfg.model.out_dim % 16) @@ -509,24 +602,32 @@ def main(cfg: DictConfig): num_params = sum(p.numel() for p in model.parameters()) logger.info(f"Number of parameters: {num_params}") - # Training dataset + # Load the normalization file from configured directory (defaults to current dir) + norm_dir = getattr(cfg.data, "normalization_dir", ".") + if cfg.data.mode == "surface": + norm_file = str(Path(norm_dir) / "surface_fields_normalization.npz") + elif cfg.data.mode == "volume": + norm_file = str(Path(norm_dir) / "volume_fields_normalization.npz") + + norm_data = np.load(norm_file) + norm_factors = { + "mean": torch.from_numpy(norm_data["mean"]).to(dist_manager.device), + "std": torch.from_numpy(norm_data["std"]).to(dist_manager.device), + } - train_dataset = DomainParallelZarrDataset( - data_path=cfg.data.train.data_path, - max_workers=cfg.data.max_workers, - pin_memory=cfg.data.pin_memory, - keys_to_read=cfg.data.data_keys, - large_keys=cfg.data.large_keys, + # Training dataset + train_dataloader = create_transolver_dataset( + cfg.data, + phase="train", + scaling_factors=norm_factors, ) # Validation dataset - val_dataset = DomainParallelZarrDataset( - data_path=cfg.data.val.data_path, # Assuming validation data path is configured - max_workers=cfg.data.max_workers, - pin_memory=cfg.data.pin_memory, - keys_to_read=cfg.data.data_keys, - large_keys=cfg.data.large_keys, + val_dataloader = create_transolver_dataset( + cfg.data, + phase="val", + scaling_factors=norm_factors, ) num_replicas = dist_manager.world_size @@ -534,7 +635,7 @@ def main(cfg: DictConfig): # Set up distributed samplers train_sampler = torch.utils.data.distributed.DistributedSampler( - train_dataset, + train_dataloader, num_replicas=num_replicas, rank=data_rank, shuffle=True, @@ -542,52 +643,46 @@ def main(cfg: DictConfig): ) val_sampler = torch.utils.data.distributed.DistributedSampler( - val_dataset, + val_dataloader, num_replicas=num_replicas, rank=data_rank, shuffle=False, # No shuffling for validation drop_last=True, ) - # Load the normalization file from configured directory (defaults to current dir) - norm_dir = getattr(cfg.data, "normalization_dir", ".") - if cfg.data.mode == "surface": - norm_file = str(Path(norm_dir) / "surface_fields_normalization.npz") - elif cfg.data.mode == "volume": - raise Exception("Volume training not yet supported.") - - norm_data = np.load(norm_file) - norm_factors = { - "mean": torch.from_numpy(norm_data["mean"]).to(dist_manager.device), - "std": torch.from_numpy(norm_data["std"]).to(dist_manager.device), - } + muon_params = [p for p in model.parameters() if p.ndim == 2] + other_params = [p for p in model.parameters() if p.ndim != 2] # Set up optimizer and scheduler - optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters()) + optimizer = hydra.utils.instantiate(cfg.training.optimizer, params=other_params) + + optimizer = CombinedOptimizer( + optimizers=[ + torch.optim.Muon( + muon_params, + lr=cfg.training.optimizer.lr, + weight_decay=cfg.training.optimizer.weight_decay, + adjust_lr_fn="match_rms_adamw", + ), + optimizer, + ], + ) # Set up learning rate scheduler based on config - scheduler_cfg = cfg.scheduler + scheduler_cfg = cfg.training.scheduler scheduler_name = scheduler_cfg.name scheduler_params = dict(scheduler_cfg.params) - if scheduler_name == "OneCycleLR": - scheduler_params.setdefault("max_lr", cfg.optimizer.lr) - # Dynamically compute total_steps - total_steps = len(list(train_sampler)) * cfg.training.num_epochs - scheduler_params["total_steps"] = total_steps - scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, **scheduler_params) - elif scheduler_name == "ReduceLROnPlateau": - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, **scheduler_params - ) - elif scheduler_name == "StepLR": - scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **scheduler_params) - else: - raise ValueError(f"Unknown scheduler: {scheduler_name}") + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **scheduler_params) - precision = getattr(cfg.training, "precision", "float32") + precision = cfg.precision scaler = GradScaler() if precision == "float16" else None + if precision == "float8" and not TE_AVAILABLE: + raise ImportError( + "TransformerEngine is not installed. Please install it to use float8 precision." + ) + ckpt_args = { "path": f"{checkpoint_dir}/{cfg.run_id}/checkpoints", "optimizer": optimizer, @@ -597,7 +692,7 @@ def main(cfg: DictConfig): loaded_epoch = load_checkpoint(device=dist_manager.device, **ckpt_args) - if cfg.training.compile: + if cfg.compile: model = torch.compile(model) # Training loop @@ -606,43 +701,44 @@ def main(cfg: DictConfig): # Set the epoch in the samplers train_sampler.set_epoch(epoch) val_sampler.set_epoch(epoch) + train_dataloader.dataset.set_indices(list(train_sampler)) + val_dataloader.dataset.set_indices(list(val_sampler)) start_time = time.time() # Training phase - train_loss = train_epoch( - train_dataset, - train_sampler, - model, - output_pad_size, - optimizer, - scheduler, - logger, - writer, - epoch, - cfg, - dist_manager, - norm_factors, - scaler, - ) - end_time = time.time() - train_duration = end_time - start_time + with Profiler(): + train_loss = train_epoch( + train_dataloader, + len(list(train_sampler)), + model, + output_pad_size, + optimizer, + scheduler, + logger, + writer, + epoch, + cfg, + dist_manager, + scaler, + ) + end_time = time.time() + train_duration = end_time - start_time - start_time = time.time() - # Validation phase - val_loss = val_epoch( - val_dataset, - val_sampler, - model, - output_pad_size, - logger, - val_writer, - epoch, - cfg, - dist_manager, - norm_factors, - ) - end_time = time.time() - val_duration = end_time - start_time + start_time = time.time() + # Validation phase + val_loss = val_epoch( + val_dataloader, + len(list(val_sampler)), + model, + output_pad_size, + logger, + val_writer, + epoch, + cfg, + dist_manager, + ) + end_time = time.time() + val_duration = end_time - start_time # Log epoch results logger.info( @@ -670,8 +766,9 @@ def launch(cfg: DictConfig): # If you want to use `line_profiler` or PyTorch's profiler, enable them here. profiler = Profiler() - # profiler.enable("torch") - # profiler.enable("line_profiler") + if cfg.profile: + profiler.enable("torch") + profiler.enable("line_profiler") profiler.initialize() main(cfg) profiler.finalize() diff --git a/physicsnemo/datapipes/cae/__init__.py b/physicsnemo/datapipes/cae/__init__.py index b733c6b6d5..a188adb0da 100644 --- a/physicsnemo/datapipes/cae/__init__.py +++ b/physicsnemo/datapipes/cae/__init__.py @@ -15,4 +15,6 @@ # limitations under the License. from .domino_datapipe import DoMINODataPipe -from .mesh_datapipe import MeshDatapipe + +# from .mesh_datapipe import MeshDatapipe +from .transolver_datapipe import TransolverDataPipe diff --git a/physicsnemo/datapipes/cae/cae_dataset.py b/physicsnemo/datapipes/cae/cae_dataset.py index 705c18f92a..f4ce9f19d5 100644 --- a/physicsnemo/datapipes/cae/cae_dataset.py +++ b/physicsnemo/datapipes/cae/cae_dataset.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import pathlib import time from abc import ABC, abstractmethod @@ -89,6 +90,13 @@ def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]: """ pass + @abstractmethod + def read_file_attributes(self, filename: pathlib.Path) -> dict[str, torch.Tensor]: + """ + Read the attributes of a file and return a dictionary of tensors. + """ + pass + @abstractmethod def read_file_sharded( self, filename: pathlib.Path, device_mesh: torch.distributed.DeviceMesh @@ -255,14 +263,16 @@ def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]: # Make sure to select the slice outside of the loop. if self.is_volumetric: + volume_key = next(key for key in in_data.keys() if "volume" in key) + volume_shape = in_data[volume_key].shape[0] if self.volume_sampling_size is not None: volume_slice = self.select_random_sections_from_slice( 0, - in_data["volume_mesh_centers"].shape[0], + volume_shape, self.volume_sampling_size, ) else: - volume_slice = slice(0, in_data["volume_mesh_centers"].shape[0]) + volume_slice = slice(0, volume_shape) # This is a slower basic way to do this, to be improved: data = {} @@ -302,35 +312,60 @@ def __init__( ) -> None: super().__init__(keys_to_read, keys_to_read_if_available) + def read_file_attributes(self, filename: pathlib.Path) -> dict[str, torch.Tensor]: + """ + Read the attributes of a file and return a dictionary of tensors. + """ + group = zarr.open_group(filename, mode="r") + return group.attrs + def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]: """ Read a file and return a dictionary of tensors. """ group = zarr.open_group(filename, mode="r") - missing_keys = set(self.keys_to_read) - set(group.keys()) + attributes = self.read_file_attributes(filename) + + missing_keys = ( + set(self.keys_to_read) - set(group.keys()) - set(attributes.keys()) + ) + data = {} if len(missing_keys) > 0: raise ValueError(f"Keys {missing_keys} not found in file {filename}") + # Read in attributes: + for key in self.keys_to_read: + if key in attributes.keys(): + data[key] = torch.tensor(attributes[key]) + # Make sure to select the slice outside of the loop. if self.is_volumetric: + volume_key = next(key for key in group.keys() if "volume" in key) + volume_shape = group[volume_key].shape[0] if self.volume_sampling_size is not None: volume_slice = self.select_random_sections_from_slice( 0, - group["volume_mesh_centers"].shape[0], + volume_shape, self.volume_sampling_size, ) else: - volume_slice = slice(0, group["volume_mesh_centers"].shape[0]) + volume_slice = slice(0, volume_shape) # This is a slower basic way to do this, to be improved: - data = {} for key in self.keys_to_read: - if "volume" not in key: - data[key] = torch.from_numpy(group[key][:]) + # Don't read things that came from attributes, potentially; + if key in data.keys(): + continue + + if group[key].shape == (): + data[key] = torch.from_numpy(np.array(group[key])).to(torch.float32) else: - data[key] = torch.from_numpy(group[key][volume_slice]) + if "volume" not in key: + data[key] = torch.from_numpy(group[key][:]) + else: + data[key] = torch.from_numpy(group[key][volume_slice]) return self.fill_optional_keys(data) @@ -347,14 +382,28 @@ def read_file_sharded( group = zarr.open_group(filename, mode="r") - missing_keys = set(self.keys_to_read) - set(group.keys()) + attributes = self.read_file_attributes(filename) + + missing_keys = ( + set(self.keys_to_read) - set(group.keys()) - set(attributes.keys()) + ) if len(missing_keys) > 0: raise ValueError(f"Keys {missing_keys} not found in file {filename}") data = {} + + # Read in attributes: + for key in self.keys_to_read: + if key in attributes.keys(): + data[key] = torch.tensor(attributes[key]) + specs = {} for key in self.keys_to_read: + # Skip attributes: + if key in data.keys(): + continue + # Open the array in zarr without reading it and get info: zarr_array = group[key] array_shape = zarr_array.shape @@ -572,14 +621,47 @@ def __init__( } ) + def read_file_attributes( + self, filename: pathlib.Path + ) -> dict[str, torch.Tensor]: + """ + Read the attributes of a file and return a dictionary of tensors. + """ + store_spec = self.spec_template["kvstore"].copy() + store_spec["path"] = str(filename) + store = ts.KvStore.open(store_spec).result() + + keys = store.list().result() + + # Zarr 3 check: + if b"/zarr.json" in keys: + zarr_json = store.read(b"/zarr.json").result() + # load into json's parser: + attributes_dict = json.loads(zarr_json.value)["attributes"] + attributes = {k: torch.tensor(v) for k, v in attributes_dict.items()} + return attributes + elif b"/.zattrs" in keys: + # Zarr 2: + zarr_attrs = store.read(b"/.zattrs").result() + attributes_dict = json.loads(zarr_attrs.value) + attributes = {k: torch.tensor(v) for k, v in attributes_dict.items()} + return attributes + else: + return {} + def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]: """ Read a file and return a dictionary of tensors. """ + # We need to figure out, first, which keys are attributes. + attributes = self.read_file_attributes(filename) + + local_keys_to_read = set(self.keys_to_read) - set(attributes.keys()) + # Trigger an async open of each data item: read_futures = {} - for key in self.keys_to_read: + for key in local_keys_to_read: spec = self.spec_template.copy() spec["kvstore"]["path"] = str(filename) + "/" + str(key) @@ -593,23 +675,22 @@ def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]: } # Make sure to select the slice outside of the loop. - # We need if self.is_volumetric: + volume_key = next(key for key in read_futures.keys() if "volume" in key) + volume_shape = read_futures[volume_key].shape[0] if self.volume_sampling_size is not None: volume_slice = self.select_random_sections_from_slice( 0, - read_futures["volume_mesh_centers"].shape[0], + volume_shape, self.volume_sampling_size, ) else: - volume_slice = slice( - 0, read_futures["volume_mesh_centers"].shape[0] - ) + volume_slice = slice(0, volume_shape) # Trigger an async read of each data item: # (Each item will be a numpy ndarray after this:) tensor_futures = {} - for key in self.keys_to_read: + for key in local_keys_to_read: if "volume" not in key: tensor_futures[key] = read_futures[key].read() # For the volume data, read the slice: @@ -620,9 +701,12 @@ def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]: # (make sure to block for the result) data = { key: torch.as_tensor(tensor_futures[key].result(), dtype=torch.float32) - for key in self.keys_to_read + for key in local_keys_to_read } + # Patch in the attributes: + data.update(attributes) + return self.fill_optional_keys(data) def read_file_sharded( @@ -631,6 +715,10 @@ def read_file_sharded( """ Read a file and return a dictionary of tensors. """ + # We need to figure out, first, which keys are attributes. + attributes = self.read_file_attributes(filename) + + local_keys_to_read = set(self.keys_to_read) - set(attributes.keys()) # We need the coordinates of this GPU: this_rank = device_mesh.get_local_rank() @@ -638,7 +726,7 @@ def read_file_sharded( # This pulls a list of store objects in tensorstore: stores = {} - for key in self.keys_to_read: + for key in local_keys_to_read: spec = self.spec_template.copy() spec["kvstore"]["path"] = str(filename) + "/" + str(key) @@ -650,7 +738,7 @@ def read_file_sharded( data = {} specs = {} - for key in self.keys_to_read: + for key in local_keys_to_read: # Open the array in zarr without reading it and get info: store = stores[key] array_shape = store.shape @@ -694,7 +782,7 @@ def read_file_sharded( specs[key] = (placement, chunk_sizes) # Finally, await the full data read: - for key in self.keys_to_read: + for key in local_keys_to_read: data[key] = torch.as_tensor(data[key].result()) # Patch in the optional keys: diff --git a/physicsnemo/datapipes/cae/transolver_datapipe.py b/physicsnemo/datapipes/cae/transolver_datapipe.py new file mode 100644 index 0000000000..6844f04e14 --- /dev/null +++ b/physicsnemo/datapipes/cae/transolver_datapipe.py @@ -0,0 +1,734 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code provides the datapipe for reading the processed npy files, +generating multi-res grids, calculating signed distance fields, +sampling random points in the volume and on surface, +normalizing fields and returning the output tensors as a dictionary. + +This datapipe also non-dimensionalizes the fields, so the order in which the variables should +be fixed: velocity, pressure, turbulent viscosity for volume variables and +pressure, wall-shear-stress for surface variables. The different parameters such as +variable names, domain resolution, sampling size etc. are configurable in config.yaml. +""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable, Literal, Optional + +import torch +from omegaconf import DictConfig +from torch.utils.data import Dataset + +from physicsnemo.datapipes.cae.cae_dataset import ( + CAEDataset, +) +from physicsnemo.distributed import DistributedManager +from physicsnemo.utils.domino.utils import ( + normalize, + standardize, + unnormalize, + unstandardize, +) +from physicsnemo.utils.sdf import signed_distance_field + + +@dataclass +class TransolverDataConfig: + """ + Configuration for Transolver data processing pipeline. + + Attributes: + + Attributes: + data_path: Path to the dataset to load. + model_type: Type of the model ("surface" or "volume"). + resolution: Resolution of the sampled data, per batch. + include_normals: Whether to include surface normals in embeddings. + include_sdf: Whether to include signed distance fields in embeddings. + translational_invariance: Enable translational adjustment using center of mass. + reference_origin: Origin for translational invariance, defaults to the center of mass. + broadcast_global_features: Whether to apply global features across all points. + volume_sample_from_disk: Whether to sample points from the disk for volume data. + return_mesh_features: Whether to return the mesh areas and normals for the surface data. + Used to compute force coefficients. Transformations are applied to the mesh coordinates. + """ + + data_path: Path | None + model_type: Literal["surface", "volume"] = "surface" + resolution: int = 200_000 + + # Control what features are added to the inputs to the model: + include_normals: bool = True + include_sdf: bool = True + + # Control the geometry configuration: + include_geometry: bool = False + geometry_sampling: int = 300_000 + + # For controlling the normalization of target values: + scaling_type: Optional[Literal["min_max_scaling", "mean_std_scaling"]] = None + normalization_factors: Optional[torch.Tensor] = None + + ############################################################ + # Translation invariance configuration: + ############################################################ + + translational_invariance: bool = False + # If none, uses the center of mass from the STLs: + reference_origin: torch.Tensor | None = None + + ############################################################ + # Scale Invariance: + ############################################################ + scale_invariance: bool = False + # Must be set if scale invariance is enabled. + # Should be castable to torch tensor + reference_scale: list[float] | None = None + + broadcast_global_features: bool = True + + volume_sample_from_disk: bool = True + + return_mesh_features: bool = False + + def __post_init__(self): + if self.data_path is not None: + # Ensure data_path is a Path object: + if isinstance(self.data_path, str): + self.data_path = Path(self.data_path) + self.data_path = self.data_path.expanduser() + + if not self.data_path.exists(): + raise ValueError(f"Path {self.data_path} does not exist") + + if not self.data_path.is_dir(): + raise ValueError(f"Path {self.data_path} is not a directory") + + if self.scaling_type is not None: + if self.scaling_type not in [ + # "min_max_scaling", + "mean_std_scaling", + ]: + raise ValueError( + f"scaling_type should be one of ['min_max_scaling', 'mean_std_scaling'], got {self.scaling_type}" + ) + + if self.scale_invariance: + if self.reference_scale is None: + raise ValueError( + "reference_scale must be set if scale invariance is enabled" + ) + + self.reference_scale = list(self.reference_scale) + if len(self.reference_scale) != 3: + raise ValueError("reference_scale must be a list of 3 floats") + self.reference_scale = ( + torch.tensor(self.reference_scale).to(torch.float32).reshape(1, 3) + ) + + +class TransolverDataPipe(Dataset): + """ + Base Datapipe for Transolver + + Leverages a dataset for the actual reading of the data, and this + object is responsible for preprocessing the data. + + """ + + def __init__( + self, + input_path, + model_type: Literal["surface", "volume"], + pin_memory: bool = False, + **data_config_overrides, + ): + # Perform config packaging and validation + self.config = TransolverDataConfig( + data_path=input_path, model_type=model_type, **data_config_overrides + ) + + # Set up the distributed manager: + if not DistributedManager.is_initialized(): + DistributedManager.initialize() + + self.dataset = None + + def preprocess_surface_data( + self, + data_dict, + center_of_mass: torch.Tensor | None = None, + scale_factor: torch.Tensor | None = None, + ): + positions = data_dict["surface_mesh_centers"] + + if self.config.resolution is not None: + idx = torch.multinomial( + torch.ones(data_dict["surface_mesh_centers"].shape[0]), + self.config.resolution, + ) + else: + idx = None + + if idx is not None: + positions = positions[idx] + + # This is a center of mass computation for the stl surface, + # using the size of each mesh point as weight. + if self.config.translational_invariance: + positions -= center_of_mass + + if self.config.scale_invariance: + positions = positions / scale_factor + + # Build the embeddings: + embeddings_inputs = [positions] + + # Surface SDF is always 0: + if self.config.include_sdf: + sdf = torch.zeros_like(positions[:, 0:1]) + embeddings_inputs.append(sdf) + + if self.config.include_normals: + normals = data_dict["surface_normals"] + if idx is not None: + normals = normals[idx] + normals = normals / torch.norm(normals, dim=-1, keepdim=True) + embeddings_inputs.append(normals) + + embeddings = torch.cat(embeddings_inputs, dim=-1) + + # Build fx: + fx_inputs = [ + data_dict["air_density"], + data_dict["stream_velocity"], + ] + fx = torch.stack(fx_inputs, dim=-1) + + if self.config.broadcast_global_features: + fx = fx.broadcast_to(embeddings.shape[0], -1) + else: + fx = fx.unsqueeze(0) + + fields = data_dict["surface_fields"] + if idx is not None: + fields = fields[idx] + + if self.config.scaling_type is not None: + fields = self.scale_model_targets(fields, self.config.normalization_factors) + + return { + "embeddings": embeddings, + "fx": fx, + "fields": fields, + } + + def preprocess_volume_data( + self, + data_dict, + center_of_mass: torch.Tensor | None = None, + scale_factor: torch.Tensor | None = None, + ): + positions = data_dict["volume_mesh_centers"] + + if self.config.resolution is not None: + idx = poisson_sample_indices_fixed( + positions.shape[0], self.config.resolution, device=positions.device + ) + else: + idx = None + + if idx is not None: + positions = positions[idx] + + # We need the CoM for some operations, regardless of translation invariance: + if center_of_mass is None: + center_of_mass = torch.mean(data_dict["stl_centers"], dim=0).unsqueeze(0) + + if self.config.translational_invariance: + positions -= center_of_mass + + if self.config.scale_invariance: + positions = positions / scale_factor + + # Build the embeddings: + embeddings_inputs = [positions] + + if self.config.include_sdf: + coords = data_dict["stl_coordinates"] + # Remove CoM, optionally: + if self.config.translational_invariance: + coords = coords - center_of_mass + + # Set scale, optionally: + if self.config.scale_invariance: + coords = coords / scale_factor + + sdf, closest_points = signed_distance_field( + coords, + data_dict["stl_faces"].flatten().to(torch.int32), + positions, + use_sign_winding_number=True, + ) + + embeddings_inputs.append(sdf.reshape(-1, 1)) + else: + closest_points = center_of_mass + + # Make sure we have a scale-invariant component to subtract + # from scale-invariant positions, below: + if self.config.scale_invariance: + closest_points = closest_points / scale_factor + + if self.config.include_normals: + normals = positions - closest_points + + # Be sure to normalize: + + # Sometimes, if the points are very close or on the mesh, the + # sdf is 0.0, and the norm goes to 0.0 + + distance_to_closest_point = torch.norm(positions - closest_points, dim=-1) + null_points = distance_to_closest_point < 1e-6 + + # In these cases, we update the vector to be from the center of mass + normals[null_points] = positions[null_points] - center_of_mass + + norm = torch.norm(normals, dim=-1, keepdim=True) + 1e-6 + normals = normals / norm + + embeddings_inputs.append(normals) + + embeddings = torch.cat(embeddings_inputs, dim=-1) + + # Build fx: + fx_inputs = [ + data_dict["air_density"], + data_dict["stream_velocity"], + ] + fx = torch.stack(fx_inputs, dim=-1) + + if self.config.broadcast_global_features: + fx = fx.broadcast_to(embeddings.shape[0], -1) + else: + fx = fx.unsqueeze(0) + + fields = data_dict["volume_fields"] + if idx is not None: + fields = fields[idx] + + if self.config.scaling_type is not None: + fields = self.scale_model_targets(fields, self.config.normalization_factors) + + return { + "embeddings": embeddings, + "fx": fx, + "fields": fields, + } + + def process_geometry( + self, + data_dict, + center_of_mass: torch.Tensor | None = None, + scale_factor: torch.Tensor | None = None, + ): + """ + Process the geometry data. + """ + geometry_coordinates = data_dict["stl_coordinates"] + if self.config.geometry_sampling is not None: + # idx = torch.multinomial( + # torch.ones(data_dict["stl_coordinates"].shape[0]), + # self.config.geometry_sampling, + # ) + idx = poisson_sample_indices_fixed( + data_dict["stl_coordinates"].shape[0], + self.config.geometry_sampling, + device=data_dict["stl_coordinates"].device, + ) + geometry_coordinates = geometry_coordinates[idx] + + if self.config.translational_invariance: + geometry_coordinates -= center_of_mass + + if self.config.scale_invariance: + geometry_coordinates = geometry_coordinates / scale_factor + + return geometry_coordinates + + @torch.no_grad() + def process_data(self, data_dict): + """ + Preprocess the data. We have slight differences between surface and volume data processing, + mostly revolving around the keys that represent the inputs. + + - For surface data, we use the mesh coordinates and normals as the embeddings. + - Normals are always normalized to 1.0, and are a relative direction. + - coordinates can be shifted to the center of mass, and then the whole + coordinate system can be aligned to the preferred direction. + - SDF is identically 0 for surface data. + - Optionally, if the scale invariance is enabled, the coordinates + are scaled by the (maybe-rotated) scale factor. + + - For Volume data: we still use the volume coordinates + - normals are approximated as the direction between the volume point + and closest mesh point. Normalized to 1.0. + - SDF is not zero for volume data. + + + To make the calculations consistent and logical to follow: + - First, get the coordinates (volume_mesh_centers or surface_mesh_centers, usually) + which is a configuration. + - Second, get the STL information. We need the "stl_vertices" and "stl_indices" + to compute an SDF. We downsample "stl_coordinates" to potentially encode + a geometry tensor, which is optional. + + Then, start imposing optional symmetries: + - Impose translation invariance. For every "position-like" tensor, subtract + off the reference_origin if translation invariance is enabled. + - Second, impose scale invariance: for every position-like tensor, multiply + by the reference scale. + - Finally, apply rotation invariance. Normals are rotated, points are rotated. + Roation requires not just a reference vector (in the config) but a + vector unique to this example to come from the data - we have to rotate to it. + + After that, the rest is simple: + - Spatial Encodings are the point locations + normal vectors (optional) + sdf (optional) + - If the normals aren't provided, we derive them from the center of mass (without SDF) or SDF point (with SDF) + - Geometry encoding (if using) is the STL coordinates, downsampled. + - parameter encodings are straight forward vectors / reference values. + + The downstream applications can take the embeddings and the features as needed. + + """ + + # Validate that all required keys are present in data_dict + required_keys = [ + "stl_centers", + ] + + if self.config.model_type == "volume": + # We need these for the SDF calculation: + required_keys.extend( + [ + "stl_coordinates", + "stl_faces", + ] + ) + elif self.config.model_type == "surface": + required_keys.extend( + [ + "surface_normals", + ] + ) + + if self.config.translational_invariance: + if self.config.reference_origin is not None: + center_of_mass = self.config.reference_origin + else: + center_of_mass = torch.mean(data_dict["stl_centers"], dim=0) + center_of_mass = center_of_mass.unsqueeze(0) # (1, 3) + else: + center_of_mass = None + + field_key = f"{self.config.model_type}_fields" + coords_key = f"{self.config.model_type}_mesh_centers" + + required_keys.extend( + [ + field_key, + coords_key, + ] + ) + + missing_keys = [key for key in required_keys if key not in data_dict] + if missing_keys: + raise ValueError( + f"Missing required keys in data_dict: {missing_keys}. " + f"Required keys are: {required_keys}" + ) + + scale_factor = ( + self.config.reference_scale if self.config.scale_invariance else None + ) + + if self.config.model_type == "surface": + outputs = self.preprocess_surface_data( + data_dict, center_of_mass, scale_factor + ) + elif self.config.model_type == "volume": + outputs = self.preprocess_volume_data( + data_dict, center_of_mass, scale_factor + ) + + if self.config.include_geometry: + outputs["geometry"] = self.process_geometry( + data_dict, center_of_mass, scale_factor + ) + + if self.config.return_mesh_features: + outputs["surface_areas"] = data_dict["surface_areas"] + outputs["surface_normals"] = data_dict["surface_normals"] + + if "air_density" in data_dict: + outputs["air_density"] = data_dict["air_density"] + if "stream_velocity" in data_dict: + outputs["stream_velocity"] = data_dict["stream_velocity"] + + return outputs + + def scale_model_targets( + self, fields: torch.Tensor, factors: torch.Tensor + ) -> torch.Tensor: + """ + Scale the model targets based on the configured scaling factors. + """ + if self.config.scaling_type == "mean_std_scaling": + field_mean = factors["mean"] + field_std = factors["std"] + return standardize(fields, field_mean, field_std) + elif self.config.scaling_type == "min_max_scaling": + field_min = factors["min"] + field_max = factors["max"] + return normalize(fields, field_max, field_min) + + def unscale_model_targets( + self, + fields: torch.Tensor | None = None, + air_density: torch.Tensor | None = None, + stream_velocity: torch.Tensor | None = None, + ): + """ + Unscale the model outputs based on the configured scaling factors. + + The unscaling is included here to make it a consistent interface regardless + of the scaling factors and type used. + + """ + + factors = self.config.normalization_factors + + if self.config.scaling_type == "mean_std_scaling": + field_mean = factors["mean"] + field_std = factors["std"] + fields = unstandardize(fields, field_mean, field_std) + elif self.config.scaling_type == "min_max_scaling": + field_min = factors["min"] + field_max = factors["max"] + fields = unnormalize(fields, field_max, field_min) + + if air_density is not None and stream_velocity is not None: + fields = fields * air_density * stream_velocity**2 + + return fields + + def set_dataset(self, dataset: Iterable) -> None: + """ + Pass a dataset to the datapipe to enable iterating over both in one pass. + """ + self.dataset = dataset + + if self.config.scale_invariance: + self.config.reference_scale = self.config.reference_scale.to( + self.dataset.output_device + ) + + if self.config.model_type == "volume" and self.config.volume_sample_from_disk: + # We deliberately double the data to read compared to the sampling size: + self.dataset.set_volume_sampling_size(25 * self.config.resolution) + + def __len__(self): + if self.dataset is not None: + return len(self.dataset) + else: + return 0 + + def __getitem__(self, idx): + """ + Function for fetching and processing a single file's data. + + Domino, in general, expects one example per file and the files + are relatively large due to the mesh size. + + Requires the user to have set a dataset via `set_dataset`. + """ + if self.dataset is None: + raise ValueError("Dataset is not present") + + # Get the data from the dataset. + # Under the hood, this may be fetching preloaded data. + data_dict = self.dataset[idx] + + return self.__call__(data_dict) + + def __call__(self, data_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Process the incoming data dictionary. + - Processes the data + - moves it to GPU + - adds a batch dimension + + Args: + data_dict: Dictionary containing the data to process as torch.Tensors. + + Returns: + Dictionary containing the processed data as torch.Tensors. + + """ + outputs = self.process_data(data_dict) + + for key in outputs.keys(): + outputs[key] = outputs[key].unsqueeze(0) + + return outputs + + def __iter__(self): + if self.dataset is None: + raise ValueError( + "Dataset is not present, can not use the datapipe as an iterator." + ) + + for i, batch in enumerate(self.dataset): + yield self.__call__(batch) + + +def create_transolver_dataset( + cfg: DictConfig, + phase: Literal["train", "val", "test"], + # keys_to_read: list[str], + # keys_to_read_if_available: dict[str, torch.Tensor], + scaling_factors: list[float], + # normalize_coordinates: bool = True, + device_mesh: torch.distributed.DeviceMesh | None = None, + placements: dict[str, torch.distributed.tensor.Placement] | None = None, +): + model_type = cfg.mode + if phase == "train": + input_path = cfg.train.data_path + elif phase == "val": + input_path = cfg.val.data_path + # elif phase == "test": + # input_path = cfg.eval.test_path + else: + raise ValueError(f"Invalid phase {phase}") + + # The dataset path works in two pieces: + # There is a core "dataset" which is loading data and moving to GPU + # And there is the preprocess step, here. + + # Optionally, and for backwards compatibility, the preprocess + # object can accept a dataset which will enable it as an iterator. + # The iteration function will loop over the dataset, preprocess the + # output, and return it. + + keys_to_read = cfg.data_keys + + overrides = {} + + dm = DistributedManager() + + if torch.cuda.is_available(): + device = dm.device + consumer_stream = torch.cuda.default_stream() + else: + device = torch.device("cpu") + consumer_stream = None + + if cfg.get("preload_depth", None) is not None: + preload_depth = cfg.preload_depth + else: + preload_depth = 1 + + if cfg.get("pin_memory", None) is not None: + pin_memory = cfg.pin_memory + else: + pin_memory = False + + # These are keys that could be set in the config, + # but have a sensible default if not. + optional_cfg_keys = [ + "include_normals", + "include_sdf", + "volume_sample_from_disk", + "broadcast_global_features", + "include_geometry", + "geometry_sampling", + "translational_invariance", + "reference_origin", + "scale_invariance", + "reference_scale", + "return_mesh_features", + ] + + for optional_key in optional_cfg_keys: + if cfg.get(optional_key, None) is not None: + overrides[optional_key] = cfg[optional_key] + + dataset = CAEDataset( + data_dir=input_path, + keys_to_read=keys_to_read, + keys_to_read_if_available={}, + output_device=device, + preload_depth=preload_depth, + pin_memory=pin_memory, + device_mesh=device_mesh, + placements=placements, + consumer_stream=consumer_stream, + ) + + datapipe = TransolverDataPipe( + input_path, + resolution=cfg.resolution, + normalization_factors=scaling_factors, + model_type=model_type, + scaling_type="mean_std_scaling", + **overrides, + ) + + datapipe.set_dataset(dataset) + + return datapipe + + +def poisson_sample_indices_fixed(N: int, k: int, device=None): + """ + This function is a nearly uniform sampler of indices for when the + number of indices to sample is very, very large. It's useful when + the number of indices to sample is larger than 2^24 and torch + multinomial can't work. Unlike using randperm, there is no + need to materialize and randomize the entire tensor of indices. + + """ + # Draw exponential gaps off of random initializations: + gaps = torch.rand(k, device=device).exponential_() + + summed = gaps.sum() + + # Normalize so total cumulative sum == N + gaps *= N / summed + + # Compute cumulative positions + idx = torch.cumsum(gaps, dim=0) + + # Shift down so range starts at 0 and ends below N + idx -= gaps[0] / 2 + + # Round to nearest integer index + idx = torch.clamp(idx.floor().long(), min=0, max=N - 1) + + return idx diff --git a/physicsnemo/experimental/models/typhon/__init__.py b/physicsnemo/experimental/models/typhon/__init__.py new file mode 100644 index 0000000000..430d0e93d2 --- /dev/null +++ b/physicsnemo/experimental/models/typhon/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .typhon import Typhon + +__all__ = ["Typhon"] diff --git a/physicsnemo/experimental/models/typhon/typhon.py b/physicsnemo/experimental/models/typhon/typhon.py new file mode 100644 index 0000000000..4351a17e26 --- /dev/null +++ b/physicsnemo/experimental/models/typhon/typhon.py @@ -0,0 +1,821 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch +import torch.nn as nn +from einops import rearrange + +import physicsnemo # noqa: F401 for docs +from physicsnemo.utils.version_check import check_min_version +from physicsnemo.models.transolver.Physics_Attention import ( + PhysicsAttentionIrregularMesh, + gumbel_softmax, +) +from physicsnemo.models.transolver.transolver import MLP + +from physicsnemo.models.meta import ModelMetaData +from physicsnemo.models.module import Module + +# Check optional dependency availability +TE_AVAILABLE = check_min_version("transformer-engine", "0.1.0", hard_fail=False) +if TE_AVAILABLE: + import transformer_engine.pytorch as te + +ACTIVATION = { + "gelu": nn.GELU, + "tanh": nn.Tanh, + "sigmoid": nn.Sigmoid, + "relu": nn.ReLU, + "leaky_relu": nn.LeakyReLU(0.1), + "softplus": nn.Softplus, + "ELU": nn.ELU, + "silu": nn.SiLU, +} + + +class GALE(PhysicsAttentionIrregularMesh): + r"""Geometry-Aware Latent Embeddings (GALE) attention layer. + + This is an extension of the Transolver PhysicsAttention mechanism to support + cross-attention with a context vector, built from geometry and global embeddings. + GALE combines self-attention on learned physical state slices with cross-attention + to geometry-aware context, using a learnable mixing weight to blend the two. + + Parameters + ---------- + dim : int + Input dimension of the features. + heads : int, optional + Number of attention heads. Default is 8. + dim_head : int, optional + Dimension of each attention head. Default is 64. + dropout : float, optional + Dropout rate. Default is 0.0. + slice_num : int, optional + Number of learned physical state slices. Default is 64. + use_te : bool, optional + Whether to use Transformer Engine backend when available. Default is True. + plus : bool, optional + Whether to use Transolver++ features. Default is False. + context_dim : int, optional + Dimension of the context vector for cross-attention. Default is 0. + + Notes + ----- + The mixing between self-attention and cross-attention is controlled by a learnable + parameter ``state_mixing`` which is passed through a sigmoid function to ensure + the mixing weight stays in \([0, 1]\). + + See Also + -------- + :class:`physicsnemo.models.transolver.Physics_Attention.PhysicsAttentionIrregularMesh` : Base physics attention class. + :class:`GALE_block` : Transformer block using GALE attention. + """ + + def __init__( + self, + dim, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + slice_num: int = 64, + use_te: bool = True, + plus: bool = False, + context_dim: int = 0, + ): + super().__init__(dim, heads, dim_head, dropout, slice_num, use_te, plus) + + linear_layer = te.Linear if self.use_te else nn.Linear + + # We have additional parameters, here: + self.cross_q = linear_layer(dim_head, dim_head) + self.cross_k = linear_layer(context_dim, dim_head) + self.cross_v = linear_layer(context_dim, dim_head) + + # This is the learnable mixing weight between self and cross attention. + # We start near 0.0 since it is passed through a sigmoid to keep the + # mixing weight between 0 and 1. + self.state_mixing = nn.Parameter(torch.tensor(0.0)) + + def compute_slice_attention_cross( + self, slice_tokens: torch.Tensor, context: torch.Tensor + ) -> torch.Tensor: + r"""Compute cross-attention between slice tokens and context. + + Parameters + ---------- + slice_tokens : torch.Tensor + Slice tokens of shape \((B, H, N, D)\) where \(B\) is batch size, \(H\) is number of heads, \(N\) is number of slices, and \(D\) is head dimension. + context : torch.Tensor + Context tensor of shape \((B, H, N_c, D_c)\) where \(N_c\) is number of context slices and \(D_c\) is context dimension. + + Returns + ------- + torch.Tensor + Cross-attention output of shape \((B, H, N, D)\). + """ + + # Project the slice and context tokens: + q = self.cross_q(slice_tokens) + k = self.cross_k(context) + v = self.cross_v(context) + + # Compute the attention: + if self.use_te: + cross_attention = self.attn_fn(q, k, v) + else: + cross_attention = torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=False + ) + + return cross_attention + + def forward( + self, x: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + r"""Forward pass of the GALE module. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape \((B, N, C)\) where \(B\) is batch size, \(N\) is number of tokens, and \(C\) is number of channels. + context : torch.Tensor, optional + Context tensor for cross-attention of shape \((B, H, S_c, D_c)\) where \(H\) is number of heads, \(S_c\) is number of context slices, and \(D_c\) is context dimension. If None, only self-attention is applied. Default is None. + + Returns + ------- + torch.Tensor + Output tensor of shape \((B, N, C)\), same shape as input. + """ + + # Project the inputs onto learned spaces: + if self.plus: + x_mid = self.project_input_onto_slices(x) + # In transolver ++, fx_mid is gone. + # x_mid is used to compute the projections instead: + fx_mid = x_mid + else: + x_mid, fx_mid = self.project_input_onto_slices(x) + + # Perform the linear projection of learned latent space onto slices: + slice_projections = self.in_project_slice(x_mid) + + # Slice projections has shape [B, N_head, N_tokens, Head_dim], but head_dim may have changed! + + # Use the slice projections and learned spaces to compute the slices, and their weights: + slice_weights, slice_tokens = self.compute_slices_from_projections( + slice_projections, fx_mid + ) + # slice_weights has shape [Batch, N_heads, N_tokens, Slice_num] + # slice_tokens has shape [Batch, N_heads, N_tokens, head_dim] + + # Apply attention to the slice tokens + if self.use_te: + self_slice_token = self.compute_slice_attention_te(slice_tokens) + else: + self_slice_token = self.compute_slice_attention_sdpa(slice_tokens) + + # HERE, we are differing: apply cross-attention with physical states: + if context is not None: + cross_slice_token = self.compute_slice_attention_cross( + slice_tokens, context + ) + + # Apply learnable mixing: + mixing_weight = torch.sigmoid(self.state_mixing) + out_slice_token = ( + mixing_weight * self_slice_token + + (1 - mixing_weight) * cross_slice_token + ) + + else: + # Just keep self attention: + out_slice_token = self_slice_token + + # Shape unchanged + + # Deslice: + outputs = self.project_attention_outputs(out_slice_token, slice_weights) + + # Outputs now has the same shape as the original input x + + return outputs + + +class GALE_block(nn.Module): + r"""Transformer encoder block using GALE attention. + + This block replaces standard self-attention with the GALE (Geometry-Aware Latent + Embeddings) attention mechanism, which combines physics-aware self-attention with + cross-attention to geometry and global context. + + Parameters + ---------- + num_heads : int + Number of attention heads. + hidden_dim : int + Hidden dimension of the transformer. + dropout : float + Dropout rate. + act : str, optional + Activation function name. Default is "gelu". + mlp_ratio : int, optional + Ratio of MLP hidden dimension to ``hidden_dim``. Default is 4. + last_layer : bool, optional + Whether this is the last layer in the model. Default is False. + out_dim : int, optional + Output dimension (only used if ``last_layer=True``). Default is 1. + slice_num : int, optional + Number of learned physical state slices. Default is 32. + use_te : bool, optional + Whether to use Transformer Engine backend. Default is True. + plus : bool, optional + Whether to use Transolver++ features. Default is False. + context_dim : int, optional + Dimension of the context vector for cross-attention. Default is 0. + + Notes + ----- + The block applies layer normalization before the attention operation and uses + residual connections after both the attention and MLP layers. + """ + + def __init__( + self, + num_heads: int, + hidden_dim: int, + dropout: float, + act="gelu", + mlp_ratio=4, + last_layer=False, + out_dim=1, + slice_num=32, + use_te=True, + plus: bool = False, + context_dim: int = 0, + ): + super().__init__() + + if use_te and not TE_AVAILABLE: + raise ImportError( + "Transformer Engine is not installed. Please install it with: pip install transformer-engine>=0.1.0" + ) + + self.last_layer = last_layer + if use_te: + self.ln_1 = te.LayerNorm(hidden_dim) + else: + self.ln_1 = nn.LayerNorm(hidden_dim) + + self.Attn = GALE( + hidden_dim, + heads=num_heads, + dim_head=hidden_dim // num_heads, + dropout=dropout, + slice_num=slice_num, + use_te=use_te, + plus=plus, + context_dim=context_dim, + ) + + if use_te: + self.ln_mlp1 = te.LayerNormMLP( + hidden_size=hidden_dim, + ffn_hidden_size=hidden_dim * mlp_ratio, + ) + else: + self.ln_mlp1 = nn.Sequential( + nn.LayerNorm(hidden_dim), + MLP( + hidden_dim, + hidden_dim * mlp_ratio, + hidden_dim, + n_layers=0, + res=False, + act=act, + use_te=False, + ), + ) + + def forward(self, fx: torch.Tensor, global_context: torch.Tensor) -> torch.Tensor: + r"""Forward pass of the GALE block. + + Parameters + ---------- + fx : torch.Tensor + Input tensor of shape \((B, N, C)\) where \(B\) is batch size, \(N\) is number of tokens, and \(C\) is hidden dimension. + global_context : torch.Tensor + Global context tensor for cross-attention of shape \((B, H, S_c, D_c)\) where \(H\) is number of heads, \(S_c\) is number of context slices, and \(D_c\) is context dimension. + + Returns + ------- + torch.Tensor + Output tensor of shape \((B, N, C)\), same shape as input. + """ + fx = self.Attn(self.ln_1(fx), global_context) + fx + fx = self.ln_mlp1(fx) + fx + + return fx + + +@dataclass +class TyphonMetaData(ModelMetaData): + """ + Data class for storing essential meta data needed for the Typhon model. + """ + + name: str = "Typhon" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp: bool = True + # Inference + onnx_cpu: bool = False # No FFT op on CPU + onnx_gpu: bool = True + onnx_runtime: bool = True + # Physics informed + var_dim: int = 1 + func_torch: bool = False + auto_grad: bool = False + + +class ContextProjector(nn.Module): + r"""Projects context features onto physical state space. + + This context projector is conceptually similar to half of a GALE attention layer. + It projects context values (geometry or global embeddings) onto a learned physical + state space, but unlike a full attention layer, it never projects back to the + original space. The projected features are used as context in all GALE blocks + of the Typhon model. + + Parameters + ---------- + dim : int + Input dimension of the context features. + heads : int, optional + Number of projection heads. Default is 8. + dim_head : int, optional + Dimension of each projection head. Default is 64. + dropout : float, optional + Dropout rate. Default is 0.0. + slice_num : int, optional + Number of learned physical state slices. Default is 64. + use_te : bool, optional + Whether to use Transformer Engine backend when available. Default is True. + plus : bool, optional + Whether to use Transolver++ features. Default is False. + + Notes + ----- + The global features are reused in all blocks of the model, so the learned + projections must capture globally useful features rather than layer-specific ones. + + See Also + -------- + :class:`GALE` : Full GALE attention layer that uses these projected context features. + :class:`Typhon` : Main model that uses ContextProjector for geometry and global embeddings. + """ + + def __init__( + self, + dim, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + slice_num: int = 64, + use_te: bool = True, + plus: bool = False, + ): + super().__init__() + inner_dim = dim_head * heads + self.dim_head = dim_head + self.heads = heads + self.plus = plus + self.scale = dim_head**-0.5 + self.use_te = use_te + + # Keep below here: + if use_te: + self.in_project_x = te.Linear(dim, inner_dim) + if not plus: + self.in_project_fx = te.Linear(dim, inner_dim) + else: + self.in_project_x = nn.Linear(dim, inner_dim) + if not plus: + self.in_project_fx = nn.Linear(dim, inner_dim) + + self.softmax = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) + + if plus: + linear_layer = te.Linear if self.use_te else nn.Linear + self.proj_temperature = torch.nn.Sequential( + linear_layer(self.dim_head, slice_num), + nn.GELU(), + linear_layer(slice_num, 1), + nn.GELU(), + ) + + if self.use_te: + self.in_project_slice = te.Linear(dim_head, slice_num) + else: + self.in_project_slice = nn.Linear(dim_head, slice_num) + + def project_input_onto_slices( + self, x: torch.Tensor + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + r"""Project the input onto the slice space. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape \((B, N, C)\) where \(B\) is batch size, \(N\) is number of tokens, and \(C\) is number of channels. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + If ``plus=True``, returns single tensor ``x_mid`` of shape \((B, H, N, D)\) where \(H\) is number of heads and \(D\) is head dimension. If ``plus=False``, returns tuple ``(x_mid, fx_mid)`` both of shape \((B, H, N, D)\). + """ + x_mid = rearrange( + self.in_project_x(x), "B N (h d) -> B h N d", h=self.heads, d=self.dim_head + ) + if self.plus: + return x_mid + else: + fx_mid = rearrange( + self.in_project_fx(x), + "B N (h d) -> B h N d", + h=self.heads, + d=self.dim_head, + ) + + return x_mid, fx_mid + + def compute_slices_from_projections( + self, slice_projections: torch.Tensor, fx: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + r"""Compute slice weights and slice tokens from input projections and latent features. + + Parameters + ---------- + slice_projections : torch.Tensor + Projected input tensor of shape \((B, N, H, S)\) where \(B\) is batch size, \(H\) is number of heads, \(N\) is number of tokens, and \(S\) is number of slices, representing the projection of each token onto each slice for each attention head. + fx : torch.Tensor + Latent feature tensor of shape \((B, N, H, D)\) where \(D\) is head dimension, representing the learned states to be aggregated by the slice weights. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + - ``slice_weights``: Tensor of shape \((B, N, H, S)\), representing the normalized weights for each slice per token and head. + - ``slice_token``: Tensor of shape \((B, H, S, D)\), representing the aggregated latent features for each slice, head, and batch. + + Notes + ----- + The function computes a temperature-scaled softmax over the slice projections to obtain + slice weights, then aggregates the latent features for each slice using these weights. + The aggregated features are normalized by the sum of weights for numerical stability. + """ + + # Project the latent space vectors on to the weight computation space, + # and compute a temperature adjusted softmax. + + if self.plus: + temperature = self.temperature + self.proj_temperature(fx) + clamped_temp = torch.clamp(temperature, min=0.01).to( + slice_projections.dtype + ) + slice_weights = gumbel_softmax( + slice_projections, clamped_temp + ) # [Batch, N_heads, N_tokens, Slice_num] + + else: + clamped_temp = torch.clamp(self.temperature, min=0.5, max=5).to( + slice_projections.dtype + ) + slice_weights = nn.functional.softmax( + slice_projections / clamped_temp, dim=-1 + ) # [Batch, N_heads, N_tokens, Slice_num] + + # Cast to the computation type (since the parameter is probably fp32) + slice_weights = slice_weights.to(slice_projections.dtype) + + # This does the projection of the latent space fx by the weights: + + # Computing the slice tokens is a matmul followed by a normalization. + # It can, unfortunately, overflow in reduced precision, so normalize first: + slice_norm = slice_weights.sum(2) # [Batch, N_heads, Slice_num] + normed_weights = slice_weights / (slice_norm[:, :, None, :] + 1e-2) + slice_token = torch.matmul(normed_weights.transpose(2, 3), fx) + + # Return the original weights, not the normed weights: + return slice_weights, slice_token + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""Reduced forward pass projecting inputs to physical state slices. + + This performs a partial physics attention operation: it projects the input onto + learned physical state slices but does not project back to the original space. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape \((B, N, C)\) where \(B\) is batch size, \(N\) is number of tokens, and \(C\) is number of channels. + + Returns + ------- + torch.Tensor + Slice tokens of shape \((B, H, S, D)\) where \(H\) is number of heads, \(S\) is number of slices, and \(D\) is head dimension. + """ + + # All of this is derived from the PhysicsAttention Layer + + # Project the inputs onto learned spaces: + if self.plus: + x_mid = self.project_input_onto_slices(x) + # In transolver ++, fx_mid is gone. + # x_mid is used to compute the projections instead: + fx_mid = x_mid + else: + x_mid, fx_mid = self.project_input_onto_slices(x) + + # Perform the linear projection of learned latent space onto slices: + slice_projections = self.in_project_slice(x_mid) + + # Slice projections has shape [B, N_head, N_tokens, Head_dim], but head_dim may have changed! + + # Use the slice projections and learned spaces to compute the slices, and their weights: + _, slice_tokens = self.compute_slices_from_projections( + slice_projections, fx_mid + ) + # _ has shape [Batch, N_heads, N_tokens, Slice_num] + # slice_tokens has shape [Batch, N_heads, N_tokens, head_dim] + + return slice_tokens + + +class Typhon(Module): + r"""Typhon: Geometry-Aware Physics Attention Transformer. + + Typhon is an adaptation of the Transolver architecture, replacing standard attention + with GALE (Geometry-Aware Latent Embeddings) attention. GALE combines physics-aware + self-attention on learned state slices with cross-attention to geometry and global + context embeddings. + + The model projects geometry and global features onto physical state spaces, which are + then used as context in all transformer blocks. This design enables the model to + incorporate geometric structure and global information throughout the forward pass. + + Parameters + ---------- + functional_dim : int + Dimension of the input values (local embeddings), not including global embeddings or geometry features. Input will be projected to ``n_hidden`` before processing. + out_dim : int + Dimension of the output of the model. + geometry_dim : int, optional + Pointwise dimension of the geometry input features. If provided, geometry features will be projected onto physical states and used as context in all GALE layers. Default is None. + global_dim : int, optional + Dimension of the global embedding features. If provided, global features will be projected onto physical states and used as context in all GALE layers. Default is None. + n_layers : int, optional + Number of GALE layers in the model. Default is 4. + n_hidden : int, optional + Hidden dimension of the transformer. Default is 256. + dropout : float, optional + Dropout rate applied across the GALE layers. Default is 0.0. + n_head : int, optional + Number of attention heads in each GALE layer. Must evenly divide ``n_hidden`` to yield an integer head dimension. Default is 8. + act : str, optional + Activation function name. Default is "gelu". + mlp_ratio : int, optional + Ratio of MLP hidden dimension to ``n_hidden``. Default is 4. + slice_num : int, optional + Number of learned physical state slices in the GALE layers, representing the number of learned states each layer should project inputs onto. Default is 32. + use_te : bool, optional + Whether to use Transformer Engine backend when available. Default is True. + time_input : bool, optional + Whether to include time embeddings. Default is False. + plus : bool, optional + Whether to use Transolver++ features in the GALE layers. Default is False. + + Raises + ------ + ValueError + If ``n_hidden`` is not evenly divisible by ``n_head``. + + + Forward + ---------- + local_embedding : torch.Tensor + Local embedding of the input data of shape \((B, N, C)\) where \(B\) is batch size, \(N\) is number of nodes/tokens, and \(C\) is ``functional_dim``. Output will have the same \((B, N)\) shape but with ``out_dim`` channels. + global_embedding : torch.Tensor, optional + Global embedding of the input data of shape \((B, N_g, C_g)\) where \(N_g\) is number of global tokens and \(C_g\) is ``global_dim``. If None, global context is not used. Default is None. + geometry : torch.Tensor, optional + Geometry features of the input data of shape \((B, N, C_{geo})\) where \(C_{geo}\) is ``geometry_dim``. If None, geometry context is not used. Default is None. + time : torch.Tensor, optional + Time embedding (currently not implemented). Default is None. + + Returns + ------- + torch.Tensor + Output tensor of shape \((B, N, C_{out})\) where \(C_{out}\) is ``out_dim``. + + Notes + ----- + Typhon currently supports unstructured mesh input only. Enhancements for image-based + and voxel-based inputs may be available in the future. + + For more details on Transolver, see: + - https://arxiv.org/pdf/2402.02366 + - https://arxiv.org/pdf/2502.02414 + + See Also + -------- + :class:`GALE` : The attention mechanism used in Typhon. + :class:`GALE_block` : Transformer block using GALE attention. + :class:`ContextProjector` : Projects context features onto physical states. + + Examples + -------- + Basic usage with local embeddings only: + + >>> import torch + >>> import physicsnemo + >>> model = physicsnemo.models.Typhon( + ... functional_dim=64, + ... out_dim=3, + ... n_hidden=256, + ... n_layers=4 + ... ) + >>> local_emb = torch.randn(2, 1000, 64) # (batch, nodes, features) + >>> output = model(local_emb) + >>> output.shape + torch.Size([2, 1000, 3]) + + Usage with geometry and global context: + + >>> model = physicsnemo.models.Typhon( + ... functional_dim=64, + ... out_dim=3, + ... geometry_dim=3, + ... global_dim=16, + ... n_hidden=256, + ... n_layers=4 + ... ) + >>> local_emb = torch.randn(2, 1000, 64) + >>> geometry = torch.randn(2, 1000, 3) # (batch, nodes, spatial_dim) + >>> global_emb = torch.randn(2, 1, 16) # (batch, 1, global_features) + >>> output = model(local_emb, global_embedding=global_emb, geometry=geometry) + >>> output.shape + torch.Size([2, 1000, 3]) + """ + + def __init__( + self, + functional_dim: int, + out_dim: int, + geometry_dim: int | None = None, + global_dim: int | None = None, + n_layers: int = 4, + n_hidden: int = 256, + dropout: float = 0.0, + n_head: int = 8, + act: str = "gelu", + mlp_ratio: int = 4, + slice_num: int = 32, + use_te: bool = True, + time_input: bool = False, + plus: bool = False, + ) -> None: + super().__init__(meta=TyphonMetaData()) + self.__name__ = "Typhon" + + self.use_te = use_te + # Check that the hidden dimension and head dimensions are compatible: + if not n_hidden % n_head == 0: + raise ValueError( + f"Typhon requires n_hidden % n_head == 0, but instead got {n_hidden % n_head}" + ) + + # These are to project geometry embeddings and global embeddings onto + # a physical state space: + context_dim = 0 + if geometry_dim is not None: + self.geometry_tokenizer = ContextProjector( + geometry_dim, + n_head, + n_hidden // n_head, + dropout, + slice_num, + use_te, + plus, + ) + context_dim += n_hidden // n_head + if global_dim is not None: + self.global_tokenizer = ContextProjector( + global_dim, n_head, n_hidden // n_head, dropout, slice_num, use_te, plus + ) + context_dim += n_hidden // n_head + + # This MLP is the initial projection onto the hidden space + self.preprocess = MLP( + functional_dim, + n_hidden * 2, + n_hidden, + n_layers=0, + res=False, + act=act, + use_te=use_te, + ) + + self.n_hidden = n_hidden + + self.blocks = nn.ModuleList( + [ + GALE_block( + num_heads=n_head, + hidden_dim=n_hidden, + dropout=dropout, + act=act, + mlp_ratio=mlp_ratio, + slice_num=slice_num, + last_layer=(_ == n_layers - 1), + use_te=use_te, + plus=plus, + context_dim=context_dim, + ) + for _ in range(n_layers) + ] + ) + + if use_te: + self.ln_mlp_out = te.LayerNormLinear( + in_features=n_hidden, out_features=out_dim + ) + else: + self.ln_mlp_out = nn.Sequential( + nn.LayerNorm(n_hidden), + nn.Linear(n_hidden, out_dim), + ) + + self.time_input = time_input + if time_input: + self.time_fc = nn.Sequential( + nn.Linear(n_hidden, n_hidden), nn.SiLU(), nn.Linear(n_hidden, n_hidden) + ) + + def forward( + self, + local_embedding: torch.Tensor, + global_embedding: torch.Tensor | None = None, + geometry: torch.Tensor | None = None, + time: torch.Tensor | None = None, + ) -> torch.Tensor: + r"""Forward pass of the Typhon model. + + The model constructs global context embeddings from geometry and global features by + projecting them onto physical state spaces. These context embeddings are then used + in all GALE blocks via cross-attention, allowing geometric and global information to + guide the learned physical state dynamics. + + """ + + # First, construct the global context vectors: + global_context_input = [] + + if geometry is not None: + geometry_states = self.geometry_tokenizer(geometry) + global_context_input.append(geometry_states) + + if global_embedding is not None: + global_states = self.global_tokenizer(global_embedding) + global_context_input.append(global_states) + + # Construct the embedding states: + if len(global_context_input) > 0: + embedding_states = torch.cat(global_context_input, dim=-1) + + # Project the inputs to the hidden dimension: + x = self.preprocess(local_embedding) + + for block in self.blocks: + x = block(x, embedding_states) + + # Now, pass the data through the model: + x = self.ln_mlp_out(x) + + return x diff --git a/physicsnemo/models/transolver/Physics_Attention.py b/physicsnemo/models/transolver/Physics_Attention.py index 2c793aaf59..978a065dc5 100644 --- a/physicsnemo/models/transolver/Physics_Attention.py +++ b/physicsnemo/models/transolver/Physics_Attention.py @@ -34,7 +34,15 @@ import torch import torch.nn as nn -import transformer_engine.pytorch as te # noqa: F401 + +try: + import transformer_engine.pytorch as te +except (ImportError, FileNotFoundError): + te = None + TE_AVAILABLE = False +else: + TE_AVAILABLE = True + from einops import rearrange from torch.autograd.profiler import record_function from torch.distributed.tensor.placement_types import Replicate @@ -42,6 +50,30 @@ from physicsnemo.distributed import ShardTensor +def gumbel_softmax(logits: torch.Tensor, tau: float = 1.0) -> torch.Tensor: + """ + Implementation of Gumblel Softmax from transolver++. + + Original code: https://github.com/thuml/Transolver_plus/blob/main/models/Transolver_plus.py#L69 + + Args: + logits (torch.Tensor): The logits to apply Gumblel Softmax to. + tau (float): The temperature parameter for the Gumblel Softmax. + + Returns: + torch.Tensor: The Gumblel Softmax of the logits. + """ + u = torch.rand_like(logits) + gumbel_noise = -torch.log(-torch.log(u + 1e-8) + 1e-8) + + y = logits + gumbel_noise + y = y / tau + + y = torch.nn.functional.softmax(y, dim=-1) + + return y + + class PhysicsAttentionBase(nn.Module, ABC): """ Base class for all physics attention modules. @@ -59,18 +91,36 @@ class PhysicsAttentionBase(nn.Module, ABC): """ - def __init__(self, dim, heads, dim_head, dropout, slice_num, use_te): + def __init__( + self, + dim: int, + heads: int, + dim_head: int, + dropout: float, + slice_num: int, + use_te: bool, + plus: bool, + ): super().__init__() inner_dim = dim_head * heads self.dim_head = dim_head self.heads = heads - + self.plus = plus self.scale = dim_head**-0.5 + self.use_te = use_te self.softmax = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) self.temperature = nn.Parameter(torch.ones([1, 1, heads, 1]) * 0.5) - self.use_te = use_te + + if plus: + linear_layer = te.Linear if self.use_te else nn.Linear + self.proj_temperature = torch.nn.Sequential( + linear_layer(self.dim_head, slice_num), + nn.GELU(), + linear_layer(slice_num, 1), + nn.GELU(), + ) if self.use_te: self.in_project_slice = te.Linear(dim_head, slice_num) @@ -138,48 +188,57 @@ def compute_slices_from_projections( - The aggregated features are normalized by the sum of weights for numerical stability. """ - with record_function("compute_slices_from_projections"): - # Project the latent space vectors on to the weight computation space, - # and compute a temperature adjusted softmax. - clamped_temp = torch.clamp(self.temperature, min=0.5, max=5).to( + # Project the latent space vectors on to the weight computation space, + # and compute a temperature adjusted softmax. + + if self.plus: + temperature = self.temperature + self.proj_temperature(fx) + clamped_temp = torch.clamp(temperature, min=0.01).to( slice_projections.dtype ) - - slice_weights = nn.functional.softmax( - slice_projections / clamped_temp, dim=-1 + slice_weights = gumbel_softmax( + slice_projections, clamped_temp ) # [Batch, N_tokens, N_heads, Slice_num] - # Cast to the computation type (since the parameter is probably fp32) - slice_weights = slice_weights.to(slice_projections.dtype) - - # This does the projection of the latent space fx by the weights: - - # Computing the slice tokens is a matmul followed by a normalization. - # It can, unfortunately, overflow in reduced precision, so normalize first: - slice_norm = slice_weights.sum(1) # [Batch, N_heads, Slice_num] - # Sharded note: slice_norm will be a partial sum at this point. - # That's because the we're summing over the tokens, which are distributed - normed_weights = slice_weights / (slice_norm[:, None, :, :] + 1e-2) - # Normed weights has shape - # (batch, n_tokens, n_heads, slice_num) - - # Sharded note: normed_weights will resolve the partial slice_norm - # and the output normed_weights will be sharded. - # fx has shape (Batch, n_tokens, n_heads, head_dim) - # This matmul needs to contract over the tokens - # This should produce an output with shape - # [Batch, N_heads, Slice_num, Head_dim] - - # Like the weight norm, this sum is a **partial** sum since we are summing - # over the tokens - - slice_token = torch.matmul( - normed_weights.permute(0, 2, 3, 1), fx.permute(0, 2, 1, 3) + else: + clamped_temp = torch.clamp(self.temperature, min=0.5, max=5).to( + slice_projections.dtype ) + slice_weights = nn.functional.softmax( + slice_projections / clamped_temp, dim=-1 + ) # [Batch, N_heads, N_tokens, Slice_num] + + # Cast to the computation type (since the parameter is probably fp32) + slice_weights = slice_weights.to(slice_projections.dtype) + + # This does the projection of the latent space fx by the weights: + + # Computing the slice tokens is a matmul followed by a normalization. + # It can, unfortunately, overflow in reduced precision, so normalize first: + slice_norm = slice_weights.sum(1) + 1e-2 # [Batch, N_heads, Slice_num] + # Sharded note: slice_norm will be a partial sum at this point. + # That's because the we're summing over the tokens, which are distributed + normed_weights = slice_weights / (slice_norm[:, None, :, :]) + # Normed weights has shape + # (batch, n_tokens, n_heads, slice_num) + + # Sharded note: normed_weights will resolve the partial slice_norm + # and the output normed_weights will be sharded. + # fx has shape (Batch, n_tokens, n_heads, head_dim) + # This matmul needs to contract over the tokens + # This should produce an output with shape + # [Batch, N_heads, Slice_num, Head_dim] + + # Like the weight norm, this sum is a **partial** sum since we are summing + # over the tokens + + slice_token = torch.matmul( + normed_weights.permute(0, 2, 3, 1), fx.permute(0, 2, 1, 3) + ) - # Return the original weights, not the normed weights: + # Return the original weights, not the normed weights: - return slice_weights, slice_token + return slice_weights, slice_token def compute_slice_attention_te(self, slice_tokens: torch.Tensor) -> torch.Tensor: """ @@ -277,37 +336,42 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Input x should have shape of [Batch, N_tokens, N_Channels] ([B, N, C]) """ - with record_function("forward"): - # Project the inputs onto learned spaces: + # Project the inputs onto learned spaces: + if self.plus: + x_mid = self.project_input_onto_slices(x) + # In transolver ++, fx_mid is gone. + # x_mid is used to compute the projections instead: + fx_mid = x_mid + else: x_mid, fx_mid = self.project_input_onto_slices(x) - # Perform the linear projection of learned latent space onto slices: + # Perform the linear projection of learned latent space onto slices: - slice_projections = self.in_project_slice(x_mid) + slice_projections = self.in_project_slice(x_mid) - # Slice projections has shape [B, N_tokens, N_head, Head_dim], but head_dim may have changed! + # Slice projections has shape [B, N_tokens, N_head, Head_dim], but head_dim may have changed! - # Use the slice projections and learned spaces to compute the slices, and their weights: - slice_weights, slice_tokens = self.compute_slices_from_projections( - slice_projections, fx_mid - ) - # slice_weights has shape [Batch, N_tokens, N_heads, Slice_num] - # slice_tokens has shape [Batch, N_tokens, N_heads, head_dim] + # Use the slice projections and learned spaces to compute the slices, and their weights: + slice_weights, slice_tokens = self.compute_slices_from_projections( + slice_projections, fx_mid + ) + # slice_weights has shape [Batch, N_tokens, N_heads, Slice_num] + # slice_tokens has shape [Batch, N_tokens, N_heads, head_dim] - # Apply attention to the slice tokens - if self.use_te: - out_slice_token = self.compute_slice_attention_te(slice_tokens) - else: - out_slice_token = self.compute_slice_attention_sdpa(slice_tokens) + # Apply attention to the slice tokens + if self.use_te: + out_slice_token = self.compute_slice_attention_te(slice_tokens) + else: + out_slice_token = self.compute_slice_attention_sdpa(slice_tokens) - # Shape unchanged + # Shape unchanged - # Deslice: - outputs = self.project_attention_outputs(out_slice_token, slice_weights) + # Deslice: + outputs = self.project_attention_outputs(out_slice_token, slice_weights) - # Outputs now has the same shape as the original input x + # Outputs now has the same shape as the original input x - return outputs + return outputs class PhysicsAttentionIrregularMesh(PhysicsAttentionBase): @@ -316,18 +380,29 @@ class PhysicsAttentionIrregularMesh(PhysicsAttentionBase): """ def __init__( - self, dim, heads=8, dim_head=64, dropout=0.0, slice_num=64, use_te=True + self, + dim, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + slice_num: int = 64, + use_te: bool = True, + plus: bool = False, ): - super().__init__(dim, heads, dim_head, dropout, slice_num, use_te) + super().__init__(dim, heads, dim_head, dropout, slice_num, use_te, plus) inner_dim = dim_head * heads if use_te: self.in_project_x = te.Linear(dim, inner_dim) - self.in_project_fx = te.Linear(dim, inner_dim) + if not plus: + self.in_project_fx = te.Linear(dim, inner_dim) else: self.in_project_x = nn.Linear(dim, inner_dim) - self.in_project_fx = nn.Linear(dim, inner_dim) + if not plus: + self.in_project_fx = nn.Linear(dim, inner_dim) - def project_input_onto_slices(self, x) -> tuple[torch.Tensor, torch.Tensor]: + def project_input_onto_slices( + self, x + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ Project the input onto the slice space. @@ -338,14 +413,20 @@ def project_input_onto_slices(self, x) -> tuple[torch.Tensor, torch.Tensor]: tuple[torch.Tensor, torch.Tensor]: The projected x and fx tensors of shape [Batch, N_tokens, N_Channels], [Batch, N_tokens, N_heads, Head_dim] """ - fx = self.in_project_fx(x) - fx_mid = rearrange(fx, "B N (h d) -> B N h d", h=self.heads, d=self.dim_head) - x_mid = rearrange( self.in_project_x(x), "B N (h d) -> B N h d", h=self.heads, d=self.dim_head ) + if self.plus: + return x_mid + else: + fx_mid = rearrange( + self.in_project_fx(x), + "B N (h d) -> B N h d", + h=self.heads, + d=self.dim_head, + ) - return x_mid, fx_mid + return x_mid, fx_mid class PhysicsAttentionStructuredMesh2D(PhysicsAttentionBase): @@ -357,14 +438,15 @@ class PhysicsAttentionStructuredMesh2D(PhysicsAttentionBase): def __init__( self, - dim, - spatial_shape, - heads=8, + dim: int, + spatial_shape: tuple[int, int], + heads: int = 8, dim_head=64, - dropout=0.0, - slice_num=64, - kernel=3, - use_te=True, + dropout: float = 0.0, + slice_num: int = 64, + kernel: int = 3, + use_te: bool = True, + plus: bool = False, ): # kernel=3): super().__init__(dim, heads, dim_head, dropout, slice_num, use_te) @@ -373,9 +455,12 @@ def __init__( self.W = spatial_shape[1] self.in_project_x = nn.Conv2d(dim, inner_dim, kernel, 1, kernel // 2) - self.in_project_fx = nn.Conv2d(dim, inner_dim, kernel, 1, kernel // 2) + if not plus: + self.in_project_fx = nn.Conv2d(dim, inner_dim, kernel, 1, kernel // 2) - def project_input_onto_slices(self, x) -> tuple[torch.Tensor, torch.Tensor]: + def project_input_onto_slices( + self, x + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: # Rearrange the input tokens back to an image shape: b = x.shape[0] c = x.shape[-1] @@ -384,25 +469,29 @@ def project_input_onto_slices(self, x) -> tuple[torch.Tensor, torch.Tensor]: x = x.permute(0, 3, 1, 2) # Apply the projections, here they are convolutions in 2D: - input_projected_fx = self.in_project_fx(x) - input_projected_x = self.in_project_x(x) - # Next, re-reshape the projections into token-like shapes: - input_projected_fx = rearrange( - input_projected_fx, - "b (n_heads head_dim) h w -> b (h w) n_heads head_dim", - head_dim=self.dim_head, - n_heads=self.heads, - ) + input_projected_x = self.in_project_x(x) input_projected_x = rearrange( input_projected_x, "b (n_heads head_dim) h w -> b (h w) n_heads head_dim", head_dim=self.dim_head, n_heads=self.heads, ) + if self.plus: + return input_projected_x + else: + input_projected_fx = self.in_project_fx(x) + + # Next, re-reshape the projections into token-like shapes: + input_projected_fx = rearrange( + input_projected_fx, + "b (n_heads head_dim) h w -> b (h w) n_heads head_dim", + head_dim=self.dim_head, + n_heads=self.heads, + ) - # Return the projections: - return input_projected_x, input_projected_fx + # Return the projections: + return input_projected_x, input_projected_fx class PhysicsAttentionStructuredMesh3D(PhysicsAttentionBase): @@ -414,14 +503,15 @@ class PhysicsAttentionStructuredMesh3D(PhysicsAttentionBase): def __init__( self, - dim, - spatial_shape, - heads=8, - dim_head=64, - dropout=0.0, - slice_num=32, - kernel=3, - use_te=True, + dim: int, + spatial_shape: tuple[int, int, int], + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + slice_num: int = 32, + kernel: int = 3, + use_te: int = True, + plus: bool = False, ): super().__init__(dim, heads, dim_head, dropout, slice_num, use_te) @@ -431,9 +521,12 @@ def __init__( self.D = spatial_shape[2] self.in_project_x = nn.Conv3d(dim, inner_dim, kernel, 1, kernel // 2) - self.in_project_fx = nn.Conv3d(dim, inner_dim, kernel, 1, kernel // 2) + if not plus: + self.in_project_fx = nn.Conv3d(dim, inner_dim, kernel, 1, kernel // 2) - def project_input_onto_slices(self, x) -> tuple[torch.Tensor, torch.Tensor]: + def project_input_onto_slices( + self, x + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ Project the input onto the slice space. @@ -448,21 +541,23 @@ def project_input_onto_slices(self, x) -> tuple[torch.Tensor, torch.Tensor]: x = x.permute(0, 4, 1, 2, 3) # Apply the projections, here they are convolutions: - input_projected_fx = self.in_project_fx(x) input_projected_x = self.in_project_x(x) # Next, re-reshape the projections into token-like shapes: - input_projected_fx = rearrange( - input_projected_fx, - "b (n_heads head_dim) h w d-> b (h w d) n_heads head_dim", - head_dim=self.dim_head, - n_heads=self.heads, - ) input_projected_x = rearrange( input_projected_x, "b (n_heads head_dim) h w d -> b (h w d) n_heads head_dim", head_dim=self.dim_head, n_heads=self.heads, ) - - return input_projected_x, input_projected_fx + if self.plus: + return input_projected_x + else: + input_projected_fx = self.in_project_fx(x) + input_projected_fx = rearrange( + input_projected_fx, + "b (n_heads head_dim) h w -> b (h w d) n_heads head_dim", + head_dim=self.dim_head, + n_heads=self.heads, + ) + return input_projected_x, input_projected_fx diff --git a/physicsnemo/models/transolver/transolver.py b/physicsnemo/models/transolver/transolver.py index d24a49ff4e..06f744b855 100644 --- a/physicsnemo/models/transolver/transolver.py +++ b/physicsnemo/models/transolver/transolver.py @@ -40,7 +40,7 @@ import transformer_engine.pytorch as te TE_AVAILABLE = True -except ImportError: +except (ImportError, FileNotFoundError): TE_AVAILABLE = False import physicsnemo # noqa: F401 for docs @@ -123,6 +123,7 @@ def __init__( slice_num=32, spatial_shape: tuple[int, ...] | None = None, use_te=True, + plus: bool = False, ): super().__init__() @@ -145,6 +146,7 @@ def __init__( dropout=dropout, slice_num=slice_num, use_te=use_te, + plus=plus, ) else: if len(spatial_shape) == 2: @@ -156,6 +158,7 @@ def __init__( dropout=dropout, slice_num=slice_num, use_te=use_te, + plus=plus, ) elif len(spatial_shape) == 3: self.Attn = PhysicsAttentionStructuredMesh3D( @@ -166,6 +169,7 @@ def __init__( dropout=dropout, slice_num=slice_num, use_te=use_te, + plus=plus, ) else: raise Exception( @@ -323,6 +327,7 @@ def __init__( structured_shape: None | tuple[int] = None, use_te: bool = True, time_input: bool = False, + plus: bool = False, ) -> None: super().__init__(meta=MetaData()) self.__name__ = "Transolver" @@ -405,6 +410,7 @@ def __init__( spatial_shape=structured_shape, last_layer=(_ == n_layers - 1), use_te=use_te, + plus=plus, ) for _ in range(n_layers) ]