Skip to content

Conversation

@konakaji
Copy link

@konakaji konakaji commented Dec 1, 2025

Implement GQE Manuscript V2 features (GRPO loss, Replay Buffer, Variance-based scheduler)

Summary

This PR implements key features from the updated manuscript (arXiv:2401.09253v2)
including GRPO loss function, Replay Buffer mechanism, and Variance-based temperature scheduler,
along with a comprehensive refactoring of the GQE (Generative Quantum Eigensolver) internal
implementation to improve modularity and add user customization options.

Motivation

  • Implement Manuscript V2 features: This PR implements the GRPO loss function, Replay
    Buffer mechanism, and Variance-based temperature scheduler described in the updated manuscript
    (arXiv:2401.09253v2), which were not present
    in the original implementation
  • Previous implementation mixed Fabric and Lightning frameworks, causing complexity
  • Model and training logic were tightly coupled

Changes

Core Implementation

  • Add GRPO loss function and Replay Buffer: Implemented GRPO (Group Relative Policy
    Optimization) loss and replay buffer mechanism as described in
    arXiv:2401.09253v2, set as default
  • Add Variance-based temperature scheduler: Implemented adaptive temperature scheduling
    based on energy variance as described in the manuscript V2
  • Separate Model and LightningModule: Decoupled previously combined Transformer into
    GPT2 (model) and Pipeline (LightningModule)
  • Migrate to PyTorch Lightning: Unified training implementation to PyTorch Lightning
    framework, removing Fabric/Lightning mixing. Updated logger from Fabric to Lightning
    accordingly
  • Extract Scheduler: Moved scheduler implementations to separate file for better
    modularity

User-Facing Features

  • Add customization options: Users can now provide custom trainer_kwargs and
    callbacks to customize Lightning Trainer behavior
  • Add operator pool utilities: New utils.py with helper functions for generating
    operator pools (get_identity, get_gqe_pauli_pool)

Examples and Tests

  • Add N2 molecule example: New gqe_n2.py demonstrating Pauli operator pool usage
  • Add comprehensive tests: Tests for operator pool utilities, schedulers and updated existing tests

Breaking Changes

⚠️ Configuration structure changed:

  • cfg.fabric_loggercfg.lightning_logger
  • cfg.use_fabric_loggingcfg.use_lightning_logging

Testing

  • All existing tests pass (12/12 in test_gqe.py)
  • New tests added for operator pool utilities
  • Verified with H2 and N2 molecule examples

References

  • K. Nakaji et al., "Generative Quantum Eigensolver with Adaptive Ansatz",
    arXiv:2401.09253v2 (2024)

- Split transformer.py into pipeline.py and model.py for better modularity
- pipeline.py: Contains Pipeline class for data processing
- model.py: Contains GPT2 model class
- Updated gqe.py to import from new modules
- Improves code organization and maintainability

Signed-off-by: Kohei Nakaji <[email protected]>
- Add factory.py: Factory class for creating loss functions
- Refactor pipeline.py: Simplify code and remove redundant logic
- Update gqe.py: Improve code structure and readability
- Update loss.py: Minor improvements for better consistency
- Reduces code complexity and improves maintainability

Signed-off-by: Kohei Nakaji <[email protected]>
- Add scheduler.py: Contains TemperatureScheduler, DefaultScheduler, and CosineScheduler
- Update gqe.py: Import schedulers from new module, remove scheduler definitions
- Improves code organization and separation of concerns
- Reduces gqe.py complexity by ~70 lines

Signed-off-by: Kohei Nakaji <[email protected]>
- Migrate from manual Fabric loop to Lightning Trainer
- Extract ReplayBuffer and BufferDataset to data.py
- Create callbacks.py with MinEnergyCallback and TrajectoryCallback
- Simplify TemperatureScheduler interface (get_inverse_temperature, update)
- Add Factory.create_temperature_scheduler method
- Move seed_everything to Pipeline.__init__
- Disable checkpointing for performance (2.5s -> 0.001s between epochs)
- Add num_sanity_val_steps=0 to suppress warnings
- Fix device placement issues for CUDA tensors
- Set DataLoader num_workers=0 to avoid pickling SpinOperator

Signed-off-by: Kohei Nakaji <[email protected]>
- Add GRPOLoss class inheriting from Loss base class
- Update loss.compute() signature to use **kwargs instead of context
- Add GRPOLoss to Factory with configurable clip_ratio
- Remove unnecessary logger check in Pipeline
- Update gqe_h2.py example with max_iters=50
- Clean up pyscf-generated files (.log, .chk)

Signed-off-by: Kohei Nakaji <[email protected]>
- Implement variance-based adaptive temperature scheduler
  - Adjusts temperature based on energy variance in training batches
  - Increases temperature for high variance (exploration)
  - Decreases temperature for low variance (exploitation)

- Add scheduler factory support
  - Extend Factory.create_temperature_scheduler() to support 'variance' mode
  - Configure via cfg.scheduler='variance' and cfg.target_variance

- Fix Loss classes to properly inherit from torch.nn.Module
  - Add super().__init__() calls to all Loss subclasses
  - Fix device placement issues in GFlowLogitMatching

- Enhance test coverage
  - Add test_variance_scheduler() for VarBasedScheduler unit testing
  - Add test_solvers_gqe_with_variance_scheduler() for integration testing
  - Add test_solvers_gqe_with_cosine_scheduler() for CosineScheduler
  - Add test_solvers_gqe_with_exp_loss() for ExpLogitMatching
  - Fix existing scheduler tests to use new API methods

All 11 GQE tests pass successfully.

Signed-off-by: Kohei Nakaji <[email protected]>
- Add trainer_kwargs and callbacks configuration options for customization
- Flatten config structure (cfg.trainer.* → cfg.*)
- Add operator pool utility functions (utils.py)
- Add N2 molecule example (gqe_n2.py)
- Update config docstring to match implementation

Signed-off-by: Kohei Nakaji <[email protected]>
…fault

- Add test_get_gqe_pauli_pool to verify operator pool generation
- Set enable_checkpointing=False by default in trainer config
- Minor formatting fixes in test file

Signed-off-by: Kohei Nakaji <[email protected]>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Dec 1, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@konakaji konakaji changed the title Feature/gqe customization and grpo Implement GQE Manuscript V2 features and refactor architecture Dec 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant