UNet vs BasicUNet speed and memory usage. #3048
Replies: 4 comments 1 reply
-
Hi @ericspod , Could you please help share some details of the implementation that may cause this difference? Thanks in advance. |
Beta Was this translation helpful? Give feedback.
-
passing both through Output from
Output from
|
Beta Was this translation helpful? Give feedback.
-
To follow on from @rijobro 's analysis, running the following will produce a full report of the layers in the networks as you've already done: from monai.networks.nets import UNet, BasicUNet
import torchinfo
unet = UNet(
dimensions=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
)
basic_unet = BasicUNet()
print(torchinfo.summary(unet, (1, 1, 128, 128, 128), depth=20))
print(torchinfo.summary(basic_unet, (1, 1, 128, 128, 128), depth=20)) The slight difference here is setting
What we're seeing here with an input of |
Beta Was this translation helpful? Give feedback.
-
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I was looking at the implementations of both
UNet
andBasicUNet
. While theBasicUNet
implementation is very straightforward, theUNet
implementation uses a recursive function to build level by level. I tested both using common configurations that yielded a somewhat similar parameter count. Here are the functions I used:And these are the tested architectures:
The results I got showed that
Unet
is almost10x
faster thanBasicUNet
for performing inference on a(1, 1, 128, 128, 128)
ones volume. Are there any specific architectural changes that would explain such a difference between the inference times?Furthermore if I try increasing the batch size from 1 to 2, i.e.
(2, 1, 128, 128, 128)
I run into a memory error usingBasicUNet
while forUNet
I can increase it much more. InBasicUNet
the tensorsx0
tox4
andu4
tou1
are kept until logits is returned will the recursive implementation help with the memory consumption? or is this just the benefit of using a 2-strided convolution?Any comments are very much appreciated! Thanks!
Beta Was this translation helpful? Give feedback.
All reactions