@@ -546,6 +546,7 @@ class PirateNetBlock(nn.Layer):
546
546
$$
547
547
548
548
Args:
549
+ input_dim (int): Input dimension.
549
550
embed_dim (int): Embedding dimension.
550
551
activation (str, optional): Name of activation function. Defaults to "tanh".
551
552
random_weight (Optional[Dict[str, float]]): Mean and std of random weight
@@ -554,16 +555,17 @@ class PirateNetBlock(nn.Layer):
554
555
555
556
def __init__ (
556
557
self ,
558
+ input_dim : int ,
557
559
embed_dim : int ,
558
560
activation : str = "tanh" ,
559
561
random_weight : Optional [Dict [str , float ]] = None ,
560
562
):
561
563
super ().__init__ ()
562
564
self .linear1 = (
563
- nn .Linear (embed_dim , embed_dim )
565
+ nn .Linear (input_dim , embed_dim )
564
566
if random_weight is None
565
567
else RandomWeightFactorization (
566
- embed_dim ,
568
+ input_dim ,
567
569
embed_dim ,
568
570
mean = random_weight ["mean" ],
569
571
std = random_weight ["std" ],
@@ -721,6 +723,9 @@ def __init__(
721
723
cur_size , fourier ["dim" ], fourier ["scale" ]
722
724
)
723
725
cur_size = fourier ["dim" ]
726
+ else :
727
+ self .linear_emb = nn .Linear (cur_size , hidden_size [0 ])
728
+ cur_size = hidden_size [0 ]
724
729
725
730
self .embed_u = nn .Sequential (
726
731
(
@@ -769,6 +774,7 @@ def __init__(
769
774
self .blocks .append (
770
775
PirateNetBlock (
771
776
cur_size ,
777
+ _size ,
772
778
activation = activation ,
773
779
random_weight = random_weight ,
774
780
)
@@ -811,6 +817,8 @@ def forward(self, x):
811
817
812
818
if self .fourier :
813
819
y = self .fourier_emb (y )
820
+ else :
821
+ y = self .linear_emb (y )
814
822
815
823
y = self .forward_tensor (y )
816
824
y = self .split_to_dict (y , self .output_keys , axis = - 1 )
0 commit comments