diff --git a/README.md b/README.md index 8e9ec8912..12410f9ec 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计 | 微分方程 | [若斯叻方程](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/rossler) | 数据驱动 | Transformer-Physx | 监督学习 | [Data](https://github.com/zabaras/transformer-physx) | [Paper](https://arxiv.org/abs/2010.03957) | | 算子学习 | [DeepONet](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/deeponet) | 数据驱动 | MLP | 监督学习 | [Data](https://deepxde.readthedocs.io/en/latest/demos/operator/antiderivative_unaligned.html) | [Paper](https://export.arxiv.org/pdf/1910.03193.pdf) | | 微分方程 | [梯度增强的物理知识融合 PDE 求解](https://github.com/PaddlePaddle/PaddleScience/blob/develop/examples/gpinn/poisson_1d.py) | 机理驱动 | gPINN | 无监督学习 | - | [Paper](https://doi.org/10.1016/j.cma.2022.114823) | +| 微分方程 | [PDE 求解](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/data_efficient_nopt) | 数据驱动 | FNO/Transformer | 无监督学习 | - | [Paper](https://arxiv.org/abs/2402.15734) | | 积分方程 | [沃尔泰拉积分方程](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/volterra_ide) | 机理驱动 | MLP | 无监督学习 | - | [Project](https://github.com/lululxvi/deepxde/blob/master/examples/pinn_forward/Volterra_IDE.py) | | 微分方程 | [分数阶微分方程](https://github.com/PaddlePaddle/PaddleScience/blob/develop/examples/fpde/fractional_poisson_2d.py) | 机理驱动 | MLP | 无监督学习 | - | - | | 光孤子 | [Optical soliton](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/nlsmb) | 机理驱动 | MLP | 无监督学习 | - | [Paper](https://doi.org/10.1007/s11071-023-08824-w)| diff --git a/docs/index.md b/docs/index.md index 4a9e43aa5..3fb6f4b4f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -83,6 +83,7 @@ | 微分方程 | [若斯叻方程](./zh/examples/rossler.md) | 数据驱动 | Transformer-Physx | 监督学习 | [Data](https://github.com/zabaras/transformer-physx) | [Paper](https://arxiv.org/abs/2010.03957) | | 算子学习 | [DeepONet](./zh/examples/deeponet.md) | 数据驱动 | MLP | 监督学习 | [Data](https://deepxde.readthedocs.io/en/latest/demos/operator/antiderivative_unaligned.html) | [Paper](https://export.arxiv.org/pdf/1910.03193.pdf) | | 微分方程 | [梯度增强的物理知识融合 PDE 求解](https://github.com/PaddlePaddle/PaddleScience/blob/develop/examples/gpinn/poisson_1d.py) | 机理驱动 | gPINN | 无监督学习 | - | [Paper](https://doi.org/10.1016/j.cma.2022.114823) | +| 微分方程 | [PDE 求解](./zh/examples/data_efficient_nopt.md) | 数据驱动 | FNO/Transformer | 无监督学习 | - | [Paper](https://arxiv.org/abs/2402.15734) | | 积分方程 | [沃尔泰拉积分方程](./zh/examples/volterra_ide.md) | 机理驱动 | MLP | 无监督学习 | - | [Project](https://github.com/lululxvi/deepxde/blob/master/examples/pinn_forward/Volterra_IDE.py) | | 微分方程 | [分数阶微分方程](https://github.com/PaddlePaddle/PaddleScience/blob/develop/examples/fpde/fractional_poisson_2d.py) | 机理驱动 | MLP | 无监督学习 | - | - | | 光孤子 | [Optical soliton](./zh/examples/nlsmb.md) | 机理驱动 | MLP | 无监督学习 | - | [Paper](https://doi.org/10.1007/s11071-023-08824-w)| diff --git a/docs/zh/examples/data_efficient_nopt.md b/docs/zh/examples/data_efficient_nopt.md new file mode 100644 index 000000000..8fad20c63 --- /dev/null +++ b/docs/zh/examples/data_efficient_nopt.md @@ -0,0 +1,266 @@ +# DataEfficientNopt + +## 论文信息 + +| 年份 | 会议 | 作者 | 引用数 | 论文 PDF | +| ---- | ---------------------------------------------------------------- | ------------------------------------------------------ | ------ | ----------------------------------------------------------------------------------------------------------------------------- | +| 2024 | 38th Conference on Neural Information Processing Systems (NeurIPS 2024) | 12 | Data-Efficient Operator Learning via Unsupervised Pretraining and In-Context Learning | + +## 代码信息 + +| Model | Checkpoint | **R2** | **Slope** | +| :---------: | :--------------------: | :----: | :-------: | +| FNO_Possion | [finetune_b01_m0_n8192](https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/possion_data/finetune_b01_m0_n8192.pdparams) | 0.9765 | 0.9752 | + +=== "模型训练命令" + + ``` sh + # Download possion_64 data and model into `examples/data_efficient_nopt/data` + cd examples/data_efficient_nopt + mkdir -p data/possion_64 && cd data/possion_64 + wget https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/possion_data/poisson_64_e1_20_train.h5 + wget https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/possion_data/poisson_64_e1_20_val.h5 + wget https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/possion_data/poisson_64_e1_20_test.h5 + wget https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/possion_data/poisson_64_e1_20_train_scale.npy + wget https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/possion_data/poisson_64_e5_15_train.h5 + wget https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/possion_data/poisson_64_e5_15_train_scale.npy + wget https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/possion_data/poisson_64_e5_15_val.h5 + wget https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/possion_data/poisson_64_e5_15_test.h5 + wget https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/possion_data/train_rand_idx.npy + wget https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/possion_data/pretrain_b01_m0.pdparams + + + # pretrain + cd ../.. + python data_efficient_nopt.py \ + --config-name data_efficient_nopt_fno_poisson \ + config=pois-64-pretrain-e1_20_m0 + + # finetune + python data_efficient_nopt.py \ + --config-name data_efficient_nopt_fno_poisson \ + mode=finetune \ + config=pois_64_finetune_e5_15 \ + train_config.pois_64_finetune_e5_15.pretrained_ckpt_path="./data/possion_64/pretrain_b01_m0.pdparams" + ``` + +=== "模型推理命令" + + ``` sh + cd examples/data_efficient_nopt + mkdir -p data/possion_64 && cd data/possion_64 + wget https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/possion_data/finetune_b01_m0_n8192.pdparams + wget https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/possion_data/poisson_64_e15_50_test.h5 + wget https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/possion_data/poisson_64_e15_50_train.h5 + wget https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/possion_data/poisson_64_e5_15_train_scale.npy + + cd ../.. + python data_efficient_nopt.py \ + --config-name=data_efficient_nopt_fno_poisson.yaml \ + mode=infer \ + infer_config.ckpt_path=./data/possion_64/finetune_b01_m0_n8192.pdparams + ``` + +## 1. 背景简介 + +在广阔而深邃的科学与工程领域中,许多最为核心且复杂的问题都绕不开**偏微分方程(PDEs)**的精确求解。这些方程是理解和描述从宏观宇宙现象(如星系演化、天气系统)到微观世界(如量子力学、分子动力学)以及工程应用(如流体力学、结构力学、热传导、电磁学)等各种物理过程与自然规律的基本数学框架。长期以来,传统的数值求解方法,例如**有限元法(Finite Element Method, FEM)、有限体积法(Finite Volume Method, FVM)和有限差分法(Finite Difference Method, FDM)**,一直是求解这些复杂PDEs的主流且不可或缺的手段。然而,这些方法往往面临着**计算成本极其高昂、网格生成复杂、迭代求解耗时、难以处理高维问题以及在复杂几何形状上效率低下**等诸多挑战。它们对计算资源的巨大需求和潜在的数值误差也常常限制了其在大规模和实时应用中的可行性。 + +近年来,随着**机器学习(Machine Learning, ML)**,特别是**深度学习(Deep Learning, DL)**技术的爆炸式发展和在各领域的突破性应用,研究者们开始积极探索将这些前沿的计算范式与深厚的物理领域知识相结合,催生了**科学机器学习(Scientific Machine Learning, SciML)**这一新兴交叉学科。在SciML的众多前沿方向中,**算子学习(Operator Learning)**无疑是一个前景最为广阔且极具颠覆性的研究领域。其核心目标是超越传统神经网络仅能学习有限维度输入到有限维度输出映射的局限,转而学习从**一个函数空间到另一个函数空间**的映射,即直接学习PDEs的**解算子(Solution Operator)**本身。这种对“算子”而非“点”的学习,使得模型能够处理不同分辨率、不同离散化方案甚至连续的函数作为输入和输出,从而赋予了模型前所未有的**更强泛化能力和分辨率无关性**。例如,像**深度算子网络(DeepONet)和傅里叶神经算子(Fourier Neural Operator, FNO)**等模型,正是通过学习这种高维函数映射,展现出强大的潜力,能够高效、准确地解决一系列复杂的PDEs。这些开创性的工作不仅为传统数值方法提供了革命性的新替代方案,更在计算效率和对未见数据的泛化能力上实现了质的飞跃,为各种科学应用开辟了前所未有的新途径。 + +然而,尽管深度学习方法在求解PDEs方面取得了令人瞩目的显著进展,并预示着计算效率的巨大提升,但它们,特别是**神经算子(Neural Operators)**,如同大多数深度学习方法一样,也面临着普遍存在的**数据饥渴(data-hungry)挑战**。为了训练出高性能、高精度的神经算子,使其能够准确捕捉复杂物理现象的非线性关系和多尺度特性,通常需要**极其大量的、高质量且已精确标记(即包含输入参数和对应真实解)的PDE解数据**。而这些关键数据,正如论文中所反复强调的,往往是通过运行**高精度的数值模拟**获得的。这些模拟过程本身就异常**耗时费力且计算资源消耗巨大**,尤其是在处理高维、复杂几何、多物理场耦合或需要长时间积分的物理场景时。论文中生动提及的例子——**模拟一场旧金山7.0级地震的整个过程,便需要投入巨大的计算能力和漫长的计算时间**——这直观地揭示了获取高质量PDE数据所涉及的巨大成本和时间代价。这种对昂贵数据的固有依赖,在一定程度上削弱了机器学习方法最初旨在**避免传统昂贵数值模拟**的初衷。论文发现,自己陷入了一个**看似矛盾的悖论**:为了摆脱高昂的数值模拟负担,论文转向了以深度学习为代表的先进计算范式,但讽刺的是,深度学习反过来又要求论文进行大量的,甚至可能更加耗时和资源密集型的数值模拟来生成其训练数据。这使得数据的收集和生成,而非模型训练本身,成为了应用这些先进机器学习方法的主要**瓶颈和障碍**,严重限制了它们在实际科学和工程领域的普适性和效率。 + +为了**根本性地应对这一在科学机器学习领域日益突出的数据效率和成本挑战**,本研究提出了一种**具有高度创新性的数据高效算子学习方法**。该方法的核心目标是**显著减少**训练神经算子对昂贵数值模拟所产生标记数据的严重依赖。论文的核心策略经过精心设计,包含**两个紧密相连且相辅相成的关键组成部分**,共同构筑了解决上述瓶颈的有效路径。 + +1. **首先,论文设计并提出了一种新颖的**无监督预训练(Unsupervised Pretraining)**方案,旨在从根本上提升神经算子处理数据时的效率和学习能力。** 其核心思想在于**充分利用海量的、相对易于获取的未标记PDE数据**——即使这些数据缺乏对应的精确模拟解。这些未标记数据可能来源于简化模拟、低精度测量、或者仅仅是物理参数的配置集合,它们生成成本远低于完整的标记解。论文通过构建一系列**受物理学启发且基于重建(Reconstruction-Based)的代理任务(Proxy Tasks)**来对神经算子进行预训练。这些代理任务的目标是巧妙地迫使模型从不完整、受损或经过特定扰动的输入中,精确地重建出原始的PDE数据。例如,模型可能被要求恢复输入场中被随机掩码(遮盖)的区域、将低分辨率的物理场提升为高分辨率、或者强制模型输出满足已知的物理守恒律。通过执行这些无监督的重建任务,神经算子被激励和引导去**学习并内化 PDE 底层的物理规律、数据分布的固有特征以及复杂的非线性映射关系**,而无需依赖任何昂贵的监督信号(即精确的模拟解)。这种创新的方法极大地扩展了可用的训练数据范围,因为它允许论文利用大量成本极低、易于获取的未标记数据来为模型提供丰富的先验知识和领域适应性,从而为后续在少量标记数据上的微调奠定了坚实的基础。 + +2. **其次,为了在预训练的基础上进一步显著增强模型在**分布外(Out-of-Distribution, OOD)**场景下的泛化性能和鲁棒性**,论文巧妙地引入了一种**基于相似性的上下文学习(Similarity-Based In-Context Learning)机制**。这种机制允许预训练的神经算子在**推理阶段**,以一种高度灵活且动态的方式,利用少量与当前查询输入在物理特性上高度相似的“上下文示例”(或称“演示样本”)来辅助其进行预测。这些上下文示例通常是从一个小型、高质的标记演示数据集中检索而来,包含与当前查询相似的物理参数输入及其对应的精确解。其关键在于,这种上下文学习机制**无需任何额外的训练成本或复杂的设计**,它是在模型已经完成预训练和/或微调的基础上,在推理时才发挥作用。模型首先计算当前查询输入与预先存储的上下文示例之间的相似度(例如通过高维特征空间的距离度量),然后利用这些被检索到的相似示例所蕴含的丰富信息,动态调整或优化其对当前查询的预测结果。这种非参数化的适应能力使得模型能够有效地适应并处理那些在训练阶段从未见过的、甚至超出其训练数据分布范围的物理条件,极大地提升了模型在复杂真实世界场景中的实用性和可靠性。 + +论文对本框架在多组不同类型的PDEs上进行了广泛而深入的实证评估,涵盖了从线性到非线性的多种物理现象和复杂性,并验证了其在模拟真实世界场景中的有效性。实验结果以无可辩驳的证据有力地证明了论文方法所具备的卓越性能:它展现出惊人的**数据效率**,能够大幅削减对昂贵标记数据的需求,显著降低了数据生成成本;同时,模型通过无监督预训练内化了物理知识,使得其输出**具有更高的可解释性**,其预测结果更符合物理直觉,并能更好地反映底层的物理原理和结构;在处理复杂物理问题时,模型表现出**显著的优越性能**,其预测精度和鲁棒性均达到甚至超越了当前领域的领先水平。尤其值得注意的是,论文的方法在某些关键指标上甚至**超越了在通用图像或视频数据集上预训练的传统视觉预测模型**,这进一步凸显了其在科学计算领域,尤其是在解决当前普遍面临的**数据稀缺性问题**上所蕴含的巨大潜力。因此,本研究的背景旨在直击神经算子在**数据效率方面的核心痛点**,寻求一种能够有效利用海量廉价无标签物理数据的方法,以期**显著减少对昂贵PDE解数据的依赖**,从而最终使神经算子在更广泛的实际科学和工程应用中变得更具**可行性、普适性和成本效益**。· + +## 2. 问题定义 + +### 2.1 数据集呈现 + +数据集分别从[Goole Drive for Helmholtz](https://drive.google.com/drive/folders/1UjIaF6FsjmN_xlGGSUX-1K2V3EF2Zalw)和[Google Drive for Poisson](https://drive.google.com/drive/folders/1crIsTZGxZULWhrXkwGDiWF33W6RHxJkf),目前代码提供FNO_Poisson和FNO_Helmholtz网络的预训练和推理。 + +- 对Poisson方程,论文使用扩散特征值[1, 20]进行预训练,[5, 15]进行微调。 +- 对Helmholtz方程,论文使用扩散特征值[1, 20]进行预训练,[5, 15]进行微调。 + +### 2.2 核心问题 + +在诸多至关重要的科学计算领域,例如复杂的气候与天气预报模型、前沿的材料科学研究、精密的生物医学模拟以及基础物理学的各项探索等,**偏微分方程(PDEs)**始终扮演着不可或缺的核心角色。它们不仅仅是抽象的数学工具,更是描述从微观粒子行为到宏观宇宙演变等各种自然现象的根本性语言。长期以来,诸如**有限元法(Finite Element Method, FEM)、有限体积法(Finite Volume Method, FVM)和有限差分法(Finite Difference Method, FDM)**等传统的数值求解方法,一直是解决这些复杂方程的主流且被广泛信赖的手段。然而,这些经典方法在面对高维问题、复杂几何边界或需要高精度解析时,往往会遭遇**计算成本高昂、数值稳定性挑战以及效率相对较低**等重重瓶颈。其固有的离散化误差和迭代收敛速度也常常限制了它们的实际应用范围和效率。 + +近期,伴随着**深度学习技术**的迅猛发展和在各领域的成功应用,研究者们开始积极探索利用神经网络这一强大的工具来近似求解PDEs。这不仅为传统数值方法提供了革命性的新替代方案,更在计算效率和泛化能力上实现了显著的飞跃。例如,**深度算子网络(DeepONet)**和**傅里叶神经算子(Fourier Neural Operator, FNO)**等开创性模型,突破了传统神经网络仅学习点对点映射的局限,转而致力于学习从**输入函数空间到输出函数空间的映射**,即直接学习PDEs的**解算子(solution operator)**本身。这种范式转变使得这些模型能够以惊人的效率和精度解决一系列复杂的PDEs,它们能够快速地对新的输入函数(如不同的初始条件或物理参数)给出预测,而无需重新进行耗时的数值模拟。这些具有前瞻性的工作无疑为基于PDE的科学计算领域带来了深刻的革新,为从基础科学研究到工程应用等各种科学领域开辟了前所未有的新途径。 + +尽管深度学习方法在求解PDEs方面取得了令人瞩目的显著进展,并且在理论上预示着计算效率的巨大提升,但其成功的实现和高性能的维持,在很大程度上仍然**高度依赖于获得大量高质量、且已进行精确标记(即带有对应真实解)的训练数据**。这些宝贵的标记数据,通常并非直接从物理实验中采集,而是通过运行**传统的高精度数值模拟(如利用有限元方法对PDE进行耗时的离散化和求解)**来生成。然而,这些复杂的数值模拟过程本身就极其**耗时、费力且对计算资源消耗巨大**,特别是在处理高维、多物理场耦合或涉及复杂非线性的物理场景时,往往需要动用超级计算机进行数天甚至数周的运算。这种对大规模昂贵标记数据的固有依赖,在一定程度上反而**削弱了深度学习方法最初旨在降低计算成本的核心优势**。简而言之,论文在实践中发现自己陷入了一个**看似矛盾的悖论**:为了避免进行高昂的数值模拟,论文转向了深度学习这一先进工具,但深度学习反过来又要求论文进行大量的,甚至可能更加繁重的数值模拟来生成其所需的训练数据。这种循环使得**数据的收集和生成**成为应用这些先进机器学习方法,特别是深度算子学习方法时所面临的**主要瓶颈和限制因素**。 + +为了**根本性地应对这一数据效率和成本的严峻挑战**,论文提出了一种**具有高度创新性的数据高效算子学习方法**。该方法旨在从根本上**显著减少**训练过程中对昂贵数值模拟所产生标记数据的依赖。论文的核心策略经过精心设计,包含**两个紧密相连且相辅相成的关键组成部分**:**无监督预训练(Unsupervised Pretraining)**和**上下文学习(In-Context Learning)**,二者共同构筑了解决上述瓶颈的有效路径。 + +1. **首先,论文设计并提出了一种新颖的无监督预训练方案,旨在显著提升神经算子处理数据时的效率。** 该方案的核心思想是**充分利用海量的、相对易于获取的未标记PDE数据**。这些数据的一个关键特征是它们**没有对应精确的模拟解**,这意味着它们无需经过耗时的数值模拟过程,例如可以是从物理系统传感器实时采集的原始参数数据,或是通过低精度、快速模拟得到的近似物理场。论文通过构建**基于物理启发的重建代理任务(Physics-Inspired Reconstruction Proxy Tasks)**来对神经算子进行预训练。这些代理任务的目标是巧妙地迫使模型从不完整、受损或经过特定扰动的输入中,精确地重建出原始的PDE数据。例如,模型可能被要求恢复被掩码的区域,或者将低分辨率的输入提升为高分辨率。通过执行这些重建任务,模型被激励和引导去**学习并内化 PDE 底层的物理规律和数据分布的固有特征**,而无需依赖任何昂贵的监督信号(即精确的模拟解)。这种创新的方法极大地扩展了可用的训练数据范围,因为它允许论文利用大量成本极低、易于获取的未标记数据来为模型提供丰富的先验知识和领域适应性。 + +2. **其次,为了进一步显著提升模型在**分布外(Out-of-Distribution, OOD)**场景下的泛化性能和适应性**,论文巧妙地引入了一种**基于相似性的上下文学习机制(Similarity-Based In-Context Learning Mechanism)**。这种机制允许预训练的神经算子在**推理阶段**,以一种高度灵活且动态的方式,利用少量与当前查询输入在物理特性上高度相似的“上下文示例”来辅助其进行预测。这些上下文示例通常包含来自训练数据集中与当前查询输入相似的物理参数以及它们对应的精确解。其关键在于,这种上下文学习机制**无需额外的训练成本或复杂的设计**,它是在模型已经完成预训练和/或微调的基础上,在推理时才发挥作用。具体实现上,模型会首先计算当前查询输入与预先存储的上下文示例之间的相似度(例如通过嵌入空间的距离),然后利用这些被检索到的相似示例所蕴含的丰富信息,来动态调整或优化其对当前查询的预测结果。这种机制使得模型能够有效地适应并处理那些在训练阶段从未见过的、甚至超出其训练数据分布范围的物理条件,从而大幅提升了其在复杂真实世界场景中的实用性和可靠性。 + +论文对本框架在多组不同类型的PDEs上进行了广泛而深入的实验评估,涵盖了多种物理现象和复杂性。实验结果以无可辩驳的证据有力地证明了论文方法所具备的卓越性能。具体而言,该框架展现出惊人的**数据效率**,能够大幅削减对昂贵标记数据的需求;其**更高的可解释性**体现在模型学习到的表示能够更好地反映底层的物理原理和结构,使模型的决策过程更具物理意义;同时,在处理复杂物理问题时,模型表现出**显著的优越性能**,其预测精度和鲁棒性均达到业界领先水平。值得特别注意的是,论文的方法在某些关键指标上甚至**超越了专门用于视觉预测的传统模型**,这进一步凸显了其在科学计算领域,尤其是在解决当前普遍面临的**数据稀缺性问题**上所蕴含的巨大潜力,为未来科学人工智能的发展提供了新的范式和方向。 + +### 3. 问题求解 + +论文详尽阐述了为提升算子学习数据效率所提出的方法。如图1所示,论文的框架融合了两个核心组件:无监督预训练和上下文学习。无监督预训练旨在利用大量未标记数据来学习物理先验知识,从而在下游任务中减少对昂贵标记数据的需求。而上下文学习则通过整合相似的上下文示例,进一步增强模型在分布外(OOD)场景下的泛化能力, + +![fig1](https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/fig1.png) + +#### 3.1 无监督预训练 (Unsupervised Pretraining) + +无监督(或自监督)预训练的核心思想是使用合理设计的代理任务来训练神经网络。这些代任务需要标签数据,但它们被设计得与感兴趣的监督学习目标高度相关。虽然在计算机视觉和自然语言处理领域很受欢迎,但在机器学习的偏微分方程算子学习中,从未探索过对未标记数据的无监督预训练。无监督预训练的目标是在不使用昂贵PDE解数据的情况下,预先训练神经算子以学习物理相关的特征表示,为后续的监督微调提供一个良好的初始化。 + +##### 3.1.1 无标签PDE数据 + +一般来说,当神经算子在偏微分方程(PDE)数据集上进行训练时,它被设计并学习构建一个复杂的映射关系,即将多模态的输入信息——包括**物理参数**(如材料属性、介质常数)、**空间坐标**、**外部强迫函数**(driving forces)以及**初始条件**(或边界条件)等——精确地映射到相应的PDE解,即物理场在特定空间和时间点的分布。因此,给定一组通过数值模拟或实验观测收集到的PDE数据,其**无标签版本**被明确定义为:仅仅包含上述各类输入信息,而**不附带其对应的精确PDE解**的数据。论文所指的无标签PDE数据,是一个更广泛的概念,它涵盖了与所建模的PDE系统相关的各种输入配置和状态,而无需付出昂贵的计算成本去获取这些输入所对应的完整解。 + +为了更具象化这一概念,我们可以考虑一个**二阶线性偏微分方程**作为通用示例,其形式如下: + +$$\sum_{i,j=1}^{n}a_{ij}(x)u_{{x_i}{x_j}}+\sum_{i=1}^n{b_i(x)u_{x_i}}+c(x)u=f(x)$$ + +在这个方程中,$x \in R^n$ 代表在不同物理系统中变化的**物理空间维度**,其中 $n$ 可以是空间维度,若考虑时间相关的PDE,则 $n$ 还可以包含时间维度(例如,对于二维时变PDE,我们可以将时间作为第三个维度,此时 $n=3$)。方程中的系数 $a_{ij}(x)$、$b_i(x)$ 和 $c(x)$ 通常代表了与物理过程相关的**物理参数或材料属性**,它们的值通常是已知的或可配置的。$u$ 代表的是我们试图求解的目标解,即描述物理现象的未知场(例如温度分布、流体速度、电磁场强度等)。而 $f(x)$ 则表示作用于系统上的**外部强迫函数**,它是驱动系统演化的外部激励。在这种通用框架下,我们可以考虑两种主要情况,即PDE解是**不可用(unlabeled)**的场景: + +* **时间无关方程(Time-Independent Equations):** 对于稳态或时间无关的PDE,其解不随时间变化。在此情况下,论文的**未标记PDE数据**将仅包括定义该方程的各项输入,即**物理参数 $a_{ij}$、$b_i$、$c$ 的具体数值或空间分布、外部强迫函数 $f$ 的定义,以及所考虑的**空间坐标(即离散物理空间的网格点或其几何描述)**。在这种场景下,尽管我们拥有完整的输入信息,但由于计算资源限制或模拟耗时等原因,我们并未对每个输入配置都运行昂贵的数值求解器来获取对应的稳态解。这使得大量的输入参数组合可以作为无标签数据被收集和利用。 + +* **时间相关方程(Time-Dependent Equations):** 对于涉及时间动态演化的PDE,获取其随时间步进的完整解序列通常需要极其高昂的计算成本。在这种情况下,论文的**未标记PDE数据**策略是:**在不模拟完整时间动态的情况下,仅收集定义PDE系统的初始快照 $u_0(x)$**。这意味着我们只拥有系统在某个起始时刻的状态信息,而其在后续时间步长上的演化过程(即解 $u(x, t)$)则未被模拟或标注。请注意,在**大规模场景**中,收集带有时间动态的连续快照序列,比仅仅捕获单个瞬时快照要**复杂得多**,且对资源要求更高。例如,像**天气预报**或**烟雾分散模拟**这样的应用,需要对大气或流场进行**连续的监控**和利用**多个传感器的协同工作**来获取实时的、随时间变化的物理量。这种连续性监测涉及庞大的数据流、复杂的同步机制和巨大的存储需求。与此形成鲜明对比的是,获取**单个测量值**或某个初始状态的数据则相对简单得多,且对计算和资源的需求也**远不那么密集**。与仅仅进行一次性测量或记录不同,**长期的数据收集**通常需要部署**广泛的传感器网络**,并涉及复杂的**数据传输、存储和后期处理管道**,这进一步凸显了时间相关PDE数据获取的成本和挑战。 + +在偏微分方程(PDE)领域,收集大规模、高质量的PDE数据往往伴随着极其昂贵的计算成本,其中一个核心且关键的原因在于数值模拟过程中所涉及的复杂且耗时的**时间推进(Time Propagation)**过程。这种逐步演进的计算方法,尤其是在**高精度、多维度**以及**长时间尺度**的模拟中,会消耗**巨量的计算资源**(包括CPU/GPU计算能力、内存和存储)和**宝贵的时间**。每一个时间步长都需要进行复杂的数值运算,确保数值稳定性,并累积误差。然而,论文提出了一种**极具成本效益的替代方案**,旨在打破这一数据获取的瓶颈:即**仅仅生成未经标签标注的PDE数据**,特别是那些不包含复杂时间动态的瞬时“快照”或纯粹的物理参数配置。这类数据的生成所付出的计算代价将**远低于**运行完整且带有精确解的数值模拟来生成标记数据。例如,获取一个初始时刻的场分布,或仅仅是物理参数的组合,不需要进行长时间的迭代求解。这一**显著的成本差异**使得论文所提出的**无监督预训练方法在实践中变得高度可行且极具吸引力**。这种策略不仅极大地降低了数据获取的门槛,也为深度学习模型在PDE领域的应用带来了前所未有的数据效率提升。 + +论文所采取的这种创新训练策略被证明是**极其高效**的。通过巧妙地利用**大量的、廉价且易于获取的无标签数据**进行预训练,论文能够成功地**规避**在传统方法中模拟大型、高保真标记解决方案时,因处理复杂时间依赖方程所产生的**沉重计算负担**。这意味着,在无需耗费巨额计算资源来求解每一个时刻的详细解、进行漫长的时域积分的情况下,模型依然能够通过对无标签数据的学习,从本质上**内化重要的物理规律和底层数据特征**。这种“先学习基础知识,再进行精细调整”的范式,使得整个训练流程在**资源消耗和时间成本**上都表现出卓越的效率,从而为PDE问题的机器学习求解提供了一条更加经济、快捷的途径。 + +除了上文所述的在数据获取成本方面的显著优势,对未标记的偏微分方程(PDE)数据进行无监督预训练,还为神经算子的学习过程带来了多方面、深远且不可忽视的额外益处: + +1. 强大的正则化效应有助于预防过拟合 (Potent Regularization for Overfitting Prevention) + +无监督预训练过程本身就扮演着一种**强大而内在的正则化手段**。这种学习范式迫使模型在没有明确的监督信号或标签信息的情况下,主动地、自适应地从海量的无标签数据中捕捉到**更深层次、更本质的结构和不变性**。这意味着模型不仅仅学习表面的数据关联,而是深入理解数据背后的生成机制和潜在规律。这种自适应的学习过程有助于模型建立起**更加鲁棒和泛化能力更强的内部表示**,从而使其在面对新的、未曾见过的数据时,能够表现出卓越的性能,有效避免了模型对训练数据的过度拟合。尤其在**标记数据稀缺**的场景下,无监督预训练的重要性尤为突出。它能够有效缓解传统深度学习中因数据量不足而导致的过拟合问题,使得模型不会过度依赖有限的标记样本的特定噪声或偏差,而是学习到更具普遍性和迁移性的特征,从而提高了模型在实际应用中的可靠性。这种正则化作用类似于为模型打下了坚实的基础,使其在后续的监督学习任务中能够更加稳健地泛化到未见数据。 + +2. 显著加速模型收敛 (Significantly Accelerated Model Convergence) + +对大规模未标记PDE数据进行无监督预训练,为神经算子提供了一个**极佳的、领域适应性强的初始化表示**。换言之,模型在正式开始有监督微调阶段之前,就已经通过无监督学习对PDE数据所处的**函数空间、潜在的物理规律、数据点之间的相互依赖关系以及数据的固有结构**有了初步且深刻的理解。这种经过“预热”和“领域对齐”的初始化状态,使得模型在后续的有监督训练阶段能够从一个**更优、更接近全局最优解的起点**开始优化,从而大大加速了训练的收敛速度。模型不再需要从零开始在庞大且复杂的参数空间中进行盲目探索,而是能够沿着一个已经预设好的、更有利于优化的方向进行迭代。这直接导致了**训练周期的显著缩短**,意味着在相同的时间内可以完成更多次的实验迭代,或者以更短的时间达到更高的性能水平。这种效率的提升不仅节约了宝贵的计算资源(如GPU小时),也极大地加速了科研和工程项目的开发进程。 + +3. 提取具有深层意义的表示 (Extraction of Profoundly Meaningful Representations) + +通过在海量的未标记PDE数据上进行无监督预训练,模型能够被有效引导,**自动地提取出那些对于后续神经算子学习至关重要的、高度抽象且富有意义的特征表示**。这些表示并非仅仅是对原始数据的简单复制或压缩,它们是数据中蕴含的**物理规则、多尺度模式、内在结构以及系统动态的高效编码**。例如,模型可能学习到如何识别流体的涡流结构、热传导的梯度模式、或波传播的特征,即使在预训练时并没有明确的标签告知它这些概念。一旦模型具备了这种**高质量的特征提取能力**,后续的神经算子在处理特定任务(例如预测某个特定物理参数下的PDE解)时,就能够**更加高效地利用这些预先学习到的、低维但信息丰富的特征**。这种高级表示的学习,避免了从原始高维数据中直接学习复杂映射的困难,从而极大地提升了神经算子在下游任务中的学习效率和预测精度。最终,这使得模型在解决复杂的科学计算问题时,不仅表现出**更强的能力和更广的适用性**,其输出也可能因为根植于更深刻的物理理解而更具**物理合理性和可解释性**。 + +##### 3.1.2 代理任务 + +为了系统且清晰地阐明论文在构建**代理任务(Proxy Tasks)**方面的通用方法学,本研究精心地选择了两种基于**重建(Reconstruction-Based)**的变体作为其核心代理任务。这些代理任务并非直接进行有监督预测,而是旨在通过巧妙设计的自监督机制,在模型**缺乏直接监督信号**(即没有精确的PDE解标签)的情况下,引导其学习到**有意义且鲁棒的内部表示**。具体而言,论文将未经标签标注的偏微分方程(PDE)数据作为输入,精心喂入论文所设计的神经算子。此后,经过一个专门设计的**解码器网络(Decoder Network)**处理,神经算子将被强制要求其输出能够**高度近似于原始输入数据**。这种自动编码器(Autoencoder)式的框架迫使模型去理解和内化输入的固有结构。论文在此框架下细致考量了两种关键的**扰动变体**,它们也可以被形象地称之为数据的**增强视图(Augmented Views)**。这两种设计思想并非凭空而来,而是直接来源于**真实世界中科学数据采集的实际设置和挑战**,并且它们深刻地代表了在科学机器学习(SciML)模型中**必须融入的重要不变性原则**,以确保模型在面对真实物理系统时的泛化能力、实用性和鲁棒性。 + +首先,论文深入探索并引入了**掩码自编码器(Masked Autoencoders, MAEs)**这一强大的自监督学习范式。MAEs在近年来已被广泛证明是具有**卓越可扩展性**的自监督学习范式,其在学习高质量特征表示方面的能力令人瞩目。该方法在概念上极为简洁,却异常有效:通过**策略性地移除输入数据中的一部分内容(即进行“掩码”操作)**,模型被训练来精确地预测并恢复这些被移除的信息。这种机制强迫模型不仅仅关注可见部分的局部特征,更要从全局语境和剩余信息中推断出缺失部分,从而学习到数据的深层、非局部依赖关系。正是这种强大的自我监督机制,使得MAEs(及其前身如BERT中的掩码语言模型)在自然语言处理(NLP)领域能够训练出规模庞大、参数量超过千亿的通用模型,它们能够捕捉复杂的语言结构和语义信息。同样地,在计算机视觉(CV)领域,MAEs也展现出颠覆性的潜力,例如通过掩码图像块来学习图像的丰富视觉特征。 + +在此研究中,论文深入探索了MAEs在**科学建模**,特别是**PDE领域**中的应用价值和深层潜力。其核心动机根植于一个重要的物理洞察:**PDE所描述的动力学过程对于全场科学数据的稀疏感知具有内在的不变性**。换言之,无论我们是通过完整的高分辨率传感器网络还是通过稀疏、离散的传感器点来观测物理场,底层的物理规律和现象本身是保持一致的。物理世界的本质不会因为我们的观测手段是完整还是稀疏而改变。在现实世界中,科学数据常常不可避免地需要从**离散或稀疏的传感器网络**中进行收集,例如气候监测站、地震传感器阵列、医学成像设备的有限采样,或工业生产线上的局部监测点。因此,**重建或生成缺失传感器区域的数据,以及从不完整观测中推断完整物理场**,成为一项极其常见且关键的任务。论文正是利用这一点,通过在未经标签标注的PDE数据中引入**随机掩码机制**——例如,随机遮盖一部分网格点或物理参数配置——来**强制模型学习对传感器稀疏性的不变性**。这种训练迫使模型从无标签PDE数据的各种扭曲视图中,有效提取出那些不随观测条件变化(即无论观测是稀疏还是密集)的本质特征和物理规律。通过学习这种对科学数据稀疏感知的不变性属性,MAE所学习到的数据表示将显著增强其**鲁棒性**(抵抗数据缺失或噪声的能力)和**实用性**(在真实世界不完整数据上的表现),从而使神经算子能够更好地应用于实际的科学计算和工程问题。 + +![fig9](https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/fig9.png) + +此外,对海量的、易于获取的未标记偏微分方程(PDE)数据进行无监督预训练,还为神经算子的学习过程带来了多方面、深远且不可或视的显著优势,这些优势超越了单纯的数据效率提升: + +首先,强大的正则化效应有助于预防过拟合,无监督预训练过程本身就扮演着一种**强大而内在的正则化手段**。这种学习范式迫使模型在**没有明确监督信号**(即不依赖昂贵的精确PDE解标签)的情况下,主动地、自适应地从大规模无标签数据中捕捉到**更深层次、更本质的结构和不变性**。这意味着模型不仅仅学习数据表面的关联,而是深入理解数据背后的生成机制和潜在物理规律。例如,它可能学习到能量守恒、质量守恒等基本物理原理的隐式表示,或者数据在不同条件下应表现出的对称性。这种**自适应的学习过程**有助于模型建立起**更加鲁棒和泛化能力更强的内部表示**,从而使其在面对新的、未曾见过的数据(包括来自不同分布的数据)时,能够表现出卓越的预测性能,有效避免了模型对训练数据的过度拟合。尤其在**标记数据极其稀缺**的科学计算场景下,无监督预训练的重要性尤为突出。它能够有效缓解传统深度学习中因数据量不足而导致的过拟合问题,使得模型不会过度依赖有限的标记样本的特定噪声或偏差,而是学习到更具普遍性和迁移性的物理特征,从而提高了模型在实际应用中的可靠性和外推能力。这种正则化作用类似于为模型打下了坚实而宽广的知识基础,使其在后续的监督学习任务中能够更加稳健地泛化到未见数据。 + +其次,显著加速模型收敛,对大规模未标记PDE数据进行无监督预训练,为神经算子提供了一个**极佳的、领域适应性强的初始化表示**。换言之,模型在正式开始有监督微调阶段之前,就已经通过无监督学习对PDE数据所处的**函数空间、潜在的物理规律、数据点之间的相互依赖关系以及数据的固有结构**有了初步且深刻的理解。这种经过“预热”和“领域对齐”的初始化状态,使得模型的参数已经位于一个相对“有意义”且接近最优解的区域。这与从完全随机的参数初始化开始训练形成鲜明对比,后者需要模型从零开始探索整个巨大的参数空间。因此,经过预训练的模型在后续的**有监督训练阶段**能够更快地找到最优解,从而**大大加速了训练的收敛速度**。优化器不再需要进行大量的“探索性”更新,而是能够沿着一个已经预设好的、更有利于优化的方向进行快速迭代。这直接导致了**训练周期的显著缩短**,意味着在相同的时间内可以完成更多次的实验迭代,或者以更短的时间达到更高的性能水平。这种效率的提升不仅节约了宝贵的计算资源(如GPU小时),降低了能源消耗,也极大地加速了科研和工程项目的开发进程,使得模型能够更快地投入实际应用。 + +第三,提取具有深层意义的表示,通过在海量的未标记PDE数据上进行无监督预训练,模型能够被有效引导,**自动地提取出那些对于后续神经算子学习至关重要的、高度抽象且富有深层意义的特征表示**。这些表示并非仅仅是对原始数据的简单复制或压缩,它们是数据中蕴含的**物理规则、多尺度模式、潜在结构以及系统动态的高效编码**。例如,模型可能在无监督学习中便内化了对流体边界层、激波、湍流涡旋、热扩散梯度或波传播特性等复杂物理现象的抽象理解,即使在预训练时并没有明确的标签告知它这些概念。这种对底层物理机制的隐式编码,使得模型能够从数据的表象中提炼出本质。一旦模型具备了这种**高质量的、富含物理信息的特征提取能力**,后续的神经算子在处理特定任务(例如预测某个特定物理参数下的PDE解、进行逆问题求解或进行参数发现)时,就能够**更加有效地利用这些预先学习到的、低维但信息丰富的特征**。这种高级表示的学习,避免了从原始高维数据中直接学习复杂映射的困难,降低了对有监督数据的需求,从而极大地提升了神经算子在下游任务中的学习效率和预测精度。最终,这使得模型在解决复杂的科学计算问题时,不仅表现出**更强的能力和更广的适用性**,其输出也可能因为根植于更深刻的物理理解而更具**物理合理性、可解释性和预测的可靠性**。 + +##### 3.1.3 偏微分方程 + +在无标签PDE数据上进行预训练后,论文会在PDE的模拟解上对神经算子进行微调。论文研究了两种与时间无关的PDE(泊松方程、亥姆霍兹方程)和两种与时间有关的PDE(反应-扩散方程、纳维-斯托克斯方程)。 + +##### 3.1.4 模型架构 + +为了与之前在神经算子学习领域的工作进行公平且深入的比较,论文慎重考量并最终选择了两种在当前机器学习社群中享有盛誉且应用广泛的模型架构。这两种架构均遵循**编码器-解码器(Encoder-Decoder)**的通用范式,其核心设计理念在于,当模型仅能观测到部分输入数据时,它能够通过内在的学习机制,精确地重建出原始完整的输入信息。具体来说,编码器的职责是将观测到的、未经标签标注的偏微分方程(PDE)数据,高效地映射到一个维度更低但信息更丰富的**潜在空间(Latent Space)**中。这个潜在空间捕捉了数据的核心特征和内在结构。随后,解码器则负责利用这个潜在表示,重新构建出原始的输入条件或其完整形态,从而实现对缺失信息的推理和重建。这种架构尤其适用于自监督学习场景,因为它允许模型通过自身的重建任务来学习强大的数据表示。 + +**傅里叶神经算子(Fourier Neural Operator, FNO)** + +傅里叶神经算子(FNO)作为一种开创性的架构,其独特之处在于它旨在**傅里叶空间**中学习PDE数据所蕴含的复杂规律。与传统的基于网格的方法不同,FNO通过将数据转换到频率域,能够更高效地捕捉到PDE解中的**全局依赖关系和长程相互作用**,这对于处理具有非局部特性的物理系统尤为关键。FNO的原始模型骨干,即其**编码器部分**,首先通过**傅里叶变换**将输入的空间数据转换到频域,随后利用一系列**线性变换**来学习并处理那些代表数据主要特征的**较低傅里叶模式**(即低频成分)。这种在频率域进行操作的能力,使得FNO能够以一种参数效率极高的方式近似复杂的函数映射,因为物理系统的许多核心行为往往可以通过其在频率空间中的少量模式来表征。最终,FNO骨干会将这些在频率域处理过的特征,通过逆傅里叶变换,高效地输出回**空间域**,形成在“像素级别”嵌入的物理量,从而直接与实际的物理空间坐标相对应。 + +* **预训练阶段:** 在无监督预训练阶段,为了实现有效的自监督学习,论文将傅里叶神经算子的**解码器部分**构建得与编码器在结构上保持一致,尽管它们的具体输入/输出维度可能有所差异。这种对称设计是自编码器范式中的常见策略,旨在确保模型在潜在空间中学习到有效的、可逆的表示。在此阶段,未经标签标注的PDE数据在**像素级别**上被随机地进行**掩码(masking)**处理。这意味着输入数据的部分像素值会被故意移除或置零,以此来强制模型学习如何从剩余的可见信息中,精确地推断并重建出这些缺失的像素,从而增强模型对局部信息完整性的理解和恢复能力。 +* **微调阶段:** 完成了无监督预训练后,模型已经学习到了PDE数据的丰富内在表示。在随后的**有监督微调阶段**,论文选择**丢弃预训练时使用的解码器**。其原因在于,微调的目标不再是重建原始输入,而是直接从物理参数函数 `a` 预测其对应的PDE解 `u`。此时,编码器已经足以提取出用于预测所需的高级特征。在编码器输出的特征之上,论文遵循原始FNO设计的惯例,额外**附加了两个全连接层(带有ReLU激活函数)**。这些层作为一个轻量级的回归器,负责将编码器学习到的抽象特征,进一步转换为最终的、高精度的空间解预测,同时引入非线性,以增强模型的表达能力。 + +**Transformer 架构** + +与傅里叶神经算子直接在网格或傅里叶空间上进行操作的方式截然不同,**Transformer架构**以其主要采用的**自注意力块(Self-Attention Blocks)**和**线性变换块(Linear Transformation Blocks)**而闻名。这种设计使其在处理序列数据方面表现出卓越的能力,并在自然语言处理(NLP)和计算机视觉(CV)等多个领域展现出突破性的潜力。Transformer的核心优势在于其能够捕获数据中的**长程依赖关系**和复杂的**上下文关系**,而无需依赖于固定的网格结构或对数据进行特定的傅里叶变换。为了适应PDE数据的特点,Transformer将原始的网格数据进行**分词(tokenization)**处理,并将其**分组为若干个“块”(patches)**。这意味着每个经过分词处理的块都嵌入了原始网格中一个**子网格的局部邻域信息**,从而将空间数据转化为一种更类似于序列的输入格式,以便强大的自注意力机制能够对其进行高效处理。论文在此研究中遵循了**Video-MAE的3D Transformer架构**,这是一种为处理视频或三维数据设计的、在掩码自编码器框架下表现出色的模型,其选择旨在有效处理PDE数据可能具备的三维或时空特性。论文的**编码器**通过**线性投影**将这些分词后的块转换为高维度的**嵌入向量**,并且,为了保留空间信息(因为分词过程会损失原始的空间位置信息),论文还像标准的Vision Transformer (ViT) 中那样,为这些嵌入向量添加了至关重要的**位置嵌入(Positional Embeddings)**。随后,这些包含了丰富局部内容和精确位置信息的嵌入向量集合,会通过一系列串联的**Transformer块**进行深度处理,从而捕获数据中复杂的交互和依赖。对于Transformer架构的无监督预训练,未经标签标注的PDE数据同样被随机地进行**掩码处理,但这次是在“块级别”**上进行,这意味着整个子网格区域被掩盖,而非单个像素。 + +* **预训练阶段:** 在Transformer的预训练阶段,为了实现计算效率,论文的**Transformer编码器**仅被应用于**可见令牌(即,未被掩码的块)的子集**,并且被掩码的块在编码器的输入阶段即被移除。这种不对称的处理方式使得编码器能够高效地从部分输入中学习到强大的表示。而**MAE解码器的输入**则是一个**完整的令牌集合**,它包含了两个关键部分:**(i) 编码器处理后得到的可见块的特征表示**,以及**(ii) 专门的“掩码令牌”(mask tokens)**。这些掩码令牌是一个**共享的学习向量**,它的存在明确地指示了原始输入中存在一个待预测的缺失块。为了确保解码器能够理解每个令牌(无论是可见的还是被掩码的)在完整输入中的空间位置,论文向这个完整的令牌集合中的所有令牌都**添加了位置嵌入**。如果缺少这一点,掩码令牌将无法获得其在原始输入中的位置信息,从而导致重建困难。遵循[19, 65]的实践,论文采用了**不对称设计**,即解码器相对于编码器而言更为**轻量化**(即,其层数更少、宽度更窄)。这种设计理念基于这样的认识:编码器负责从复杂输入中提取高层特征,需要更强的容量;而解码器则主要负责基于这些特征进行重建,通常可以用更少的参数完成任务,从而提高了整体训练效率。 +* **微调阶段:** 与FNO不同,在Transformer架构的**微调阶段**,**解码器得以保留**。这是因为Transformer编码器的输出通常是一系列经过处理的块嵌入,而非直接的空间解。为了将这些块级别的表示重新组织并重建回原始的空间网格结构,从而获得最终的PDE解,解码器是必不可少的。它的存在确保了从抽象的块级特征到具体物理空间解的平滑转换,使得整个模型在有监督微调时能够进行端到端的学习和优化。 + +### 3.2 基于相似性的上下文示例挖掘 + +**分布外(Out-Of-Distribution, OOD)泛化能力**,即模型在训练数据分布之外的、从未见过的数据上依然能够保持高性能的强大能力,这不仅在科学机器学习(SciML)这一新兴领域内部构成一个**核心且极具挑战性的技术难题**,而且在更广泛的科学人工智能(Scientific AI)的多个分支和实际应用场景中,同样被视为亟待解决的**关键瓶颈**。鉴于真实世界的科学现象和实验条件往往具有**高度的复杂性、不确定性和多样性**(例如,在极端物理参数、全新的几何构型、多物理场耦合效应,或含有噪声/不完整性的真实测量数据等情境下),模型能否在**超出其训练经验所覆盖的条件下**,依然做出准确、可靠且具有物理合理性的预测,直接决定了其在实际科学研究和工程应用中的**有效性、可信度乃至安全性**。 + +为了显著提升神经算子在面对这些未见分布数据时的泛化能力,并同时减少在下游任务中进行额外且耗时的模型微调工作量和计算开销,学术界已经提出了一种**创新性的推理范式**:在这种范式下,当模型接收到一个新的查询输入(即需要求解的特定PDE实例)时,它不再仅仅依靠自身内部参数进行预测,而是**同时被提供一些精心挑选的“支持示例”(Support Examples,通常被称为“演示”,即Demonstrations)**。这些支持示例连同其对应的真实解一同呈现给模型,作为辅助性的情境信息,以指导其进行最终的预测。这种**在推理时动态引入情境信息**的方法,赋予了模型进行**“开放集”(Open-Set)泛化**的独特能力,使其能够对那些在训练阶段从未出现过的、甚至分布迥异的样本做出合理且准确的预测,从而极大地拓宽了模型的应用范围,使其能够更好地应对真实世界中层出不穷的新挑战。 + +最初,在**少样本学习(Few-Shot Learning)**的广阔研究文献中,研究人员为了实现上述高度灵活的推理目标,开发出了许多**精细且复杂的架构**。这些传统架构的核心设计理念在于,通过引入专门的模块和复杂的网络结构,主动地寻找目标查询与支持示例之间在特征层面的**深层对应关系和相似性**。这类额外架构/训练设计的目的可以概括为两个主要方面:首先,它们旨在**精准地量化或识别目标输入与支持示例之间存在的相似性**,这可能涉及到复杂的度量学习(Metric Learning)、特征对齐技术(Feature Alignment)或基于注意力的相似性计算机制;其次,它们致力于**有效地聚合来自支持示例的标签信息**,以作为指导信号,用于生成最终的预测结果。这种聚合可能通过加权平均、神经网络融合或迭代优化等方式实现。近年来,在基于PDE数据进行**情境学习(In-Context Learning, ICL)**的研究工作中,也普遍采用了这种策略。例如,一些模型通过引入**Transformer架构**和**Cross-Attention层**来实现情境示例与查询之间的信息交互,以便模型能够从演示中学习如何适应新任务。然而,这种方法通常意味着需要**额外的架构设计**和相应的**训练成本**来优化这些复杂的交互机制,这无疑增加了模型的训练复杂度和计算负担,也可能限制了其在大规模应用中的可扩展性。 + +然而,**大型语言模型(Large Language Models, LLMs)的崛起及其所促成的情境学习(ICL)范式**,为我们提供了一种截然不同且更具效率的策略,深刻地改变了我们对模型泛化方式的认知。在这种范式下,模型的**预训练阶段依然保持其标准化和简洁性**,通常围绕着预测下一个掩码令牌(Masked Token Prediction)或自回归语言建模等任务进行,而**无需引入额外的、为支持ICL而设计的训练成本或复杂的架构调整**。这种简化预训练过程的优势在于,它使得模型的训练更具**可扩展性**,并能更高效地利用海量无标签数据,因为核心模型只需学习通用的表示,而非针对特定情境学习机制进行优化。在推理过程中,LLMs展现出**高度的灵活性和强大能力**:它们能够以**自回归的方式**接受任意数量的少样本示例作为输入序列的一部分,无需预先设定示例的数量或格式。通过其**强大的自注意力机制(Self-Attention Mechanism)**,LLMs能够在输入的少样本示例中的令牌(tokens)与目标查询中的令牌之间**自动地建立起复杂的相似性关联**,并进行**深层的信息整合和推理**。随后,通过对少样本示例中令牌嵌入的有效聚合和处理,模型能够生成高质量的响应或预测,其结果能够根据提供的情境示例进行动态调整。LLM中所采用的这种ICL策略,以其**无与伦比的高度可扩展性和卓越的训练效率**,为科学机器学习,尤其是神经算子学习,提供了**新的启示和可能性**。它强烈表明,**无需复杂的额外训练或架构设计**,仅通过在推理时灵活地利用情境信息,也能**显著提升模型在复杂科学问题上的泛化能力和数据效率**,开启了在数据稀缺和OOD场景下部署SciML模型的新篇章。 + +受此启发,论文提出通过两个步骤来利用情境示例: + +1. 通过预测的相似性。 论文通过计算它们在输出空间中的距离来找到空间和时间上的相似演示 。这意味着,对于空间和时间域上的两个输入位置,如果论文发现它们经过训练的神经算子的输出相似,那么论文就将它们视为相似样本 。遵循 [24, 25],论文假设演示与查询共享相同的物理参数分布 。 + +2. 聚合。 对于查询的每个空间-时间位置,在找到其在演示中的相似样本后,论文聚合并平均它们的解作为预测 。 + +![fig10](https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/fig10.png) + +## 4. 完整代码 + +``` py linenums="1" title="examples/data_efficient_nopt/data_efficient_nopt.py" +--8<-- +examples/data_efficient_nopt/data_efficient_nopt.py +--8<-- +``` + +## 5. 结果展示 + +论文通过在多个**偏微分方程(PDE)基准测试**中进行大量且严谨的实证评估,并模拟了**泊松方程、亥姆霍兹方程**等多种真实的物理场景,从而全面且有力地展示了其所提出的创新框架在解决科学计算问题方面的卓越性能和独特优越性。这些评估不仅覆盖了不同类型的PDE,还深入探讨了模型在面对实际应用挑战时的鲁棒性和泛化能力。 + +### 显著提升数据效率 (Significant Improvement in Data Efficiency) + +本研究的核心成果之一在于**显著提升了数据效率**,这在PDE模拟数据生成成本高昂的背景下显得尤为重要。 + +* **数据节省的巨大潜力:** 实验结果清晰地表明,在广泛的无标签PDE数据上进行无监督预训练,能够极其显著地减少对昂贵且难以获取的**标签数据(即带有精确模拟解的数据)**的需求。与那些从随机初始化状态开始训练的神经算子相比,经过无监督预训练后的模型,在达到相同的性能水平时,所需的标签数据量大大减少。这意味着,模型在学习复杂物理规律和函数映射时,能够更加高效地利用有限的监督信息,从而极大地降低了数据采集和处理的成本。 +* **量化效益的惊人体现:** 这种数据效率的提升并非仅仅是概念上的,论文通过量化分析揭示了其在实践中的惊人效益。在不同的PDE系统和仿真任务中,本研究所提出的方法能够**节省多达 $5 \times 10^2$ 到 $8 \times 10^5$ 个昂贵的数值模拟解**。这意味着,那些原本需要动用大型计算集群、耗费数万甚至数十万次高精度数值模拟才能训练出的复杂模型,现在可能仅仅需要数百或数千次模拟便能达到相似甚至更优的性能。这种计算资源的节省,不仅体现在CPU/GPU时间的缩短上,也包括了能源消耗的降低以及科研人员宝贵时间的解放,从而极大地加速了科学发现和工程设计周期。 + +### 超越现有通用预训练模型 (Outperforming Existing General Pretrained Models) + +为了进一步验证论文方法在物理领域特定性上的优势,论文还进行了一项关键的对比实验。 + +* **与通用模型的性能对比:** 论文将本研究提出的、针对物理领域设计的无监督预训练方法,与那些来自**计算机视觉领域**的现有最先进通用预训练模型进行了深入对比。例如,论文评估了在大型通用数据集(如SSV2数据集)上预训练过的**Video-MAE**模型。令人瞩目的是,即便这些通用模型在各自领域展现出强大的性能,但在有限的PDE模拟数据下进行微调时,它们的性能表现却不尽人意,往往难以达到与论文方法相媲美的精度和鲁棒性。 +* **强调领域适应性预训练的重要性:** 这一对比结果有力地强调了在科学机器学习(SciML)领域进行**领域适应性无监督预训练**的极端重要性。通用模型在像素级或图像块级的特征学习上可能表现出色,但它们通常缺乏对物理数据所特有的**内在结构、多尺度特性以及底层物理约束(如守恒定律、边界条件)**的先验知识或有效编码。物理数据与自然图像数据在本质上存在巨大差异,这使得通用预训练模型难以在仅通过少量PDE数据微调后即刻适应复杂的物理规律,从而凸显了专门为物理系统设计的无监督预训练方法的不可替代性。 + +### 更好的泛化能力 (Improved Generalization Capability) + +本框架的另一项显著优势在于其能够带来**显著提升的泛化能力**,这对于模型在真实世界中的可靠应用至关重要。 + +* **显著减小泛化差距:** 无监督预训练的引入,使得模型在学习过程中能够从海量的无标签数据中提取出更加**鲁棒和通用的特征表示**。这种机制强有力地规范化了模型,有效防止了其仅仅记忆训练集中的特定模式,从而**显著减小了模型的泛化差距(即测试误差与训练误差之间的差异)**。这意味着经过预训练的模型在面对从未见过的测试数据时,能够表现出更稳定、更准确的预测性能,其学习到的知识更具迁移性。 +* **在OOD场景下的卓越性能提升:** 当本研究的无监督预训练框架与**情境学习(In-Context Learning, ICL)**机制相结合时,模型在**分布外(Out-of-Distribution, OOD)**的场景下展现出了前所未有的强大泛化能力。例如,当PDE的**物理参数(如扩散系数、反应速率)或初始条件(如初始流场、温度分布)**超出了模型在有监督阶段所训练的范围时,模型仍能做出高度合理且准确的预测。这种在复杂外推任务中的可靠性,使得模型能够被更广泛地应用于探索新的物理现象或模拟极端条件。 + +### 更快的收敛速度 (Faster Convergence Speed) + +除了性能上的提升,本研究提出的预训练策略还带来了训练效率上的巨大飞跃。 + +* **领域适应的良好初始化:** 在大规模无标签PDE数据上进行预训练,为神经算子提供了一个极其优越且**领域适应性强**的良好初始化。这意味着模型在开始有监督微调之前,其内部权重就已经对PDE数据的基本结构、物理约束以及潜在模式有了初步而深刻的理解,而非从完全随机的状态开始学习。 +* **加速监督微调过程:** 这种经过“预热”和“领域对齐”的预训练权重,使得模型在后续的**监督微调阶段**能够以显著更快的速度收敛到最优解。优化器不再需要在巨大的参数空间中盲目探索,而是从一个已经靠近“解决方案区域”的有利位置开始。这直接转化表现为**训练时间的显著缩短**、**计算资源消耗(如GPU小时数)的降低**,以及科研人员更快的实验迭代周期,极大地提高了研究和开发的效率。 + +### 情境学习的额外泛化优势 (Additional Generalization Advantage from In-Context Learning) + +情境学习作为论文框架的另一关键组成部分,在不增加任何额外训练开销的前提下,为模型的泛化能力提供了独特的加成。 + +* **零训练开销的灵活性:** 情境学习最显著的优势在于其**无需任何额外的训练成本**。它在**推理阶段**通过灵活地引入与当前查询任务相似的支持示例,来动态地增强模型的预测能力。这意味着模型在部署之后,无需进行任何参数更新或再训练,即可通过选择相关的上下文信息来适应新的情况,从而大大提升了模型的实时适应性和实用性。 +* **持续且可量化的性能提升:** 论文的实验结果明确证明,通过**增加情境示例的数量**,神经算子在各种类型的PDE问题上,尤其是在**OOD泛化能力方面**,能够持续且可量化地得到提升。这为模型在部署之后,通过简单的上下文管理(例如,在发现新的类似数据时将其加入情境示例库),来**动态提升其在复杂未知情况下的性能**提供了切实可行的路径。尤其是在解决OOD问题时,情境示例能够提供关键的**校准信息**,帮助模型校准其输出结果的**量级和模式**,使其预测结果能够更加精准地匹配真实的物理解,从而在面对极端或异常的物理条件时,依然能够提供高度可靠的预测,显著增强了模型的鲁棒性和可靠性。 + +图2体现了在该方法下,在多种泛化场景下,数据效率使用的提升,有更好的收敛速度。 + +![fig2](https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/fig3.png) + +图3体现该种方法在情景推理场景中,超越了现有的通用预训练模型,具有额外泛化优势。 + +![fig6](https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/fig6.png) + +总而言之,这篇论文提出了一种创新且高效的神经算子学习框架,通过无监督预训练在大量廉价的无标签物理数据上学习通用表示,并通过情境学习在推理阶段利用少量相似案例来提升OOD泛化能力。这一框架显著降低了对昂贵模拟数据的需求,并提高了模型在复杂物理问题中的适应性和泛化性,为科学机器学习的数据高效发展开辟了新途径。通过不同模型上实际数据与预测误差的对比,验证了该论文优秀的泛化性。 + +![fig7](https://dataset.bj.bcebos.com/PaddleScience/data_efficient_nopt/fig7.png) + +## 6. 参考资料 + +- [Data-Efficient Operator Learning via Unsupervised Pretraining and In-Context Learning](https://arxiv.org/abs/2402.15734) diff --git a/examples/data_efficient_nopt/config/data_efficient_nopt_fno_poisson.yaml b/examples/data_efficient_nopt/config/data_efficient_nopt_fno_poisson.yaml new file mode 100644 index 000000000..15aabae85 --- /dev/null +++ b/examples/data_efficient_nopt/config/data_efficient_nopt_fno_poisson.yaml @@ -0,0 +1,325 @@ +# general settings +mode: train +seed: 42 + +# training settings +run_name: r0 +use_ddp: False +config: pois-64-pretrain-e1_20_m3 +sweep_id: '' +logdir: exp +output_dir: ${hydra:run.dir} + +train_config: + default: &DEFAULT + num_data_workers: 4 + # model + model: 'fno' + depth: 5 + in_dim: 2 + out_dim: 1 + dropout: 0 + # data/domain + Lx: !!float 1.0 + Ly: !!float 1.0 + nx: 256 + ny: 256 + # optimization + loss_style: 'mean' + loss_func: 'mse' + optimizer: 'adam' + scheduler: 'none' + learning_rate: !!float 1.0 + max_epochs: 500 + scheduler_epochs: 500 + weight_decay: 0 + batch_size: 25 + # misc + log_to_screen: !!bool True + save_checkpoint: !!bool False + seed: 0 + plot_figs: !!bool False + pack_data: !!bool False + # Weights & Biases + entity: 'entity_name' + project: 'proj_name' + group: 'poisson' + log_to_wandb: !!bool False + distill: !!bool False + subsample: 1 + exp_dir: './exp/' + tie_fields: !!bool False # Whether to use 1 embedding per field per data + use_all_fields: !!bool True # Prepopulate the field metadata dictionary from dictionary in datasets + tie_batches: !!bool False # Force everything in batch to come from one dset + model_type: fno + pretrained: False + warmup_steps: 0 + epoch_size: 1 + accum_grad: 1 + enable_amp: !!bool False + log_interval: 1 + checkpoint_save_interval: 1000 + debug_grad: False + + poisson: &poisson + <<: *DEFAULT + n_demos: 0 + batch_size: 512 + nx: 128 + ny: 128 + save_checkpoint: !!bool True + max_epochs: 500 + scheduler: 'cosine' + + model: 'fno' + layers: [64, 64, 64, 64, 64] + modes1: [65, 65, 65, 65] + modes2: [65, 65, 65, 65] + fc_dim: 256 + + in_dim: 4 + out_dim: 1 + mode_cut: 16 + embed_cut: 64 + fc_cut: 2 + + optimizer: 'adam' + + learning_rate: 1E-3 + pack_data: !!bool False + + + poisson-64-scale-e5_15: &poisson_64_e5_15 + <<: *poisson + train_path: 'data/possion_64/poisson_64_e5_15_train.h5' + val_path: 'data/possion_64/poisson_64_e5_15_val.h5' + test_path: 'data/possion_64/poisson_64_e5_15_test.h5' + scales_path: 'data/possion_64/poisson_64_e5_15_train_scale.npy' + train_rand_idx_path: 'data/possion_64/train_rand_idx.npy' + batch_size: 128 + log_to_wandb: !!bool False + learning_rate: 1E-3 + + mode_cut: 32 + embed_cut: 64 + fc_cut: 2 + subsample: 1 + nx: 64 + ny: 64 + + pt: "train" + pt_split: [46080, 8192] + pretrained: False + + + pois-64-pretrain-e1_20: &pois_64_e1_20_pt + <<: *poisson + train_path: 'data/possion_64/poisson_64_e1_20_train.h5' + val_path: 'data/possion_64/poisson_64_e1_20_val.h5' + test_path: 'data/possion_64/poisson_64_e1_20_test.h5' + scales_path: 'data/possion_64/poisson_64_e1_20_train_scale.npy' + train_rand_idx_path: 'data/possion_64/train_rand_idx.npy' + batch_size: 128 + log_to_wandb: !!bool False + mode_cut: 32 + embed_cut: 64 + fc_cut: 2 + subsample: 1 + nx: 64 + ny: 64 + learning_rate: 1E-3 + pt: "pretrain" + pt_split: [46080, 8192] + blur: [0, 1] + + + pois_64_finetune_e5_15: &pois_64_e5_15_ft + <<: *poisson + train_path: 'data/possion_64/poisson_64_e5_15_train.h5' + val_path: 'data/possion_64/poisson_64_e5_15_val.h5' + test_path: 'data/possion_64/poisson_64_e5_15_test.h5' + scales_path: 'data/possion_64/poisson_64_e5_15_train_scale.npy' + train_rand_idx_path: 'data/possion_64/train_rand_idx.npy' + batch_size: 128 + log_to_wandb: !!bool False + mode_cut: 32 + embed_cut: 64 + fc_cut: 2 + subsample: 1 + nx: 64 + ny: 64 + learning_rate: 1E-3 + pt: "train" + pt_split: [46080, 8192] + fix_backbone: False + resuming: False + pretrained: True + pretrained_ckpt_path: /pretrained_ckpt_path/training_checkpoints/ckpt.tar + + pois-64-e5_15_ft0: &pois_64_e5_15_ft0 + <<: *pois_64_e5_15_ft + subsample: 1 + + pois-64-e5_15_ft1: &pois_64_e5_15_ft1 + <<: *pois_64_e5_15_ft + subsample: 2 + + pois-64-e5_15_ft2: &pois_64_e5_15_ft2 + <<: *pois_64_e5_15_ft + subsample: 4 + + pois-64-e5_15_ft3: &pois_64_e5_15_ft3 + <<: *pois_64_e5_15_ft + subsample: 8 + + pois-64-e5_15_ft4: &pois_64_e5_15_ft4 + <<: *pois_64_e5_15_ft + subsample: 16 + + pois-64-e5_15_ft5: &pois_64_e5_15_ft5 + <<: *pois_64_e5_15_ft + subsample: 32 + + pois-64-e5_15_ft6: &pois_64_e5_15_ft6 + <<: *pois_64_e5_15_ft + subsample: 64 + + pois-64-e5_15_ft7: &pois_64_e5_15_ft7 + <<: *pois_64_e5_15_ft + subsample: 128 + batch_size: 64 + + pois-64-e5_15_ft8: &pois_64_e5_15_ft8 + <<: *pois_64_e5_15_ft + subsample: 256 + batch_size: 32 + + pois-64-e5_15_ft9: &pois_64_e5_15_ft9 + <<: *pois_64_e5_15_ft + subsample: 512 + batch_size: 16 + + pois-64-pretrain-e1_20_m0: &pois-64-e1_20_pt_m0 + <<: *pois_64_e1_20_pt + mask_ratio: 0. + + pois-64-pretrain-e1_20_m1: &pois-64-e1_20_pt_m1 + <<: *pois_64_e1_20_pt + mask_ratio: 0.1 + + pois-64-pretrain-e1_20_m2: &pois-64-e1_20_pt_m2 + <<: *pois_64_e1_20_pt + mask_ratio: 0.2 + + pois-64-pretrain-e1_20_m3: &pois-64-e1_20_pt_m3 + <<: *pois_64_e1_20_pt + mask_ratio: 0.3 + + pois-64-pretrain-e1_20_m4: &pois-64-e1_20_pt_m4 + <<: *pois_64_e1_20_pt + mask_ratio: 0.4 + + pois-64-pretrain-e1_20_m5: &pois-64-e1_20_pt_m5 + <<: *pois_64_e1_20_pt + mask_ratio: 0.5 + + pois-64-pretrain-e1_20_m6: &pois-64-e1_20_pt_m6 + <<: *pois_64_e1_20_pt + mask_ratio: 0.6 + + pois-64-pretrain-e1_20_m7: &pois-64-e1_20_pt_m7 + <<: *pois_64_e1_20_pt + mask_ratio: 0.7 + + pois-64-pretrain-e1_20_m8: &pois-64-e1_20_pt_m8 + <<: *pois_64_e1_20_pt + mask_ratio: 0.8 + + pois-64-pretrain-e1_20_m9: &pois-64-e1_20_pt_m9 + <<: *pois_64_e1_20_pt + mask_ratio: 0.9 + + + + poisson-64-e5_15_bsln: &pois_64_e5_15_baseline + <<: *poisson_64_e5_15 + + # 8192 + poisson-64-e5_15_b0: &pois_64_e5_15_ss4 + <<: *pois_64_e5_15_baseline + subsample: 1 + + poisson-64-e5_15_b1: &pois_64_e5_15_ss8 + <<: *pois_64_e5_15_baseline + subsample: 2 + + poisson-64-e5_15_b2: &pois_64_e5_15_ss16 + <<: *pois_64_e5_15_baseline + subsample: 4 + + poisson-64-e5_15_b3: &pois_64_e5_15_ss32 + <<: *pois_64_e5_15_baseline + subsample: 8 + + poisson-64-e5_15_b4: &pois_64_e5_15_ss64 + <<: *pois_64_e5_15_baseline + subsample: 16 + + poisson-64-e5_15_b5: &pois_64_e5_15_ss128 + <<: *pois_64_e5_15_baseline + subsample: 32 + + poisson-64-e5_15_b6: &pois_64_e5_15_ss256 + <<: *pois_64_e5_15_baseline + subsample: 64 + + poisson-64-e5_15_b7: &pois_64_e5_15_ss512 + <<: *pois_64_e5_15_baseline + subsample: 128 + batch_size: 64 + + +# inference settings +ckpt_path: data/pd_finetune_b01_m0_n8192.tar +num_demos: 1 +tqdm: False +save_pred: False + +infer_config: + train_path: 'data/possion_64/poisson_64_e15_50_train.h5' # pick demos + test_path: 'data/possion_64/poisson_64_e15_50_test.h5' + scales_path: 'data/possion_64/poisson_64_e5_15_train_scale.npy' + ckpt_path: data/possion_64/finetune_b01_m0_n8192.pdparams + + num_data_workers: 4 + subsample: 1 + num_demos: 0 + shuffle: False + nx: 64 + nt: 64 + Lx: !!float 1.0 + Ly: !!float 1.0 + pack_data: !!bool False + + model: 'fno' + layers: [64, 64, 64, 64, 64] + modes1: [65, 65, 65, 65] + modes2: [65, 65, 65, 65] + fc_dim: 128 + + in_dim: 4 + out_dim: 1 + mode_cut: 32 + embed_cut: 64 + fc_cut: 2 + dropout: 0 + + fix_backbone: True + + loss_func: mse + + batch_size: 1 + loss_style: sum + + log_to_wandb: !!bool False + logdir: ./log diff --git a/examples/data_efficient_nopt/data_efficient_nopt.py b/examples/data_efficient_nopt/data_efficient_nopt.py new file mode 100644 index 000000000..db16fdda0 --- /dev/null +++ b/examples/data_efficient_nopt/data_efficient_nopt.py @@ -0,0 +1,736 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +import random +from argparse import Namespace +from collections import OrderedDict +from os import path as osp + +import hydra +import numpy as np +import paddle +import paddle.amp as amp +import paddle.distributed as dist +import paddle.nn as nn +import paddle.nn.functional as F +import paddle.optimizer as optim +from einops import rearrange +from omegaconf import DictConfig +from ruamel.yaml import YAML +from scipy.stats import linregress +from tqdm import tqdm + +import ppsci +from ppsci.arch.data_efficient_nopt_model import add_weight_decay +from ppsci.arch.data_efficient_nopt_model import build_fno +from ppsci.arch.data_efficient_nopt_model import fno_pretrain as fno +from ppsci.arch.data_efficient_nopt_model import gaussian_blur +from ppsci.arch.data_efficient_nopt_model import get_cutoff +from ppsci.arch.data_efficient_nopt_model import grad_norm +from ppsci.arch.data_efficient_nopt_model import l2_err +from ppsci.arch.data_efficient_nopt_model import param_diff +from ppsci.arch.data_efficient_nopt_model import param_norm +from ppsci.data.dataset.data_efficient_nopt_dataset import MixedDatasetLoader +from ppsci.data.dataset.data_efficient_nopt_dataset import PoisHelmDatasetLoader +from ppsci.utils import logger + + +class Trainer: + def __init__( + self, params, global_rank, local_rank, device, output_dir, sweep_id=None + ): + self.device = device + self.params = params + self.output_dir = output_dir + self.global_rank = global_rank + self.local_rank = local_rank + self.world_size = int(os.environ.get("WORLD_SIZE", 1)) + self.sweep_id = sweep_id + self.log_to_screen = params.log_to_screen + self.train_loss = nn.MSELoss() + self.startEpoch = 0 + self.epoch = 0 + self.debug_grad = params.debug_grad + self.mp_type = ( + "bfloat16" + if paddle.device.cuda.device_count() >= 1 + and paddle.amp.is_bfloat16_supported() + else "float16" + ) + + self.iters = 0 + self.initialize_data(self.params) + self.initialize_model(self.params) + self.initialize_optimizer(self.params) + if params.resuming: + logger.info("Loading checkpoint %s" % params.checkpoint_path) + self.restore_checkpoint(params.checkpoint_path) + elif params.resuming is False and params.pretrained: + logger.info( + "Starting from pretrained model at %s" % params.pretrained_ckpt_path + ) + self.restore_checkpoint(params.pretrained_ckpt_path) + self.iters = 0 + self.startEpoch = 0 + else: + pass + + self.initialize_scheduler(self.params) + + def initialize_data(self, params): + if params.tie_batches: + in_rank = 0 + else: + in_rank = self.global_rank + if self.log_to_screen: + print(f"Initializing data on rank {self.global_rank}") + + if self.params.model_type == "fno": + if params.mode == "train": + params.masking = ((params.nx, params.ny), params.mask_ratio) + ( + self.train_data_loader, + self.train_dataset, + self.train_sampler, + ) = PoisHelmDatasetLoader( + params, params.train_path, dist.is_initialized(), train=True + ) + ( + self.valid_data_loader, + self.valid_dataset, + self.valid_sampler, + ) = PoisHelmDatasetLoader( + params, params.val_path, dist.is_initialized(), train=False + ) + elif self.params.model_type == "vmae": + params.masking = ( + ( + params.n_steps, + params.input_size // params.patch_size, + params.input_size // params.patch_size, + ), + params.mask_ratio, + ) + + ( + self.train_data_loader, + self.train_dataset, + self.train_sampler, + ) = MixedDatasetLoader( + params, + params.train_data_paths, + dist.is_initialized(), + split="train", + rank=in_rank, + train_offset=self.params.embedding_offset, + ) + self.valid_data_loader, self.valid_dataset, _ = MixedDatasetLoader( + params, + params.valid_data_paths, + dist.is_initialized(), + split="val", + rank=in_rank, + ) + if dist.is_initialized(): + self.train_sampler.set_epoch(0) + + def initialize_model(self, params): + if self.params.model_type == "fno": + if self.params.mode == "train": + self.model = fno(params) + elif self.params.mode == "finetune": + logger.info("Using Build FNO") + self.model = build_fno(params) + else: + raise NotImplementedError("Only support FNO for now") + + if dist.is_initialized(): + self.model = paddle.DataParallel( + self.model, + find_unused_parameters=True, + ) + + print( + f"Model parameter count: {sum([p.numel() for p in self.model.parameters()])}" + ) + + def initialize_optimizer(self, params): + parameters = add_weight_decay(self.model, self.params.weight_decay) + if params.optimizer == "adam": + self.optimizer = optim.AdamW( + parameters=parameters, learning_rate=params.learning_rate + ) + else: + raise ValueError(f"Optimizer {params.optimizer} not supported") + self.gscaler = amp.GradScaler( + enable=(self.mp_type == paddle.float16 and params.enable_amp) + ) + + def initialize_scheduler(self, params): + if params.scheduler_epochs > 0: + sched_epochs = params.scheduler_epochs + else: + sched_epochs = params.max_epochs + if params.scheduler == "cosine": + if self.params.learning_rate < 0: + self.scheduler = paddle.optimizer.lr.CosineAnnealingDecay( + learning_rate=self.optimizer.get_lr(), + last_epoch=(self.startEpoch * params.epoch_size) - 1, + T_max=sched_epochs * params.epoch_size, + eta_min=params.learning_rate / 100, + ) + self.optimizer.set_lr_scheduler(self.scheduler) + else: + k = params.warmup_steps + if (self.startEpoch * params.epoch_size) < k: + warmup = paddle.optimizer.lr.LinearLR( + learning_rate=self.optimizer.get_lr(), + start_factor=0.01, + end_factor=1.0, + total_iters=k, + ) + self.optimizer.set_lr_scheduler(warmup) + decay = paddle.optimizer.lr.CosineAnnealingDecay( + learning_rate=self.optimizer.get_lr(), + eta_min=params.learning_rate / 100, + T_max=sched_epochs, + ) + self.optimizer.set_lr_scheduler(decay) + + raise NotImplementedError("Scheduler not implemented yet") + else: + self.scheduler = paddle.optimizer.lr.CosineAnnealingDecay( + learning_rate=self.optimizer.get_lr(), T_max=sched_epochs + ) + self.optimizer.set_lr_scheduler(self.scheduler) + elif params.scheduler == "reducelr": + self.scheduler = paddle.optimizer.lr.ReduceOnPlateau( + learning_rate=self.optimizer.get_lr(), + mode="min", + patience=params.patience, + verbose=True, + min_lr=1e-3 * 1e-5, + factor=0.2, + ) + self.optimizer.set_lr_scheduler(self.scheduler) + else: + self.scheduler = None + + def save_checkpoint(self, checkpoint_path, model=None): + """Save model and optimizer to checkpoint""" + if not model: + model = self.model + + paddle.save( + { + "iters": self.epoch * self.params.epoch_size, + "epoch": self.epoch, + "model_state": model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + }, + checkpoint_path, + ) + + def restore_checkpoint(self, checkpoint_path): + """Load model/opt from path""" + checkpoint = paddle.load(checkpoint_path) + try: + self.model.set_state_dict(checkpoint["model_state"]) + except: # noqa + new_state_dict = OrderedDict() + for key, val in checkpoint["model_state"].items(): + name = key[7:] + new_state_dict[name] = val + self.model.set_state_dict(new_state_dict) + + if self.params.resuming: + self.optimizer.set_state_dict(checkpoint["optimizer_state_dict"]) + self.startEpoch = checkpoint["epoch"] + self.epoch = self.startEpoch + self.iters = checkpoint["iters"] + else: + self.iters = 0 + checkpoint = None + self.model = self.model + + def train_one_epoch(self): + self.model.train() + self.epoch += 1 + logs = { + "train_rmse": paddle.zeros([1]), + "train_nrmse": paddle.zeros([1]), + "train_l1": paddle.zeros([1]), + "train_l2": paddle.zeros([1]), + "train_loss": paddle.zeros([1]), + } + steps = 0 + n = len(self.train_data_loader) + for batch_idx, data in enumerate(self.train_data_loader): + steps += 1 + if len(data) == 3: + inp, label, mask = map(lambda x: x, data) + if sum(self.params.blur) > 0: + inp_blur = [] + for _inp in inp: + sigma = random.uniform(*self.params.blur) + _kernel = min( + int((sigma * 4 + 1) / 2) * 2 + 1, + (_inp.shape[2] // 2) * 2 - 1, + ) + if _kernel >= 2: + _inp = gaussian_blur( + _inp, kernel_size=[_kernel, _kernel], sigma=sigma + ) + inp_blur.append(_inp) + inp_blur = paddle.stack(inp_blur, axis=0) + else: + inp_blur = inp.detach().clone() + else: + inp, label = map(lambda x: x, data) + mask = None + inp_blur = inp.detach().clone() + if len(inp.shape) == 5: + inp = rearrange(inp, "b t c h w -> t b c h w") + inp_blur = rearrange(inp_blur, "b t c h w -> t b c h w") + + self.model.require_backward_grad_sync = ( + 1 + batch_idx + ) % self.params.accum_grad == 0 + with amp.auto_cast(self.params.enable_amp, dtype=self.mp_type): + if self.params.mode == "train": + output = self.model(inp_blur, mask) + elif self.params.mode == "finetune": + output = self.model(inp) + else: + raise ValueError(f"Invalid mode {self.params.mode}") + + if self.params.mode == "train": + label = inp + else: + label = label + + if len(label.shape) == 5: + labels = rearrange( + label.permute(1, 2, 0, 3, 4), + "b c t (h p1) (w p2) -> b (t h w) (p1 p2 c)", + h=label.shape[3] // self.model.patch_size, + w=label.shape[4] // self.model.patch_size, + p1=self.model.patch_size, + p2=self.model.patch_size, + ) + if mask is not None: + mask = mask.flatten(1).to(paddle.bool) + if mask.sum() == 0: + labels = labels[~mask] + else: + labels = labels[mask] + labels = labels.reshape( + label.shape[1], -1, label.shape[2] * self.model.patch_size**2 + ) + spatial_dims = tuple(range(output.ndim))[2:] + + residuals = output - labels + inp_norm = 1e-7 + labels.pow(2).mean(spatial_dims, keepdim=True) + raw_loss = (residuals).pow(2).mean( + spatial_dims, keepdim=True + ) / inp_norm + elif len(label.shape) == 4: + spatial_dims = tuple(range(output.ndim))[1:] + if mask is not None: + labels = label * (1 - mask) + output = output * (1 - mask) + else: + labels = label + + residuals = output - labels + raw_loss = (residuals) ** 2 + loss = raw_loss.sum() / output.shape[0] / self.params.accum_grad + + with paddle.no_grad(): + logs["train_l1"] += F.l1_loss(output, labels) + logs["train_nrmse"] += self.train_loss(output, labels) + logs["train_rmse"] += ( + residuals.pow(2).mean(spatial_dims).sqrt().mean() + ) + logs["train_l2"] += l2_err(output, labels, spatial_dims) + logs["train_loss"] += loss + log_nrmse = raw_loss.sqrt().mean() + self.gscaler.scale(loss).backward() + + if self.debug_grad and self.model.require_backward_grad_sync: + with paddle.no_grad(): + self.gscaler.unscale_(self.optimizer) + grad_diff = grad_norm(self.model.parameters()) + porig = [p.clone() for p in self.model.parameters()] + + if self.model.require_backward_grad_sync: + self.gscaler.unscale_(self.optimizer) + paddle.nn.utils.clip_grad_norm_(self.model.parameters(), 1) + self.gscaler.step(self.optimizer) + self.gscaler.update() + if self.debug_grad: + if self.global_rank == 0: + pdiff = param_diff(self.model.parameters(), porig) + print( + "grad_norm", + grad_diff, + "last_step_size", + pdiff, + "loss", + loss.item(), + "data_shape", + label.shape, + ) + self.optimizer.clear_gradients(set_to_zero=False) + if self.scheduler is not None: + self.scheduler.step() + if ( + self.log_to_screen + and batch_idx % self.params.log_interval == 0 + and self.global_rank == 0 + ): + logger.info( + f"Epoch {self.epoch}/{self.params.max_epochs} Batch {batch_idx+1}/{len(self.train_data_loader)} Train Loss {log_nrmse.item():.2e}" + ) + logs = {k: v / steps for k, v in logs.items()} + if dist.is_initialized(): + for key in sorted(logs.keys()): + dist.all_reduce(logs[key].detach()) + logs[key] = float(logs[key] / dist.get_world_size()) + + self.iters += steps + if self.global_rank == 0: + logs["iters"] = self.iters + logs["parameter norm"] = param_norm(self.model.parameters()) + logs["train_nrmse"] = logs["train_nrmse"].item() / n + logs["train_l2"] = logs["train_l2"].item() / n + return logs + + def single_dset_val(self, subset, logs, cutoff=40): + if self.params.use_ddp: + temp_loader = paddle.io.DataLoader( + subset, + batch_size=self.params.batch_size, + num_workers=self.params.num_data_workers, + ) + else: + temp_loader = paddle.io.DataLoader( + subset, + batch_size=self.params.batch_size, + num_workers=self.params.num_data_workers, + shuffle=True, + drop_last=True, + ) + count = 0 + for _, data in enumerate(temp_loader): + if count > cutoff: + del temp_loader + break + count += 1 + input = data[0] + label = data[1] if len(data) > 1 else None + + # unsupervised pretrain + if self.params.mode == "train": + label = input + if len(input.shape) == 5: + input = rearrange(input, "b t c h w -> t b c h w") + else: + pass + if self.params.mode == "train": + output = self.model(input, None) + elif self.params.mode == "finetune": + output = self.model(input) + if self.params.model_type == "fno": + spatial_dims = tuple(range(output.ndim))[1:] + elif self.params.model_type == "vmae": + spatial_dims = tuple(range(output.ndim))[2:] + else: + raise NotImplementedError + + logs["valid_nrmse"] += self.train_loss(output, label) + logs["valid_l2"] += l2_err(output, label, spatial_dims).item() + else: + del temp_loader + logs["valid_nrmse"] = logs["valid_nrmse"].item() / count + logs["valid_l2"] = logs["valid_l2"].item() / count + return logs + + def validate_one_epoch(self, full=False): + """ + Validates - for each batch just use a small subset to make it easier. + + Note: need to split datasets for meaningful metrics, but TBD. + """ + self.model.eval() + cutoff = get_cutoff(full=full) + with paddle.no_grad(): + with amp.auto_cast(enable=False, dtype=self.mp_type): + logs = { + "valid_nrmse": paddle.zeros([1]), + "valid_l2": paddle.zeros([1]), + } + if hasattr(self.valid_dataset, "sub_dsets"): + for subset_group in self.valid_dataset.sub_dsets: + for subset in subset_group.get_per_file_dsets(): + logs = self.single_dset_val(subset, logs, cutoff) + else: + logs = self.single_dset_val(self.valid_dataset, logs, cutoff) + + if dist.is_initialized(): + for key in sorted(logs.keys()): + dist.all_reduce(logs[key].detach()) + logs[key] = float(logs[key].item() / dist.get_world_size()) + if "rmse" in key: + logs[key] = logs[key] + return logs + + def train(self): + logger.info( + f"iters per epoch = {len(self.train_data_loader)}, samples number = {len(self.train_dataset)}, batch size = {self.params.batch_size}, total batches = {len(self.train_data_loader)*self.params.batch_size}" + ) + best_loss = 1.0 + + for epoch in range(self.startEpoch, self.params.max_epochs): + if dist.is_initialized(): + self.train_sampler.set_epoch(epoch) + train_logs = self.train_one_epoch() + if epoch == self.params.max_epochs - 1: + valid_logs = self.validate_one_epoch(True) + else: + valid_logs = self.validate_one_epoch() + train_logs.update(valid_logs) + gc.collect() + paddle.device.cuda.empty_cache() + if epoch % self.params.checkpoint_save_interval == 0: + save_dir = self.params.checkpoint_path.replace( + "ckpt", f"ckpt_epoch_{epoch}" + ) + logger.info("saving checkpoint : save_dir") + self.save_checkpoint(save_dir) + logger.info( + f"Train loss: {train_logs['train_nrmse']:.2e}, Valid Loss: {valid_logs['valid_nrmse']:.2e}\n" + ) + + if valid_logs["valid_nrmse"] < best_loss: + best_loss = valid_logs["valid_nrmse"] + save_dir = self.output_dir + "/best.pt" + logger.info( + f"saving best in epoch {epoch}, [valid = {best_loss:.2e}] checkpoint : {save_dir}" + ) + self.save_checkpoint(save_dir) + + save_dir = self.params.checkpoint_path.replace("ckpt", "ckpt_last") + logger.info(f"saving checkpoint : {save_dir}") + self.save_checkpoint(save_dir) + + +def train(config: DictConfig): + params = YAML() + params._config_name = config.config + params.params = {} + params.mode = config.mode + params.use_ddp = config.use_ddp + for key, val in config.train_config[config.config].items(): + val = None if val == "None" else val + params.params[key] = val + params.__setattr__(key, val) + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + global_rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + if config.use_ddp: + dist.init_process_group("nccl") + + device = f"gpu:{local_rank}" if paddle.device.cuda.device_count() >= 1 else "cpu" + paddle.set_device(device) + ppsci.utils.misc.set_random_seed(config.seed) + + params.batch_size = int(params.batch_size // world_size) + params.startEpoch = 0 + exp_dir = os.path.join(params.exp_dir, config.config, str(config.run_name)) + + params.old_exp_dir = exp_dir + params.experiment_dir = os.path.abspath(exp_dir) + params.checkpoint_path = os.path.join(exp_dir, "training_checkpoints/ckpt.tar") + params.best_checkpoint_path = os.path.join( + exp_dir, "training_checkpoints/best_ckpt.tar" + ) + params.old_checkpoint_path = os.path.join( + params.old_exp_dir, "training_checkpoints/best_ckpt.tar" + ) + + if global_rank == 0 and not os.path.isdir(exp_dir): + os.makedirs(exp_dir) + os.makedirs(os.path.join(exp_dir, "training_checkpoints/")) + params.resuming = True if os.path.isfile(params.checkpoint_path) else False + params.name = str(config.run_name) + params.log_to_screen = (global_rank == 0) and params.log_to_screen + + trainer = Trainer( + params, + global_rank, + local_rank, + device, + config.output_dir, + sweep_id=config.sweep_id, + ) + if config.sweep_id and trainer.global_rank == 0: + print(config.sweep_id, trainer.params.entity, trainer.params.project) + else: + trainer.train() + + +@paddle.no_grad() +def inference(config): + config = config.infer_config + if config.ckpt_path: + save_dir = os.path.join( + "/".join(config.ckpt_path.split("/")[:-1]), "results_icl" + ) + else: + basedir = os.path.join("exp", config["log"]["logdir"]) + save_dir = os.path.join(basedir, "results_icl") + os.makedirs(save_dir, exist_ok=True) + save_path = os.path.join( + save_dir, + "fno-prediction-demo%d.pt" % (config.num_demos if config.num_demos else 0), + ) + + params = Namespace(**config) + if not hasattr(params, "n_demos"): + params.n_demos = 0 + if "batch_size" in config: + params.local_valid_batch_size = config["batch_size"] + else: + params.local_valid_batch_size = 1 + dataloader, dataset, sampler = PoisHelmDatasetLoader( + params, params.test_path, dist.is_initialized(), train=False + ) + if config.num_demos is not None and config.num_demos != 0: + params.subsample = 1 + params.local_valid_batch_size = config.num_demos + dataloader_icl, dataset_icl, _ = PoisHelmDatasetLoader( + params, params.train_path, dist.is_initialized(), train=False + ) + input_demos, target_demos = next(iter(dataloader_icl)) + input_demos = input_demos + target_demos = target_demos + + model = build_fno(params) + + if config.ckpt_path: + checkpoint = paddle.load(config.ckpt_path) + try: + model.set_state_dict(checkpoint["model_state"]) + except: # noqa + new_state_dict = OrderedDict() + for key, val in checkpoint["model_state"].items(): + name = key + if "module" in name: + name = name[7:] + new_state_dict[name] = val + state = model.state_dict() + pretrained_dict = { + k: v + for k, v in new_state_dict.items() + if k in state and state[k].size() == new_state_dict[k].size() + } + state.update(pretrained_dict) + + unload_keys = [k for k in new_state_dict.keys() if k not in pretrained_dict] + if len(unload_keys) > 0: + import warnings + + warnings.warn( + "Warning: unload keys during restoring checkpoint: %s" + % (str(unload_keys)) + ) + + model.eval() + truth_list = [] + pred_list = [] + losses = [] + losses_normalized = [] + pbar = tqdm(dataloader, total=len(dataloader)) + for inputs, targets in pbar: + inputs, targets = inputs, targets + if config.num_demos is None or config.num_demos == 0: + u = model(inputs) + else: + model.target = targets + u = model.forward_icl( + inputs, input_demos, target_demos, use_tqdm=config.tqdm + ) + + data_loss = l2_err(u.detach(), targets.detach()) + losses.append(data_loss.item()) + data_loss_normalized = l2_err( + u.detach() / paddle.abs(u).max(), + targets.detach() / paddle.abs(targets).max(), + ) + losses_normalized.append(data_loss_normalized.item()) + truth_list.append(targets.cpu()) + pred_list.append(u.cpu()) + + slope, intercept, r, p, se = linregress( + paddle.concat(pred_list, axis=0).view([-1]).numpy(), + paddle.concat(truth_list, axis=0).view([-1]).numpy(), + ) + print( + "L2:", + np.mean(losses), + "L2 (normalized)", + np.mean(losses_normalized), + "R2:", + r, + "Slope:", + slope, + ) + truth_arr = paddle.concat(truth_list, axis=0) + pred_arr = paddle.concat(pred_list, axis=0) + paddle.save( + { + "truth": truth_arr, + "pred": pred_arr, + "rmse": np.mean(losses), + "rmse_normalized": np.mean(losses_normalized), + "r2": r, + "slope": slope, + }, + save_path, + ) + + +@hydra.main( + version_base=None, + config_path="./config", + config_name="data_efficient_nopt_fno_poisson", +) +def main(config: DictConfig): + logger.init_logger("ppsci", osp.join(config.logdir, f"{config.mode}.log"), "info") + if config.mode == "train" or config.mode == "finetune": + train(config) + elif config.mode == "infer": + inference(config) + else: + raise ValueError( + f"config.mode should in ['train', 'infer'], but got '{config.mode}'" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/data_efficient_nopt/requirements.txt b/examples/data_efficient_nopt/requirements.txt new file mode 100644 index 000000000..7481519bf --- /dev/null +++ b/examples/data_efficient_nopt/requirements.txt @@ -0,0 +1,3 @@ +ruamel.yaml==0.17.32 +ruamel.yaml.clib==0.2.7 +zarr==2.16.1 diff --git a/mkdocs.yml b/mkdocs.yml index 3808615cb..5851641c3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -54,6 +54,7 @@ nav: - NeuralOperator: zh/examples/neuraloperator.md - Brusselator3D: zh/examples/brusselator3d.md - Transformer4SR: zh/examples/transformer4sr.md + - DataEffcientNopt: zh/examples/data_efficient_nopt.md - 技术科学(AI for Technology): - 流体: - Catheter: zh/examples/catheter.md diff --git a/ppsci/arch/data_efficient_nopt_model.py b/ppsci/arch/data_efficient_nopt_model.py new file mode 100644 index 000000000..51bbb19e3 --- /dev/null +++ b/ppsci/arch/data_efficient_nopt_model.py @@ -0,0 +1,875 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# refs: https://github.com/delta-lab-ai/data_efficient_nopt + +from typing import List + +import paddle +import paddle.nn as nn +import paddle.tensor as Tensor + +from ppsci.arch.activation import act_func_dict + +FULL_MODE_CUTOFF = 999999999999 +NORMAL_MODE_CUTOFF = 40 + + +def get_cutoff(full): + return FULL_MODE_CUTOFF if full else NORMAL_MODE_CUTOFF + + +def l2_err(pred, target, spatial_dim=(-1, -2, -3)): + x = paddle.sum((pred - target) ** 2, axis=spatial_dim) / paddle.sum( + target**2, axis=spatial_dim + ) + x = paddle.sqrt(x) + return paddle.mean(x) + + +def grad_norm(parameters): + with paddle.no_grad(): + total_norm = 0 + for p in parameters: + if p.grad is not None: + total_norm += p.grad.data.pow(2).sum().item() + return total_norm**0.5 + + +def grad_clone(parameters): + with paddle.no_grad(): + clones = [] + for p in parameters: + if p.grad is not None: + clones.append(p.grad.clone()) + else: + clones.append(paddle.zeros_like(p)) + return clones + + +def param_norm(parameters): + with paddle.no_grad(): + total_norm = 0 + for p in parameters: + total_norm += p.pow(2).sum().item() + return total_norm**0.5 + + +def param_diff(params1, params2): + with paddle.no_grad(): + total_norm = 0 + for p1, p2 in zip(params1, params2): + total_norm += (p2 - p1).pow(2).sum().item() + return total_norm**0.5 + + +def add_weight_decay(model, weight_decay=1e-5, inner_lr=1e-3, skip_list=()): + decay = [] + no_decay = [] + for name, param in model.named_parameters(): + if param.stop_gradient: + continue + if len(param.squeeze().shape) <= 1 or name in skip_list: + no_decay.append(param) + else: + decay.append(param) + return [ + { + "params": no_decay, + "weight_decay": 0.0, + }, + {"params": decay, "weight_decay": weight_decay}, + ] + + +def compl_mul2d_v2(a: paddle.Tensor, b: paddle.Tensor) -> paddle.Tensor: + tmp = paddle.einsum("bixys,ioxyt->stboxy", a, b) + return paddle.stack( + [ + tmp[0, 0, :, :, :, :] - tmp[1, 1, :, :, :, :], + tmp[1, 0, :, :, :, :] + tmp[0, 1, :, :, :, :], + ], + axis=-1, + ) + + +class SpectralConv2dV2(nn.Layer): + def __init__(self, in_channels, out_channels, modes1, modes2): + super(SpectralConv2dV2, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.modes1 = modes1 + self.modes2 = modes2 + self.scale = 1 / (in_channels * out_channels) + self.weights1 = paddle.base.framework.EagerParamBase.from_tensor( + self.scale + * paddle.rand([in_channels, out_channels, self.modes1, self.modes2, 2]) + ) + self.weights2 = paddle.base.framework.EagerParamBase.from_tensor( + self.scale + * paddle.rand([in_channels, out_channels, self.modes1, self.modes2, 2]) + ) + + def forward(self, x: paddle.Tensor): + size_0 = x.shape[-2] + size_1 = x.shape[-1] + batchsize = x.shape[0] + x_ft = paddle.fft.rfft2(x.astype(paddle.float32), axes=(-2, -1), norm="ortho") + x_ft = paddle.as_real(x_ft) + + out_ft = paddle.zeros( + [batchsize, self.out_channels, size_0, size_1 // 2 + 1, 2] + ) + out_ft[:, :, : self.modes1, : self.modes2] = compl_mul2d_v2( + x_ft[:, :, : self.modes1, : self.modes2], self.weights1 + ) + out_ft[:, :, -self.modes1 :, : self.modes2] = compl_mul2d_v2( + x_ft[:, :, -self.modes1 :, : self.modes2], self.weights2 + ) + out_ft = paddle.as_complex(out_ft) + + x = paddle.fft.irfft2(out_ft, axes=(-2, -1), norm="ortho", s=(size_0, size_1)) + + return x + + +class FNN2d_Backbone(nn.Layer): + def __init__( + self, + modes1, + modes2, + width=64, + layers=None, + in_dim=3, + dropout=0, + activation="tanh", + ): + super(FNN2d_Backbone, self).__init__() + + """ + The backbone network. It contains 4 layers of the Fourier layer. + 1. Lift the input to the desire channel dimension by self.fc0 . + 2. 4 layers of the integral operators u' = (W + K)(u). + W defined by self.w; K defined by self.conv . + + input: the solution of the coefficient function and locations (a(x, y), x, y) + input shape: (batchsize, c=3, x=s, y=s) + output: the feature + output shape: (batchsize, c=width, x=s, y=s) + """ + + self.modes1 = modes1 + self.modes2 = modes2 + self.width = width + if layers is None: + self.layers = [width] * 4 + else: + self.layers = layers + self.fc0 = nn.Linear(in_dim, self.layers[0]) + + self.sp_convs = nn.LayerList( + [ + SpectralConv2dV2(in_size, out_size, mode1_num, mode2_num) + for in_size, out_size, mode1_num, mode2_num in zip( + self.layers, self.layers[1:], self.modes1, self.modes2 + ) + ] + ) + + self.dropout = nn.Dropout(p=dropout) + + self.ws = nn.LayerList( + [ + nn.Conv1D(in_size, out_size, 1) + for in_size, out_size in zip(self.layers, self.layers[1:]) + ] + ) + + self.activation = act_func_dict[activation] + + def forward(self, x): + """ + (b,c,h,w) -> (b,1,h,w) + """ + length = len(self.ws) + batchsize = x.shape[0] + size_x, size_y = x.shape[2], x.shape[3] + + x = x.transpose([0, 2, 3, 1]) + x = self.fc0(x) + + x = x.transpose([0, 3, 1, 2]) + + for i, (speconv, w) in enumerate(zip(self.sp_convs, self.ws)): + x1 = speconv(x) + x2 = w(x.view([batchsize, self.layers[i], -1])).view( + [batchsize, self.layers[i + 1], size_x, size_y] + ) + x = x1 + x2 + if i != length - 1: + x = self.activation(x) + x = self.dropout(x) + + return x + + +class FNN2d(nn.Layer): + def __init__( + self, + modes1, + modes2, + width=64, + fc_dim=128, + layers=None, + in_dim=3, + out_dim=1, + dropout=0, + activation="tanh", + mean_constraint=False, + ): + super(FNN2d, self).__init__() + + """ + The overall network. The backbone contains 4 layers of the Fourier layer. + 1. Backbone: + 1) Lift the input to the desire channel dimension by self.fc0 . + 2) 4 layers of the integral operators u' = (W + K)(u). + W defined by self.w; K defined by self.conv . + 2. Project from the channel space to the output space by self.fc1 and self.fc2 . + + input: the solution of the coefficient function and locations (a(x, y), x, y) + input shape: (batchsize, c=3, x=s, y=s) + output: the solution + output shape: (batchsize, c=1, x=s, y=s) + """ + + self.backbone = FNN2d_Backbone( + modes1, modes2, width, layers, in_dim, dropout, activation + ) + self.dropout = nn.Dropout(p=dropout) + self.fc1 = nn.Linear(layers[-1], fc_dim) + self.fc2 = nn.Linear(fc_dim, out_dim) + self.activation = act_func_dict[activation] + self.mean_constraint = mean_constraint + + def forward(self, x): + """ + (b,c,h,w) -> (b,1,h,w) + """ + x = self.backbone(x) + x = x.transpose([0, 2, 3, 1]) + x = self.fc1(x) + x = self.activation(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + x = x.transpose([0, 3, 1, 2]) + + if self.mean_constraint: + x = x - paddle.mean(x, axis=(-2, -1), keepdim=True) + + return x + + def forward_icl(self, x, demo_xs, demo_ys, use_tqdm=False): + """ + x: B, C, H, W + demo_xs: J, C, H, W + demo_ys: J, H, W + """ + C_out = 1 + B, C, H, W = x.shape + + repeat = 1 + p = 0.0 + sigma_range = [0, 0] + x_aug = [] + demo_xs_aug = [] + for _ in range(repeat): + if sum(sigma_range) > 0: + import random + + sigma = random.uniform(*sigma_range) + _kernel = min( + int((sigma * 4 + 1) / 2) * 2 + 1, (x.shape[1] // 2) * 2 - 1 + ) + mask = paddle.nn.functional.dropout(paddle.ones([1, C, H, W]), p=p) + if sum(sigma_range) > 0: + _x_aug = gaussian_blur( + x.clone(), kernel_size=[_kernel, _kernel], sigma=sigma + ) + else: + _x_aug = x.clone() + _x_aug = _x_aug * mask + x_aug.append(_x_aug) + _demo_xs_aug = [] + if sum(sigma_range) > 0: + _demo_xs_aug = gaussian_blur( + demo_xs.clone(), kernel_size=[_kernel, _kernel], sigma=sigma + ) + else: + _demo_xs_aug = demo_xs.clone() + _demo_xs_aug = _demo_xs_aug * mask + demo_xs_aug.append(_demo_xs_aug) + x_aug = paddle.stack(x_aug, axis=0) + demo_xs_aug = paddle.stack(demo_xs_aug, axis=0) + + J = demo_xs.shape[0] + pred0 = self.forward(x) + pred = paddle.stack([self.forward(_x) for _x in x_aug], axis=-1) + C = pred.shape[-1] + demo_pred = [] + idx = 0 + for _demo_xs_aug in demo_xs_aug: + idx = 0 + _demo_pred = [] + while idx < _demo_xs_aug.shape[0]: + _x = _demo_xs_aug[idx : idx + B] + _pred = self.forward(_x) + _demo_pred.append(_pred) + idx += _x.shape[0] + demo_pred.append(paddle.concat(_demo_pred, axis=0)) + demo_pred = paddle.stack(demo_pred, axis=-1) + + demo_pred_flat = demo_pred.view([1, -1, C]) + y_nn = paddle.zeros([B, C_out, H, W]) + stds_nn = paddle.zeros([B, 1, H, W]) + batch_b = 1 + _b = 0 + batch_h = 64 + _h = 0 + batch_w = 64 + _w = 0 + + topk = int(20 * (J**0.5)) + pbar = None + while _b < B: + _h = 0 + while _h < H: + _w = 0 + while _w < W: + if pbar is not None: + pbar.set_description("_b %d, _h %d, _w %d" % (_b, _h, _w)) + pbar.update(1) + pred_flat = pred[ + _b : _b + batch_b, :, _h : _h + batch_h, _w : _w + batch_w + ] + __b, _, __h, __w, _ = pred_flat.shape + pred_flat = pred_flat.view([-1, 1, C]) + + gap = paddle.linalg.norm( + (pred_flat - demo_pred_flat).pow(2) / pred_flat.pow(2), axis=-1 + ) + gap_re = gap.view([__b, __h, __w, -1]) + index = paddle.argsort(paddle.abs(gap_re), -1)[:, :, :, :topk] + _y_nn = paddle.stack( + [ + paddle.take_along_axis( + demo_ys.view([-1, C_out]), + index[:, :, :, _k].view([-1, 1]), + axis=0, + ).view([__b, C_out, __h, __w]) + for _k in range(topk) + ], + -1, + ) + + y_nn[ + _b : _b + batch_b, :, _h : _h + batch_h, _w : _w + batch_w + ] = _y_nn.mean(-1) + stds_nn[ + _b : _b + batch_b, :, _h : _h + batch_h, _w : _w + batch_w + ] = paddle.abs(_y_nn.std(-1) / _y_nn.mean(-1)) + + _w += batch_w + _h += batch_h + _b += batch_b + + mask = (stds_nn < stds_nn.mean()).astype(paddle.float32) + return mask * y_nn + (1 - mask) * pred0 + + +class FNN2d_FewShot_Baseline(nn.Layer): + def __init__( + self, + modes1, + modes2, + width=64, + fc_dim=128, + layers=None, + in_dim=3, + out_dim=1, + dropout=0, + activation="tanh", + mean_constraint=False, + n_demos=7, + ): + super(FNN2d_FewShot_Baseline, self).__init__() + + """ + The overall network. The backbone contains 4 layers of the Fourier layer. + 1. Backbone: + 1) Lift the input to the desire channel dimension by self.fc0 . + 2) 4 layers of the integral operators u' = (W + K)(u). + W defined by self.w; K defined by self.conv . + 2. Project from the channel space to the output space by self.fc1 and self.fc2 . + + input: the solution of the coefficient function and locations (a(x, y), x, y) + input shape: (batchsize, c=3, x=s, y=s) + output: the solution + output shape: (batchsize, c=1, x=s, y=s) + """ + self.in_dim = in_dim + self.C_fno = layers[-1] + self.fc_dim = fc_dim + self.out_dim = out_dim + self.backbone = FNN2d_Backbone( + modes1, modes2, width, layers, in_dim, dropout, activation + ) + self.dropout = nn.Dropout(p=dropout) + self.num_heads = 8 + self.fc1 = nn.Linear( + layers[-1] * (n_demos + 1) + out_dim * n_demos * self.num_heads, fc_dim + ) + self.fc2 = nn.Linear(fc_dim, out_dim) + self.activation = act_func_dict[activation] + self.mean_constraint = mean_constraint + self.n_demos = n_demos + + def forward(self, demo_XY_query_x): + """ + demo_XY_query_x: (b, J*c + J*1 + c, h, w) + """ + demo_X, demo_Y, query_x = [], [], None + B = len(demo_XY_query_x) + C = self.in_dim + J = (demo_XY_query_x.shape[1] - C) // (C + 1) + H, W = demo_XY_query_x.shape[-2:] + query_x = demo_XY_query_x[:, -C:] + demo_X = demo_XY_query_x[:, : J * C] + demo_Y = demo_XY_query_x[:, J * C : -C] + """ + demo_X: (b, J*c, h, w) + demo_Y: (b, J*1, h, w) + query_x: (b, c, h, w) + -> (b,1,h,w) + """ + query_features = self.backbone(query_x).transpose([0, 2, 3, 1]) + demo_features = ( + self.backbone(demo_X.view([B, J, C, H, W]).view([B * J, C, H, W])) + .view([B, J * self.C_fno, H, W]) + .transpose([0, 2, 3, 1]) + ) + + x = paddle.stack( + [ + paddle.concat([query_features[_b], demo_features[_b]], axis=-1) + for _b in range(B) + ], + axis=0, + ) + x = paddle.stack( + [ + paddle.concat( + [ + x[_b], + demo_Y[_b].repeat([self.num_heads, 1, 1]).transpose([1, 2, 0]), + ], + axis=-1, + ) + for _b in range(B) + ], + axis=0, + ) + x = self.fc1(x) + x = self.activation(x) + x = self.dropout(x) + + x = self.fc2(x) + x = self.dropout(x) + + x = x.transpose([0, 3, 1, 2]) + + if self.mean_constraint: + x = x - paddle.mean(x, axis=(-2, -1), keepdim=True) + + return x + + +class FNN2d_FewShot_Spatial_v2(nn.Layer): + def __init__( + self, + modes1, + modes2, + width=64, + fc_dim=128, + layers=None, + in_dim=3, + out_dim=1, + dropout=0, + activation="tanh", + mean_constraint=False, + n_demos=7, + l_attn=1, + down=1, + win_s=8, + c_attn_hidden=1024, + skip_backbone=False, + ): + super(FNN2d_FewShot_Spatial_v2, self).__init__() + + """ + The overall network. The backbone contains 4 layers of the Fourier layer. + 1. Backbone: + 1) Lift the input to the desire channel dimension by self.fc0 . + 2) 4 layers of the integral operators u' = (W + K)(u). + W defined by self.w; K defined by self.conv . + 2. Project from the channel space to the output space by self.fc1 and self.fc2 . + + input: the solution of the coefficient function and locations (a(x, y), x, y) + input shape: (batchsize, c=3, x=s, y=s) + output: the solution + output shape: (batchsize, c=1, x=s, y=s) + """ + self.in_dim = in_dim + self.C_fno = layers[-1] + self.fc_dim = fc_dim + self.out_dim = out_dim + self.down = down + self.win_s = win_s + self.skip_backbone = skip_backbone + if not skip_backbone: + self.backbone = FNN2d_Backbone( + modes1, modes2, width, layers, in_dim, dropout, activation + ) + else: + self.in_dim = self.C_fno + self.dropout = nn.Dropout(p=dropout) + self.num_heads = 8 + self.l_attn = l_attn + self.fc1 = nn.Linear(self.C_fno, fc_dim) + self.fc2 = nn.Linear(fc_dim, out_dim) + self.activation = act_func_dict[activation] + self.mean_constraint = mean_constraint + self.n_demos = n_demos + + def forward(self, demo_XY_query_x): + """ + demo_XY_query_x: (b, J*c + J*1 + c, h, w) + """ + demo_X, demo_Y, query_x = [], [], None + B = len(demo_XY_query_x) + C = self.in_dim + J = (demo_XY_query_x.shape[1] - C) // (C + 1) + H, W = demo_XY_query_x.shape[-2:] + query_x = demo_XY_query_x[:, :C] + demo_X = demo_XY_query_x[:, C : (J + 1) * C] + demo_Y = demo_XY_query_x[:, (J + 1) * C :] + """ + demo_X: (b, J*c, h, w) + demo_Y: (b, J*1, h, w) + query_x: (b, c, h, w) + -> (b,1,h,w) + """ + + if not self.skip_backbone: + query_features = self.backbone(query_x) + demo_features = self.backbone( + demo_X.view([B, J, C, H, W]).view([B * J, C, H, W]) + ).view(B, J, self.C_fno, H, W) + else: + query_features = query_x + demo_features = demo_X.view([B, J, self.C_fno, H, W]) + + self.query_score = None + self._attn_mats = [None] + + y = self.fc1(query_features.transpose([0, 2, 3, 1])) + y = self.activation(y) + y = self.dropout(y) + y = self.fc2(y) + y = self.dropout(y) + y = y.transpose([0, 3, 1, 2]) + if self.mean_constraint: + y = y - paddle.mean(y, axis=(-2, -1), keepdim=True) + + y_demo = self.fc1(demo_features.transpose([0, 1, 3, 4, 2])) + y_demo = self.activation(y_demo) + y_demo = self.dropout(y_demo) + y_demo = self.fc2(y_demo) + y_demo = self.dropout(y_demo) + y_demo = y_demo.transpose([0, 1, 4, 2, 3]) + if self.mean_constraint: + y_demo = y_demo - paddle.mean(y_demo, axis=(-2, -1), keepdim=True) + + B, C, H, W = y.shape + + y_flat = y.view([-1, 1]) + y_demo_flat = y_demo.view([1, -1]) + gap = y_flat - y_demo_flat + gap_re = gap.view([B, C, H, W, -1]) + + index = paddle.argsort(paddle.abs(gap_re), -1) + + topk = 100 + y_nn = 0 + for _k in range(topk): + y_nn += paddle.take(demo_Y.view([-1, 1]), index[:, :, :, :, _k]) + y_nn /= topk + y = (y + y_nn) / 2 + return y_nn + + +def build_fno(params): + if params.mode_cut > 0: + params.modes1 = [params.mode_cut] * len(params.modes1) + params.modes2 = [params.mode_cut] * len(params.modes2) + + if params.embed_cut > 0: + params.layers = [params.embed_cut] * len(params.layers) + + if params.fc_cut > 0 and params.embed_cut > 0: + params.fc_dim = params.embed_cut * params.fc_cut + + input_dim = params.in_dim + + if params.n_demos == 0: + return FNN2d( + params.modes1, + params.modes2, + layers=params.layers, + fc_dim=params.fc_dim, + in_dim=input_dim, + out_dim=params.out_dim, + dropout=params.dropout, + activation="gelu", + mean_constraint=(params.loss_func == "pde"), + ) + else: + if hasattr(params, "baseline") and params.baseline: + return FNN2d_FewShot_Baseline( + params.modes1, + params.modes2, + layers=params.layers, + fc_dim=params.fc_dim, + in_dim=input_dim, + out_dim=params.out_dim, + dropout=params.dropout, + activation="gelu", + mean_constraint=(params.loss_func == "pde"), + n_demos=params.n_demos, + ) + elif hasattr(params, "spatial") and params.spatial: + return FNN2d_FewShot_Spatial_v2( + params.modes1, + params.modes2, + layers=params.layers, + fc_dim=params.fc_dim, + in_dim=input_dim, + out_dim=params.out_dim, + dropout=params.dropout, + activation="gelu", + mean_constraint=(params.loss_func == "pde"), + n_demos=params.n_demos, + l_attn=params.l_attn, + c_attn_hidden=params.c_attn_hidden, + down=params.down, + win_s=params.win_s, + skip_backbone=( + params.train_path.endswith("npy") + and ("feature_data" in params.train_path) + ), + ) + + +class FNN2d_MAE(nn.Layer): + def __init__( + self, + modes1, + modes2, + width=64, + fc_dim=128, + layers=None, + in_dim=3, + out_dim=1, + dropout=0, + activation="tanh", + mean_constraint=False, + ): + super(FNN2d_MAE, self).__init__() + + """ + The overall network. The backbone contains 4 layers of the Fourier layer. + Backbone: + 1) Lift the input to the desire channel dimension by self.fc0 . + 2) 4 layers of the integral operators u' = (W + K)(u). + W defined by self.w; K defined by self.conv . + + input: the solution of the coefficient function and locations (a(x, y), x, y) + input shape: (batchsize, c=3, x=s, y=s) + """ + self.in_dim = in_dim + self.C_fno = layers[-1] + self.fc_dim = fc_dim + self.out_dim = out_dim + self.encoder = FNN2d_Backbone( + modes1, modes2, width, layers, in_dim, dropout, activation + ) + self.decoder = FNN2d_Backbone( + modes1, + modes2, + width, + layers[:-1] + [in_dim], + self.C_fno, + dropout, + activation, + ) + self.dropout = nn.Dropout(p=dropout) + self.encoder_to_decoder = nn.Linear(self.C_fno, self.C_fno) + self.activation = act_func_dict[activation] + self.mean_constraint = mean_constraint + + def forward(self, x, mask=None): + """ + x: (b, c, h, w) + """ + if mask is None: + x_enc = self.encoder(x) + else: + x_enc = self.encoder(x * mask) + x_enc = self.encoder_to_decoder(x_enc.transpose([0, 2, 3, 1])).transpose( + [0, 3, 1, 2] + ) + x_dec = self.decoder(x_enc) + return x_dec + + +def fno_pretrain(params): + if params.mode_cut > 0: + params.modes1 = [params.mode_cut] * len(params.modes1) + params.modes2 = [params.mode_cut] * len(params.modes2) + + if params.embed_cut > 0: + params.layers = [params.embed_cut] * len(params.layers) + + if params.fc_cut > 0 and params.embed_cut > 0: + params.fc_dim = params.embed_cut * params.fc_cut + + input_dim = params.in_dim + + return FNN2d_MAE( + params.modes1, + params.modes2, + layers=params.layers, + fc_dim=params.fc_dim, + in_dim=input_dim, + out_dim=params.out_dim, + dropout=params.dropout, + activation="gelu", + mean_constraint=(params.loss_func == "pde"), + ) + + +def _cast_squeeze_in(img: Tensor, req_dtypes: List[paddle.dtype]): + need_squeeze = False + if img.ndim < 4: + img = img.unsqueeze(axis=0) + need_squeeze = True + + out_dtype = img.dtype + need_cast = False + if out_dtype not in req_dtypes: + need_cast = True + req_dtype = req_dtypes[0] + img = img.to(req_dtype) + return img, need_cast, need_squeeze, out_dtype + + +def _get_gaussian_kernel1d( + kernel_size: int, sigma: float, dtype: paddle.dtype +) -> Tensor: + ksize_half = (kernel_size - 1) * 0.5 + + x = paddle.linspace(-ksize_half, ksize_half, num=kernel_size, dtype=dtype) + pdf = paddle.exp(-0.5 * (x / sigma).pow(2)) + kernel1d = pdf / pdf.sum() + + return kernel1d + + +def _get_gaussian_kernel2d( + kernel_size: List[int], sigma: List[float], dtype: paddle.dtype +) -> Tensor: + kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype) + kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype) + kernel2d = paddle.mm(kernel1d_y[:, None], kernel1d_x[None, :]) + return kernel2d + + +def _cast_squeeze_out( + img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: paddle.dtype +) -> Tensor: + if need_squeeze: + img = img.squeeze(axis=0) + + if need_cast: + if out_dtype in ( + paddle.uint8, + paddle.int8, + paddle.int16, + paddle.int32, + paddle.int64, + ): + img = paddle.round(img) + img = img.to(out_dtype) + + return img + + +def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor: + if sigma is None: + sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] + + if sigma is not None and not isinstance(sigma, (int, float, list, tuple)): + raise TypeError( + f"sigma should be either float or sequence of floats. Got {type(sigma)}" + ) + if isinstance(sigma, (int, float)): + sigma = [float(sigma), float(sigma)] + if isinstance(sigma, (list, tuple)) and len(sigma) == 1: + sigma = [sigma[0], sigma[0]] + if len(sigma) != 2: + raise ValueError( + f"If sigma is a sequence, its length should be 2. Got {len(sigma)}" + ) + for s in sigma: + if s <= 0.0: + raise ValueError(f"sigma should have positive values. Got {sigma}") + + dtype = img.dtype if paddle.is_floating_point(img) else paddle.float32 + kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype) + kernel = kernel.expand([img.shape[-3], 1, kernel.shape[0], kernel.shape[1]]) + + img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype]) + + padding = [ + kernel_size[0] // 2, + kernel_size[0] // 2, + kernel_size[1] // 2, + kernel_size[1] // 2, + ] + img = paddle.nn.functional.pad(img, padding, mode="reflect") + img = paddle.nn.functional.conv2d(img, kernel, groups=img.shape[-3]) + + img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) + return img diff --git a/ppsci/data/dataset/data_efficient_nopt_dataset.py b/ppsci/data/dataset/data_efficient_nopt_dataset.py new file mode 100644 index 000000000..8d20d95e3 --- /dev/null +++ b/ppsci/data/dataset/data_efficient_nopt_dataset.py @@ -0,0 +1,953 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# refs: https://github.com/delta-lab-ai/data_efficient_nopt + +import glob +import os +from typing import Iterator +from typing import TypeVar + +import h5py +import numpy as np +import paddle +from paddle.io import DataLoader +from paddle.io import Dataset +from paddle.io import DistributedBatchSampler +from paddle.io import RandomSampler +from paddle.io import Sampler + +from ppsci.utils import logger + +__all__ = [ + "MultisetSampler", +] + +T_co = TypeVar("T_co", covariant=True) +broken_paths = [] + + +class BaseHDF5DirectoryDataset(Dataset): + """ + Base class for data loaders. Returns data in T x B x C x H x W format. + + Note - doesn't currently normalize because the data is on wildly different + scales but probably should. + + Split is provided so I can be lazy and not separate out HDF5 files. + + Takes in path to directory of HDF5 files to construct dset. + + Args: + path (str): Path to directory of HDF5 files + include_string (str): Only include files with this string in name + n_steps (int): Number of steps to include in each sample + dt (int): Time step between samples + split (str): train/val/test split + train_val_test (tuple): Percent of data to use for train/val/test + subname (str): Name to use for dataset + split_level (str): 'sample' or 'file' - whether to split by samples within a file + (useful for data segmented by parameters) or file (mostly INS right now) + """ + + def __init__( + self, + path, + include_string="", + n_steps=1, + dt=1, + split="train", + train_val_test=None, + subname=None, + extra_specific=False, + rollout=1, + ): + super().__init__() + self.path = path + self.split = split + self.extra_specific = extra_specific + if subname is None: + self.subname = path.split("/")[-1] + else: + self.subname = subname + self.dt = dt + self.rollout = rollout + self.n_steps = n_steps + self.include_string = include_string + self.train_val_test = train_val_test + self.partition = {"train": 0, "val": 1, "test": 2}[split] + ( + self.time_index, + self.sample_index, + self.field_names, + self.type, + self.split_level, + ) = self._specifics() + self._get_directory_stats(path) + if self.extra_specific: + self.title = self.more_specific_title(self.type, path, include_string) + else: + self.title = self.type + + def get_name(self, full_name=False): + if full_name: + return self.subname + "_" + self.type + else: + return self.type + + def more_specific_title(self, type, path, include_string): + """ + Override this to add more info to the dataset name + """ + return type + + @staticmethod + def _specifics(): + raise NotImplementedError + + def get_per_file_dsets(self): + if self.split_level == "file" or len(self.files_paths) == 1: + return [self] + else: + sub_dsets = [] + for file in self.files_paths: + subd = self.__class__( + self.path, + file, + n_steps=self.n_steps, + dt=self.dt, + split=self.split, + train_val_test=self.train_val_test, + subname=self.subname, + extra_specific=True, + ) + sub_dsets.append(subd) + return sub_dsets + + def _get_specific_stats(self, f): + raise NotImplementedError + + def _get_specific_bcs(self, f): + raise NotImplementedError + + def _reconstruct_sample(self, file, sample_idx, time_idx, n_steps): + raise NotImplementedError + + def _get_directory_stats(self, path): + self.files_paths = glob.glob(path + "/*.h5") + glob.glob(path + "/*.hdf5") + self.files_paths.sort() + self.n_files = len(self.files_paths) + self.file_steps = [] + self.file_nsteps = [] + self.file_samples = [] + self.split_offsets = [] + self.offsets = [0] + file_paths = [] + for file in self.files_paths: + if len(self.include_string) > 0 and self.include_string not in file: + continue + elif file in broken_paths: + continue + else: + file_paths.append(file) + try: + with h5py.File(file, "r") as _f: + samples, steps = self._get_specific_stats(_f) + if steps - self.n_steps - (self.dt - 1) < 1: + logger.warning( + "WARNING: File {} has {} steps, but n_steps is {}. Setting file steps = max allowable.".format( + file, steps, self.n_steps + ) + ) + file_nsteps = steps - self.dt + else: + file_nsteps = self.n_steps + self.file_nsteps.append(file_nsteps) + self.file_steps.append(steps - file_nsteps - (self.dt - 1)) + if self.split_level == "sample": + partition = self.partition + sample_per_part = np.ceil( + np.absolute(np.array(self.train_val_test) * samples) + ).astype(int) + sample_per_part[2] = max( + samples - sample_per_part[0] - sample_per_part[1], 0 + ) + if self.train_val_test[0] >= 0: + self.split_offsets.append( + self.file_steps[-1] + * sum(sample_per_part[:partition]) + ) + else: + if partition == 0: + self.split_offsets.append( + self.file_steps[-1] * (1 - sum(sample_per_part)) + ) + else: + self.split_offsets.append( + self.file_steps[-1] + * (1 - sum(sample_per_part[partition:])) + ) + split_samples = sample_per_part[partition] + else: + split_samples = samples + self.file_samples.append(split_samples) + self.offsets.append( + self.offsets[-1] + + (steps - file_nsteps - (self.dt - 1)) * split_samples + ) + except: # noqa + logger.warning( + "WARNING: Failed to open file {}. Continuing without it.".format( + file + ) + ) + raise RuntimeError("Failed to open file {}".format(file)) + self.files_paths = file_paths + self.offsets[0] = -1 + self.files = [None for _ in self.files_paths] + self.len = self.offsets[-1] + if self.split_level == "file": + if self.train_val_test is None: + logger.warning( + "WARNING: No train/val/test split specified. Using all data for training." + ) + self.split_offset = 0 + self.len = self.offsets[-1] + else: + total_samples = sum(self.file_samples) + if ( + self.train_val_test[1] * total_samples < 1 + or self.train_val_test[2] * total_samples < 1 + ): + ideal_split_offsets = [ + self.train_val_test[i] * total_samples for i in range(3) + ] + ideal_split_offsets = [ + int(value) if value >= 1 else value + for value in ideal_split_offsets + ] + else: + ideal_split_offsets = [ + int(self.train_val_test[i] * total_samples) for i in range(3) + ] + if ideal_split_offsets[0] > 0: + end_ind = 0 + elif ideal_split_offsets[0] == 0: + ideal_split_offsets[0] = abs(self.train_val_test[0] * total_samples) + assert ideal_split_offsets[0] < 1 and ideal_split_offsets[0] > 0 + end_ind = total_samples - round(sum(ideal_split_offsets[1:])) - 1 + else: + ideal_split_offsets[0] = -ideal_split_offsets[0] + end_ind = ( + total_samples + - round(sum(ideal_split_offsets[1:])) + - ideal_split_offsets[0] + ) + for i in range(self.partition + 1): + run_sum = 0 + start_ind = end_ind + for samples, steps in zip(self.file_samples, self.file_steps): + run_sum += samples + if run_sum <= ideal_split_offsets[i]: + end_ind += round(samples * (steps)) + if run_sum == ideal_split_offsets[i]: + break + else: + end_ind += round( + np.abs((run_sum - samples) - ideal_split_offsets[i]) + * (steps) + ) + break + start_ind, end_ind = int(start_ind), int(end_ind) + self.split_offset = start_ind + self.len = end_ind - start_ind + + def _open_file(self, file_ind): + _file = h5py.File(self.files_paths[file_ind], "r") + self.files[file_ind] = _file + + def __getitem__(self, index): + if self.split_level == "file": + index = index + self.split_offset + file_idx = int(np.searchsorted(self.offsets, index, side="right") - 1) + nsteps = self.file_nsteps[file_idx] + self.rollout - 1 + local_idx = index - max(self.offsets[file_idx], 0) + if self.split_level == "sample": + sample_idx = (local_idx + self.split_offsets[file_idx]) // self.file_steps[ + file_idx + ] + else: + sample_idx = local_idx // self.file_steps[file_idx] + time_idx = local_idx % self.file_steps[file_idx] + + if self.files[file_idx] is None: + self._open_file(file_idx) + + time_idx = ( + time_idx - self.dt if time_idx >= self.file_steps[file_idx] else time_idx + ) + time_idx += nsteps + trajectory = self._reconstruct_sample( + self.files[file_idx], sample_idx, time_idx, nsteps + ) + try: + trajectory = self._reconstruct_sample( + self.files[file_idx], sample_idx, time_idx, nsteps + ) + except: # noqa: + raise RuntimeError( + f"Failed to reconstruct sample for file {self.files_paths[file_idx]} sample {sample_idx} time {time_idx}" + ) + return trajectory[:-1], trajectory[-1] + + def __len__(self): + return self.len + + +class DiffRe2DDataset(BaseHDF5DirectoryDataset): + @staticmethod + def _specifics(): + time_index = 0 + sample_index = None + field_names = ["activator", "inhibitor"] + type = "diffre2d" + split_level = "sample" + return time_index, sample_index, field_names, type, split_level + + def _get_specific_stats(self, f): + samples = list(f.keys()) + steps = f[samples[0]]["data"].shape[0] + return len(samples), steps + + def _get_specific_bcs(self, f): + return [0, 0] + + def _reconstruct_sample(self, file, sample_idx, time_idx, n_steps): + samples = list(file.keys()) + return file[samples[sample_idx]]["data"][ + time_idx - n_steps * self.dt : time_idx + self.dt + ].transpose(0, 3, 1, 2) + + +class IncompNSDataset(BaseHDF5DirectoryDataset): + """ + Order Vx, Vy, "particles" + """ + + @staticmethod + def _specifics(): + time_index = 1 + sample_index = 0 + field_names = ["Vx", "Vy", "particles"] + type = "incompNS" + split_level = "file" + return time_index, sample_index, field_names, type, split_level + + def _get_specific_stats(self, f): + samples = f["velocity"].shape[0] + steps = f["velocity"].shape[1] + return samples, steps + + def _reconstruct_sample(self, file, sample_idx, time_idx, n_steps): + velocity = file["velocity"][ + sample_idx, time_idx - n_steps * self.dt : time_idx + self.dt + ] + particles = file["particles"][ + sample_idx, time_idx - n_steps * self.dt : time_idx + self.dt + ] + comb = np.concatenate([velocity, particles], -1) + return comb.transpose((0, 3, 1, 2)) + + def _get_specific_bcs(self, f): + return [0, 0] + + +class TubeMaskingGenerator: + def __init__(self, input_size, mask_ratio): + assert mask_ratio < 1 and mask_ratio >= 0 + self.mask_ratio = mask_ratio + self.frames, self.height, self.width = input_size + self.num_patches_per_frame = self.height * self.width + self.total_patches = self.frames * self.num_patches_per_frame + self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) + self.total_masks = self.frames * self.num_masks_per_frame + + def __repr__(self): + repr_str = "Maks: total patches {}, mask patches {}".format( + self.total_patches, self.total_masks + ) + return repr_str + + def __call__(self): + if self.mask_ratio > 0: + mask_per_frame = np.hstack( + [ + np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), + np.ones(self.num_masks_per_frame), + ] + ) + elif self.mask_ratio == 0: + mask_per_frame = np.hstack( + [ + np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), + ] + ) + np.random.shuffle(mask_per_frame) + mask = np.tile(mask_per_frame, (self.frames, 1)).flatten() + return mask + + +class MaskingGenerator: + def __init__(self, input_size, mask_ratio): + assert mask_ratio < 1 and mask_ratio >= 0 + self.height, self.width = input_size + self.mask_ratio = mask_ratio + self.frames = 1 + self.num_patches_per_frame = self.height * self.width + self.total_patches = self.frames * self.num_patches_per_frame + self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) + self.total_masks = self.frames * self.num_masks_per_frame + + def __repr__(self): + repr_str = "Maks: total patches {}, mask patches {}".format( + self.total_patches, self.total_masks + ) + return repr_str + + def __call__(self): + if self.mask_ratio > 0: + mask_per_frame = np.hstack( + [ + np.zeros(self.num_masks_per_frame), + np.ones(self.num_patches_per_frame - self.num_masks_per_frame), + ] + ) + else: + mask_per_frame = np.hstack( + [ + np.ones(self.num_patches_per_frame - self.num_masks_per_frame), + ] + ) + np.random.shuffle(mask_per_frame) + mask = np.tile(mask_per_frame, (self.frames, 1)).flatten() + return mask.astype(np.float16) + + +class MultisetSampler(Sampler[T_co]): + r"""Sampler that restricts data loading to a subset of the dataset.""" + + def __init__( + self, + dataset: Dataset, + base_sampler: Sampler, + batch_size: int, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = True, + max_samples=10, + rank=0, + distributed=True, + ) -> None: + self.batch_size = batch_size + self.sub_dsets = dataset.sub_dsets + if distributed: + self.sub_samplers = [ + base_sampler(dataset, drop_last=drop_last) for dataset in self.sub_dsets + ] + else: + self.sub_samplers = [base_sampler(dataset) for dataset in self.sub_dsets] + self.dataset = dataset + self.epoch = 0 + self.drop_last = drop_last + self.shuffle = shuffle + self.seed = seed + self.max_samples = max_samples + self.rank = rank + + def __iter__(self) -> Iterator[T_co]: + samplers = [iter(sampler) for sampler in self.sub_samplers] + sampler_choices = list(range(len(samplers))) + count = 0 + while len(sampler_choices) > 0: + count += 1 + index_sampled = paddle.randint(0, len(sampler_choices), shape=(1,)).item() + dset_sampled = sampler_choices[index_sampled] + offset = max(0, self.dataset.offsets[dset_sampled]) + try: + queue = [] + for i in range(self.batch_size): + queue.append(next(samplers[dset_sampled]) + offset) + if len(queue) == self.batch_size: + for d in queue: + yield d + except Exception as err: + logger.error("ERRRR", err) + sampler_choices.pop(index_sampled) + logger.warning( + f"Note: dset {dset_sampled} fully used. Dsets remaining: {len(sampler_choices)}" + ) + continue + if count >= self.max_samples: + break + + def __len__(self) -> int: + return len(self.dataset) + + def set_epoch(self, epoch: int) -> None: + r""" + Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + for sampler in self.sub_samplers: + sampler.set_epoch(epoch) + self.epoch = epoch + + +def PoisHelmDatasetLoader(params, location, distributed, train=True): + transform = paddle.to_tensor + dataset = PoisHelmDataset(params, location, transform, train) + + sampler = DistributedBatchSampler(dataset, shuffle=train) if distributed else None + dataloader = DataLoader( + dataset, + batch_size=int(params.batch_size), + num_workers=params.num_data_workers, + shuffle=False, + drop_last=True, + ) + return dataloader, dataset, sampler + + +class PoisHelmDataset(Dataset): + def __init__(self, params, location, transform, train): + self.transform = transform + self.params = params + self.location = location + self.train = train + self.masking = params.mask_ratio if hasattr(params, "mask_ratio") else False + if hasattr(self.params, "subsample") and (self.train): + self.subsample = self.params.subsample + else: + self.subsample = 1 + self.scales = None + self._get_files_stats() + if isinstance(self.masking, float): + self.mask_generator = MaskingGenerator( + (self.img_shape_x, self.img_shape_y), self.masking + ) + file = self._open_file(self.location) + self.data = file["fields"] + if self.train: + if hasattr(self.params, "train_rand_idx_path"): + self.train_rand_idx = np.load(self.params.train_rand_idx_path) + logger.info("Randomizing train dataset using given random index path") + else: + self.train_rand_idx = range(self.data.shape[0]) + self.train_rand_idx = self.train_rand_idx[self.pt_idxs[0] : self.pt_idxs[1]] + self.data = self.data[()][self.train_rand_idx, ...] + logger.info( + "Getting only data idx for training set for length: {}".format( + len(self.train_rand_idx) + ) + ) + if "tensor" in list(file.keys()): + self.tensor = file["tensor"] + if self.train: + self.tensor = self.tensor[()][self.train_rand_idx, ...] + else: + self.tensor = None + + def _get_files_stats(self): + self.file = self.location + with h5py.File(self.file, "r") as _f: + logger.info("Getting file stats from {}".format(self.file)) + if len(_f["fields"].shape) == 4: + self.n_demos = None + self.n_samples = _f["fields"].shape[0] + self.img_shape_x = _f["fields"].shape[2] + self.img_shape_y = _f["fields"].shape[3] + self.in_channels = _f["fields"].shape[1] - 1 + elif len(_f["fields"].shape) == 5: + self.n_demos = _f["fields"].shape[2] + assert self.n_demos >= self.params.n_demos + self.n_samples = _f["fields"].shape[0] + self.img_shape_x = _f["fields"].shape[3] + self.img_shape_y = _f["fields"].shape[4] + self.in_channels = _f["fields"].shape[1] - 1 + if "tensor" in list(_f.keys()): + self.tensor_shape = _f["tensor"].shape[1] + else: + self.tensor_shape = 0 + if self.train: + if hasattr(self.params, "pt_split"): + self.pt_split = self.params.pt_split + else: + self.pt_split = [0.9, 0.1] + logger.info( + "Split training set into {} for pretrain, {} for train. ".format( + self.pt_split[0], self.pt_split[1] + ) + ) + if hasattr(self.params, "pt"): + self.pt = self.params.pt + else: + self.pt = "train" + if int(sum(self.pt_split)) == 1: + self.n_samples *= self.pt_split[-1 if self.pt == "train" else 0] + else: + assert int(sum(self.pt_split)) <= self.n_samples + self.n_samples = self.pt_split[-1 if self.pt == "train" else 0] + self.n_samples = int(self.n_samples) + self.pt_idxs = ( + [-self.n_samples, None] if self.pt == "train" else [0, self.n_samples] + ) + self.n_samples /= self.subsample + self.n_samples = int(self.n_samples) + logger.info( + "Found data at path {}. Number of examples: {}. Image Shape: {} x {}".format( + self.location, self.n_samples, self.img_shape_x, self.img_shape_y + ) + ) + if hasattr(self.params, "scales_path"): + self.scales = np.load(self.params.scales_path) + self.scales = np.array([s if s != 0 else 1 for s in self.scales]) + self.scales = self.scales.astype("float32") + measure_x = self.scales[-2] / self.img_shape_x + measure_y = self.scales[-1] / self.img_shape_y + self.measure = measure_x * measure_y + logger.info( + "Scales for PDE are (source, tensor, sol, domain): {}".format( + self.scales + ) + ) + logger.info( + "Measure of the set is lx/nx * ly/ny = {}/{} * {}/{}".format( + self.scales[-2], self.img_shape_x, self.scales[-1], self.img_shape_y + ) + ) + + def __len__(self): + return self.n_samples + + def _open_file(self, path): + return h5py.File(path, "r") + + def _getitem_single(self, local_idx): + if self.params.n_demos == 0: + if self.n_demos is None: + X = self.data[local_idx, 0 : self.in_channels] + else: + X = self.data[local_idx, 0 : self.in_channels, 0] + else: + if self.train: + demo_indices = np.random.choice( + range(self.n_demos), self.params.n_demos, replace=False + ) + X = np.take( + self.data[local_idx, 0 : self.in_channels], + np.insert(demo_indices, 0, 0), + 1, + ) + else: + X = self.data[ + local_idx, 0 : self.in_channels, : self.params.n_demos + 1 + ] + if self.tensor is not None: + tensor = [] + for tidx in range(self.tensor_shape): + coef = np.full( + (1, self.img_shape_x, self.img_shape_y), + self.tensor[local_idx, tidx], + ) + tensor.append(coef) + X = np.concatenate([X] + tensor, axis=0).astype("float32") + + if self.scales is not None: + f_norm = np.linalg.norm(X[0]) * self.measure + f_scaling = f_norm / self.scales[0] + X = X / f_scaling + X[self.in_channels :] = ( + X[self.in_channels :] + / self.scales[ + self.in_channels : (self.in_channels + self.tensor_shape), + None, + None, + ] + ) + + X = self.transform(X) + + if self.params.n_demos == 0: + if self.n_demos is None: + y = self.data[local_idx, self.in_channels :] + else: + y = self.data[local_idx, self.in_channels :, 0] + else: + if self.train: + y = np.take( + self.data[local_idx, self.in_channels :], + np.insert(demo_indices, 0, 0), + 1, + ) + else: + y = self.data[local_idx, self.in_channels :, : self.params.n_demos + 1] + y = self.transform(y) + + if isinstance(self.masking, float): + mask = self.mask_generator().reshape(1, self.img_shape_x, self.img_shape_y) + return X, y, mask + else: + return X, y + + def __getitem__(self, idx): + local_idx = int(idx * self.subsample) + if self.params.n_demos > 0 and self.n_demos is None: + candidate_idx = list(range(self.n_samples)) + candidate_idx.remove(idx) + idx_range = ( + ( + np.random.choice( + candidate_idx, size=self.params.n_demos, replace=False + ) + * self.subsample + ) + .astype(int) + .tolist() + ) + idx_range.append(local_idx) + X, Y, y = [], [], [] + _X, y = self._getitem_single(idx_range[-1]) + X.append(_X) + for idx in idx_range[:-1]: + _X, _y = self._getitem_single(idx) + X.append(_X) + Y.append(_y) + X += Y + X = paddle.concat(X, axis=0) + return X, y + else: + mask = None + _data = self._getitem_single(local_idx) + if len(_data) == 2: + X, y = _data + else: + X, y, mask = _data + if self.params.n_demos > 0: + X = paddle.concat( + [ + X.view([-1, self.img_shape_x, self.img_shape_y]), + y[:, 1:].view([-1, self.img_shape_x, self.img_shape_y]), + ], + axis=0, + ) + y = y[:, 0] + if mask is None: + return X, y + else: + return X, y, mask + + +def MixedDatasetLoader( + params, paths, distributed, split="train", rank=0, train_offset=0 +): + train_val_test = params.train_val_test + if split == "pretrain": + train_val_test = [ + params.train_val_test[0] * params.pretrain_train[0], + train_val_test[1], + train_val_test[2], + ] + split = "train" + elif split == "train": + train_val_test = [ + -params.train_val_test[0] + * params.pretrain_train[1] + * params.train_subsample, + train_val_test[1], + train_val_test[2], + ] + dataset = MixedDataset( + paths, + n_steps=params.n_steps, + train_val_test=train_val_test, + split=split, + tie_fields=params.tie_fields, + use_all_fields=params.use_all_fields, + enforce_max_steps=params.enforce_max_steps, + train_offset=train_offset, + masking=params.masking if hasattr(params, "masking") else None, + blur=params.blur if hasattr(params, "blur") else None, + rollout=getattr(params, "rollout", 1), + ) + if distributed: + base_sampler = DistributedBatchSampler + else: + base_sampler = RandomSampler + sampler = MultisetSampler( + dataset, + base_sampler, + params.batch_size, + distributed=distributed, + max_samples=params.epoch_size, + rank=rank, + ) + dataloader = DataLoader( + dataset, + batch_size=int(params.batch_size), + num_workers=params.num_data_workers, + shuffle=False, + drop_last=True, + ) + return dataloader, dataset, sampler + + +DSET_NAME_TO_OBJECT = { + "incompNS": IncompNSDataset, + "diffre2d": DiffRe2DDataset, +} + + +class MixedDataset(Dataset): + def __init__( + self, + path_list=[], + n_steps=1, + dt=1, + train_val_test=(0.8, 0.1, 0.1), + split="train", + tie_fields=True, + use_all_fields=True, + extended_names=False, + enforce_max_steps=False, + train_offset=0, + masking=None, + blur=None, + rollout=1, + ): + super().__init__() + self.train_offset = train_offset + self.path_list, self.type_list, self.include_string = zip(*path_list) + self.tie_fields = tie_fields + self.extended_names = extended_names + self.split = split + self.sub_dsets = [] + self.offsets = [0] + self.train_val_test = train_val_test + self.use_all_fields = use_all_fields + self.rollout = rollout + + for dset, path, include_string in zip( + self.type_list, self.path_list, self.include_string + ): + subdset = DSET_NAME_TO_OBJECT[dset]( + path, + include_string, + n_steps=n_steps, + dt=dt, + train_val_test=train_val_test, + split=split, + rollout=self.rollout, + ) + try: + len(subdset) + except ValueError: + raise ValueError( + f"Dataset {path} is empty. Check that n_steps < trajectory_length in file." + ) + self.sub_dsets.append(subdset) + self.offsets.append(self.offsets[-1] + len(self.sub_dsets[-1])) + self.offsets[0] = -1 + + self.subset_dict = self._build_subset_dict() + + self.masking = masking + if ( + self.masking + and type(self.masking) in [tuple, list] + and len(self.masking) == 2 + ): + self.mask_generator = TubeMaskingGenerator(self.masking[0], self.masking[1]) + self.blur = blur + + def get_state_names(self): + name_list = [] + if self.use_all_fields: + for name, dset in DSET_NAME_TO_OBJECT.items(): + field_names = dset._specifics()[2] + name_list += field_names + return name_list + else: + visited = set() + for dset in self.sub_dsets: + name = dset.get_name() + if name not in visited: + visited.add(name) + name_list.append(dset.field_names) + return [f for fl in name_list for f in fl] + + def _build_subset_dict(self): + if self.tie_fields: + subset_dict = { + "swe": [3], + "incompNS": [0, 1, 2], + "compNS": [0, 1, 2, 3], + "diffre2d": [4, 5], + } + elif self.use_all_fields: + cur_max = 0 + subset_dict = {} + for name, dset in DSET_NAME_TO_OBJECT.items(): + field_names = dset._specifics()[2] + subset_dict[name] = list(range(cur_max, cur_max + len(field_names))) + cur_max += len(field_names) + else: + subset_dict = {} + cur_max = self.train_offset + for dset in self.sub_dsets: + name = dset.get_name(self.extended_names) + if name not in subset_dict: + subset_dict[name] = list( + range(cur_max, cur_max + len(dset.field_names)) + ) + cur_max += len(dset.field_names) + return subset_dict + + def __getitem__(self, index): + file_idx = np.searchsorted(self.offsets, index, side="right") - 1 + local_idx = index - max(self.offsets[file_idx], 0) + + x, y = self.sub_dsets[file_idx][local_idx] + try: + x, y = self.sub_dsets[file_idx][local_idx] + except: # noqa + logger.error( + "FAILED AT ", file_idx, local_idx, index, int(os.environ.get("RANK", 0)) + ) + + if ( + self.masking + and type(self.masking) in [tuple, list] + and len(self.masking) == 2 + ): + mask = self.mask_generator() + return x, y, mask + else: + return x, y + + def __len__(self): + return sum([len(dset) for dset in self.sub_dsets])