We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent ab71a49 commit 620ee3eCopy full SHA for 620ee3e
topoloss/core.py
@@ -25,7 +25,7 @@ def __init__(
25
26
def check_layer_type(self, layer):
27
assert isinstance(
28
- layer, Union[nn.Conv2d, nn.Linear]
+ layer, (nn.Conv2d, nn.Linear)
29
), f"Expect layer to be either nn.Conv2d or nn.Linear, but got: {type(layer)}"
30
31
def get_layerwise_topo_losses(self, model, do_scaling: bool = True) -> dict:
0 commit comments