|
1 | 1 |
|
| 2 | +# RFC: ComfyUI Training Modules |
| 3 | + |
| 4 | +- Start Date: 2025-03-01 |
| 5 | +- Target Major Version: TBD |
| 6 | + |
| 7 | +## Summary |
| 8 | + |
| 9 | +This RFC proposes the addition of training capabilities to ComfyUI, enabling users to create and fine-tune LoRA (Low-Rank Adaptation) models directly through the ComfyUI interface. The proposal includes a set of node implementations for loading image datasets, training LoRAs, visualizing training progress, and saving trained models. |
| 10 | + |
| 11 | +## Basic example |
| 12 | + |
| 13 | +The basic workflow would allow users to: |
| 14 | + |
| 15 | +1. Load an image dataset: |
| 16 | + |
| 17 | + |
| 18 | + |
| 19 | +2. Train a LoRA on these images: |
| 20 | + |
| 21 | + |
| 22 | +3. Save the resulting LoRA: |
| 23 | + |
| 24 | + |
| 25 | + |
| 26 | +4. Visualize training loss: |
| 27 | + |
| 28 | + |
| 29 | + |
| 30 | +## Motivation |
| 31 | + |
| 32 | +Currently, users who want to create custom LoRA models need to: |
| 33 | + |
| 34 | +1. Use external tools and scripts for training, which often requires command-line expertise |
| 35 | +2. Set up specialized environments for training |
| 36 | +3. Manually move the trained models between systems |
| 37 | + |
| 38 | +Adding training capabilities directly to ComfyUI would: |
| 39 | + |
| 40 | +1. **Simplify the training workflow**: Users can train models in the same interface where they use them |
| 41 | +2. **Increase accessibility**: Users without programming experience can customize models |
| 42 | +3. **Enable rapid iteration**: The ability to train and immediately test models in the same interface |
| 43 | +4. **Provide visual feedback**: Real-time visualization of the training process |
| 44 | +5. **Maintain workflow continuity**: The entire model creation, training, and inference pipeline can be represented as a unified workflow |
| 45 | + |
| 46 | +## Detailed design |
| 47 | + |
| 48 | +The implementation consists of four main components: |
| 49 | + |
| 50 | +### 1. Image Dataset Loading |
| 51 | + |
| 52 | +Two nodes are proposed for loading image datasets: |
| 53 | + |
| 54 | +- `LoadImageSetNode`: Loads individual images selected by the user |
| 55 | +- `LoadImageSetFromFolderNode`: Loads all images from a specified folder |
| 56 | + |
| 57 | +These nodes offer options for handling images of different sizes (stretch, crop, pad) and prepare the images for training. |
| 58 | + |
| 59 | +```python |
| 60 | +class LoadImageSetFromFolderNode: |
| 61 | + @classmethod |
| 62 | + def INPUT_TYPES(s): |
| 63 | + return { |
| 64 | + "required": { |
| 65 | + "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}) |
| 66 | + }, |
| 67 | + "optional": { |
| 68 | + "resize_method": ( |
| 69 | + ["None", "Stretch", "Crop", "Pad"], |
| 70 | + {"default": "None"}, |
| 71 | + ), |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + RETURN_TYPES = ("IMAGE",) |
| 76 | + FUNCTION = "load_images" |
| 77 | + CATEGORY = "loaders" |
| 78 | + EXPERIMENTAL = True |
| 79 | + DESCRIPTION = "Loads a batch of images from a directory for training." |
| 80 | +``` |
| 81 | + |
| 82 | +### 2. LoRA Training Node |
| 83 | + |
| 84 | +The `TrainLoraNode` is the core component that handles the training process: |
| 85 | + |
| 86 | +```python |
| 87 | +class TrainLoraNode: |
| 88 | + @classmethod |
| 89 | + def INPUT_TYPES(s): |
| 90 | + return { |
| 91 | + "required": { |
| 92 | + "model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}), |
| 93 | + "vae": (IO.VAE, {"tooltip": "The VAE model to use for encoding images for training."}), |
| 94 | + "positive": (IO.CONDITIONING, {"tooltip": "The positive conditioning to use for training."}), |
| 95 | + "image": (IO.IMAGE, {"tooltip": "The image or image batch to train the LoRA on."}), |
| 96 | + "batch_size": (IO.INT, {"default": 1, "min": 1, "max": 10000, "step": 1}), |
| 97 | + "steps": (IO.INT, {"default": 50, "min": 1, "max": 1000}), |
| 98 | + "learning_rate": (IO.FLOAT, {"default": 0.0003, "min": 0.0000001, "max": 1.0, "step": 0.00001}), |
| 99 | + "rank": (IO.INT, {"default": 8, "min": 1, "max": 128}), |
| 100 | + "optimizer": (["Adam", "AdamW", "SGD", "RMSprop"], {"default": "Adam"}), |
| 101 | + "loss_function": (["MSE", "L1", "Huber", "SmoothL1"], {"default": "MSE"}), |
| 102 | + "seed": (IO.INT, {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}), |
| 103 | + "training_dtype": (["bf16", "fp32"], {"default": "bf16"}), |
| 104 | + "existing_lora": (folder_paths.get_filename_list("loras") + ["[None]"], {"default": "[None]"}), |
| 105 | + }, |
| 106 | + } |
| 107 | + |
| 108 | + RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT) |
| 109 | + RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps") |
| 110 | + FUNCTION = "train" |
| 111 | + CATEGORY = "training" |
| 112 | + EXPERIMENTAL = True |
| 113 | +``` |
| 114 | + |
| 115 | +The training process: |
| 116 | +1. Takes a batch of images and encodes them using a VAE |
| 117 | +2. Sets up LoRA layers for all eligible weights in the model |
| 118 | +3. Configures an optimizer and loss function based on user selections |
| 119 | +4. Performs gradient-based training for the specified number of steps |
| 120 | +5. Returns the model with LoRA applied, the LoRA weights, a map of training losses, and the total training steps |
| 121 | + |
| 122 | +### 3. Model Saving Node |
| 123 | + |
| 124 | +The `SaveLoRA` node enables users to save their trained LoRA models: |
| 125 | + |
| 126 | +```python |
| 127 | +class SaveLoRA: |
| 128 | + @classmethod |
| 129 | + def INPUT_TYPES(s): |
| 130 | + return { |
| 131 | + "required": { |
| 132 | + "lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to save."}), |
| 133 | + "prefix": (IO.STRING, {"default": "trained_lora"}), |
| 134 | + }, |
| 135 | + "optional": { |
| 136 | + "steps": (IO.INT, {"forceInput": True}), |
| 137 | + }, |
| 138 | + } |
| 139 | + |
| 140 | + RETURN_TYPES = () |
| 141 | + FUNCTION = "save" |
| 142 | + CATEGORY = "loaders" |
| 143 | + EXPERIMENTAL = True |
| 144 | + OUTPUT_NODE = True |
| 145 | +``` |
| 146 | + |
| 147 | +The node saves the LoRA weights in SafeTensors format, with a filename that includes the number of training steps and a timestamp. |
| 148 | + |
| 149 | +### 4. Training Visualization Node |
| 150 | + |
| 151 | +The `LossGraphNode` visualizes the training progress: |
| 152 | + |
| 153 | +```python |
| 154 | +class LossGraphNode: |
| 155 | + @classmethod |
| 156 | + def INPUT_TYPES(s): |
| 157 | + return { |
| 158 | + "required": { |
| 159 | + "loss": (IO.LOSS_MAP, {"default": {}}), |
| 160 | + "filename_prefix": (IO.STRING, {"default": "loss_graph"}), |
| 161 | + }, |
| 162 | + } |
| 163 | + |
| 164 | + RETURN_TYPES = () |
| 165 | + FUNCTION = "plot_loss" |
| 166 | + OUTPUT_NODE = True |
| 167 | + CATEGORY = "training" |
| 168 | + EXPERIMENTAL = True |
| 169 | + DESCRIPTION = "Plots the loss graph and saves it to the output directory." |
| 170 | +``` |
| 171 | + |
| 172 | +This node generates a graph showing the training loss over time, providing visual feedback on the training process. |
| 173 | + |
| 174 | +### Supporting Components |
| 175 | + |
| 176 | +The implementation also includes several support classes: |
| 177 | + |
| 178 | +1. `TrainSampler`: A custom sampler that performs gradient updates during the sampling process |
| 179 | +2. `LoraDiff` and `BiasDiff`: Weight wrapper classes that apply LoRA adaptations to model weights |
| 180 | + |
| 181 | +## Drawbacks |
| 182 | + |
| 183 | +1. **Resource Consumption**: Training is computationally intensive and may strain systems with limited resources |
| 184 | +2. **UI Responsiveness**: Long training processes could make the ComfyUI interface less responsive |
| 185 | +3. **Complexity**: Adding training capabilities increases the complexity of the ComfyUI codebase |
| 186 | +4. **Learning Curve**: Users may need to understand more ML concepts to effectively use the training features |
| 187 | + |
| 188 | +## Adoption strategy |
| 189 | + |
| 190 | +1. **Experimental Flag**: Initially release nodes with the `EXPERIMENTAL = True` flag to indicate the developing nature of the feature |
| 191 | +2. **Documentation**: Provide comprehensive documentation and tutorial workflows |
| 192 | +3. **Gradual Feature Addition**: Start with basic LoRA training and expand to other training types based on user feedback |
| 193 | +4. **Default Parameters**: Set sensible defaults to help users get started without deep ML knowledge |
| 194 | + |
| 195 | +## Unresolved questions |
| 196 | + |
| 197 | +1. **Memory Management**: How will the system handle memory during training, especially for large models and datasets? |
| 198 | +2. **Checkpoint Frequency**: Should the system automatically save checkpoints during training to prevent loss of progress? |
| 199 | +3. **Training Interruption**: How should the system handle interrupted training sessions? |
| 200 | +4. **Hyperparameter Optimization**: Should the system provide tools for automatically finding optimal hyperparameters? |
| 201 | +5. **Multi-GPU Support**: How will training utilize multiple GPUs if available? |
| 202 | +6. **Integration with Existing Workflows**: How can trained models be seamlessly integrated into existing inference workflows? |
| 203 | +7. **Performance Metrics**: Should additional metrics beyond loss be tracked and visualized? |
| 204 | +8. **Dataset Preparation**: Should the system provide more tools for dataset curation and augmentation? |
| 205 | + |
| 206 | +## Implementation Plan |
| 207 | + |
| 208 | +### Phase 1: Basic LoRA Training |
| 209 | + |
| 210 | +Initial implementation of the nodes described in this RFC. |
| 211 | + |
| 212 | +### Phase 2: Enhanced Features |
| 213 | + |
| 214 | +- Checkpoint saving during training |
| 215 | +- More advanced training visualizations |
| 216 | +- Support for additional training techniques (e.g., DreamBooth, Control model training like Control LoRA and IPA) |
| 217 | + |
| 218 | +### Phase 3: Workflow Integration |
| 219 | + |
| 220 | +- Templates for common training scenarios |
| 221 | +- Integration with model merging and inference workflows |
| 222 | +- Advanced dataset management tools |
| 223 | + |
| 224 | +### Phase 4: Model Format |
| 225 | + |
| 226 | +- New model format to improve model memory management and metadata of models in ComfyUI |
| 227 | + |
| 228 | +## Links |
| 229 | + |
| 230 | +<!-- |
| 231 | + Both links below will be automatically filled in when you create the PR. |
| 232 | + You do not need to modify this section. |
| 233 | +--> |
| 234 | + |
| 235 | +- [Full Rendered Proposal]() |
| 236 | + |
| 237 | +- [Discussion Thread]() |
| 238 | + |
| 239 | +<!-- |
| 240 | + Optional: Include any additional links to related issues or resources below |
| 241 | +--> |
| 242 | + |
| 243 | +--- |
| 244 | + |
| 245 | +**Important: Do NOT comment on this PR. Please use the discussion thread linked above to provide feedback, as it provides branched discussions that are easier to follow. This also makes the edit history of the PR clearer.** |
0 commit comments