-
Hi, I noticed that in the tutorial, they use
to initialize the Swin UNETR encoder from self-supervised pre-trained weights. For the test, they use
I would like to know why it's necessary to use the function load_from rather than just filter out the mismatches and use load_state_dict. I'm curious about this since when I want to load a pre-trained model to continue my training on the same task, I notice that the loss appears higher than what it should be, and it degrades the performance. Using load_from instead cannot solve my problem, either. So I would appreciate it if someone could point out the difference between the two functions. Thanks. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
hi @Levishery , thanks for the question. The first load_from is used to load the pre-trained weights from self-supervised learning before fine-tuning on segmentation task. The pre-training weights only has the Swin Transformer which serves the encoder part of SwinUNETR. So the load_from function is to match keys only for SwinUNETR encoder. Later, for test, the entire Swin UNETR model weights need to be loaded. Thus, the direct torch.load is used. For your case, if your pre-trained model is trained on entire Swin UNETR, you can directly load weights with "torch.load", if your pre-trained model is only for the encoder, "load_from" might help. Thanks. |
Beta Was this translation helpful? Give feedback.
hi @Levishery , thanks for the question.
The first load_from is used to load the pre-trained weights from self-supervised learning before fine-tuning on segmentation task. The pre-training weights only has the Swin Transformer which serves the encoder part of SwinUNETR. So the load_from function is to match keys only for SwinUNETR encoder.
Later, for test, the entire Swin UNETR model weights need to be loaded. Thus, the direct torch.load is used.
For your case, if your pre-trained model is trained on entire Swin UNETR, you can directly load weights with "torch.load", if your pre-trained model is only for the encoder, "load_from" might help.
Thanks.