Feat: Add PyTorch implementation of LoRA fine-tuning #854
+527
−13
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Title
Feat: Add PyTorch-based LoRA Fine-tuning Support
Description
Summary
This PR introduces and perfects PyTorch-based LoRA fine-tuning support for OpenPI. It enables the
pi05_libero_lora_pytorchconfiguration to undergo LoRA training natively in PyTorch and integrates seamlessly with the existing inference pipeline. This implementation ensures full compatibility with the original JAX training and inference workflows.Changes
1. Training Configuration (
src/openpi/training/config.py)TrainConfig: Added thelora_config: lora_pytorch.LoRATrainingConfig | None = Nonefield to unify the definition of PyTorch LoRA hyperparameters (rank, alpha, dropout, target modules, etc.) at the configuration level.openpi.models_pytorch.lora_pytorchto serve as the unified source for LoRA-related configurations and utilities.pi05_libero_lora_pytorch:gemma_2b/gemma_300m) with dynamic LoRA application vialora_config.pytorch_weight_pathto the pre-trained PyTorch weights directory.2. Model Loading (
src/openpi/models/model.py)BaseModelConfig.load_pytorch:PI0Pytorchmodel usingtrain_config.model.train_config.lora_configis enabled. If true, callslora_pytorch.apply_lora_to_pi0_pytorchto dynamically inject LoRA layers into the LLM and expert modules.model.safetensorsafter the structure is prepared.TrainConfig(containinglora_config) topolicy_config.create_trained_policyduring inference, the system automatically reconstructs the exact LoRA structure used during training.3. PyTorch Training Script (
scripts/train_pytorch.py)config.lora_config.enabled.config.pytorch_weight_pathbefore applying LoRA adapters to prevent structural/weight mismatches.lora_configafter injection.AdamWinitialization: In LoRA mode, the optimizer is created only for parameters withrequires_grad=True. This avoids allocating optimizer states for frozen parameters, significantly reducing VRAM usage.DistributedDataParallelafter LoRA application to ensure adapters are correctly synchronized.find_unused_parameters=Trueto support cases where certain parameters might be skipped during steps.Motivation
TrainConfig.Testing
serve_policy.pyandpolicy_config.create_trained_policy.Results