Diffusion Loss is LYING to Optimizer! (AKA my addition to Everything you know about loss is a LIE! #294) #1994
jilek-josef
started this conversation in
Show and tell
Replies: 1 comment
-
I tested with flow-chroma model. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Well not exactly, I just wanted funny title but it's not completely wrong either, I will try to explain.
Timesteps, Optimizers, Batches & Loss
First, let's recapitulate what we know so far. A diffusion models are based on something we call de-noising process. The whole thing the models does is removing noise from image (or more like something we call latent space but for simplicity let’s imagine it as image for now). So during learning we add noise to the images as well and the amount of noise is defined by timestep. The higher timestep the more noise is added, the lower timestep the less. Which is all fine except the part it breaks our loss function.
The model naturally can’t be the same accurate when facing a lot of noise, while it’s much simpler task when there is less noise. Which naturally results in higher loss.
Usually in machine learning, loss is direct indicate how good something is matching but here due to timesteps, loss is only really meaningful when compared among same timesteps. This issue is sort of mitigated when we do really high batches like 100+ or more like 1000+ batch size, then the loss averages among all and is relatively stable.
With LoRA that is however issue, we can’t have that big sizes here due to compute and VRAM constrains, so the loss can’t be averaged in batch good enough and fluctuates. This is amplified by that the timesteps in the single batch are random so we sometime can take all low or all high batch.
Why loss is important? Well it directly guides and affect optimizer. Higher loss means effectively bigger learning rate and when there are many high loss and than suddenly low loss is introduces it causes overshooting. Imagine you are basically blindly doing something different ways and always being said it’s no good and than suddenly you get praised “this is the best thing you have ever did”, you would get confused right, or perhaps thinking you have finally made it. Well this is sort of the same situation as when optimizer suddenly gets low loss.
Loli Optimizer: A New Approach
To address the loss issue we have sort of 2 options. Either we can somehow scale the loss directly, for example multiply the loss by 2 when timestep lower than x but than we would be lying to optimizer and while more stable, it will have side effects such as increased learning rate. It’s not possible to generalize this either since it may differ a bit depending on what we are training like style vs character.
The second option I come up with is creating a special “timestep aware” optimizer. To be fair, I don’t think that is effectively possible but I come up with workaround. Why not to cluster timesteps and have different optimizer states for each cluster?
And so I did that. But it requires one optimization to work properly. Instead of having random timesteps in batch, we need to have only timesteps from one cluster par batch. This way we get interesting effect, we update more or less different weights per different cluster. Which is good and actually sort of required since having different optimizer states means they sort of can do different and possibly contradicting updates.
A good balance I think is having about 10 clusters. There is not much benefit having more. Having less, like 5 is also acceptable but going beyond that may start to be problematic. Also while this should be applicable for full training, possibly speeding convergence as well, it may not be the best thing there, since having 10 clusters means 10x more memory consumption on optimizer. For LoRA this works since the optimizer size scale with LoRA rank.
Loli Loss Optimization
There are multiple loss functions and each works a bit different and is suitable for different thing. There is commonly used MSE which tries to enforce pixel perfect state by a lot, that may be really good thing but for diffusion not always. Let’s say we want to have a tree in the image, The tree is on one side but in our example is on other side. MSE in that case would act like everything is wrong.
Than there is L1 which sort of balances this approach and does not try to be pixel perfect so much. We can base loss on cosine similarity as well which is accepts some position shifts as rather fine thing.
L1 is cool and probably most balanced approach out here but why not to use all 3? Computing loss is cheap so we can just compute all 3 and than sum them. Such as MSE*0.4+L1*0.4+cos_loss*0.2. We can even make this a parameter of the training and try to balance more but this showed to be rather better alternative to pure MSE.
Code
This was specifically tested on Flux/Chroma, however it should generalize to any diffusion model. Here is a code which implements that for Chroma LoRA trainer https://github.com/jilek-josef/flow-chroma
This ultimately needs more tests, since there may be scenarios when this is not ideal, however from my testing it helped a lot, especially for Flux/Chroma where without that the learning was significantly unstable and not doing well.
Beta Was this translation helpful? Give feedback.
All reactions