Skip to content

Commit 2cfb896

Browse files
authored
Fix reshape in ViTAutoEnc (#3841)
* Add failing test Signed-off-by: Sven Koitka <[email protected]> * Fix reshape after ViT Signed-off-by: Sven Koitka <[email protected]> * Remove unused import Signed-off-by: Sven Koitka <[email protected]>
1 parent e570c3b commit 2cfb896

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

monai/networks/nets/vitautoenc.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
# limitations under the License.
1111

1212

13-
import math
1413
from typing import Sequence, Union
1514

1615
import torch
@@ -19,6 +18,7 @@
1918
from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
2019
from monai.networks.blocks.transformerblock import TransformerBlock
2120
from monai.networks.layers import Conv
21+
from monai.utils import ensure_tuple_rep
2222

2323
__all__ = ["ViTAutoEnc"]
2424

@@ -74,6 +74,7 @@ def __init__(
7474

7575
super().__init__()
7676

77+
self.patch_size = ensure_tuple_rep(patch_size, spatial_dims)
7778
self.spatial_dims = spatial_dims
7879

7980
self.patch_embedding = PatchEmbeddingBlock(
@@ -105,14 +106,15 @@ def forward(self, x):
105106
x: input tensor must have isotropic spatial dimensions,
106107
such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``.
107108
"""
109+
spatial_size = x.shape[2:]
108110
x = self.patch_embedding(x)
109111
hidden_states_out = []
110112
for blk in self.blocks:
111113
x = blk(x)
112114
hidden_states_out.append(x)
113115
x = self.norm(x)
114116
x = x.transpose(1, 2)
115-
d = [round(math.pow(x.shape[2], 1 / self.spatial_dims))] * self.spatial_dims
117+
d = [s // p for s, p in zip(spatial_size, self.patch_size)]
116118
x = torch.reshape(x, [x.shape[0], x.shape[1], *d])
117119
x = self.conv3d_transpose(x)
118120
x = self.conv3d_transpose_1(x)

tests/test_vitautoenc.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,25 @@
4141

4242
TEST_CASE_Vitautoenc.append(test_case)
4343

44+
TEST_CASE_Vitautoenc.append(
45+
[
46+
{
47+
"in_channels": 1,
48+
"img_size": (512, 512, 32),
49+
"patch_size": (16, 16, 16),
50+
"hidden_size": 768,
51+
"mlp_dim": 3072,
52+
"num_layers": 4,
53+
"num_heads": 12,
54+
"pos_embed": "conv",
55+
"dropout_rate": 0.6,
56+
"spatial_dims": 3,
57+
},
58+
(2, 1, 512, 512, 32),
59+
(2, 1, 512, 512, 32),
60+
]
61+
)
62+
4463

4564
class TestPatchEmbeddingBlock(unittest.TestCase):
4665
@parameterized.expand(TEST_CASE_Vitautoenc)

0 commit comments

Comments
 (0)