Skip to content

Commit a49d920

Browse files
authored
Add details 0000-trainer.md
1 parent 2176c57 commit a49d920

File tree

1 file changed

+244
-0
lines changed

1 file changed

+244
-0
lines changed

rfcs/0000-trainer.md

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,245 @@
11

2+
# RFC: ComfyUI Training Modules
3+
4+
- Start Date: 2025-03-01
5+
- Target Major Version: TBD
6+
7+
## Summary
8+
9+
This RFC proposes the addition of training capabilities to ComfyUI, enabling users to create and fine-tune LoRA (Low-Rank Adaptation) models directly through the ComfyUI interface. The proposal includes a set of node implementations for loading image datasets, training LoRAs, visualizing training progress, and saving trained models.
10+
11+
## Basic example
12+
13+
The basic workflow would allow users to:
14+
15+
1. Load an image dataset:
16+
![image](https://github.com/user-attachments/assets/3e00d09c-14ea-432d-a694-270ab13367ec)
17+
18+
19+
2. Train a LoRA on these images:
20+
![image](https://github.com/user-attachments/assets/e631e59a-9944-4fc2-b6dd-13e0f0c132f9)
21+
22+
3. Save the resulting LoRA:
23+
![image](https://github.com/user-attachments/assets/dbcbf1e4-af13-4095-86e9-7c4eada23432)
24+
25+
26+
4. Visualize training loss:
27+
![image](https://github.com/user-attachments/assets/d036b420-ef6c-4d2e-af55-d25c11724623)
28+
29+
30+
## Motivation
31+
32+
Currently, users who want to create custom LoRA models need to:
33+
34+
1. Use external tools and scripts for training, which often requires command-line expertise
35+
2. Set up specialized environments for training
36+
3. Manually move the trained models between systems
37+
38+
Adding training capabilities directly to ComfyUI would:
39+
40+
1. **Simplify the training workflow**: Users can train models in the same interface where they use them
41+
2. **Increase accessibility**: Users without programming experience can customize models
42+
3. **Enable rapid iteration**: The ability to train and immediately test models in the same interface
43+
4. **Provide visual feedback**: Real-time visualization of the training process
44+
5. **Maintain workflow continuity**: The entire model creation, training, and inference pipeline can be represented as a unified workflow
45+
46+
## Detailed design
47+
48+
The implementation consists of four main components:
49+
50+
### 1. Image Dataset Loading
51+
52+
Two nodes are proposed for loading image datasets:
53+
54+
- `LoadImageSetNode`: Loads individual images selected by the user
55+
- `LoadImageSetFromFolderNode`: Loads all images from a specified folder
56+
57+
These nodes offer options for handling images of different sizes (stretch, crop, pad) and prepare the images for training.
58+
59+
```python
60+
class LoadImageSetFromFolderNode:
61+
@classmethod
62+
def INPUT_TYPES(s):
63+
return {
64+
"required": {
65+
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."})
66+
},
67+
"optional": {
68+
"resize_method": (
69+
["None", "Stretch", "Crop", "Pad"],
70+
{"default": "None"},
71+
),
72+
}
73+
}
74+
75+
RETURN_TYPES = ("IMAGE",)
76+
FUNCTION = "load_images"
77+
CATEGORY = "loaders"
78+
EXPERIMENTAL = True
79+
DESCRIPTION = "Loads a batch of images from a directory for training."
80+
```
81+
82+
### 2. LoRA Training Node
83+
84+
The `TrainLoraNode` is the core component that handles the training process:
85+
86+
```python
87+
class TrainLoraNode:
88+
@classmethod
89+
def INPUT_TYPES(s):
90+
return {
91+
"required": {
92+
"model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}),
93+
"vae": (IO.VAE, {"tooltip": "The VAE model to use for encoding images for training."}),
94+
"positive": (IO.CONDITIONING, {"tooltip": "The positive conditioning to use for training."}),
95+
"image": (IO.IMAGE, {"tooltip": "The image or image batch to train the LoRA on."}),
96+
"batch_size": (IO.INT, {"default": 1, "min": 1, "max": 10000, "step": 1}),
97+
"steps": (IO.INT, {"default": 50, "min": 1, "max": 1000}),
98+
"learning_rate": (IO.FLOAT, {"default": 0.0003, "min": 0.0000001, "max": 1.0, "step": 0.00001}),
99+
"rank": (IO.INT, {"default": 8, "min": 1, "max": 128}),
100+
"optimizer": (["Adam", "AdamW", "SGD", "RMSprop"], {"default": "Adam"}),
101+
"loss_function": (["MSE", "L1", "Huber", "SmoothL1"], {"default": "MSE"}),
102+
"seed": (IO.INT, {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}),
103+
"training_dtype": (["bf16", "fp32"], {"default": "bf16"}),
104+
"existing_lora": (folder_paths.get_filename_list("loras") + ["[None]"], {"default": "[None]"}),
105+
},
106+
}
107+
108+
RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT)
109+
RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps")
110+
FUNCTION = "train"
111+
CATEGORY = "training"
112+
EXPERIMENTAL = True
113+
```
114+
115+
The training process:
116+
1. Takes a batch of images and encodes them using a VAE
117+
2. Sets up LoRA layers for all eligible weights in the model
118+
3. Configures an optimizer and loss function based on user selections
119+
4. Performs gradient-based training for the specified number of steps
120+
5. Returns the model with LoRA applied, the LoRA weights, a map of training losses, and the total training steps
121+
122+
### 3. Model Saving Node
123+
124+
The `SaveLoRA` node enables users to save their trained LoRA models:
125+
126+
```python
127+
class SaveLoRA:
128+
@classmethod
129+
def INPUT_TYPES(s):
130+
return {
131+
"required": {
132+
"lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to save."}),
133+
"prefix": (IO.STRING, {"default": "trained_lora"}),
134+
},
135+
"optional": {
136+
"steps": (IO.INT, {"forceInput": True}),
137+
},
138+
}
139+
140+
RETURN_TYPES = ()
141+
FUNCTION = "save"
142+
CATEGORY = "loaders"
143+
EXPERIMENTAL = True
144+
OUTPUT_NODE = True
145+
```
146+
147+
The node saves the LoRA weights in SafeTensors format, with a filename that includes the number of training steps and a timestamp.
148+
149+
### 4. Training Visualization Node
150+
151+
The `LossGraphNode` visualizes the training progress:
152+
153+
```python
154+
class LossGraphNode:
155+
@classmethod
156+
def INPUT_TYPES(s):
157+
return {
158+
"required": {
159+
"loss": (IO.LOSS_MAP, {"default": {}}),
160+
"filename_prefix": (IO.STRING, {"default": "loss_graph"}),
161+
},
162+
}
163+
164+
RETURN_TYPES = ()
165+
FUNCTION = "plot_loss"
166+
OUTPUT_NODE = True
167+
CATEGORY = "training"
168+
EXPERIMENTAL = True
169+
DESCRIPTION = "Plots the loss graph and saves it to the output directory."
170+
```
171+
172+
This node generates a graph showing the training loss over time, providing visual feedback on the training process.
173+
174+
### Supporting Components
175+
176+
The implementation also includes several support classes:
177+
178+
1. `TrainSampler`: A custom sampler that performs gradient updates during the sampling process
179+
2. `LoraDiff` and `BiasDiff`: Weight wrapper classes that apply LoRA adaptations to model weights
180+
181+
## Drawbacks
182+
183+
1. **Resource Consumption**: Training is computationally intensive and may strain systems with limited resources
184+
2. **UI Responsiveness**: Long training processes could make the ComfyUI interface less responsive
185+
3. **Complexity**: Adding training capabilities increases the complexity of the ComfyUI codebase
186+
4. **Learning Curve**: Users may need to understand more ML concepts to effectively use the training features
187+
188+
## Adoption strategy
189+
190+
1. **Experimental Flag**: Initially release nodes with the `EXPERIMENTAL = True` flag to indicate the developing nature of the feature
191+
2. **Documentation**: Provide comprehensive documentation and tutorial workflows
192+
3. **Gradual Feature Addition**: Start with basic LoRA training and expand to other training types based on user feedback
193+
4. **Default Parameters**: Set sensible defaults to help users get started without deep ML knowledge
194+
195+
## Unresolved questions
196+
197+
1. **Memory Management**: How will the system handle memory during training, especially for large models and datasets?
198+
2. **Checkpoint Frequency**: Should the system automatically save checkpoints during training to prevent loss of progress?
199+
3. **Training Interruption**: How should the system handle interrupted training sessions?
200+
4. **Hyperparameter Optimization**: Should the system provide tools for automatically finding optimal hyperparameters?
201+
5. **Multi-GPU Support**: How will training utilize multiple GPUs if available?
202+
6. **Integration with Existing Workflows**: How can trained models be seamlessly integrated into existing inference workflows?
203+
7. **Performance Metrics**: Should additional metrics beyond loss be tracked and visualized?
204+
8. **Dataset Preparation**: Should the system provide more tools for dataset curation and augmentation?
205+
206+
## Implementation Plan
207+
208+
### Phase 1: Basic LoRA Training
209+
210+
Initial implementation of the nodes described in this RFC.
211+
212+
### Phase 2: Enhanced Features
213+
214+
- Checkpoint saving during training
215+
- More advanced training visualizations
216+
- Support for additional training techniques (e.g., DreamBooth, Control model training like Control LoRA and IPA)
217+
218+
### Phase 3: Workflow Integration
219+
220+
- Templates for common training scenarios
221+
- Integration with model merging and inference workflows
222+
- Advanced dataset management tools
223+
224+
### Phase 4: Model Format
225+
226+
- New model format to improve model memory management and metadata of models in ComfyUI
227+
228+
## Links
229+
230+
<!--
231+
Both links below will be automatically filled in when you create the PR.
232+
You do not need to modify this section.
233+
-->
234+
235+
- [Full Rendered Proposal]()
236+
237+
- [Discussion Thread]()
238+
239+
<!--
240+
Optional: Include any additional links to related issues or resources below
241+
-->
242+
243+
---
244+
245+
**Important: Do NOT comment on this PR. Please use the discussion thread linked above to provide feedback, as it provides branched discussions that are easier to follow. This also makes the edit history of the PR clearer.**

0 commit comments

Comments
 (0)