Skip to content

Commit caa67ac

Browse files
add pirate network
1 parent 6d1d4ef commit caa67ac

File tree

9 files changed

+401
-0
lines changed

9 files changed

+401
-0
lines changed

docs/source/_rst/_code.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ Models
104104
LowRankNeuralOperator <model/low_rank_neural_operator.rst>
105105
GraphNeuralOperator <model/graph_neural_operator.rst>
106106
GraphNeuralKernel <model/graph_neural_operator_integral_kernel.rst>
107+
PirateNet <model/pirate_network.rst>
107108

108109
Blocks
109110
-------------
@@ -121,6 +122,7 @@ Blocks
121122
Continuous Convolution Interface <model/block/convolution_interface.rst>
122123
Continuous Convolution Block <model/block/convolution.rst>
123124
Orthogonal Block <model/block/orthogonal.rst>
125+
PirateNet Block <model/block/pirate_network_block.rst>
124126

125127
Message Passing
126128
-------------------
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
PirateNet Block
2+
=======================================
3+
.. currentmodule:: pina.model.block.pirate_network_block
4+
5+
.. autoclass:: PirateNetBlock
6+
:members:
7+
:show-inheritance:
8+
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
PirateNet
2+
=======================
3+
.. currentmodule:: pina.model.pirate_network
4+
5+
.. autoclass:: PirateNet
6+
:members:
7+
:show-inheritance:

pina/model/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"LowRankNeuralOperator",
1414
"Spline",
1515
"GraphNeuralOperator",
16+
"PirateNet",
1617
]
1718

1819
from .feed_forward import FeedForward, ResidualFeedForward
@@ -24,3 +25,4 @@
2425
from .low_rank_neural_operator import LowRankNeuralOperator
2526
from .spline import Spline
2627
from .graph_neural_operator import GraphNeuralOperator
28+
from .pirate_network import PirateNet

pina/model/block/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"LowRankBlock",
1919
"RBFBlock",
2020
"GNOBlock",
21+
"PirateNetBlock",
2122
]
2223

