Skip to content

Commit c8b7af0

Browse files
committed
fix types
1 parent 6bc3dc0 commit c8b7af0

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/layers/test_drop.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_conv2d_kernel_midpoint_mask_even(self):
3636
shape=(5, 7),
3737
kernel=(2, 2),
3838
device=torch_device,
39-
dtype=torch.bool,
39+
dtype=torch.float32,
4040
)
4141
print(mask)
4242
assert mask.device == torch.device(torch_device)
@@ -55,7 +55,7 @@ def test_clip_mask_2d_kernel_too_big(self):
5555
shape=(4, 7),
5656
kernel=(5, 5),
5757
device=torch_device,
58-
dtype=torch.bool,
58+
dtype=torch.float32,
5959
)
6060
raise RuntimeError("Expected throw")
6161

@@ -102,7 +102,7 @@ def test_drop_filter_messy(self):
102102
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
103103
],
104104
device=torch_device,
105-
dtype=torch.int32,
105+
dtype=torch.float32,
106106
).unsqueeze(0).unsqueeze(0)
107107

108108
result = drop.drop_block_2d_drop_filter_(

0 commit comments

Comments
 (0)