-
Notifications
You must be signed in to change notification settings - Fork 37
Implement GQE Manuscript V2 features and refactor architecture #373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
konakaji
wants to merge
9
commits into
NVIDIA:main
Choose a base branch
from
konakaji:feature/gqe-customization-and-grpo
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
- Split transformer.py into pipeline.py and model.py for better modularity - pipeline.py: Contains Pipeline class for data processing - model.py: Contains GPT2 model class - Updated gqe.py to import from new modules - Improves code organization and maintainability Signed-off-by: Kohei Nakaji <[email protected]>
- Add factory.py: Factory class for creating loss functions - Refactor pipeline.py: Simplify code and remove redundant logic - Update gqe.py: Improve code structure and readability - Update loss.py: Minor improvements for better consistency - Reduces code complexity and improves maintainability Signed-off-by: Kohei Nakaji <[email protected]>
- Add scheduler.py: Contains TemperatureScheduler, DefaultScheduler, and CosineScheduler - Update gqe.py: Import schedulers from new module, remove scheduler definitions - Improves code organization and separation of concerns - Reduces gqe.py complexity by ~70 lines Signed-off-by: Kohei Nakaji <[email protected]>
- Migrate from manual Fabric loop to Lightning Trainer - Extract ReplayBuffer and BufferDataset to data.py - Create callbacks.py with MinEnergyCallback and TrajectoryCallback - Simplify TemperatureScheduler interface (get_inverse_temperature, update) - Add Factory.create_temperature_scheduler method - Move seed_everything to Pipeline.__init__ - Disable checkpointing for performance (2.5s -> 0.001s between epochs) - Add num_sanity_val_steps=0 to suppress warnings - Fix device placement issues for CUDA tensors - Set DataLoader num_workers=0 to avoid pickling SpinOperator Signed-off-by: Kohei Nakaji <[email protected]>
- Add GRPOLoss class inheriting from Loss base class - Update loss.compute() signature to use **kwargs instead of context - Add GRPOLoss to Factory with configurable clip_ratio - Remove unnecessary logger check in Pipeline - Update gqe_h2.py example with max_iters=50 - Clean up pyscf-generated files (.log, .chk) Signed-off-by: Kohei Nakaji <[email protected]>
- Implement variance-based adaptive temperature scheduler - Adjusts temperature based on energy variance in training batches - Increases temperature for high variance (exploration) - Decreases temperature for low variance (exploitation) - Add scheduler factory support - Extend Factory.create_temperature_scheduler() to support 'variance' mode - Configure via cfg.scheduler='variance' and cfg.target_variance - Fix Loss classes to properly inherit from torch.nn.Module - Add super().__init__() calls to all Loss subclasses - Fix device placement issues in GFlowLogitMatching - Enhance test coverage - Add test_variance_scheduler() for VarBasedScheduler unit testing - Add test_solvers_gqe_with_variance_scheduler() for integration testing - Add test_solvers_gqe_with_cosine_scheduler() for CosineScheduler - Add test_solvers_gqe_with_exp_loss() for ExpLogitMatching - Fix existing scheduler tests to use new API methods All 11 GQE tests pass successfully. Signed-off-by: Kohei Nakaji <[email protected]>
- Add trainer_kwargs and callbacks configuration options for customization - Flatten config structure (cfg.trainer.* → cfg.*) - Add operator pool utility functions (utils.py) - Add N2 molecule example (gqe_n2.py) - Update config docstring to match implementation Signed-off-by: Kohei Nakaji <[email protected]>
…fault - Add test_get_gqe_pauli_pool to verify operator pool generation - Set enable_checkpointing=False by default in trainer config - Minor formatting fixes in test file Signed-off-by: Kohei Nakaji <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Implement GQE Manuscript V2 features (GRPO loss, Replay Buffer, Variance-based scheduler)
Summary
This PR implements key features from the updated manuscript (arXiv:2401.09253v2)
including GRPO loss function, Replay Buffer mechanism, and Variance-based temperature scheduler,
along with a comprehensive refactoring of the GQE (Generative Quantum Eigensolver) internal
implementation to improve modularity and add user customization options.
Motivation
Buffer mechanism, and Variance-based temperature scheduler described in the updated manuscript
(arXiv:2401.09253v2), which were not present
in the original implementation
Changes
Core Implementation
Optimization) loss and replay buffer mechanism as described in
arXiv:2401.09253v2, set as default
based on energy variance as described in the manuscript V2
GPT2 (model) and Pipeline (LightningModule)
framework, removing Fabric/Lightning mixing. Updated logger from Fabric to Lightning
accordingly
modularity
User-Facing Features
trainer_kwargsandcallbacksto customize Lightning Trainer behaviorutils.pywith helper functions for generatingoperator pools (
get_identity,get_gqe_pauli_pool)Examples and Tests
gqe_n2.pydemonstrating Pauli operator pool usageBreaking Changes
cfg.fabric_logger→cfg.lightning_loggercfg.use_fabric_logging→cfg.use_lightning_loggingTesting
References
arXiv:2401.09253v2 (2024)