Skip to content

Commit 0c80c3f

Browse files
committed
switch to slice assign
1 parent 03c57a4 commit 0c80c3f

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

tests/layers/test_drop.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
class ClipMaskTests(unittest.TestCase):
1515
def test_clip_mask_2d_odd(self):
1616
mask = drop.clip_mask_2d(h=5, w=7, kernel=3, device=torch_device)
17+
print(mask)
1718
assert mask.device == torch.device(torch_device)
1819
assert mask.tolist() == \
1920
[
@@ -26,6 +27,7 @@ def test_clip_mask_2d_odd(self):
2627

2728
def test_clip_mask_2d_even(self):
2829
mask = drop.clip_mask_2d(h=5, w=7, kernel=2, device=torch_device)
30+
print(mask)
2931
assert mask.device == torch.device(torch_device)
3032
# TODO: This is a suprising result; should even kernels be forbidden?
3133
assert mask.tolist() == \

timm/layers/drop.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import torch.nn as nn
1919
import torch.nn.functional as F
2020

21-
from .grid import ndgrid
22-
2321

2422
def clip_mask_2d(
2523
h: int,
@@ -49,13 +47,11 @@ def clip_mask_2d(
4947
"""
5048
assert kernel <= min(h, w), f"{kernel=} > min({h=}, {w=})"
5149

52-
h_i, w_i = ndgrid(torch.arange(h, device=device), torch.arange(w, device=device))
53-
return (
54-
(h_i >= kernel // 2) &
55-
(h_i < h - (kernel - 1) // 2) &
56-
(w_i >= kernel // 2) &
57-
(w_i < w - (kernel - 1) // 2)
58-
).reshape(h, w)
50+
mask = torch.zeros((h, w), dtype=torch.bool, device=device)
51+
start = kernel // 2
52+
end = ((kernel - 1) // 2)
53+
mask[start:h-end, start:w-end] = True
54+
return mask
5955

6056

6157
def drop_block_2d(

0 commit comments

Comments
 (0)