-
Notifications
You must be signed in to change notification settings - Fork 1
Dev gan trainer #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev gan trainer #24
Conversation
…ce for complex orchestration that involves combining the context outputs from multiple forward groups.
…AN training process of updating of generator and discriminator separately.
…mprove readability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I double checked and found the tests all pass, nice job. Given the great stuff going on with testing you might consider making testing a part of automated jobs with a GitHub Actions job. The benefit here is that it'd provide you and reviewers confidence about the changes within the context of the pull request (similar to how the pre-commit checks operate).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I will save the CI test integration for a separate future PR as a careful pass of the test suite to mark and ensure skipping of cuda dependent tests and add cpu alternatives to tests that are only written for cuda as cuda enabled gpu is not offered on the free github action runners (unless I am wrong).
… inputs, targets, and preds.
… optimizer management
… forward groups derived context objects
…ndling and streamline output dimension calculations
…to use lowercase tuple syntax
…lemented to raising NotImplementedError
…od to raise KeyError
…to support a variety of forward function signatures in concrete loss classes, the old abstract AbstractLoss class has been replaced with a protocol and a non-abstract BaseLoss. Related modules and methods using the old abstract losses are also updated. And a new module is implemented to hourse brushed up wGAN training losses whereas the obsetele modules were removed.
… for inputs and outputs, enhancing type clarity and consistency.
…l instead of BaseGeneratorModel for consistency.
… for flexibility.
…flowLogger for optional dictionary parameters
…weights and add model registry for logging parameters
- Introduced a new script for training a Wasserstein GAN (WGAN) using a U-Net generator. - Implemented logging with MLflow to track metrics and visualize predictions during training. - Included dataset processing utilities to handle image files and create a cropped dataset for training. - Added functionality to visualize training outcomes and metrics from the MLflow tracking server. - Configured hyperparameters suitable for demo purposes, including learning rates and batch size.
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
|
Thanks @d33bs. Addressed all your comments except the CI test integration is which is better as a separate future PR. Merging now! |
This PR adds back the Generative Adversarial Network (GAN) training functionality following its removal in refactor #19 where the model forward pass logic is decoupled from trainer and promoted as a separate engine subpackage (but #19 only included the engine parts for regular UNet training).
Introduces
src/virtual_stain_flow/engine/forward_groups.py:Added
DiscriminatorForwardGroup, a standardized interface for discriminator forward passessrc/virtual_stain_flow/engine/orchestrators.py:An additional layer of abstraction.
OrchestratedStepandGANOrchestratorbetween the forward passing engine and the trainer classes to define complex orchestrations involving more than one model (neural network), as needed when training GANs.Specifically,
GANOrchestratorprovides training interface to separately train the generator and discriminatorhence the orchestrator abstract helps makes things cleaner.
New operations for Context class
Minimal tests