Skip to content

Conversation

@stmartineau99
Copy link
Contributor

@stmartineau99 stmartineau99 commented Jul 22, 2025

  • added new class NewPseudoLabeler to subtract background from the pseudo labels if a background mask is given
  • updated get_unsupervised_loader to handle 4 cases:
  1. both sample mask and background mask are given, this will be be used by the train loader
  2. sample mask only, this will be used by the validation loader
  3. background mask only, not implemented error
  4. neither mask given, use the default behavior
  • created new classes to ensure that get_unsupervised_loader handles each case correctly: DropChannel, ChannelWiseRawTransform, ChannelWiseAugmentations
  • updated mean_teacher_adaptation to use the NewPseudoLabeler if the background mask is given
  • added subclass NewMeanTeacherTrainer, which drops the background mask from the teacher input after computing pseudo labels. it also drops the background mask from the student input since this is not used. This behavior only occurs when it recieves training data, and not for validation.
  • fixed the NewPseudoLabeler so that it behaves correctly when it recieves training and validation data, since in the case of validation there is no background mask channel.

Copy link
Contributor

@constantinpape constantinpape left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very good now! The only things to address:

  • You are not actually using the new class for the mean teacher trainer.
  • Some cosmetics.

See details in the comments.

@stmartineau99
Copy link
Contributor Author

I think some changes for these files did not get merged to the main branch:

  • training/semisupervised_training.py
  • training/domain_adaptation.py

@stmartineau99
Copy link
Contributor Author

I added gpu device arguments to some other files, and my personal training scripts.

# Sample from both the supervised and unsupervised loader.
for xu1, xu2 in self.unsupervised_train_loader:

# Assuming shape (B, C, D, H, W), only keep the first channel for xu2 (student input).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add the assert here or throw some other kind of error if we have a wrong number of channels @stmartineau99 ? See previous comment.

Args:
data_paths: The filepaths to the hdf5 files containing the training data.
data_paths: The filepaths to the mrc files containing the training data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
data_paths: The filepaths to the mrc files containing the training data.
data_paths: The filepaths to the mrc or hdf5 files containing the training data.

in_channels: int = 1,
out_channels: int = 2,
mask_channel: bool = False,
device: int = 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it's better to pass a torch.device here that is optional:

device: Optional[torch.device] = None

Then you don't need any further changes below.

x = f(x)
return x

class ChannelWiseAugmentations:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is creating an issue with serialization in Case 1 only (see below) because it is called below like: augmentations = (ChannelWiseAugmentations(weak_augmentations()), ChannelWiseAugmentations(weak_augmentations()))

I rewrote the class to avoid having nested function calls inside get_unsupervised_loader() but this same error still happens.

Copy link
Contributor Author

@stmartineau99 stmartineau99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • mainly changes to semisupervised_training.py - get_unsupervised_loader()
  • changes to domain_adaptation.py - responded to your last comment about asserting dims.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants