Skip to content

Commit cc56df9

Browse files
guaneecaria1th
authored andcommitted
Fix dropout logic
1 parent 85fcccc commit cc56df9

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

modules/hypernetworks/hypernetwork.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class HypernetworkModule(torch.nn.Module):
3535
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
3636

3737
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
38-
add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
38+
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=True):
3939
super().__init__()
4040

4141
assert layer_structure is not None, "layer_structure must not be None"
@@ -61,7 +61,7 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N
6161
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
6262

6363
# Add dropout except last layer
64-
if 'last_layer_dropout' in kwargs and kwargs['last_layer_dropout'] and use_dropout and i < len(layer_structure) - 2:
64+
if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
6565
linears.append(torch.nn.Dropout(p=0.3))
6666

6767
self.linear = torch.nn.Sequential(*linears)

0 commit comments

Comments
 (0)