-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Description
Motivation
Currently, the PyTorch implementation in this repository seems to lack native LoRA training support. Many users (including myself) aim to fine-tune pi0.5 on consumer-grade hardware (e.g., RTX 4090), which is challenging with full parameter fine-tuning due to VRAM constraints.
Proposal
I have implemented a native LoRA training pipeline for the PyTorch version of pi0.5. This implementation is lightweight and efficient, enabling high-performance fine-tuning on a single GPU.
Implementation Details
- Zero Dependencies: Unlike approaches using
peftor other external libraries, this implementation is pure PyTorch. It does not bloat the project's dependency list and offers better integration with the existing codebase. - Target Modules: The implementation identifies and injects adapters into the Attention mechanism and MLP layers within the transformer blocks:
- Attention:
q_proj,k_proj,v_proj,o_proj - FFN/MLP:
gate_proj,up_proj,down_proj
- Attention:
- Flexibility: The implementation supports unfreezing specific components (e.g., Vision Encoder) alongside LoRA adapters.
Parameter Breakdown
Below is the configuration used for validation (Hybrid: LoRA for Transformer + Unfrozen Vision Encoder):
Total Parameters: 3,643,299,600
Frozen Parameters: 3,202,149,376
Trainable Parameters:
-----------------------------------------
LoRA Adapters (Rank=16): 26,542,080 (0.73%)
Vision Encoder (Unfrozen): 412,442,352
Projection Layers: 2,165,792
-----------------------------------------
Total Trainable: 441,150,224 (12.11%)
Results & Validation
I have successfully trained a pi0.5 model on Libero dataset using this implementation.
-
Hardware: Single RTX 4090 (24GB)
-
VRAM Usage: ~16 GB (Peak during training)
-
Training Steps: 30k
-
Batch Size: 16
-
Evaluation (Libero Benchmark)
The fine-tuned model demonstrates promising results across Libero tasks:
| Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average |
|---|---|---|---|---|---|
| Ο0.5 (PyTorch LoRA @ 30k) | 96.6 | 98.2 | 97.8 | 95.2 | 96.95 |
Next Steps
I am ready to submit a Pull Request. Please let me know if this aligns with the roadmap and if there are any specific coding guidelines or branch preferences I should follow.