2324
from .convolution_2d import ContinuousConvBlock
@@ -35,3 +36,4 @@
3536
from .low_rank_block import LowRankBlock
3637
from .rbf_block import RBFBlock
3738
from .gno_block import GNOBlock
39+
from .pirate_network_block import PirateNetBlock
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""Module for the PirateNet block class."""
2+
3+
import torch
4+
from ...utils import check_consistency, check_positive_integer
5+
6+
7+
class PirateNetBlock(torch.nn.Module):
8+
"""
9+
The inner block of Physics-Informed residual adaptive network (PirateNet).
10+
11+
The block consists of three dense layers with dual gating operations and an
12+
adaptive residual connection. The trainable ``alpha`` parameter controls
13+
the contribution of the residual connection.
14+
15+
.. seealso::
16+
17+
**Original reference**:
18+
Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025).
19+
*Simulating Three-dimensional Turbulence with Physics-informed Neural
20+
Networks*.
21+
DOI: `arXiv preprint arXiv:2507.08972.
22+
<https://arxiv.org/abs/2507.08972>`_
23+
"""
24+
25+
def __init__(self, inner_size, activation):
26+
"""
27+
Initialization of the :class:`PirateNetBlock` class.
28+
29+
:param int inner_size: The number of hidden units in the dense layers.
30+
:param torch.nn.Module activation: The activation function.
31+
"""
32+
super().__init__()
33+
34+
# Check consistency
35+
check_consistency(activation, torch.nn.Module, subclass=True)
36+
check_positive_integer(inner_size, strict=True)
37+
38+
# Initialize the linear transformations of the dense layers
39+
self.linear1 = torch.nn.Linear(inner_size, inner_size)
40+
self.linear2 = torch.nn.Linear(inner_size, inner_size)
41+
self.linear3 = torch.nn.Linear(inner_size, inner_size)
42+
43+
# Initialize the scales of the dense layers
44+
self.scale1 = torch.nn.Parameter(torch.zeros(inner_size))
45+
self.scale2 = torch.nn.Parameter(torch.zeros(inner_size))
46+
self.scale3 = torch.nn.Parameter(torch.zeros(inner_size))
47+
48+
# Initialize the adaptive residual connection parameter
49+
self._alpha = torch.nn.Parameter(torch.zeros(1))
50+
51+
# Initialize the activation function
52+
self.activation = activation()
53+
54+
def forward(self, x, U, V):
55+
"""
56+
Forward pass of the PirateNet block. It computes the output of the block
57+
by applying the dense layers with scaling, and combines the results with
58+
the input using the adaptive residual connection.
59+
60+
:param x: The input tensor.
61+
:type x: torch.Tensor | LabelTensor
62+
:param torch.Tensor U: The first shared gating tensor. It must have the
63+
same shape as ``x``.
64+
:param torch.Tensor V: The second shared gating tensor. It must have the
65+
same shape as ``x``.
66+
:return: The output tensor of the block.
67+
:rtype: torch.Tensor | LabelTensor
68+
"""
69+
# Compute the output of the first dense layer with scaling
70+
f = self.activation(self.linear1(x) * torch.exp(self.scale1))
71+
z1 = f * U + (1 - f) * V
72+
73+
# Compute the output of the second dense layer with scaling
74+
g = self.activation(self.linear2(z1) * torch.exp(self.scale2))
75+
z2 = g * U + (1 - g) * V
76+
77+
# Compute the output of the block
78+
h = self.activation(self.linear3(z2) * torch.exp(self.scale3))
79+
return self._alpha * h + (1 - self._alpha) * x
80+
81+
@property
82+
def alpha(self):
83+
"""
84+
Return the alpha parameter.
85+
86+
:return: The alpha parameter controlling the residual connection.
87+
:rtype: torch.nn.Parameter
88+
"""
89+
return self._alpha

pina/model/pirate_network.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""Module for the PirateNet model class."""
2+
3+
import torch
4+
from .block import FourierFeatureEmbedding, PirateNetBlock
5+
from ..utils import check_consistency, check_positive_integer
6+
7+
8+
class PirateNet(torch.nn.Module):
9+
"""
10+
Implementation of Physics-Informed residual adaptive network (PirateNet).
11+
12+
The model consists of a Fourier feature embedding layer, multiple PirateNet
13+
blocks, and a final output layer. Each PirateNet block consist of three
14+
dense layers with dual gating mechanism and an adaptive residual connection,
15+
whose contribution is controlled by a trainable parameter ``alpha``.
16+
17+
The PirateNet, augmented with random weight factorization, is designed to
18+
mitigate spectral bias in deep networks.
19+
20+
.. seealso::
21+
22+
**Original reference**:
23+
Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025).
24+
*Simulating Three-dimensional Turbulence with Physics-informed Neural
25+
Networks*.
26+
DOI: `arXiv preprint arXiv:2507.08972.
27+
<https://arxiv.org/abs/2507.08972>`_
28+
"""
29+
30+
def __init__(
31+
self,
32+
input_dimension,
33+
inner_size,
34+
output_dimension,
35+
embedding=None,
36+
n_layers=3,
37+
activation=torch.nn.Tanh,
38+
):
39+
"""
40+
Initialization of the :class:`PirateNet` class.
41+
42+
:param int input_dimension: The number of input features.
43+
:param int inner_size: The number of hidden units in the dense layers.
44+
:param int output_dimension: The number of output features.
45+
:param torch.nn.Module embedding: The embedding module used to transform
46+
the input into a higher-dimensional feature space. If ``None``, a
47+
default :class:`~pina.model.block.FourierFeatureEmbedding` with
48+
scaling factor of 2 is used. Default is ``None``.
49+
:param int n_layers: The number of PirateNet blocks in the model.
50+
Default is 3.
51+
:param torch.nn.Module activation: The activation function to be used in
52+
the blocks. Default is :class:`torch.nn.Tanh`.
53+
"""
54+
super().__init__()
55+
56+
# Check consistency
57+
check_consistency(activation, torch.nn.Module, subclass=True)
58+
check_positive_integer(input_dimension, strict=True)
59+
check_positive_integer(inner_size, strict=True)
60+
check_positive_integer(output_dimension, strict=True)
61+
check_positive_integer(n_layers, strict=True)
62+
63+
# Initialize the activation function
64+
self.activation = activation()
65+
66+
# Initialize the Fourier embedding
67+
self.embedding = embedding or FourierFeatureEmbedding(
68+
input_dimension=input_dimension,
69+
output_dimension=inner_size,
70+
sigma=2.0,
71+
)
72+
73+
# Initialize the shared dense layers
74+
self.linear1 = torch.nn.Linear(inner_size, inner_size)
75+
self.linear2 = torch.nn.Linear(inner_size, inner_size)
76+
77+
# Initialize the PirateNet blocks
78+
self.blocks = torch.nn.ModuleList(
79+
[PirateNetBlock(inner_size, activation) for _ in range(n_layers)]
80+
)
81+
82+
# Initialize the output layer
83+
self.output_layer = torch.nn.Linear(inner_size, output_dimension)
84+
85+
def forward(self, input_):
86+
"""
87+
Forward pass of the PirateNet model. It applies the Fourier feature
88+
embedding, computes the shared gating tensors U and V, and passes the
89+
input through each block in the network. Finally, it applies the output
90+
layer to produce the final output.
91+
92+
:param input_: The input tensor for the model.
93+
:type input_: torch.Tensor | LabelTensor
94+
:return: The output tensor of the model.
95+
:rtype: torch.Tensor | LabelTensor
96+
"""
97+
# Apply the Fourier feature embedding
98+
x = self.embedding(input_)
99+
100+
# Compute U and V from the shared dense layers
101+
U = self.activation(self.linear1(x))
102+
V = self.activation(self.linear2(x))
103+
104+
# Pass through each block in the network
105+
for block in self.blocks:
106+
x = block(x, U, V)
107+
108+
return self.output_layer(x)
109+
110+
@property
111+
def alpha(self):
112+
"""
113+
Return the alpha values of all PirateNetBlock layers.
114+
115+
:return: A list of alpha values from each block.
116+
:rtype: list
117+
"""
118+
return [block.alpha.item() for block in self.blocks]
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
import pytest
3+
from pina.model.block import PirateNetBlock
4+
5+
data = torch.rand((20, 3))
6+
7+
8+
@pytest.mark.parametrize("inner_size", [10, 20])
9+
def test_constructor(inner_size):
10+
11+
PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
12+
13+
# Should fail if inner_size is negative
14+
with pytest.raises(AssertionError):
15+
PirateNetBlock(inner_size=-1, activation=torch.nn.Tanh)
16+
17+
18+
@pytest.mark.parametrize("inner_size", [10, 20])
19+
def test_forward(inner_size):
20+
21+
model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
22+
23+
# Create dummy embedding
24+
dummy_embedding = torch.nn.Linear(data.shape[1], inner_size)
25+
x = dummy_embedding(data)
26+
27+
# Create dummy U and V tensors
28+
U = torch.rand((data.shape[0], inner_size))
29+
V = torch.rand((data.shape[0], inner_size))
30+
31+
output_ = model(x, U, V)
32+
assert output_.shape == (data.shape[0], inner_size)
33+
34+
35+
@pytest.mark.parametrize("inner_size", [10, 20])
36+
def test_backward(inner_size):
37+
38+
model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
39+
data.requires_grad_()
40+
41+
# Create dummy embedding
42+
dummy_embedding = torch.nn.Linear(data.shape[1], inner_size)
43+
x = dummy_embedding(data)
44+
45+
# Create dummy U and V tensors
46+
U = torch.rand((data.shape[0], inner_size))
47+
V = torch.rand((data.shape[0], inner_size))
48+
49+
output_ = model(x, U, V)
50+
51+
loss = torch.mean(output_)
52+
loss.backward()
53+
assert data.grad.shape == data.shape

0 commit comments

Comments
 (0)