We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent de096d0 commit 7207e3bCopy full SHA for 7207e3b
modules/hypernetworks/hypernetwork.py
@@ -32,7 +32,7 @@ class HypernetworkModule(torch.nn.Module):
32
"tanh": torch.nn.Tanh,
33
"sigmoid": torch.nn.Sigmoid,
34
}
35
- activation_dict.update({cls_name: cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
+ activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
36
37
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
38
super().__init__()
0 commit comments