ValueError: Sequence must have length 3, got 2. when modifying UNETR for 2D RGB images #3746
-
Dear all, I want to use UNETR for 2D RGB images (shape: (256,256,3)) segmentation. I changed the UNETR code to
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union
import torch
import torch.nn as nn
from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
from monai.networks.nets import ViT
class UNETR2D(nn.Module):
"""
UNETR based on: "Hatamizadeh et al.,
UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
"""
def __init__(
self,
in_channels: int,
out_channels: int,
img_size: Tuple[int, int],
feature_size: int = 16,
hidden_size: int = 768,
mlp_dim: int = 3072,
num_heads: int = 12,
pos_embed: str = "perceptron",
norm_name: Union[Tuple, str] = "instance",
conv_block: bool = False,
res_block: bool = True,
dropout_rate: float = 0.0,
) -> None:
super().__init__()
if not (0 <= dropout_rate <= 1):
raise AssertionError("dropout_rate should be between 0 and 1.")
if hidden_size % num_heads != 0:
raise AssertionError("hidden size should be divisible by num_heads.")
if pos_embed not in ["conv", "perceptron"]:
raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")
self.num_layers = 12
self.patch_size = (16, 16)
self.feat_size = (
img_size[0] // self.patch_size[0],
img_size[1] // self.patch_size[1]
)
self.hidden_size = hidden_size
self.classification = False
self.vit = ViT(
in_channels=in_channels,
img_size=img_size,
patch_size=self.patch_size,
hidden_size=hidden_size,
mlp_dim=mlp_dim,
num_layers=self.num_layers,
num_heads=num_heads,
pos_embed=pos_embed,
classification=self.classification,
dropout_rate=dropout_rate,
)
self.encoder1 = UnetrBasicBlock(
spatial_dims=2,
in_channels=in_channels,
out_channels=feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
)
self.encoder2 = UnetrPrUpBlock(
spatial_dims=2,
in_channels=hidden_size,
out_channels=feature_size * 2,
num_layer=2,
kernel_size=3,
stride=1,
upsample_kernel_size=2,
norm_name=norm_name,
conv_block=conv_block,
res_block=res_block,
)
self.encoder3 = UnetrPrUpBlock(
spatial_dims=2,
in_channels=hidden_size,
out_channels=feature_size * 4,
num_layer=1,
kernel_size=3,
stride=1,
upsample_kernel_size=2,
norm_name=norm_name,
conv_block=conv_block,
res_block=res_block,
)
self.encoder4 = UnetrPrUpBlock(
spatial_dims=2,
in_channels=hidden_size,
out_channels=feature_size * 8,
num_layer=0,
kernel_size=3,
stride=1,
upsample_kernel_size=2,
norm_name=norm_name,
conv_block=conv_block,
res_block=res_block,
)
self.decoder5 = UnetrUpBlock(
spatial_dims=2,
in_channels=hidden_size,
out_channels=feature_size * 8,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder4 = UnetrUpBlock(
spatial_dims=2,
in_channels=feature_size * 8,
out_channels=feature_size * 4,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder3 = UnetrUpBlock(
spatial_dims=2,
in_channels=feature_size * 4,
out_channels=feature_size * 2,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder2 = UnetrUpBlock(
spatial_dims=2,
in_channels=feature_size * 2,
out_channels=feature_size,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) # type: ignore
def proj_feat(self, x, hidden_size, feat_size):
x = x.view(x.size(0), feat_size[0], feat_size[1], hidden_size)
x = x.permute(0, 3, 1, 2).contiguous()
return x
# def load_from(self, weights):
# with torch.no_grad():
# res_weight = weights
# # copy weights from patch embedding
# for i in weights['state_dict']:
# print(i)
# self.vit.patch_embedding.position_embeddings.copy_(weights['state_dict']['module.transformer.patch_embedding.position_embeddings_3d'])
# self.vit.patch_embedding.cls_token.copy_(weights['state_dict']['module.transformer.patch_embedding.cls_token'])
# self.vit.patch_embedding.patch_embeddings[1].weight.copy_(weights['state_dict']['module.transformer.patch_embedding.patch_embeddings.1.weight'])
# self.vit.patch_embedding.patch_embeddings[1].bias.copy_(weights['state_dict']['module.transformer.patch_embedding.patch_embeddings.1.bias'])
# # copy weights from encoding blocks (default: num of blocks: 12)
# for bname, block in self.vit.blocks.named_children():
# print(block)
# block.loadFrom(weights, n_block=bname)
# # last norm layer of transformer
# self.vit.norm.weight.copy_(weights['state_dict']['module.transformer.norm.weight'])
# self.vit.norm.bias.copy_(weights['state_dict']['module.transformer.norm.bias'])
def forward(self, x_in):
x, hidden_states_out = self.vit(x_in)
enc1 = self.encoder1(x_in)
x2 = hidden_states_out[3]
enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size))
x3 = hidden_states_out[6]
enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size))
x4 = hidden_states_out[9]
enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size))
dec4 = self.proj_feat(x, self.hidden_size, self.feat_size)
dec3 = self.decoder5(dec4, enc4)
dec2 = self.decoder4(dec3, enc3)
dec1 = self.decoder3(dec2, enc2)
out = self.decoder2(dec1, enc1)
logits = self.out(out)
return logits Load the mode networks.unetr2d import UNETR2D
model = UNETR2D(
in_channels=3, # 3 channels, R,G,B
out_channels=3,
img_size=(256, 256),
feature_size=16,
hidden_size=768,
mlp_dim=3072,
num_heads=12,
pos_embed="perceptron",
norm_name="instance",
res_block=True,
dropout_rate=0.0,
) However, I got errors as follows Traceback (most recent call last):
File "/tmp/ipykernel_1400599/1295113233.py", line 1, in <module>
model = UNETR2D(
File "***/unetr2d.py", line 86, in __init__
self.vit = ViT(
File "/home/***/anaconda3/lib/python3.9/site-packages/monai/networks/nets/vit.py", line 82, in __init__
self.patch_embedding = PatchEmbeddingBlock(
File "/home/***/anaconda3/lib/python3.9/site-packages/monai/networks/blocks/patchembedding.py", line 75, in __init__
img_size = ensure_tuple_rep(img_size, spatial_dims)
File "/home/***/anaconda3/lib/python3.9/site-packages/monai/utils/misc.py", line 137, in ensure_tuple_rep
raise ValueError(f"Sequence must have length {dim}, got {len(tup)}.")
ValueError: Sequence must have length 3, got 2. How should I fix this error? Any comments would be highly appreciated. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Hi @ahatamiz , Could you please help share some comments here? Thanks in advance. |
Beta Was this translation helpful? Give feedback.
-
Thanks for the interest in our work. I made some modifications in your provided snippet. The 2D UNETR should work without any issues as below:
|
Beta Was this translation helpful? Give feedback.
Hi @EdwardZhao1991
Thanks for the interest in our work. I made some modifications in your provided snippet. The 2D UNETR should work without any issues as below: