Skip to content

[Feature Proposal] Add LoRA training support for pi0.5 (PyTorch backend)Β #842

@sen-code-lost

Description

@sen-code-lost

Motivation
Currently, the PyTorch implementation in this repository seems to lack native LoRA training support. Many users (including myself) aim to fine-tune pi0.5 on consumer-grade hardware (e.g., RTX 4090), which is challenging with full parameter fine-tuning due to VRAM constraints.

Proposal
I have implemented a native LoRA training pipeline for the PyTorch version of pi0.5. This implementation is lightweight and efficient, enabling high-performance fine-tuning on a single GPU.

Implementation Details

  • Zero Dependencies: Unlike approaches using peft or other external libraries, this implementation is pure PyTorch. It does not bloat the project's dependency list and offers better integration with the existing codebase.
  • Target Modules: The implementation identifies and injects adapters into the Attention mechanism and MLP layers within the transformer blocks:
    • Attention: q_proj, k_proj, v_proj, o_proj
    • FFN/MLP: gate_proj, up_proj, down_proj
  • Flexibility: The implementation supports unfreezing specific components (e.g., Vision Encoder) alongside LoRA adapters.

Parameter Breakdown
Below is the configuration used for validation (Hybrid: LoRA for Transformer + Unfrozen Vision Encoder):

Total Parameters:           3,643,299,600
Frozen Parameters:          3,202,149,376

Trainable Parameters:
-----------------------------------------
LoRA Adapters (Rank=16):       26,542,080  (0.73%)
Vision Encoder (Unfrozen):    412,442,352
Projection Layers:              2,165,792
-----------------------------------------
Total Trainable:              441,150,224  (12.11%)

Results & Validation
I have successfully trained a pi0.5 model on Libero dataset using this implementation.

  • Hardware: Single RTX 4090 (24GB)

  • VRAM Usage: ~16 GB (Peak during training)

  • Training Steps: 30k

  • Batch Size: 16

  • Loss Curve(Libero Dataset):
    Loss Curve

  • Evaluation (Libero Benchmark)
    The fine-tuned model demonstrates promising results across Libero tasks:

Model Libero Spatial Libero Object Libero Goal Libero 10 Average
Ο€0.5 (PyTorch LoRA @ 30k) 96.6 98.2 97.8 95.2 96.95

Next Steps
I am ready to submit a Pull Request. Please let me know if this aligns with the roadmap and if there are any specific coding guidelines or branch preferences I should follow.


Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions