UNet vs BasicUNet speed and memory usage. #3048
Replies: 4 comments 1 reply
-
|
Hi @ericspod , Could you please help share some details of the implementation that may cause this difference? Thanks in advance. |
Beta Was this translation helpful? Give feedback.
-
|
passing both through Output from Output from |
Beta Was this translation helpful? Give feedback.
-
|
To follow on from @rijobro 's analysis, running the following will produce a full report of the layers in the networks as you've already done: from monai.networks.nets import UNet, BasicUNet
import torchinfo
unet = UNet(
dimensions=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
)
basic_unet = BasicUNet()
print(torchinfo.summary(unet, (1, 1, 128, 128, 128), depth=20))
print(torchinfo.summary(basic_unet, (1, 1, 128, 128, 128), depth=20))The slight difference here is setting What we're seeing here with an input of |
Beta Was this translation helpful? Give feedback.
-
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I was looking at the implementations of both
UNetandBasicUNet. While theBasicUNetimplementation is very straightforward, theUNetimplementation uses a recursive function to build level by level. I tested both using common configurations that yielded a somewhat similar parameter count. Here are the functions I used:And these are the tested architectures:
The results I got showed that
Unetis almost10xfaster thanBasicUNetfor performing inference on a(1, 1, 128, 128, 128)ones volume. Are there any specific architectural changes that would explain such a difference between the inference times?Furthermore if I try increasing the batch size from 1 to 2, i.e.
(2, 1, 128, 128, 128)I run into a memory error usingBasicUNetwhile forUNetI can increase it much more. InBasicUNetthe tensorsx0tox4andu4tou1are kept until logits is returned will the recursive implementation help with the memory consumption? or is this just the benefit of using a 2-strided convolution?Any comments are very much appreciated! Thanks!
Beta Was this translation helpful? Give feedback.
All reactions