Skip to content

Commit 2e8ef81

Browse files
authored
APTx Function: Default value of gamma should be 0.5 when trainable=False (#222)
In APTx activation function, the default value of gamma should be 0.5 when trainable=False
1 parent fb91d55 commit 2e8ef81

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

neurodiffeq/networks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,13 @@ class APTx(nn.Module):
190190
:type trainable: bool
191191
"""
192192

193-
def __init__(self, alpha=1.0, beta=1.0, gamma=1.0, trainable=False):
193+
def __init__(self, alpha=1.0, beta=1.0, gamma=0.5, trainable=False):
194194
super(APTx, self).__init__()
195195
alpha = float(alpha)
196196
beta = float(beta)
197197
gamma = float(gamma)
198198
self.trainable = trainable
199-
if trainable:
199+
if self.trainable:
200200
self.alpha = nn.Parameter(torch.tensor(alpha))
201201
self.beta = nn.Parameter(torch.tensor(beta))
202202
self.gamma = nn.Parameter(torch.tensor(gamma))
@@ -206,4 +206,4 @@ def __init__(self, alpha=1.0, beta=1.0, gamma=1.0, trainable=False):
206206
self.gamma = gamma
207207

208208
def forward(self, x):
209-
return (self.alpha + torch.nn.functional.tanh(self.beta*x))*self.gamma*x
209+
return (self.alpha + torch.nn.functional.tanh(self.beta*x))*self.gamma*x

0 commit comments

Comments
 (0)