Skip to content

Commit 2f4c918

Browse files
authored
Remove activation from final layer of HNs
1 parent 3e15f8e commit 2f4c918

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

modules/hypernetworks/hypernetwork.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
4141
# Add a fully-connected layer
4242
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
4343

44-
# Add an activation func
45-
if activation_func == "linear" or activation_func is None:
44+
# Add an activation func except last layer
45+
if activation_func == "linear" or activation_func is None or i >= len(layer_structure) - 3:
4646
pass
4747
elif activation_func in self.activation_dict:
4848
linears.append(self.activation_dict[activation_func]())
@@ -53,7 +53,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
5353
if add_layer_norm:
5454
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
5555

56-
# Add dropout expect last layer
56+
# Add dropout except last layer
5757
if use_dropout and i < len(layer_structure) - 3:
5858
linears.append(torch.nn.Dropout(p=0.3))
5959

0 commit comments

Comments
 (0)