Commit 03c57a4
committed
Fix bug in timm.layers.drop.drop_block_2d when
There are two bugs in the `valid_block` code for `drop_block_2d`.
- a (W, H) grid being reshaped as (H, W)
The current code uses (W, H) to generate the meshgrid;
but then uses a `.reshape((1, 1, H, W))` to unsqueeze the block map.
The simplest fix to the first bug is a one-line change:
```python
h_i, w_i = ndgrid(torch.arange(H), torch.arange(W))
```
This is a longer patch, that attempts to make the code testable.
Note: The current code behaves oddly when the block_size or
clipped_block_size is even; I've added tests exposing the behavior;
but have not changed it.
When you trigger the reshape bug, you get wild results:
```
$ python scratch.py
{'H': 4, 'W': 5, 'block_size': 3, 'fix_reshape': False}
grid.shape=torch.Size([1, 1, 4, 5])
tensor([[[[False, False, False, False, False],
[ True, True, False, False, True],
[ True, False, False, True, True],
[False, False, False, False, False]]]])
{'H': 4, 'W': 5, 'block_size': 3, 'fix_reshape': True}
grid.shape=torch.Size([1, 1, 4, 5])
tensor([[[[False, False, False, False, False],
[False, True, True, True, False],
[False, True, True, True, False],
[False, False, False, False, False]]]])
```
Here's a tiny exceprt script, showing the problem;
it generated the above output.
```python
import torch
from typing import Tuple
def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
"""generate N-D grid in dimension order.
The ndgrid function is like meshgrid except that the order of the first two input arguments are switched.
That is, the statement
[X1,X2,X3] = ndgrid(x1,x2,x3)
produces the same result as
[X2,X1,X3] = meshgrid(x2,x1,x3)
This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make
torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy').
"""
try:
return torch.meshgrid(*tensors, indexing='ij')
except TypeError:
# old PyTorch < 1.10 will follow this path as it does not have indexing arg,
# the old behaviour of meshgrid was 'ij'
return torch.meshgrid(*tensors)
def valid_block(H, W, block_size, fix_reshape=False):
clipped_block_size = min(block_size, H, W)
if fix_reshape:
# This should match the .reshape() dimension order below.
h_i, w_i = ndgrid(torch.arange(H), torch.arange(W))
else:
# The original produces crazy stride patterns, due to .reshape() offset winding.
# This is only visible when H != W.
w_i, h_i = ndgrid(torch.arange(W), torch.arange(H))
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
valid_block = torch.reshape(valid_block, (1, 1, H, W))
return valid_block
def main():
common_args = dict(H=4, W=5, block_size=3)
for fix in [False, True]:
args = dict(H=4, W=5, block_size=3, fix_reshape=fix)
grid = valid_block(**args)
print(args)
print(f"{grid.shape=}")
print(grid)
print()
if __name__ == "__main__":
main()
```H != W.1 parent 954613a commit 03c57a4
3 files changed
+106
-6
lines changedWhitespace-only changes.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
24 | 61 | | |
25 | 62 | | |
26 | 63 | | |
| |||
30 | 67 | | |
31 | 68 | | |
32 | 69 | | |
33 | | - | |
| 70 | + | |
34 | 71 | | |
35 | 72 | | |
36 | 73 | | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
37 | 85 | | |
38 | 86 | | |
39 | 87 | | |
40 | | - | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
41 | 92 | | |
42 | 93 | | |
43 | 94 | | |
44 | 95 | | |
45 | 96 | | |
46 | | - | |
47 | | - | |
48 | | - | |
49 | | - | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
50 | 103 | | |
51 | 104 | | |
52 | 105 | | |
| |||
0 commit comments