diff --git a/examples/structural_mechanics/crash_domino/README.md b/examples/structural_mechanics/crash_domino/README.md new file mode 100644 index 0000000000..30be4e026f --- /dev/null +++ b/examples/structural_mechanics/crash_domino/README.md @@ -0,0 +1,448 @@ +# DoMINO: Decomposable Multi-scale Iterative Neural Operator for Crash Simulation + +DoMINO is a local, multi-scale, point-cloud based model architecture to model large-scale +physics problems such as structural mechanics crash simulations. The DoMINO model architecture takes STL +geometries as input and evaluates structural response quantities such as displacement fields +on the surface of structures over time. The DoMINO architecture is designed to be a fast, accurate +and scalable surrogate model for transient structural dynamics simulations. + +DoMINO uses local geometric information to predict solutions on discrete points. First, +a global geometry encoding is learnt from point clouds using a multi-scale, iterative +approach. The geometry representation takes into account both short- and long-range +dependencies that are typically encountered in structural dynamics problems. Additional information +such as signed distance field (SDF), positional encoding, and temporal information are used to enrich the global encoding. +Next, discrete points are randomly sampled, a sub-region is constructed around each point +and the local geometry encoding is extracted in this region from the global encoding. +The local geometry information is learnt using dynamic point convolution kernels. +Finally, a computational stencil is constructed dynamically around each discrete point +by sampling random neighboring points within the same sub-region. The local-geometry +encoding and the computational stencil are aggregated to predict the solutions on the +discrete points. + +A preprint describing additional details about the model architecture can be found here +[paper](https://arxiv.org/abs/2501.13350). + +## Recent Updates + +### Model Refactoring (Latest) + +The DoMINO model architecture has been significantly refactored to improve code quality: + +- **Modular Forward Pass**: The monolithic forward method has been decomposed into 12 focused helper methods, each with a single responsibility +- **Separation of Concerns**: Clear distinction between geometry encodings and positional encodings +- **Improved Documentation**: Comprehensive docstrings for all methods with detailed parameter descriptions +- **Enhanced Maintainability**: Easier to test, debug, and extend individual components +- **Backward Compatible**: No changes to the public API - existing training scripts work without modification + +See the [Model Architecture and Code Structure](#model-architecture-and-code-structure) section for details. + +## Prerequisites + +Install the required dependencies by running below: + +```bash +pip install -r requirements.txt +``` + +## Getting started with the Crash Simulation example + +### Configuration basics + +DoMINO training and testing is managed through YAML configuration files +powered by Hydra. The base configuration file, `config.yaml` is located in `src/conf` +directory. + +To select a specific configuration, use the `--config-name` option when running +the scripts. +You can modify configuration options in two ways: + +1. **Direct Editing:** Modify the YAML files directly +2. **Command Line Override:** Use Hydra's `++` syntax to override settings at runtime + +For example, to change the training epochs (controlled by `train.epochs`): + +```bash +python train.py ++train.epochs=200 # Sets number of epochs to 200 +``` + +This modular configuration system allows for flexible experimentation while +maintaining reproducibility. + +#### Project logs + +Save and track project logs, experiments, tensorboard files etc. by specifying a +project directory with `project.name`. Tag experiments with `expt`. + +### Data + +#### Dataset details + +In this example, the DoMINO model is trained using crash simulation datasets for +structural mechanics applications. The dataset contains transient structural dynamics +simulations of crash scenarios, including geometries and time-series displacement fields. +Each simulation includes: +- STL geometry files representing the initial structure +- Time-series displacement fields on the surface mesh +- Global parameters such as applied stress +- Temporal information capturing the evolution of structural deformation + +The data is processed to include multiple timesteps, capturing the transient behavior +of structures under impact or loading conditions. + +#### Data Preprocessing + +`PhysicsNeMo` has a related project to help with data processing, called +[PhysicsNeMo-Curator](https://github.com/NVIDIA/physicsnemo-curator). +Using `PhysicsNeMo-Curator`, the data needed to train a DoMINO model can be setup easily. +Please refer to +[these instructions on getting started](https://github.com/NVIDIA/physicsnemo-curator?tab=readme-ov-file#what-is-physicsnemo-curator) +with `PhysicsNeMo-Curator`. + +The first step for running the DoMINO pipeline requires processing the raw data +(VTP and STL) into either Zarr or NumPy format for training. +Each of the raw simulation files should be in `vtp` (for time-series surface data) and `stl` (for geometry) formats. +The data processing pipeline extracts displacement fields at multiple timesteps and prepares them for training. + +Caching is implemented in the DoMINO datapipe. +Optionally, users can run `cache_data.py` to save outputs +of DoMINO datapipe in the `.npy` files. The DoMINO datapipe is set up to calculate +Signed Distance Field and Nearest Neighbor interpolations on-the-fly during +training. Caching will save these as a preprocessing step and can be used in +cases where the **STL surface meshes or VTP time-series data are very large**. +Data processing is parallelized and takes a couple of hours to write all the +processed files. + +The final processed dataset should be divided and saved into 2 directories, +for training and validation. + +#### Data Scaling factors + +DoMINO has several data-specific configuration tools that rely on some +knowledge of the dataset: + +- The output fields (the labels) are normalized during training to a mean + of zero and a standard deviation of one, averaged over the dataset. + The scaling is controlled by passing the `surface_factors` values to the datapipe. +- The input locations are scaled by, and optionally cropped to, user defined + bounding boxes for the surface. Whether cropping occurs, or not, + is controlled by the `sample_in_bbox` value of the datapipe. Normalization + to the bounding box is enabled with `normalize_coordinates`. By default, + both are set to true. The value of the boxes are configured in the + `config.yaml` file. + +> Note: The datapipe module has a helper function `create_domino_dataset` +> with sensible defaults to help create a Domino Datapipe. + +To facilitate setting reasonable values of these, you can use the +`compute_statistics.py` script. This will load the core dataset as defined +in your `config.yaml` file, loop over several events (200, by default), and +both print and store the surface field statistics as well as the +coordinate statistics. + +#### Training + +Specify the training and validation data paths, bounding box sizes etc. in the +`data` tab and the training configs such as epochs, batch size etc. +in the `train` tab. + +#### Testing + +The testing is directly carried out on raw files. +Specify the testing configs in the `test` tab. + +### Training the DoMINO model + +To train and test the DoMINO model on crash simulation datasets, follow these steps: + +1. Specify the configuration settings in `conf/config.yaml`. + +2. Run `train.py` to start the training. Modify data, train and model keys in config file. + If using cached data then use `conf/cached.yaml` instead of `conf/config.yaml`. + +3. Run `test.py` to test on `.vtp` / `.vtu`. Predictions are written to the same file. + Modify eval key in config file to specify checkpoint, input and output directory. + Important to note that the data used for testing is in the raw simulation format and + should not be processed to `.npy`. + +4. Download the validation results (saved in form of point clouds in `.vtp` / `.vtu` format), + and visualize in Paraview. + +**Training Guidelines:** + +- Duration: Training time depends on dataset size and complexity +- Checkpointing: Automatically resumes from latest checkpoint if interrupted +- Multi-GPU Support: Compatible with `torchrun` or MPI for distributed training +- If the training crashes because of OOM, modify the points sampled on surface + `model.surface_points_sample` and time points `model.time_points_sample` + to manage memory requirements for your GPU +- The DoMINO model for crash simulation focuses on surface displacement fields + over time. The model can be configured for transient simulations with + `model.transient: true` and integration scheme with `model.transient_scheme` + (either "explicit" or "implicit"). +- MSE loss for the surface model gives the best results. +- Bounding box is configurable and will depend on the usecase. + +### Training with Domain Parallelism + +DoMINO has support for training and inference using domain parallelism in PhysicsNeMo, +via the `ShardTensor` mechanisms and pytorch's FSDP tools. `ShardTensor`, built on +PyTorch's `DTensor` object, is a domain-parallel-aware tensor that can live on multiple +GPUs and perform operations in a numerically consistent way. For more information +about the techniques of domain parallelism and `ShardTensor`, refer to PhysicsNeMo +tutorials such as [`ShardTensor`](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.distributed.shardtensor.html). + +In DoMINO specifically, domain parallelism has been enabled in two ways, which +can be used concurrently or separately. First, the input sampled surface points +can be sharded to accommodate higher resolution point sampling across multiple timesteps. +Second, the latent space of the model - typically a regularized grid - can be +sharded to reduce computational complexity of the latent processing. When training +with sharded models in DoMINO, the primary objective is to enable higher +resolution inputs and larger latent spaces without sacrificing substantial compute time. + +When configuring DoMINO for sharded training, adjust the following parameters +from `src/conf/config.yaml`: + +```yaml +domain_parallelism: + domain_size: 2 + shard_grid: True + shard_points: True +``` + +The `domain_size` represents the number of GPUs used for each batch - setting +`domain_size: 1` is not advised since that is the standard training regime, +but with extra overhead. `shard_grid` and `shard_points` will enable domain +parallelism over the latent space and input/output points, respectively. + +As one last note regarding domain-parallel training: in the phase of the DoMINO +where the output solutions are calculated, the model can used two different +techniques (numerically identical) to calculate the output. Due to the +overhead of potential communication at each operation, it's recommended to +use the `one-loop` mode with `model.solution_calculation_mode` when doing +sharded training. This technique launches vectorized kernels with less +launch overhead at the cost of more memory use. For non-sharded +training, the `two-loop` setting is more optimal. The difference in `one-loop` +or `two-loop` is purely computational, not algorithmic. + +### Performance Optimizations + +The training and inference scripts for DoMINO contain several performance +enhancements to accelerate the training and usage of the model. In this +section we'll highlight several of them, as well as how to customize them +if needed. + +#### Memory Pool Optimizations + +The preprocessor of DoMINO requires a computation of k Nearest Neighbors, +which is accelerated via the `cuml` Neighbors tool. By default, `cuml` and +`torch` both use memory allocation pools to speed up allocating tensors, but +they do not use the same pool. This means that during preprocessing, it's +possible for the kNN operation to spend a significant amount of time in +memory allocations - and further, it limits the available memory to `torch`. + +To mitigate this, by default in DoMINO we use the Rapids Memory Manager +([`rmm`](https://github.com/rapidsai/rmm)). If, for some reason, you wish +to disable this you can do so with an environment variable: + +```bash +export PHYSICSNEMO_DISABLE_RMM=True +``` + +Or remove this line from the training script: + +```python +from physicsnemo.utils.memory import unified_gpu_memory +``` + +> Note - why not make it configurable? We have to set up the shared memory +> pool allocation very early in the program, before the config has even +> been read. So, we enable by default and the opt-out path is via the +> environment. + +#### Transient Data Handling + +The dataset size for transient crash simulation data can be substantial due to +multiple timesteps. Each simulation includes time-series displacement data across +all surface points. + +DoMINO's data pipeline handles transient data efficiently by: +- Sampling random time points during training via `model.time_points_sample` +- Sampling surface points at each timestep via `model.surface_points_sample` +- Supporting both explicit and implicit time integration schemes + +For large time-series datasets, preprocessing with `PhysicsNeMo-Curator` can help +organize the data efficiently. The data reader supports both Zarr and NumPy formats. + +#### Overall Performance + +DoMINO is a computationally complex and challenging workload. Over the course +of several releases, we have chipped away at performance bottlenecks to speed +up the training and inference time. We hope these optimizations enable you to explore more +parameters and surrogate models; if there is a performance issue you see, +please open an issue on GitHub. + +### Example Training Results + +To provide an example of what a successful training should look like, monitor +the following during training: + +- **Training Loss**: MSE loss on displacement predictions should decrease over epochs +- **Validation Loss**: Should track training loss without significant divergence +- **L2 Metrics**: Relative L2 error on displacement fields (X, Y, Z components) +- **Displacement Magnitude**: Error in total displacement magnitude across timesteps + +The test script will output detailed metrics including: +- L2 norm for each displacement component +- Mean squared error for displacement fields +- Maximum displacement error +- Time-series displacement accuracy across all timesteps + +Results can be visualized in Paraview using the generated VTP files with time-series data. + + +## Model Architecture and Code Structure + +The DoMINO model has been refactored with improved modularity and maintainability. +The forward pass is now organized into focused, well-documented helper methods +that separate different computational concerns: + +### Key Architectural Components + +1. **Feature Validation**: `_validate_and_extract_features()` + - Validates nodal surface, volume, and geometry features + - Ensures dimensional consistency + +2. **Geometry Encodings**: + - `_compute_volume_encodings()` - Computes volume geometry representations + - `_compute_surface_encodings()` - Computes surface geometry representations + - Handles grid normalization and spatial structure + +3. **Positional Encodings**: + - `_compute_volume_positional_encoding()` - SDF-based and positional features for volume + - `_compute_surface_positional_encoding()` - Positional features for surface + - Applies Fourier-based positional encoders + +4. **Local Geometry Processing**: + - `_compute_volume_local_encodings()` - Local volume geometry features + - `_compute_surface_local_encodings()` - Local surface geometry features + - Supports both transient and steady-state simulations + +5. **Solution Computation**: + - `_compute_volume_output_implicit/explicit()` - Volume solutions + - `_compute_surface_output_implicit/explicit()` - Surface solutions + - Handles both implicit and explicit time integration schemes + +### Benefits of Modular Architecture + +- **Improved Testability**: Each component can be tested independently +- **Better Maintainability**: Changes are localized to specific methods +- **Enhanced Readability**: Clear separation of concerns makes code easier to understand +- **Flexible Development**: Easy to modify geometry vs positional encoding independently + +For developers extending DoMINO, this modular structure makes it straightforward +to customize specific components (e.g., adding new encoding types) without +affecting the entire pipeline. + +### DoMINO model pipeline for inference on test samples + +After training is completed, `test.py` script can be used to run inference on +test samples. Follow the below steps to run the `test.py` + +1. Update the config in the `conf/config.yaml` under the `Testing data Configs` + tab. + +2. The test script is designed to run inference on the raw `.stl` and `.vtp` + files for each test sample. Use the same scaling parameters that + were generated during the training. Typically this is `outputs//`, + where `project.name` is as defined in the `config.yaml`. Update the + `eval.scaling_param_path` accordingly. + +3. Run the `test.py`. The test script can be run in parallel as well. Refer to + the training guidelines for Multi-GPU. Note, for running `test.py` in parallel, + the number of GPUs chosen must be <= the number of test samples. + +4. The output will include time-series VTP files showing predicted displacement fields + at each timestep, which can be loaded in Paraview for visualization. + +## Extending DoMINO to a custom dataset + +This repository includes examples of **DoMINO** training on crash simulation datasets. +However, many use cases require training **DoMINO** on a **custom dataset**. +The steps below outline the process. + +> **Note for Developers**: The refactored model architecture makes it easier to customize +> specific components. For example, if you need custom positional encodings, you can modify +> `_compute_volume_positional_encoding()` or `_compute_surface_positional_encoding()` without +> affecting geometry encoding logic. Similarly, custom geometry representations can be added +> by modifying the `_compute_*_encodings()` methods. All helper methods are well-documented +> with clear input/output specifications. + +1. Reorganize your dataset to have a consistent directory structure. The + raw data directory should contain a separate directory for each simulation. + Each simulation directory needs to contain mainly 2 files: `stl` and `vtp`, + corresponding to the geometry and time-series surface field information. + Additional details such as loading conditions, for example applied stress or impact velocity, + may be added in a separate metadata file, in case these vary from one case to the next. +2. Modify the following parameters in `conf/config.yaml` + - `project.name`: Specify a name for your project. + - `exp_tag`: This is the experiment tag. + - `data_processor.input_dir`: Input directory where the raw simulation dataset is stored. + - `data_processor.output_dir`: Output directory to save the processed dataset (`.npy`). + - `data_processor.num_processors`: Number of parallel processors for data processing. + - `variables.surface`: Variable names of surface fields and fields type (vector or scalar). + For crash simulations, typically this is `Displacement: vector`. + - `variables.global_parameters`: Global parameters like stress, material properties, etc. + - `data.input_dir`: Processed files used for training. + - `data.input_dir_val`: Processed files used for validation. + - `data.bounding_box_surface`: Dimensions of bounding box enclosing the biggest geometry + in dataset. Surface fields are modeled inside this bounding box. + - `train.epochs`: Set the number of training epochs. + - `model.surface_points_sample`: Number of points to sample on the surface mesh per epoch + per batch. Tune based on GPU memory. + - `model.time_points_sample`: Number of time steps to sample per epoch per batch. + - `model.geom_points_sample`: Number of points to sample on STL mesh per epoch per batch. + Ensure point sampled is less than number of points on STL (for coarser STLs). + - `model.transient`: Set to `true` for transient crash simulations. + - `model.transient_scheme`: Choose `explicit` or `implicit` time integration. + - `model.integration_steps`: Number of time steps in the simulation. + - `eval.test_path`: Path of directory of raw simulation files for testing and verification. + - `eval.save_path`: Path of directory where the AI predicted simulation files are saved. + - `eval.checkpoint_name`: Checkpoint name `outputs/{project.name}/models` to evaluate model. + - `eval.scaling_param_path`: Scaling parameters populated in `outputs/{project.name}`. +3. Before running `process_data.py` to process the data, modify it to match your + dataset structure. Key modifications include: + - Non-dimensionalization schemes based on the order of your variables + - Path definitions for STL geometry files and VTP time-series data + - Extraction of displacement fields at each timestep + - Handling of global parameters (stress, loading conditions, etc.) + + For example, you may need to define custom path functions: + + ```python + class CrashSimPaths: + # Specify the name of the STL in your dataset + @staticmethod + def geometry_path(sim_dir: Path) -> Path: + return sim_dir / "geometry.stl" + + # Specify the name of the VTP with time-series data + @staticmethod + def surface_path(sim_dir: Path) -> Path: + return sim_dir / "displacement_timeseries.vtp" + ``` + +4. Before running `train.py`, modify the loss functions in `loss.py` if needed. + The default configuration uses MSE loss with optional area weighting. + For crash simulations with displacement fields, the current loss formulation + works well, but you may want to customize it based on your specific requirements + (e.g., emphasizing certain displacement components or adding physics-based constraints). + +5. Run `test.py` to validate the trained model on test simulations. + +The DoMINO model architecture for crash simulations demonstrates the versatility of +the framework for handling transient structural dynamics problems with complex geometries +and time-varying displacement fields. + +## References + +1. [DoMINO: A Decomposable Multi-scale Iterative Neural Operator for Modeling Large Scale Engineering Simulations](https://arxiv.org/abs/2501.13350) diff --git a/examples/structural_mechanics/crash_domino/requirements.txt b/examples/structural_mechanics/crash_domino/requirements.txt new file mode 100644 index 0000000000..f7204a27ec --- /dev/null +++ b/examples/structural_mechanics/crash_domino/requirements.txt @@ -0,0 +1,6 @@ +torchinfo +warp-lang +tensorboard +cuml +einops +tensorstore diff --git a/examples/structural_mechanics/crash_domino/src/compute_statistics.py b/examples/structural_mechanics/crash_domino/src/compute_statistics.py new file mode 100644 index 0000000000..5a8dfbf052 --- /dev/null +++ b/examples/structural_mechanics/crash_domino/src/compute_statistics.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Compute and save scaling factors for DoMINO datasets. + +This script computes mean, standard deviation, minimum, and maximum values +for all field variables in a DoMINO dataset. The computed statistics are +saved in a structured format that can be easily loaded and used for +normalization during training and inference. + +The script uses the same configuration system as the training script, +ensuring consistency in dataset handling and processing parameters. +""" + +import os +import time +from pathlib import Path + +import hydra +import torch +from omegaconf import DictConfig, OmegaConf + +from physicsnemo.distributed import DistributedManager +from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper + +from physicsnemo.datapipes.cae.domino_datapipe_transient import compute_scaling_factors +from utils import ScalingFactors + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + """ + Main function to compute and save scaling factors. + + Args: + cfg: Hydra configuration object containing all parameters + """ + # ################################ + # # Force single-process mode for statistics computation + # # This script doesn't benefit from distributed execution + # ################################ + # for var in ['RANK', 'WORLD_RANK', 'WORLD_SIZE', 'LOCAL_RANK']: + # os.environ.pop(var, None) + + ################################ + # Initialize distributed manager + ################################ + DistributedManager.initialize() + dist = DistributedManager() + + ################################ + # Initialize logger + ################################ + logger = PythonLogger("ComputeStatistics") + logger = RankZeroLoggingWrapper(logger, dist) + + logger.info("Starting scaling factors computation") + logger.info(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") + + ################################ + # Create output directory + ################################ + output_dir = os.path.dirname(cfg.data.scaling_factors) + os.makedirs(output_dir, exist_ok=True) + + if dist.world_size > 1: + torch.distributed.barrier() + + ################################ + # Check if scaling exists + ################################ + pickle_path = output_dir + "/scaling_factors.pkl" + + try: + scaling_factors = ScalingFactors.load(pickle_path) + logger.info(f"Scaling factors loaded from: {pickle_path}") + except FileNotFoundError: + logger.info(f"Scaling factors not found at: {pickle_path}; recomputing.") + scaling_factors = None + + ################################ + # Compute scaling factors + ################################ + if scaling_factors is None: + logger.info("Computing scaling factors from dataset...") + start_time = time.perf_counter() + + target_keys = [ + "surface_fields", + "surface_mesh_centers", + ] + + mean, std, min_val, max_val = compute_scaling_factors( + cfg=cfg, + input_path=cfg.data.input_dir, + target_keys=target_keys, + max_samples=cfg.data.max_samples_for_statistics, + ) + mean = {k: m.cpu().numpy() for k, m in mean.items()} + std = {k: s.cpu().numpy() for k, s in std.items()} + min_val = {k: m.cpu().numpy() for k, m in min_val.items()} + max_val = {k: m.cpu().numpy() for k, m in max_val.items()} + + compute_time = time.perf_counter() - start_time + logger.info( + f"Scaling factors computation completed in {compute_time:.2f} seconds" + ) + + ################################ + # Create structured data object + ################################ + dataset_info = { + "input_path": cfg.data.input_dir, + "model_type": cfg.model.model_type, + "normalization": cfg.model.normalization, + "compute_time": compute_time, + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "config_name": cfg.project.name, + } + + scaling_factors = ScalingFactors( + mean=mean, + std=std, + min_val=min_val, + max_val=max_val, + field_keys=target_keys, + ) + + ################################ + # Save scaling factors + ################################ + if dist.rank == 0: + # Save as structured pickle file + pickle_path = output_dir + "/scaling_factors.pkl" + scaling_factors.save(pickle_path) + logger.info(f"Scaling factors saved to: {pickle_path}") + + # Save summary report + summary_path = output_dir + "/scaling_factors_summary.txt" + with open(summary_path, "w") as f: + f.write(scaling_factors.summary()) + logger.info(f"Summary report saved to: {summary_path}") + + ################################ + # Display summary + ################################ + logger.info("Scaling factors computation summary:") + logger.info(f"Field keys processed: {scaling_factors.field_keys}") + + logger.info("Scaling factors computation completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/examples/structural_mechanics/crash_domino/src/conf/cached.yaml b/examples/structural_mechanics/crash_domino/src/conf/cached.yaml new file mode 100644 index 0000000000..f2a316df51 --- /dev/null +++ b/examples/structural_mechanics/crash_domino/src/conf/cached.yaml @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - config + - _self_ + +exp_tag: cached + +data: # Input directory for training and validation data + input_dir: /lustre/cached/drivaer_aws/drivaer_data_full/train/ + input_dir_val: /lustre/cached/drivaer_aws/drivaer_data_full/val/ +data_processor: + use_cache: true + +train: # Training configurable parameters + dataloader: + num_workers: 12 + +val: # Validation configurable parameters + dataloader: + num_workers: 6 \ No newline at end of file diff --git a/examples/structural_mechanics/crash_domino/src/conf/config.yaml b/examples/structural_mechanics/crash_domino/src/conf/config.yaml new file mode 100644 index 0000000000..afac7471f7 --- /dev/null +++ b/examples/structural_mechanics/crash_domino/src/conf/config.yaml @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ┌───────────────────────────────────────────┐ +# │ Project Details │ +# └───────────────────────────────────────────┘ +project: # Project name + name: Crash_Dataset_Displacement + +exp_tag: 1 # Experiment tag +# Main output directory. +project_dir: outputs/${project.name}/ +output: outputs/${project.name}/${exp_tag} + +hydra: # Hydra config + run: + dir: ${output} + output_subdir: hydra # Default is .hydra which causes files not being uploaded in W&B. + +# The directory to search for checkpoints to continue training. +resume_dir: ${output}/models + +# ┌───────────────────────────────────────────┐ +# │ Data Preprocessing │ +# └───────────────────────────────────────────┘ +data_processor: # Data processor configurable parameters + output_dir: /user/crash_data_all/ + input_dir: /user/data/crash_data_all/ + cached_dir: /user/cached/crash_data_all/ + use_cache: false + num_processors: 12 + +# ┌───────────────────────────────────────────┐ +# │ Solution variables │ +# └───────────────────────────────────────────┘ +variables: + surface: + solution: + # The following is for AWS DrivAer dataset. + Displacement: vector + volume: + solution: + # The following is for AWS DrivAer dataset. + Stress: vector + global_parameters: + stress: + type: scalar + reference: [1.0] + +# ┌───────────────────────────────────────────┐ +# │ Data Configs │ +# └───────────────────────────────────────────┘ +data: # Input directory for training and validation data + input_dir: /user/data/crash_data_all/ + input_dir_val: /user/data/crash_data_all_val/ + bounding_box: # Bounding box dimensions for computational domain + min: [560, -840, 650] + max: [3350 , 850, 1320] + bounding_box_surface: # Bounding box dimensions for car surface + min: [560, -840, 650] + max: [3350, 850, 1320] + gpu_preprocessing: true + gpu_output: true + normalize_coordinates: true + sample_in_bbox: false + sampling: true + scaling_factors: ${project_dir}/scaling_factors/scaling_factors.pkl + volume_sample_from_disk: false + max_samples_for_statistics: 200 + +# ┌───────────────────────────────────────────┐ +# │ Domain Parallelism Settings │ +# └───────────────────────────────────────────┘ +domain_parallelism: + domain_size: 1 + shard_grid: false + shard_points: false + +# ┌───────────────────────────────────────────┐ +# │ Model Parameters │ +# └───────────────────────────────────────────┘ +model: + model_type: surface # train which model? surface, volume, combined + transient: true # Whether to use transient model + transient_scheme: "explicit" # "explicit" or "implicit" + integration_steps: 10 # Number of integration steps for transient model + activation: "relu" # "relu" or "gelu" + loss_function: + loss_type: "mse" # mse or rmse + area_weighing_factor: 0.004 # Generally inverse of maximum area + interp_res: [128, 32, 32] # resolution of latent space 128, 64, 48 + use_sdf_in_basis_func: false # SDF in basis function network + volume_points_sample: 8192 # Number of points to sample in volume per epoch + surface_points_sample: 2000 # Number of points to sample on surface per epoch + time_points_sample: 10 # Number of time points to sample per epoch + surface_sampling_algorithm: random #random or area_weighted + mesh_type: "node" # element or node + geom_points_sample: 80_000 # Number of points to sample on STL per epoch + num_neighbors_surface: 7 # How many neighbors on surface? + num_neighbors_volume: 10 # How many neighbors on volume? + combine_volume_surface: false # combine volume and surface encodings + return_volume_neighbors: false # Whether to return volume neighbors or not + use_surface_normals: false # Use surface normals and surface areas for surface computation? + use_surface_area: false # Use only surface normals and not surface area + integral_loss_scaling_factor: 100 # Scale integral loss by this factor + normalization: min_max_scaling # or mean_std_scaling + encode_parameters: false # encode inlet velocity and air density in the model + surf_loss_scaling: 1.0 # scale surface loss with this factor in combined mode + vol_loss_scaling: 1.0 # scale volume loss with this factor in combined mode + geometry_encoding_type: stl # geometry encoder type, sdf, stl, both + solution_calculation_mode: two-loop # one-loop is better for sharded, two-loop is lower memory but more overhead. Physics losses are not supported via one-loop presently. + geometry_rep: # Hyperparameters for geometry representation network + geo_conv: + base_neurons: 32 # 256 or 64 + base_neurons_in: 1 + base_neurons_out: 1 + volume_radii: [0.01, 0.05, 0.1, 0.5, 1.0] # radii for volume + surface_radii: [0.05, 0.1, 0.5, 1.0] # radii for surface + surface_hops: 1 # Number of surface iterations + volume_hops: 1 # Number of volume iterations + volume_neighbors_in_radius: [8, 64, 128, 256, 512] # Number of neighbors in radius for volume + surface_neighbors_in_radius: [8, 16, 64, 128] # Number of neighbors in radius for surface + fourier_features: false + num_modes: 5 + activation: ${model.activation} + geo_processor: + base_filters: 8 + activation: ${model.activation} + processor_type: conv # conv or unet (conv is better; fno, fignet to be added) + self_attention: false # can be used only with unet + cross_attention: false # can be used only with unet + surface_sdf_scaling_factor: [0.01, 0.02, 0.04] # Scaling factor for SDF, smaller is more emphasis on surface + volume_sdf_scaling_factor: [0.04] # Scaling factor for SDF, smaller is more emphasis on surface + nn_basis_functions: # Hyperparameters for basis function network + base_layer: 512 + fourier_features: true + num_modes: 5 + activation: ${model.activation} + local_point_conv: + activation: ${model.activation} + aggregation_model: # Hyperparameters for aggregation network + base_layer: 512 + activation: ${model.activation} + position_encoder: # Hyperparameters for position encoding network + base_neurons: 512 + activation: ${model.activation} + fourier_features: true + num_modes: 5 + geometry_local: # Hyperparameters for local geometry extraction + volume_neighbors_in_radius: [64, 128] # Number of radius points + surface_neighbors_in_radius: [32, 64, 128, 256] # Number of radius points + volume_radii: [0.1, 0.25] # Volume radii + surface_radii: [0.05, 0.1, 0.5, 1.0] # Surface radii + base_layer: 512 + parameter_model: + base_layer: 512 + fourier_features: false + num_modes: 5 + activation: ${model.activation} + +# ┌───────────────────────────────────────────┐ +# │ Training Configs │ +# └───────────────────────────────────────────┘ +train: # Training configurable parameters + epochs: 1000 + checkpoint_interval: 2 + dataloader: + batch_size: 1 + preload_depth: 1 + pin_memory: True # if the preprocessing is outputing GPU data, set this to false + sampler: + shuffle: true + drop_last: false + checkpoint_dir: /user/models/ # Use only for retraining + add_physics_loss: false + lr_scheduler: + name: MultiStepLR # Also supports CosineAnnealingLR + milestones: [50, 200, 400, 500, 600, 700, 800, 900] # only used if lr_scheduler is MultiStepLR + gamma: 0.5 # only used if lr_scheduler is MultiStepLR + T_max: ${train.epochs} # only used if lr_scheduler is CosineAnnealingLR + eta_min: 1e-6 # only used if lr_scheduler is CosineAnnealingLR + optimizer: + name: Adam # or AdamW + lr: 0.001 + weight_decay: 0.0 + amp: + enabled: true + autocast: + dtype: torch.float16 + scaler: + _target_: torch.cuda.amp.GradScaler + enabled: ${..enabled} + clip_grad: true + grad_max_norm: 2.0 + + +# ┌───────────────────────────────────────────┐ +# │ Validation Configs │ +# └───────────────────────────────────────────┘ +val: # Validation configurable parameters + dataloader: + batch_size: 1 + preload_depth: 1 + pin_memory: true # if the preprocessing is outputing GPU data, set this to false + sampler: + shuffle: true + drop_last: false + +# ┌───────────────────────────────────────────┐ +# │ Testing data Configs │ +# └───────────────────────────────────────────┘ +eval: # Testing configurable parameters + test_path: /user/testing_data # Dir for testing data in raw format (vtp, vtu ,stls) + save_path: /user/predicted_data # Dir to save predicted results in raw format (vtp, vtu) + checkpoint_name: DoMINO.0.455.pt # Name of checkpoint to select from saved checkpoints + scaling_param_path: /user/scaling_params + refine_stl: False # Automatically refine STL during inference + num_points: 1_240_000 # Number of points to sample on surface and volume per batch diff --git a/examples/structural_mechanics/crash_domino/src/crash_datapipe.py b/examples/structural_mechanics/crash_domino/src/crash_datapipe.py new file mode 100644 index 0000000000..47411fad26 --- /dev/null +++ b/examples/structural_mechanics/crash_domino/src/crash_datapipe.py @@ -0,0 +1,198 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This is the datapipe to read OpenFoam files (vtp/vtu/stl) and save them as point clouds +in npy format. + +""" + +import time, random +from collections import defaultdict +from pathlib import Path +from typing import Any, Iterable, List, Literal, Mapping, Optional, Union, Callable, Dict + +import numpy as np +import pandas as pd +import pyvista as pv +import vtk +from physicsnemo.utils.domino.utils import * +from torch.utils.data import Dataset +from utils import extract_index_from_filename, extract_time_series_info, get_time_series_data + +class CrashDataset(Dataset): + """ + Datapipe for converting openfoam dataset to npy + + """ + + def __init__( + self, + input_dir: Union[str, Path], + surface_variables: Optional[list] = [ + "pMean", + "wallShearStress", + ], + volume_variables: Optional[list] = ["UMean", "pMean"], + global_params_types: Optional[dict] = { + "stress": "vector", + }, + global_params_reference: Optional[dict] = { + "stress": [1.0], + }, + device: int = 0, + model_type=None, + transient_scheme="explicit", + ): + if isinstance(input_dir, str): + input_dir = Path(input_dir) + input_dir = input_dir.expanduser() + + self.data_path = input_dir + + assert self.data_path.exists(), f"Path {self.data_path} does not exist" + + assert self.data_path.is_dir(), f"Path {self.data_path} is not a directory" + + self.filenames = get_filenames(self.data_path) + random.shuffle(self.filenames) + self.indices = np.array(len(self.filenames)) + + self.surface_variables = surface_variables + self.volume_variables = volume_variables + + self.global_params_types = global_params_types + self.global_params_reference = global_params_reference + + self.stress = self.global_params_reference["stress"] + + self.device = device + self.model_type = model_type + self.transient_scheme = transient_scheme + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + cfd_filename = self.filenames[idx] + file_index = extract_index_from_filename(cfd_filename) + + displacement_dir = self.data_path / f"run{file_index}_displacement.vtp" + + mesh_displacement = pv.read(displacement_dir) + + stl_vertices = mesh_displacement.points + + mesh_indices_flattened = np.array(mesh_displacement.faces).reshape((-1, 4))[:, 1:].flatten() # Assuming triangular elements + length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) + + stl_sizes = mesh_displacement.compute_cell_sizes(length=False, area=True, volume=False) + stl_sizes = np.array(stl_sizes.cell_data["Area"]) + stl_centers = np.array(mesh_displacement.cell_centers().points) + + cell_data = mesh_displacement.point_data_to_cell_data() + surface_coordinates_centers = cell_data.cell_centers().points + surface_normals = np.array(cell_data.cell_normals) + surface_sizes = cell_data.compute_cell_sizes( + length=False, area=True, volume=False + ) + surface_sizes = np.array(surface_sizes.cell_data["Area"]) + timesteps, displacement_data, magnitude_data = get_time_series_data(mesh_displacement, data_prefix="displacement") + surface_fields = displacement_data #Displacements are from the starting position, not the previous timestep + surface_coordinates = mesh_displacement.points + + surface_coordinates_all = [] + surface_normals_all = [] + surface_sizes_all = [] + for i in range(surface_fields.shape[0]): + surface_coordinates_all.append(surface_coordinates + surface_fields[i]) + surface_normals_all.append(surface_normals) + surface_sizes_all.append(surface_sizes) + surface_coordinates_all = np.asarray(surface_coordinates_all) + surface_normals_all = np.asarray(surface_normals_all) + surface_sizes_all = np.asarray(surface_sizes_all) + + surface_coordinates = np.concatenate([np.expand_dims(surface_coordinates, 0), surface_coordinates_all], axis=0) + surface_normals = np.concatenate([np.expand_dims(surface_normals, 0), surface_normals_all], axis=0) + surface_sizes = np.concatenate([np.expand_dims(surface_sizes, 0), surface_sizes_all], axis=0) + + # For implicit scheme, we need to add the displacements from the previous timestep to the current position + if self.transient_scheme == "implicit": + surface_fields_new = [] + for i in range(surface_coordinates.shape[0]-1): + surface_fields_new.append(surface_coordinates[i+1] - surface_coordinates[i]) + surface_fields = np.asarray(surface_fields_new) + + surface_coordinates = surface_coordinates[:-1] + surface_normals = surface_normals[:-1] + surface_sizes = surface_sizes[:-1] + + # Arrange global parameters reference in a list based on the type of the parameter + global_params_reference_list = [] + for name, type in self.global_params_types.items(): + if type == "vector": + global_params_reference_list.extend(self.global_params_reference[name]) + elif type == "scalar": + global_params_reference_list.append(self.global_params_reference[name]) + else: + raise ValueError( + f"Global parameter {name} not supported for this dataset" + ) + global_params_reference = np.array( + global_params_reference_list, dtype=np.float32 + ) + + # Prepare the list of global parameter values for each simulation file + # Note: The user must ensure that the values provided here correspond to the + # `global_parameters` specified in `config.yaml` and that these parameters + # exist within each simulation file. + global_params_values_list = [] + for key in self.global_params_types.keys(): + if key == "stress": + global_params_values_list.extend( + self.global_params_reference["stress"] + ) + else: + raise ValueError( + f"Global parameter {key} not supported for this dataset" + ) + global_params_values = np.array(global_params_values_list, dtype=np.float32) + + # Add the parameters to the dictionary + return { + "stl_coordinates": np.float32(surface_coordinates[0]), # np.float32(surface_coordinates[0]) + # "stl_centers": np.float32(stl_centers), # surface centers + # "stl_faces": np.float32(mesh_indices_flattened), # + # "stl_areas": np.float32(stl_sizes), + "surface_mesh_centers": np.float32(surface_coordinates), # (t, N 3) coordinates, x, y, z + # "surface_normals": np.float32(surface_normals), + # "surface_areas": np.float32(surface_sizes), + "surface_fields": np.float32(surface_fields), # (t, N, 3) acceleration + "surface_features": np.float32(surface_coordinates), # (N, 1) thickness, This can be thickness and material properties on nodes + "stl_features": np.float32(surface_coordinates[0]), # This can be thickness and material properties on nodes # rename tostl_features + "timesteps": np.float32(timesteps), # t = [0, 1, 2, 3, 4, 5, 6] + "filename": cfd_filename, + } + + +if __name__ == "__main__": + fm_data = CrashDataset( + input_dir="/user/data/", + surface_variables=["pMean", "wallShearStress"], + global_params_types={"stress": "vector"}, + global_params_reference={"stress": [1.0]}, + ) + d_dict = fm_data[1] \ No newline at end of file diff --git a/examples/structural_mechanics/crash_domino/src/loss.py b/examples/structural_mechanics/crash_domino/src/loss.py new file mode 100644 index 0000000000..249cb1e562 --- /dev/null +++ b/examples/structural_mechanics/crash_domino/src/loss.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from typing import Literal, Any + +from physicsnemo.utils.domino.utils import unnormalize + +from typing import Literal, Any + +import torch.cuda.nvtx as nvtx + +from physicsnemo.utils.domino.utils import * + + +def loss_fn( + output: torch.Tensor, + target: torch.Tensor, + loss_type: Literal["mse", "rmse"], + padded_value: float = -10, +) -> torch.Tensor: + """Calculate mean squared error or root mean squared error with masking for padded values. + + Args: + output: Predicted values from the model + target: Ground truth values + loss_type: Type of loss to calculate ("mse" or "rmse") + padded_value: Value used for padding in the tensor + + Returns: + Calculated loss as a scalar tensor + """ + mask = abs(target - padded_value) > 1e-3 + + dims = (0, 1, 2) + + num = torch.sum(mask * (output - target) ** 2.0, dims) + if loss_type == "rmse": + denom = torch.sum(mask * (target - torch.mean(target, dims)) ** 2.0, dims) + loss = torch.mean(num / denom) + elif loss_type == "mse": + denom = torch.sum(mask, dims) + loss = torch.mean(num / denom) + else: + raise ValueError(f"Invalid loss type: {loss_type}") + + return loss + + + +def compute_loss_dict( + prediction_vol: torch.Tensor, + prediction_surf: torch.Tensor, + batch_inputs: dict, + loss_fn_type: dict, + surf_loss_scaling: float, +) -> tuple[torch.Tensor, dict]: + """ + Compute the loss terms in a single function call. + + Computes: + - Volume loss if prediction_vol is not None + - Surface loss if prediction_surf is not None + - Total loss as a weighted sum of the above + + Returns: + - Total loss as a scalar tensor + - Dictionary of loss terms (for logging, etc) + """ + nvtx.range_push("Loss Calculation") + total_loss_terms = [] + loss_dict = {} + + if prediction_vol is not None: + target_vol = batch_inputs["volume_fields"] + + loss_vol = loss_fn( + prediction_vol, + target_vol, + loss_fn_type.loss_type, + padded_value=-10, + ) + loss_dict["loss_vol"] = loss_vol + total_loss_terms.append(loss_vol) + + if prediction_surf is not None: + target_surf = batch_inputs["surface_fields"] + + loss_surf = loss_fn( + prediction_surf, + target_surf, + loss_fn_type.loss_type, + ) + + if loss_fn_type.loss_type == "mse": + loss_surf = loss_surf * surf_loss_scaling + + total_loss_terms.append(loss_surf) + loss_dict["loss_surf"] = loss_surf + + total_loss = sum(total_loss_terms) + loss_dict["total_loss"] = total_loss + nvtx.range_pop() + + return total_loss, loss_dict diff --git a/examples/structural_mechanics/crash_domino/src/process_data.py b/examples/structural_mechanics/crash_domino/src/process_data.py new file mode 100644 index 0000000000..17e7da0a7a --- /dev/null +++ b/examples/structural_mechanics/crash_domino/src/process_data.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code runs the data processing in parallel to load OpenFoam files, process them +and save in the npy format for faster processing in the DoMINO datapipes. Several +parameters such as number of processors, input and output paths, etc. can be +configured in config.yaml in the data_processing tab. +""" + +from crash_datapipe import CrashDataset +from physicsnemo.utils.domino.utils import * +import multiprocessing +import hydra, time, os +from hydra.utils import to_absolute_path +from omegaconf import DictConfig, OmegaConf +import numpy as np + + +def process_files(*args_list): + ids = args_list[0] + processor_id = args_list[1] + fm_data = args_list[2] + output_dir = args_list[3] + for j in ids: + fname = fm_data.filenames[j] + outname = os.path.join(output_dir, fname) + print("Filename:%s on processor: %d" % (outname, processor_id)) + filename = f"{outname}.npy" + if os.path.exists(filename): + print(f"Skipping {filename} - already exists.") + continue + start_time = time.time() + data_dict = fm_data[j] + np.save(filename, data_dict) + print("Time taken for %d = %f" % (j, time.time() - start_time)) + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig): + print(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") + phase = "train" + volume_variable_names = list(cfg.variables.volume.solution.keys()) + num_vol_vars = 0 + for j in volume_variable_names: + if cfg.variables.volume.solution[j] == "vector": + num_vol_vars += 3 + else: + num_vol_vars += 1 + + surface_variable_names = list(cfg.variables.surface.solution.keys()) + num_surf_vars = 0 + for j in surface_variable_names: + if cfg.variables.surface.solution[j] == "vector": + num_surf_vars += 3 + else: + num_surf_vars += 1 + + # Extract global parameters names and reference values + global_params_names = list(cfg.variables.global_parameters.keys()) + global_params_reference = { + name: cfg.variables.global_parameters[name]["reference"] + for name in global_params_names + } + global_params_types = { + name: cfg.variables.global_parameters[name]["type"] + for name in global_params_names + } + + fm_data = CrashDataset( + input_dir=cfg.data_processor.input_dir, + volume_variables=volume_variable_names, + surface_variables=surface_variable_names, + global_params_types=global_params_types, + global_params_reference=global_params_reference, + model_type=cfg.model.model_type, + transient_scheme=cfg.model.transient_scheme + ) + output_dir = cfg.data_processor.output_dir + create_directory(output_dir) + n_processors = cfg.data_processor.num_processors + + num_files = len(fm_data) + ids = np.arange(num_files) + num_elements = int(num_files / n_processors) + 1 + process_list = [] + ctx = multiprocessing.get_context("spawn") + for i in range(n_processors): + if i != n_processors - 1: + sf = ids[i * num_elements : i * num_elements + num_elements] + else: + sf = ids[i * num_elements :] + # print(sf) + process = ctx.Process(target=process_files, args=(sf, i, fm_data, output_dir)) + + process.start() + process_list.append(process) + + for process in process_list: + process.join() + + +if __name__ == "__main__": + main() diff --git a/examples/structural_mechanics/crash_domino/src/test.py b/examples/structural_mechanics/crash_domino/src/test.py new file mode 100644 index 0000000000..41d5d9a088 --- /dev/null +++ b/examples/structural_mechanics/crash_domino/src/test.py @@ -0,0 +1,733 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code defines a distributed pipeline for testing the DoMINO model on +Crash datasets. It includes the instantiating the DoMINO model and datapipe, +automatically loading the most recent checkpoint, reading the VTP/VTU/STL +testing files, calculation of parameters required for DoMINO model and +evaluating the model in parallel using DistributedDataParallel across multiple +GPUs. This is a common recipe that enables training of surface model. +The model predictions are loaded in +the the VTP/VTU/STL files and saved in the specified directory. The eval tab in +config.yaml can be used to specify the input and output directories. +""" + +import os, re +import time + +import hydra +from hydra.utils import to_absolute_path +from omegaconf import DictConfig, OmegaConf + +# This will set up the cupy-ecosystem and pytorch to share memory pools +from physicsnemo.utils.memory import unified_gpu_memory + +import numpy as np +import cupy as cp + +from collections import defaultdict +from pathlib import Path +from typing import Any, Iterable, List, Literal, Mapping, Optional, Union, Callable + +import pandas as pd +import pyvista as pv + +import torch +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DataLoader, Dataset + +import vtk +from vtk.util import numpy_support + +from physicsnemo.distributed import DistributedManager +from physicsnemo.datapipes.cae.domino_datapipe_transient import DoMINODataPipe +from physicsnemo.models.domino_transient.model import DoMINO +from physicsnemo.models.domino_transient.geometry_rep import scale_sdf +from physicsnemo.utils.domino.utils import * +from physicsnemo.utils.domino.vtk_file_utils import * +from physicsnemo.utils.sdf import signed_distance_field +from physicsnemo.utils.neighbors import knn +from utils import ScalingFactors, load_scaling_factors +from utils import get_time_series_data + +def loss_fn(output, target): + masked_loss = torch.mean(((output - target) ** 2.0), (0, 1, 2)) + loss = torch.mean(masked_loss) + return loss + + +def test_step(data_dict, model, device, cfg, surf_factors): + + output_features_surf = True + + with torch.no_grad(): + point_batch_size = 256000 + # data_dict = dict_to_device(data_dict, device) + + # Non-dimensionalization factors + length_scale = data_dict["length_scale"] + + global_params_values = data_dict["global_params_values"] + global_params_reference = data_dict["global_params_reference"] + stress = global_params_reference[:, 0, :] + + # STL nodes + geo_centers = data_dict["geometry_coordinates"] + + # Bounding box grid + s_grid = data_dict["surf_grid"] + sdf_surf_grid = data_dict["sdf_surf_grid"] + # Scaling factors + surf_max = data_dict["surface_min_max"][:, 1] + surf_min = data_dict["surface_min_max"][:, 0] + + if output_features_surf is not None: + # Represent geometry on bounding box + geo_centers_surf = ( + 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + ) + encoding_g_surf = model.geo_rep_surface( + geo_centers_surf, s_grid, sdf_surf_grid + ) + + if output_features_surf is not None: + # Next calculate surface predictions + # Sampled points on surface + surface_mesh_centers = data_dict["surface_mesh_centers"] + surface_normals = data_dict["surface_normals"] + surface_areas = data_dict["surface_areas"] + + # Neighbors of sampled points on surface + surface_mesh_neighbors = data_dict["surface_mesh_neighbors"] + surface_neighbors_normals = data_dict["surface_neighbors_normals"] + surface_neighbors_areas = data_dict["surface_neighbors_areas"] + surface_areas = torch.unsqueeze(surface_areas, -1) + surface_neighbors_areas = torch.unsqueeze(surface_neighbors_areas, -1) + pos_surface_center_of_mass = data_dict["pos_surface_center_of_mass"] + num_points = surface_mesh_centers.shape[2] + subdomain_points = int(np.floor(num_points / point_batch_size)) + + target_surf = data_dict["surface_fields"] + prediction_surf = torch.zeros_like(target_surf) + + start_time = time.time() + + for p in range(subdomain_points + 1): + start_idx = p * point_batch_size + end_idx = (p + 1) * point_batch_size + with torch.no_grad(): + target_batch = target_surf[:, start_idx:end_idx] + surface_mesh_centers_batch = surface_mesh_centers[ + :, :, start_idx:end_idx + ] + surface_mesh_neighbors_batch = surface_mesh_neighbors[ + :, :, start_idx:end_idx + ] + surface_normals_batch = surface_normals[:, :, start_idx:end_idx] + surface_neighbors_normals_batch = surface_neighbors_normals[ + :, :, start_idx:end_idx + ] + surface_areas_batch = surface_areas[:, :, start_idx:end_idx] + surface_neighbors_areas_batch = surface_neighbors_areas[ + :, :, start_idx:end_idx + ] + pos_surface_center_of_mass_batch = pos_surface_center_of_mass[ + :, :, start_idx:end_idx + ] + + if cfg.model.transient: + geo_encoding_local_all = [] + for i in range(surface_mesh_centers.shape[1]): + geo_encoding_local_i = model.surface_local_geo_encodings( + 0.5 * encoding_g_surf, surface_mesh_centers_batch[:, i, :, :3], s_grid + ) + geo_encoding_local_all.append(torch.unsqueeze(geo_encoding_local_i, 1)) + geo_encoding_local = torch.cat(geo_encoding_local_all, dim=1) + else: + geo_encoding_local = model.surface_local_geo_encodings( + 0.5 * encoding_g_surf, + surface_mesh_centers_batch, + s_grid, + ) + pos_encoding = model.fc_p_surf(pos_surface_center_of_mass_batch) + + if cfg.model.transient_scheme == "implicit": + for i in range(cfg.model.integration_steps): + if i == 0: + surface_mesh_centers_batch_i = surface_mesh_centers_batch[:, i].clone() + surface_mesh_neighbors_batch_i = surface_mesh_neighbors_batch[:, i].clone() + else: + surface_mesh_centers_batch_i[:, :, :3] += tpredictions_batch + for j in range(surface_mesh_neighbors_batch_i.shape[2]): + surface_mesh_neighbors_batch_i[:, :, j, :3] += tpredictions_batch + + tpredictions_batch = model.solution_calculator_surf( + surface_mesh_centers_batch_i, + geo_encoding_local[:, i], + pos_encoding[:, i], + surface_mesh_neighbors_batch_i, + surface_normals_batch[:, i], + surface_neighbors_normals_batch[:, i], + surface_areas_batch[:, i], + surface_neighbors_areas_batch[:, i], + global_params_values, + global_params_reference, + ) + prediction_surf[:, i, start_idx:end_idx] = tpredictions_batch + else: + for i in range(surface_mesh_centers.shape[1]): + tpredictions_batch = model.solution_calculator_surf( + surface_mesh_centers_batch[:, i], + geo_encoding_local[:, i], + pos_encoding[:, i], + surface_mesh_neighbors_batch[:, i], + surface_normals_batch[:, i], + surface_neighbors_normals_batch[:, i], + surface_areas_batch[:, i], + surface_neighbors_areas_batch[:, i], + global_params_values, + global_params_reference, + ) + prediction_surf[:, i, start_idx:end_idx] = tpredictions_batch + + if cfg.model.normalization == "min_max_scaling": + prediction_surf = unnormalize( + prediction_surf, surf_factors[0], surf_factors[1] + ) + elif cfg.model.normalization == "mean_std_scaling": + prediction_surf = unstandardize( + prediction_surf, surf_factors[0], surf_factors[1] + ) + + else: + prediction_surf = None + + return prediction_surf + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig): + print(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") + + input_path = cfg.eval.test_path + + model_type = cfg.model.model_type + + # initialize distributed manager + DistributedManager.initialize() + dist = DistributedManager() + + if model_type == "volume" or model_type == "combined": + volume_variable_names = list(cfg.variables.volume.solution.keys()) + num_vol_vars = 0 + for j in volume_variable_names: + if cfg.variables.volume.solution[j] == "vector": + num_vol_vars += 3 + else: + num_vol_vars += 1 + else: + num_vol_vars = None + + if model_type == "surface" or model_type == "combined": + surface_variable_names = list(cfg.variables.surface.solution.keys()) + num_surf_vars = 0 + for j in surface_variable_names: + if cfg.variables.surface.solution[j] == "vector": + num_surf_vars += 3 + else: + num_surf_vars += 1 + else: + num_surf_vars = None + + global_features = 0 + global_params_names = list(cfg.variables.global_parameters.keys()) + for param in global_params_names: + if cfg.variables.global_parameters[param].type == "vector": + global_features += len(cfg.variables.global_parameters[param].reference) + else: + global_features += 1 + + ###################################################### + # Get scaling factors - precompute them if this fails! + ###################################################### + pickle_path = os.path.join(cfg.data.scaling_factors) + + vol_factors, surf_factors = load_scaling_factors(cfg) + print("Vol factors:", vol_factors) + print("Surf factors:", surf_factors) + + model = DoMINO( + input_features=3, + output_features_vol=num_vol_vars, + output_features_surf=num_surf_vars, + global_features=global_features, + model_parameters=cfg.model, + ).to(dist.device) + + model = torch.compile(model, disable=True) + + checkpoint = torch.load( + to_absolute_path(os.path.join(cfg.resume_dir, cfg.eval.checkpoint_name)), + map_location=dist.device, + ) + + model.load_state_dict(checkpoint) + + print("Model loaded") + + if dist.world_size > 1: + model = DistributedDataParallel( + model, + device_ids=[dist.local_rank], + output_device=dist.device, + broadcast_buffers=dist.broadcast_buffers, + find_unused_parameters=dist.find_unused_parameters, + gradient_as_bucket_view=True, + static_graph=True, + ) + model = model.module + + dirnames = get_filenames(input_path) + dev_id = torch.cuda.current_device() + num_files = int(len(dirnames) / dist.world_size) + dirnames_per_gpu = dirnames[int(num_files * dev_id) : int(num_files * (dev_id + 1))] + + pred_save_path = cfg.eval.save_path + + if dist.rank == 0: + create_directory(pred_save_path) + + l2_surface_all = [] + for count, dirname in enumerate(dirnames_per_gpu): + filepath = os.path.join(input_path, dirname) + tag = int(re.findall(r"(\w+?)(\d+)", dirname)[0][1]) + + # Read STL + reader = pv.get_reader(filepath) + mesh_stl = reader.read() + stl_vertices = mesh_stl.points + stl_faces = np.array(mesh_stl.faces).reshape((-1, 4))[ + :, 1: + ] # Assuming triangular elements + mesh_indices_flattened = stl_faces.flatten() + length_scale = np.array( + np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)), + dtype=np.float32, + ) + length_scale = torch.from_numpy(length_scale).to(torch.float32).to(dist.device) + stl_sizes = mesh_stl.compute_cell_sizes(length=False, area=True, volume=False) + stl_sizes = np.array(stl_sizes.cell_data["Area"], dtype=np.float32) + stl_centers = np.array(mesh_stl.cell_centers().points, dtype=np.float32) + + # Convert to torch tensors and load on device + stl_vertices = torch.from_numpy(stl_vertices).to(torch.float32).to(dist.device) + stl_sizes = torch.from_numpy(stl_sizes).to(torch.float32).to(dist.device) + stl_centers = torch.from_numpy(stl_centers).to(torch.float32).to(dist.device) + mesh_indices_flattened = ( + torch.from_numpy(mesh_indices_flattened).to(torch.int32).to(dist.device) + ) + + # Center of mass calculation + center_of_mass = calculate_center_of_mass(stl_centers, stl_sizes) + + s_max = ( + torch.from_numpy(np.asarray(cfg.data.bounding_box_surface.max)) + .to(torch.float32) + .to(dist.device) + ) + s_min = ( + torch.from_numpy(np.asarray(cfg.data.bounding_box_surface.min)) + .to(torch.float32) + .to(dist.device) + ) + + nx, ny, nz = cfg.model.interp_res + + surf_grid = create_grid( + s_max, s_min, torch.from_numpy(np.asarray([nx, ny, nz])).to(dist.device) + ) + + normed_stl_vertices_cp = normalize(stl_vertices, s_max, s_min) + surf_grid_normed = normalize(surf_grid, s_max, s_min) + + # SDF calculation on the grid using WARP + time_start = time.time() + sdf_surf_grid, _ = signed_distance_field( + normed_stl_vertices_cp, + mesh_indices_flattened, + surf_grid_normed, + use_sign_winding_number=True, + ) + + surf_grid_max_min = torch.stack([s_min, s_max]) + + # Get global parameters and global parameters scaling from config.yaml + global_params_names = list(cfg.variables.global_parameters.keys()) + global_params_reference = { + name: cfg.variables.global_parameters[name]["reference"] + for name in global_params_names + } + global_params_types = { + name: cfg.variables.global_parameters[name]["type"] + for name in global_params_names + } + stress = global_params_reference["stress"] + + # Arrange global parameters reference in a list, ensuring it is flat + global_params_reference_list = [] + for name, type in global_params_types.items(): + if type == "vector": + global_params_reference_list.extend(global_params_reference[name]) + elif type == "scalar": + global_params_reference_list.append(global_params_reference[name]) + else: + raise ValueError( + f"Global parameter {name} not supported for this dataset" + ) + global_params_reference = np.array( + global_params_reference_list, dtype=np.float32 + ) + global_params_reference = torch.from_numpy(global_params_reference).to( + dist.device + ) + + # Define the list of global parameter values for each simulation. + # Note: The user must ensure that the values provided here correspond to the + # `global_parameters` specified in `config.yaml` and that these parameters + # exist within each simulation file. + global_params_values_list = [] + for key in global_params_types.keys(): + if key == "stress": + global_params_values_list.append(stress) + else: + raise ValueError(f"Global parameter {key} not supported for this dataset") + global_params_values_list = np.array( + global_params_values_list, dtype=np.float32 + ) + global_params_values = torch.from_numpy(global_params_values_list).to( + dist.device + ) + + # Read VTP + if model_type == "surface" or model_type == "combined": + cell_data = mesh_stl.point_data_to_cell_data() + + if cfg.model.mesh_type == "node": + timesteps, surface_fields, magnitude_data = get_time_series_data(mesh_stl, data_prefix="displacement") + surface_coordinates = mesh_stl.points + else: + surface_coordinates = cell_data.cell_centers().points + timesteps, surface_fields, magnitude_data = get_time_series_data(cell_data, data_prefix="displacement") + + num_timesteps = len(timesteps) + num_points = surface_coordinates.shape[0] + + t_max = np.amax(timesteps) + t_min = np.amin(timesteps) + timesteps = torch.from_numpy(timesteps).to(torch.float32).to(dist.device) + timesteps = normalize(timesteps, t_max, t_min) + timesteps = repeat_array(timesteps, num_points, axis=1, new_axis=True) + timesteps = torch.unsqueeze(timesteps, axis=-1) + + surface_normals = np.array(cell_data.cell_normals, dtype=np.float32) + surface_sizes = cell_data.compute_cell_sizes( + length=False, area=True, volume=False + ) + surface_sizes = np.array(surface_sizes.cell_data["Area"], dtype=np.float32) + + # Normalize cell normals + surface_normals = ( + surface_normals / np.linalg.norm(surface_normals, axis=1)[:, np.newaxis] + ) + + surface_coordinates_all = [] + surface_normals_all = [] + surface_sizes_all = [] + for i in range(surface_fields.shape[0]): + surface_coordinates_all.append(surface_coordinates + surface_fields[i]) + surface_normals_all.append(surface_normals) + surface_sizes_all.append(surface_sizes) + surface_coordinates_all = np.asarray(surface_coordinates_all) + surface_normals_all = np.asarray(surface_normals_all) + surface_sizes_all = np.asarray(surface_sizes_all) + + surface_coordinates = np.concatenate([np.expand_dims(surface_coordinates, 0), surface_coordinates_all], axis=0) + surface_normals = np.concatenate([np.expand_dims(surface_normals, 0), surface_normals_all], axis=0) + surface_sizes = np.concatenate([np.expand_dims(surface_sizes, 0), surface_sizes_all], axis=0) + + # For implicit scheme, we need to add the displacements from the previous timestep to the current position + if cfg.model.transient_scheme == "implicit": + surface_fields_new = [] + for i in range(surface_coordinates.shape[0]-1): + surface_fields_new.append(surface_coordinates[i+1] - surface_coordinates[i]) + surface_fields = np.asarray(surface_fields_new) + + surface_coordinates = surface_coordinates[:-1] + surface_normals = surface_normals[:-1] + surface_sizes = surface_sizes[:-1] + # print(surface_coordinates.shape, surface_normals.shape, surface_sizes.shape, surface_fields.shape) + # exit() + if cfg.model.transient_scheme == "explicit": + surface_coordinates_init = surface_coordinates[0] + surface_normals_init = surface_normals[0] + surface_sizes_init = surface_sizes[0] + + for j in range(surface_coordinates.shape[0]): + surface_coordinates[j] = surface_coordinates_init + surface_normals[j] = surface_normals_init + surface_sizes[j] = surface_sizes_init + + surface_coordinates = ( + torch.from_numpy(surface_coordinates).to(torch.float32).to(dist.device) + ) + surface_normals = ( + torch.from_numpy(surface_normals).to(torch.float32).to(dist.device) + ) + surface_sizes = ( + torch.from_numpy(surface_sizes).to(torch.float32).to(dist.device) + ) + surface_fields = ( + torch.from_numpy(surface_fields).to(torch.float32).to(dist.device) + ) + + if cfg.model.num_neighbors_surface > 1: + time_start = time.time() + ii, dd = knn( + points=surface_coordinates[0], + queries=surface_coordinates[0], + k=cfg.model.num_neighbors_surface, + ) + + surface_neighbors = surface_coordinates[:, ii] + surface_neighbors = surface_neighbors[:, :, 1:] + + timesteps_neighbors = repeat_array(timesteps, cfg.model.num_neighbors_surface-1, axis=2, new_axis=True) + + if cfg.model.mesh_type == "element": + surface_neighbors_normals = surface_normals[:, ii] + surface_neighbors_normals = surface_neighbors_normals[:, :, 1:] + surface_neighbors_sizes = surface_sizes[:, ii] + surface_neighbors_sizes = surface_neighbors_sizes[:, :, 1:] + else: + surface_neighbors_normals = surface_normals + surface_neighbors_sizes = surface_sizes + + else: + surface_neighbors = surface_coordinates + surface_neighbors_normals = surface_normals + surface_neighbors_sizes = surface_sizes + + if cfg.data.normalize_coordinates: + surface_coordinates = normalize(surface_coordinates, s_max, s_min) + surf_grid = normalize(surf_grid, s_max, s_min) + center_of_mass_normalized = normalize(center_of_mass, s_max, s_min) + surface_neighbors = normalize(surface_neighbors, s_max, s_min) + else: + center_of_mass_normalized = center_of_mass + pos_surface_center_of_mass = surface_coordinates - center_of_mass_normalized + + surface_coordinates = torch.cat([surface_coordinates, timesteps], axis=-1) + if cfg.model.num_neighbors_surface > 1: + surface_neighbors = torch.cat([surface_neighbors, timesteps_neighbors], axis=-1) + else: + surface_neighbors = surface_neighbors + + else: + surface_coordinates = None + surface_fields = None + surface_sizes = None + surface_normals = None + surface_neighbors = None + surface_neighbors_normals = None + surface_neighbors_sizes = None + pos_surface_center_of_mass = None + + geom_centers = stl_vertices + + if model_type == "surface": + data_dict = { + "pos_surface_center_of_mass": pos_surface_center_of_mass, + "geometry_coordinates": geom_centers, + "surf_grid": surf_grid, + "sdf_surf_grid": sdf_surf_grid, + "surface_mesh_centers": surface_coordinates, + "surface_mesh_neighbors": surface_neighbors, + "surface_normals": surface_normals, + "surface_neighbors_normals": surface_neighbors_normals, + "surface_areas": surface_sizes, + "surface_neighbors_areas": surface_neighbors_sizes, + "surface_fields": surface_fields, + "surface_min_max": surf_grid_max_min, + "length_scale": length_scale, + "global_params_values": torch.unsqueeze(global_params_values, -1), + "global_params_reference": torch.unsqueeze(global_params_reference, -1), + } + else: + raise ValueError(f"Model type: {model_type} not supported yet") + + data_dict = {key: torch.unsqueeze(value, 0) for key, value in data_dict.items()} + + prediction_surf = test_step(data_dict, model, dist.device, cfg, surf_factors) + + prediction_surf = prediction_surf[0].reshape(num_timesteps, num_points, prediction_surf.shape[-1]) + surface_fields = surface_fields.reshape(num_timesteps, num_points, surface_fields.shape[-1]) + + surface_coordinates_initial = unnormalize(surface_coordinates[0, :, :3], s_max, s_min) + surface_coordinates_unnormalized = unnormalize(surface_coordinates[:, :, :3], s_max, s_min) + + if cfg.model.transient_scheme == "implicit": + for i in range(num_timesteps): + if i == 0: + prediction_surf[i, :, :] += surface_coordinates_initial + surface_fields[i, :, :] += surface_coordinates_initial + else: + d_prediction_surf = prediction_surf[i-1, :, :] + d_truth_surf = surface_fields[i-1, :, :] + prediction_surf[i, :, :] = prediction_surf[i, :, :] + d_prediction_surf + surface_fields[i, :, :] = (surface_coordinates_unnormalized[i, :, :] - surface_coordinates_unnormalized[i-1, :, :]) + d_truth_surf + elif cfg.model.transient_scheme == "explicit": + for i in range(num_timesteps): + prediction_surf[i, :, :] += surface_coordinates_initial + surface_fields[i, :, :] += surface_coordinates_initial + else: + raise ValueError(f"Invalid transient scheme: {cfg.model.transient_scheme}") + # import pdb; pdb.set_trace() + vtp_pred_save_path = os.path.join( + pred_save_path, dirname[:-4], "predicted" + ) + create_directory(vtp_pred_save_path) + vtp_true_save_path = os.path.join( + pred_save_path, dirname[:-4], "true" + ) + create_directory(vtp_true_save_path) + + prediction_surf = prediction_surf.cpu().numpy() + surface_fields = surface_fields.cpu().numpy() + + surface_coordinates_initial = surface_coordinates_initial.cpu().numpy() + timesteps = unnormalize(timesteps, t_max, t_min) + timesteps = timesteps.cpu().numpy() + if prediction_surf is not None: + + mesh_stl.clear_cell_data() + mesh_stl.clear_point_data() + + mesh_stl_deformed_new = mesh_stl.copy() + initial_field_pred = mesh_stl_deformed_new.points + initial_field_true = mesh_stl_deformed_new.points + + for i in range(1, cfg.model.integration_steps + 1): + vtp_pred_save_path_new = os.path.join( + vtp_pred_save_path, f"boundary_predicted_{i}.vtp" + ) + vtp_true_save_path_new = os.path.join( + vtp_true_save_path, f"boundary_true_{i}.vtp" + ) + vector_field_name = f"displacement" + + initial_field_pred_new = prediction_surf[i, :, :] + initial_field_true_new = surface_fields[i, :, :] + + mesh_stl_deformed_new.points = initial_field_pred_new + mesh_stl_deformed_new[vector_field_name] = prediction_surf[i, :, :] - surface_coordinates_initial + mesh_stl_deformed_new.save(vtp_pred_save_path_new) + mesh_stl_deformed_new.points = initial_field_true_new + mesh_stl_deformed_new[vector_field_name] = surface_fields[i, :, :] - surface_coordinates_initial + mesh_stl_deformed_new.save(vtp_true_save_path_new) + + pvd_content = """ + + + """ + for timestep in range(1, cfg.model.integration_steps + 1): + pvd_content += f' \n' + pvd_content += """ + + """ + + pvd_filename = os.path.join(os.path.join(vtp_pred_save_path, "predicted.pvd")) + with open(pvd_filename, "w") as f: + f.write(pvd_content) + + pvd_content = """ + + + """ + for timestep in range(1, cfg.model.integration_steps + 1): + pvd_content += f' \n' + pvd_content += """ + + """ + pvd_filename = os.path.join(os.path.join(vtp_true_save_path, "truth.pvd")) + with open(pvd_filename, "w") as f: + f.write(pvd_content) + + if prediction_surf is not None: + + for ii in range(surface_fields.shape[0]): + print("Timestep:", ii) + l2_gt = np.mean(np.square(surface_fields[ii] - surface_coordinates_initial), (0)) + l2_error = np.mean(np.square(prediction_surf[ii] - surface_fields[ii]), (0)) + l2_surface_all.append(np.sqrt(l2_error / l2_gt)) + + error_max = (np.max(np.abs(prediction_surf[ii] - surface_coordinates_initial), axis=(0)) - np.amax(abs(surface_fields[ii] - surface_coordinates_initial), axis=(0)))/np.amax(np.abs(surface_fields[ii] - surface_coordinates_initial), axis=(0)) + pred_displacement_mag = np.sqrt(np.sum(np.square(prediction_surf[ii] - surface_coordinates_initial), axis=(1))) + true_displacement_mag = np.sqrt(np.sum(np.square(surface_fields[ii] - surface_coordinates_initial), axis=(1))) + l2_gt_displacement_mag = np.mean(np.square(true_displacement_mag), (0)) + l2_error_displacement_mag = np.mean(np.square(pred_displacement_mag - true_displacement_mag), (0)) + error_max_displacement = (np.max(np.abs(pred_displacement_mag), axis=(0)) - np.amax(abs(true_displacement_mag), axis=(0)))/np.amax(np.abs(true_displacement_mag), axis=(0)) + + print( + "Surface L-2 norm:", + dirname, + np.sqrt(l2_error) / np.sqrt(l2_gt), + ) + print( + "Surface mse:", + dirname, + l2_error, + ) + print( + "Surface error max:", + dirname, + error_max, + ) + print( + "Displacement L-2 norm:", + dirname, + np.sqrt(l2_error_displacement_mag) / np.sqrt(l2_gt_displacement_mag), + ) + print( + "Displacement mse:", + dirname, + l2_error_displacement_mag, + ) + print( + "Displacement error max:", + dirname, + error_max_displacement, + ) + + l2_surface_all = np.asarray(l2_surface_all) # num_files, 4 + l2_surface_mean = np.mean(l2_surface_all, 0) + print( + f"Mean over all samples, surface={l2_surface_mean}" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/structural_mechanics/crash_domino/src/train.py b/examples/structural_mechanics/crash_domino/src/train.py new file mode 100644 index 0000000000..6230842d6f --- /dev/null +++ b/examples/structural_mechanics/crash_domino/src/train.py @@ -0,0 +1,598 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code defines a distributed pipeline for training the DoMINO model on +Crash datasets. It includes the computation of scaling factors, instantiating +the DoMINO model and datapipe, automatically loading the most recent checkpoint, +training the model in parallel using DistributedDataParallel across multiple +GPUs, calculating the loss and updating model parameters using mixed precision. +This is a common recipe that enables training of surface model. +Validation is also conducted every epoch, +where predictions are compared against ground truth values. The code logs training +and validation metrics to TensorBoard. The train tab in config.yaml can be used to +specify batch size, number of epochs and other training parameters. +""" + +import time +import os +import re +from typing import Literal, Any +from tabulate import tabulate + +import numpy as np +import hydra +from hydra.utils import to_absolute_path +from omegaconf import DictConfig, OmegaConf + +# This will set up the cupy-ecosystem and pytorch to share memory pools +from physicsnemo.utils.memory import unified_gpu_memory + +import torchinfo +import torch.distributed as dist +from torch.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.utils.tensorboard import SummaryWriter +from nvtx import annotate as nvtx_annotate +import torch.cuda.nvtx as nvtx + + +from physicsnemo.distributed import DistributedManager +from physicsnemo.launch.utils import load_checkpoint, save_checkpoint +from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper + +from physicsnemo.datapipes.cae.domino_datapipe_transient import ( + DoMINODataPipe, + create_domino_dataset, +) + +from physicsnemo.models.domino_transient.model import DoMINO +from physicsnemo.utils.domino.utils import * + +from utils import ScalingFactors, get_keys_to_read, coordinate_distributed_environment + +# This is included for GPU memory tracking: +from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo +import time + + +# Initialize NVML +nvmlInit() + + +from physicsnemo.utils.profiling import profile, Profiler + + +from loss import compute_loss_dict +from utils import get_num_vars, load_scaling_factors, compute_l2, all_reduce_dict + + +def validation_step( + dataloader, + model, + device, + logger, + tb_writer, + epoch_index, + use_sdf_basis=False, + use_surface_normals=False, + loss_fn_type=None, + vol_loss_scaling=None, + surf_loss_scaling=None, + vol_factors: torch.Tensor | None = None, + autocast_enabled=None, +): + dm = DistributedManager() + running_vloss = 0.0 + with torch.no_grad(): + metrics = None + + for i_batch, sample_batched in enumerate(dataloader): + sampled_batched = dict_to_device(sample_batched, device) + + with autocast("cuda", enabled=autocast_enabled, cache_enabled=False): + prediction_vol, prediction_surf = model(sampled_batched) + + loss, loss_dict = compute_loss_dict( + prediction_vol, + prediction_surf, + sampled_batched, + loss_fn_type, + surf_loss_scaling, + ) + + running_vloss += loss.item() + local_metrics = compute_l2( + prediction_surf, prediction_vol, sampled_batched, dataloader + ) + if metrics is None: + metrics = local_metrics + else: + metrics = { + key: metrics[key] + local_metrics[key] for key in metrics.keys() + } + + avg_vloss = running_vloss / (i_batch + 1) + metrics = {key: metrics[key] / (i_batch + 1) for key in metrics.keys()} + + metrics = all_reduce_dict(metrics, dm) + + if dm.rank == 0: + logger.info( + f" Device {device}, batch: {i_batch + 1}, VAL loss norm: {loss.detach().item():.5f}" + ) + tb_x = epoch_index + for key in metrics.keys(): + tb_writer.add_scalar(f"L2 Metrics/val/{key}", metrics[key], tb_x) + + metrics_table = tabulate( + [[k, v] for k, v in metrics.items()], + headers=["Metric", "Average Value"], + tablefmt="pretty", + ) + logger.info( + f"\nEpoch {epoch_index} VALIDATION Average Metrics:\n{metrics_table}\n" + ) + + return avg_vloss + + +@profile +def train_epoch( + dataloader, + model, + optimizer, + scaler, + tb_writer, + logger, + gpu_handle, + epoch_index, + device, + loss_fn_type, + surf_loss_scaling=None, + autocast_enabled=None, + grad_clip_enabled=None, + grad_max_norm=None, +): + dm = DistributedManager() + + running_loss = 0.0 + last_loss = 0.0 + loss_interval = 1 + + gpu_start_info = nvmlDeviceGetMemoryInfo(gpu_handle) + start_time = time.perf_counter() + with Profiler(): + io_start_time = time.perf_counter() + metrics = None + for i_batch, sampled_batched in enumerate(dataloader): + io_end_time = time.perf_counter() + with autocast("cuda", enabled=autocast_enabled, cache_enabled=False): + with nvtx.range("Model Forward Pass"): + prediction_vol, prediction_surf = model(sampled_batched) + + loss, loss_dict = compute_loss_dict( + prediction_vol, + prediction_surf, + sampled_batched, + loss_fn_type, + surf_loss_scaling, + ) + + local_metrics = compute_l2( + prediction_surf, prediction_vol, sampled_batched, dataloader + ) + if metrics is None: + metrics = local_metrics + else: + # Sum the running total: + metrics = { + key: metrics[key] + local_metrics[key] for key in metrics.keys() + } + + loss = loss / loss_interval + scaler.scale(loss).backward() + + if ((i_batch + 1) % loss_interval == 0) or (i_batch + 1 == len(dataloader)): + if grad_clip_enabled: + # Unscales the gradients of optimizer's assigned params in-place. + scaler.unscale_(optimizer) + + # Since the gradients of optimizer's assigned params are unscaled, clips as usual. + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_max_norm) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + # Gather data and report + running_loss += loss.detach().item() + elapsed_time = time.perf_counter() - start_time + io_time = io_end_time - io_start_time + start_time = time.perf_counter() + gpu_end_info = nvmlDeviceGetMemoryInfo(gpu_handle) + gpu_memory_used = gpu_end_info.used / (1024**3) + gpu_memory_delta = (gpu_end_info.used - gpu_start_info.used) / (1024**3) + + logging_string = f"Device {device}, batch processed: {i_batch + 1}\n" + # Format the loss dict into a string: + loss_string = ( + " " + + "\t".join( + [f"{key.replace('loss_', ''):<10}" for key in loss_dict.keys()] + ) + + "\n" + ) + loss_string += ( + " " + + f"\t".join( + [f"{l.detach().item():<10.3e}" for l in loss_dict.values()] + ) + + "\n" + ) + + logging_string += loss_string + logging_string += f" GPU memory used: {gpu_memory_used:.3f} Gb (delta: {gpu_memory_delta:.3f})\n" + logging_string += f" Timings: (IO: {io_time:.2f}, Model: {elapsed_time - io_time:.2f}, Total: {elapsed_time:.2f})s\n" + logger.info(logging_string) + gpu_start_info = nvmlDeviceGetMemoryInfo(gpu_handle) + io_start_time = time.perf_counter() + + last_loss = running_loss / (i_batch + 1) # loss per batch + # Normalize metrics: + metrics = {key: metrics[key] / (i_batch + 1) for key in metrics.keys()} + # reduce metrics across batch: + metrics = all_reduce_dict(metrics, dm) + if dm.rank == 0: + logger.info( + f" Device {device}, batch: {i_batch + 1}, loss norm: {loss.detach().item():.5f}" + ) + tb_x = epoch_index * len(dataloader) + i_batch + 1 + tb_writer.add_scalar("Loss/train", last_loss, tb_x) + for key in metrics.keys(): + tb_writer.add_scalar(f"L2 Metrics/train/{key}", metrics[key], epoch_index) + + metrics_table = tabulate( + [[k, v] for k, v in metrics.items()], + headers=["Metric", "Average Value"], + tablefmt="pretty", + ) + logger.info(f"\nEpoch {epoch_index} Average Metrics:\n{metrics_table}\n") + + return last_loss + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + ###################################################### + # initialize distributed manager + ###################################################### + DistributedManager.initialize() + dist = DistributedManager() + + # DoMINO supports domain parallel training. This function helps coordinate + # how to set that up, if needed. + domain_mesh, data_mesh, placements = coordinate_distributed_environment(cfg) + + ################################ + # Initialize NVML + ################################ + nvmlInit() + gpu_handle = nvmlDeviceGetHandleByIndex(dist.device.index) + + ###################################################### + # Initialize logger + ###################################################### + + logger = PythonLogger("Train") + logger = RankZeroLoggingWrapper(logger, dist) + + logger.info(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") + + ###################################################### + # Get scaling factors - precompute them if this fails! + ###################################################### + # vol_factors, surf_factors = load_scaling_factors(cfg) + try: + vol_factors, surf_factors = load_scaling_factors(cfg) + except FileNotFoundError: + surf_factors = None + vol_factors = None + + if surf_factors is None and (cfg.model.model_type == "surface" or cfg.model.model_type == "combined"): + raise FileNotFoundError(f"Scaling factors not found at: {cfg.data.scaling_factors}; please run compute_statistics.py to compute them.") + + if vol_factors is None and (cfg.model.model_type == "volume" or cfg.model.model_type == "combined"): + raise FileNotFoundError(f"Scaling factors not found at: {cfg.data.scaling_factors}; please run compute_statistics.py to compute them.") + + ###################################################### + # Configure the model + ###################################################### + model_type = cfg.model.model_type + num_vol_vars, num_surf_vars, num_global_features = get_num_vars(cfg, model_type) + + ###################################################### + # Configure the dataset + ###################################################### + + # This helper function is to determine which keys to read from the data + # (and which to use default values for, if they aren't present - like + # stress, for example) + keys_to_read, keys_to_read_if_available = get_keys_to_read( + cfg, model_type, get_ground_truth=True + ) + + # The dataset actually works in two pieces + # The core dataset just reads data from disk, and puts it on the GPU if needed. + # The data processesing pipeline will preprocess that data and prepare it for the model. + # Obviously, you need both, so this function will return the datapipeline in + # a way that can be iterated over. + # + # To properly shuffle the data, we use a distributed sampler too. + # It's configured properly for optional domain parallelism, and you have + # to make sure to call set_epoch below. + + train_dataloader = create_domino_dataset( + cfg, + phase="train", + keys_to_read=keys_to_read, + keys_to_read_if_available=keys_to_read_if_available, + vol_factors=vol_factors, + surf_factors=surf_factors, + device_mesh=domain_mesh, + placements=placements, + normalize_coordinates=cfg.data.normalize_coordinates, + sample_in_bbox=cfg.data.sample_in_bbox, + sampling=cfg.data.sampling, + ) + train_sampler = DistributedSampler( + train_dataloader, + num_replicas=data_mesh.size(), + rank=data_mesh.get_local_rank(), + **cfg.train.sampler, + ) + + val_dataloader = create_domino_dataset( + cfg, + phase="val", + keys_to_read=keys_to_read, + keys_to_read_if_available=keys_to_read_if_available, + vol_factors=vol_factors, + surf_factors=surf_factors, + device_mesh=domain_mesh, + placements=placements, + normalize_coordinates=cfg.data.normalize_coordinates, + sample_in_bbox=cfg.data.sample_in_bbox, + sampling=cfg.data.sampling, + ) + val_sampler = DistributedSampler( + val_dataloader, + num_replicas=data_mesh.size(), + rank=data_mesh.get_local_rank(), + **cfg.val.sampler, + ) + + ###################################################### + # Configure the model + ###################################################### + model = DoMINO( + input_features=3, + output_features_vol=num_vol_vars, + output_features_surf=num_surf_vars, + global_features=num_global_features, + model_parameters=cfg.model, + nodal_surface_features=0, + nodal_geometry_features=0, + ).to(dist.device) + + # Print model summary (structure and parmeter count). + logger.info(f"Model summary:\n{torchinfo.summary(model, verbose=0, depth=2)}\n") + + if dist.world_size > 1: + model = DistributedDataParallel( + model, + device_ids=[dist.local_rank], + output_device=dist.device, + broadcast_buffers=dist.broadcast_buffers, + find_unused_parameters=dist.find_unused_parameters, + gradient_as_bucket_view=True, + static_graph=True, + ) + + ###################################################### + # Initialize optimzer and gradient scaler + ###################################################### + + optimizer_class = None + if cfg.train.optimizer.name == "Adam": + optimizer_class = torch.optim.Adam + elif cfg.train.optimizer.name == "AdamW": + optimizer_class = torch.optim.AdamW + else: + raise ValueError(f"Unsupported optimizer: {cfg.train.optimizer.name}") + optimizer = optimizer_class( + model.parameters(), + lr=cfg.train.optimizer.lr, + weight_decay=cfg.train.optimizer.weight_decay, + ) + if cfg.train.lr_scheduler.name == "MultiStepLR": + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=cfg.train.lr_scheduler.milestones, + gamma=cfg.train.lr_scheduler.gamma, + ) + elif cfg.train.lr_scheduler.name == "CosineAnnealingLR": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=cfg.train.lr_scheduler.T_max, + eta_min=cfg.train.lr_scheduler.eta_min, + ) + else: + raise ValueError(f"Unsupported scheduler: {cfg.train.lr_scheduler.name}") + + # Initialize the scaler for mixed precision + scaler = GradScaler() + + ###################################################### + # Initialize output tools + ###################################################### + + # Tensorboard Writer to track training. + writer = SummaryWriter(os.path.join(cfg.output, "tensorboard")) + + epoch_number = 0 + + model_save_path = os.path.join(cfg.output, "models") + param_save_path = os.path.join(cfg.output, "param") + best_model_path = os.path.join(model_save_path, "best_model") + if dist.rank == 0: + create_directory(model_save_path) + create_directory(param_save_path) + create_directory(best_model_path) + + if dist.world_size > 1: + torch.distributed.barrier() + + ###################################################### + # Load checkpoint if available + ###################################################### + init_epoch = load_checkpoint( + to_absolute_path(cfg.resume_dir), + models=model, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + device=dist.device, + ) + + if init_epoch != 0: + init_epoch += 1 # Start with the next epoch + epoch_number = init_epoch + + # retrive the smallest validation loss if available + numbers = [] + for filename in os.listdir(best_model_path): + match = re.search(r"\d+\.\d*[1-9]\d*", filename) + if match: + number = float(match.group(0)) + numbers.append(number) + + best_vloss = min(numbers) if numbers else 1_000_000.0 + + ###################################################### + # Begin Training loop over epochs + ###################################################### + + for epoch in range(init_epoch, cfg.train.epochs): + start_time = time.perf_counter() + logger.info(f"Device {dist.device}, epoch {epoch_number}:") + + # This controls what indices to use for each epoch. + train_sampler.set_epoch(epoch) + val_sampler.set_epoch(epoch) + train_dataloader.dataset.set_indices(list(train_sampler)) + val_dataloader.dataset.set_indices(list(val_sampler)) + + if epoch > 250: + surface_scaling_loss = 1.0 * cfg.model.surf_loss_scaling + else: + surface_scaling_loss = cfg.model.surf_loss_scaling + + model.train(True) + epoch_start_time = time.perf_counter() + avg_loss = train_epoch( + dataloader=train_dataloader, + model=model, + optimizer=optimizer, + scaler=scaler, + tb_writer=writer, + logger=logger, + gpu_handle=gpu_handle, + epoch_index=epoch, + device=dist.device, + loss_fn_type=cfg.model.loss_function, + surf_loss_scaling=surface_scaling_loss, + autocast_enabled=cfg.train.amp.enabled, + grad_clip_enabled=cfg.train.amp.clip_grad, + grad_max_norm=cfg.train.amp.grad_max_norm, + ) + epoch_end_time = time.perf_counter() + logger.info( + f"Device {dist.device}, Epoch {epoch_number} took {epoch_end_time - epoch_start_time:.3f} seconds" + ) + epoch_end_time = time.perf_counter() + + model.eval() + avg_vloss = validation_step( + dataloader=val_dataloader, + model=model, + device=dist.device, + logger=logger, + tb_writer=writer, + epoch_index=epoch, + use_sdf_basis=cfg.model.use_sdf_in_basis_func, + use_surface_normals=cfg.model.use_surface_normals, + loss_fn_type=cfg.model.loss_function, + surf_loss_scaling=surface_scaling_loss, + autocast_enabled=cfg.train.amp.enabled, + ) + + scheduler.step() + logger.info( + f"Device {dist.device} " + f"LOSS train {avg_loss:.5f} " + f"valid {avg_vloss:.5f} " + f"Current lr {scheduler.get_last_lr()[0]} " + ) + + if dist.rank == 0: + writer.add_scalars( + "Training vs. Validation Loss", + {"Training": avg_loss, "Validation": avg_vloss}, + epoch_number, + ) + writer.flush() + + # Track best performance, and save the model's state + if dist.world_size > 1: + torch.distributed.barrier() + + if avg_vloss < best_vloss: # This only considers GPU: 0, is that okay? + best_vloss = avg_vloss + + if dist.rank == 0: + print(f"Device {dist.device}, Best val loss {best_vloss}") + + if dist.rank == 0 and (epoch + 1) % cfg.train.checkpoint_interval == 0.0: + save_checkpoint( + to_absolute_path(model_save_path), + models=model, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + epoch=epoch, + ) + + epoch_number += 1 + + if scheduler.get_last_lr()[0] == 1e-6: + print("Training ended") + exit() + + +if __name__ == "__main__": + main() diff --git a/examples/structural_mechanics/crash_domino/src/utils.py b/examples/structural_mechanics/crash_domino/src/utils.py new file mode 100644 index 0000000000..1c71735a8e --- /dev/null +++ b/examples/structural_mechanics/crash_domino/src/utils.py @@ -0,0 +1,629 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from dataclasses import dataclass +from typing import Dict, Optional, Any +import numpy as np +import torch +import torch.distributed as dist +import pickle +from pathlib import Path +from typing import Literal, Tuple +from omegaconf import DictConfig +from physicsnemo.distributed import DistributedManager + +from torch.distributed.tensor.placement_types import ( + Shard, + Replicate, +) +import pyvista as pv + + +def get_num_vars(cfg: dict, model_type: Literal["volume", "surface", "combined"]): + """Calculate the number of variables for volume, surface, and global features. + + This function analyzes the configuration to determine how many variables are needed + for different mesh data types based on the model type. Vector variables contribute + 3 components (x, y, z) while scalar variables contribute 1 component each. + + Args: + cfg: Configuration object containing variable definitions for volume, surface, + and global parameters with their types (scalar/vector). + model_type (str): Type of model - can be "volume", "surface", or "combined". + Determines which variable types are included in the count. + + Returns: + tuple: A 3-tuple containing: + - num_vol_vars (int or None): Number of volume variables. None if model_type + is not "volume" or "combined". + - num_surf_vars (int or None): Number of surface variables. None if model_type + is not "surface" or "combined". + - num_global_features (int): Number of global parameter features. + """ + num_vol_vars = 0 + volume_variable_names = [] + if model_type == "volume" or model_type == "combined": + volume_variable_names = list(cfg.variables.volume.solution.keys()) + for j in volume_variable_names: + if cfg.variables.volume.solution[j] == "vector": + num_vol_vars += 3 + else: + num_vol_vars += 1 + else: + num_vol_vars = None + + num_surf_vars = 0 + surface_variable_names = [] + if model_type == "surface" or model_type == "combined": + surface_variable_names = list(cfg.variables.surface.solution.keys()) + num_surf_vars = 0 + for j in surface_variable_names: + if cfg.variables.surface.solution[j] == "vector": + num_surf_vars += 3 + else: + num_surf_vars += 1 + else: + num_surf_vars = None + + num_global_features = 0 + global_params_names = list(cfg.variables.global_parameters.keys()) + for param in global_params_names: + if cfg.variables.global_parameters[param].type == "vector": + num_global_features += len(cfg.variables.global_parameters[param].reference) + elif cfg.variables.global_parameters[param].type == "scalar": + num_global_features += 1 + else: + raise ValueError(f"Unknown global parameter type") + + return num_vol_vars, num_surf_vars, num_global_features + + +def get_keys_to_read( + cfg: dict, + model_type: Literal["volume", "surface", "combined"], + get_ground_truth: bool = True, +): + """ + This function helps configure the keys to read from the dataset. + + And, if some global parameter values are provided in the config, + they are also read here and passed to the dataset. + + """ + + # Always read these keys: + keys_to_read = ["stl_coordinates", "timesteps"] + + # If these keys are in the config, use them, else provide defaults in + # case they aren't in the dataset: + # cfg_params_vec = [] + # for key in cfg.variables.global_parameters: + # if cfg.variables.global_parameters[key].type == "vector": + # cfg_params_vec.extend(cfg.variables.global_parameters[key].reference) + # else: + # cfg_params_vec.append(cfg.variables.global_parameters[key].reference) + # keys_to_read_if_available = { + # "global_params_values": torch.tensor(cfg_params_vec).reshape(-1, 1), + # "global_params_reference": torch.tensor(cfg_params_vec).reshape(-1, 1), + # } + keys_to_read_if_available = {} + + # Volume keys: + volume_keys = [ + "volume_mesh_centers", + ] + if get_ground_truth: + volume_keys.append("volume_fields") + + # Surface keys: + surface_keys = [ + "surface_mesh_centers", + # "surface_normals", + # "surface_areas", + ] + if get_ground_truth: + surface_keys.append("surface_fields") + + if model_type == "volume" or model_type == "combined": + keys_to_read.extend(volume_keys) + if model_type == "surface" or model_type == "combined": + keys_to_read.extend(surface_keys) + + return keys_to_read, keys_to_read_if_available + + +def coordinate_distributed_environment(cfg: DictConfig): + """ + Initialize the distributed env for DoMINO. This is actually always a 2D Mesh: + one dimension is the data-parallel dimension (DDP), and the other is the + domain dimension. + + For the training scripts, we need to know the rank, size of each dimension, + and return the domain_mesh and placements for the loader. + + Args: + cfg: Configuration object containing the domain parallelism configuration. + + Returns: + domain_mesh: torch.distributed.DeviceMesh: The domain mesh for the domain parallel dimension. + data_mesh: torch.distributed.DeviceMesh: The data mesh for the data parallel dimension. + placements: dict[str, torch.distributed.tensor.Placement]: The placements for the data set + """ + + if not DistributedManager.is_initialized(): + DistributedManager.initialize() + dist = DistributedManager() + + # Default to no domain parallelism: + domain_size = cfg.get("domain_parallelism", {}).get("domain_size", 1) + + # Initialize the device mesh: + mesh = dist.initialize_mesh( + mesh_shape=(-1, domain_size), mesh_dim_names=("ddp", "domain") + ) + domain_mesh = mesh["domain"] + data_mesh = mesh["ddp"] + + if domain_size > 1: + # Define the default placements for each tensor that might show up in + # the data. Note that we'll define placements for all keys, even if + # they aren't actually used. + + # Note that placements are defined for pre-batched data, no batch index! + + grid_like_placement = [ + Shard(0), + ] + point_like_placement = [ + Shard(0), + ] + replicate_placement = [ + Replicate(), + ] + placements = { + "stl_coordinates": point_like_placement, + "stl_centers": point_like_placement, + "stl_faces": point_like_placement, + "stl_areas": point_like_placement, + "surface_fields": point_like_placement, + "volume_mesh_centers": point_like_placement, + "volume_fields": point_like_placement, + "surface_mesh_centers": point_like_placement, + "surface_normals": point_like_placement, + "surface_areas": point_like_placement, + } + else: + domain_mesh = None + placements = None + + return domain_mesh, data_mesh, placements + + +@dataclass +class ScalingFactors: + """ + Data structure for storing scaling factors computed for DoMINO datasets. + + This class provides a clean, easily serializable format for storing + mean, std, min, and max values for different array keys in the dataset. + Uses numpy arrays for easy serialization and cross-platform compatibility. + + Attributes: + mean: Dictionary mapping keys to mean numpy arrays + std: Dictionary mapping keys to standard deviation numpy arrays + min_val: Dictionary mapping keys to minimum value numpy arrays + max_val: Dictionary mapping keys to maximum value numpy arrays + field_keys: List of field keys for which statistics were computed + """ + + mean: Dict[str, np.ndarray] + std: Dict[str, np.ndarray] + min_val: Dict[str, np.ndarray] + max_val: Dict[str, np.ndarray] + field_keys: list[str] + + def to_torch( + self, device: Optional[torch.device] = None + ) -> Dict[str, Dict[str, torch.Tensor]]: + """Convert numpy arrays to torch tensors for use in training/inference.""" + device = device or torch.device("cpu") + + return { + "mean": {k: torch.from_numpy(v).to(device) for k, v in self.mean.items()}, + "std": {k: torch.from_numpy(v).to(device) for k, v in self.std.items()}, + "min_val": { + k: torch.from_numpy(v).to(device) for k, v in self.min_val.items() + }, + "max_val": { + k: torch.from_numpy(v).to(device) for k, v in self.max_val.items() + }, + } + + def save(self, filepath: str | Path) -> None: + """Save scaling factors to pickle file.""" + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + + with open(filepath, "wb") as f: + pickle.dump(self, f) + + @classmethod + def load(cls, filepath: str | Path) -> "ScalingFactors": + """Load scaling factors from pickle file.""" + with open(filepath, "rb") as f: + factors = pickle.load(f) + return factors + + def get_field_shapes(self) -> Dict[str, tuple]: + """Get the shape of each field's statistics.""" + return {key: self.mean[key].shape for key in self.field_keys} + + def summary(self) -> str: + """Generate a human-readable summary of the scaling factors.""" + summary = ["Scaling Factors Summary:"] + summary.append(f"Field Keys: {self.field_keys}") + + for key in self.field_keys: + mean_val = self.mean[key] + std_val = self.std[key] + min_val = self.min_val[key] + max_val = self.max_val[key] + + summary.append(f"\n{key}:") + summary.append(f" Shape: {mean_val.shape}") + summary.append(f" Mean: {mean_val}") + summary.append(f" Std: {std_val}") + summary.append(f" Min: {min_val}") + summary.append(f" Max: {max_val}") + + return "\n".join(summary) + + +def load_scaling_factors( + cfg: DictConfig, logger=None +) -> tuple[torch.Tensor, torch.Tensor]: + """Load scaling factors from the configuration.""" + pickle_path = os.path.join(cfg.data.scaling_factors) + + try: + scaling_factors = ScalingFactors.load(pickle_path) + if logger is not None: + logger.info(f"Scaling factors loaded from: {pickle_path}") + except FileNotFoundError: + raise FileNotFoundError( + f"Scaling factors not found at: {pickle_path}; please run compute_statistics.py to compute them." + ) + + if cfg.model.normalization == "min_max_scaling": + if cfg.model.model_type == "volume" or cfg.model.model_type == "combined": + vol_factors = np.asarray( + [ + scaling_factors.max_val["volume_fields"], + scaling_factors.min_val["volume_fields"], + ] + ) + if cfg.model.model_type == "surface" or cfg.model.model_type == "combined": + surf_factors = np.asarray( + [ + scaling_factors.max_val["surface_fields"], + scaling_factors.min_val["surface_fields"], + ] + ) + elif cfg.model.normalization == "mean_std_scaling": + if cfg.model.model_type == "volume" or cfg.model.model_type == "combined": + vol_factors = np.asarray( + [ + scaling_factors.mean["volume_fields"], + scaling_factors.std["volume_fields"], + ] + ) + if cfg.model.model_type == "surface" or cfg.model.model_type == "combined": + surf_factors = np.asarray( + [ + scaling_factors.mean["surface_fields"], + scaling_factors.std["surface_fields"], + ] + ) + else: + raise ValueError(f"Invalid normalization mode: {cfg.model.normalization}") + + dm = DistributedManager() + if cfg.model.model_type == "volume" or cfg.model.model_type == "combined": + vol_factors_tensor = torch.from_numpy(vol_factors) + vol_factors_tensor = vol_factors_tensor.to(dm.device, dtype=torch.float32) + else: + vol_factors_tensor = None + if cfg.model.model_type == "surface" or cfg.model.model_type == "combined": + surf_factors_tensor = torch.from_numpy(surf_factors) + surf_factors_tensor = surf_factors_tensor.to(dm.device, dtype=torch.float32) + else: + surf_factors_tensor = None + return vol_factors_tensor, surf_factors_tensor + +def compute_l2( + pred_surface: torch.Tensor | None, + pred_volume: torch.Tensor | None, + batch, + dataloader, +) -> dict[str, torch.Tensor]: + """ + Compute the L2 norm between prediction and target. + + Requires the dataloader to unscale back to original values + """ + + l2_dict = {} + + if pred_surface is not None: + _, target_surface = dataloader.unscale_model_outputs( + surface_fields=batch["surface_fields"] + ) + _, pred_surface = dataloader.unscale_model_outputs(surface_fields=pred_surface) + l2_surface = metrics_fn_surface(pred_surface, target_surface) + l2_dict.update(l2_surface) + if pred_volume is not None: + target_volume, _ = dataloader.unscale_model_outputs( + volume_fields=batch["volume_fields"] + ) + pred_volume, _ = dataloader.unscale_model_outputs(volume_fields=pred_volume) + l2_volume = metrics_fn_volume(pred_volume, target_volume) + l2_dict.update(l2_volume) + + return l2_dict + + +def metrics_fn_surface( + pred: torch.Tensor, + target: torch.Tensor, +) -> dict[str, torch.Tensor]: + """ + Computes L2 surface metrics between prediction and target. + + Args: + pred: Predicted values (normalized). + target: Target values (normalized). + + Returns: + Dictionary of L2 surface metrics for pressure and shear components. + """ + + l2_num = (pred - target) ** 2 + l2_num = torch.sum(l2_num, dim=(1,2)) + l2_num = torch.sqrt(l2_num) + + l2_denom = target**2 + l2_denom = torch.sum(l2_denom, dim=(1,2)) + l2_denom = torch.sqrt(l2_denom) + + l2 = l2_num / l2_denom + + metrics = { + "l2_displacement_x": torch.mean(l2[:, 0]), + "l2_displacement_y": torch.mean(l2[:, 1]), + "l2_displacement_z": torch.mean(l2[:, 2]), + } + + return metrics + + +def metrics_fn_volume( + pred: torch.Tensor, + target: torch.Tensor, +) -> dict[str, torch.Tensor]: + """ + Computes L2 volume metrics between prediction and target. + """ + l2_num = (pred - target) ** 2 + l2_num = torch.sum(l2_num, dim=1) + l2_num = torch.sqrt(l2_num) + + l2_denom = target**2 + l2_denom = torch.sum(l2_denom, dim=1) + l2_denom = torch.sqrt(l2_denom) + + l2 = l2_num / l2_denom + + metrics = { + "stress": torch.mean(l2[:, 0]), + } + + return metrics + + +def all_reduce_dict( + metrics: dict[str, torch.Tensor], dm: DistributedManager +) -> dict[str, torch.Tensor]: + """ + Reduces a dictionary of metrics across all distributed processes. + + Args: + metrics: Dictionary of metric names to torch.Tensor values. + dm: DistributedManager instance for distributed context. + + Returns: + Dictionary of reduced metrics. + """ + # TODO - update this to use domains and not the full world + + if dm.world_size == 1: + return metrics + + for key, value in metrics.items(): + dist.all_reduce(value) + value = value / dm.world_size + metrics[key] = value + + return metrics + +def extract_index_from_filename(filename: str, pattern: str = "auto") -> int: + """Extract numeric index from filename using various patterns. + + This function extracts numeric indices from filenames to help with + ordering and processing files in sequence. + + Args: + filename: The filename to extract index from. + pattern: Pattern to use for extraction: + - "auto": Automatically detect common patterns + - "suffix": Extract number at end before extension (file_001.csv) + - "prefix": Extract number at beginning (001_file.csv) + - "middle": Extract first number found anywhere + - "last": Extract last number found anywhere + + Returns: + int: Extracted index number, or -1 if no number found. + + Examples: + # Various filename patterns + extract_index_from_filename("data_001.csv") # Returns: 1 + extract_index_from_filename("file001.csv") # Returns: 1 + extract_index_from_filename("001_data.csv") # Returns: 1 + extract_index_from_filename("mesh_5_final.csv") # Returns: 5 + extract_index_from_filename("output123data.csv") # Returns: 123 + """ + import re + + # Remove file extension for cleaner processing + base_name = Path(filename).stem + + if pattern == "auto": + # Try different patterns in order of preference + patterns = [ + r'_(\d+)$', # Underscore followed by number at end: file_001 + r'(\d+)$', # Number at end: file001 + r'^(\d+)_', # Number at start with underscore: 001_file + r'^(\d+)', # Number at start: 001file + r'_(\d+)_', # Number between underscores: file_001_data + r'(\d+)', # Any number (first occurrence) + ] + + for p in patterns: + match = re.search(p, base_name) + if match: + return int(match.group(1)) + + elif pattern == "suffix": + # Extract number at end before extension + match = re.search(r'(\d+)$', base_name) + if match: + return int(match.group(1)) + + elif pattern == "prefix": + # Extract number at beginning + match = re.search(r'^(\d+)', base_name) + if match: + return int(match.group(1)) + + elif pattern == "middle": + # Extract first number found + match = re.search(r'(\d+)', base_name) + if match: + return int(match.group(1)) + + elif pattern == "last": + # Extract last number found + matches = re.findall(r'(\d+)', base_name) + if matches: + return int(matches[-1]) + + # No number found + return -1 + +def extract_time_series_info(mesh: pv.PolyData, data_prefix: str = "displacement") -> Dict: + """Extract information about time series data in the mesh. + + Args: + mesh: PyVista mesh object. + data_prefix: Prefix of the time series data fields. + + Returns: + dict: Information about the time series including timesteps and field names. + """ + # Find all arrays that match the prefix + time_arrays = [name for name in mesh.array_names if name.startswith(data_prefix)] + magnitude_arrays = [name for name in time_arrays if "magnitude" in name] + vector_arrays = [name for name in time_arrays if "magnitude" not in name] + + # Extract timesteps from field names + timesteps = [] + for name in vector_arrays: + # Extract timestep from name like "displacement_t0.123" + if "_t" in name: + try: + timestep_str = name.split("_t")[1] + timestep = float(timestep_str) + timesteps.append(timestep) + except (IndexError, ValueError): + print(f"Warning: Could not extract timestep from {name}") + + timesteps = sorted(timesteps) + + info = { + 'n_timesteps': len(timesteps), + 'timesteps': np.array(timesteps), + 'vector_arrays': sorted(vector_arrays), + 'magnitude_arrays': sorted(magnitude_arrays), + 'all_time_arrays': sorted(time_arrays), + 'data_prefix': data_prefix + } + + return info + +def get_time_series_data(mesh: pv.PolyData, data_prefix: str = "displacement") -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Extract time series data from mesh into numpy arrays. + + Args: + mesh: PyVista mesh object. + data_prefix: Prefix of the time series data fields. + + Returns: + tuple: (timesteps, vector_data, magnitude_data) + - timesteps: Array of timestep values + - vector_data: Array of shape (n_timesteps, n_points, 3) + - magnitude_data: Array of shape (n_timesteps, n_points) + """ + info = extract_time_series_info(mesh, data_prefix) + + if info['n_timesteps'] == 0: + print(f"No time series data found with prefix '{data_prefix}'") + return np.array([]), np.array([]), np.array([]) + + n_points = mesh.n_points + n_timesteps = info['n_timesteps'] + timesteps = info['timesteps'] + + # Initialize arrays + vector_data = np.zeros((n_timesteps, n_points, 3)) + magnitude_data = np.zeros((n_timesteps, n_points)) + + # Extract data for each timestep + for i, timestep in enumerate(timesteps): + # Vector data + vector_field_name = f"{data_prefix}_t{timestep:.3f}" + if vector_field_name in mesh.array_names: + vector_data[i, :, :] = mesh[vector_field_name] + + # Magnitude data + magnitude_field_name = f"{data_prefix}_magnitude_t{timestep:.3f}" + if magnitude_field_name in mesh.array_names: + magnitude_data[i, :] = mesh[magnitude_field_name] + else: + # Calculate magnitude if not stored + magnitude_data[i, :] = np.linalg.norm(vector_data[i, :, :], axis=1) + + return timesteps, vector_data, magnitude_data diff --git a/examples/structural_mechanics/crash_domino/src/validate_cache.py b/examples/structural_mechanics/crash_domino/src/validate_cache.py new file mode 100644 index 0000000000..f22039dd7b --- /dev/null +++ b/examples/structural_mechanics/crash_domino/src/validate_cache.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script processes DoMINODataPipe format files into cached versions +for faster loading during training. It processes files in parallel and can be +configured through config.yaml in the data_processing tab. +""" + +from physicsnemo.datapipes.cae.domino_datapipe import ( + CachedDoMINODataset, + DoMINODataPipe, +) +import hydra +import numpy as np +import os +from omegaconf import DictConfig +import torch +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from physicsnemo.distributed import DistributedManager + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + assert cfg.data_processor.use_cache, "Cache must be enabled to be validated!" + # initialize distributed manager + DistributedManager.initialize() + dist = DistributedManager() + + vol_save_path = os.path.join(cfg.project_dir, "volume_scaling_factors.npy") + surf_save_path = os.path.join(cfg.project_dir, "surface_scaling_factors.npy") + if os.path.exists(vol_save_path): + vol_factors = np.load(vol_save_path) + else: + vol_factors = None + + if os.path.exists(surf_save_path): + surf_factors = np.load(surf_save_path) + else: + surf_factors = None + + # Set up variables based on model type + model_type = cfg.model.model_type + volume_variable_names = [] + surface_variable_names = [] + + if model_type in ["volume", "combined"]: + volume_variable_names = list(cfg.variables.volume.solution.keys()) + if model_type in ["surface", "combined"]: + surface_variable_names = list(cfg.variables.surface.solution.keys()) + + # Create dataset once + dataset_orig = DoMINODataPipe( + data_path=cfg.data_processor.output_dir, # Caching comes after data processing + phase="train", # Phase doesn't matter for caching + grid_resolution=cfg.model.interp_res, + volume_variables=volume_variable_names, + surface_variables=surface_variable_names, + normalize_coordinates=True, + sampling=True, + sample_in_bbox=True, + volume_points_sample=cfg.model.volume_points_sample, + surface_points_sample=cfg.model.surface_points_sample, + geom_points_sample=cfg.model.geom_points_sample, + positional_encoding=cfg.model.positional_encoding, + volume_factors=vol_factors, + surface_factors=surf_factors, + scaling_type=cfg.model.normalization, + model_type=cfg.model.model_type, + bounding_box_dims=cfg.data.bounding_box, + bounding_box_dims_surf=cfg.data.bounding_box_surface, + num_surface_neighbors=cfg.model.num_surface_neighbors, + for_caching=False, + deterministic_seed=True, + ) + + dataset_cached = CachedDoMINODataset( + data_path=cfg.data_processor.cached_dir, + phase="train", + sampling=True, + volume_points_sample=cfg.model.volume_points_sample, + surface_points_sample=cfg.model.surface_points_sample, + geom_points_sample=cfg.model.geom_points_sample, + model_type=cfg.model.model_type, + deterministic_seed=True, + ) + + # Wait for directory creation + if dist.world_size > 1: + torch.distributed.barrier() + + def get_dataloader(dataset, world_size, rank): + sampler = DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=False + ) + + return DataLoader( + dataset, + sampler=sampler, + batch_size=1, # Process one at a time for caching + num_workers=0, # Must be 0 due to GPU operations in dataset + ) + + dataloader_orig = get_dataloader(dataset_orig, dist.world_size, dist.rank) + dataloader_cached = get_dataloader(dataset_cached, dist.world_size, dist.rank) + + # Process and cache files + for _, (sample_orig, sample_cached) in enumerate( + zip(dataloader_orig, dataloader_cached) + ): + filename_orig = sample_orig["filename"][0] + filename_cached = sample_cached["filename"][0] + mismatched = False + if filename_orig != filename_cached: + print( + f"Rank {dist.rank}: Mismatched filenames: {filename_orig} != {filename_cached}" + ) + mismatched = True + for k, v in sample_orig.items(): + if k in ["filename"]: + continue + if k not in sample_cached: + print(f"Rank {dist.rank}: Key {k} missing from cached sample") + mismatched = True + elif not torch.allclose(v, sample_cached[k]): + print(f"Rank {dist.rank}: Mismatched values for key {k}") + # Get boolean mask of mismatches + mismatches = v != sample_cached[k] + # Get indices where values mismatch + mismatch_indices = torch.nonzero(mismatches, as_tuple=False) + print( + f" Found {len(mismatch_indices)} mismatches, of {v.numel()} total values" + ) + print(f" Tensor shape: {v.shape}, vs {sample_cached[k].shape}") + # Get the actual values at those positions + for idx in mismatch_indices[:5]: # Show only first 5 mismatches + idx_tuple = tuple( + idx.tolist() + ) # Convert index tensor to tuple for indexing + val1 = v[idx_tuple].item() + val2 = sample_cached[k][idx_tuple].item() + print(f" Index {idx_tuple}: {val1} vs {val2}") + mismatched = True + if mismatched: + print(f"FAILED Rank {dist.rank}: {filename_orig}") + else: + print(f"Rank {dist.rank}: {filename_orig} validated") + + # Wait for all processes to complete + if dist.world_size > 1: + torch.distributed.barrier() + + if dist.rank == 0: + print("All processing complete!") + + +if __name__ == "__main__": + main() diff --git a/physicsnemo/datapipes/cae/cae_dataset.py b/physicsnemo/datapipes/cae/cae_dataset.py index 705c18f92a..7c89b68523 100644 --- a/physicsnemo/datapipes/cae/cae_dataset.py +++ b/physicsnemo/datapipes/cae/cae_dataset.py @@ -1135,7 +1135,7 @@ def __del__(self): def compute_mean_std_min_max( - dataset: CAEDataset, field_keys: list[str], max_samples: int = 20 + dataset: CAEDataset, field_keys: list[str], max_samples: int = 20, transient: bool = False ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute the mean, standard deviation, minimum, and maximum for a specified field @@ -1146,7 +1146,8 @@ def compute_mean_std_min_max( Args: dataset (CAEDataset): The dataset to process. field_key (str): The key for the field to normalize. - + max_samples (int): The maximum number of samples to process. + transient (bool): Whether the dataset is transient. Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: mean, std, min, max tensors for the field. @@ -1184,6 +1185,11 @@ def compute_mean_std_min_max( device=example_data[key].device, ) + if transient: + axis = (0, 1) + else: + axis = (0,) + global_start = time.perf_counter() start = time.perf_counter() data_list = np.arange(len(dataset)) @@ -1197,10 +1203,12 @@ def compute_mean_std_min_max( field_data = data[field_key] # Compute batch statistics - batch_mean = field_data.mean(axis=(0)) - batch_M2 = ((field_data - batch_mean) ** 2).sum(axis=(0)) - batch_n = field_data.shape[0] + batch_mean = field_data.mean(axis=axis) + batch_M2 = ((field_data - batch_mean) ** 2).sum(axis=axis) + if transient: + batch_n = field_data.shape[0] * field_data.shape[1] + # Update running mean and M2 (Welford's algorithm) delta = batch_mean - mean[field_key] N[field_key] += batch_n # batch_n should also be torch.int64 @@ -1235,6 +1243,10 @@ def compute_mean_std_min_max( for field_key in field_keys: field_data = data[field_key] + if transient: + if len(field_data.shape) == 3: + field_data = field_data.reshape(field_data.shape[0] * field_data.shape[1], field_data.shape[-1]) + batch_n = field_data.shape[0] # # Update min/max diff --git a/physicsnemo/datapipes/cae/domino_datapipe_transient.py b/physicsnemo/datapipes/cae/domino_datapipe_transient.py new file mode 100644 index 0000000000..e33962973b --- /dev/null +++ b/physicsnemo/datapipes/cae/domino_datapipe_transient.py @@ -0,0 +1,1495 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code provides the datapipe for reading the processed npy files, +generating multi-res grids, calculating signed distance fields, +sampling random points in the volume and on surface, +normalizing fields and returning the output tensors as a dictionary. + +This datapipe also non-dimensionalizes the fields, so the order in which the variables should +be fixed: velocity, pressure, turbulent viscosity for volume variables and +pressure, wall-shear-stress for surface variables. The different parameters such as +variable names, domain resolution, sampling size etc. are configurable in config.yaml. +""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable, Literal, Optional, Protocol, Sequence, Union + +import numpy as np +import torch +import torch.cuda.nvtx as nvtx +from omegaconf import DictConfig +from torch.distributed.tensor.placement_types import Replicate +from torch.utils.data import Dataset + +from physicsnemo.datapipes.cae.cae_dataset import ( + CAEDataset, + compute_mean_std_min_max, +) +from physicsnemo.distributed import DistributedManager +from physicsnemo.distributed.shard_tensor import ShardTensor, scatter_tensor +from physicsnemo.utils.domino.utils import ( + calculate_center_of_mass, + create_grid, + get_filenames, + normalize, + pad, + shuffle_array, + standardize, + unnormalize, + unstandardize, + repeat_array, +) +from physicsnemo.utils.neighbors import knn +from physicsnemo.utils.profiling import profile +from physicsnemo.utils.sdf import signed_distance_field + + +class BoundingBox(Protocol): + """ + Type definition for the required format of bounding box dimensions. + """ + + min: Sequence + max: Sequence + + +@dataclass +class DoMINODataConfig: + """Configuration for DoMINO dataset processing pipeline. + + Attributes: + data_path: Path to the dataset to load. + phase: Which phase of data to load ("train", "val", or "test"). + surface_variables: (Surface specific) Names of surface variables. + surface_points_sample: (Surface specific) Number of surface points to sample per batch. + num_surface_neighbors: (Surface specific) Number of surface neighbors to consider for nearest neighbors approach. + surface_sampling_algorithm: (Surface specific) Algorithm to use for surface sampling ("area_weighted" or "random"). + surface_factors: (Surface specific) Non-dimensionalization factors for surface variables. + If set, and scaling_type is: + - min_max_scaling -> rescale surface_fields to the min/max set here + - mean_std_scaling -> rescale surface_fields to the mean and std set here. + bounding_box_dims_surf: (Surface specific) Dimensions of bounding box. Must be an object with min/max + attributes that are arraylike. + volume_variables: (Volume specific) Names of volume variables. + volume_points_sample: (Volume specific) Number of volume points to sample per batch. + volume_sample_from_disk: (Volume specific) If the volume data is in a shuffled state on disk, + read contiguous chunks of the data rather than the entire volume data. This greatly + accelerates IO in bandwidth limited systems or when the volumetric data is very large. + volume_factors: (Volume specific) Non-dimensionalization factors for volume variables scaling. + If set, and scaling_type is: + - min_max_scaling -> rescale volume_fields to the min/max set here + - mean_std_scaling -> rescale volume_fields to the mean and std set here. + bounding_box_dims: (Volume specific) Dimensions of bounding box. Must be an object with min/max + attributes that are arraylike. + grid_resolution: Resolution of the latent grid. + normalize_coordinates: Whether to normalize coordinates based on min/max values. + For surfaces: uses s_min/s_max, defined from: + - Surface bounding box, if defined. + - Min/max of the stl_vertices + For volumes: uses c_min/c_max, defined from: + - Volume bounding_box if defined, + - 1.5x s_min/max otherwise, except c_min[2] = s_min[2] in this case + sample_in_bbox: Whether to sample points in a specified bounding box. + Uses the same min/max points as coordinate normalization. + Only performed if compute_scaling_factors is false. + sampling: Whether to downsample the full resolution mesh to fit in GPU memory. + Surface and volume sampling points are configured separately as: + - surface.points_sample + - volume.points_sample + geom_points_sample: Number of STL points sampled per batch. + Independent of volume.points_sample and surface.points_sample. + scaling_type: Scaling type for volume variables. + If used, will rescale the volume_fields and surface fields outputs. + Requires volume.factor and surface.factor to be set. + compute_scaling_factors: Whether to compute scaling factors. + Not available if caching. + Many preprocessing pieces are disabled if computing scaling factors. + caching: Whether this is for caching or serving. + deterministic: Whether to use a deterministic seed for sampling and random numbers. + gpu_preprocessing: Whether to do preprocessing on the GPU (False for CPU). + gpu_output: Whether to return output on the GPU as cupy arrays. + If False, returns numpy arrays. + You might choose gpu_preprocessing=True and gpu_output=False if caching. + """ + + data_path: Path | None + phase: Literal["train", "val", "test"] + mesh_type: Literal["element", "node"] = "element" + + # Surface-specific variables: + surface_variables: Optional[Sequence] = ("pMean", "wallShearStress") + surface_points_sample: int = 1024 + num_surface_neighbors: int = 11 + surface_sampling_algorithm: str = Literal["area_weighted", "random"] + surface_factors: Optional[Sequence] = None + bounding_box_dims_surf: Optional[Union[BoundingBox, Sequence]] = None + use_surface_normals: bool = False + use_surface_area: bool = False + + # Volume specific variables: + volume_variables: Optional[Sequence] = ("UMean", "pMean") + volume_points_sample: int = 1024 + volume_sample_from_disk: bool = False + volume_factors: Optional[Sequence] = None + bounding_box_dims: Optional[Union[BoundingBox, Sequence]] = None + + # Transient specific variables: + time_points_sample: int = 10 + transient_scheme: str = "explicit" # "explicit" or "implicit" + transient: bool = False # Whether to use transient model + + grid_resolution: Sequence = (256, 96, 64) + normalize_coordinates: bool = False + sample_in_bbox: bool = False + sampling: bool = False + geom_points_sample: int = 300000 + scaling_type: Optional[Literal["min_max_scaling", "mean_std_scaling"]] = None + compute_scaling_factors: bool = False + caching: bool = False + deterministic: bool = False + gpu_preprocessing: bool = True + gpu_output: bool = True + + def __post_init__(self): + if self.data_path is not None: + # Ensure data_path is a Path object: + if isinstance(self.data_path, str): + self.data_path = Path(self.data_path) + self.data_path = self.data_path.expanduser() + + if not self.data_path.exists(): + raise ValueError(f"Path {self.data_path} does not exist") + + if not self.data_path.is_dir(): + raise ValueError(f"Path {self.data_path} is not a directory") + + # Object if caching settings are impossible: + if self.caching: + if self.sampling: + raise ValueError("Sampling should be False for caching") + if self.compute_scaling_factors: + raise ValueError("Compute scaling factors should be False for caching") + + if self.phase not in [ + "train", + "val", + "test", + ]: + raise ValueError( + f"phase should be one of ['train', 'val', 'test'], got {self.phase}" + ) + if self.scaling_type is not None: + if self.scaling_type not in [ + "min_max_scaling", + "mean_std_scaling", + ]: + raise ValueError( + f"scaling_type should be one of ['min_max_scaling', 'mean_std_scaling'], got {self.scaling_type}" + ) + + +##### TODO +# - The SDF normalization here is based on using a normalized mesh and +# a normalized coordinate. The alternate method is to normalize to the min/max of the grid. + + +class DoMINODataPipe(Dataset): + """ + Datapipe for DoMINO + + Leverages a dataset for the actual reading of the data, and this + object is responsible for preprocessing the data. + + """ + + def __init__( + self, + input_path, + model_type: Literal["surface", "volume", "combined"], + pin_memory: bool = False, + **data_config_overrides, + ): + # Perform config packaging and validation + self.config = DoMINODataConfig(data_path=input_path, **data_config_overrides) + + # Set up the distributed manager: + if not DistributedManager.is_initialized(): + DistributedManager.initialize() + + dist = DistributedManager() + + if self.config.mesh_type == "node": + if self.config.use_surface_normals is True or self.config.use_surface_area is True: + raise ValueError("use_surface_normals and use_surface_area must be False when mesh_type is node") + + # Set devices for the preprocessing and IO target + self.preproc_device = ( + dist.device if self.config.gpu_preprocessing else torch.device("cpu") + ) + # The cae_dataset will automatically target this device + # In an async transfer. + self.output_device = ( + dist.device if self.config.gpu_output else torch.device("cpu") + ) + + # Model type determines whether we process surface, volume, or both. + self.model_type = model_type + + # Update the arrays for bounding boxes: + if hasattr(self.config.bounding_box_dims, "max") and hasattr( + self.config.bounding_box_dims, "min" + ): + self.config.bounding_box_dims = [ + torch.tensor( + self.config.bounding_box_dims.max, + device=self.preproc_device, + dtype=torch.float32, + ), + torch.tensor( + self.config.bounding_box_dims.min, + device=self.preproc_device, + dtype=torch.float32, + ), + ] + self.default_volume_grid = create_grid( + self.config.bounding_box_dims[0], + self.config.bounding_box_dims[1], + self.config.grid_resolution, + ) + + # And, do the surface bounding box if supplied: + if hasattr(self.config.bounding_box_dims_surf, "max") and hasattr( + self.config.bounding_box_dims_surf, "min" + ): + self.config.bounding_box_dims_surf = [ + torch.tensor( + self.config.bounding_box_dims_surf.max, + device=self.preproc_device, + dtype=torch.float32, + ), + torch.tensor( + self.config.bounding_box_dims_surf.min, + device=self.preproc_device, + dtype=torch.float32, + ), + ] + + self.default_surface_grid = create_grid( + self.config.bounding_box_dims_surf[0], + self.config.bounding_box_dims_surf[1], + self.config.grid_resolution, + ) + + # Ensure the volume and surface scaling factors are torch tensors + # and on the right device: + if self.config.volume_factors is not None: + if not isinstance(self.config.volume_factors, torch.Tensor): + self.config.volume_factors = torch.from_numpy( + self.config.volume_factors + ) + self.config.volume_factors = self.config.volume_factors.to( + self.preproc_device, dtype=torch.float32 + ) + if self.config.surface_factors is not None: + if not isinstance(self.config.surface_factors, torch.Tensor): + self.config.surface_factors = torch.from_numpy( + self.config.surface_factors + ) + self.config.surface_factors = self.config.surface_factors.to( + self.preproc_device, dtype=torch.float32 + ) + + self.dataset = None + + def compute_stl_scaling_and_surface_grids( + self, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the min and max for the defining mesh. + + If the user supplies a bounding box, we use that. Otherwise, + it raises an error. + + The returned min/max and grid are used for surface data. + """ + + # Check the bounding box is not unit length + + if self.config.bounding_box_dims_surf is not None: + s_max = self.config.bounding_box_dims_surf[0] + s_min = self.config.bounding_box_dims_surf[1] + surf_grid = self.default_surface_grid + else: + raise ValueError("Bounding box dimensions are not set in config") + + return s_min, s_max, surf_grid + + def compute_volume_scaling_and_grids( + self, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the min and max and grid for volume data. + + If the user supplies a bounding box, we use that. Otherwise, + it raises an error. + + """ + + # Determine the volume min / max locations + if self.config.bounding_box_dims is not None: + c_max = self.config.bounding_box_dims[0] + c_min = self.config.bounding_box_dims[1] + volume_grid = self.default_volume_grid + else: + raise ValueError("Bounding box dimensions are not set in config") + + return c_min, c_max, volume_grid + + @profile + def downsample_geometry( + self, + stl_vertices, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Downsample the geometry to the desired number of points. + + Args: + stl_vertices: The vertices of the surface. + """ + + if self.config.sampling: + geometry_points = self.config.geom_points_sample + + geometry_coordinates_sampled, idx_geometry = shuffle_array( + stl_vertices, geometry_points + ) + if geometry_coordinates_sampled.shape[0] < geometry_points: + raise ValueError( + "Surface mesh has fewer points than requested sample size" + ) + geom_centers = geometry_coordinates_sampled + else: + geom_centers = stl_vertices + idx_geometry = None + + return geom_centers, idx_geometry + + def process_surface( + self, + s_min: torch.Tensor, + s_max: torch.Tensor, + c_min: torch.Tensor, + c_max: torch.Tensor, + *, # Forcing the rest by keyword only since it's a long list ... + center_of_mass: torch.Tensor, + surf_grid: torch.Tensor, + surface_coordinates: torch.Tensor, + surface_normals: torch.Tensor, + surface_sizes: torch.Tensor, + stl_vertices: torch.Tensor, + stl_indices: torch.Tensor, + surface_fields: torch.Tensor | None, + timesteps: torch.Tensor | None, + surface_features: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + nx, ny, nz = self.config.grid_resolution + + return_dict = {} + + ######################################################################## + # Remove any sizes <= 0: + ######################################################################## + if self.config.mesh_type == "element": + idx = surface_sizes > 0 + if surface_sizes is not None: + surface_sizes = surface_sizes[idx] + if surface_normals is not None: + surface_normals = surface_normals[idx] + surface_coordinates = surface_coordinates[idx] + if surface_fields is not None: + surface_fields = surface_fields[idx] + if surface_features is not None: + surface_features = surface_features[idx] + ######################################################################## + # Reject surface points outside of the Bounding Box + # NOTE - this is using the VOLUME bounding box! + ######################################################################## + if self.config.sample_in_bbox: + ids_min = surface_coordinates[0, :] > c_min + ids_max = surface_coordinates[0, :] < c_max + + ids_in_bbox = ids_min & ids_max + ids_in_bbox = ids_in_bbox.all(dim=-1) + + surface_coordinates = surface_coordinates[:, ids_in_bbox] + if self.config.mesh_type == "element" and surface_normals is not None: + surface_normals = surface_normals[:, ids_in_bbox] + if self.config.mesh_type == "element" and surface_sizes is not None: + surface_sizes = surface_sizes[:, ids_in_bbox] + if surface_fields is not None: + surface_fields = surface_fields[:, ids_in_bbox] + if surface_features is not None: + surface_features = surface_features[:, ids_in_bbox] + + ######################################################################## + # Perform Down sampling of the surface fields. + # Note that we snapshot the full surface coordinates for + # use in the kNN in the next step. + ######################################################################## + + full_surface_coordinates = surface_coordinates + full_surface_features = surface_features + + if surface_normals is not None: + full_surface_normals = surface_normals + else: + full_surface_normals = None + if surface_sizes is not None: + full_surface_sizes = surface_sizes + else: + full_surface_sizes = None + + if self.config.sampling: + # Perform the down sampling: + if self.config.surface_sampling_algorithm == "area_weighted" and self.config.mesh_type == "element": + weights = surface_sizes + else: + weights = None + + surface_coordinates_sampled, idx_surface = shuffle_array( + surface_coordinates[0], + self.config.surface_points_sample, + weights=weights, + ) + + if surface_coordinates_sampled.shape[0] < self.config.surface_points_sample: + raise ValueError( + "Surface mesh has fewer points than requested sample size" + ) + + if self.config.transient: + if self.config.transient_scheme == "explicit": + timesteps_sampled, idx_time = shuffle_array(timesteps, self.config.time_points_sample) + timesteps_sampled = repeat_array(timesteps_sampled, self.config.surface_points_sample, axis=1, new_axis=True) + timesteps_sampled = torch.unsqueeze(timesteps_sampled, axis=-1) + elif self.config.transient_scheme == "implicit": + idx_time_start = torch.randint(low=0, high=surface_fields.shape[0]-self.config.time_points_sample, size=(1,)) + timesteps_sampled = timesteps[idx_time_start:idx_time_start+self.config.time_points_sample] + timesteps_sampled = repeat_array(timesteps_sampled, self.config.surface_points_sample, axis=1, new_axis=True) + timesteps_sampled = torch.unsqueeze(timesteps_sampled, axis=-1) + else: + raise ValueError(f"Invalid transient scheme: {self.config.transient_scheme}") + + # Select out the sampled points for non-neighbor arrays: + if surface_fields is not None: + if self.config.transient: + if self.config.transient_scheme == "explicit": + surface_fields_time = surface_fields[idx_time] + surface_fields = surface_fields_time[:, idx_surface] + elif self.config.transient_scheme == "implicit": + surface_fields_time = surface_fields[idx_time_start:idx_time_start+self.config.time_points_sample] + surface_fields = surface_fields_time[:, idx_surface] + else: + raise ValueError(f"Invalid transient scheme: {self.config.transient_scheme}") + else: + surface_fields = surface_fields[idx_surface] + + # Subsample the normals and sizes: + if self.config.mesh_type == "element": + if surface_normals is not None: + surface_normals = surface_normals[:, idx_surface] + if surface_sizes is not None: + surface_sizes = surface_sizes[:, idx_surface] + + # Update the coordinates to the sampled points: + surface_coordinates = surface_coordinates[:, idx_surface] + if surface_features is not None: + surface_features = surface_features[:, idx_surface] + + if self.config.transient: + if self.config.transient_scheme == "explicit": + idx_time[:] = 0 + surface_coordinates = surface_coordinates[idx_time] + if surface_features is not None: + surface_features = surface_features[idx_time] + if surface_normals is not None: + surface_normals = surface_normals[idx_time] + if surface_sizes is not None: + surface_sizes = surface_sizes[idx_time] + elif self.config.transient_scheme == "implicit": + surface_coordinates = surface_coordinates[idx_time_start:idx_time_start+self.config.time_points_sample] + if surface_features is not None: + surface_features = surface_features[idx_time_start:idx_time_start+self.config.time_points_sample] + + if surface_normals is not None: + surface_normals = surface_normals[idx_time_start:idx_time_start+self.config.time_points_sample] + + if surface_sizes is not None: + surface_sizes = surface_sizes[idx_time_start:idx_time_start+self.config.time_points_sample] + else: + raise ValueError(f"Invalid transient scheme: {self.config.transient_scheme}") + + ######################################################################## + # Perform a kNN on the surface to find the neighbor information + ######################################################################## + if self.config.num_surface_neighbors > 1: + # Perform the kNN: + neighbor_indices, neighbor_distances = knn( + points=full_surface_coordinates[0], + queries=surface_coordinates[0], + k=self.config.num_surface_neighbors, + ) + # print(f"Full surface coordinates shape: {full_surface_coordinates.shape}") + # Pull out the neighbor elements. + # Note that `neighbor_indices` is the index into the original, + # full sized tensors (full_surface_coordinates, etc). + surface_neighbors = full_surface_coordinates[:, neighbor_indices][:, :, 1:] + if surface_features is not None: + surface_neighbors_features = full_surface_features[:, neighbor_indices][:, :, 1:] + else: + surface_neighbors_features = None + if self.config.transient: + if self.config.transient_scheme == "explicit": + surface_neighbors = surface_neighbors[idx_time] + if surface_features is not None: + surface_neighbors_features = surface_neighbors_features[idx_time] + elif self.config.transient_scheme == "implicit": + surface_neighbors = surface_neighbors[idx_time_start:idx_time_start+self.config.time_points_sample] + if surface_features is not None: + surface_neighbors_features = surface_neighbors_features[idx_time_start:idx_time_start+self.config.time_points_sample] + else: + raise ValueError(f"Invalid transient scheme: {self.config.transient_scheme}") + + timesteps_neighbors = repeat_array(timesteps_sampled, self.config.num_surface_neighbors-1, axis=2, new_axis=True) + + if self.config.mesh_type == "element": + if full_surface_normals is not None: + surface_neighbors_normals = full_surface_normals[:, neighbor_indices][:, :, 1:] + else: + surface_neighbors_normals = None + if full_surface_sizes is not None: + surface_neighbors_sizes = full_surface_sizes[:, neighbor_indices][:, :, 1:] + else: + surface_neighbors_sizes = None + else: + surface_neighbors_normals = None + surface_neighbors_sizes = None + + if self.config.transient: + if self.config.transient_scheme == "explicit": + if surface_neighbors_normals is not None: + surface_neighbors_normals = surface_neighbors_normals[idx_time] + if surface_neighbors_sizes is not None: + surface_neighbors_sizes = surface_neighbors_sizes[idx_time] + elif self.config.transient_scheme == "implicit": + if surface_neighbors_normals is not None: + surface_neighbors_normals = surface_neighbors_normals[idx_time_start:idx_time_start+self.config.time_points_sample] + if surface_neighbors_sizes is not None: + surface_neighbors_sizes = surface_neighbors_sizes[idx_time_start:idx_time_start+self.config.time_points_sample] + else: + surface_neighbors = surface_coordinates + if surface_features is not None: + surface_neighbors_features = surface_features + else: + surface_neighbors_features = None + + # Better to normalize everything after the kNN and sampling + if self.config.normalize_coordinates: + surface_coordinates = normalize(surface_coordinates, s_max, s_min) + surface_neighbors = normalize(surface_neighbors, s_max, s_min) + center_of_mass = normalize(center_of_mass, s_max, s_min) + + pos_normals_com_surface = surface_coordinates - center_of_mass + + if self.config.transient: + surface_coordinates = torch.cat([surface_coordinates, timesteps_sampled], axis=-1) + if self.config.num_surface_neighbors > 1: + surface_neighbors = torch.cat([surface_neighbors, timesteps_neighbors], axis=-1) + + ######################################################################## + # Apply scaling to the targets, if desired: + ######################################################################## + if self.config.scaling_type is not None and surface_fields is not None: + surface_fields = self.scale_model_targets( + surface_fields, self.config.surface_factors + ) + + return_dict.update( + { + "pos_surface_center_of_mass": pos_normals_com_surface, + "surface_mesh_centers": surface_coordinates, + "surface_mesh_neighbors": surface_neighbors, + } + ) + + if surface_normals is not None: + return_dict["surface_normals"] = surface_normals + if surface_sizes is not None: + return_dict["surface_areas"] = surface_sizes + if surface_neighbors_normals is not None: + return_dict["surface_neighbors_normals"] = surface_neighbors_normals + if surface_neighbors_sizes is not None: + return_dict["surface_neighbors_areas"] = surface_neighbors_sizes + + if surface_features is not None: + return_dict["surface_features"] = surface_features + return_dict["surface_neighbors_features"] = surface_neighbors_features + if surface_fields is not None: + return_dict["surface_fields"] = surface_fields + + return return_dict + + def process_volume( + self, + c_min: torch.Tensor, + c_max: torch.Tensor, + volume_coordinates: torch.Tensor, + volume_grid: torch.Tensor, + center_of_mass: torch.Tensor, + stl_vertices: torch.Tensor, + stl_indices: torch.Tensor, + volume_fields: torch.Tensor | None, + volume_features: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + """ + Preprocess the volume data. + + First, if configured, we reject points not in the volume bounding box. + + Next, if sampling is enabled, we sample the volume points and apply that + sampling to the ground truth too, if it's present. + + """ + ######################################################################## + # Reject points outside the volumetric BBox + ######################################################################## + if self.config.sample_in_bbox: + # Remove points in the volume that are outside + # of the bbox area. + min_check = volume_coordinates[0, :] > c_min + max_check = volume_coordinates[0, :] < c_max + + ids_in_bbox = min_check & max_check + ids_in_bbox = ids_in_bbox.all(dim=1) + + volume_coordinates = volume_coordinates[:, ids_in_bbox] + if volume_fields is not None: + volume_fields = volume_fields[:, ids_in_bbox] + if volume_features is not None: + volume_features = volume_features[:, ids_in_bbox] + ######################################################################## + # Apply sampling to the volume coordinates and fields + ######################################################################## + + # If the volume data has been sampled from disk, directly, then + # still apply sampling. We over-pull from disk deliberately. + if self.config.sampling: + # Generate a series of idx to sample the volume + # without replacement + volume_coordinates_sampled, idx_volume = shuffle_array( + volume_coordinates[0], self.config.volume_points_sample + ) + # volume_coordinates_sampled = volume_coordinates[idx_volume] + # In case too few points are in the sampled data (because the + # inputs were too few), pad the outputs: + if volume_coordinates_sampled.shape[0] < self.config.volume_points_sample: + raise ValueError( + "Volume mesh has fewer points than requested sample size" + ) + + if self.config.transient: + if self.config.transient_scheme == "explicit": + timesteps_sampled, idx_time = shuffle_array(timesteps, self.config.time_points_sample) + timesteps_sampled = repeat_array(timesteps_sampled, self.config.volume_points_sample, axis=1, new_axis=True) + timesteps_sampled = torch.unsqueeze(timesteps_sampled, axis=-1) + elif self.config.transient_scheme == "implicit": + idx_time_start = torch.randint(low=0, high=volume_fields.shape[0]-self.config.time_points_sample, size=(1,)) + timesteps_sampled = timesteps[idx_time_start:idx_time_start+self.config.time_points_sample] + timesteps_sampled = repeat_array(timesteps_sampled, self.config.volume_points_sample, axis=1, new_axis=True) + timesteps_sampled = torch.unsqueeze(timesteps_sampled, axis=-1) + else: + raise ValueError(f"Invalid transient scheme: {self.config.transient_scheme}") + + # Apply the same sampling to the targets, too: + if self.config.transient: + if self.config.transient_scheme == "explicit": + volume_fields_time = volume_fields[idx_time] + volume_fields = volume_fields_time[:, idx_volume] + elif self.config.transient_scheme == "implicit": + volume_fields_time = volume_fields[idx_time_start:idx_time_start+self.config.time_points_sample] + volume_fields = volume_fields_time[:, idx_volume] + else: + raise ValueError(f"Invalid transient scheme: {self.config.transient_scheme}") + else: + volume_fields = volume_fields[:, idx_volume] + + if self.config.transient: + if self.config.transient_scheme == "explicit": + idx_time[:] = 0 + volume_coordinates = volume_coordinates[idx_time] + volume_coordinates = volume_coordinates[:, idx_volume] + if volume_features is not None: + volume_features = volume_features[idx_time] + elif self.config.transient_scheme == "implicit": + volume_coordinates = volume_coordinates[idx_time_start:idx_time_start+self.config.time_points_sample] + volume_coordinates = volume_coordinates[:, idx_volume] + if volume_features is not None: + volume_features = volume_features[idx_time_start:idx_time_start+self.config.time_points_sample] + else: + raise ValueError(f"Invalid transient scheme: {self.config.transient_scheme}") + + ######################################################################## + # Apply normalization to the coordinates, if desired: + ######################################################################## + if self.config.normalize_coordinates: + volume_coordinates = normalize(volume_coordinates, c_max, c_min) + grid = normalize(volume_grid, c_max, c_min) + normed_vertices = normalize(stl_vertices, c_max, c_min) + center_of_mass = normalize(center_of_mass, c_max, c_min) + else: + grid = volume_grid + normed_vertices = stl_vertices + center_of_mass = center_of_mass + + ######################################################################## + # Apply scaling to the targets, if desired: + ######################################################################## + if self.config.scaling_type is not None and volume_fields is not None: + volume_fields = self.scale_model_targets( + volume_fields, self.config.volume_factors + ) + + ######################################################################## + # Compute Signed Distance Function for volumetric quantities + # Note - the SDF happens here, after volume data processing finishes, + # because we need to use the (maybe) normalized volume coordinates and grid + ######################################################################## + + if mesh_indices is not None: + sdf_grid, _ = signed_distance_field( + normed_vertices, + mesh_indices, + grid, + use_sign_winding_number=True, + ) + # Get the SDF of all the selected volume coordinates, + # And keep the closest point to each one. + sdf_nodes, sdf_node_closest_point = signed_distance_field( + normed_vertices, + stl_indices, + volume_coordinates[0], + use_sign_winding_number=True, + ) + sdf_nodes = sdf_nodes.reshape((-1, 1)) + else: + sdf_grid = None + sdf_nodes = None + sdf_node_closest_point = None + + # Use the closest point from the mesh to compute the volume encodings: + pos_normals_closest_vol, pos_normals_com_vol = self.calculate_volume_encoding( + volume_coordinates, sdf_node_closest_point, center_of_mass + ) + + return_dict = { + "volume_mesh_centers": volume_coordinates, + "grid": grid, + "pos_volume_center_of_mass": pos_normals_com_vol, + } + + if sdf_nodes is not None: + return_dict["sdf_nodes"] = sdf_nodes + if sdf_grid is not None: + return_dict["sdf_grid"] = sdf_grid + if pos_normals_closest_vol is not None: + return_dict["pos_volume_closest"] = pos_normals_closest_vol + if volume_features is not None: + return_dict["volume_features"] = volume_features + if volume_fields is not None: + return_dict["volume_fields"] = volume_fields + + return return_dict + + def calculate_volume_encoding( + self, + volume_coordinates: torch.Tensor, + sdf_node_closest_point: torch.Tensor, + center_of_mass: torch.Tensor, + ): + if sdf_node_closest_point is not None: + pos_normals_closest_vol = volume_coordinates - sdf_node_closest_point + else: + pos_normals_closest_vol = None + if center_of_mass is not None: + pos_normals_com_vol = volume_coordinates - center_of_mass + else: + pos_normals_com_vol = None + + return pos_normals_closest_vol, pos_normals_com_vol + + @torch.no_grad() + def process_data(self, data_dict): + return_dict = {} + # Validate that all required keys are present in data_dict + required_keys = [ + "stl_coordinates", + ] + + if "global_params_values" in data_dict: + required_keys.append("global_params_values") + if "global_params_reference" in data_dict: + required_keys.append("global_params_reference") + + if self.config.use_surface_normals: + required_keys.append("stl_faces") + required_keys.append("stl_normals") + required_keys.append("stl_centers") + if self.config.use_surface_area: + required_keys.append("stl_areas") + if self.config.transient: + required_keys.append("timesteps") + + missing_keys = [key for key in required_keys if key not in data_dict] + if missing_keys: + raise ValueError( + f"Missing required keys in data_dict: {missing_keys}. " + f"Required keys are: {required_keys}" + ) + + # Start building the preprocessed return dict: + if "global_params_values" in data_dict and "global_params_reference" in data_dict: + return_dict["global_params_values"] = data_dict["global_params_values"] + return_dict["global_params_reference"] = data_dict["global_params_reference"] + + ######################################################################## + # Process the core STL information + ######################################################################## + + # This function gets information about the surface scale, + # and decides what the surface grid will be: + + stl_coordinates = data_dict["stl_coordinates"] # (N, 3) coordinates, x, y, z + + s_min, s_max, surf_grid = self.compute_stl_scaling_and_surface_grids() + + if isinstance(stl_coordinates, ShardTensor): + mesh = stl_coordinates._spec.mesh + # Then, replicate the bounding box along the mesh if present. + s_max = scatter_tensor( + s_max, + 0, + mesh=mesh, + placements=[ + Replicate(), + ], + global_shape=s_max.shape, + dtype=s_max.dtype, + requires_grad=False, + ) + s_min = scatter_tensor( + s_min, + 0, + mesh=mesh, + placements=[ + Replicate(), + ], + global_shape=s_min.shape, + dtype=s_min.dtype, + requires_grad=False, + ) + surf_grid = scatter_tensor( + surf_grid, + 0, + mesh=mesh, + placements=[ + Replicate(), + ], + global_shape=surf_grid.shape, + dtype=surf_grid.dtype, + requires_grad=False, + ) + + # We always need to calculate the SDF on the surface grid: + # This is for the SDF Later: + if self.config.normalize_coordinates: + normed_vertices = normalize(data_dict["stl_coordinates"], s_max, s_min) + surf_grid = normalize(surf_grid, s_max, s_min) + else: + normed_vertices = data_dict["stl_coordinates"] + + # For SDF calculations, make sure the mesh_indices_flattened is an integer array: + if "stl_faces" in data_dict: + mesh_indices_flattened = data_dict["stl_faces"].to(torch.int32) # Make this optional + # Compute signed distance function for the surface grid: + # Make this optional + sdf_surf_grid, _ = signed_distance_field( + mesh_vertices=normed_vertices, + mesh_indices=mesh_indices_flattened, + input_points=surf_grid, + use_sign_winding_number=True, + ) + return_dict["sdf_surf_grid"] = sdf_surf_grid # Make this optional + else: + sdf_surf_grid = None + mesh_indices_flattened = None + + return_dict["surf_grid"] = surf_grid # Make this optional + + # Store this only if normalization is active: + if self.config.normalize_coordinates: + return_dict["surface_min_max"] = torch.stack([s_min, s_max]) + + # This is a center of mass computation for the stl surface, + # using the size of each mesh point as weight. + if "stl_centers" in data_dict and "stl_areas" in data_dict: + center_of_mass = calculate_center_of_mass( + data_dict["stl_centers"], data_dict["stl_areas"] + ) + else: + center_of_mass = torch.mean(data_dict["stl_coordinates"], dim=0) + + # This will apply downsampling if needed to the geometry coordinates + geom_centers, idx_geometry = self.downsample_geometry( + stl_vertices=data_dict["stl_coordinates"], + ) + return_dict["geometry_coordinates"] = geom_centers + if "stl_features" in data_dict: + return_dict["geometry_features"] = data_dict["stl_features"][idx_geometry] + + ######################################################################## + # Determine the volumetric bounds of the data: + ######################################################################## + # Compute the min/max for volume an the unnomralized grid: + c_min, c_max, volume_grid = self.compute_volume_scaling_and_grids() + + ######################################################################## + # Process the transient data + ######################################################################## + if self.config.transient: + timesteps = data_dict["timesteps"] + t_max = torch.amax(timesteps) + t_min = torch.amin(timesteps) + timesteps = normalize(timesteps, t_max, t_min) + # return_dict["timesteps"] = timesteps + # return_dict["t_max"] = t_max + # return_dict["t_min"] = t_min + else: + timesteps = None + + ######################################################################## + # Process the surface data + ######################################################################## + if self.model_type == "surface" or self.model_type == "combined": + surface_fields_raw = ( + data_dict["surface_fields"] if "surface_fields" in data_dict else None + ) + if "surface_features" in data_dict: + surface_features_raw = data_dict["surface_features"] + else: + surface_features_raw = None + + if "surface_normals" in data_dict: + surface_normals_raw = data_dict["surface_normals"] + else: + surface_normals_raw = None + + if "surface_areas" in data_dict: + surface_sizes_raw = data_dict["surface_areas"] + else: + surface_sizes_raw = None + + surface_dict = self.process_surface( + s_min, + s_max, + c_min, + c_max, + center_of_mass=center_of_mass, + surf_grid=surf_grid, + surface_coordinates=data_dict["surface_mesh_centers"], + surface_normals=surface_normals_raw, + surface_sizes=surface_sizes_raw, + stl_vertices=data_dict["stl_coordinates"], + stl_indices=mesh_indices_flattened, + surface_fields=surface_fields_raw, + timesteps=timesteps, + surface_features=surface_features_raw, + ) + + return_dict.update(surface_dict) + + ######################################################################## + # Process the volume data + ######################################################################## + # For volume data, we store this only if normalizing coordinates: + if self.model_type == "volume" or self.model_type == "combined": + if self.config.normalize_coordinates: + return_dict["volume_min_max"] = torch.stack([c_min, c_max]) + + if self.model_type == "volume" or self.model_type == "combined": + volume_fields_raw = ( + data_dict["volume_fields"] if "volume_fields" in data_dict else None + ) + if "volume_features" in data_dict: + volume_features_raw = data_dict["volume_features"] + else: + volume_features_raw = None + volume_dict = self.process_volume( + c_min, + c_max, + volume_coordinates=data_dict["volume_mesh_centers"], + volume_grid=volume_grid, + center_of_mass=center_of_mass, + stl_vertices=data_dict["stl_coordinates"], + stl_indices=mesh_indices_flattened, + volume_fields=volume_fields_raw, + timesteps=timesteps, + volume_features=volume_features_raw, + ) + + return_dict.update(volume_dict) + + return return_dict + + def scale_model_targets( + self, fields: torch.Tensor, factors: torch.Tensor + ) -> torch.Tensor: + """ + Scale the model targets based on the configured scaling factors. + """ + if self.config.scaling_type == "mean_std_scaling": + field_mean = factors[0] + field_std = factors[1] + return standardize(fields, field_mean, field_std) + elif self.config.scaling_type == "min_max_scaling": + field_min = factors[1] + field_max = factors[0] + return normalize(fields, field_max, field_min) + + def unscale_model_outputs( + self, + volume_fields: torch.Tensor | None = None, + surface_fields: torch.Tensor | None = None, + ): + """ + Unscale the model outputs based on the configured scaling factors. + + The unscaling is included here to make it a consistent interface regardless + of the scaling factors and type used. + + """ + + if volume_fields is not None: + if self.config.scaling_type == "mean_std_scaling": + vol_mean = self.config.volume_factors[0] + vol_std = self.config.volume_factors[1] + volume_fields = unstandardize(volume_fields, vol_mean, vol_std) + elif self.config.scaling_type == "min_max_scaling": + vol_min = self.config.volume_factors[1] + vol_max = self.config.volume_factors[0] + volume_fields = unnormalize(volume_fields, vol_max, vol_min) + if surface_fields is not None: + if self.config.scaling_type == "mean_std_scaling": + surf_mean = self.config.surface_factors[0] + surf_std = self.config.surface_factors[1] + surface_fields = unstandardize(surface_fields, surf_mean, surf_std) + elif self.config.scaling_type == "min_max_scaling": + surf_min = self.config.surface_factors[1] + surf_max = self.config.surface_factors[0] + surface_fields = unnormalize(surface_fields, surf_max, surf_min) + + return volume_fields, surface_fields + + def set_dataset(self, dataset: Iterable) -> None: + """ + Pass a dataset to the datapipe to enable iterating over both in one pass. + """ + self.dataset = dataset + + if self.config.volume_sample_from_disk: + # We deliberately double the data to read compared to the sampling size: + self.dataset.set_volume_sampling_size( + 100 * self.config.volume_points_sample + ) + + def __len__(self): + if self.dataset is not None: + return len(self.dataset) + else: + return 0 + + def __getitem__(self, idx): + """ + Function for fetching and processing a single file's data. + + Domino, in general, expects one example per file and the files + are relatively large due to the mesh size. + + Requires the user to have set a dataset via `set_dataset`. + """ + if self.dataset is None: + raise ValueError("Dataset is not present") + + # Get the data from the dataset. + # Under the hood, this may be fetching preloaded data. + data_dict = self.dataset[idx] + + return self.__call__(data_dict) + + def __call__(self, data_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Process the incoming data dictionary. + - Processes the data + - moves it to GPU + - adds a batch dimension + + Args: + data_dict: Dictionary containing the data to process as torch.Tensors. + + Returns: + Dictionary containing the processed data as torch.Tensors. + + """ + data_dict = self.process_data(data_dict) + + # If the data is not on the target device, put it there: + for key, value in data_dict.items(): + if value.device != self.output_device: + data_dict[key] = value.to(self.output_device) + + # Add a batch dimension to the data_dict + data_dict = {k: v.unsqueeze(0) for k, v in data_dict.items()} + + return data_dict + + def __iter__(self): + if self.dataset is None: + raise ValueError( + "Dataset is not present, can not use the datapipe as an iterator." + ) + + for i, batch in enumerate(self.dataset): + yield self.__call__(batch) + + +def compute_scaling_factors( + cfg: DictConfig, + input_path: str, + target_keys: list[str], + max_samples=20, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Using the dataset at the path, compute the mean, std, min, and max of the target keys. + + Args: + cfg: Hydra configuration object containing all parameters + input_path: Path to the dataset to load. + target_keys: List of keys to compute the mean, std, min, and max of. + use_cache: (deprecated) This argument has no effect. + """ + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + dataset = CAEDataset( + data_dir=input_path, + keys_to_read=target_keys, + keys_to_read_if_available={}, + output_device=device, + ) + + mean, std, min_val, max_val = compute_mean_std_min_max( + dataset, + field_keys=target_keys, + max_samples=max_samples, + transient=cfg.model.transient, + ) + + return mean, std, min_val, max_val + + +class CachedDoMINODataset(Dataset): + """ + Dataset for reading cached DoMINO data files, with optional resampling. + Acts as a drop-in replacement for DoMINODataPipe. + """ + + # @nvtx_annotate(message="CachedDoMINODataset __init__") + def __init__( + self, + data_path: Union[str, Path], + phase: Literal["train", "val", "test"] = "train", + sampling: bool = False, + volume_points_sample: Optional[int] = None, + surface_points_sample: Optional[int] = None, + geom_points_sample: Optional[int] = None, + model_type=None, # Model_type, surface, volume or combined + deterministic_seed=False, + surface_sampling_algorithm="area_weighted", + ): + super().__init__() + + self.model_type = model_type + if deterministic_seed: + np.random.seed(42) + + if isinstance(data_path, str): + data_path = Path(data_path) + self.data_path = data_path.expanduser() + + if not self.data_path.exists(): + raise AssertionError(f"Path {self.data_path} does not exist") + if not self.data_path.is_dir(): + raise AssertionError(f"Path {self.data_path} is not a directory") + + self.deterministic_seed = deterministic_seed + self.sampling = sampling + self.volume_points = volume_points_sample + self.surface_points = surface_points_sample + self.geom_points = geom_points_sample + self.surface_sampling_algorithm = surface_sampling_algorithm + + self.filenames = get_filenames(self.data_path, exclude_dirs=True) + + total_files = len(self.filenames) + + self.phase = phase + self.indices = np.array(range(total_files)) + + np.random.shuffle(self.indices) + + if not self.filenames: + raise AssertionError(f"No cached files found in {self.data_path}") + + def __len__(self): + return len(self.indices) + + # @nvtx_annotate(message="CachedDoMINODataset __getitem__") + def __getitem__(self, idx): + if self.deterministic_seed: + np.random.seed(idx) + nvtx.range_push("Load cached file") + + index = self.indices[idx] + cfd_filename = self.filenames[index] + + filepath = self.data_path / cfd_filename + result = np.load(filepath, allow_pickle=True).item() + result = { + k: torch.from_numpy(v) if isinstance(v, np.ndarray) else v + for k, v in result.items() + } + + nvtx.range_pop() + if not self.sampling: + return result + + nvtx.range_push("Sample points") + + # Sample volume points if present + if "volume_mesh_centers" in result and self.volume_points: + coords_sampled, idx_volume = shuffle_array( + result["volume_mesh_centers"], self.volume_points + ) + if coords_sampled.shape[0] < self.volume_points: + coords_sampled = pad( + coords_sampled, self.volume_points, pad_value=-10.0 + ) + + result["volume_mesh_centers"] = coords_sampled + for key in [ + "volume_fields", + "pos_volume_closest", + "pos_volume_center_of_mass", + "sdf_nodes", + ]: + if key in result: + result[key] = result[key][idx_volume] + + # Sample surface points if present + if "surface_mesh_centers" in result and self.surface_points: + if self.surface_sampling_algorithm == "area_weighted": + coords_sampled, idx_surface = shuffle_array( + points=result["surface_mesh_centers"], + n_points=self.surface_points, + weights=result["surface_areas"], + ) + else: + coords_sampled, idx_surface = shuffle_array( + result["surface_mesh_centers"], self.surface_points + ) + + if coords_sampled.shape[0] < self.surface_points: + coords_sampled = pad( + coords_sampled, self.surface_points, pad_value=-10.0 + ) + + ii = result["neighbor_indices"] + result["surface_mesh_neighbors"] = result["surface_mesh_centers"][ii] + result["surface_neighbors_normals"] = result["surface_normals"][ii] + result["surface_neighbors_areas"] = result["surface_areas"][ii] + + result["surface_mesh_centers"] = coords_sampled + + for key in [ + "surface_fields", + "surface_areas", + "surface_normals", + "pos_surface_center_of_mass", + "surface_mesh_neighbors", + "surface_neighbors_normals", + "surface_neighbors_areas", + ]: + if key in result: + result[key] = result[key][idx_surface] + + del result["neighbor_indices"] + + # Sample geometry points if present + if "geometry_coordinates" in result and self.geom_points: + coords_sampled, _ = shuffle_array( + result["geometry_coordinates"], self.geom_points + ) + if coords_sampled.shape[0] < self.geom_points: + coords_sampled = pad(coords_sampled, self.geom_points, pad_value=-100.0) + result["geometry_coordinates"] = coords_sampled + + nvtx.range_pop() + return result + + +def create_domino_dataset( + cfg: DictConfig, + phase: Literal["train", "val", "test"], + keys_to_read: list[str], + keys_to_read_if_available: dict[str, torch.Tensor], + vol_factors: list[float], + surf_factors: list[float], + normalize_coordinates: bool = True, + sample_in_bbox: bool = True, + sampling: bool = True, + device_mesh: torch.distributed.DeviceMesh | None = None, + placements: dict[str, torch.distributed.tensor.Placement] | None = None, +): + model_type = cfg.model.model_type + if phase == "train": + input_path = cfg.data.input_dir + dataloader_cfg = cfg.train.dataloader + elif phase == "val": + input_path = cfg.data.input_dir_val + dataloader_cfg = cfg.val.dataloader + elif phase == "test": + input_path = cfg.eval.test_path + dataloader_cfg = None + else: + raise ValueError(f"Invalid phase {phase}") + + if cfg.data_processor.use_cache: + return CachedDoMINODataset( + input_path, + phase=phase, + sampling=sampling, + volume_points_sample=cfg.model.volume_points_sample, + surface_points_sample=cfg.model.surface_points_sample, + geom_points_sample=cfg.model.geom_points_sample, + model_type=cfg.model.model_type, + surface_sampling_algorithm=cfg.model.surface_sampling_algorithm, + ) + else: + # The dataset path works in two pieces: + # There is a core "dataset" which is loading data and moving to GPU + # And there is the preprocess step, here. + + # Optionally, and for backwards compatibility, the preprocess + # object can accept a dataset which will enable it as an iterator. + # The iteration function will loop over the dataset, preprocess the + # output, and return it. + + overrides = {} + if hasattr(cfg.data, "gpu_preprocessing"): + overrides["gpu_preprocessing"] = cfg.data.gpu_preprocessing + + if hasattr(cfg.data, "gpu_output"): + overrides["gpu_output"] = cfg.data.gpu_output + + dm = DistributedManager() + + if cfg.data.gpu_preprocessing: + device = dm.device + consumer_stream = torch.cuda.default_stream() + else: + device = torch.device("cpu") + consumer_stream = None + + if dataloader_cfg is not None: + preload_depth = dataloader_cfg.preload_depth + pin_memory = dataloader_cfg.pin_memory + else: + preload_depth = 1 + pin_memory = False + + dataset = CAEDataset( + data_dir=input_path, + keys_to_read=keys_to_read, + keys_to_read_if_available=keys_to_read_if_available, + output_device=device, + preload_depth=preload_depth, + pin_memory=pin_memory, + device_mesh=device_mesh, + placements=placements, + consumer_stream=consumer_stream, + ) + + datapipe = DoMINODataPipe( + input_path, + phase=phase, + mesh_type=cfg.model.mesh_type, + transient=cfg.model.transient, + use_surface_normals=cfg.model.use_surface_normals, + use_surface_area=cfg.model.use_surface_area, + grid_resolution=cfg.model.interp_res, + normalize_coordinates=normalize_coordinates, + sampling=sampling, + sample_in_bbox=sample_in_bbox, + volume_points_sample=cfg.model.volume_points_sample, + surface_points_sample=cfg.model.surface_points_sample, + geom_points_sample=cfg.model.geom_points_sample, + volume_factors=vol_factors, + surface_factors=surf_factors, + scaling_type=cfg.model.normalization, + model_type=model_type, + bounding_box_dims=cfg.data.bounding_box, + bounding_box_dims_surf=cfg.data.bounding_box_surface, + volume_sample_from_disk=cfg.data.volume_sample_from_disk, + num_surface_neighbors=cfg.model.num_neighbors_surface, + surface_sampling_algorithm=cfg.model.surface_sampling_algorithm, + transient_scheme=cfg.model.transient_scheme, + **overrides, + ) + + datapipe.set_dataset(dataset) + + return datapipe + + +if __name__ == "__main__": + fm_data = DoMINODataPipe( + data_path="/code/processed_data/new_models_1/", + phase="train", + sampling=False, + sample_in_bbox=False, + ) diff --git a/physicsnemo/models/domino_transient/__init__.py b/physicsnemo/models/domino_transient/__init__.py new file mode 100644 index 0000000000..e64c3ec5da --- /dev/null +++ b/physicsnemo/models/domino_transient/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .model import DoMINO diff --git a/physicsnemo/models/domino_transient/encodings.py b/physicsnemo/models/domino_transient/encodings.py new file mode 100644 index 0000000000..7b27eeb134 --- /dev/null +++ b/physicsnemo/models/domino_transient/encodings.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code contains the DoMINO model architecture. +The DoMINO class contains an architecture to model both surface and +volume quantities together as well as separately (controlled using +the config.yaml file) +""" + +import torch +import torch.nn as nn +from einops import rearrange + +from physicsnemo.models.layers import BQWarp + +from .mlps import LocalPointConv + + +class LocalGeometryEncoding(nn.Module): + """ + A local geometry encoding module. + + This will apply a ball query to the input features, mapping the point cloud + to the volume mesh, and then apply a local point convolution to the output. + + Args: + radius: The radius of the ball query. + neighbors_in_radius: The number of neighbors in the radius of the ball query. + total_neighbors_in_radius: The total number of neighbors in the radius of the ball query. + base_layer: The number of neurons in the hidden layer of the MLP. + activation: The activation function to use in the MLP. + grid_resolution: The resolution of the grid. + """ + + def __init__( + self, + radius: float, + neighbors_in_radius: int, + total_neighbors_in_radius: int, + base_layer: int, + activation: nn.Module, + grid_resolution: tuple[int, int, int], + ): + super().__init__() + self.bq_warp = BQWarp( + radius=radius, + neighbors_in_radius=neighbors_in_radius, + ) + self.local_point_conv = LocalPointConv( + input_features=total_neighbors_in_radius, + base_layer=base_layer, + output_features=neighbors_in_radius, + activation=activation, + ) + self.grid_resolution = grid_resolution + + def forward( + self, + encoding_g: torch.Tensor, + volume_mesh_centers: torch.Tensor, + p_grid: torch.Tensor, + ) -> torch.Tensor: + batch_size = volume_mesh_centers.shape[0] + nx, ny, nz = self.grid_resolution + + p_grid = torch.reshape(p_grid, (batch_size, nx * ny * nz, 3)) + mapping, outputs = self.bq_warp( + volume_mesh_centers, p_grid, reverse_mapping=False + ) + + mapping = mapping.type(torch.int64) + mask = mapping != 0 + + encoding_g_inner = [] + for j in range(encoding_g.shape[1]): + geo_encoding = rearrange(encoding_g[:, j], "b nx ny nz -> b 1 (nx ny nz)") + + geo_encoding_sampled = torch.index_select( + geo_encoding, 2, mapping.flatten() + ) + geo_encoding_sampled = torch.reshape(geo_encoding_sampled, mask.shape) + geo_encoding_sampled = geo_encoding_sampled * mask + + encoding_g_inner.append(geo_encoding_sampled) + encoding_g_inner = torch.cat(encoding_g_inner, dim=2) + encoding_g_inner = self.local_point_conv(encoding_g_inner) + + return encoding_g_inner + + +class MultiGeometryEncoding(nn.Module): + """ + Module to apply multiple local geometry encodings + + This will stack several local geometry encodings together, and concatenate the results. + + Args: + radii: The list of radii of the local geometry encodings. + neighbors_in_radius: The list of number of neighbors in the radius of the local geometry encodings. + geo_encoding_type: The type of geometry encoding to use. Can be "both", "stl", or "sdf". + base_layer: The number of neurons in the hidden layer of the MLP. + activation: The activation function to use in the MLP. + grid_resolution: The resolution of the grid. + """ + + def __init__( + self, + radii: list[float], + neighbors_in_radius: list[int], + geo_encoding_type: str, + n_upstream_radii: int, + base_layer: int, + activation: nn.Module, + grid_resolution: tuple[int, int, int], + ): + super().__init__() + + self.local_geo_encodings = nn.ModuleList( + [ + LocalGeometryEncoding( + radius=r, + neighbors_in_radius=n, + total_neighbors_in_radius=self.calculate_total_neighbors_in_radius( + geo_encoding_type, n, n_upstream_radii + ), + base_layer=base_layer, + activation=activation, + grid_resolution=grid_resolution, + ) + for r, n in zip(radii, neighbors_in_radius) + ] + ) + + def calculate_total_neighbors_in_radius( + self, geo_encoding_type: str, neighbors_in_radius: int, n_upstream_radii: int + ) -> int: + if geo_encoding_type == "both": + total_neighbors_in_radius = neighbors_in_radius * (n_upstream_radii + 1) + elif geo_encoding_type == "stl": + total_neighbors_in_radius = neighbors_in_radius * (n_upstream_radii) + elif geo_encoding_type == "sdf": + total_neighbors_in_radius = neighbors_in_radius + + return total_neighbors_in_radius + + def forward( + self, + encoding_g: torch.Tensor, + volume_mesh_centers: torch.Tensor, + p_grid: torch.Tensor, + ) -> torch.Tensor: + return torch.cat( + [ + local_geo_encoding(encoding_g, volume_mesh_centers, p_grid) + for local_geo_encoding in self.local_geo_encodings + ], + dim=-1, + ) diff --git a/physicsnemo/models/domino_transient/geometry_rep.py b/physicsnemo/models/domino_transient/geometry_rep.py new file mode 100644 index 0000000000..31020faa0a --- /dev/null +++ b/physicsnemo/models/domino_transient/geometry_rep.py @@ -0,0 +1,516 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from physicsnemo.models.layers import BQWarp, Mlp, fourier_encode, get_activation +from physicsnemo.models.unet import UNet + +# from .encodings import fourier_encode + + +def scale_sdf(sdf: torch.Tensor, scaling_factor: float = 0.04) -> torch.Tensor: + """ + Scale a signed distance function (SDF) to emphasize surface regions. + + This function applies a non-linear scaling to the SDF values that compresses + the range while preserving the sign, effectively giving more weight to points + near surfaces where abs(SDF) is small. + + Args: + sdf: Tensor containing signed distance function values + + Returns: + Tensor with scaled SDF values in range [-1, 1] + """ + return sdf / (scaling_factor + torch.abs(sdf)) + + +class GeoConvOut(nn.Module): + """ + Geometry layer to project STL geometry data onto regular grids. + """ + + def __init__( + self, + input_features: int, + neighbors_in_radius: int, + model_parameters, + grid_resolution=None, + nodal_geometry_features: int = 0, + ): + """ + Initialize the GeoConvOut layer. + + Args: + input_features: Number of input feature dimensions + neighbors_in_radius: Number of neighbors in radius + model_parameters: Configuration parameters for the model + grid_resolution: Resolution of the output grid [nx, ny, nz] + """ + super().__init__() + if grid_resolution is None: + grid_resolution = [256, 96, 64] + base_neurons = model_parameters.base_neurons + self.fourier_features = model_parameters.fourier_features + self.num_modes = model_parameters.num_modes + self.nodal_geometry_features = nodal_geometry_features + input_features = input_features + self.nodal_geometry_features + if self.fourier_features: + input_features_calculated = ( + input_features * (1 + 2 * self.num_modes) * neighbors_in_radius + ) + else: + input_features_calculated = input_features * neighbors_in_radius + + self.mlp = Mlp( + in_features=input_features_calculated, + hidden_features=[base_neurons, base_neurons // 2], + out_features=model_parameters.base_neurons_in, + act_layer=get_activation(model_parameters.activation), + drop=0.0, + ) + + self.grid_resolution = grid_resolution + + self.activation = get_activation(model_parameters.activation) + + self.neighbors_in_radius = neighbors_in_radius + + if self.fourier_features: + self.register_buffer( + "freqs", torch.exp(torch.linspace(0, math.pi, self.num_modes)) + ) + + def forward( + self, + x: torch.Tensor, + grid: torch.Tensor, + radius: float = 0.025, + neighbors_in_radius: int = 10, + geometry_features: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Process and project geometric features onto a 3D grid. + + Args: + x: Input tensor containing coordinates of the neighboring points + (batch_size, nx*ny*nz, n_points, 3) + grid: Input tensor represented as a grid of shape + (batch_size, nx, ny, nz, 3) + geometry_features: Geometry features tensor + (batch_size, nx*ny*nz, n_points, n_features) + Returns: + Processed geometry features of shape (batch_size, base_neurons_in, nx, ny, nz) + """ + + nx, ny, nz = ( + self.grid_resolution[0], + self.grid_resolution[1], + self.grid_resolution[2], + ) + grid = grid.reshape(1, nx * ny * nz, 3, 1) + + if self.nodal_geometry_features > 0: + x = torch.cat((x, geometry_features), axis=-1) + + x = rearrange( + x, "b x y z -> b x (y z)", x=nx * ny * nz, y=self.neighbors_in_radius, z=3+self.nodal_geometry_features + ) + + if self.fourier_features: + facets = torch.cat((x, fourier_encode(x, self.freqs)), axis=-1) + else: + facets = x + + x = F.tanh(self.mlp(facets)) + + x = rearrange(x, "b (x y z) c -> b c x y z", x=nx, y=ny, z=nz) + + return x + + +class GeoProcessor(nn.Module): + """Geometry processing layer using CNNs""" + + def __init__(self, input_filters: int, output_filters: int, model_parameters): + """ + Initialize the GeoProcessor network. + + Args: + input_filters: Number of input channels + model_parameters: Configuration parameters for the model + """ + super().__init__() + base_filters = model_parameters.base_filters + self.conv1 = nn.Conv3d( + input_filters, base_filters, kernel_size=3, padding="same" + ) + self.conv2 = nn.Conv3d( + base_filters, 2 * base_filters, kernel_size=3, padding="same" + ) + self.conv3 = nn.Conv3d( + 2 * base_filters, 4 * base_filters, kernel_size=3, padding="same" + ) + self.conv3_1 = nn.Conv3d( + 4 * base_filters, 4 * base_filters, kernel_size=3, padding="same" + ) + self.conv4 = nn.Conv3d( + 4 * base_filters, 2 * base_filters, kernel_size=3, padding="same" + ) + self.conv5 = nn.Conv3d( + 4 * base_filters, base_filters, kernel_size=3, padding="same" + ) + self.conv6 = nn.Conv3d( + 2 * base_filters, input_filters, kernel_size=3, padding="same" + ) + self.conv7 = nn.Conv3d( + 2 * input_filters, input_filters, kernel_size=3, padding="same" + ) + self.conv8 = nn.Conv3d( + input_filters, output_filters, kernel_size=3, padding="same" + ) + self.avg_pool = torch.nn.AvgPool3d((2, 2, 2)) + self.max_pool = nn.MaxPool3d(2) + self.upsample = nn.Upsample(scale_factor=2, mode="nearest") + self.activation = get_activation(model_parameters.activation) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Process geometry information through the 3D CNN network. + + The network follows an encoder-decoder architecture with skip connections: + 1. Downsampling path (encoder) with three levels of max pooling + 2. Processing loop in the bottleneck + 3. Upsampling path (decoder) with skip connections from the encoder + + Args: + x: Input tensor containing grid-represented geometry of shape + (batch_size, input_filters, nx, ny, nz) + + Returns: + Processed geometry features of shape (batch_size, 1, nx, ny, nz) + """ + # Encoder + x0 = x + x = self.conv1(x) + x = self.activation(x) + x = self.max_pool(x) + + x1 = x + x = self.conv2(x) + x = self.activation(x) + x = self.max_pool(x) + + x2 = x + x = self.conv3(x) + x = self.activation(x) + x = self.max_pool(x) + + # Processor loop + x = self.activation(self.conv3_1(x)) + + # Decoder + x = self.conv4(x) + x = self.activation(x) + x = self.upsample(x) + x = torch.cat((x, x2), dim=1) + + x = self.conv5(x) + x = self.activation(x) + x = self.upsample(x) + x = torch.cat((x, x1), dim=1) + + x = self.conv6(x) + x = self.activation(x) + x = self.upsample(x) + x = torch.cat((x, x0), dim=1) + + x = self.activation(self.conv7(x)) + x = self.conv8(x) + + return x + + +class GeometryRep(nn.Module): + """ + Geometry representation module that processes STL geometry data. + + This module constructs a multiscale representation of geometry by: + 1. Computing multi-scale geometry encoding for local and global context + 2. Processing signed distance field (SDF) data for surface information + + The combined encoding enables the model to reason about both local and global + geometric properties. + """ + + def __init__( + self, + input_features: int, + radii: Sequence[float], + neighbors_in_radius, + hops=1, + sdf_scaling_factor: Sequence[float] = [0.04], + model_parameters=None, + nodal_geometry_features: int = 0, + # activation_conv: nn.Module, + # activation_processor: nn.Module, + ): + """ + Initialize the GeometryRep module. + + Args: + input_features: Number of input feature dimensions + model_parameters: Configuration parameters for the model + """ + super().__init__() + geometry_rep = model_parameters.geometry_rep + self.geo_encoding_type = model_parameters.geometry_encoding_type + self.cross_attention = geometry_rep.geo_processor.cross_attention + self.self_attention = geometry_rep.geo_processor.self_attention + self.activation_conv = get_activation(geometry_rep.geo_conv.activation) + self.activation_processor = geometry_rep.geo_processor.activation + self.sdf_scaling_factor = sdf_scaling_factor + + self.bq_warp = nn.ModuleList() + self.geo_processors = nn.ModuleList() + for j in range(len(radii)): + self.bq_warp.append( + BQWarp( + radius=radii[j], + neighbors_in_radius=neighbors_in_radius[j], + ) + ) + if geometry_rep.geo_processor.processor_type == "unet": + h = geometry_rep.geo_processor.base_filters + if self.self_attention: + normalization_in_unet = "layernorm" + else: + normalization_in_unet = None + self.geo_processors.append( + UNet( + in_channels=geometry_rep.geo_conv.base_neurons_in, + out_channels=geometry_rep.geo_conv.base_neurons_out, + model_depth=3, + feature_map_channels=[ + h, + 2 * h, + 4 * h, + ], + num_conv_blocks=1, + kernel_size=3, + stride=1, + conv_activation=self.activation_processor, + padding=1, + padding_mode="zeros", + pooling_type="MaxPool3d", + pool_size=2, + normalization=normalization_in_unet, + use_attn_gate=self.self_attention, + attn_decoder_feature_maps=[4 * h, 2 * h], + attn_feature_map_channels=[2 * h, h], + attn_intermediate_channels=4 * h, + gradient_checkpointing=True, + ) + ) + elif geometry_rep.geo_processor.processor_type == "conv": + self.geo_processors.append( + nn.Sequential( + GeoProcessor( + input_filters=geometry_rep.geo_conv.base_neurons_in, + output_filters=geometry_rep.geo_conv.base_neurons_out, + model_parameters=geometry_rep.geo_processor, + ), + ) + ) + else: + raise ValueError("Invalid prompt. Specify unet or conv ...") + + self.geo_conv_out = nn.ModuleList() + self.geo_processor_out = nn.ModuleList() + for u in range(len(radii)): + self.geo_conv_out.append( + GeoConvOut( + input_features=input_features, + neighbors_in_radius=neighbors_in_radius[u], + model_parameters=geometry_rep.geo_conv, + grid_resolution=model_parameters.interp_res, + nodal_geometry_features=nodal_geometry_features, + ) + ) + self.geo_processor_out.append( + nn.Conv3d( + geometry_rep.geo_conv.base_neurons_out, + 1, + kernel_size=3, + padding="same", + ) + ) + + if geometry_rep.geo_processor.processor_type == "unet": + h = geometry_rep.geo_processor.base_filters + if self.self_attention: + normalization_in_unet = "layernorm" + else: + normalization_in_unet = None + + self.geo_processor_sdf = UNet( + in_channels=5 + len(self.sdf_scaling_factor), + out_channels=geometry_rep.geo_conv.base_neurons_out, + model_depth=3, + feature_map_channels=[ + h, + 2 * h, + 4 * h, + ], + num_conv_blocks=1, + kernel_size=3, + stride=1, + conv_activation=self.activation_processor, + padding=1, + padding_mode="zeros", + pooling_type="MaxPool3d", + pool_size=2, + normalization=normalization_in_unet, + use_attn_gate=self.self_attention, + attn_decoder_feature_maps=[4 * h, 2 * h], + attn_feature_map_channels=[2 * h, h], + attn_intermediate_channels=4 * h, + gradient_checkpointing=True, + ) + elif geometry_rep.geo_processor.processor_type == "conv": + self.geo_processor_sdf = nn.Sequential( + GeoProcessor( + input_filters=5 + len(self.sdf_scaling_factor), + output_filters=geometry_rep.geo_conv.base_neurons_out, + model_parameters=geometry_rep.geo_processor, + ), + ) + else: + raise ValueError("Invalid prompt. Specify unet or conv ...") + self.radii = radii + self.neighbors_in_radius = neighbors_in_radius + self.hops = hops + + self.geo_processor_sdf_out = nn.Conv3d( + geometry_rep.geo_conv.base_neurons_out, 1, kernel_size=3, padding="same" + ) + + if self.cross_attention: + h = geometry_rep.geo_processor.base_filters + self.combined_unet = UNet( + in_channels=1 + len(radii), + out_channels=1 + len(radii), + model_depth=3, + feature_map_channels=[ + h, + 2 * h, + 4 * h, + ], + num_conv_blocks=1, + kernel_size=3, + stride=1, + conv_activation=self.activation_processor, + padding=1, + padding_mode="zeros", + pooling_type="MaxPool3d", + pool_size=2, + normalization="layernorm", + use_attn_gate=True, + attn_decoder_feature_maps=[4 * h, 2 * h], + attn_feature_map_channels=[2 * h, h], + attn_intermediate_channels=4 * h, + gradient_checkpointing=True, + ) + + def forward( + self, x: torch.Tensor, p_grid: torch.Tensor, sdf: torch.Tensor, geometry_features: torch.Tensor | None = None + ) -> torch.Tensor: + """ + Process geometry data to create a comprehensive representation. + + This method combines short-range, long-range, and SDF-based geometry + encodings to create a rich representation of the geometry. + + Args: + x: Input tensor containing geometric point data + p_grid: Grid points for sampling + sdf: Signed distance field tensor + geometry_features: Geometry features tensor + Returns: + Comprehensive geometry encoding that concatenates short-range, + SDF-based, and long-range features + """ + if self.geo_encoding_type == "both" or self.geo_encoding_type == "stl": + # Calculate multi-scale geoemtry dependency + x_encoding = [] + for j in range(len(self.radii)): + + mapping, k_short = self.bq_warp[j](x, p_grid) + if geometry_features is not None: + geometry_features_calculated = torch.unsqueeze(geometry_features[0, mapping[0]], 0) + else: + geometry_features_calculated = None + x_encoding_inter = self.geo_conv_out[j](k_short, p_grid, geometry_features=geometry_features_calculated) + + # Propagate information in the geometry enclosed BBox + for _ in range(self.hops): + dx = self.geo_processors[j](x_encoding_inter) / self.hops + x_encoding_inter = x_encoding_inter + dx + x_encoding_inter = self.geo_processor_out[j](x_encoding_inter) + x_encoding.append(x_encoding_inter) + x_encoding = torch.cat(x_encoding, dim=1) + + if self.geo_encoding_type == "both" or self.geo_encoding_type == "sdf": + # Expand SDF + sdf = torch.unsqueeze(sdf, 1) + # Binary sdf + binary_sdf = torch.where(sdf >= 0, 0.0, 1.0) + # Gradients of SDF + sdf_x, sdf_y, sdf_z = torch.gradient(sdf, dim=[2, 3, 4]) + + scaled_sdf = [] + # Scaled sdf to emphasize near surface + for s in range(len(self.sdf_scaling_factor)): + s_sdf = scale_sdf(sdf, self.sdf_scaling_factor[s]) + scaled_sdf.append(s_sdf) + + scaled_sdf = torch.cat(scaled_sdf, dim=1) + + # Process SDF and its computed features + sdf = torch.cat((sdf, scaled_sdf, binary_sdf, sdf_x, sdf_y, sdf_z), 1) + + sdf_encoding = self.geo_processor_sdf(sdf) + sdf_encoding = self.geo_processor_sdf_out(sdf_encoding) + + if self.geo_encoding_type == "both": + # Geometry encoding comprised of short-range, long-range and SDF features + encoding_g = torch.cat((x_encoding, sdf_encoding), 1) + elif self.geo_encoding_type == "sdf": + encoding_g = sdf_encoding + elif self.geo_encoding_type == "stl": + encoding_g = x_encoding + + if self.cross_attention: + encoding_g = self.combined_unet(encoding_g) + + return encoding_g diff --git a/physicsnemo/models/domino_transient/mlps.py b/physicsnemo/models/domino_transient/mlps.py new file mode 100644 index 0000000000..7223990df1 --- /dev/null +++ b/physicsnemo/models/domino_transient/mlps.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file contains specific MLPs for the DoMINO model. + +The main feature here is we've locked in the number of layers. +""" + +import torch.nn as nn + +from physicsnemo.models.layers import Mlp + + +class AggregationModel(Mlp): + """ + Neural network module to aggregate local geometry encoding with basis functions. + + This module combines basis function representations with geometry encodings + to predict the final output quantities. It serves as the final prediction layer + that integrates all available information sources. + + It is implemented as a straightforward MLP with 6 total layers. + + """ + + def __init__( + self, + input_features: int, + output_features: int, + base_layer: int, + activation: nn.Module, + ): + hidden_features = [base_layer, base_layer, base_layer, base_layer] + + super().__init__( + in_features=input_features, + hidden_features=hidden_features, + out_features=output_features, + act_layer=activation, + drop=0.0, + ) + + +class LocalPointConv(Mlp): + """Layer for local geometry point kernel + + This is a straight forward MLP, with exactly two layers. + """ + + def __init__( + self, + input_features: int, + base_layer: int, + output_features: int, + activation: nn.Module, + ): + super().__init__( + in_features=input_features, + hidden_features=base_layer, + out_features=output_features, + act_layer=activation, + drop=0.0, + ) diff --git a/physicsnemo/models/domino_transient/model.py b/physicsnemo/models/domino_transient/model.py new file mode 100644 index 0000000000..fc06cdbfc9 --- /dev/null +++ b/physicsnemo/models/domino_transient/model.py @@ -0,0 +1,1018 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code contains the DoMINO model architecture. +The DoMINO class contains an architecture to model both surface and +volume quantities together as well as separately (controlled using +the config.yaml file) +""" + +import torch +import torch.nn as nn + +from physicsnemo.models.layers import FourierMLP, get_activation +from physicsnemo.models.unet import UNet + +from .encodings import ( + MultiGeometryEncoding, +) +from .geometry_rep import GeometryRep, scale_sdf +from .mlps import AggregationModel +from .solutions import SolutionCalculatorSurface, SolutionCalculatorVolume + +# @dataclass +# class MetaData(ModelMetaData): +# name: str = "DoMINO" +# # Optimization +# jit: bool = False +# cuda_graphs: bool = True +# amp: bool = True +# # Inference +# onnx_cpu: bool = True +# onnx_gpu: bool = True +# onnx_runtime: bool = True +# # Physics informed +# var_dim: int = 1 +# func_torch: bool = False +# auto_grad: bool = False + + +class DoMINO(nn.Module): + """ + DoMINO model architecture for predicting both surface and volume quantities. + + The DoMINO (Deep Operational Modal Identification and Nonlinear Optimization) model + is designed to model both surface and volume physical quantities in aerodynamic + simulations. It can operate in three modes: + 1. Surface-only: Predicting only surface quantities + 2. Volume-only: Predicting only volume quantities + 3. Combined: Predicting both surface and volume quantities + + The model uses a combination of: + - Geometry representation modules + - Neural network basis functions + - Parameter encoding + - Local and global geometry processing + - Aggregation models for final prediction + + Parameters + ---------- + input_features : int + Number of point input features + output_features_vol : int, optional + Number of output features in volume + output_features_surf : int, optional + Number of output features on surface + model_parameters + Model parameters controlled by config.yaml + + Example + ------- + >>> from physicsnemo.models.domino.model import DoMINO + >>> import torch, os + >>> from hydra import compose, initialize + >>> from omegaconf import OmegaConf + >>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + >>> cfg = OmegaConf.register_new_resolver("eval", eval) + >>> with initialize(version_base="1.3", config_path="examples/cfd/external_aerodynamics/domino/src/conf"): + ... cfg = compose(config_name="config") + >>> cfg.model.model_type = "combined" + >>> model = DoMINO( + ... input_features=3, + ... output_features_vol=5, + ... output_features_surf=4, + ... model_parameters=cfg.model + ... ).to(device) + + Warp ... + >>> bsize = 1 + >>> nx, ny, nz = cfg.model.interp_res + >>> num_neigh = 7 + >>> global_features = 2 + >>> pos_normals_closest_vol = torch.randn(bsize, 100, 3).to(device) + >>> pos_normals_com_vol = torch.randn(bsize, 100, 3).to(device) + >>> pos_normals_com_surface = torch.randn(bsize, 100, 3).to(device) + >>> geom_centers = torch.randn(bsize, 100, 3).to(device) + >>> grid = torch.randn(bsize, nx, ny, nz, 3).to(device) + >>> surf_grid = torch.randn(bsize, nx, ny, nz, 3).to(device) + >>> sdf_grid = torch.randn(bsize, nx, ny, nz).to(device) + >>> sdf_surf_grid = torch.randn(bsize, nx, ny, nz).to(device) + >>> sdf_nodes = torch.randn(bsize, 100, 1).to(device) + >>> surface_coordinates = torch.randn(bsize, 100, 3).to(device) + >>> surface_neighbors = torch.randn(bsize, 100, num_neigh, 3).to(device) + >>> surface_normals = torch.randn(bsize, 100, 3).to(device) + >>> surface_neighbors_normals = torch.randn(bsize, 100, num_neigh, 3).to(device) + >>> surface_sizes = torch.rand(bsize, 100).to(device) + 1e-6 # Note this needs to be > 0.0 + >>> surface_neighbors_areas = torch.rand(bsize, 100, num_neigh).to(device) + 1e-6 + >>> volume_coordinates = torch.randn(bsize, 100, 3).to(device) + >>> vol_grid_max_min = torch.randn(bsize, 2, 3).to(device) + >>> surf_grid_max_min = torch.randn(bsize, 2, 3).to(device) + >>> global_params_values = torch.randn(bsize, global_features, 1).to(device) + >>> global_params_reference = torch.randn(bsize, global_features, 1).to(device) + >>> input_dict = { + ... "pos_volume_closest": pos_normals_closest_vol, + ... "pos_volume_center_of_mass": pos_normals_com_vol, + ... "pos_surface_center_of_mass": pos_normals_com_surface, + ... "geometry_coordinates": geom_centers, + ... "grid": grid, + ... "surf_grid": surf_grid, + ... "sdf_grid": sdf_grid, + ... "sdf_surf_grid": sdf_surf_grid, + ... "sdf_nodes": sdf_nodes, + ... "surface_mesh_centers": surface_coordinates, + ... "surface_mesh_neighbors": surface_neighbors, + ... "surface_normals": surface_normals, + ... "surface_neighbors_normals": surface_neighbors_normals, + ... "surface_areas": surface_sizes, + ... "surface_neighbors_areas": surface_neighbors_areas, + ... "volume_mesh_centers": volume_coordinates, + ... "volume_min_max": vol_grid_max_min, + ... "surface_min_max": surf_grid_max_min, + ... "global_params_reference": global_params_values, + ... "global_params_values": global_params_reference, + ... } + >>> output = model(input_dict) + >>> print(f"{output[0].shape}, {output[1].shape}") + torch.Size([1, 100, 5]), torch.Size([1, 100, 4]) + """ + + def __init__( + self, + input_features: int, + output_features_vol: int | None = None, + output_features_surf: int | None = None, + global_features: int = 2, + nodal_surface_features: int = 0, + nodal_volume_features: int = 0, + nodal_geometry_features: int = 0, + model_parameters=None, + ): + """ + Initialize the DoMINO model. + + Args: + input_features: Number of input feature dimensions for point data + output_features_vol: Number of output features for volume quantities (None for surface-only mode) + output_features_surf: Number of output features for surface quantities (None for volume-only mode) + transient: Whether the model is transient + tranient_scheme: The scheme to use for the transient model + model_parameters: Configuration parameters for the model + + Raises: + ValueError: If both output_features_vol and output_features_surf are None + """ + super().__init__() + + self.output_features_vol = output_features_vol + self.output_features_surf = output_features_surf + self.num_sample_points_surface = model_parameters.num_neighbors_surface + self.num_sample_points_volume = model_parameters.num_neighbors_volume + self.integration_steps = model_parameters.integration_steps + self.integration_scheme = model_parameters.transient_scheme + self.transient = model_parameters.transient + self.activation_processor = ( + model_parameters.geometry_rep.geo_processor.activation + ) + self.nodal_surface_features = nodal_surface_features + self.nodal_volume_features = nodal_volume_features + self.nodal_geometry_features = nodal_geometry_features + + self.global_features = global_features + + if self.output_features_vol is None and self.output_features_surf is None: + raise ValueError( + "At least one of `output_features_vol` or `output_features_surf` must be specified" + ) + if hasattr(model_parameters, "solution_calculation_mode"): + if model_parameters.solution_calculation_mode not in [ + "one-loop", + "two-loop", + ]: + raise ValueError( + f"Invalid solution_calculation_mode: {model_parameters.solution_calculation_mode}, select 'one-loop' or 'two-loop'." + ) + self.solution_calculation_mode = model_parameters.solution_calculation_mode + else: + self.solution_calculation_mode = "two-loop" + self.num_variables_vol = output_features_vol + self.num_variables_surf = output_features_surf + self.grid_resolution = model_parameters.interp_res + self.use_surface_normals = model_parameters.use_surface_normals + self.use_surface_area = model_parameters.use_surface_area + self.encode_parameters = model_parameters.encode_parameters + self.geo_encoding_type = model_parameters.geometry_encoding_type + + if self.use_surface_normals: + if not self.use_surface_area: + input_features_surface = input_features + 3 + else: + input_features_surface = input_features + 4 + else: + input_features_surface = input_features + + if self.encode_parameters: + # Defining the parameter model + base_layer_p = model_parameters.parameter_model.base_layer + self.parameter_model = FourierMLP( + input_features=self.global_features, + fourier_features=model_parameters.parameter_model.fourier_features, + num_modes=model_parameters.parameter_model.num_modes, + base_layer=model_parameters.parameter_model.base_layer, + activation=get_activation(model_parameters.parameter_model.activation), + ) + else: + base_layer_p = 0 + + self.geo_rep_volume = GeometryRep( + input_features=input_features, + radii=model_parameters.geometry_rep.geo_conv.volume_radii, + neighbors_in_radius=model_parameters.geometry_rep.geo_conv.volume_neighbors_in_radius, + hops=model_parameters.geometry_rep.geo_conv.volume_hops, + sdf_scaling_factor=model_parameters.geometry_rep.geo_processor.volume_sdf_scaling_factor, + model_parameters=model_parameters, + nodal_geometry_features=nodal_geometry_features, + ) + + self.geo_rep_surface = GeometryRep( + input_features=input_features, + radii=model_parameters.geometry_rep.geo_conv.surface_radii, + neighbors_in_radius=model_parameters.geometry_rep.geo_conv.surface_neighbors_in_radius, + hops=model_parameters.geometry_rep.geo_conv.surface_hops, + sdf_scaling_factor=model_parameters.geometry_rep.geo_processor.surface_sdf_scaling_factor, + model_parameters=model_parameters, + nodal_geometry_features=nodal_geometry_features, + ) + + if self.transient: + input_features_surface = input_features_surface + 1 # Adding one for the time step + input_features = input_features + 1 # Adding one for the time step + + # Basis functions for surface and volume + base_layer_nn = model_parameters.nn_basis_functions.base_layer + if self.output_features_surf is not None: + self.nn_basis_surf = nn.ModuleList() + for _ in range( + self.num_variables_surf + ): # Have the same basis function for each variable + self.nn_basis_surf.append( + FourierMLP( + input_features=input_features_surface + self.nodal_surface_features, + base_layer=model_parameters.nn_basis_functions.base_layer, + fourier_features=model_parameters.nn_basis_functions.fourier_features, + num_modes=model_parameters.nn_basis_functions.num_modes, + activation=get_activation( + model_parameters.nn_basis_functions.activation + ), + # model_parameters=model_parameters.nn_basis_functions, + ) + ) + + if self.output_features_vol is not None: + self.nn_basis_vol = nn.ModuleList() + for _ in range( + self.num_variables_vol + ): # Have the same basis function for each variable + self.nn_basis_vol.append( + FourierMLP( + input_features=input_features + self.nodal_volume_features, + base_layer=model_parameters.nn_basis_functions.base_layer, + fourier_features=model_parameters.nn_basis_functions.fourier_features, + num_modes=model_parameters.nn_basis_functions.num_modes, + activation=get_activation( + model_parameters.nn_basis_functions.activation + ), + # model_parameters=model_parameters.nn_basis_functions, + ) + ) + + # Positional encoding + position_encoder_base_neurons = model_parameters.position_encoder.base_neurons + self.activation = get_activation(model_parameters.activation) + self.use_sdf_in_basis_func = model_parameters.use_sdf_in_basis_func + self.sdf_scaling_factor = ( + model_parameters.geometry_rep.geo_processor.volume_sdf_scaling_factor + ) + if self.output_features_vol is not None: + inp_pos_vol = ( + 7 + len(self.sdf_scaling_factor) + if model_parameters.use_sdf_in_basis_func + else 3 + ) + + self.fc_p_vol = FourierMLP( + input_features=inp_pos_vol, + fourier_features=model_parameters.position_encoder.fourier_features, + num_modes=model_parameters.position_encoder.num_modes, + base_layer=model_parameters.position_encoder.base_neurons, + activation=get_activation(model_parameters.position_encoder.activation), + ) + + if self.output_features_surf is not None: + inp_pos_surf = 3 + + self.fc_p_surf = FourierMLP( + input_features=inp_pos_surf, + fourier_features=model_parameters.position_encoder.fourier_features, + num_modes=model_parameters.position_encoder.num_modes, + base_layer=model_parameters.position_encoder.base_neurons, + activation=get_activation(model_parameters.position_encoder.activation), + ) + + # Create a set of local geometry encodings for the surface data: + self.surface_local_geo_encodings = MultiGeometryEncoding( + radii=model_parameters.geometry_local.surface_radii, + neighbors_in_radius=model_parameters.geometry_local.surface_neighbors_in_radius, + geo_encoding_type=self.geo_encoding_type, + n_upstream_radii=len(model_parameters.geometry_rep.geo_conv.surface_radii), + base_layer=512, + activation=get_activation(model_parameters.local_point_conv.activation), + grid_resolution=self.grid_resolution, + ) + + # Create a set of local geometry encodings for the surface data: + self.volume_local_geo_encodings = MultiGeometryEncoding( + radii=model_parameters.geometry_local.volume_radii, + neighbors_in_radius=model_parameters.geometry_local.volume_neighbors_in_radius, + geo_encoding_type=self.geo_encoding_type, + n_upstream_radii=len(model_parameters.geometry_rep.geo_conv.volume_radii), + base_layer=512, + activation=get_activation(model_parameters.local_point_conv.activation), + grid_resolution=self.grid_resolution, + ) + + # Aggregation model + if self.output_features_surf is not None: + # Surface + base_layer_geo_surf = 0 + for j in model_parameters.geometry_local.surface_neighbors_in_radius: + base_layer_geo_surf += j + + self.agg_model_surf = nn.ModuleList() + for _ in range(self.num_variables_surf): + self.agg_model_surf.append( + AggregationModel( + input_features=position_encoder_base_neurons + + base_layer_nn + + base_layer_geo_surf + + base_layer_p, + output_features=1, + base_layer=model_parameters.aggregation_model.base_layer, + activation=get_activation( + model_parameters.aggregation_model.activation + ), + ) + ) + + self.solution_calculator_surf = SolutionCalculatorSurface( + num_variables=self.num_variables_surf, + num_sample_points=self.num_sample_points_surface, + use_surface_normals=self.use_surface_normals, + use_surface_area=self.use_surface_area, + encode_parameters=self.encode_parameters, + parameter_model=self.parameter_model + if self.encode_parameters + else None, + aggregation_model=self.agg_model_surf, + nn_basis=self.nn_basis_surf, + ) + + if self.output_features_vol is not None: + # Volume + base_layer_geo_vol = 0 + for j in model_parameters.geometry_local.volume_neighbors_in_radius: + base_layer_geo_vol += j + + self.agg_model_vol = nn.ModuleList() + for _ in range(self.num_variables_vol): + self.agg_model_vol.append( + AggregationModel( + input_features=position_encoder_base_neurons + + base_layer_nn + + base_layer_geo_vol + + base_layer_p, + output_features=1, + base_layer=model_parameters.aggregation_model.base_layer, + activation=get_activation( + model_parameters.aggregation_model.activation + ), + ) + ) + if hasattr(model_parameters, "return_volume_neighbors"): + return_volume_neighbors = model_parameters.return_volume_neighbors + else: + return_volume_neighbors = False + + self.solution_calculator_vol = SolutionCalculatorVolume( + num_variables=self.num_variables_vol, + num_sample_points=self.num_sample_points_volume, + noise_intensity=50, + return_volume_neighbors=return_volume_neighbors, + encode_parameters=self.encode_parameters, + parameter_model=self.parameter_model + if self.encode_parameters + else None, + aggregation_model=self.agg_model_vol, + nn_basis=self.nn_basis_vol, + ) + + def _validate_and_extract_surface_properties(self, data_dict): + """ + Validate and extract surface properties from data dictionary. + + Args: + data_dict: Input data dictionary + + Returns: + Tuple of (surface_areas, surface_normals, surface_neighbors_areas, surface_neighbors_normals) + """ + surface_areas = None + surface_normals = None + surface_neighbors_areas = None + surface_neighbors_normals = None + + if "surface_areas" in data_dict: + surface_areas = data_dict["surface_areas"] + surface_neighbors_areas = data_dict["surface_neighbors_areas"] + surface_areas = torch.unsqueeze(surface_areas, -1) + surface_neighbors_areas = torch.unsqueeze(surface_neighbors_areas, -1) + if "surface_normals" in data_dict: + surface_normals = data_dict["surface_normals"] + surface_neighbors_normals = data_dict["surface_neighbors_normals"] + + return surface_areas, surface_normals, surface_neighbors_areas, surface_neighbors_normals + + def _validate_and_extract_nodal_features(self, data_dict): + """ + Validate and extract nodal features from data dictionary. + """ + surface_features = None + volume_features = None + geometry_features = None + + if "surface_features" in data_dict: + surface_features = data_dict["surface_features"] + if surface_features.shape[-1] != self.nodal_surface_features: + raise ValueError( + f"Surface features must have {self.nodal_surface_features} features" + ) + + if "volume_features" in data_dict: + volume_features = data_dict["volume_features"] + if volume_features.shape[-1] != self.nodal_volume_features: + raise ValueError( + f"Volume features must have {self.nodal_volume_features} features" + ) + + if "geometry_features" in data_dict: + geometry_features = data_dict["geometry_features"] + if geometry_features.shape[-1] != self.nodal_geometry_features: + raise ValueError( + f"Geometry features must have {self.nodal_geometry_features} features" + ) + + return surface_features, volume_features, geometry_features + + def _compute_volume_positional_encoding(self, data_dict): + """ + Compute positional encodings for volume domain. + + Args: + data_dict: Input data dictionary containing: + - sdf_nodes: SDF values at volume nodes + - pos_volume_closest: Positions of closest surface points + - pos_volume_center_of_mass: Positions relative to geometry center of mass + + Returns: + Positional encoding tensor for volume nodes + """ + # Compute SDF-based features + sdf_nodes = data_dict["sdf_nodes"] + scaled_sdf_nodes = [ + scale_sdf(sdf_nodes, scaling) for scaling in self.sdf_scaling_factor + ] + scaled_sdf_nodes = torch.cat(scaled_sdf_nodes, dim=-1) + + # Compute positional encodings + pos_volume_closest = data_dict["pos_volume_closest"] + pos_volume_center_of_mass = data_dict["pos_volume_center_of_mass"] + + if self.use_sdf_in_basis_func: + encoding_node_vol = torch.cat( + (sdf_nodes, scaled_sdf_nodes, pos_volume_closest, pos_volume_center_of_mass), + dim=-1, + ) + else: + encoding_node_vol = pos_volume_center_of_mass + + # Apply positional encoder network + encoding_node_vol = self.fc_p_vol(encoding_node_vol) + + return encoding_node_vol + + def _compute_volume_encodings(self, data_dict, geo_centers, geometry_features): + """ + Compute geometry encodings for volume domain. + + Args: + data_dict: Input data dictionary + geo_centers: Geometry center coordinates + geometry_features: Optional geometry features + + Returns: + Tuple of (encoding_g_vol, p_grid) + """ + # Computational domain grid + p_grid = data_dict["grid"] + if "sdf_grid" in data_dict: + sdf_grid = data_dict["sdf_grid"] + if sdf_grid.shape[0] != p_grid.shape[0]: + raise ValueError( + "SDF grid and grid must have the same number of points" + ) + else: + sdf_grid = None + + # Normalize geometry centers based on volume domain + if "volume_min_max" in data_dict: + vol_max = data_dict["volume_min_max"][:, 1] + vol_min = data_dict["volume_min_max"][:, 0] + geo_centers_vol = 2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1 + else: + geo_centers_vol = geo_centers + + # Compute geometry representation encoding + encoding_g_vol = self.geo_rep_volume( + geo_centers_vol, p_grid, sdf_grid, geometry_features=geometry_features + ) + + return encoding_g_vol, p_grid + + def _compute_surface_positional_encoding(self, data_dict): + """ + Compute positional encodings for surface domain. + + Args: + data_dict: Input data dictionary containing: + - pos_surface_center_of_mass: Positions relative to geometry center of mass + + Returns: + Positional encoding tensor for surface nodes + """ + # Compute positional encoding + pos_surface_center_of_mass = data_dict["pos_surface_center_of_mass"] + encoding_node_surf = self.fc_p_surf(pos_surface_center_of_mass) + + return encoding_node_surf + + def _compute_surface_encodings(self, data_dict, geo_centers, geometry_features): + """ + Compute geometry encodings for surface domain. + + Args: + data_dict: Input data dictionary + geo_centers: Geometry center coordinates + geometry_features: Optional geometry features + + Returns: + Tuple of (encoding_g_surf, s_grid, sdf_surf_grid) + """ + # Surface grid + s_grid = data_dict["surf_grid"] + if "sdf_surf_grid" in data_dict: + sdf_surf_grid = data_dict["sdf_surf_grid"] + if sdf_surf_grid.shape[0] != s_grid.shape[0]: + raise ValueError( + "SDF surface grid and surface grid must have the same number of points" + ) + else: + sdf_surf_grid = None + # Normalize geometry centers based on surface domain + if "surface_min_max" in data_dict: + surf_max = data_dict["surface_min_max"][:, 1] + surf_min = data_dict["surface_min_max"][:, 0] + geo_centers_surf = 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + else: + geo_centers_surf = geo_centers + + # Compute geometry representation encoding + encoding_g_surf = self.geo_rep_surface( + geo_centers_surf, s_grid, sdf_surf_grid, geometry_features=geometry_features + ) + + return encoding_g_surf, s_grid, sdf_surf_grid + + def _compute_volume_local_encodings(self, encoding_g_vol, volume_mesh_centers, p_grid): + """ + Compute local geometry encodings for volume mesh. + + Args: + encoding_g_vol: Global volume geometry encoding + volume_mesh_centers: Volume mesh center coordinates + p_grid: Volume grid + + Returns: + Local volume geometry encodings + """ + if self.transient: + encoding_g_vol_all = [] + for i in range(volume_mesh_centers.shape[1]): + encoding_g_vol_i = self.volume_local_geo_encodings( + 0.5 * encoding_g_vol, volume_mesh_centers[:, i, :, :3], p_grid + ) + encoding_g_vol_all.append(torch.unsqueeze(encoding_g_vol_i, 1)) + return torch.cat(encoding_g_vol_all, dim=1) + else: + return self.volume_local_geo_encodings( + 0.5 * encoding_g_vol, volume_mesh_centers, p_grid + ) + + def _compute_surface_local_encodings(self, encoding_g_surf, surface_mesh_centers, s_grid): + """ + Compute local geometry encodings for surface mesh. + + Args: + encoding_g_surf: Global surface geometry encoding + surface_mesh_centers: Surface mesh center coordinates + s_grid: Surface grid + + Returns: + Local surface geometry encodings + """ + if self.transient: + encoding_g_surf_all = [] + for i in range(surface_mesh_centers.shape[1]): + encoding_g_surf_i = self.surface_local_geo_encodings( + 0.5 * encoding_g_surf, surface_mesh_centers[:, i, :, :3], s_grid + ) + encoding_g_surf_all.append(torch.unsqueeze(encoding_g_surf_i, 1)) + return torch.cat(encoding_g_surf_all, dim=1) + else: + return self.surface_local_geo_encodings( + 0.5 * encoding_g_surf, surface_mesh_centers, s_grid + ) + + def _compute_volume_output_implicit( + self, + volume_mesh_centers, + encoding_g_vol, + encoding_node_vol, + global_params_values, + global_params_reference, + volume_features, + ): + """ + Compute volume output using implicit integration scheme. + + Args: + volume_mesh_centers: Volume mesh center coordinates + encoding_g_vol: Volume geometry encodings + encoding_node_vol: Volume node encodings + global_params_values: Global parameter values + global_params_reference: Global parameter references + volume_features: Optional volume features + + Returns: + Volume output tensor + """ + output_vol_all = [] + volume_mesh_centers_i = None + + for i in range(self.integration_steps): + if i == 0: + volume_mesh_centers_i = volume_mesh_centers[:, i] + else: + volume_mesh_centers_i[:, :, :3] += output_vol + + volume_features_i = volume_features[:, i] if volume_features is not None else None + + output_vol = self.solution_calculator_vol( + volume_mesh_centers_i, + encoding_g_vol[:, i], + encoding_node_vol[:, i], + global_params_values, + global_params_reference, + volume_features_i, + ) + output_vol_all.append(torch.unsqueeze(output_vol, 1)) + + return torch.cat(output_vol_all, dim=1) + + def _compute_volume_output_explicit( + self, + volume_mesh_centers, + encoding_g_vol, + encoding_node_vol, + global_params_values, + global_params_reference, + volume_features, + ): + """ + Compute volume output using explicit integration scheme. + + Args: + volume_mesh_centers: Volume mesh center coordinates + encoding_g_vol: Volume geometry encodings + encoding_node_vol: Volume node encodings + global_params_values: Global parameter values + global_params_reference: Global parameter references + volume_features: Optional volume features + + Returns: + Volume output tensor + """ + output_vol_all = [] + + for i in range(volume_mesh_centers.shape[1]): + volume_features_i = volume_features[:, i] if volume_features is not None else None + + output_vol = self.solution_calculator_vol( + volume_mesh_centers[:, i], + encoding_g_vol[:, i], + encoding_node_vol[:, i], + global_params_values, + global_params_reference, + volume_features_i, + ) + output_vol_all.append(torch.unsqueeze(output_vol, 1)) + + return torch.cat(output_vol_all, dim=1) + + def _compute_surface_output_implicit( + self, + surface_mesh_centers, + surface_mesh_neighbors, + surface_normals, + surface_neighbors_normals, + surface_areas, + surface_neighbors_areas, + encoding_g_surf, + encoding_node_surf, + global_params_values, + global_params_reference, + surface_features, + ): + """ + Compute surface output using implicit integration scheme. + + Args: + surface_mesh_centers: Surface mesh center coordinates + surface_mesh_neighbors: Surface mesh neighbor coordinates + surface_normals: Surface normal vectors + surface_neighbors_normals: Surface neighbor normal vectors + surface_areas: Surface element areas + surface_neighbors_areas: Surface neighbor areas + encoding_g_surf: Surface geometry encodings + encoding_node_surf: Surface node encodings + global_params_values: Global parameter values + global_params_reference: Global parameter references + surface_features: Optional surface features + + Returns: + Surface output tensor + """ + output_surf_all = [] + surface_mesh_centers_i = None + surface_mesh_neighbors_i = None + + for i in range(self.integration_steps): + if i == 0: + surface_mesh_centers_i = surface_mesh_centers[:, i] + surface_mesh_neighbors_i = surface_mesh_neighbors[:, i] + else: + surface_mesh_centers_i[:, :, :3] += output_surf + for j in range(surface_mesh_neighbors_i.shape[2]): + surface_mesh_neighbors_i[:, :, j, :3] += output_surf + + surface_features_i = surface_features[:, i] if surface_features is not None else None + + output_surf = self.solution_calculator_surf( + surface_mesh_centers_i, + encoding_g_surf[:, i], + encoding_node_surf[:, i], + surface_mesh_neighbors_i, + surface_normals[:, i] if surface_normals is not None else None, + surface_neighbors_normals[:, i] if surface_neighbors_normals is not None else None, + surface_areas[:, i] if surface_areas is not None else None, + surface_neighbors_areas[:, i] if surface_neighbors_areas is not None else None, + global_params_values, + global_params_reference, + surface_features_i, + ) + output_surf_all.append(torch.unsqueeze(output_surf, 1)) + + return torch.cat(output_surf_all, dim=1) + + def _compute_surface_output_explicit( + self, + surface_mesh_centers, + surface_mesh_neighbors, + surface_normals, + surface_neighbors_normals, + surface_areas, + surface_neighbors_areas, + encoding_g_surf, + encoding_node_surf, + global_params_values, + global_params_reference, + surface_features, + ): + """ + Compute surface output using explicit integration scheme. + + Args: + surface_mesh_centers: Surface mesh center coordinates + surface_mesh_neighbors: Surface mesh neighbor coordinates + surface_normals: Surface normal vectors + surface_neighbors_normals: Surface neighbor normal vectors + surface_areas: Surface element areas + surface_neighbors_areas: Surface neighbor areas + encoding_g_surf: Surface geometry encodings + encoding_node_surf: Surface node encodings + global_params_values: Global parameter values + global_params_reference: Global parameter references + surface_features: Optional surface features + + Returns: + Surface output tensor + """ + output_surf_all = [] + + for i in range(surface_mesh_centers.shape[1]): + surface_features_i = surface_features[:, i] if surface_features is not None else None + + output_surf = self.solution_calculator_surf( + surface_mesh_centers[:, i], + encoding_g_surf[:, i], + encoding_node_surf[:, i], + surface_mesh_neighbors[:, i], + surface_normals[:, i] if surface_normals is not None else None, + surface_neighbors_normals[:, i] if surface_neighbors_normals is not None else None, + surface_areas[:, i] if surface_areas is not None else None, + surface_neighbors_areas[:, i] if surface_neighbors_areas is not None else None, + global_params_values, + global_params_reference, + surface_features_i, + ) + output_surf_all.append(torch.unsqueeze(output_surf, 1)) + + return torch.cat(output_surf_all, dim=1) + + def forward(self, data_dict): + """ + Forward pass of the DoMINO model. + + Args: + data_dict: Dictionary containing all input data including: + - geometry_coordinates: Geometry center coordinates + - surf_grid: Surface grid + - sdf_surf_grid: Surface SDF grid + - sdf_grid: Volume SDF grid + - grid: Volume grid + - volume_mesh_centers: Volume mesh center coordinates + - surface_mesh_centers: Surface mesh center coordinates + - surface_normals: Surface normal vectors + - surface_areas: Surface element areas + - surface_mesh_neighbors: Surface mesh neighbor coordinates + - surface_neighbors_normals: Surface neighbor normal vectors + - surface_neighbors_areas: Surface neighbor areas + - volume_mesh_centers: Volume mesh center coordinates + - surface_mesh_centers: Surface mesh center coordinates + - (optional) surface_normals: Surface normal vectors + - (optional) surface_areas: Surface element areas + - surface_mesh_neighbors: Surface mesh neighbor coordinates + - (optional) surface_neighbors_normals: Surface neighbor normal vectors + - (optional) surface_neighbors_areas: Surface neighbor areas + - global_params_values: Global parameter values + - global_params_reference: Global parameter references + - (optional) surface_features, volume_features, geometry_features + - (optional) volume-specific data if output_features_vol is not None + - (optional) surface-specific data if output_features_surf is not None + + Returns: + Tuple of (output_vol, output_surf) where each can be None if not computed + """ + + # Extract base inputs + geo_centers = data_dict["geometry_coordinates"] + if "global_params_values" in data_dict: + global_params_values = data_dict["global_params_values"] + else: + global_params_values = None + if "global_params_reference" in data_dict: + global_params_reference = data_dict["global_params_reference"] + else: + global_params_reference = None + + # Validate and extract features + surface_features, volume_features, geometry_features = ( + self._validate_and_extract_nodal_features(data_dict) + ) + + surface_areas, surface_normals, surface_neighbors_areas, surface_neighbors_normals = ( + self._validate_and_extract_surface_properties(data_dict) + ) + + # Compute volume outputs if required + output_vol = None + if self.output_features_vol is not None: + # Compute volume geometry encodings + encoding_g_vol, p_grid = self._compute_volume_encodings( + data_dict, geo_centers, geometry_features + ) + + # Compute volume positional encodings + encoding_node_vol = self._compute_volume_positional_encoding(data_dict) + + # Get volume mesh data + volume_mesh_centers = data_dict["volume_mesh_centers"] + + # Compute local geometry encodings + encoding_g_vol = self._compute_volume_local_encodings( + encoding_g_vol, volume_mesh_centers, p_grid + ) + + # Compute volume solution based on integration scheme + if self.integration_scheme == "implicit": + output_vol = self._compute_volume_output_implicit( + volume_mesh_centers, + encoding_g_vol, + encoding_node_vol, + global_params_values, + global_params_reference, + volume_features, + ) + else: + output_vol = self._compute_volume_output_explicit( + volume_mesh_centers, + encoding_g_vol, + encoding_node_vol, + global_params_values, + global_params_reference, + volume_features, + ) + + # Compute surface outputs if required + output_surf = None + if self.output_features_surf is not None: + # Compute surface geometry encodings + encoding_g_surf, s_grid, _ = self._compute_surface_encodings( + data_dict, geo_centers, geometry_features + ) + + # Compute surface positional encodings + encoding_node_surf = self._compute_surface_positional_encoding(data_dict) + + # Get surface mesh data + surface_mesh_centers = data_dict["surface_mesh_centers"] + surface_mesh_neighbors = data_dict["surface_mesh_neighbors"] + + # Compute local geometry encodings + encoding_g_surf = self._compute_surface_local_encodings( + encoding_g_surf, surface_mesh_centers, s_grid + ) + + # Compute surface solution based on integration scheme + if self.integration_scheme == "implicit": + output_surf = self._compute_surface_output_implicit( + surface_mesh_centers, + surface_mesh_neighbors, + surface_normals, + surface_neighbors_normals, + surface_areas, + surface_neighbors_areas, + encoding_g_surf, + encoding_node_surf, + global_params_values, + global_params_reference, + surface_features, + ) + else: + output_surf = self._compute_surface_output_explicit( + surface_mesh_centers, + surface_mesh_neighbors, + surface_normals, + surface_neighbors_normals, + surface_areas, + surface_neighbors_areas, + encoding_g_surf, + encoding_node_surf, + global_params_values, + global_params_reference, + surface_features, + ) + + return output_vol, output_surf diff --git a/physicsnemo/models/domino_transient/solutions.py b/physicsnemo/models/domino_transient/solutions.py new file mode 100644 index 0000000000..bcd24c800d --- /dev/null +++ b/physicsnemo/models/domino_transient/solutions.py @@ -0,0 +1,376 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code contains the DoMINO model architecture. +The DoMINO class contains an architecture to model both surface and +volume quantities together as well as separately (controlled using +the config.yaml file) +""" + +from collections import defaultdict + +import torch +import torch.nn as nn + + +def apply_parameter_encoding( + mesh_centers: torch.Tensor, + global_params_values: torch.Tensor, + global_params_reference: torch.Tensor, +) -> torch.Tensor: + processed_parameters = [] + for k in range(global_params_values.shape[1]): + param = torch.unsqueeze(global_params_values[:, k, :], 1) + ref = torch.unsqueeze(global_params_reference[:, k, :], 1) + param = param.expand( + param.shape[0], + mesh_centers.shape[1], + param.shape[2], + ) + param = param / ref + processed_parameters.append(param) + processed_parameters = torch.cat(processed_parameters, axis=-1) + + return processed_parameters + + +def sample_sphere(center, r, num_points): + """Uniformly sample points in a 3D sphere around the center. + + This method generates random points within a sphere of radius r centered + at each point in the input tensor. The sampling is uniform in volume, + meaning points are more likely to be sampled in the outer regions of the sphere. + + Args: + center: Tensor of shape (batch_size, num_points, 3) containing center coordinates + r: Radius of the sphere for sampling + num_points: Number of points to sample per center + + Returns: + Tensor of shape (batch_size, num_points, num_samples, 3) containing + the sampled points around each center + """ + # Adjust the center points to the final shape: + unsqueezed_center = center.unsqueeze(2).expand(-1, -1, num_points, -1) + + # Generate directions like the centers: + directions = torch.randn_like(unsqueezed_center) + directions = directions / torch.norm(directions, dim=-1, keepdim=True) + + # Generate radii like the centers: + radii = r * torch.pow(torch.rand_like(unsqueezed_center), 1 / 3) + + output = unsqueezed_center + directions * radii + return output + + +def sample_sphere_shell(center, r_inner, r_outer, num_points): + """Uniformly sample points in a 3D spherical shell around a center. + + This method generates random points within a spherical shell (annulus) + between inner radius r_inner and outer radius r_outer centered at each + point in the input tensor. The sampling is uniform in volume within the shell. + + Args: + center: Tensor of shape (batch_size, num_points, 3) containing center coordinates + r_inner: Inner radius of the spherical shell + r_outer: Outer radius of the spherical shell + num_points: Number of points to sample per center + + Returns: + Tensor of shape (batch_size, num_points, num_samples, 3) containing + the sampled points within the spherical shell around each center + """ + + unsqueezed_center = center.unsqueeze(2).expand(-1, -1, num_points, -1) + + # Generate directions like the centers: + directions = torch.randn_like(unsqueezed_center) + directions = directions / torch.norm(directions, dim=-1, keepdim=True) + + radii = torch.rand_like(unsqueezed_center) * (r_outer**3 - r_inner**3) + r_inner**3 + radii = torch.pow(radii, 1 / 3) + + output = unsqueezed_center + directions * radii + + return output + + +class SolutionCalculatorVolume(nn.Module): + """ + Module to calculate the output solution of the DoMINO Model for volume data. + """ + + def __init__( + self, + num_variables: int, + num_sample_points: int, + noise_intensity: float, + encode_parameters: bool, + return_volume_neighbors: bool, + parameter_model: nn.Module | None, + aggregation_model: nn.ModuleList, + nn_basis: nn.ModuleList, + ): + super().__init__() + + self.num_variables = num_variables + self.num_sample_points = num_sample_points + self.noise_intensity = noise_intensity + self.encode_parameters = encode_parameters + self.return_volume_neighbors = return_volume_neighbors + self.parameter_model = parameter_model + self.aggregation_model = aggregation_model + self.nn_basis = nn_basis + + if self.encode_parameters: + if self.parameter_model is None: + raise ValueError( + "Parameter model is required when encode_parameters is True" + ) + + def forward( + self, + volume_mesh_centers: torch.Tensor, + encoding_g: torch.Tensor, + encoding_node: torch.Tensor, + global_params_values: torch.Tensor, + global_params_reference: torch.Tensor, + volume_features: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]: + """ + Forward pass of the SolutionCalculator module. + """ + if self.encode_parameters: + param_encoding = apply_parameter_encoding( + volume_mesh_centers, global_params_values, global_params_reference + ) + param_encoding = self.parameter_model(param_encoding) + + volume_m_c_perturbed = [volume_mesh_centers.unsqueeze(2)] + + if self.return_volume_neighbors: + num_hop1 = self.num_sample_points + num_hop2 = ( + self.num_sample_points // 2 if self.num_sample_points != 1 else 1 + ) # This is per 1 hop node + neighbors = defaultdict(list) + + volume_m_c_hop1 = sample_sphere( + volume_mesh_centers, 1 / self.noise_intensity, num_hop1 + ) + # 1 hop neighbors + for i in range(num_hop1): + idx = len(volume_m_c_perturbed) + volume_m_c_perturbed.append(volume_m_c_hop1[:, :, i : i + 1, :]) + neighbors[0].append(idx) + + # 2 hop neighbors + for i in range(num_hop1): + parent_idx = i + 1 # Skipping the first point, which is the original + parent_point = volume_m_c_perturbed[parent_idx] + + children = sample_sphere_shell( + parent_point.squeeze(2), + 1 / self.noise_intensity, + 2 / self.noise_intensity, + num_hop2, + ) + + for c in range(num_hop2): + idx = len(volume_m_c_perturbed) + volume_m_c_perturbed.append(children[:, :, c : c + 1, :]) + neighbors[parent_idx].append(idx) + + volume_m_c_perturbed = torch.cat(volume_m_c_perturbed, dim=2) + neighbors = dict(neighbors) + field_neighbors = {i: [] for i in range(self.num_variables)} + else: + volume_m_c_sample = sample_sphere( + volume_mesh_centers, 1 / self.noise_intensity, self.num_sample_points + ) + for i in range(self.num_sample_points): + volume_m_c_perturbed.append(volume_m_c_sample[:, :, i : i + 1, :]) + + volume_m_c_perturbed = torch.cat(volume_m_c_perturbed, dim=2) + + for f in range(self.num_variables): + for p in range(volume_m_c_perturbed.shape[2]): + volume_m_c = volume_m_c_perturbed[:, :, p, :] + if p != 0: + dist = torch.norm( + volume_m_c - volume_mesh_centers, dim=-1, keepdim=True + ) + if volume_features is not None: + volume_m_c = torch.cat((volume_m_c, volume_features), dim=-1) + basis_f = self.nn_basis[f](volume_m_c) + output = torch.cat((basis_f, encoding_node, encoding_g), dim=-1) + + if self.encode_parameters: + output = torch.cat((output, param_encoding), dim=-1) + if p == 0: + output_center = self.aggregation_model[f](output) + else: + if p == 1: + output_neighbor = self.aggregation_model[f](output) * ( + 1.0 / dist + ) + dist_sum = 1.0 / dist + else: + output_neighbor += self.aggregation_model[f](output) * ( + 1.0 / dist + ) + dist_sum += 1.0 / dist + if self.return_volume_neighbors: + field_neighbors[f].append(self.aggregation_model[f](output)) + + if self.return_volume_neighbors: + field_neighbors[f] = torch.stack(field_neighbors[f], dim=2) + + if self.num_sample_points > 1: + output_res = ( + 0.5 * output_center + 0.5 * output_neighbor / dist_sum + ) # This only applies to the main point, and not the preturbed points + else: + output_res = output_center + if f == 0: + output_all = output_res + else: + output_all = torch.cat((output_all, output_res), axis=-1) + + if self.return_volume_neighbors: + field_neighbors = torch.cat( + [field_neighbors[i] for i in range(self.num_variables)], dim=3 + ) + return output_all, volume_m_c_perturbed, field_neighbors, neighbors + else: + return output_all + + +class SolutionCalculatorSurface(nn.Module): + """ + Module to calculate the output solution of the DoMINO Model for surface data. + """ + + def __init__( + self, + num_variables: int, + num_sample_points: int, + encode_parameters: bool, + use_surface_normals: bool, + use_surface_area: bool, + parameter_model: nn.Module | None, + aggregation_model: nn.ModuleList, + nn_basis: nn.ModuleList, + ): + super().__init__() + self.num_variables = num_variables + self.num_sample_points = num_sample_points + self.encode_parameters = encode_parameters + self.use_surface_normals = use_surface_normals + self.use_surface_area = use_surface_area + self.parameter_model = parameter_model + self.aggregation_model = aggregation_model + self.nn_basis = nn_basis + + if self.encode_parameters: + if self.parameter_model is None: + raise ValueError( + "Parameter model is required when encode_parameters is True" + ) + + def forward( + self, + surface_mesh_centers: torch.Tensor, + encoding_g: torch.Tensor, + encoding_node: torch.Tensor, + surface_mesh_neighbors: torch.Tensor, + surface_normals: torch.Tensor | None = None, + surface_neighbors_normals: torch.Tensor | None = None, + surface_areas: torch.Tensor | None = None, + surface_neighbors_areas: torch.Tensor | None = None, + global_params_values: torch.Tensor | None = None, + global_params_reference: torch.Tensor | None = None, + surface_features: torch.Tensor | None = None, + ) -> torch.Tensor: + """Function to approximate solution given the neighborhood information""" + + if self.encode_parameters: + param_encoding = apply_parameter_encoding( + surface_mesh_centers, global_params_values, global_params_reference + ) + param_encoding = self.parameter_model(param_encoding) + + centers_inputs = [ + surface_mesh_centers, + ] + neighbors_inputs = [ + surface_mesh_neighbors, + ] + + if self.use_surface_normals and surface_normals is not None: + centers_inputs.append(surface_normals) + if self.num_sample_points > 1 and surface_neighbors_normals is not None: + neighbors_inputs.append(surface_neighbors_normals) + + if self.use_surface_area and surface_areas is not None: + centers_inputs.append(torch.log(surface_areas) / 10) + if self.num_sample_points > 1 and surface_neighbors_areas is not None: + neighbors_inputs.append(torch.log(surface_neighbors_areas) / 10) + + surface_mesh_centers = torch.cat(centers_inputs, dim=-1) + surface_mesh_neighbors = torch.cat(neighbors_inputs, dim=-1) + + for f in range(self.num_variables): + for p in range(self.num_sample_points): + if p == 0: + surface_m_c = surface_mesh_centers + else: + surface_m_c = surface_mesh_neighbors[:, :, p - 1] + 1e-6 + noise = surface_mesh_centers - surface_m_c + dist = torch.norm(noise, dim=-1, keepdim=True) + + if surface_features is not None: + surface_m_c = torch.cat((surface_m_c, surface_features), dim=-1) + basis_f = self.nn_basis[f](surface_m_c) + output = torch.cat((basis_f, encoding_node, encoding_g), dim=-1) + + if self.encode_parameters: + output = torch.cat((output, param_encoding), dim=-1) + if p == 0: + output_center = self.aggregation_model[f](output) + else: + if p == 1: + output_neighbor = self.aggregation_model[f](output) * ( + 1.0 / dist + ) + dist_sum = 1.0 / dist + else: + output_neighbor += self.aggregation_model[f](output) * ( + 1.0 / dist + ) + dist_sum += 1.0 / dist + if self.num_sample_points > 1: + output_res = 0.5 * output_center + 0.5 * output_neighbor / dist_sum + else: + output_res = output_center + if f == 0: + output_all = output_res + else: + output_all = torch.cat((output_all, output_res), dim=-1) + + return output_all diff --git a/physicsnemo/utils/domino/utils.py b/physicsnemo/utils/domino/utils.py index 5942795cc2..52bf9c0932 100644 --- a/physicsnemo/utils/domino/utils.py +++ b/physicsnemo/utils/domino/utils.py @@ -23,13 +23,46 @@ """ from pathlib import Path -from typing import Any, Sequence +from typing import Any, Optional, Sequence import torch from physicsnemo.utils.neighbors import knn +def repeat_array( + arr: torch.Tensor, p: int, axis: Optional[int], new_axis: bool, **kwargs +) -> torch.Tensor: + """Repeat each element p times along the specified axis using torch operations. + + Args: + arr: Input tensor to repeat. + p: Number of times to repeat each element. + axis: Axis along which to repeat. If None and new_axis is True, defaults to 0. + new_axis: If True, adds a new dimension before repeating. If False, repeats along existing axis. + **kwargs: Ignored keyword arguments (for backwards compatibility with xp parameter). + + Returns: + Tensor with repeated elements. + + Examples: + >>> import torch + >>> arr = torch.tensor([1, 2, 3]) + >>> repeat_array(arr, 2, axis=0, new_axis=False) + tensor([1, 1, 2, 2, 3, 3]) + >>> repeat_array(arr, 2, axis=0, new_axis=True).shape + torch.Size([2, 3]) + """ + if new_axis: + # Add new axis and repeat along it + if axis is None: + axis = 0 + expanded = torch.unsqueeze(arr, dim=axis) + return torch.repeat_interleave(expanded, p, dim=axis) + else: + # Repeat along existing axis + return torch.repeat_interleave(arr, p, dim=axis) + def calculate_center_of_mass( centers: torch.Tensor, sizes: torch.Tensor ) -> torch.Tensor: