Skip to content

Commit 8e1f922

Browse files
support piratenet without fourier embedding (#1157)
1 parent 359136f commit 8e1f922

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

ppsci/arch/mlp.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ class PirateNetBlock(nn.Layer):
546546
$$
547547
548548
Args:
549+
input_dim (int): Input dimension.
549550
embed_dim (int): Embedding dimension.
550551
activation (str, optional): Name of activation function. Defaults to "tanh".
551552
random_weight (Optional[Dict[str, float]]): Mean and std of random weight
@@ -554,16 +555,17 @@ class PirateNetBlock(nn.Layer):
554555

555556
def __init__(
556557
self,
558+
input_dim: int,
557559
embed_dim: int,
558560
activation: str = "tanh",
559561
random_weight: Optional[Dict[str, float]] = None,
560562
):
561563
super().__init__()
562564
self.linear1 = (
563-
nn.Linear(embed_dim, embed_dim)
565+
nn.Linear(input_dim, embed_dim)
564566
if random_weight is None
565567
else RandomWeightFactorization(
566-
embed_dim,
568+
input_dim,
567569
embed_dim,
568570
mean=random_weight["mean"],
569571
std=random_weight["std"],
@@ -721,6 +723,9 @@ def __init__(
721723
cur_size, fourier["dim"], fourier["scale"]
722724
)
723725
cur_size = fourier["dim"]
726+
else:
727+
self.linear_emb = nn.Linear(cur_size, hidden_size[0])
728+
cur_size = hidden_size[0]
724729

725730
self.embed_u = nn.Sequential(
726731
(
@@ -769,6 +774,7 @@ def __init__(
769774
self.blocks.append(
770775
PirateNetBlock(
771776
cur_size,
777+
_size,
772778
activation=activation,
773779
random_weight=random_weight,
774780
)
@@ -811,6 +817,8 @@ def forward(self, x):
811817

812818
if self.fourier:
813819
y = self.fourier_emb(y)
820+
else:
821+
y = self.linear_emb(y)
814822

815823
y = self.forward_tensor(y)
816824
y = self.split_to_dict(y, self.output_keys, axis=-1)

0 commit comments

Comments
 (0)