ViT not working for multi-class classification #4525
Replies: 2 comments 4 replies
-
Thanks for starting off the discussion 1.) Why is it that you are only using out_channels=4 with DenseNet121 as compared to num_classes=10 for ViT, is it because you tested 2D rotation with DenseNet as a classification task? 2.) If the answer to above question is "yes", I would suggest to try the same with ViT first, before claiming it does not work. :) 3.) I would also encourage you to share the training and validation curves for the DenseNet and then the same for ViT, to be able to help you out better. As additional insight, I would also ask you if the 3D rotation this way makes for a well-posed problem or not, because usually for classification one usually goes for a 4 class rotation by fixing an axis. If you have a paper/reference for 3D rotation classification please share, it would help in understanding the use case and also towards what you have implemented |
Beta Was this translation helpful? Give feedback.
-
@fengling0410 Thanks for sharing the reference and also the plots. It does seem like the training loss is going down, however it does not seem to be reflecting the same for the validation plot of ViT. It seems like a high learning rate, is being used. Can you try lowering the learning rate for ViT model, because the training loss is decreasing (but its just not smooth enough) and try it only with 4 classes for now. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all, I'm doing self-supervised learning for a CT scan ViT encoder. The pre-training task is rotation where I rotate the image patches according to 10 directions (so I have 10 classes in total) and let the ViT classification model predict the class. However, I find that the ViT classification model doesn't work. The model constantly gives the same prediction for all the inputs. But when I replace the ViT with DenseNet121, it works. I'm pretty confused about this and want to ask the community for advice.
Here are my codes
After I have replaced ViT model with
model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=4).to(device)
, the code works well. I would appreciate any advice from the community. Thank you in advance!Beta Was this translation helpful? Give feedback.
All reactions