-
Notifications
You must be signed in to change notification settings - Fork 3
background mask option for pseudo labeler - unsupervised training #135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
background mask option for pseudo labeler - unsupervised training #135
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks very good now! The only things to address:
- You are not actually using the new class for the mean teacher trainer.
- Some cosmetics.
See details in the comments.
|
I think some changes for these files did not get merged to the main branch:
|
|
I added gpu device arguments to some other files, and my personal training scripts. |
| # Sample from both the supervised and unsupervised loader. | ||
| for xu1, xu2 in self.unsupervised_train_loader: | ||
|
|
||
| # Assuming shape (B, C, D, H, W), only keep the first channel for xu2 (student input). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add the assert here or throw some other kind of error if we have a wrong number of channels @stmartineau99 ? See previous comment.
| Args: | ||
| data_paths: The filepaths to the hdf5 files containing the training data. | ||
| data_paths: The filepaths to the mrc files containing the training data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| data_paths: The filepaths to the mrc files containing the training data. | |
| data_paths: The filepaths to the mrc or hdf5 files containing the training data. |
| in_channels: int = 1, | ||
| out_channels: int = 2, | ||
| mask_channel: bool = False, | ||
| device: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that it's better to pass a torch.device here that is optional:
device: Optional[torch.device] = NoneThen you don't need any further changes below.
| x = f(x) | ||
| return x | ||
|
|
||
| class ChannelWiseAugmentations: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is creating an issue with serialization in Case 1 only (see below) because it is called below like: augmentations = (ChannelWiseAugmentations(weak_augmentations()), ChannelWiseAugmentations(weak_augmentations()))
I rewrote the class to avoid having nested function calls inside get_unsupervised_loader() but this same error still happens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- mainly changes to semisupervised_training.py - get_unsupervised_loader()
- changes to domain_adaptation.py - responded to your last comment about asserting dims.
NewPseudoLabelerto subtract background from the pseudo labels if a background mask is givenget_unsupervised_loaderto handle 4 cases:get_unsupervised_loaderhandles each case correctly:DropChannel,ChannelWiseRawTransform,ChannelWiseAugmentationsmean_teacher_adaptationto use theNewPseudoLabelerif the background mask is givenNewMeanTeacherTrainer, which drops the background mask from the teacher input after computing pseudo labels. it also drops the background mask from the student input since this is not used. This behavior only occurs when it recieves training data, and not for validation.NewPseudoLabelerso that it behaves correctly when it recieves training and validation data, since in the case of validation there is no background mask channel.