Skip to content

Commit db00642

Browse files
committed
fix potential typing issue
1 parent 1696a06 commit db00642

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

topoloss/core.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@ def __init__(
2424
self.losses = losses
2525

2626
def check_layer_type(self, layer):
27-
assert isinstance(
28-
layer, (nn.Conv2d, nn.Linear)
29-
), f"Expect layer to be either nn.Conv2d or nn.Linear, but got: {type(layer)}"
27+
assert isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear), f"Expect layer to be either nn.Conv2d or nn.Linear, but got: {type(layer)}"
28+
3029

3130
def get_layerwise_topo_losses(self, model, do_scaling: bool = True) -> dict:
3231
layer_wise_losses = {}

0 commit comments

Comments
 (0)