Skip to content

Conversation

@sen-code-lost
Copy link

Title

Feat: Add PyTorch-based LoRA Fine-tuning Support

Description

Summary

This PR introduces and perfects PyTorch-based LoRA fine-tuning support for OpenPI. It enables the pi05_libero_lora_pytorch configuration to undergo LoRA training natively in PyTorch and integrates seamlessly with the existing inference pipeline. This implementation ensures full compatibility with the original JAX training and inference workflows.


Changes

1. Training Configuration (src/openpi/training/config.py)

  • Updated TrainConfig: Added the lora_config: lora_pytorch.LoRATrainingConfig | None = None field to unify the definition of PyTorch LoRA hyperparameters (rank, alpha, dropout, target modules, etc.) at the configuration level.
  • Centralized Entry Point: Imported openpi.models_pytorch.lora_pytorch to serve as the unified source for LoRA-related configurations and utilities.
  • New Configuration: Added pi05_libero_lora_pytorch:
    • Uses standard variants (gemma_2b / gemma_300m) with dynamic LoRA application via lora_config.
    • Points pytorch_weight_path to the pre-trained PyTorch weights directory.
    • Adjusts hyperparameters (batch size, learning rate schedule) specifically for LoRA fine-tuning scenarios.

2. Model Loading (src/openpi/models/model.py)

  • LoRA Support in BaseModelConfig.load_pytorch:
    • Constructs the PI0Pytorch model using train_config.model.
    • Checks if train_config.lora_config is enabled. If true, calls lora_pytorch.apply_lora_to_pi0_pytorch to dynamically inject LoRA layers into the LLM and expert modules.
    • Crucial Step: Loads the full PyTorch weights from model.safetensors after the structure is prepared.
  • Benefit: Ensures alignment between training and inference. By passing the same TrainConfig (containing lora_config) to policy_config.create_trained_policy during inference, the system automatically reconstructs the exact LoRA structure used during training.

3. PyTorch Training Script (scripts/train_pytorch.py)

  • Integrated LoRA Workflow:
    • Conditionally enables LoRA based on config.lora_config.enabled.
    • Weight Loading Strategy: Loads pre-trained weights from config.pytorch_weight_path before applying LoRA adapters to prevent structural/weight mismatches.
    • Freezes non-LoRA parameters and specified modules via lora_config after injection.
  • Optimizer Efficiency:
    • Optimized AdamW initialization: In LoRA mode, the optimizer is created only for parameters with requires_grad=True. This avoids allocating optimizer states for frozen parameters, significantly reducing VRAM usage.
    • Maintains backward compatibility by using full-parameter optimization when LoRA is disabled.
  • DDP Compatibility:
    • Wraps the model with DistributedDataParallel after LoRA application to ensure adapters are correctly synchronized.
    • Retains find_unused_parameters=True to support cases where certain parameters might be skipped during steps.
  • Consistency: Kept other training logic (LR scheduling, logging, checkpointing) consistent with the original script.

Motivation

  • To provide a native PyTorch LoRA fine-tuning path for pi0.5 + Libero scenarios without disrupting existing JAX code.
  • To achieve "One Config for Training and Inference" by centrally managing LoRA settings in TrainConfig.
  • To improve training efficiency and reduce memory footprint by optimizing gradient updates and optimizer states specifically for LoRA.

Testing

  • Local Training Verification (Single/Multi-GPU):
    • Verified that pre-trained PyTorch weights load correctly.
    • Confirmed LoRA adapters are injected and only specified parameters are updated.
    • Verified that checkpoints can be saved and training can resume from them.
  • End-to-End Inference:
    • Verified the Training-to-Inference loop: Successfully loaded the PyTorch checkpoint generated by LoRA training using serve_policy.py and policy_config.create_trained_policy.

Results

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant