diff --git a/README.md b/README.md
index 2159b96a4b..4011049c4f 100644
--- a/README.md
+++ b/README.md
@@ -98,6 +98,7 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计
| 热仿真 | [1D 换热器热仿真](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/heat_exchanger) | 机理驱动 | PI-DeepONet | 无监督学习 | - | - |
| 热仿真 | [2D 热仿真](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/heat_pinn) | 机理驱动 | PINN | 无监督学习 | - | [Paper](https://arxiv.org/abs/1711.10561)|
| 热仿真 | [2D 芯片热仿真](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/chip_heat) | 机理驱动 | PI-DeepONet | 无监督学习 | - | [Paper](https://doi.org/10.1063/5.0194245)|
+| 外流空气动力学 | [DoMINO](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/domino) | 数据驱动 | FNO | 监督学习 | [Data](https://caemldatasets.org/drivaerml/) | [Paper](https://arxiv.org/abs/2501.13350)|
材料科学(AI for Material)
diff --git a/docs/index.md b/docs/index.md index 128e4c9464..44f471ab52 100644 --- a/docs/index.md +++ b/docs/index.md @@ -133,6 +133,7 @@ | 热仿真 | [1D 换热器热仿真](./zh/examples/heat_exchanger.md) | 机理驱动 | PI-DeepONet | 无监督学习 | - | - | | 热仿真 | [2D 热仿真](./zh/examples/heat_pinn.md) | 机理驱动 | PINN | 无监督学习 | - | [Paper](https://arxiv.org/abs/1711.10561)| | 热仿真 | [2D 芯片热仿真](./zh/examples/chip_heat.md) | 机理驱动 | PI-DeepONet | 无监督学习 | - | [Paper](https://doi.org/10.1063/5.0194245)| +| 外流空气动力学 | [DoMINO](./zh/examples/domino.md) | 数据驱动 | FNO | 监督学习 | [Data](https://caemldatasets.org/drivaerml/) | [Paper](https://arxiv.org/abs/2501.13350)|材料科学(AI for Material)
diff --git a/docs/zh/examples/domino.md b/docs/zh/examples/domino.md new file mode 100644 index 0000000000..79225ac4a4 --- /dev/null +++ b/docs/zh/examples/domino.md @@ -0,0 +1,142 @@ +# DoMINO + +=== "模型训练命令" + + ``` sh + cd examples/domino + + # 1. Download the DrivAer ML dataset using the provided download_aws_dataset.sh script or using the Hugging Face repo(https://huggingface.co/datasets/neashton/drivaerml). + sh download_aws_dataset.sh + + # 2. Specify the configuration settings in `examples/domino/conf/config.yaml`. + + # 3. Run process_data.py. This will process VTP/VTU files and save them as npy for faster processing in DoMINO datapipe. Modify data_processor key in config file. Additionally, run cache_data.py to save outputs of DoMINO datapipe in the .npy files. The DoMINO datapipe is set up to calculate Signed Distance Field and Nearest Neighbor interpolations on-the-fly during training. Caching will save these as a preprocessing step and should be used in cases where the STL surface meshes are upwards of 30 million cells. The final processed dataset should be divided and saved into 2 directories, for training and validation. Specify these directories in conf/config.yaml. + # specify mode using `process`, set path to data_processor.output_dir and data_processor.input_dir + python3 domino.py + + # 4. run train, specify mode using `train`, set path to data.input_dir and data.input_dir_val + python3 domino.py + ``` + +=== "模型评估命令" + + 暂无 + +=== "模型导出命令" + + 暂无 + +=== "模型推理命令" + + ``` sh + cd examples/domino + # specify mode using `eval`, and set path to eval.test_path, eval.save_path and eval.checkpoint_name + python3 domino.py + ``` + +## 1. 背景简介 + +在现代工程产品的设计与开发过程中,数值模拟(如计算流体动力学,CFD)扮演着至关重要的角色。它们能够提供对复杂物理现象的精确预测,从而指导产品性能优化和设计迭代。然而,传统的高保真数值模拟方法,特别是针对具有复杂几何形状(例如汽车、飞机等)和大规模计算域的场景,往往需要耗费巨大的计算资源和时间。动辄数小时甚至数天的模拟周期,严重制约了设计迭代的效率和并行探索多个设计方案的可能性。 + +为了突破这一瓶颈,近年来研究人员积极探索将机器学习(ML)模型作为传统数值模拟的快速替代(即代理模型)。这些ML模型通过从大量的模拟数据中学习物理系统的输入-输出映射关系,从而在显著减少计算时间的同时,仍能保持可接受的精度。早期的ML代理模型在处理较小规模或简化问题时表现出一定的潜力。然而,当面对大型工程模拟时,这些模型常常暴露出局限性,例如在准确性和可扩展性方面存在瓶颈。许多现有方法依赖于对模拟网格进行大幅度的降采样,这不仅可能导致预测精度的下降,还会损害模型在未见数据上的泛化能力,限制了它们在实际复杂工程问题中的应用。因此,开发一种既能处理大规模数据、又能保持高精度和良好泛化能力的机器学习代理模型,成为了当前研究的迫切需求。 + +## 2. 问题定义 + +### 2.1 数据集 + +DrivAerML数据集是一个专门为汽车空气动力学机器学习研究而设计的大规模、高保真计算流体动力学(CFD)数据集。它包含了数百个经过几何变形的DrivAer溜背式汽车变体的高保真CFD模拟结果。该数据集以其庞大的数据量、高分辨率的网格以及显著的几何变化而著称,这为训练和测试能够处理复杂几何和流场的机器学习模型提供了理想的基准。与早期的DrivAerNet相比,DrivAerML及其后续版本如DrivAerNet++包含了更多样的几何设计(例如,DrivAerNet++包含8000个汽车设计,涵盖了传统的内燃机汽车和电动汽车的各种底盘和车轮设计),以及更为丰富的CFD模拟数据,包括STL格式的参数化汽车几何体、表面压力场数据,以及全三维的压力、速度和湍流场以及壁面剪切应力等。这些数据集的发布旨在为数据驱动的空气动力学设计提供大规模、高保真的数据,支持机器学习模型在空气动力学评估、生成设计等方面的训练。 + +### 2.2 主要问题 + +尽管机器学习代理模型在工程模拟领域展现出广阔前景,但其在大规模工程模拟中的应用仍面临诸多挑战。具体而言,本文所要解决的问题主要集中在以下几个方面: + +- *可扩展性与计算效率*: 传统的ML模型在处理大规模计算域和高分辨率网格(通常包含数亿甚至数十亿个网格元素)时,面临巨大的内存和计算资源需求。它们难以有效地扩展到如此庞大的数据量,导致训练和推理时间过长,无法满足实时或准实时的工程应用需求。 + +- *几何表示与泛化能力*: 许多现有ML方法难以有效地表示复杂的三维几何形状。它们通常试图学习一个全局的几何表示来预测整个计算域的解场。然而,这种全局表示往往是高维且密集的,对于具有精细特征的复杂几何体,很难准确捕捉到解场与局部几何细节之间的复杂关系。此外,许多模型对输入数据的空间结构敏感,导致在不同网格类型或点云分布之间泛化能力较差。例如,在结构化网格上训练的模型可能无法很好地应用于非结构化网格或任意分布的点云数据。 + +- *精度与长程相互作用*: 尽管一些方法能够扩展到大型网格,但由于无法有效捕捉长程物理相互作用(例如,上游几何形状对下游流场的影响),它们的预测精度往往受到限制。这意味着模型可能无法准确预测远离输入几何体或在复杂流态(如尾流)中的物理量。 + +- *迭代特性与物理一致性*: 传统的模拟过程通常是迭代的,逐步收敛到稳态解。ML代理模型往往直接预测最终解,缺乏迭代特性,这可能导致其无法充分利用物理约束或通过迭代修正来提高解的物理一致性。 + +针对这些问题,本文的目标是开发一种新的ML模型架构,能够有效地处理大规模几何数据,准确捕捉局部和长程依赖关系,同时具有良好的泛化能力,并在计算效率上取得显著提升,从而成为高保真工程模拟的实用替代方案。 + +## 3. 模型原理 + +为了克服上述挑战,本文提出了DOMINO(Decomposable Multi-scale Iterative Neural Operator)模型。DOMINO的核心思想是结合了点云处理、多尺度学习和迭代优化,以高效地建模大规模工程模拟。 + +### 3.1 模型架构概述 + +DOMINO模型以三维几何体的表面点云(通常从表面网格转换而来)作为输入。它首先在几何体周围定义一个三维包围盒作为整个计算域。关键创新在于,DOMINO不试图一次性学习整个计算域的全局解,而是将问题分解为局部的、可并行处理的子问题,并通过多尺度和迭代的方式逐步细化解。 + +### 3.2 几何表示与特征提取 + +- *输入表示*: 原始的三维几何体通过其表面网格(例如,三角形网格)或点云进行表示。这些几何信息被转换为一个标准化的N维结构化表示,用于定义计算域内的分辨率。 + +- *全局几何编码网络*: 这是一个多尺度的点卷积网络,旨在从输入的表面几何体中提取丰富的几何特征。 + + - *多尺度点卷积*: 针对表面几何体,采用一系列具有不同半径参数(核大小)的点卷积核。这使得网络能够同时捕捉几何体的精细局部特征和长程的几何相互作用(例如,车身不同部位之间的相对位置关系)。 + - *特征传播到计算域*: 提取到的表面几何特征需要传播到整个计算域。DOMINO提供了两种方法: + 1. *独立的点卷积*: 学习一组单独的多尺度点卷积核,将表面几何信息直接投影到计算域的规则网格或点云上。 + 2. *CNN块传播*: 使用包含卷积、池化和反池化层的U-Net风格的卷积神经网络(CNN)块,将表面包围盒网格上提取的特征有效、分层地传播到计算域包围盒网格上。这种方法能够更好地捕捉不同尺度下的空间上下文信息。 + +- *局部几何编码*: 尽管全局几何编码提供了丰富的上下文信息,但计算域中任何一点的解场主要受其局部物理环境的影响。因此,DOMINO设计了一种机制来从全局几何编码中提取特定点的局部几何编码。这通过在每个采样点周围定义一个子区域并对该子区域内的特征进行聚合来实现,确保模型能够关注与当前预测点最相关的几何特征。 + +### 3.3 迭代与多分辨率预测框架 + +DOMINO采用了一种多分辨率迭代方法来逐步精化预测结果,这模拟了传统数值模拟的迭代求解过程: + +- *计算模板*: 在计算域中随机或均匀采样一批离散点。对于每个采样点,在其周围采样一定数量(p个)的邻近点,形成一个“计算模板”(类似于有限差分或有限体积方法中的计算单元)。这些邻近点及其特征构成了当前点预测的局部上下文。 + +- *聚合网络*: 一个专门的聚合神经网络被设计用于处理每个采样点及其计算模板的输入特征。这个网络结合了局部几何编码、模板中点的坐标信息以及迭代过程中的当前预测值。它通过对这些局部信息的聚合和非线性变换,预测出当前采样点的新解值。 + +- *迭代细化*: DOMINO是一个迭代模型。在每次迭代中,模型会基于当前的解场和几何信息,利用聚合网络更新所有采样点的解。这种迭代过程允许模型逐步收敛到更准确的解,并可以模拟物理系统中的信息传播。 + +- *多分辨率*: DOMINO支持多分辨率处理。模型可以在不同分辨率的网格或点云上进行预测,通过迭代过程在粗粒度上捕捉大尺度特征,然后在细粒度上捕捉局部细节。 + +### 3.4 表面与体积变量预测 + +DOMINO能够同时预测表面变量(如压力系数$C_p$、壁面剪切应力$\tau_w$)和体积变量(如速度场$u$、压力$p$、湍流参数等)。由于表面变量和体积变量的物理特性和分布模式不同,DOMINO为表面预测和体积预测设计了独立的聚合神经网络。然而,共享的全局几何编码网络可以为两者提供统一的几何上下文信息,提高了模型的效率。 + +$$y_i=\Sigma_{j=0}^{j=n_y}f(\overrightarrow{x_i},\overrightarrow{x_j},d_{ij})$$ + +其中,$\overrightarrow{x_i},\overrightarrow{x_j}$分别表示不同的点云数据集,$d_{ij}$表示两者之间的距离,$n_y$表示独立的聚合神经网络。 + +### 3.5 损失函数与训练 + +模型通常通过最小化预测值与真实CFD模拟数据之间的L2范数误差进行训练。为了提高模型的泛化能力和鲁棒性,可能会结合其他损失项,例如物理约束损失或正则化项。 + +$$\epsilon=\frac{\sqrt(\Sigma(y_T^2))-y_2^P}{\sqrt(\Sigma(y_T^2))}$$ + +其中,$y_T^2$和$y_P^2$分别表示真实值与预测CFD模拟数据。 + +## 4. 完整代码 + +``` py linenums="1" title="examples/domino/domino.py" +--8<-- +examples/domino/domino.py +--8<-- +``` + +## 5. 结果展示 + +### 5.1 实验结果与性能评估 + +在DrivAerML数据集上的实验结果充分证明了DOMINO模型的有效性和优越性: + +- *准确捕获表面和体积流场*: DOMINO模型能够准确地预测汽车表面的压力分布和壁面剪切应力,这些是评估车辆空气动力学性能的关键指标。例如,在挡风玻璃、侧后视镜、车身底部等关键区域,模型的预测值与高保真CFD模拟结果高度吻合。对于体积流场,如速度、压力和湍流粘度,DOMINO也能在整个计算域内提供高精度的预测。 + +- *准确捕捉设计趋势与工程指标*: 除了流场的可视化对比,DOMINO还能准确预测关键工程指标,例如汽车的空气阻力(Drag Force)。模型不仅能给出准确的阻力值,还能捕获不同几何变体下的阻力设计趋势,这对于辅助工程师进行快速设计迭代至关重要。研究中提供了模拟值与DOMINO预测值之间的回归图,进一步证实了其预测精度。 + +- *出色的泛化能力*: 模型在训练集中未见过的几何变体上表现出强大的泛化能力,这表明DOMINO能够学习到通用的物理规律和几何-流场映射关系,而不仅仅是记忆训练数据。 + +- *网格独立性*: 这是DOMINO的一个显著优势。模型在均匀采样的点云上而非原始模拟网格上进行验证,证明了其预测能力不依赖于特定的网格结构。这意味着DOMINO可以应用于不同离散化方式或分辨率的数据,极大地增强了其在实际应用中的灵活性。 + +- *计算效率*: 尽管论文中没有直接给出具体的加速倍数,但作为基于深度学习的代理模型,DOMINO的目标就是在保持高精度的前提下,显著降低推理时间,从而实现近乎实时的空气动力学评估,这对于迭代设计和优化过程至关重要。通过利用局部信息和迭代细化,DOMINO避免了处理整个大规模网格的内存和计算瓶颈。 + +### 5.2 局限性与未来工作 + +尽管DOMINO取得了显著进展,但论文可能也指出了未来改进的方向。例如,对于一些稀疏且复杂的物理量(如湍流粘度),特别是在远离几何体的区域,预测精度仍有提升空间。未来工作可能包括探索更先进的神经网络架构、更有效的损失函数、或者结合更多物理约束来进一步提高模型在各种复杂流态下的预测精度和鲁棒性。 + +## 6. 参考资料 + +- [DoMINO: A Decomposable Multi-scale Iterative Neural Operator for Modeling Large Scale Engineering Simulations](https://arxiv.org/abs/2501.13350) diff --git a/examples/domino/conf/config.yaml b/examples/domino/conf/config.yaml new file mode 100644 index 0000000000..1c029b7c89 --- /dev/null +++ b/examples/domino/conf/config.yaml @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +project: # Project name + name: AWS_Dataset + +mode: train # process, train, eval +seed: 42 +exp_tag: 1 # Experiment tag +# Main output directory. +output: outputs/${project.name}/${exp_tag} + +hydra: # Hydra config + run: + dir: ${output} + output_subdir: hydra # Default is .hydra which causes files not being uploaded in W&B. + +data: # Input directory for training and validation data + input_dir: outputs/volume_data/ + input_dir_val: outputs/volume_data/ + bounding_box: # Bounding box dimensions for computational domain + min: [-3.5, -2.25 , -0.32] + max: [8.5 , 2.25 , 3.00] + bounding_box_surface: # Bounding box dimensions for car surface + min: [-1.1, -1.2 , -0.32] + max: [4.5 , 1.2 , 1.2] + +# The directory to search for checkpoints to continue training. +resume_dir: ${output}/models + +variables: + surface: + solution: + # The following is for AWS DrivAer dataset. + pMeanTrim: scalar + wallShearStressMeanTrim: vector + volume: + solution: + # The following is for AWS DrivAer dataset. + UMeanTrim: vector + pMeanTrim: scalar + nutMeanTrim: scalar + +model: + model_type: combined # train which model? surface, volume, combined + loss_function: "mse" # mse or rmse + interp_res: [64, 32, 24] # resolution of latent space + use_sdf_in_basis_func: true # SDF in basis function network + positional_encoding: false # calculate positional encoding? + volume_points_sample: 1024 # Number of points to sample in volume per epoch + surface_points_sample: 1024 # Number of points to sample on surface per epoch + geom_points_sample: 2_000 # Number of points to sample on STL per epoch + surface_neighbors: true # Pre-compute surface neighborhood from input data + num_surface_neighbors: 7 # How many neighbors? + use_surface_normals: true # Use surface normals and surface areas for surface computation? + use_only_normals: true # Use only surface normals and not surface area + integral_loss_scaling_factor: 0 # Scale integral loss by this factor + normalization: min_max_scaling # or mean_std_scaling + encode_parameters: true # encode inlet velocity and air density in the model + geometry_rep: # Hyperparameters for geometry representation network + base_filters: 16 + geo_conv: + base_neurons: 32 # 256 or 64 + base_neurons_out: 1 + radius_short: 0.1 + radius_long: 0.5 # 1.0, 1.5 + hops: 1 + geo_processor: + base_filters: 8 + geo_processor_sdf: + base_filters: 8 + nn_basis_functions: # Hyperparameters for basis function network + base_layer: 512 + aggregation_model: # Hyperparameters for aggregation network + base_layer: 512 + position_encoder: # Hyperparameters for position encoding network + base_neurons: 512 + geometry_local: # Hyperparameters for local geometry extraction + neighbors_in_radius: 64 + radius: 0.05 # 0.2 in expt 7 + base_layer: 512 + parameter_model: + base_layer: 512 + scaling_params: [30.0, 1.226] # [inlet_velocity, air_density] + +train: # Training configurable parameters + epochs: 50 + checkpoint_interval: 1 + dataloader: + batch_size: 1 + sampler: + shuffle: true + drop_last: false + checkpoint_dir: outputs/AWS_Dataset/3/models/ + +val: # Validation configurable parameters + dataloader: + batch_size: 1 + sampler: + shuffle: true + drop_last: false + +eval: # Testing configurable parameters + test_path: drivaer_data_full + save_path: outputs/mesh_predictions_surf_final1/ + checkpoint_name: outputs/AWS_Dataset/1/models/DoMINO.0.30.pdparams + +data_processor: # Data processor configurable parameters + kind: drivaer_aws # must be either drivesim or drivaer_aws + output_dir: data/volume_data/ + input_dir: drivaer_aws/drivaer_data_full/ + num_processors: 12 diff --git a/examples/domino/domino.py b/examples/domino/domino.py new file mode 100644 index 0000000000..8616b3e633 --- /dev/null +++ b/examples/domino/domino.py @@ -0,0 +1,1682 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino + +import multiprocessing +import os +import re +import time + +import hydra +import numpy as np +import paddle +import paddle.distributed as dist +import pyvista as pv +import vtk +from hydra.utils import to_absolute_path +from omegaconf import DictConfig +from omegaconf import OmegaConf +from paddle import DataParallel +from paddle.amp import GradScaler +from paddle.amp import auto_cast +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from scipy.spatial import KDTree +from vtk.util import numpy_support + +from ppsci.arch.physicsnemo import DoMINO +from ppsci.arch.physicsnemo import create_directory +from ppsci.arch.physicsnemo import get_fields +from ppsci.arch.physicsnemo import get_node_to_elem +from ppsci.arch.physicsnemo import get_volume_data +from ppsci.arch.physicsnemo import load_checkpoint +from ppsci.arch.physicsnemo import mean_std_sampling +from ppsci.arch.physicsnemo import save_checkpoint +from ppsci.arch.physicsnemo import write_to_vtp +from ppsci.arch.physicsnemo import write_to_vtu +from ppsci.data.dataset.domino_datapipe import DoMINODataPipe +from ppsci.data.dataset.domino_datapipe import OpenFoamDataset +from ppsci.data.dataset.domino_datapipe import cal_normal_positional_encoding +from ppsci.data.dataset.domino_datapipe import calculate_center_of_mass +from ppsci.data.dataset.domino_datapipe import create_grid +from ppsci.data.dataset.domino_datapipe import get_filenames +from ppsci.data.dataset.domino_datapipe import normalize +from ppsci.data.dataset.domino_datapipe import unnormalize +from ppsci.data.process.openfoam import process_files +from ppsci.utils.sdf import signed_distance_field + +AIR_DENSITY = 1.205 +STREAM_VELOCITY = 30.00 + +paddle.set_device("gpu") + + +def process(cfg: DictConfig): + print(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") + volume_variable_names = list(cfg.variables.volume.solution.keys()) + num_vol_vars = 0 + for j in volume_variable_names: + if cfg.variables.volume.solution[j] == "vector": + num_vol_vars += 3 + else: + num_vol_vars += 1 + + surface_variable_names = list(cfg.variables.surface.solution.keys()) + num_surf_vars = 0 + for j in surface_variable_names: + if cfg.variables.surface.solution[j] == "vector": + num_surf_vars += 3 + else: + num_surf_vars += 1 + + fm_data = OpenFoamDataset( + cfg.data_processor.input_dir, + kind=cfg.data_processor.kind, + volume_variables=volume_variable_names, + surface_variables=surface_variable_names, + model_type=cfg.model.model_type, + ) + output_dir = cfg.data_processor.output_dir + create_directory(output_dir) # noqa: F405 + n_processors = cfg.data_processor.num_processors + + num_files = len(fm_data) + ids = np.arange(num_files) + num_elements = int(num_files / n_processors) + 1 + process_list = [] + ctx = multiprocessing.get_context("spawn") + for i in range(n_processors): + if i != n_processors - 1: + sf = ids[i * num_elements : i * num_elements + num_elements] + else: + sf = ids[i * num_elements :] + # print(sf) + process = ctx.Process(target=process_files, args=(sf, i, fm_data, output_dir)) + + process.start() + process_list.append(process) + + for process in process_list: + process.join() + + +def relative_loss_fn(output, target, padded_value=-10): + mask = abs(target - padded_value) > 1e-3 + mask = mask.to(dtype=output.dtype) + masked_loss = paddle.sum(((output - target) ** 2.0) * mask, (0, 1)) / paddle.sum( + mask, (0, 1) + ) + masked_truth = paddle.sum(((target) ** 2.0) * mask, (0, 1)) / paddle.sum( + mask, (0, 1) + ) + loss = paddle.mean(masked_loss / masked_truth) + return loss + + +def mse_loss_fn(output, target, padded_value=-10): + mask = abs(target - padded_value) > 1e-3 + mask = mask.to(dtype=output.dtype) + masked_loss = paddle.sum(((output - target) ** 2.0) * mask, (0, 1)) / paddle.sum( + mask, (0, 1) + ) + loss = paddle.mean(masked_loss) + return loss + + +def mse_loss_fn_surface(output, target, normals, padded_value=-10): + masked_loss_pres = paddle.mean( + ((output[:, :, :1] - target[:, :, :1]) ** 2.0), (0, 1) + ) + + ws_x_true = target[:, :, 1:2] + ws_x_pred = output[:, :, 1:2] + masked_loss_ws_x = paddle.mean(((ws_x_pred - ws_x_true) ** 2.0), (0, 1)) + + ws_y_true = target[:, :, 2:3] + ws_y_pred = output[:, :, 2:3] + masked_loss_ws_y = paddle.mean(((ws_y_pred - ws_y_true) ** 2.0), (0, 1)) + + ws_z_true = target[:, :, 3:4] + ws_z_pred = output[:, :, 3:4] + masked_loss_ws_z = paddle.mean(((ws_z_pred - ws_z_true) ** 2.0), (0, 1)) + + loss = ( + paddle.mean(masked_loss_pres) + + paddle.mean(masked_loss_ws_x) + + paddle.mean(masked_loss_ws_y) + + paddle.mean(masked_loss_ws_z) + ) + loss = loss / 4 + return loss + + +def relative_loss_fn_surface(output, target, normals, padded_value=-10): + masked_loss_pres = paddle.mean( + ((output[:, :, :1] - target[:, :, :1]) ** 2.0), (0, 1) + ) / paddle.mean(((target[:, :, :1]) ** 2.0), (0, 1)) + + ws_x_true = target[:, :, 1:2] + ws_x_pred = output[:, :, 1:2] + masked_loss_ws_x = paddle.mean( + ((ws_x_pred - ws_x_true) ** 2.0), (0, 1) + ) / paddle.mean(((ws_x_true) ** 2.0), (0, 1)) + + ws_y_true = target[:, :, 2:3] + ws_y_pred = output[:, :, 2:3] + masked_loss_ws_y = paddle.mean( + ((ws_y_pred - ws_y_true) ** 2.0), (0, 1) + ) / paddle.mean(((ws_y_true) ** 2.0), (0, 1)) + + ws_z_true = target[:, :, 3:4] + ws_z_pred = output[:, :, 3:4] + masked_loss_ws_z = paddle.mean( + ((ws_z_pred - ws_z_true) ** 2.0), (0, 1) + ) / paddle.mean(((ws_z_true) ** 2.0), (0, 1)) + + loss = ( + paddle.mean(masked_loss_pres) + + paddle.mean(masked_loss_ws_x) + + paddle.mean(masked_loss_ws_y) + + paddle.mean(masked_loss_ws_z) + ) + loss = loss / 4 + return loss + + +def relative_loss_fn_area(output, target, normals, area, padded_value=-10): + scale_factor = 1.0 # Get this from the dataset + area = area * 10**4 + pres_x_true = target[:, :, :1] * normals[:, :, 0:1] * area * scale_factor**2.0 + pres_x_pred = output[:, :, :1] * normals[:, :, 0:1] * area * scale_factor**2.0 + + masked_loss_pres_x = paddle.mean( + ((pres_x_pred - pres_x_true) ** 2.0), (0, 1) + ) / paddle.mean(((pres_x_true) ** 2.0), (0, 1)) + + ws_x_true = target[:, :, 1:2] * area * scale_factor**2.0 + ws_x_pred = output[:, :, 1:2] * area * scale_factor**2.0 + masked_loss_ws_x = paddle.mean( + ((ws_x_pred - ws_x_true) ** 2.0), (0, 1) + ) / paddle.mean(((ws_x_true) ** 2.0), (0, 1)) + + ws_y_true = target[:, :, 2:3] * area * scale_factor**2.0 + ws_y_pred = output[:, :, 2:3] * area * scale_factor**2.0 + masked_loss_ws_y = paddle.mean( + ((ws_y_pred - ws_y_true) ** 2.0), (0, 1) + ) / paddle.mean(((ws_y_true) ** 2.0), (0, 1)) + + ws_z_true = target[:, :, 3:4] * area * scale_factor**2.0 + ws_z_pred = output[:, :, 3:4] * area * scale_factor**2.0 + masked_loss_ws_z = paddle.mean( + ((ws_z_pred - ws_z_true) ** 2.0), (0, 1) + ) / paddle.mean(((ws_z_true) ** 2.0), (0, 1)) + + loss = ( + paddle.mean(masked_loss_pres_x) + + paddle.mean(masked_loss_ws_x) + + paddle.mean(masked_loss_ws_y) + + paddle.mean(masked_loss_ws_z) + ) + loss = loss / 4 + return loss + + +def mse_loss_fn_area(output, target, normals, area, padded_value=-10): + scale_factor = 1.0 # Get this from the dataset + area = area * 10**4 + + pres_x_true = target[:, :, :1] * normals[:, :, 0:1] * area * scale_factor**2.0 + pres_x_pred = output[:, :, :1] * normals[:, :, 0:1] * area * scale_factor**2.0 + + masked_loss_pres_x = paddle.mean(((pres_x_pred - pres_x_true) ** 2.0), (0, 1)) + + ws_x_true = target[:, :, 1:2] * area * scale_factor**2.0 + ws_x_pred = output[:, :, 1:2] * area * scale_factor**2.0 + masked_loss_ws_x = paddle.mean(((ws_x_pred - ws_x_true) ** 2.0), (0, 1)) + + ws_y_true = target[:, :, 2:3] * area * scale_factor**2.0 + ws_y_pred = output[:, :, 2:3] * area * scale_factor**2.0 + masked_loss_ws_y = paddle.mean(((ws_y_pred - ws_y_true) ** 2.0), (0, 1)) + + ws_z_true = target[:, :, 3:4] * area * scale_factor**2.0 + ws_z_pred = output[:, :, 3:4] * area * scale_factor**2.0 + masked_loss_ws_z = paddle.mean(((ws_z_pred - ws_z_true) ** 2.0), (0, 1)) + + loss = ( + paddle.mean(masked_loss_pres_x) + + paddle.mean(masked_loss_ws_x) + + paddle.mean(masked_loss_ws_y) + + paddle.mean(masked_loss_ws_z) + ) + loss = loss / 4 + return loss + + +def integral_loss_fn(output, target, area, normals, padded_value=-10): + vel_inlet = 30.0 # Get this from the dataset + mask = abs(target - padded_value) > 1e-3 + mask = mask.to(dtype=output.dtype) + area = paddle.unsqueeze(area, -1) + output_true = target * mask * area * (vel_inlet) ** 2.0 + output_pred = output * mask * area * (vel_inlet) ** 2.0 + + output_true[:, :, 0] = output_true[:, :, 0] * normals[:, :, 0] + output_pred[:, :, 0] = output_pred[:, :, 0] * normals[:, :, 0] + + masked_pred = paddle.sum(output_pred, (1)) + masked_truth = paddle.sum(output_true, (1)) + + loss = (masked_pred - masked_truth) ** 2.0 + loss = paddle.mean(loss) + return loss + + +def integral_loss_fn_new(output, target, area, normals, padded_value=-10): + drag_loss = drag_loss_fn(output, target, area, normals, padded_value=-10) + lift_loss = lift_loss_fn(output, target, area, normals, padded_value=-10) + return lift_loss + drag_loss + + +def lift_loss_fn(output, target, area, normals, padded_value=-10): + vel_inlet = 30.0 # Get this from the dataset + mask = abs(target - padded_value) > 1e-3 + mask = mask.to(dtype=output.dtype) + area = paddle.unsqueeze(area, -1) + output_true = target * mask * area * (vel_inlet) ** 2.0 + output_pred = output * mask * area * (vel_inlet) ** 2.0 + + pres_true = output_true[:, :, 0] * normals[:, :, 2] + pres_pred = output_pred[:, :, 0] * normals[:, :, 2] + + wz_true = output_true[:, :, -1] + wz_pred = output_pred[:, :, -1] + + masked_pred = paddle.sum(pres_pred + wz_pred, (1)) / ( + paddle.sum(area) * (vel_inlet) ** 2.0 + ) + masked_truth = paddle.sum(pres_true + wz_true, (1)) / ( + paddle.sum(area) * (vel_inlet) ** 2.0 + ) + + loss = (masked_pred - masked_truth) ** 2.0 + loss = paddle.mean(loss) + return loss + + +def drag_loss_fn(output, target, area, normals, padded_value=-10): + vel_inlet = 30.0 # Get this from the dataset + mask = abs(target - padded_value) > 1e-3 + mask = mask.to(dtype=output.dtype) + area = paddle.unsqueeze(area, -1) + output_true = target * mask * area * (vel_inlet) ** 2.0 + output_pred = output * mask * area * (vel_inlet) ** 2.0 + + pres_true = output_true[:, :, 0] * normals[:, :, 0] + pres_pred = output_pred[:, :, 0] * normals[:, :, 0] + + wx_true = output_true[:, :, 1] + wx_pred = output_pred[:, :, 1] + + masked_pred = paddle.sum(pres_pred + wx_pred, (1)) / ( + paddle.sum(area) * (vel_inlet) ** 2.0 + ) + masked_truth = paddle.sum(pres_true + wx_true, (1)) / ( + paddle.sum(area) * (vel_inlet) ** 2.0 + ) + + loss = (masked_pred - masked_truth) ** 2.0 + loss = paddle.mean(loss) + return loss + + +def validation_step( + dataloader, + model, + device, + use_sdf_basis=False, + use_surface_normals=False, + integral_scaling_factor=1.0, + loss_fn_type="mse", +): + running_vloss = 0.0 + with paddle.no_grad(): + for i_batch, sampled_batched in enumerate(dataloader): + prediction_vol, prediction_surf = model(sampled_batched) + + if prediction_vol is not None: + target_vol = sampled_batched["volume_fields"] + if loss_fn_type == "rmse": + loss_norm_vol = relative_loss_fn( + prediction_vol, target_vol, padded_value=-10 + ) + else: + loss_norm_vol = mse_loss_fn( + prediction_vol, target_vol, padded_value=-10 + ) + + if prediction_surf is not None: + target_surf = sampled_batched["surface_fields"] + surface_normals = sampled_batched["surface_normals"] + surface_areas = sampled_batched["surface_areas"] + if loss_fn_type == "rmse": + loss_norm_surf = relative_loss_fn_surface( + prediction_surf, target_surf, surface_normals, padded_value=-10 + ) + loss_norm_surf_area = relative_loss_fn_area( + prediction_surf, + target_surf, + surface_normals, + surface_areas, + padded_value=-10, + ) + else: + loss_norm_surf = mse_loss_fn_surface( + prediction_surf, target_surf, surface_normals, padded_value=-10 + ) + loss_norm_surf_area = mse_loss_fn_area( + prediction_surf, + target_surf, + surface_normals, + surface_areas, + padded_value=-10, + ) + loss_integral = ( + integral_loss_fn_new( + prediction_surf, + target_surf, + surface_areas, + surface_normals, + padded_value=-10, + ) + ) * integral_scaling_factor + + if prediction_surf is not None and prediction_vol is not None: + vloss = ( + loss_norm_vol + + 0.5 * loss_norm_surf + + loss_integral + + 0.5 * loss_norm_surf_area + ) + elif prediction_vol is not None: + vloss = loss_norm_vol + elif prediction_surf is not None: + vloss = 0.5 * loss_norm_surf + loss_integral + 0.5 * loss_norm_surf_area + + running_vloss += vloss + + avg_vloss = running_vloss / (i_batch + 1) + + return avg_vloss + + +def train_epoch( + dataloader, + model, + optimizer, + scaler, + epoch_index, + device, + integral_scaling_factor, + loss_fn_type, +): + + running_loss = 0.0 + last_loss = 0.0 + loss_interval = 1 + + for i_batch, sampled_batched in enumerate(dataloader): + with auto_cast(enable=False): + prediction_vol, prediction_surf = model(sampled_batched) + + if prediction_vol is not None: + target_vol = sampled_batched["volume_fields"] + if loss_fn_type == "rmse": + loss_norm_vol = relative_loss_fn( + prediction_vol, target_vol, padded_value=-10 + ) + else: + loss_norm_vol = mse_loss_fn( + prediction_vol, target_vol, padded_value=-10 + ) + + if prediction_surf is not None: + + target_surf = sampled_batched["surface_fields"] + surface_areas = sampled_batched["surface_areas"] + surface_normals = sampled_batched["surface_normals"] + if loss_fn_type == "rmse": + loss_norm_surf = relative_loss_fn_surface( + prediction_surf, target_surf, surface_normals, padded_value=-10 + ) + loss_norm_surf_area = relative_loss_fn_area( + prediction_surf, + target_surf, + surface_normals, + surface_areas, + padded_value=-10, + ) + else: + loss_norm_surf = mse_loss_fn_surface( + prediction_surf, target_surf, surface_normals, padded_value=-10 + ) + loss_norm_surf_area = mse_loss_fn_area( + prediction_surf, + target_surf, + surface_normals, + surface_areas, + padded_value=-10, + ) + loss_integral = ( + integral_loss_fn_new( + prediction_surf, + target_surf, + surface_areas, + surface_normals, + padded_value=-10, + ) + ) * integral_scaling_factor + + if prediction_vol is not None and prediction_surf is not None: + loss_norm = ( + loss_norm_vol + + 0.5 * loss_norm_surf + + loss_integral + + 0.5 * loss_norm_surf_area + ) + elif prediction_vol is not None: + loss_norm = loss_norm_vol + elif prediction_surf is not None: + loss_norm = ( + 0.5 * loss_norm_surf + loss_integral + 0.5 * loss_norm_surf_area + ) + + loss = loss_norm + loss = loss / loss_interval + scaler.scale(loss).backward() + + if ((i_batch + 1) % loss_interval == 0) or (i_batch + 1 == len(dataloader)): + scaler.step(optimizer) + scaler.update() + optimizer.clear_gradients() + # Gather data and report + running_loss += loss.item() + + if prediction_vol is not None and prediction_surf is not None: + print( + f"Device {device}, batch processed: {i_batch + 1}, loss volume: {loss_norm_vol:.5f} \ + , loss surface: {loss_norm_surf:.5f}, loss integral: {loss_integral:.5f}, loss surface area: {loss_norm_surf_area:.5f}" + ) + elif prediction_vol is not None: + print( + f"Device {device}, batch processed: {i_batch + 1}, loss volume: {loss_norm_vol:.5f}" + ) + elif prediction_surf is not None: + print( + f"Device {device}, batch processed: {i_batch + 1} \ + , loss surface: {loss_norm_surf:.5f}, loss integral: {loss_integral:.5f}, loss surface area: {loss_norm_surf_area:.5f}" + ) + + last_loss = running_loss / (i_batch + 1) # loss per batch + print(f" Device {device}, batch: {i_batch + 1}, loss norm: {loss:.5f}") + tb_x = epoch_index * len(dataloader) + i_batch + 1 + print(f"Loss/train: {last_loss}/{tb_x}") + + return last_loss + + +def compute_scaling_factors(cfg: DictConfig): + + model_type = cfg.model.model_type + + if model_type == "volume" or model_type == "combined": + vol_save_path = os.path.join( + "outputs", cfg.project.name, "volume_scaling_factors.npy" + ) + if not os.path.exists(vol_save_path): + input_path = cfg.data.input_dir + + volume_variable_names = list(cfg.variables.volume.solution.keys()) + + fm_dict = DoMINODataPipe( + input_path, + phase="train", + grid_resolution=cfg.model.interp_res, + volume_variables=volume_variable_names, + surface_variables=None, + normalize_coordinates=True, + sampling=False, + sample_in_bbox=True, + volume_points_sample=cfg.model.volume_points_sample, + geom_points_sample=cfg.model.geom_points_sample, + positional_encoding=cfg.model.positional_encoding, + model_type=cfg.model.model_type, + bounding_box_dims=cfg.data.bounding_box, + bounding_box_dims_surf=cfg.data.bounding_box_surface, + compute_scaling_factors=True, + ) + + # Calculate mean + if cfg.model.normalization == "mean_std_scaling": + for j in range(len(fm_dict)): + d_dict = fm_dict[j] + vol_fields = d_dict["volume_fields"] + + if vol_fields is not None: + if j == 0: + vol_fields_sum = np.mean(vol_fields, 0) + else: + vol_fields_sum += np.mean(vol_fields, 0) + else: + vol_fields_sum = 0.0 + + vol_fields_mean = vol_fields_sum / len(fm_dict) + + for j in range(len(fm_dict)): + d_dict = fm_dict[j] + vol_fields = d_dict["volume_fields"] + + if vol_fields is not None: + if j == 0: + vol_fields_sum_square = np.mean( + (vol_fields - vol_fields_mean) ** 2.0, 0 + ) + else: + vol_fields_sum_square += np.mean( + (vol_fields - vol_fields_mean) ** 2.0, 0 + ) + else: + vol_fields_sum_square = 0.0 + + vol_fields_std = np.sqrt(vol_fields_sum_square / len(fm_dict)) + + vol_scaling_factors = [vol_fields_mean, vol_fields_std] + + if cfg.model.normalization == "min_max_scaling": + for j in range(len(fm_dict)): + d_dict = fm_dict[j] + vol_fields = d_dict["volume_fields"] + + if vol_fields is not None: + vol_mean = np.mean(vol_fields, 0) + vol_std = np.std(vol_fields, 0) + vol_idx = mean_std_sampling( + vol_fields, vol_mean, vol_std, tolerance=12.0 + ) + vol_fields_sampled = np.delete(vol_fields, vol_idx, axis=0) + if j == 0: + vol_fields_max = np.amax(vol_fields_sampled, 0) + vol_fields_min = np.amin(vol_fields_sampled, 0) + else: + vol_fields_max1 = np.amax(vol_fields_sampled, 0) + vol_fields_min1 = np.amin(vol_fields_sampled, 0) + + for k in range(vol_fields.shape[-1]): + if vol_fields_max1[k] > vol_fields_max[k]: + vol_fields_max[k] = vol_fields_max1[k] + + if vol_fields_min1[k] < vol_fields_min[k]: + vol_fields_min[k] = vol_fields_min1[k] + else: + vol_fields_max = 0.0 + vol_fields_min = 0.0 + + if j > 20: + break + vol_scaling_factors = [vol_fields_max, vol_fields_min] + np.save(vol_save_path, vol_scaling_factors) + + if model_type == "surface" or model_type == "combined": + surf_save_path = os.path.join( + "outputs", cfg.project.name, "surface_scaling_factors.npy" + ) + + if not os.path.exists(surf_save_path): + input_path = cfg.data.input_dir + + volume_variable_names = list(cfg.variables.volume.solution.keys()) + surface_variable_names = list(cfg.variables.surface.solution.keys()) + + fm_dict = DoMINODataPipe( + input_path, + phase="train", + grid_resolution=cfg.model.interp_res, + volume_variables=None, + surface_variables=surface_variable_names, + normalize_coordinates=True, + sampling=False, + sample_in_bbox=True, + volume_points_sample=cfg.model.volume_points_sample, + geom_points_sample=cfg.model.geom_points_sample, + positional_encoding=cfg.model.positional_encoding, + model_type=cfg.model.model_type, + bounding_box_dims=cfg.data.bounding_box, + bounding_box_dims_surf=cfg.data.bounding_box_surface, + compute_scaling_factors=True, + ) + + # Calculate mean + if cfg.model.normalization == "mean_std_scaling": + for j in range(len(fm_dict)): + d_dict = fm_dict[j] + surf_fields = d_dict["surface_fields"] + + if surf_fields is not None: + if j == 0: + surf_fields_sum = np.mean(surf_fields, 0) + else: + surf_fields_sum += np.mean(surf_fields, 0) + else: + surf_fields_sum = 0.0 + + surf_fields_mean = surf_fields_sum / len(fm_dict) + + for j in range(len(fm_dict)): + d_dict = fm_dict[j] + surf_fields = d_dict["surface_fields"] + + if surf_fields is not None: + if j == 0: + surf_fields_sum_square = np.mean( + (surf_fields - surf_fields_mean) ** 2.0, 0 + ) + else: + surf_fields_sum_square += np.mean( + (surf_fields - surf_fields_mean) ** 2.0, 0 + ) + else: + surf_fields_sum_square = 0.0 + + surf_fields_std = np.sqrt(surf_fields_sum_square / len(fm_dict)) + + surf_scaling_factors = [surf_fields_mean, surf_fields_std] + + if cfg.model.normalization == "min_max_scaling": + for j in range(len(fm_dict)): + d_dict = fm_dict[j] + surf_fields = d_dict["surface_fields"] + + if surf_fields is not None: + surf_mean = np.mean(surf_fields, 0) + surf_std = np.std(surf_fields, 0) + surf_idx = mean_std_sampling( + surf_fields, surf_mean, surf_std, tolerance=12.0 + ) + surf_fields_sampled = np.delete(surf_fields, surf_idx, axis=0) + if j == 0: + surf_fields_max = np.amax(surf_fields_sampled, 0) + surf_fields_min = np.amin(surf_fields_sampled, 0) + else: + surf_fields_max1 = np.amax(surf_fields_sampled, 0) + surf_fields_min1 = np.amin(surf_fields_sampled, 0) + + for k in range(surf_fields.shape[-1]): + if surf_fields_max1[k] > surf_fields_max[k]: + surf_fields_max[k] = surf_fields_max1[k] + + if surf_fields_min1[k] < surf_fields_min[k]: + surf_fields_min[k] = surf_fields_min1[k] + else: + surf_fields_max = 0.0 + surf_fields_min = 0.0 + + if j > 20: + break + + surf_scaling_factors = [surf_fields_max, surf_fields_min] + np.save(surf_save_path, surf_scaling_factors) + + +def train(cfg: DictConfig) -> None: + compute_scaling_factors(cfg) + input_path = cfg.data.input_dir + input_path_val = cfg.data.input_dir_val + model_type = cfg.model.model_type + + dist.init_parallel_env() + + print(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") + + num_vol_vars = 0 + volume_variable_names = [] + if model_type == "volume" or model_type == "combined": + volume_variable_names = list(cfg.variables.volume.solution.keys()) + for j in volume_variable_names: + if cfg.variables.volume.solution[j] == "vector": + num_vol_vars += 3 + else: + num_vol_vars += 1 + else: + num_vol_vars = None + + num_surf_vars = 0 + surface_variable_names = [] + if model_type == "surface" or model_type == "combined": + surface_variable_names = list(cfg.variables.surface.solution.keys()) + num_surf_vars = 0 + for j in surface_variable_names: + if cfg.variables.surface.solution[j] == "vector": + num_surf_vars += 3 + else: + num_surf_vars += 1 + else: + num_surf_vars = None + + vol_save_path = os.path.join( + "outputs", cfg.project.name, "volume_scaling_factors.npy" + ) + surf_save_path = os.path.join( + "outputs", cfg.project.name, "surface_scaling_factors.npy" + ) + if os.path.exists(vol_save_path) and os.path.exists(surf_save_path): + vol_factors = np.load(vol_save_path) + surf_factors = np.load(surf_save_path) + else: + vol_factors = None + surf_factors = None + + train_dataset = DoMINODataPipe( + input_path, + phase="train", + grid_resolution=cfg.model.interp_res, + volume_variables=volume_variable_names, + surface_variables=surface_variable_names, + normalize_coordinates=True, + sampling=True, + sample_in_bbox=True, + volume_points_sample=cfg.model.volume_points_sample, + surface_points_sample=cfg.model.surface_points_sample, + geom_points_sample=cfg.model.geom_points_sample, + positional_encoding=cfg.model.positional_encoding, + volume_factors=vol_factors, + surface_factors=surf_factors, + scaling_type=cfg.model.normalization, + model_type=cfg.model.model_type, + bounding_box_dims=cfg.data.bounding_box, + bounding_box_dims_surf=cfg.data.bounding_box_surface, + num_surface_neighbors=cfg.model.num_surface_neighbors, + ) + + val_dataset = DoMINODataPipe( + input_path_val, + phase="val", + grid_resolution=cfg.model.interp_res, + volume_variables=volume_variable_names, + surface_variables=surface_variable_names, + normalize_coordinates=True, + sampling=True, + sample_in_bbox=True, + volume_points_sample=cfg.model.volume_points_sample, + surface_points_sample=cfg.model.surface_points_sample, + geom_points_sample=cfg.model.geom_points_sample, + positional_encoding=cfg.model.positional_encoding, + volume_factors=vol_factors, + surface_factors=surf_factors, + scaling_type=cfg.model.normalization, + model_type=cfg.model.model_type, + bounding_box_dims=cfg.data.bounding_box, + bounding_box_dims_surf=cfg.data.bounding_box_surface, + num_surface_neighbors=cfg.model.num_surface_neighbors, + ) + print(f">>>>>> paddle.distributed.get_rank(): {paddle.distributed.get_rank()}") + print( + f">>>>>> paddle.distributed.get_world_size(): {paddle.distributed.get_world_size()}" + ) + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=1, + num_replicas=paddle.distributed.get_world_size(), + rank=paddle.distributed.get_rank(), + **cfg.train.sampler, + ) + + val_sampler = DistributedBatchSampler( + val_dataset, + batch_size=1, + num_replicas=paddle.distributed.get_world_size(), + rank=paddle.distributed.get_rank(), + **cfg.val.sampler, + ) + + train_dataloader = DataLoader(train_dataset, **cfg.train.dataloader) + val_dataloader = DataLoader(val_dataset, **cfg.val.dataloader) + + model = DoMINO( + input_features=3, + output_features_vol=num_vol_vars, + output_features_surf=num_surf_vars, + model_parameters=cfg.model, + ) + + if paddle.distributed.get_world_size() > 1: + model = DataParallel( + model, + ) + + optimizer = paddle.optimizer.Adam( + parameters=model.parameters(), learning_rate=0.001 + ) + scheduler = paddle.optimizer.lr.MultiStepDecay( + learning_rate=optimizer.get_lr(), + milestones=[50, 100, 150, 200, 250, 300, 350, 400], + gamma=0.5, + ) + optimizer.set_lr_scheduler(scheduler) + + # Initialize the scaler for mixed precision + scaler = GradScaler() + + epoch_number = 0 + + model_save_path = os.path.join(cfg.output, "models") + param_save_path = os.path.join(cfg.output, "param") + best_model_path = os.path.join(model_save_path, "best_model") + if paddle.distributed.get_rank() == 0: + create_directory(model_save_path) + create_directory(param_save_path) + create_directory(best_model_path) + + if paddle.distributed.get_world_size() > 1: + paddle.distributed.barrier() + + init_epoch = load_checkpoint( + to_absolute_path(cfg.resume_dir), + models=model, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + ) + + if init_epoch != 0: + init_epoch += 1 # Start with the next epoch + epoch_number = init_epoch + + # retrive the smallest validation loss if available + numbers = [] + for filename in os.listdir(best_model_path): + match = re.search(r"\d+\.\d*[1-9]\d*", filename) + if match: + number = float(match.group(0)) + numbers.append(number) + + best_vloss = min(numbers) if numbers else 1_000_000.0 + + initial_integral_factor_orig = cfg.model.integral_loss_scaling_factor + + for epoch in range(init_epoch, cfg.train.epochs): + start_time = time.time() + print(f"Device {paddle.distributed.get_rank()}, epoch {epoch_number}:") + + train_sampler.set_epoch(epoch) + val_sampler.set_epoch(epoch) + + initial_integral_factor = initial_integral_factor_orig + + model.train() + avg_loss = train_epoch( + dataloader=train_dataloader, + model=model, + optimizer=optimizer, + scaler=scaler, + epoch_index=epoch, + device=paddle.distributed.get_rank(), + integral_scaling_factor=initial_integral_factor, + loss_fn_type=cfg.model.loss_function, + ) + + model.eval() + avg_vloss = validation_step( + dataloader=val_dataloader, + model=model, + device=paddle.distributed.get_rank(), + use_sdf_basis=cfg.model.use_sdf_in_basis_func, + use_surface_normals=cfg.model.use_surface_normals, + integral_scaling_factor=initial_integral_factor, + loss_fn_type=cfg.model.loss_function, + ) + + scheduler.step() + print( + f"Device {paddle.distributed.get_rank()} " + f"LOSS train {avg_loss:.5f} " + f"valid {avg_vloss:.5f} " + f"Current lr {scheduler.get_lr()}" + f"Integral factor {initial_integral_factor}" + ) + + # Track best performance, and save the model's state + if paddle.distributed.get_world_size() > 1: + paddle.distributed.barrier() + + if avg_vloss < best_vloss: # This only considers GPU: 0, is that okay? + best_vloss = avg_vloss + print( + f"Device { paddle.distributed.get_rank()}, Best val loss {best_vloss}, Time taken {time.time() - start_time}" + ) + + if ( + paddle.distributed.get_rank() == 0 + and (epoch + 1) % cfg.train.checkpoint_interval == 0.0 + ): + save_checkpoint( + to_absolute_path(model_save_path), + models=model, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + epoch=epoch, + ) + + epoch_number += 1 + + if scheduler.get_lr() == 1e-6: + print("Training ended") + exit() + + +def loss_fn(output, target): + masked_loss = paddle.mean(((output - target) ** 2.0), (0, 1, 2)) + loss = paddle.mean(masked_loss) + return loss + + +def test_step(data_dict, model, device, cfg, vol_factors, surf_factors): + running_tloss_vol = 0.0 + running_tloss_surf = 0.0 + + if cfg.model.model_type == "volume" or cfg.model.model_type == "combined": + output_features_vol = True + else: + output_features_vol = None + + if cfg.model.model_type == "surface" or cfg.model.model_type == "combined": + output_features_surf = True + else: + output_features_surf = None + + with paddle.no_grad(): + point_batch_size = 256000 + + # Non-dimensionalization factors + air_density = data_dict["air_density"] + stream_velocity = data_dict["stream_velocity"] + length_scale = data_dict["length_scale"] + + # STL nodes + geo_centers = data_dict["geometry_coordinates"] + + # Bounding box grid + s_grid = data_dict["surf_grid"] + sdf_surf_grid = data_dict["sdf_surf_grid"] + # Scaling factors + surf_max = data_dict["surface_min_max"][:, 1] + surf_min = data_dict["surface_min_max"][:, 0] + + if output_features_vol is not None: + # Represent geometry on computational grid + # Computational domain grid + p_grid = data_dict["grid"] + sdf_grid = data_dict["sdf_grid"] + # Scaling factors + vol_max = data_dict["volume_min_max"][:, 1] + vol_min = data_dict["volume_min_max"][:, 0] + + # Normalize based on computational domain + geo_centers_vol = 2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1 + encoding_g_vol = model.module.geo_rep(geo_centers_vol, p_grid, sdf_grid) + + # Normalize based on BBox around surface (car) + geo_centers_surf = ( + 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + ) + encoding_g_surf = model.module.geo_rep( + geo_centers_surf, s_grid, sdf_surf_grid + ) + + if output_features_surf is not None: + # Represent geometry on bounding box + geo_centers_surf = ( + 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + ) + encoding_g_surf = model.module.geo_rep( + geo_centers_surf, s_grid, sdf_surf_grid + ) + + geo_encoding = 0.5 * encoding_g_surf + # Average the encodings + if output_features_vol is not None: + geo_encoding += 0.5 * encoding_g_vol + + if output_features_vol is not None: + # First calculate volume predictions if required + volume_mesh_centers = data_dict["volume_mesh_centers"] + target_vol = data_dict["volume_fields"] + # SDF on volume mesh nodes + sdf_nodes = data_dict["sdf_nodes"] + # Positional encoding based on closest point on surface to a volume node + pos_volume_closest = data_dict["pos_volume_closest"] + # Positional encoding based on center of mass of geometry to volume node + pos_volume_center_of_mass = data_dict["pos_volume_center_of_mass"] + p_grid = data_dict["grid"] + + prediction_vol = np.zeros_like(target_vol.cpu().numpy()) + num_points = volume_mesh_centers.shape[1] + subdomain_points = int(np.floor(num_points / point_batch_size)) + + for p in range(subdomain_points + 1): + start_idx = p * point_batch_size + end_idx = (p + 1) * point_batch_size + with paddle.no_grad(): + target_batch = target_vol[:, start_idx:end_idx] + volume_mesh_centers_batch = volume_mesh_centers[ + :, start_idx:end_idx + ] + sdf_nodes_batch = sdf_nodes[:, start_idx:end_idx] + pos_volume_closest_batch = pos_volume_closest[:, start_idx:end_idx] + pos_normals_com_batch = pos_volume_center_of_mass[ + :, start_idx:end_idx + ] + geo_encoding_local = model.module.geo_encoding_local( + geo_encoding, volume_mesh_centers_batch, p_grid + ) + if cfg.model.use_sdf_in_basis_func: + pos_encoding = paddle.concat( + ( + sdf_nodes_batch, + pos_volume_closest_batch, + pos_normals_com_batch, + ), + axis=-1, + ) + else: + pos_encoding = pos_normals_com_batch + pos_encoding = model.module.position_encoder( + pos_encoding, eval_mode="volume" + ) + tpredictions_batch = model.module.calculate_solution( + volume_mesh_centers_batch, + geo_encoding_local, + pos_encoding, + stream_velocity, + air_density, + num_sample_points=20, + eval_mode="volume", + ) + running_tloss_vol += loss_fn(tpredictions_batch, target_batch) + prediction_vol[ + :, start_idx:end_idx + ] = tpredictions_batch.cpu().numpy() + + prediction_vol = unnormalize(prediction_vol, vol_factors[0], vol_factors[1]) + + prediction_vol[:, :, :3] = ( + prediction_vol[:, :, :3] * stream_velocity[0, 0].cpu().numpy() + ) + prediction_vol[:, :, 3] = ( + prediction_vol[:, :, 3] + * stream_velocity[0, 0].cpu().numpy() ** 2.0 + * air_density[0, 0].cpu().numpy() + ) + prediction_vol[:, :, 4] = ( + prediction_vol[:, :, 4] + * stream_velocity[0, 0].cpu().numpy() + * length_scale[0].cpu().numpy() + ) + else: + prediction_vol = None + + if output_features_surf is not None: + # Next calculate surface predictions + # Sampled points on surface + surface_mesh_centers = data_dict["surface_mesh_centers"] + surface_normals = data_dict["surface_normals"] + surface_areas = data_dict["surface_areas"] + + # Neighbors of sampled points on surface + surface_mesh_neighbors = data_dict["surface_mesh_neighbors"] + surface_neighbors_normals = data_dict["surface_neighbors_normals"] + surface_neighbors_areas = data_dict["surface_neighbors_areas"] + surface_areas = paddle.unsqueeze(surface_areas, -1) + surface_neighbors_areas = paddle.unsqueeze(surface_neighbors_areas, -1) + pos_surface_center_of_mass = data_dict["pos_surface_center_of_mass"] + num_points = surface_mesh_centers.shape[1] + subdomain_points = int(np.floor(num_points / point_batch_size)) + + target_surf = data_dict["surface_fields"] + prediction_surf = np.zeros_like(target_surf.cpu().numpy()) + + surface_areas = paddle.unsqueeze(surface_areas, -1) + surface_neighbors_areas = paddle.unsqueeze(surface_neighbors_areas, -1) + + for p in range(subdomain_points + 1): + start_idx = p * point_batch_size + end_idx = (p + 1) * point_batch_size + with paddle.no_grad(): + target_batch = target_surf[:, start_idx:end_idx] + surface_mesh_centers_batch = surface_mesh_centers[ + :, start_idx:end_idx + ] + surface_mesh_neighbors_batch = surface_mesh_neighbors[ + :, start_idx:end_idx + ] + surface_normals_batch = surface_normals[:, start_idx:end_idx] + surface_neighbors_normals_batch = surface_neighbors_normals[ + :, start_idx:end_idx + ] + surface_areas_batch = surface_areas[:, start_idx:end_idx] + surface_neighbors_areas_batch = surface_neighbors_areas[ + :, start_idx:end_idx + ] + pos_surface_center_of_mass_batch = pos_surface_center_of_mass[ + :, start_idx:end_idx + ] + geo_encoding_local = model.module.geo_encoding_local_surface( + 0.5 * encoding_g_surf, surface_mesh_centers_batch, s_grid + ) + pos_encoding = pos_surface_center_of_mass_batch + pos_encoding = model.module.position_encoder( + pos_encoding, eval_mode="surface" + ) + + if cfg.model.surface_neighbors: + tpredictions_batch = ( + model.module.calculate_solution_with_neighbors( + surface_mesh_centers_batch, + geo_encoding_local, + pos_encoding, + surface_mesh_neighbors_batch, + surface_normals_batch, + surface_neighbors_normals_batch, + surface_areas_batch, + surface_neighbors_areas_batch, + stream_velocity, + air_density, + ) + ) + else: + tpredictions_batch = model.module.calculate_solution( + surface_mesh_centers_batch, + geo_encoding_local, + pos_encoding, + stream_velocity, + air_density, + num_sample_points=1, + eval_mode="surface", + ) + running_tloss_surf += loss_fn(tpredictions_batch, target_batch) + prediction_surf[ + :, start_idx:end_idx + ] = tpredictions_batch.cpu().numpy() + + prediction_surf = ( + unnormalize(prediction_surf, surf_factors[0], surf_factors[1]) + * stream_velocity[0, 0].cpu().numpy() ** 2.0 + * air_density[0, 0].cpu().numpy() + ) + + else: + prediction_surf = None + + return prediction_vol, prediction_surf + + +def test(cfg: DictConfig): + print(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") + + input_path = cfg.eval.test_path + + model_type = cfg.model.model_type + + dist.init_parallel_env() + + if model_type == "volume" or model_type == "combined": + volume_variable_names = list(cfg.variables.volume.solution.keys()) + num_vol_vars = 0 + for j in volume_variable_names: + if cfg.variables.volume.solution[j] == "vector": + num_vol_vars += 3 + else: + num_vol_vars += 1 + else: + num_vol_vars = None + + if model_type == "surface" or model_type == "combined": + surface_variable_names = list(cfg.variables.surface.solution.keys()) + num_surf_vars = 0 + for j in surface_variable_names: + if cfg.variables.surface.solution[j] == "vector": + num_surf_vars += 3 + else: + num_surf_vars += 1 + else: + num_surf_vars = None + + vol_save_path = os.path.join( + "outputs", cfg.project.name, "volume_scaling_factors.npy" + ) + surf_save_path = os.path.join( + "outputs", cfg.project.name, "surface_scaling_factors.npy" + ) + if os.path.exists(vol_save_path) and os.path.exists(surf_save_path): + vol_factors = np.load(vol_save_path) + surf_factors = np.load(surf_save_path) + else: + vol_factors = None + surf_factors = None + + model = DoMINO( + input_features=3, + output_features_vol=num_vol_vars, + output_features_surf=num_surf_vars, + model_parameters=cfg.model, + ) + + checkpoint = paddle.load( + to_absolute_path(os.path.join(cfg.resume_dir, cfg.eval.checkpoint_name)), + ) + + model.set_state_dict(checkpoint) + + print("Model loaded") + + if paddle.distributed.get_world_size() > 1: + model = DataParallel( + model, + ) + + dirnames_per_gpu = get_filenames(input_path) + + pred_save_path = cfg.eval.save_path + create_directory(pred_save_path) + + for count, dirname in enumerate(dirnames_per_gpu): + # print(f"Processing file {dirname}") + filepath = os.path.join(input_path, dirname) + tag = int(re.findall(r"(\w+?)(\d+)", dirname)[0][1]) + stl_path = os.path.join(filepath, f"drivaer_{tag}.stl") + vtp_path = os.path.join(filepath, f"boundary_{tag}.vtp") + vtu_path = os.path.join(filepath, f"volume_{tag}.vtu") + + vtp_pred_save_path = os.path.join( + pred_save_path, f"boundary_{tag}_predicted.vtp" + ) + vtu_pred_save_path = os.path.join(pred_save_path, f"volume_{tag}_predicted.vtu") + + # Read STL + reader = pv.get_reader(stl_path) + mesh_stl = reader.read() + stl_vertices = mesh_stl.points + stl_faces = np.array(mesh_stl.faces).reshape((-1, 4))[ + :, 1: + ] # Assuming triangular elements + mesh_indices_flattened = stl_faces.flatten() + length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) + stl_sizes = mesh_stl.compute_cell_sizes(length=False, area=True, volume=False) + stl_sizes = np.array(stl_sizes.cell_data["Area"], dtype=np.float32) + stl_centers = np.array(mesh_stl.cell_centers().points, dtype=np.float32) + + # Center of mass calculation + center_of_mass = calculate_center_of_mass(stl_centers, stl_sizes) + + if cfg.data.bounding_box_surface is None: + s_max = np.amax(stl_vertices, 0) + s_min = np.amin(stl_vertices, 0) + else: + bounding_box_dims_surf = [] + bounding_box_dims_surf.append(np.asarray(cfg.data.bounding_box_surface.max)) + bounding_box_dims_surf.append(np.asarray(cfg.data.bounding_box_surface.min)) + s_max = np.float32(bounding_box_dims_surf[0]) + s_min = np.float32(bounding_box_dims_surf[1]) + + nx, ny, nz = cfg.model.interp_res + + surf_grid = create_grid(s_max, s_min, [nx, ny, nz]) + surf_grid_reshaped = surf_grid.reshape(nx * ny * nz, 3) + + # SDF calculation on the grid using WARP + sdf_surf_grid = ( + signed_distance_field( + stl_vertices, + mesh_indices_flattened, + surf_grid_reshaped, + use_sign_winding_number=True, + ) + .numpy() + .reshape(nx, ny, nz) + ) + surf_grid = np.float32(surf_grid) + sdf_surf_grid = np.float32(sdf_surf_grid) + surf_grid_max_min = np.float32(np.asarray([s_min, s_max])) + + # Read VTP + if model_type == "surface" or model_type == "combined": + reader = vtk.vtkXMLPolyDataReader() + reader.SetFileName(vtp_path) + reader.Update() + polydata_surf = reader.GetOutput() + + celldata_all = get_node_to_elem(polydata_surf) + + celldata = celldata_all.GetCellData() + surface_fields = get_fields(celldata, surface_variable_names) + surface_fields = np.concatenate(surface_fields, axis=-1) + + mesh = pv.PolyData(polydata_surf) + surface_coordinates = np.array(mesh.cell_centers().points, dtype=np.float32) + + interp_func = KDTree(surface_coordinates) + dd, ii = interp_func.query( + surface_coordinates, k=cfg.model.num_surface_neighbors + ) + + surface_neighbors = surface_coordinates[ii] + surface_neighbors = surface_neighbors[:, 1:] + + surface_normals = np.array(mesh.cell_normals, dtype=np.float32) + surface_sizes = mesh.compute_cell_sizes( + length=False, area=True, volume=False + ) + surface_sizes = np.array(surface_sizes.cell_data["Area"], dtype=np.float32) + + # Normalize cell normals + surface_normals = ( + surface_normals / np.linalg.norm(surface_normals, axis=1)[:, np.newaxis] + ) + surface_neighbors_normals = surface_normals[ii] + surface_neighbors_normals = surface_neighbors_normals[:, 1:] + surface_neighbors_sizes = surface_sizes[ii] + surface_neighbors_sizes = surface_neighbors_sizes[:, 1:] + + dx, dy, dz = ( + (s_max[0] - s_min[0]) / nx, + (s_max[1] - s_min[1]) / ny, + (s_max[2] - s_min[2]) / nz, + ) + + if cfg.model.positional_encoding: + pos_surface_center_of_mass = cal_normal_positional_encoding( + surface_coordinates, center_of_mass, cell_length=[dx, dy, dz] + ) + else: + pos_surface_center_of_mass = surface_coordinates - center_of_mass + + surface_coordinates = normalize(surface_coordinates, s_max, s_min) + surface_neighbors = normalize(surface_neighbors, s_max, s_min) + surf_grid = normalize(surf_grid, s_max, s_min) + + else: + surface_coordinates = None + surface_fields = None + surface_sizes = None + surface_normals = None + surface_neighbors = None + surface_neighbors_normals = None + surface_neighbors_sizes = None + pos_surface_center_of_mass = None + + # Read VTU + if model_type == "volume" or model_type == "combined": + reader = vtk.vtkXMLUnstructuredGridReader() + reader.SetFileName(vtu_path) + reader.Update() + polydata_vol = reader.GetOutput() + volume_coordinates, volume_fields = get_volume_data( + polydata_vol, volume_variable_names + ) + volume_fields = np.concatenate(volume_fields, axis=-1) + # print(f"Processed vtu {vtu_path}") + + bounding_box_dims = [] + bounding_box_dims.append(np.asarray(cfg.data.bounding_box.max)) + bounding_box_dims.append(np.asarray(cfg.data.bounding_box.min)) + + if bounding_box_dims is None: + c_max = s_max + (s_max - s_min) / 2 + c_min = s_min - (s_max - s_min) / 2 + c_min[2] = s_min[2] + else: + c_max = np.float32(bounding_box_dims[0]) + c_min = np.float32(bounding_box_dims[1]) + + dx, dy, dz = ( + (c_max[0] - c_min[0]) / nx, + (c_max[1] - c_min[1]) / ny, + (c_max[2] - c_min[2]) / nz, + ) + # Generate a grid of specified resolution to map the bounding box + # The grid is used for capturing structured geometry features and SDF representation of geometry + grid = create_grid(c_max, c_min, [nx, ny, nz]) + grid_reshaped = grid.reshape(nx * ny * nz, 3) + + # SDF calculation on the grid using WARP + sdf_grid = ( + signed_distance_field( + stl_vertices, + mesh_indices_flattened, + grid_reshaped, + use_sign_winding_number=True, + ) + .numpy() + .reshape(nx, ny, nz) + ) + + # SDF calculation + sdf_nodes, sdf_node_closest_point = signed_distance_field( + stl_vertices, + mesh_indices_flattened, + volume_coordinates, + include_hit_points=True, + use_sign_winding_number=True, + ) + sdf_nodes = sdf_nodes.numpy().reshape(-1, 1) + sdf_node_closest_point = sdf_node_closest_point.numpy() + + if cfg.model.positional_encoding: + pos_volume_closest = cal_normal_positional_encoding( + volume_coordinates, sdf_node_closest_point, cell_length=[dx, dy, dz] + ) + pos_volume_center_of_mass = cal_normal_positional_encoding( + volume_coordinates, center_of_mass, cell_length=[dx, dy, dz] + ) + else: + pos_volume_closest = volume_coordinates - sdf_node_closest_point + pos_volume_center_of_mass = volume_coordinates - center_of_mass + + volume_coordinates = normalize(volume_coordinates, c_max, c_min) + grid = normalize(grid, c_max, c_min) + vol_grid_max_min = np.asarray([c_min, c_max]) + + else: + volume_coordinates = None + volume_fields = None + pos_volume_closest = None + pos_volume_center_of_mass = None + + # print(f"Processed sdf and normalized") + + geom_centers = np.float32(stl_vertices) + + if model_type == "combined": + # Add the parameters to the dictionary + data_dict = { + "pos_volume_closest": pos_volume_closest, + "pos_volume_center_of_mass": pos_volume_center_of_mass, + "pos_surface_center_of_mass": pos_surface_center_of_mass, + "geometry_coordinates": geom_centers, + "grid": grid, + "surf_grid": surf_grid, + "sdf_grid": sdf_grid, + "sdf_surf_grid": sdf_surf_grid, + "sdf_nodes": sdf_nodes, + "surface_mesh_centers": surface_coordinates, + "surface_mesh_neighbors": surface_neighbors, + "surface_normals": surface_normals, + "surface_neighbors_normals": surface_neighbors_normals, + "surface_areas": surface_sizes, + "surface_neighbors_areas": surface_neighbors_sizes, + "volume_fields": volume_fields, + "volume_mesh_centers": volume_coordinates, + "surface_fields": surface_fields, + "volume_min_max": vol_grid_max_min, + "surface_min_max": surf_grid_max_min, + "length_scale": np.array(length_scale, dtype=np.float32), + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), axis=-1 + ), + } + elif model_type == "surface": + data_dict = { + "pos_surface_center_of_mass": np.float32(pos_surface_center_of_mass), + "geometry_coordinates": np.float32(geom_centers), + "surf_grid": np.float32(surf_grid), + "sdf_surf_grid": np.float32(sdf_surf_grid), + "surface_mesh_centers": np.float32(surface_coordinates), + "surface_mesh_neighbors": np.float32(surface_neighbors), + "surface_normals": np.float32(surface_normals), + "surface_neighbors_normals": np.float32(surface_neighbors_normals), + "surface_areas": np.float32(surface_sizes), + "surface_neighbors_areas": np.float32(surface_neighbors_sizes), + "surface_fields": np.float32(surface_fields), + "surface_min_max": np.float32(surf_grid_max_min), + "length_scale": np.array(length_scale, dtype=np.float32), + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), axis=-1 + ), + } + elif model_type == "volume": + data_dict = { + "pos_volume_closest": pos_volume_closest, + "pos_volume_center_of_mass": pos_volume_center_of_mass, + "geometry_coordinates": geom_centers, + "grid": grid, + "surf_grid": surf_grid, + "sdf_grid": sdf_grid, + "sdf_surf_grid": sdf_surf_grid, + "sdf_nodes": sdf_nodes, + "volume_fields": volume_fields, + "volume_mesh_centers": volume_coordinates, + "volume_min_max": vol_grid_max_min, + "surface_min_max": surf_grid_max_min, + "length_scale": np.array(length_scale, dtype=np.float32), + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), axis=-1 + ), + } + + data_dict = { + key: paddle.to_tensor(np.expand_dims(np.float32(value), 0)) + for key, value in data_dict.items() + } + + prediction_vol, prediction_surf = test_step( + data_dict, + model, + paddle.distributed.get_rank(), + cfg, + vol_factors, + surf_factors, + ) + + if prediction_surf is not None: + surface_sizes = np.expand_dims(surface_sizes, -1) + + force_x_pred = np.sum( + prediction_surf[0, :, 0] * surface_normals[:, 0] * surface_sizes[:, 0] + - prediction_surf[0, :, 1] * surface_sizes[:, 0] + ) + force_x_true = np.sum( + surface_fields[:, 0] * surface_normals[:, 0] * surface_sizes[:, 0] + - surface_fields[:, 1] * surface_sizes[:, 0] + ) + print(dirname, force_x_pred, force_x_true) + + if prediction_vol is not None: + target_vol = volume_fields + prediction_vol = prediction_vol[0] + c_min = vol_grid_max_min[0] + c_max = vol_grid_max_min[1] + volume_coordinates = unnormalize(volume_coordinates, c_max, c_min) + ids_in_bbox = np.where( + (volume_coordinates[:, 0] < c_min[0]) + | (volume_coordinates[:, 0] > c_max[0]) + | (volume_coordinates[:, 1] < c_min[1]) + | (volume_coordinates[:, 1] > c_max[1]) + | (volume_coordinates[:, 2] < c_min[2]) + | (volume_coordinates[:, 2] > c_max[2]) + ) + target_vol[ids_in_bbox] = 0.0 + prediction_vol[ids_in_bbox] = 0.0 + l2_gt = np.sum(np.square(target_vol), (0)) + l2_error = np.sum(np.square(prediction_vol - target_vol), (0)) + print( + "L-2 norm:", + dirname, + np.sqrt(l2_error), + np.sqrt(l2_gt), + np.sqrt(l2_error) / np.sqrt(l2_gt), + ) + + if prediction_surf is not None: + surfParam_vtk = numpy_support.numpy_to_vtk(prediction_surf[0, :, 0:1]) + surfParam_vtk.SetName(f"{surface_variable_names[0]}Pred") + celldata_all.GetCellData().AddArray(surfParam_vtk) + + surfParam_vtk = numpy_support.numpy_to_vtk(prediction_surf[0, :, 1:]) + surfParam_vtk.SetName(f"{surface_variable_names[1]}Pred") + celldata_all.GetCellData().AddArray(surfParam_vtk) + + write_to_vtp(celldata_all, vtp_pred_save_path) + + if prediction_vol is not None: + + volParam_vtk = numpy_support.numpy_to_vtk(prediction_vol[:, 0:3]) + volParam_vtk.SetName(f"{volume_variable_names[0]}Pred") + polydata_vol.GetPointData().AddArray(volParam_vtk) + + volParam_vtk = numpy_support.numpy_to_vtk(prediction_vol[:, 3:4]) + volParam_vtk.SetName(f"{volume_variable_names[1]}Pred") + polydata_vol.GetPointData().AddArray(volParam_vtk) + + volParam_vtk = numpy_support.numpy_to_vtk(prediction_vol[:, 4:5]) + volParam_vtk.SetName(f"{volume_variable_names[2]}Pred") + polydata_vol.GetPointData().AddArray(volParam_vtk) + + write_to_vtu(polydata_vol, vtu_pred_save_path) + + +@hydra.main(version_base=None, config_path="conf", config_name="config") +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + test(cfg) + elif cfg.mode == "process": + process(cfg) + else: + raise ValueError( + f"cfg.mode should in ['process', 'train', 'eval'], but got '{cfg.mode}'" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/domino/download_aws_dataset.sh b/examples/domino/download_aws_dataset.sh new file mode 100644 index 0000000000..f793dd021f --- /dev/null +++ b/examples/domino/download_aws_dataset.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# This Bash script downloads the AWS DrivAer files from the Amazon S3 bucket to a local directory. +# Only the volume files (.vtu), STL files (.stl), and VTP files (.vtp) are downloaded. +# It uses a function, download_run_files, to check for the existence of three specific files (".vtu", ".stl", ".vtp") in a run directory. +# If a file doesn't exist, it's downloaded from the S3 bucket. If it does exist, the download is skipped. +# The script runs multiple downloads in parallel, both within a single run and across multiple runs. +# It also includes checks to prevent overloading the system by limiting the number of parallel downloads. + +# Set the local directory to download the files +LOCAL_DIR="./drivaer_data_full" # <--- This is the directory where the files will be downloaded. + +# Set the S3 bucket and prefix +S3_BUCKET="caemldatasets" +S3_PREFIX="drivaer/dataset" + +# Create the local directory if it doesn't exist +mkdir -p "$LOCAL_DIR" + +# Function to download files for a specific run +download_run_files() { + local i=$1 + RUN_DIR="run_$i" + RUN_LOCAL_DIR="$LOCAL_DIR/$RUN_DIR" + + # Create the run directory if it doesn't exist + mkdir -p "$RUN_LOCAL_DIR" + + # Check if the .vtu file exists before downloading + if [ ! -f "$RUN_LOCAL_DIR/volume_$i.vtu" ]; then + aws s3 cp --no-sign-request "s3://$S3_BUCKET/$S3_PREFIX/$RUN_DIR/volume_$i.vtu" "$RUN_LOCAL_DIR/" & + else + echo "File volume_$i.vtu already exists, skipping download." + fi + + # Check if the .stl file exists before downloading + if [ ! -f "$RUN_LOCAL_DIR/drivaer_$i.stl" ]; then + aws s3 cp --no-sign-request "s3://$S3_BUCKET/$S3_PREFIX/$RUN_DIR/drivaer_$i.stl" "$RUN_LOCAL_DIR/" & + else + echo "File drivaer_$i.stl already exists, skipping download." + fi + + # Check if the .vtp file exists before downloading + if [ ! -f "$RUN_LOCAL_DIR/boundary_$i.vtp" ]; then + aws s3 cp --no-sign-request "s3://$S3_BUCKET/$S3_PREFIX/$RUN_DIR/boundary_$i.vtp" "$RUN_LOCAL_DIR/" & + else + echo "File boundary_$i.vtp already exists, skipping download." + fi + + wait # Ensure that both files for this run are downloaded before moving to the next run +} + +# Loop through the run folders and download the files +for i in $(seq 1 500); do + download_run_files "$i" & + + # Limit the number of parallel jobs to avoid overloading the system + if (( $(jobs -r | wc -l) >= 8 )); then + wait -n # Wait for the next background job to finish before starting a new one + fi +done + +# Wait for all remaining background jobs to finish +wait diff --git a/ppsci/arch/physicsnemo.py b/ppsci/arch/physicsnemo.py new file mode 100644 index 0000000000..ffa401b866 --- /dev/null +++ b/ppsci/arch/physicsnemo.py @@ -0,0 +1,1811 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino + + +import glob +import math +import os +import re +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import NewType +from typing import Optional +from typing import Union + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import warp as wp +from paddle.amp import GradScaler +from paddle.optimizer.lr import LRScheduler +from scipy.spatial import KDTree + +import ppsci +from ppsci.utils import logger + +optimizer = NewType("optimizer", paddle.optimizer) +scheduler = NewType("scheduler", LRScheduler) +scaler = NewType("scaler", GradScaler) + + +def nd_interpolator(coodinates, field, grid): + """Function to for nd interpolation""" + interp_func = KDTree(coodinates[0]) + dd, ii = interp_func.query(grid, k=2) + + field_grid = field[ii] + field_grid = np.float32(np.mean(field_grid, (3))) + return field_grid + + +def pad_inp(arr, npoin, pad_value=0.0): + """Function for padding arrays""" + arr_pad = pad_value * np.ones( + (npoin - arr.shape[0], arr.shape[1], arr.shape[2]), dtype=np.float32 + ) + arr_padded = np.concatenate((arr, arr_pad), axis=0) + return arr_padded + + +def shuffle_array_without_sampling(arr): + """Function for shuffline arrays without sampling.""" + idx = np.arange(arr.shape[0]) + np.random.shuffle(idx) + return arr[idx], idx + + +def create_directory(filepath): + """Function to create directories""" + if not os.path.exists(filepath): + os.makedirs(filepath) + + +def calculate_pos_encoding(nx, d=8): + """Function for calculating positional encoding""" + vec = [] + for k in range(int(d / 2)): + vec.append(np.sin(nx / 10000 ** (2 * (k) / d))) + vec.append(np.cos(nx / 10000 ** (2 * (k) / d))) + return vec + + +def combine_dict(old_dict, new_dict): + """Function to combine dictionaries""" + for j in old_dict.keys(): + old_dict[j] += new_dict[j] + return old_dict + + +def merge(*lists): + """Function to merge lists""" + newlist = lists[:] + for x in lists: + if x not in newlist: + newlist.extend(x) + return newlist + + +def mean_std_sampling(field, mean, std, tolerance=3.0): + """Function for mean/std based sampling""" + idx_all = [] + for v in range(field.shape[-1]): + fv = field[:, v] + idx = np.where( + (fv > mean[v] + tolerance * std[v]) | (fv < mean[v] - tolerance * std[v]) + ) + if len(idx[0]) != 0: + idx_all += list(idx[0]) + + return idx_all + + +def dict_to_device(state_dict, device): + """Function to load dictionary to device""" + new_state_dict = {} + for k, v in state_dict.items(): + new_state_dict[k] = v.to(device) + return new_state_dict + + +class BallQuery(paddle.autograd.PyLayer): + """ + Warp based Ball Query. + """ + + @wp.kernel + def ball_query( + points1: wp.array(dtype=wp.vec3), + points2: wp.array(dtype=wp.vec3), + grid: wp.uint64, + k: wp.int32, + radius: wp.float32, + mapping: wp.array3d(dtype=wp.int32), + num_neighbors: wp.array2d(dtype=wp.int32), + ): + + # Get index of point1 + tid = wp.tid() + + # Get position from points1 + pos = points1[tid] + + # particle contact + neighbors = wp.hash_grid_query(grid, pos, radius) + + # Keep track of the number of neighbors found + nr_found = wp.int32(0) + + # loop through neighbors to compute density + for index in neighbors: + # Check if outside the radius + pos2 = points2[index] + if wp.length(pos - pos2) > radius: + continue + + # Add neighbor to the list + mapping[0, tid, nr_found] = index + + # Increment the number of neighbors found + nr_found += 1 + + # Break if we have found enough neighbors + if nr_found == k: + num_neighbors[0, tid] = k + break + + # Set the number of neighbors + num_neighbors[0, tid] = nr_found + + @wp.kernel + def sparse_ball_query( + points2: wp.array(dtype=wp.vec3), + mapping: wp.array3d(dtype=wp.int32), + num_neighbors: wp.array2d(dtype=wp.int32), + outputs: wp.array4d(dtype=wp.float32), + ): + # Get index of point1 + p1 = wp.tid() + + # Get number of neighbors + k = num_neighbors[0, p1] + + # Loop through neighbors + for _k in range(k): + # Get point2 index + index = mapping[0, p1, _k] + + # Get position from points2 + pos = points2[index] + + # Set the output + outputs[0, p1, _k, 0] = pos[0] + outputs[0, p1, _k, 1] = pos[1] + outputs[0, p1, _k, 2] = pos[2] + + @staticmethod + def forward( + ctx, + points1, + points2, + lengths1, + lengths2, + k, + radius, + hash_grid, + ): + # Only works for batch size 1 + if points1.shape[0] != 1: + raise AssertionError("nly works for batch size 1") + + # Convert from paddle to warp + ctx.points1 = wp.from_paddle( + points1[0], dtype=wp.vec3, requires_grad=points1.stop_gradient + ) + ctx.points2 = wp.from_paddle( + points2[0], dtype=wp.vec3, requires_grad=points2.stop_gradient + ) + ctx.lengths1 = wp.from_paddle(lengths1, dtype=wp.int32, requires_grad=False) + ctx.lengths2 = wp.from_paddle(lengths2, dtype=wp.int32, requires_grad=False) + ctx.k = k + ctx.radius = radius + + # Allocate the mapping and outputs + mapping = paddle.zeros([1, points1.shape[1], k], dtype=paddle.int32) + mapping.stop_gradient = False + ctx.mapping = wp.from_paddle(mapping, dtype=wp.int32, requires_grad=False) + num_neighbors = paddle.zeros([1, points1.shape[1]], dtype=paddle.int32) + num_neighbors.stop_gradient = False + ctx.num_neighbors = wp.from_paddle( + num_neighbors, dtype=wp.int32, requires_grad=False + ) + outputs = paddle.zeros([1, points1.shape[1], k, 3], dtype=paddle.float32) + outputs.stop_gradient = points1.stop_gradient or points2.stop_gradient + ctx.outputs = wp.from_paddle(outputs, dtype=wp.float32) + outputs.stop_gradient = points1.stop_gradient or points2.stop_gradient + + # Make grid + ctx.hash_grid = hash_grid + + # Build the grid + ctx.hash_grid.build(ctx.points2, radius) + + # Run the kernel to get mapping + wp.launch( + BallQuery.ball_query, + inputs=[ + ctx.points1, + ctx.points2, + ctx.hash_grid.id, + k, + radius, + ], + outputs=[ + ctx.mapping, + ctx.num_neighbors, + ], + dim=[ctx.points1.shape[0]], + ) + + # Run the kernel to get outputs + wp.launch( + BallQuery.sparse_ball_query, + inputs=[ + ctx.points2, + ctx.mapping, + ctx.num_neighbors, + ], + outputs=[ + ctx.outputs, + ], + dim=[ctx.points1.shape[0]], + ) + + return ( + wp.to_paddle(ctx.mapping), + wp.to_paddle(ctx.num_neighbors), + wp.to_paddle(ctx.outputs), + ) + + @staticmethod + def backward(ctx, grad_mapping, grad_num_neighbors, grad_outputs): + # Map incoming paddle grads to our output variable + ctx.outputs.grad = wp.from_paddle(grad_outputs, dtype=wp.float32) + + # Run the kernel in adjoint mode + wp.launch( + BallQuery.sparse_ball_query, + inputs=[ + ctx.points2, + ctx.mapping, + ctx.num_neighbors, + ], + outputs=[ + ctx.outputs, + ], + adj_inputs=[ctx.points2.grad, ctx.mapping.grad, ctx.num_neighbors.grad], + adj_outputs=[ + ctx.outputs.grad, + ], + dim=[ctx.points1.shape[0]], + adjoint=True, + ) + + # Return the gradients + return ( + wp.to_paddle(ctx.points1.grad).unsqueeze(0), + wp.to_paddle(ctx.points2.grad).unsqueeze(0), + None, + None, + None, + None, + None, + ) + + +def kaiming_init(layer): + if isinstance(layer, (nn.layer.conv._ConvNd, nn.Linear)): + print(f"layer: {layer} ") + init_kaimingUniform = paddle.nn.initializer.KaimingUniform( + nonlinearity="leaky_relu", negative_slope=math.sqrt(5) + ) + init_kaimingUniform(layer.weight) + if layer.bias is not None: + fan_in, _ = ppsci.utils.initializer._calculate_fan_in_and_fan_out( + layer.weight + ) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + init_uniform = paddle.nn.initializer.Uniform(low=-bound, high=bound) + init_uniform(layer.bias) + + +def scale_sdf(sdf): + """Function to scale SDF""" + return sdf / (0.4 + abs(sdf)) + + +def calculate_gradient(sdf): + """Function to calculate the gradients of SDF""" + m, n, o = sdf.shape[2], sdf.shape[3], sdf.shape[4] + sdf_x = sdf[:, :, 2:m, :, :] - sdf[:, :, 0 : m - 2, :, :] + sdf_y = sdf[:, :, :, 2:n, :] - sdf[:, :, :, 0 : n - 2, :] + sdf_z = sdf[:, :, :, :, 2:o] - sdf[:, :, :, :, 0 : o - 2] + + sdf_x = F.pad(x=sdf_x, pad=(0, 0, 0, 0, 0, 1), mode="constant", value=0.0) + sdf_x = F.pad(x=sdf_x, pad=(0, 0, 0, 0, 1, 0), mode="constant", value=0.0) + sdf_y = F.pad(x=sdf_y, pad=(0, 0, 0, 1, 0, 0), mode="constant", value=0.0) + sdf_y = F.pad(x=sdf_y, pad=(0, 0, 1, 0, 0, 0), mode="constant", value=0.0) + sdf_z = F.pad(x=sdf_z, pad=(0, 1, 0, 0, 0, 0), mode="constant", value=0.0) + sdf_z = F.pad(x=sdf_z, pad=(1, 0, 0, 0, 0, 0), mode="constant", value=0.0) + + return sdf_x, sdf_y, sdf_z + + +def binarize_sdf(sdf): + """Function to calculate the binarize the SDF""" + sdf = paddle.where(sdf >= 0, 0.0, 1.0).to(dtype=sdf.dtype) + return sdf + + +class BallQueryLayer(paddle.nn.Layer): + """ + Paddle layer for differentiable and accelerated Ball Query + operation using Warp. + Args: + k (int): Number of neighbors. + radius (float): Radius of influence. + grid_size (int): Uniform grid resolution + """ + + def __init__(self, k, radius, grid_size=32): + super().__init__() + wp.init() + self.k = k + self.radius = radius + self.hash_grid = wp.HashGrid(grid_size, grid_size, grid_size) + + def forward(self, points1, points2, lengths1, lengths2): + return BallQuery.apply( + points1, + points2, + lengths1, + lengths2, + self.k, + self.radius, + self.hash_grid, + ) + + +def _get_checkpoint_filename( + path: str, + base_name: str = "checkpoint", + index: Union[int, None] = None, + saving: bool = False, + model_type: str = "mdlus", +) -> str: + """Gets the file name /path of checkpoint + + This function has three different ways of providing a checkout filename: + - If supplied an index this will return the checkpoint name using that index. + - If index is None and saving is false, this will get the checkpoint with the + largest index (latest save). + - If index is None and saving is true, it will return the next valid index file name + which is calculated by indexing the largest checkpoint index found by one. + + Parameters + ---------- + path : str + Path to checkpoints + base_name: str, optional + Base file name, by default checkpoint + index : Union[int, None], optional + Checkpoint index, by default None + saving : bool, optional + Get filename for saving a new checkpoint, by default False + model_type : str + Model type, by default "mdlus" for Modulus models and "pdparams" for models + + + Returns + ------- + str + Checkpoint file name + """ + # Get model parallel rank so all processes in the first model parallel group + # can save their checkpoint. In the case without model parallelism, + # model_parallel_rank should be the same as the process rank itself and + # only rank 0 saves + model_parallel_rank = 0 + + # Input file name + checkpoint_filename = str( + Path(path).resolve() / f"{base_name}.{model_parallel_rank}" + ) + + # File extension for Modulus models or PaddlePaddle models + file_extension = ".pdparams" + + # If epoch is provided load that file + if index is not None: + checkpoint_filename = checkpoint_filename + f".{index}" + checkpoint_filename += file_extension + # Otherwise try loading the latest epoch or rolling checkpoint + else: + file_names = [ + Path(fname).name + for fname in glob.glob( + checkpoint_filename + "*" + file_extension, recursive=False + ) + ] + + if len(file_names) > 0: + # If checkpoint from a null index save exists load that + # This is the most likely line to error since it will fail with + # invalid checkpoint names + file_idx = [ + int( + re.sub( + f"^{base_name}.{model_parallel_rank}.|" + file_extension, + "", + fname, + ) + ) + for fname in file_names + ] + file_idx.sort() + # If we are saving index by 1 to get the next free file name + if saving: + checkpoint_filename = checkpoint_filename + f".{file_idx[-1]+1}" + else: + checkpoint_filename = checkpoint_filename + f".{file_idx[-1]}" + checkpoint_filename += file_extension + else: + checkpoint_filename += ".0" + file_extension + + return checkpoint_filename + + +def _unique_model_names( + models: List[paddle.nn.Layer], +) -> Dict[str, paddle.nn.Layer]: + """Util to clean model names and index if repeat names, will also strip DDP wrappers + if they exist. + + Parameters + ---------- + model : List[paddle.nn.Layer] + List of models to generate names for + + Returns + ------- + Dict[str, paddle.nn.Layer] + Dictionary of model names and respective modules + """ + # Loop through provided models and set up base names + model_dict = {} + for model0 in models: + if hasattr(model0, "module"): + # Strip out DDP layer + model0 = model0.module + # Base name of model is meta.name unless paddle model + base_name = model0.__class__.__name__ + # if isinstance(model0, modulus): + # base_name = model0.meta.name + # If we have multiple models of the same name, introduce another index + if base_name in model_dict: + model_dict[base_name].append(model0) + else: + model_dict[base_name] = [model0] + + # Set up unique model names if needed + output_dict = {} + for key, model in model_dict.items(): + if len(model) > 1: + for i, model0 in enumerate(model): + output_dict[key + str(i)] = model0 + else: + output_dict[key] = model[0] + + return output_dict + + +def save_checkpoint( + path: str, + models: Union[paddle.nn.Layer, List[paddle.nn.Layer], None] = None, + optimizer: Union[optimizer, None] = None, + scheduler: Union[scheduler, None] = None, + scaler: Union[scaler, None] = None, + epoch: Union[int, None] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> None: + """Training checkpoint saving utility + + This will save a training checkpoint in the provided path following the file naming + convention "checkpoint.{model parallel id}.{epoch/index}.mdlus". The load checkpoint + method in Modulus core can then be used to read this file. + + Parameters + ---------- + path : str + Path to save the training checkpoint + models : Union[paddle.nn.Layer, List[paddle.nn.Layer], None], optional + A single or list of PaddlePaddle models, by default None + optimizer : Union[optimizer, None], optional + Optimizer, by default None + scheduler : Union[scheduler, None], optional + Learning rate scheduler, by default None + scaler : Union[scaler, None], optional + AMP grad scaler. Will attempt to save on in static capture if none provided, by + default None + epoch : Union[int, None], optional + Epoch checkpoint to load. If none this will save the checkpoint in the next + valid index, by default None + metadata : Optional[Dict[str, Any]], optional + Additional metadata to save, by default None + """ + # Create checkpoint directory if it does not exist + if not Path(path).is_dir(): + logger.warning( + f"Output directory {path} does not exist, will " "attempt to create" + ) + Path(path).mkdir(parents=True, exist_ok=True) + + # == Saving model checkpoint == + if models: + if not isinstance(models, list): + models = [models] + models = _unique_model_names(models) + for name, model in models.items(): + # Get model type + model_type = "pdparams" + + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, saving=True, model_type=model_type + ) + + # Save state dictionary + paddle.save(model.state_dict(), file_name) + logger.info(f"Saved model state dictionary: {file_name}") + + # == Saving training checkpoint == + checkpoint_dict = {} + # Optimizer state dict + if optimizer: + checkpoint_dict["optimizer_state_dict"] = optimizer.state_dict() + + # Scheduler state dict + if scheduler: + checkpoint_dict["scheduler_state_dict"] = scheduler.state_dict() + + # Scheduler state dict + if scaler: + checkpoint_dict["scaler_state_dict"] = scaler.state_dict() + # Static capture is being used, save its grad scaler + # if _StaticCapture._amp_scalers: + # checkpoint_dict["static_capture_state_dict"] = _StaticCapture.state_dict() + + # Output file name + output_filename = _get_checkpoint_filename( + path, index=epoch, saving=True, model_type="pdparams" + ) + if epoch: + checkpoint_dict["epoch"] = epoch + if metadata: + checkpoint_dict["metadata"] = metadata + # Save checkpoint to memory + if bool(checkpoint_dict): + paddle.save( + checkpoint_dict, + output_filename, + ) + logger.info(f"Saved training checkpoint: {output_filename}") + + +def load_checkpoint( + path: str, + models: Union[paddle.nn.Layer, List[paddle.nn.Layer], None] = None, + optimizer: Union[optimizer, None] = None, + scheduler: Union[scheduler, None] = None, + scaler: Union[scaler, None] = None, + epoch: Union[int, None] = None, + metadata_dict: Optional[Dict[str, Any]] = {}, +) -> int: + """Checkpoint loading utility + + This loader is designed to be used with the save checkpoint utility in Modulus + Launch. Given a path, this method will try to find a checkpoint and load state + dictionaries into the provided training objects. + + Parameters + ---------- + path : str + Path to training checkpoint + models : Union[paddle.nn.Layer, List[paddle.nn.Layer], None], optional + A single or list of models, by default None + optimizer : Union[optimizer, None], optional + Optimizer, by default None + scheduler : Union[scheduler, None], optional + Learning rate scheduler, by default None + scaler : Union[scaler, None], optional + AMP grad scaler, by default None + epoch : Union[int, None], optional + Epoch checkpoint to load. If none is provided this will attempt to load the + checkpoint with the largest index, by default None + metadata_dict: Optional[Dict[str, Any]], optional + Dictionary to store metadata from the checkpoint, by default None + + Returns + ------- + int + Loaded epoch + """ + # Check if checkpoint directory exists + if not Path(path).is_dir(): + logger.warning( + f"Provided checkpoint directory {path} does not exist, skipping load" + ) + return 0 + + # == Loading model checkpoint == + if models: + if not isinstance(models, list): + models = [models] + models = _unique_model_names(models) + for name, model in models.items(): + # Get model type + model_type = "pdparams" + + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, model_type=model_type + ) + if not Path(file_name).exists(): + logger.error( + f"Could not find valid model file {file_name}, skipping load" + ) + continue + # Load state dictionary + model.set_state_dict(paddle.load(file_name)) + + logger.info(f"Loaded model state dictionary {file_name}") + + # == Loading training checkpoint == + checkpoint_filename = _get_checkpoint_filename( + path, index=epoch, model_type="pdparams" + ) + if not Path(checkpoint_filename).is_file(): + logger.warning("Could not find valid checkpoint file, skipping load") + return 0 + + checkpoint_dict = paddle.load(checkpoint_filename) + logger.info(f"Loaded checkpoint file {checkpoint_filename}") + + # Optimizer state dict + if optimizer and "optimizer_state_dict" in checkpoint_dict: + optimizer.set_state_dict(checkpoint_dict["optimizer_state_dict"]) + logger.info("Loaded optimizer state dictionary") + + # Scheduler state dict + if scheduler and "scheduler_state_dict" in checkpoint_dict: + scheduler.set_state_dict(checkpoint_dict["scheduler_state_dict"]) + logger.info("Loaded scheduler state dictionary") + + # Scaler state dict + if scaler and "scaler_state_dict" in checkpoint_dict: + scaler.load_state_dict(checkpoint_dict["scaler_state_dict"]) + logger.info("Loaded grad scaler state dictionary") + + epoch = 0 + if "epoch" in checkpoint_dict: + epoch = checkpoint_dict["epoch"] + # Update metadata if exists and the dictionary object is provided + metadata = checkpoint_dict.get("metadata", {}) + for key, value in metadata.items(): + metadata_dict[key] = value + + return epoch + + +class BQWarp(nn.Layer): + """Warp based ball-query layer""" + + def __init__( + self, + input_features, + grid_resolution=[256, 96, 64], + radius=0.25, + neighbors_in_radius=10, + ): + super().__init__() + self.ball_query_layer = BallQueryLayer(neighbors_in_radius, radius) + self.grid_resolution = grid_resolution + + def forward(self, x, p_grid, reverse_mapping=True): + batch_size = x.shape[0] + nx, ny, nz = ( + self.grid_resolution[0], + self.grid_resolution[1], + self.grid_resolution[2], + ) + + p_grid = paddle.reshape(p_grid, (batch_size, nx * ny * nz, 3)) + p1 = nx * ny * nz + p2 = x.shape[1] + + if reverse_mapping: + lengths1 = paddle.full((batch_size,), p1, dtype=paddle.int32) + lengths2 = paddle.full((batch_size,), p2, dtype=paddle.int32) + mapping, num_neighbors, outputs = self.ball_query_layer( + p_grid, + x, + lengths1, + lengths2, + ) + else: + lengths1 = paddle.full((batch_size,), p2, dtype=paddle.int32) + lengths2 = paddle.full((batch_size,), p1, dtype=paddle.int32) + mapping, num_neighbors, outputs = self.ball_query_layer( + x, + p_grid, + lengths1, + lengths2, + ) + + return mapping, outputs + + +class GeoConvOut(nn.Layer): + """Geometry layer to project STLs on grids""" + + def __init__(self, input_features, model_parameters, grid_resolution=[256, 96, 64]): + super().__init__() + base_neurons = model_parameters.base_neurons + + self.fc1 = nn.Linear(input_features, base_neurons) + self.fc2 = nn.Linear(base_neurons, int(base_neurons / 2)) + self.fc3 = nn.Linear(int(base_neurons / 2), model_parameters.base_neurons_out) + + self.grid_resolution = grid_resolution + + self.activation = F.relu + + def forward(self, x, radius=0.025, neighbors_in_radius=10): + batch_size = x.shape[0] + nx, ny, nz = ( + self.grid_resolution[0], + self.grid_resolution[1], + self.grid_resolution[2], + ) + + mask = abs(x - 0) > 1e-6 + + x = self.activation(self.fc1(x)) + x = self.activation(self.fc2(x)) + x = F.tanh(self.fc3(x)) + mask = mask[:, :, :, 0:1].expand( + [mask.shape[0], mask.shape[1], mask.shape[2], x.shape[-1]] + ) + + # paddle does not support multiplication with boolean tensors, + # so we convert the mask to float + x = paddle.sum(x * mask.to(dtype=x.dtype), 2) + + x = paddle.reshape(x, (batch_size, x.shape[-1], nx, ny, nz)) + return x + + +class GeoProcessor(nn.Layer): + """Geometry processing layer using CNNs""" + + def __init__(self, input_filters, model_parameters): + super().__init__() + base_filters = model_parameters.base_filters + self.conv1 = nn.Conv3D( + input_filters, base_filters, kernel_size=3, padding="same" + ) + self.conv_bn1 = nn.BatchNorm3D(int(base_filters)) + self.conv2 = nn.Conv3D( + base_filters, 2 * base_filters, kernel_size=3, padding="same" + ) + self.conv_bn2 = nn.BatchNorm3D(int(2 * base_filters)) + self.conv3 = nn.Conv3D( + 2 * base_filters, 4 * base_filters, kernel_size=3, padding="same" + ) + self.conv_bn3 = nn.BatchNorm3D(int(4 * base_filters)) + self.conv3_1 = nn.Conv3D( + 4 * base_filters, 4 * base_filters, kernel_size=3, padding="same" + ) + self.conv4 = nn.Conv3D( + 4 * base_filters, 2 * base_filters, kernel_size=3, padding="same" + ) + self.conv_bn4 = nn.BatchNorm3D(int(2 * base_filters)) + self.conv5 = nn.Conv3D( + 4 * base_filters, base_filters, kernel_size=3, padding="same" + ) + self.conv_bn5 = nn.BatchNorm3D(int(base_filters)) + self.conv6 = nn.Conv3D( + 2 * base_filters, input_filters, kernel_size=3, padding="same" + ) + self.conv_bn6 = nn.BatchNorm3D(int(input_filters)) + self.conv7 = nn.Conv3D( + 2 * input_filters, input_filters, kernel_size=3, padding="same" + ) + self.conv8 = nn.Conv3D(input_filters, 1, kernel_size=3, padding="same") + self.avg_pool = paddle.nn.AvgPool3D((2, 2, 2)) + self.max_pool = nn.MaxPool3D(2) + self.upsample = nn.Upsample(scale_factor=2, mode="nearest") + self.activation = F.relu + self.batch_norm = False + + def forward(self, x): + # Encoder + x0 = x + if self.batch_norm: + x = self.activation(self.conv_bn1(self.conv1(x))) + else: + x = self.activation(self.conv1(x)) + x = self.max_pool(x) + x1 = x + if self.batch_norm: + x = self.activation(self.conv_bn2(self.conv2(x))) + else: + x = self.activation((self.conv2(x))) + x = self.max_pool(x) + + x2 = x + if self.batch_norm: + x = self.activation(self.conv_bn3(self.conv2(x))) + else: + x = self.activation((self.conv3(x))) + x = self.max_pool(x) + + # Processor loop + x = F.relu(self.conv3_1(x)) + + # Decoder + if self.batch_norm: + x = self.activation(self.conv_bn4(self.conv4(x))) + else: + x = self.activation((self.conv4(x))) + x = self.upsample(x) + x = paddle.concat((x, x2), axis=1) + + if self.batch_norm: + x = self.activation(self.conv_bn5(self.conv5(x))) + else: + x = self.activation((self.conv5(x))) + x = self.upsample(x) + x = paddle.concat((x, x1), axis=1) + if self.batch_norm: + x = self.activation(self.conv_bn6(self.conv6(x))) + else: + x = self.activation((self.conv6(x))) + x = self.upsample(x) + x = paddle.concat((x, x0), axis=1) + + x = self.activation(self.conv7(x)) + x = self.conv8(x) + + return x + + +class GeometryRep(nn.Layer): + """Geometry representation from STLs block""" + + def __init__(self, input_features, model_parameters=None): + super().__init__() + geometry_rep = model_parameters.geometry_rep + + self.bq_warp_short = BQWarp( + input_features=input_features, + grid_resolution=model_parameters.interp_res, + radius=geometry_rep.geo_conv.radius_short, + ) + + self.bq_warp_long = BQWarp( + input_features=input_features, + grid_resolution=model_parameters.interp_res, + radius=geometry_rep.geo_conv.radius_long, + ) + + self.geo_conv_out = GeoConvOut( + input_features=input_features, + model_parameters=geometry_rep.geo_conv, + grid_resolution=model_parameters.interp_res, + ) + + self.geo_processor_short_range = GeoProcessor( + input_filters=geometry_rep.geo_conv.base_neurons_out, + model_parameters=geometry_rep.geo_processor, + ) + self.geo_processor_long_range = GeoProcessor( + input_filters=geometry_rep.geo_conv.base_neurons_out, + model_parameters=geometry_rep.geo_processor, + ) + self.geo_processor_sdf = GeoProcessor( + input_filters=6, model_parameters=geometry_rep.geo_processor + ) + self.activation = F.relu + self.radius_short = geometry_rep.geo_conv.radius_short + self.radius_long = geometry_rep.geo_conv.radius_long + self.hops = geometry_rep.geo_conv.hops + + def forward(self, x, p_grid, sdf): + + # Expand SDF + sdf = paddle.unsqueeze(sdf, 1) + + # Calculate short-range geoemtry dependency + mapping, k_short = self.bq_warp_short(x, p_grid) + x_encoding_short = self.geo_conv_out(k_short) + + # Calculate long-range geometry dependency + mapping, k_long = self.bq_warp_long(x, p_grid) + x_encoding_long = self.geo_conv_out(k_long) + + # Scaled sdf to emphasis on surface + scaled_sdf = scale_sdf(sdf) + # Binary sdf + binary_sdf = binarize_sdf(sdf) + # Gradients of SDF + sdf_x, sdf_y, sdf_z = calculate_gradient(sdf) + + # Propagate information in the geometry enclosed BBox + for _ in range(self.hops): + dx = self.geo_processor_short_range(x_encoding_short) / self.hops + x_encoding_short = x_encoding_short + dx + + # Propagate information in the computational domain BBox + for _ in range(self.hops): + dx = self.geo_processor_long_range(x_encoding_long) / self.hops + x_encoding_long = x_encoding_long + dx + + # Process SDF and its computed features + sdf = paddle.concat((sdf, scaled_sdf, binary_sdf, sdf_x, sdf_y, sdf_z), 1) + sdf_encoding = self.geo_processor_sdf(sdf) + + # Geometry encoding comprised of short-range, long-range and SDF features + encoding_g = paddle.concat((x_encoding_short, sdf_encoding, x_encoding_long), 1) + + return encoding_g + + +class NNBasisFunctions(nn.Layer): + """Basis function layer for point clouds""" + + def __init__(self, input_features, model_parameters=None): + super(NNBasisFunctions, self).__init__() + self.input_features = input_features + + base_layer = model_parameters.base_layer + self.fc1 = nn.Linear(self.input_features, base_layer) + self.fc2 = nn.Linear(base_layer, int(base_layer)) + self.fc3 = nn.Linear(int(base_layer), int(base_layer)) + self.bn1 = nn.BatchNorm1D(base_layer) + self.bn2 = nn.BatchNorm1D(int(base_layer)) + self.bn3 = nn.BatchNorm1D(int(base_layer)) + + self.activation = F.relu + + def forward(self, x, padded_value=-10): + facets = x + facets = self.activation(self.fc1(facets)) + facets = self.activation(self.fc2(facets)) + facets = self.fc3(facets) + + return facets + + +class ParameterModel(nn.Layer): + """Layer to encode parameters such as inlet velocity and air density""" + + def __init__(self, input_features, model_parameters=None): + super(ParameterModel, self).__init__() + self.input_features = input_features + + base_layer = model_parameters.base_layer + self.fc1 = nn.Linear(self.input_features, base_layer) + self.fc2 = nn.Linear(base_layer, int(base_layer)) + self.fc3 = nn.Linear(int(base_layer), int(base_layer)) + self.bn1 = nn.BatchNorm1D(base_layer) + self.bn2 = nn.BatchNorm1D(int(base_layer)) + self.bn3 = nn.BatchNorm1D(int(base_layer)) + + self.activation = F.relu + + def forward(self, x, padded_value=-10): + params = x + params = self.activation(self.fc1(params)) + params = self.activation(self.fc2(params)) + params = self.fc3(params) + + return params + + +class AggregationModel(nn.Layer): + """Layer to aggregate local geometry encoding with basis functions""" + + def __init__( + self, input_features, output_features, model_parameters=None, new_change=True + ): + super(AggregationModel, self).__init__() + self.input_features = input_features + self.output_features = output_features + self.new_change = new_change + base_layer = model_parameters.base_layer + self.fc1 = nn.Linear(self.input_features, base_layer) + self.fc2 = nn.Linear(base_layer, int(base_layer)) + self.fc3 = nn.Linear(int(base_layer), int(base_layer)) + self.fc4 = nn.Linear(int(base_layer), int(base_layer)) + self.fc5 = nn.Linear(int(base_layer), self.output_features) + self.bn1 = nn.BatchNorm1D(base_layer) + self.bn2 = nn.BatchNorm1D(int(base_layer)) + self.bn3 = nn.BatchNorm1D(int(base_layer)) + self.bn4 = nn.BatchNorm1D(int(base_layer)) + self.activation = F.relu + + def forward(self, x): + out = self.activation(self.fc1(x)) + out = self.activation(self.fc2(out)) + out = self.activation(self.fc3(out)) + out = self.activation(self.fc4(out)) + + out = self.fc5(out) + + return out + + +class DoMINO(nn.Layer): + """DoMINO model architecture + Parameters + ---------- + input_features : int + Number of point input features + output_features_vol : int + Number of output features in volume + output_features_surf : int + Number of output features on surface + model_parameters: dict + Dictionary of model parameters controlled by config.yaml + + Example + ------- + >>> from modulus.models.domino.model import DoMINO + >>> import os + >>> from hydra import compose, initialize + >>> from omegaconf import OmegaConf + >>> cfg = OmegaConf.register_new_resolver("eval", eval) + >>> with initialize(version_base="1.3", config_path="examples/cfd/external_aerodynamics/domino/src/conf"): + ... cfg = compose(config_name="config") + >>> cfg.model.model_type = "combined" + >>> model = DoMINO( + ... input_features=3, + ... output_features_vol=5, + ... output_features_surf=4, + ... model_parameters=cfg.model + ... ) + + Warp ... + >>> bsize = 1 + >>> nx, ny, nz = 128, 64, 48 + >>> num_neigh = 7 + >>> pos_normals_closest_vol = paddle.randn([bsize, 100, 3]) + >>> pos_normals_com_vol = paddle.randn([bsize, 100, 3]) + >>> pos_normals_com_surface = paddle.randn([bsize, 100, 3]) + >>> geom_centers = paddle.randn([bsize, 100, 3]) + >>> grid = paddle.randn([bsize, nx, ny, nz, 3]) + >>> surf_grid = paddle.randn([bsize, nx, ny, nz, 3]) + >>> sdf_grid = paddle.randn([bsize, nx, ny, nz]) + >>> sdf_surf_grid = paddle.randn([bsize, nx, ny, nz]) + >>> sdf_nodes = paddle.randn([bsize, 100, 1]) + >>> surface_coordinates = paddle.randn([bsize, 100, 3]) + >>> surface_neighbors = paddle.randn([bsize, 100, num_neigh, 3]) + >>> surface_normals = paddle.randn([bsize, 100, 3]) + >>> surface_neighbors_normals = paddle.randn([bsize, 100, num_neigh, 3]) + >>> surface_sizes = paddle.randn([bsize, 100, 3]) + >>> surface_neighbors_sizes = paddle.randn([bsize, 100, num_neigh, 3]) + >>> volume_coordinates = paddle.randn([bsize, 100, 3]) + >>> vol_grid_max_min = paddle.randn([bsize, 2, 3]) + >>> surf_grid_max_min = paddle.randn([bsize, 2, 3]) + >>> stream_velocity = paddle.randn([bsize, 1]) + >>> air_density = paddle.randn([bsize, 1]) + >>> input_dict = { + ... "pos_volume_closest": pos_normals_closest_vol, + ... "pos_volume_center_of_mass": pos_normals_com_vol, + ... "pos_surface_center_of_mass": pos_normals_com_surface, + ... "geometry_coordinates": geom_centers, + ... "grid": grid, + ... "surf_grid": surf_grid, + ... "sdf_grid": sdf_grid, + ... "sdf_surf_grid": sdf_surf_grid, + ... "sdf_nodes": sdf_nodes, + ... "surface_mesh_centers": surface_coordinates, + ... "surface_mesh_neighbors": surface_neighbors, + ... "surface_normals": surface_normals, + ... "surface_neighbors_normals": surface_neighbors_normals, + ... "surface_areas": surface_sizes, + ... "surface_neighbors_areas": surface_neighbors_sizes, + ... "volume_mesh_centers": volume_coordinates, + ... "volume_min_max": vol_grid_max_min, + ... "surface_min_max": surf_grid_max_min, + ... "stream_velocity": stream_velocity, + ... "air_density": air_density, + ... } + >>> output = model(input_dict) + Module ... + >>> print(f"{output[0].shape}, {output[1].shape}") + """ + + def __init__( + self, + input_features, + output_features_vol=None, + output_features_surf=None, + model_parameters=None, + ): + super(DoMINO, self).__init__() + self.input_features = input_features + self.output_features_vol = output_features_vol + self.output_features_surf = output_features_surf + + if self.output_features_vol is None and self.output_features_surf is None: + raise ValueError("Need to specify number of volume or surface features") + + self.num_variables_vol = output_features_vol + self.num_variables_surf = output_features_surf + self.grid_resolution = model_parameters.interp_res + self.surface_neighbors = model_parameters.surface_neighbors + self.use_surface_normals = model_parameters.use_surface_normals + self.use_only_normals = model_parameters.use_only_normals + self.encode_parameters = model_parameters.encode_parameters + self.param_scaling_factors = model_parameters.parameter_model.scaling_params + + if self.use_surface_normals: + if self.use_only_normals: + input_features_surface = input_features + 3 + else: + input_features_surface = input_features + 4 + else: + input_features_surface = input_features + + if self.encode_parameters: + # Defining the parameter model + base_layer_p = model_parameters.parameter_model.base_layer + self.parameter_model = ParameterModel( + input_features=2, model_parameters=model_parameters.parameter_model + ) + else: + base_layer_p = 0 + + self.geo_rep = GeometryRep( + input_features=input_features, + model_parameters=model_parameters, + ) + + # Basis functions for surface and volume + base_layer_nn = model_parameters.nn_basis_functions.base_layer + if self.output_features_surf is not None: + self.nn_basis_surf = nn.LayerList() + for _ in range(self.num_variables_surf): + self.nn_basis_surf.append( + NNBasisFunctions( + input_features=input_features_surface, + model_parameters=model_parameters.nn_basis_functions, + ) + ) + + if self.output_features_vol is not None: + self.nn_basis_vol = nn.LayerList() + for _ in range(self.num_variables_vol): + self.nn_basis_vol.append( + NNBasisFunctions( + input_features=input_features, + model_parameters=model_parameters.nn_basis_functions, + ) + ) + + # Positional encoding + position_encoder_base_neurons = model_parameters.position_encoder.base_neurons + if self.output_features_vol is not None: + if model_parameters.positional_encoding: + inp_pos_vol = 25 if model_parameters.use_sdf_in_basis_func else 12 + else: + inp_pos_vol = 7 if model_parameters.use_sdf_in_basis_func else 3 + + self.fc_p_vol = nn.Linear(inp_pos_vol, position_encoder_base_neurons) + + if self.output_features_surf is not None: + if model_parameters.positional_encoding: + inp_pos_surf = 12 + else: + inp_pos_surf = 3 + + self.fc_p_surf = nn.Linear(inp_pos_surf, position_encoder_base_neurons) + + # Positional encoding hidden layers + self.fc_p1 = nn.Linear( + position_encoder_base_neurons, position_encoder_base_neurons + ) + self.fc_p2 = nn.Linear( + position_encoder_base_neurons, position_encoder_base_neurons + ) + + # BQ for surface and volume + self.neighbors_in_radius = model_parameters.geometry_local.neighbors_in_radius + self.radius = model_parameters.geometry_local.radius + self.bq_warp = BQWarp( + input_features=input_features, + grid_resolution=model_parameters.interp_res, + radius=self.radius, + neighbors_in_radius=self.neighbors_in_radius, + ) + + base_layer_geo = model_parameters.geometry_local.base_layer + self.fc_1 = nn.Linear(self.neighbors_in_radius * 3, base_layer_geo) + self.fc_2 = nn.Linear(base_layer_geo, base_layer_geo) + self.activation = F.relu + + # Aggregation model + if self.output_features_surf is not None: + # Surface + self.agg_model_surf = nn.LayerList() + for _ in range(self.num_variables_surf): + self.agg_model_surf.append( + AggregationModel( + input_features=position_encoder_base_neurons + + base_layer_nn + + base_layer_geo + + base_layer_p, + output_features=1, + model_parameters=model_parameters.aggregation_model, + ) + ) + + if self.output_features_vol is not None: + # Volume + self.agg_model_vol = nn.LayerList() + for _ in range(self.num_variables_vol): + self.agg_model_vol.append( + AggregationModel( + input_features=position_encoder_base_neurons + + base_layer_nn + + base_layer_geo + + base_layer_p, + output_features=1, + model_parameters=model_parameters.aggregation_model, + ) + ) + + self.apply(kaiming_init) + + def geometry_encoder(self, geo_centers, p_grid, sdf): + """Function to return local geometry encoding""" + return self.geo_rep(geo_centers, p_grid, sdf) + + def position_encoder(self, encoding_node, eval_mode="volume"): + """Function to calculate positional encoding""" + if eval_mode == "volume": + x = self.activation(self.fc_p_vol(encoding_node)) + elif eval_mode == "surface": + x = self.activation(self.fc_p_surf(encoding_node)) + x = self.activation(self.fc_p1(x)) + x = self.fc_p2(x) + return x + + def geo_encoding_local_surface(self, encoding_g, volume_mesh_centers, p_grid): + """Function to calculate local geometry encoding from global encoding for surface""" + batch_size = volume_mesh_centers.shape[0] + nx, ny, nz = ( + self.grid_resolution[0], + self.grid_resolution[1], + self.grid_resolution[2], + ) + p_grid = paddle.reshape(p_grid, (batch_size, nx * ny * nz, 3)) + mapping, outputs = self.bq_warp( + volume_mesh_centers, p_grid, reverse_mapping=False + ) + mapping = mapping.astype(paddle.int64) + mask = mapping != 0 + + geo_encoding = paddle.reshape(encoding_g[:, 0], (batch_size, 1, nx * ny * nz)) + geo_encoding = geo_encoding.expand( + [batch_size, volume_mesh_centers.shape[1], geo_encoding.shape[2]] + ) + sdf_encoding = paddle.reshape(encoding_g[:, 1], (batch_size, 1, nx * ny * nz)) + sdf_encoding = sdf_encoding.expand( + [batch_size, volume_mesh_centers.shape[1], sdf_encoding.shape[2]] + ) + geo_encoding_long = paddle.reshape( + encoding_g[:, 2], (batch_size, 1, nx * ny * nz) + ) + geo_encoding_long = geo_encoding_long.expand( + [batch_size, volume_mesh_centers.shape[1], geo_encoding_long.shape[2]] + ) + + geo_encoding_sampled = paddle.take_along_axis( + geo_encoding, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + sdf_encoding_sampled = paddle.take_along_axis( + sdf_encoding, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + geo_encoding_long_sampled = paddle.take_along_axis( + geo_encoding_long, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + + encoding_g = paddle.concat( + (geo_encoding_sampled, sdf_encoding_sampled, geo_encoding_long_sampled), + axis=2, + ) + encoding_g = self.activation(self.fc_1(encoding_g)) + encoding_g = self.fc_2(encoding_g) + + return encoding_g + + def geo_encoding_local(self, encoding_g, volume_mesh_centers, p_grid): + """Function to calculate local geometry encoding from global encoding""" + batch_size = volume_mesh_centers.shape[0] + nx, ny, nz = ( + self.grid_resolution[0], + self.grid_resolution[1], + self.grid_resolution[2], + ) + p_grid = paddle.reshape(p_grid, (batch_size, nx * ny * nz, 3)) + mapping, outputs = self.bq_warp( + volume_mesh_centers, p_grid, reverse_mapping=False + ) + mapping = mapping.astype(paddle.int64) + mask = mapping != 0 + + geo_encoding = paddle.reshape(encoding_g[:, 0], (batch_size, 1, nx * ny * nz)) + geo_encoding = geo_encoding.expand( + [batch_size, volume_mesh_centers.shape[1], geo_encoding.shape[2]] + ) + sdf_encoding = paddle.reshape(encoding_g[:, 1], (batch_size, 1, nx * ny * nz)) + sdf_encoding = sdf_encoding.expand( + [batch_size, volume_mesh_centers.shape[1], sdf_encoding.shape[2]] + ) + geo_encoding_long = paddle.reshape( + encoding_g[:, 2], (batch_size, 1, nx * ny * nz) + ) + geo_encoding_long = geo_encoding_long.expand( + [batch_size, volume_mesh_centers.shape[1], geo_encoding_long.shape[2]] + ) + + geo_encoding_sampled = paddle.take_along_axis( + geo_encoding, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + sdf_encoding_sampled = paddle.take_along_axis( + sdf_encoding, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + geo_encoding_long_sampled = paddle.take_along_axis( + geo_encoding_long, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + + encoding_g = paddle.concat( + (geo_encoding_sampled, sdf_encoding_sampled, geo_encoding_long_sampled), + axis=2, + ) + encoding_g = self.activation(self.fc_1(encoding_g)) + encoding_g = self.fc_2(encoding_g) + + return encoding_g + + def calculate_solution_with_neighbors( + self, + surface_mesh_centers, + encoding_g, + encoding_node, + surface_mesh_neighbors, + surface_normals, + surface_neighbors_normals, + surface_areas, + surface_neighbors_areas, + inlet_velocity, + air_density, + ): + """Function to approximate solution given the neighborhood information""" + num_variables = self.num_variables_surf + nn_basis = self.nn_basis_surf + agg_model = self.agg_model_surf + num_sample_points = surface_mesh_neighbors.shape[2] + 1 + + if self.encode_parameters: + inlet_velocity = paddle.unsqueeze(inlet_velocity, 1) + inlet_velocity = inlet_velocity.expand( + [ + inlet_velocity.shape[0], + surface_mesh_centers.shape[1], + inlet_velocity.shape[2], + ] + ) + inlet_velocity = inlet_velocity / self.param_scaling_factors[0] + + air_density = paddle.unsqueeze(air_density, 1) + air_density = air_density.expand( + [ + air_density.shape[0], + surface_mesh_centers.shape[1], + air_density.shape[2], + ] + ) + air_density = air_density / self.param_scaling_factors[1] + + params = paddle.concat((inlet_velocity, air_density), axis=-1) + param_encoding = self.parameter_model(params) + + if self.use_surface_normals: + if self.use_only_normals: + surface_mesh_centers = paddle.concat( + (surface_mesh_centers, surface_normals), + axis=-1, + ) + surface_mesh_neighbors = paddle.concat( + ( + surface_mesh_neighbors, + surface_neighbors_normals, + ), + axis=-1, + ) + + else: + surface_mesh_centers = paddle.concat( + (surface_mesh_centers, surface_normals, 10**5 * surface_areas), + axis=-1, + ) + surface_mesh_neighbors = paddle.concat( + ( + surface_mesh_neighbors, + surface_neighbors_normals, + 10**5 * surface_neighbors_areas, + ), + axis=-1, + ) + + for f in range(num_variables): + for p in range(num_sample_points): + if p == 0: + volume_m_c = surface_mesh_centers + else: + volume_m_c = surface_mesh_neighbors[:, :, p - 1] + noise = surface_mesh_centers - volume_m_c + dist = paddle.sqrt( + noise[:, :, 0:1] ** 2.0 + + noise[:, :, 1:2] ** 2.0 + + noise[:, :, 2:3] ** 2.0 + ) + basis_f = nn_basis[f](volume_m_c) + output = paddle.concat((basis_f, encoding_node, encoding_g), axis=-1) + if self.encode_parameters: + output = paddle.concat((output, param_encoding), axis=-1) + if p == 0: + output_center = agg_model[f](output) + else: + if p == 1: + output_neighbor = agg_model[f](output) * (1.0 / dist) + dist_sum = 1.0 / dist + else: + output_neighbor += agg_model[f](output) * (1.0 / dist) + dist_sum += 1.0 / dist + if num_sample_points > 1: + output_res = 0.5 * output_center + 0.5 * output_neighbor / dist_sum + else: + output_res = output_center + if f == 0: + output_all = output_res + else: + output_all = paddle.concat((output_all, output_res), axis=-1) + + return output_all + + def calculate_solution( + self, + volume_mesh_centers, + encoding_g, + encoding_node, + inlet_velocity, + air_density, + eval_mode, + num_sample_points=20, + noise_intensity=50, + ): + """Function to approximate solution sampling the neighborhood information""" + if eval_mode == "volume": + num_variables = self.num_variables_vol + nn_basis = self.nn_basis_vol + agg_model = self.agg_model_vol + elif eval_mode == "surface": + num_variables = self.num_variables_surf + nn_basis = self.nn_basis_surf + agg_model = self.agg_model_surf + + if self.encode_parameters: + inlet_velocity = paddle.unsqueeze(inlet_velocity, 1) + inlet_velocity = inlet_velocity.expand( + [ + inlet_velocity.shape[0], + volume_mesh_centers.shape[1], + inlet_velocity.shape[2], + ] + ) + inlet_velocity = inlet_velocity / self.param_scaling_factors[0] + + air_density = paddle.unsqueeze(air_density, 1) + air_density = air_density.expand( + [ + air_density.shape[0], + volume_mesh_centers.shape[1], + air_density.shape[2], + ] + ) + air_density = air_density / self.param_scaling_factors[1] + + params = paddle.concat((inlet_velocity, air_density), axis=-1) + param_encoding = self.parameter_model(params) + + for f in range(num_variables): + for p in range(num_sample_points): + if p == 0: + volume_m_c = volume_mesh_centers + else: + noise = paddle.rand( + shape=volume_mesh_centers.shape, dtype=volume_mesh_centers.dtype + ) + noise = 2 * (noise - 0.5) + noise = noise / noise_intensity + dist = paddle.sqrt( + noise[:, :, 0:1] ** 2.0 + + noise[:, :, 1:2] ** 2.0 + + noise[:, :, 2:3] ** 2.0 + ) + volume_m_c = volume_mesh_centers + noise + basis_f = nn_basis[f](volume_m_c) + output = paddle.concat((basis_f, encoding_node, encoding_g), axis=-1) + if self.encode_parameters: + output = paddle.concat((output, param_encoding), axis=-1) + if p == 0: + output_center = agg_model[f](output) + else: + if p == 1: + output_neighbor = agg_model[f](output) * (1.0 / dist) + dist_sum = 1.0 / dist + else: + output_neighbor += agg_model[f](output) * (1.0 / dist) + dist_sum += 1.0 / dist + if num_sample_points > 1: + output_res = 0.5 * output_center + 0.5 * output_neighbor / dist_sum + else: + output_res = output_center + if f == 0: + output_all = output_res + else: + output_all = paddle.concat((output_all, output_res), axis=-1) + + return output_all + + def forward( + self, + data_dict, + ): + # Loading STL inputs, bounding box grids, precomputed SDF and scaling factors + + # STL nodes + geo_centers = data_dict["geometry_coordinates"] + + # Bounding box grid + s_grid = data_dict["surf_grid"] + sdf_surf_grid = data_dict["sdf_surf_grid"] + # Scaling factors + surf_max = data_dict["surface_min_max"][:, 1] + surf_min = data_dict["surface_min_max"][:, 0] + + # Parameters + stream_velocity = data_dict["stream_velocity"] + air_density = data_dict["air_density"] + + if self.output_features_vol is not None: + # Represent geometry on computational grid + # Computational domain grid + p_grid = data_dict["grid"] + sdf_grid = data_dict["sdf_grid"] + # Scaling factors + vol_max = data_dict["volume_min_max"][:, 1] + vol_min = data_dict["volume_min_max"][:, 0] + + # Normalize based on computational domain + geo_centers_vol = 2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1 + encoding_g_vol = self.geo_rep(geo_centers_vol, p_grid, sdf_grid) + + # Normalize based on BBox around surface (car) + geo_centers_surf = ( + 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + ) + encoding_g_surf = self.geo_rep(geo_centers_surf, s_grid, sdf_surf_grid) + + # SDF on volume mesh nodes + sdf_nodes = data_dict["sdf_nodes"] + # Positional encoding based on closest point on surface to a volume node + pos_volume_closest = data_dict["pos_volume_closest"] + # Positional encoding based on center of mass of geometry to volume node + pos_volume_center_of_mass = data_dict["pos_volume_center_of_mass"] + encoding_node_vol = paddle.concat( + (sdf_nodes, pos_volume_closest, pos_volume_center_of_mass), axis=-1 + ) + + # Calculate positional encoding on volume nodes + encoding_node_vol = self.position_encoder( + encoding_node_vol, eval_mode="volume" + ) + + if self.output_features_surf is not None: + # Represent geometry on bounding box + geo_centers_surf = ( + 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + ) + encoding_g_surf = self.geo_rep(geo_centers_surf, s_grid, sdf_surf_grid) + + # Positional encoding based on center of mass of geometry to surface node + pos_surface_center_of_mass = data_dict["pos_surface_center_of_mass"] + encoding_node_surf = pos_surface_center_of_mass + + # Calculate positional encoding on surface centers + encoding_node_surf = self.position_encoder( + encoding_node_surf, eval_mode="surface" + ) + + encoding_g = 0.5 * encoding_g_surf + # Average the encodings + if self.output_features_vol is not None: + encoding_g += 0.5 * encoding_g_vol + + if self.output_features_vol is not None: + # Calculate local geometry encoding for volume + # Sampled points on volume + volume_mesh_centers = data_dict["volume_mesh_centers"] + encoding_g_vol = self.geo_encoding_local( + encoding_g, volume_mesh_centers, p_grid + ) + + # Approximate solution on volume node + output_vol = self.calculate_solution( + volume_mesh_centers, + encoding_g_vol, + encoding_node_vol, + stream_velocity, + air_density, + eval_mode="volume", + ) + else: + output_vol = None + + if self.output_features_surf is not None: + # Sampled points on surface + surface_mesh_centers = data_dict["surface_mesh_centers"] + surface_normals = data_dict["surface_normals"] + surface_areas = data_dict["surface_areas"] + + # Neighbors of sampled points on surface + surface_mesh_neighbors = data_dict["surface_mesh_neighbors"] + surface_neighbors_normals = data_dict["surface_neighbors_normals"] + surface_neighbors_areas = data_dict["surface_neighbors_areas"] + surface_areas = paddle.unsqueeze(surface_areas, -1) + surface_neighbors_areas = paddle.unsqueeze(surface_neighbors_areas, -1) + # Calculate local geometry encoding for surface + encoding_g_surf = self.geo_encoding_local_surface( + 0.5 * encoding_g_surf, surface_mesh_centers, s_grid + ) + + # Approximate solution on surface cell center + if not self.surface_neighbors: + output_surf = self.calculate_solution( + surface_mesh_centers, + encoding_g_surf, + encoding_node_surf, + stream_velocity, + air_density, + eval_mode="surface", + num_sample_points=1, + noise_intensity=500, + ) + else: + output_surf = self.calculate_solution_with_neighbors( + surface_mesh_centers, + encoding_g_surf, + encoding_node_surf, + surface_mesh_neighbors, + surface_normals, + surface_neighbors_normals, + surface_areas, + surface_neighbors_areas, + stream_velocity, + air_density, + ) + else: + output_surf = None + + return output_vol, output_surf + + +if __name__ == "__main__": + from hydra import compose + from hydra import initialize + from omegaconf import OmegaConf + + if paddle.device.cuda.device_count() >= 1: + paddle.set_device("gpu") + else: + paddle.set_device("cpu") + cfg = OmegaConf.register_new_resolver("eval", eval) + with initialize(version_base="1.3", config_path="../../scripts/conf"): + cfg = compose(config_name="config") + cfg.model.model_type = "combined" + model = DoMINO( + input_features=3, + output_features_vol=5, + output_features_surf=4, + model_parameters=cfg.model, + ) + + bsize = 1 + nx, ny, nz = 128, 64, 48 + num_neigh = 7 + pos_normals_closest_vol = paddle.randn([bsize, 100, 3]) + pos_normals_com_vol = paddle.randn([bsize, 100, 3]) + pos_normals_com_surface = paddle.randn([bsize, 100, 3]) + geom_centers = paddle.randn([bsize, 100, 3]) + grid = paddle.randn([bsize, nx, ny, nz, 3]) + surf_grid = paddle.randn([bsize, nx, ny, nz, 3]) + sdf_grid = paddle.randn([bsize, nx, ny, nz]) + sdf_surf_grid = paddle.randn([bsize, nx, ny, nz]) + sdf_nodes = paddle.randn([bsize, 100, 1]) + surface_coordinates = paddle.randn([bsize, 100, 3]) + surface_neighbors = paddle.randn([bsize, 100, num_neigh, 3]) + surface_normals = paddle.randn([bsize, 100, 3]) + surface_neighbors_normals = paddle.randn([bsize, 100, num_neigh, 3]) + surface_sizes = paddle.randn([bsize, 100, 3]) + surface_neighbors_sizes = paddle.randn([bsize, 100, num_neigh, 3]) + volume_coordinates = paddle.randn([bsize, 100, 3]) + vol_grid_max_min = paddle.randn([bsize, 2, 3]) + surf_grid_max_min = paddle.randn([bsize, 2, 3]) + stream_velocity = paddle.randn([bsize, 1]) + air_density = paddle.randn([bsize, 1]) + input_dict = { + "pos_volume_closest": pos_normals_closest_vol, + "pos_volume_center_of_mass": pos_normals_com_vol, + "pos_surface_center_of_mass": pos_normals_com_surface, + "geometry_coordinates": geom_centers, + "grid": grid, + "surf_grid": surf_grid, + "sdf_grid": sdf_grid, + "sdf_surf_grid": sdf_surf_grid, + "sdf_nodes": sdf_nodes, + "surface_mesh_centers": surface_coordinates, + "surface_mesh_neighbors": surface_neighbors, + "surface_normals": surface_normals, + "surface_neighbors_normals": surface_neighbors_normals, + "surface_areas": surface_sizes, + "surface_neighbors_areas": surface_neighbors_sizes, + "volume_mesh_centers": volume_coordinates, + "volume_min_max": vol_grid_max_min, + "surface_min_max": surf_grid_max_min, + "stream_velocity": stream_velocity, + "air_density": air_density, + } + output = model(input_dict) + print(f"{output[0].shape}, {output[1].shape}") diff --git a/ppsci/data/dataset/__init__.py b/ppsci/data/dataset/__init__.py index 9ece354700..4146846cf2 100644 --- a/ppsci/data/dataset/__init__.py +++ b/ppsci/data/dataset/__init__.py @@ -27,6 +27,7 @@ from ppsci.data.dataset.cylinder_dataset import MeshCylinderDataset from ppsci.data.dataset.darcyflow_dataset import DarcyFlowDataset from ppsci.data.dataset.dgmr_dataset import DGMRDataset +from ppsci.data.dataset.domino_datapipe import DoMINODataPipe from ppsci.data.dataset.drivaernet_dataset import DrivAerNetDataset from ppsci.data.dataset.drivaernetplusplus_dataset import DrivAerNetPlusPlusDataset from ppsci.data.dataset.enso_dataset import ENSODataset @@ -93,6 +94,7 @@ "DrivAerNetDataset", "DrivAerNetPlusPlusDataset", "IFMMoeDataset", + "DoMINODataPipe", ] diff --git a/ppsci/data/dataset/domino_datapipe.py b/ppsci/data/dataset/domino_datapipe.py new file mode 100644 index 0000000000..7dc7fbf3a4 --- /dev/null +++ b/ppsci/data/dataset/domino_datapipe.py @@ -0,0 +1,1003 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino + +""" +This code provides the datapipe for reading the processed npy files, +generating multi-res grids, calculating signed distance fields, +positional encodings, sampling random points in the volume and on surface, +normalizing fields and returning the output tensors as a dictionary. + +This datapipe also non-dimensionalizes the fields, so the order in which the variables should +be fixed: velocity, pressure, turbulent viscosity for volume variables and +pressure, wall-shear-stress for surface variables. The different parameters such as +variable names, domain resolution, sampling size etc. are configurable in config.yaml. +""" + +import os +import random +import time +from pathlib import Path +from typing import Literal +from typing import Optional +from typing import Sequence +from typing import Union + +import numpy as np +from paddle.io import Dataset +from scipy.spatial import KDTree + +from ppsci.utils.sdf import signed_distance_field + +try: + import pyvista as pv + + PV_AVAILABLE = True +except ImportError: + PV_AVAILABLE = False +try: + import vtk + from vtk import vtkDataSetTriangleFilter + from vtk.util import numpy_support + + VTK_AVAILABLE = True +except ImportError: + VTK_AVAILABLE = False + +AIR_DENSITY = 1.205 +STREAM_VELOCITY = 30.00 + + +def calculate_center_of_mass(stl_centers, stl_sizes): + """Function to calculate center of mass""" + stl_sizes = np.expand_dims(stl_sizes, -1) + center_of_mass = np.sum(stl_centers * stl_sizes, axis=0) / np.sum(stl_sizes, axis=0) + return center_of_mass + + +def normalize(field, mx, mn): + """Function to normalize fields""" + return 2.0 * (field - mn) / (mx - mn) - 1.0 + + +def unnormalize(field, mx, mn): + """Function to unnormalize fields""" + return (field + 1.0) * (mx - mn) * 0.5 + mn + + +def standardize(field, mean, std): + """Function to standardize fields""" + return (field - mean) / std + + +def unstandardize(field, mean, std): + """Function to unstandardize fields""" + return field * std + mean + + +def write_to_vtp(polydata, filename): + """Function to write polydata to vtp""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + writer = vtk.vtkXMLPolyDataWriter() + writer.SetFileName(filename) + writer.SetInputData(polydata) + writer.Write() + + +def write_to_vtu(polydata, filename): + """Function to write polydata to vtu""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + writer = vtk.vtkXMLUnstructuredGridWriter() + writer.SetFileName(filename) + writer.SetInputData(polydata) + writer.Write() + + +def extract_surface_triangles(tet_mesh): + """Extracts the surface triangles from a triangular mesh.""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + if not PV_AVAILABLE: + raise ImportError("PyVista is not installed. This function cannot be used.") + surface_filter = vtk.vtkDataSetSurfaceFilter() + surface_filter.SetInputData(tet_mesh) + surface_filter.Update() + + surface_mesh = pv.wrap(surface_filter.GetOutput()) + triangle_indices = [] + faces = surface_mesh.faces.reshape((-1, 4)) + for face in faces: + if face[0] == 3: + triangle_indices.extend([face[1], face[2], face[3]]) + else: + raise ValueError("Face is not a triangle") + + return triangle_indices + + +def convert_to_tet_mesh(polydata): + """Function to convert tet to stl""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + # Create a VTK DataSetTriangleFilter object + tet_filter = vtkDataSetTriangleFilter() + tet_filter.SetInputData(polydata) + tet_filter.Update() # Update to apply the filter + + # Get the output as an UnstructuredGrid + # tet_mesh = pv.wrap(tet_filter.GetOutput()) + tet_mesh = tet_filter.GetOutput() + return tet_mesh + + +def get_node_to_elem(polydata): + """Function to convert node to elem""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + c2p = vtk.vtkPointDataToCellData() + c2p.SetInputData(polydata) + c2p.Update() + cell_data = c2p.GetOutput() + return cell_data + + +def get_fields_from_cell(ptdata, var_list): + """Function to get fields from elem""" + fields = [] + for var in var_list: + variable = ptdata.GetArray(var) + num_tuples = variable.GetNumberOfTuples() + cell_fields = [] + for j in range(num_tuples): + variable_value = np.array(variable.GetTuple(j)) + cell_fields.append(variable_value) + cell_fields = np.asarray(cell_fields) + fields.append(cell_fields) + fields = np.transpose(np.asarray(fields), (1, 0)) + + return fields + + +def get_fields(data, variables): + """Function to get fields from VTP/VTU""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + fields = [] + for array_name in variables: + try: + array = data.GetArray(array_name) + except ValueError: + raise ValueError( + f"Failed to get array {array_name} from the unstructured grid." + ) + array_data = numpy_support.vtk_to_numpy(array).reshape( + array.GetNumberOfTuples(), array.GetNumberOfComponents() + ) + fields.append(array_data) + return fields + + +def get_vertices(polydata): + """Function to get vertices""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + points = polydata.GetPoints() + vertices = numpy_support.vtk_to_numpy(points.GetData()) + return vertices + + +def get_volume_data(polydata, variables): + """Function to get volume data""" + vertices = get_vertices(polydata) + point_data = polydata.GetPointData() + + fields = get_fields(point_data, variables) + + return vertices, fields + + +def get_surface_data(polydata, variables): + """Function to get surface data""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + points = polydata.GetPoints() + vertices = np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]) + + point_data = polydata.GetPointData() + fields = [] + for array_name in variables: + try: + array = point_data.GetArray(array_name) + except ValueError: + raise ValueError( + f"Failed to get array {array_name} from the unstructured grid." + ) + array_data = np.zeros( + (points.GetNumberOfPoints(), array.GetNumberOfComponents()) + ) + for j in range(points.GetNumberOfPoints()): + array.GetTuple(j, array_data[j]) + fields.append(array_data) + + polys = polydata.GetPolys() + if polys is None: + raise ValueError("Failed to get polygons from the polydata.") + polys.InitTraversal() + edges = [] + id_list = vtk.vtkIdList() + for _ in range(polys.GetNumberOfCells()): + polys.GetNextCell(id_list) + num_ids = id_list.GetNumberOfIds() + edges = [ + (id_list.GetId(j), id_list.GetId((j + 1) % num_ids)) for j in range(num_ids) + ] + + return vertices, fields, edges + + +def cal_normal_positional_encoding(coordinates_a, coordinates_b=None, cell_length=[]): + """Function to get normal positional encoding""" + dx = cell_length[0] + dy = cell_length[1] + dz = cell_length[2] + if coordinates_b is not None: + normals = coordinates_a - coordinates_b + pos_x = np.asarray(calculate_pos_encoding(normals[:, 0] / dx, d=4)) + pos_y = np.asarray(calculate_pos_encoding(normals[:, 1] / dy, d=4)) + pos_z = np.asarray(calculate_pos_encoding(normals[:, 2] / dz, d=4)) + pos_normals = np.concatenate((pos_x, pos_y, pos_z), axis=0).reshape(-1, 12) + else: + normals = coordinates_a + pos_x = np.asarray(calculate_pos_encoding(normals[:, 0] / dx, d=4)) + pos_y = np.asarray(calculate_pos_encoding(normals[:, 1] / dy, d=4)) + pos_z = np.asarray(calculate_pos_encoding(normals[:, 2] / dz, d=4)) + pos_normals = np.concatenate((pos_x, pos_y, pos_z), axis=0).reshape(-1, 12) + + return pos_normals + + +def pad(arr, npoin, pad_value=0.0): + """Function for padding""" + arr_pad = pad_value * np.ones( + (npoin - arr.shape[0], arr.shape[1]), dtype=np.float32 + ) + arr_padded = np.concatenate((arr, arr_pad), axis=0) + return arr_padded + + +def shuffle_array(arr, npoin): + """Function for shuffling arrays""" + np.random.seed(seed=int(time.time())) + idx = np.arange(arr.shape[0]) + np.random.shuffle(idx) + idx = idx[:npoin] + return arr[idx], idx + + +def get_filenames(filepath): + """Function to get filenames from a directory""" + if os.path.exists(filepath): + filenames = os.listdir(filepath) + return filenames + else: + FileNotFoundError() + + +def calculate_pos_encoding(nx, d=8): + """Function for calculating positional encoding""" + vec = [] + for k in range(int(d / 2)): + vec.append(np.sin(nx / 10000 ** (2 * (k) / d))) + vec.append(np.cos(nx / 10000 ** (2 * (k) / d))) + return vec + + +def create_grid(mx, mn, nres): + """Function to create grid""" + dx = np.linspace(mn[0], mx[0], nres[0]) + dy = np.linspace(mn[1], mx[1], nres[1]) + dz = np.linspace(mn[2], mx[2], nres[2]) + + xv, yv, zv = np.meshgrid(dx, dy, dz) + xv = np.expand_dims(xv, -1) + yv = np.expand_dims(yv, -1) + zv = np.expand_dims(zv, -1) + grid = np.concatenate((xv, yv, zv), axis=-1) + grid = np.transpose(grid, (1, 0, 2, 3)) + + return grid + + +def area_weighted_shuffle_array(arr, npoin, area): + factor = 1.0 + total_area = np.sum(area**factor) + probs = area**factor / total_area + np.random.seed(seed=int(time.time())) + idx = np.arange(arr.shape[0]) + np.random.shuffle(idx) + ids = np.random.choice(idx, npoin, p=probs[idx]) + return arr[ids], ids + + +class DoMINODataPipe(Dataset): + """ + Datapipe for DoMINO + + """ + + def __init__( + self, + data_path: Union[str, Path], # Input data path + phase: Literal["train", "val", "test"] = "train", # Train, test or val + surface_variables: Optional[Sequence] = ( + "pMean", + "wallShearStress", + ), # Names of surface variables + volume_variables: Optional[Sequence] = ( + "UMean", + "pMean", + ), # Names of volume variables + sampling: bool = False, # Sampling True or False + device: int = 0, # GPU device id + grid_resolution: Optional[Sequence] = ( + 256, + 96, + 64, + ), # Resolution of latent grid + normalize_coordinates: bool = False, # Normalize coordinates? + sample_in_bbox: bool = False, # Sample points in a specified bounding box + volume_points_sample: int = 1024, # Number of volume points sampled per batch + surface_points_sample: int = 1024, # Number of surface points sampled per batch + geom_points_sample: int = 300000, # Number of STL points sampled per batch + positional_encoding: bool = False, # Positional encoding, True or False + volume_factors=None, # Non-dimensionalization factors for volume variables + surface_factors=None, # Non-dimensionalization factors for surface variables + scaling_type=None, # Scaling min_max or mean_std + model_type=None, # Model_type, surface, volume or combined + bounding_box_dims=None, # Dimensions of bounding box + bounding_box_dims_surf=None, # Dimensions of bounding box + compute_scaling_factors=False, + num_surface_neighbors=11, # Surface neighbors to consider + ): + if isinstance(data_path, str): + data_path = Path(data_path) + data_path = data_path.expanduser() + + self.data_path = data_path + + if phase not in [ + "train", + "val", + "test", + ]: + raise AssertionError( + f"phase should be one of ['train', 'val', 'test'], got {phase}" + ) + + if not self.data_path.exists(): + raise AssertionError(f"Path {self.data_path} does not exist") + + if not self.data_path.is_dir(): + raise AssertionError(f"Path {self.data_path} is not a directory") + + self.sampling = sampling + self.grid_resolution = grid_resolution + self.normalize_coordinates = normalize_coordinates + self.model_type = model_type + self.bounding_box_dims = [] + self.bounding_box_dims.append(np.asarray(bounding_box_dims.max)) + self.bounding_box_dims.append(np.asarray(bounding_box_dims.min)) + + self.bounding_box_dims_surf = [] + self.bounding_box_dims_surf.append(np.asarray(bounding_box_dims_surf.max)) + self.bounding_box_dims_surf.append(np.asarray(bounding_box_dims_surf.min)) + + self.filenames = get_filenames(self.data_path) + total_files = len(self.filenames) + + self.phase = phase + if phase == "train": + self.indices = np.array(range(total_files)) + elif phase == "val": + self.indices = np.array(range(total_files)) + elif phase == "test": + self.indices = np.array(range(total_files)) + + np.random.shuffle(self.indices) + self.surface_variables = surface_variables + self.volume_variables = volume_variables + self.volume_points = volume_points_sample + self.surface_points = surface_points_sample + self.geom_points_sample = geom_points_sample + self.sample_in_bbox = sample_in_bbox + self.device = device + self.positional_encoding = positional_encoding + self.volume_factors = volume_factors + self.surface_factors = surface_factors + self.scaling_type = scaling_type + self.compute_scaling_factors = compute_scaling_factors + self.num_surface_neighbors = num_surface_neighbors + + def __len__(self): + return len(self.indices) + + def __getitem__(self, idx): + index = self.indices[idx] + cfd_filename = self.filenames[index] + + filepath = self.data_path / cfd_filename + data_dict = np.load(filepath, allow_pickle=True).item() + + stl_vertices = data_dict["stl_coordinates"] + stl_centers = data_dict["stl_centers"] + mesh_indices_flattened = data_dict["stl_faces"] + stl_sizes = data_dict["stl_areas"] + + # Check if stream velocity in keys + if "stream_velocity" in data_dict.keys(): + STREAM_VELOCITY = data_dict["stream_velocity"] + AIR_DENSITY = data_dict["air_density"] + else: + AIR_DENSITY = 1.205 + STREAM_VELOCITY = 30.00 + + # + length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) + + # Center of mass calculation + center_of_mass = calculate_center_of_mass(stl_centers, stl_sizes) + + if self.bounding_box_dims_surf is None: + s_max = np.amax(stl_vertices, 0) + s_min = np.amin(stl_vertices, 0) + else: + s_max = np.float32(self.bounding_box_dims_surf[0]) + s_min = np.float32(self.bounding_box_dims_surf[1]) + + nx, ny, nz = self.grid_resolution + + surf_grid = create_grid(s_max, s_min, [nx, ny, nz]) + surf_grid_reshaped = surf_grid.reshape(nx * ny * nz, 3) + + # SDF calculation on the grid using WARP + sdf_surf_grid = ( + signed_distance_field( + stl_vertices, + mesh_indices_flattened, + surf_grid_reshaped, + use_sign_winding_number=True, + ) + .numpy() + .reshape(nx, ny, nz) + ) + surf_grid = np.float32(surf_grid) + sdf_surf_grid = np.float32(sdf_surf_grid) + surf_grid_max_min = np.float32(np.asarray([s_min, s_max])) + + if self.model_type == "volume" or self.model_type == "combined": + volume_coordinates = data_dict["volume_mesh_centers"] + volume_fields = data_dict["volume_fields"] + + if not self.compute_scaling_factors: + if self.bounding_box_dims is None: + c_max = s_max + (s_max - s_min) / 2 + c_min = s_min - (s_max - s_min) / 2 + c_min[2] = s_min[2] + else: + c_max = np.float32(self.bounding_box_dims[0]) + c_min = np.float32(self.bounding_box_dims[1]) + + ids_in_bbox = np.where( + (volume_coordinates[:, 0] > c_min[0]) + & (volume_coordinates[:, 0] < c_max[0]) + & (volume_coordinates[:, 1] > c_min[1]) + & (volume_coordinates[:, 1] < c_max[1]) + & (volume_coordinates[:, 2] > c_min[2]) + & (volume_coordinates[:, 2] < c_max[2]) + ) + + if self.sample_in_bbox: + volume_coordinates = volume_coordinates[ids_in_bbox] + volume_fields = volume_fields[ids_in_bbox] + + dx, dy, dz = ( + (c_max[0] - c_min[0]) / nx, + (c_max[1] - c_min[1]) / ny, + (c_max[2] - c_min[2]) / nz, + ) + + # Generate a grid of specified resolution to map the bounding box + # The grid is used for capturing structured geometry features and SDF representation of geometry + grid = create_grid(c_max, c_min, [nx, ny, nz]) + grid_reshaped = grid.reshape(nx * ny * nz, 3) + + # SDF calculation on the grid using WARP + sdf_grid = ( + signed_distance_field( + stl_vertices, + mesh_indices_flattened, + grid_reshaped, + use_sign_winding_number=True, + ) + .numpy() + .reshape(nx, ny, nz) + ) + + if self.sampling: + volume_coordinates_sampled, idx_volume = shuffle_array( + volume_coordinates, self.volume_points + ) + if volume_coordinates_sampled.shape[0] < self.volume_points: + volume_coordinates_sampled = pad( + volume_coordinates_sampled, + self.volume_points, + pad_value=-10.0, + ) + volume_fields = volume_fields[idx_volume] + volume_coordinates = volume_coordinates_sampled + + sdf_nodes, sdf_node_closest_point = signed_distance_field( + stl_vertices, + mesh_indices_flattened, + volume_coordinates, + include_hit_points=True, + use_sign_winding_number=True, + ) + sdf_nodes = sdf_nodes.numpy().reshape(-1, 1) + sdf_node_closest_point = sdf_node_closest_point.numpy() + + if self.positional_encoding: + pos_normals_closest_vol = cal_normal_positional_encoding( + volume_coordinates, + sdf_node_closest_point, + cell_length=[dx, dy, dz], + ) + pos_normals_com_vol = cal_normal_positional_encoding( + volume_coordinates, center_of_mass, cell_length=[dx, dy, dz] + ) + else: + pos_normals_closest_vol = ( + volume_coordinates - sdf_node_closest_point + ) + pos_normals_com_vol = volume_coordinates - center_of_mass + + if self.normalize_coordinates: + volume_coordinates = normalize(volume_coordinates, c_max, c_min) + grid = normalize(grid, c_max, c_min) + + if self.scaling_type is not None: + if self.volume_factors is not None: + if self.scaling_type == "mean_std_scaling": + vol_mean = self.volume_factors[0] + vol_std = self.volume_factors[1] + volume_fields = standardize( + volume_fields, vol_mean, vol_std + ) + elif self.scaling_type == "min_max_scaling": + vol_min = self.volume_factors[1] + vol_max = self.volume_factors[0] + volume_fields = normalize(volume_fields, vol_max, vol_min) + + volume_fields = np.float32(volume_fields) + pos_normals_closest_vol = np.float32(pos_normals_closest_vol) + pos_normals_com_vol = np.float32(pos_normals_com_vol) + volume_coordinates = np.float32(volume_coordinates) + sdf_nodes = np.float32(sdf_nodes) + sdf_grid = np.float32(sdf_grid) + grid = np.float32(grid) + vol_grid_max_min = np.float32(np.asarray([c_min, c_max])) + else: + pos_normals_closest_vol = None + pos_normals_com_vol = None + sdf_nodes = None + sdf_grid = None + grid = None + vol_grid_max_min = None + + else: + volume_coordinates = None + volume_fields = None + pos_normals_closest_vol = None + pos_normals_com_vol = None + sdf_nodes = None + sdf_grid = None + grid = None + vol_grid_max_min = None + + if self.model_type == "surface" or self.model_type == "combined": + surface_coordinates = data_dict["surface_mesh_centers"] + surface_normals = data_dict["surface_normals"] + surface_sizes = data_dict["surface_areas"] + surface_fields = data_dict["surface_fields"] + + if not self.compute_scaling_factors: + + c_max = np.float32(self.bounding_box_dims[0]) + c_min = np.float32(self.bounding_box_dims[1]) + + ids_in_bbox = np.where( + (surface_coordinates[:, 0] > c_min[0]) + & (surface_coordinates[:, 0] < c_max[0]) + & (surface_coordinates[:, 1] > c_min[1]) + & (surface_coordinates[:, 1] < c_max[1]) + & (surface_coordinates[:, 2] > c_min[2]) + & (surface_coordinates[:, 2] < c_max[2]) + ) + surface_coordinates = surface_coordinates[ids_in_bbox] + surface_normals = surface_normals[ids_in_bbox] + surface_sizes = surface_sizes[ids_in_bbox] + surface_fields = surface_fields[ids_in_bbox] + + # Get neighbors + interp_func = KDTree(surface_coordinates) + dd, ii = interp_func.query( + surface_coordinates, k=self.num_surface_neighbors + ) + surface_neighbors = surface_coordinates[ii] + surface_neighbors = surface_neighbors[:, 1:] + + surface_neighbors_normals = surface_normals[ii] + surface_neighbors_normals = surface_neighbors_normals[:, 1:] + surface_neighbors_sizes = surface_sizes[ii] + surface_neighbors_sizes = surface_neighbors_sizes[:, 1:] + + dx, dy, dz = ( + (s_max[0] - s_min[0]) / nx, + (s_max[1] - s_min[1]) / ny, + (s_max[2] - s_min[2]) / nz, + ) + + if self.positional_encoding: + pos_normals_com_surface = cal_normal_positional_encoding( + surface_coordinates, center_of_mass, cell_length=[dx, dy, dz] + ) + else: + pos_normals_com_surface = surface_coordinates - center_of_mass + + if self.normalize_coordinates: + surface_coordinates = normalize(surface_coordinates, s_max, s_min) + surface_neighbors = normalize(surface_neighbors, s_max, s_min) + surf_grid = normalize(surf_grid, s_max, s_min) + + if self.sampling: + ( + surface_coordinates_sampled, + idx_surface, + ) = area_weighted_shuffle_array( + surface_coordinates, self.surface_points, surface_sizes + ) + if surface_coordinates_sampled.shape[0] < self.surface_points: + surface_coordinates_sampled = pad( + surface_coordinates_sampled, + self.surface_points, + pad_value=-10.0, + ) + + surface_fields = surface_fields[idx_surface] + pos_normals_com_surface = pos_normals_com_surface[idx_surface] + surface_normals = surface_normals[idx_surface] + surface_sizes = surface_sizes[idx_surface] + surface_neighbors = surface_neighbors[idx_surface] + surface_neighbors_normals = surface_neighbors_normals[idx_surface] + surface_neighbors_sizes = surface_neighbors_sizes[idx_surface] + surface_coordinates = surface_coordinates_sampled + + if self.scaling_type is not None: + if self.surface_factors is not None: + if self.scaling_type == "mean_std_scaling": + surf_mean = self.surface_factors[0] + surf_std = self.surface_factors[1] + surface_fields = standardize( + surface_fields, surf_mean, surf_std + ) + elif self.scaling_type == "min_max_scaling": + surf_min = self.surface_factors[1] + surf_max = self.surface_factors[0] + surface_fields = normalize( + surface_fields, surf_max, surf_min + ) + + surface_coordinates = np.float32(surface_coordinates) + surface_fields = np.float32(surface_fields) + surface_sizes = np.float32(surface_sizes) + surface_normals = np.float32(surface_normals) + surface_neighbors = np.float32(surface_neighbors) + surface_neighbors_normals = np.float32(surface_neighbors_normals) + surface_neighbors_sizes = np.float32(surface_neighbors_sizes) + pos_normals_com_surface = np.float32(pos_normals_com_surface) + else: + surface_sizes = None + surface_normals = None + surface_neighbors = None + surface_neighbors_normals = None + surface_neighbors_sizes = None + pos_normals_com_surface = None + + else: + surface_coordinates = None + surface_fields = None + surface_sizes = None + surface_normals = None + surface_neighbors = None + surface_neighbors_normals = None + surface_neighbors_sizes = None + pos_normals_com_surface = None + + if self.sampling: + geometry_points = self.geom_points_sample + geometry_coordinates_sampled, idx_geometry = shuffle_array( + stl_vertices, geometry_points + ) + if geometry_coordinates_sampled.shape[0] < geometry_points: + geometry_coordinates_sampled = pad( + geometry_coordinates_sampled, geometry_points, pad_value=-100.0 + ) + geom_centers = geometry_coordinates_sampled + else: + geom_centers = stl_vertices + + geom_centers = np.float32(geom_centers) + + if self.model_type == "combined": + # Add the parameters to the dictionary + return { + "pos_volume_closest": pos_normals_closest_vol, + "pos_volume_center_of_mass": pos_normals_com_vol, + "pos_surface_center_of_mass": pos_normals_com_surface, + "geometry_coordinates": geom_centers, + "grid": grid, + "surf_grid": surf_grid, + "sdf_grid": sdf_grid, + "sdf_surf_grid": sdf_surf_grid, + "sdf_nodes": sdf_nodes, + "surface_mesh_centers": surface_coordinates, + "surface_mesh_neighbors": surface_neighbors, + "surface_normals": surface_normals, + "surface_neighbors_normals": surface_neighbors_normals, + "surface_areas": surface_sizes, + "surface_neighbors_areas": surface_neighbors_sizes, + "volume_fields": volume_fields, + "volume_mesh_centers": volume_coordinates, + "surface_fields": surface_fields, + "volume_min_max": vol_grid_max_min, + "surface_min_max": surf_grid_max_min, + "length_scale": length_scale, + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), -1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), -1 + ), + } + elif self.model_type == "surface": + return { + "pos_surface_center_of_mass": pos_normals_com_surface, + "geometry_coordinates": geom_centers, + "surf_grid": surf_grid, + "sdf_surf_grid": sdf_surf_grid, + "surface_mesh_centers": surface_coordinates, + "surface_mesh_neighbors": surface_neighbors, + "surface_normals": surface_normals, + "surface_neighbors_normals": surface_neighbors_normals, + "surface_areas": surface_sizes, + "surface_neighbors_areas": surface_neighbors_sizes, + "surface_fields": surface_fields, + "surface_min_max": surf_grid_max_min, + "length_scale": length_scale, + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), -1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), -1 + ), + } + elif self.model_type == "volume": + return { + "pos_volume_closest": pos_normals_closest_vol, + "pos_volume_center_of_mass": pos_normals_com_vol, + "geometry_coordinates": geom_centers, + "grid": grid, + "surf_grid": surf_grid, + "sdf_grid": sdf_grid, + "sdf_surf_grid": sdf_surf_grid, + "sdf_nodes": sdf_nodes, + "volume_fields": volume_fields, + "volume_mesh_centers": volume_coordinates, + "volume_min_max": vol_grid_max_min, + "surface_min_max": surf_grid_max_min, + "length_scale": length_scale, + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), -1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), -1 + ), + } + + +class DriveSimPaths: + @staticmethod + def geometry_path(car_dir: Path) -> Path: + return car_dir / "body.stl" + + @staticmethod + def volume_path(car_dir: Path) -> Path: + return car_dir / "VTK/simpleFoam_steady_3000/internal.vtu" + + @staticmethod + def surface_path(car_dir: Path) -> Path: + return car_dir / "VTK/simpleFoam_steady_3000/boundary/aero_suv.vtp" + + +class DrivAerAwsPaths: + @staticmethod + def _get_index(car_dir: Path) -> str: + return car_dir.name.removeprefix("run_") + + @staticmethod + def geometry_path(car_dir: Path) -> Path: + return car_dir / f"drivaer_{DrivAerAwsPaths._get_index(car_dir)}.stl" + + @staticmethod + def volume_path(car_dir: Path) -> Path: + return car_dir / f"volume_{DrivAerAwsPaths._get_index(car_dir)}.vtu" + + @staticmethod + def surface_path(car_dir: Path) -> Path: + return car_dir / f"boundary_{DrivAerAwsPaths._get_index(car_dir)}.vtp" + + +class OpenFoamDataset(Dataset): + """ + Datapipe for converting openfoam dataset to npy + + """ + + def __init__( + self, + data_path: Union[str, Path], + kind: Literal["drivesim", "drivaer_aws"] = "drivesim", + surface_variables: Optional[list] = [ + "pMean", + "wallShearStress", + ], + volume_variables: Optional[list] = ["UMean", "pMean"], + device: int = 0, + model_type=None, + ): + if isinstance(data_path, str): + data_path = Path(data_path) + data_path = data_path.expanduser() + + self.data_path = data_path + + supported_kinds = ["drivesim", "drivaer_aws"] + assert ( + kind in supported_kinds + ), f"kind should be one of {supported_kinds}, got {kind}" + self.path_getter = DriveSimPaths if kind == "drivesim" else DrivAerAwsPaths + + assert self.data_path.exists(), f"Path {self.data_path} does not exist" + + assert self.data_path.is_dir(), f"Path {self.data_path} is not a directory" + + self.filenames = get_filenames(self.data_path) + random.shuffle(self.filenames) + self.indices = np.array(len(self.filenames)) + + self.surface_variables = surface_variables + self.volume_variables = volume_variables + self.device = device + self.model_type = model_type + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + cfd_filename = self.filenames[idx] + car_dir = self.data_path / cfd_filename + + stl_path = self.path_getter.geometry_path(car_dir) + reader = pv.get_reader(stl_path) + mesh_stl = reader.read() + stl_vertices = mesh_stl.points + stl_faces = np.array(mesh_stl.faces).reshape((-1, 4))[ + :, 1: + ] # Assuming triangular elements + mesh_indices_flattened = stl_faces.flatten() + stl_sizes = mesh_stl.compute_cell_sizes(length=False, area=True, volume=False) + stl_sizes = np.array(stl_sizes.cell_data["Area"]) + stl_centers = np.array(mesh_stl.cell_centers().points) + + length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) + + if self.model_type == "volume" or self.model_type == "combined": + filepath = self.path_getter.volume_path(car_dir) + reader = vtk.vtkXMLUnstructuredGridReader() + reader.SetFileName(filepath) + reader.Update() + + # Get the unstructured grid data + polydata = reader.GetOutput() + volume_coordinates, volume_fields = get_volume_data( + polydata, self.volume_variables + ) + volume_fields = np.concatenate(volume_fields, axis=-1) + + # Non-dimensionalize volume fields + volume_fields[:, :3] = volume_fields[:, :3] / STREAM_VELOCITY + volume_fields[:, 3:4] = volume_fields[:, 3:4] / ( + AIR_DENSITY * STREAM_VELOCITY**2.0 + ) + + volume_fields[:, 4:] = volume_fields[:, 4:] / ( + STREAM_VELOCITY * length_scale + ) + else: + volume_fields = None + volume_coordinates = None + + if self.model_type == "surface" or self.model_type == "combined": + surface_filepath = self.path_getter.surface_path(car_dir) + reader = vtk.vtkXMLPolyDataReader() + reader.SetFileName(surface_filepath) + reader.Update() + polydata = reader.GetOutput() + + celldata_all = get_node_to_elem(polydata) + celldata = celldata_all.GetCellData() + surface_fields = get_fields(celldata, self.surface_variables) + surface_fields = np.concatenate(surface_fields, axis=-1) + + mesh = pv.PolyData(polydata) + surface_coordinates = np.array(mesh.cell_centers().points) + + surface_normals = np.array(mesh.cell_normals) + surface_sizes = mesh.compute_cell_sizes( + length=False, area=True, volume=False + ) + surface_sizes = np.array(surface_sizes.cell_data["Area"]) + + # Normalize cell normals + surface_normals = ( + surface_normals / np.linalg.norm(surface_normals, axis=1)[:, np.newaxis] + ) + + # Non-dimensionalize surface fields + surface_fields = surface_fields / (AIR_DENSITY * STREAM_VELOCITY**2.0) + else: + surface_fields = None + surface_coordinates = None + surface_normals = None + surface_sizes = None + + # Add the parameters to the dictionary + return { + "stl_coordinates": np.float32(stl_vertices), + "stl_centers": np.float32(stl_centers), + "stl_faces": np.float32(mesh_indices_flattened), + "stl_areas": np.float32(stl_sizes), + "surface_mesh_centers": np.float32(surface_coordinates), + "surface_normals": np.float32(surface_normals), + "surface_areas": np.float32(surface_sizes), + "volume_fields": np.float32(volume_fields), + "volume_mesh_centers": np.float32(volume_coordinates), + "surface_fields": np.float32(surface_fields), + "filename": cfd_filename, + "stream_velocity": STREAM_VELOCITY, + "air_density": AIR_DENSITY, + } diff --git a/ppsci/data/process/__init__.py b/ppsci/data/process/__init__.py index f46c8dd9cf..d3272a4730 100644 --- a/ppsci/data/process/__init__.py +++ b/ppsci/data/process/__init__.py @@ -13,9 +13,11 @@ # limitations under the License. from ppsci.data.process import batch_transform +from ppsci.data.process import openfoam from ppsci.data.process import transform __all__ = [ "batch_transform", "transform", + "openfoam", ] diff --git a/ppsci/data/process/openfoam/__init__.py b/ppsci/data/process/openfoam/__init__.py new file mode 100644 index 0000000000..66c1919070 --- /dev/null +++ b/ppsci/data/process/openfoam/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino + +from ppsci.data.process.openfoam.preprocess import process_files + +__all__ = [ + "process_files", +] diff --git a/ppsci/data/process/openfoam/preprocess.py b/ppsci/data/process/openfoam/preprocess.py new file mode 100644 index 0000000000..573674d879 --- /dev/null +++ b/ppsci/data/process/openfoam/preprocess.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino + +import os +import time + +import numpy as np + + +def process_files(*args_list): + ids = args_list[0] + processor_id = args_list[1] + fm_data = args_list[2] + output_dir = args_list[3] + for j in ids: + fname = fm_data.filenames[j] + if len(os.listdir(os.path.join(fm_data.data_path, fname))) == 0: + print(f"Skipping {fname} - empty.") + continue + outname = os.path.join(output_dir, fname) + print("Filename:%s on processor: %d" % (outname, processor_id)) + filename = f"{outname}.npy" + if os.path.exists(filename): + print(f"Skipping {filename} - already exists.") + continue + start_time = time.time() + data_dict = fm_data[j] + np.save(filename, data_dict) + print("Time taken for %d = %f" % (j, time.time() - start_time)) diff --git a/ppsci/utils/__init__.py b/ppsci/utils/__init__.py index 3382eee856..bc528205a6 100644 --- a/ppsci/utils/__init__.py +++ b/ppsci/utils/__init__.py @@ -35,6 +35,7 @@ from ppsci.utils.save_load import load_checkpoint from ppsci.utils.save_load import load_pretrain from ppsci.utils.save_load import save_checkpoint +from ppsci.utils.sdf import signed_distance_field from ppsci.utils.symbolic import lambdify from ppsci.utils.writer import save_csv_file from ppsci.utils.writer import save_tecplot_file @@ -63,4 +64,5 @@ "load_pretrain", "save_checkpoint", "lambdify", + "signed_distance_field", ] diff --git a/ppsci/utils/sdf.py b/ppsci/utils/sdf.py new file mode 100644 index 0000000000..e789cdfe33 --- /dev/null +++ b/ppsci/utils/sdf.py @@ -0,0 +1,140 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino + +import warp as wp +from numpy.typing import NDArray + + +@wp.kernel +def _bvh_query_distance( + mesh: wp.uint64, + points: wp.array(dtype=wp.vec3f), + max_dist: wp.float32, + sdf: wp.array(dtype=wp.float32), + sdf_hit_point: wp.array(dtype=wp.vec3f), + sdf_hit_point_id: wp.array(dtype=wp.int32), + use_sign_winding_number: bool = False, +): + + """ + Computes the signed distance from each point in the given array `points` + to the mesh represented by `mesh`,within the maximum distance `max_dist`, + and stores the result in the array `sdf`. + + Parameters: + mesh (wp.uint64): The identifier of the mesh. + points (wp.array): An array of 3D points for which to compute the + signed distance. + max_dist (wp.float32): The maximum distance within which to search + for the closest point on the mesh. + sdf (wp.array): An array to store the computed signed distances. + sdf_hit_point (wp.array): An array to store the computed hit points. + sdf_hit_point_id (wp.array): An array to store the computed hit point ids. + use_sign_winding_number (bool): Flag to use sign_winding_number method for SDF. + + Returns: + None + """ + tid = wp.tid() + + if use_sign_winding_number: + res = wp.mesh_query_point_sign_winding_number(mesh, points[tid], max_dist) + else: + res = wp.mesh_query_point_sign_normal(mesh, points[tid], max_dist) + + mesh_ = wp.mesh_get(mesh) + + p0 = mesh_.points[mesh_.indices[3 * res.face + 0]] + p1 = mesh_.points[mesh_.indices[3 * res.face + 1]] + p2 = mesh_.points[mesh_.indices[3 * res.face + 2]] + + p_closest = res.u * p0 + res.v * p1 + (1.0 - res.u - res.v) * p2 + + sdf[tid] = res.sign * wp.abs(wp.length(points[tid] - p_closest)) + sdf_hit_point[tid] = p_closest + sdf_hit_point_id[tid] = res.face + + +def signed_distance_field( + mesh_vertices: list[tuple[float, float, float]], + mesh_indices: NDArray[float], + input_points: list[tuple[float, float, float]], + max_dist: float = 1e8, + include_hit_points: bool = False, + include_hit_points_id: bool = False, + use_sign_winding_number: bool = False, +) -> wp.array: + """ + Computes the signed distance field (SDF) for a given mesh and input points. + + Parameters: + ---------- + mesh_vertices (list[tuple[float, float, float]]): List of vertices defining the mesh. + mesh_indices (list[tuple[int, int, int]]): List of indices defining the triangles of the mesh. + input_points (list[tuple[float, float, float]]): List of input points for which to compute the SDF. + max_dist (float, optional): Maximum distance within which to search for + the closest point on the mesh. Default is 1e8. + include_hit_points (bool, optional): Whether to include hit points in + the output. Default is False. + include_hit_points_id (bool, optional): Whether to include hit point + IDs in the output. Default is False. + + Returns: + ------- + wp.array: An array containing the computed signed distance field. + + Example: + ------- + >>> mesh_vertices = [(0, 0, 0), (1, 0, 0), (0, 1, 0)] + >>> mesh_indices = np.array((0, 1, 2)) + >>> input_points = [(0.5, 0.5, 0.5)] + >>> signed_distance_field(mesh_vertices, mesh_indices, input_points).numpy() + Module ... + array([0.5], dtype=float32) + """ + + wp.init() + mesh = wp.Mesh( + wp.array(mesh_vertices, dtype=wp.vec3), wp.array(mesh_indices, dtype=wp.int32) + ) + + sdf_points = wp.array(input_points, dtype=wp.vec3) + sdf = wp.zeros(shape=sdf_points.shape, dtype=wp.float32) + sdf_hit_point = wp.zeros(shape=sdf_points.shape, dtype=wp.vec3f) + sdf_hit_point_id = wp.zeros(shape=sdf_points.shape, dtype=wp.int32) + + wp.launch( + kernel=_bvh_query_distance, + dim=len(sdf_points), + inputs=[ + mesh.id, + sdf_points, + max_dist, + sdf, + sdf_hit_point, + sdf_hit_point_id, + use_sign_winding_number, + ], + ) + + if include_hit_points and include_hit_points_id: + return (sdf, sdf_hit_point, sdf_hit_point_id) + elif include_hit_points: + return (sdf, sdf_hit_point) + elif include_hit_points_id: + return (sdf, sdf_hit_point_id) + else: + return sdf diff --git a/requirements.txt b/requirements.txt index 7efcb16d5f..0858a3520e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,11 +3,13 @@ einops h5py hydra-core imageio +importlib_metadata matplotlib meshio==5.3.4 numpy>=1.20.0,<2.0.0 pydantic>=2.5.0 pyevtk +pyvista==0.34.2 pyyaml requests scikit-learn<1.5.0 @@ -15,6 +17,9 @@ scikit-optimize scipy seaborn sympy +termcolor tqdm +treelib typing-extensions +warp-lang wget