Skip to content

Add preference optimization (Diffusion-DPO, MaPO, DDO, BPO, CPO, SDPO, SimPO)#1427

Draft
rockerBOO wants to merge 28 commits intokohya-ss:sd3from
rockerBOO:po
Draft

Add preference optimization (Diffusion-DPO, MaPO, DDO, BPO, CPO, SDPO, SimPO)#1427
rockerBOO wants to merge 28 commits intokohya-ss:sd3from
rockerBOO:po

Conversation

@rockerBOO
Copy link
Contributor

@rockerBOO rockerBOO commented Jul 13, 2024

Add preference optimization (PO) support
Add paired images in dataset.

Preference Optimization algo implemented:

Currently looking for feedback about implementation.

Decisions I made and why:

Pairing images

Pairing images in ImageSetInfo (exetnding ImageInfo)

  • Pairing images keeps the images from shuffling
  • Paired images should be updated after seeing both.

Batch size of 1 will load 2 captions and 2 images for 1 requested image/caption pair

Dataset

Datasets can be defined as "preference" args.preference --preference and in dataset_config.toml

  [[datasets.subsets]]
  image_dir = "dataset/1_name"
  preference = true

To create a pattern of dataset, we are hard coding in dataset/1_name/w and dataset/1_name/l. You would then have a typical dreambooth-like dataset with the following.

  • dataset/1_name/w/image.png
  • dataset/1_name/w/image.caption
  • dataset/1_name/l/image.png
  • dataset/1_name/l/image.caption

Note w and l are like the typical dataset with image/caption file pairs. They all have the same file name to create the pairs.

Good idea to consider other file dataset patterns.

Preference dataset examples:

Pickapic dataset is a preference between 2 images and showing the pairing and embedding the 2 images into the dataset.

Caption prefix/suffix for preference/non-preference

Prefix/suffix allow some techniques of moving away from some concepts. Allows different ones for preference/non preference, to give flexibility in experimentation.

Training

Added PO into the main training script to allow flexibility but will be moved to the typical functions for these. I have it setup for network training but would work other scripts.

  • Images come in as pairs on the tensors
  • Pass the loss through the PO algorithm for the pairs
  • Log the associated values from the training

Hyperparameters

--beta_dpo = KL-divergence parameter beta for Diffusion-DPO

2500 for 1.5, 5000 for SDXL were what I have found suggested.

--mapo_weight = MaPO contribution factor

Start around 0.1 but adjusting this can be helpful at how much the contribution of the preference optimization will have on the training. See

TODO

  • Cache support
  • ControlNet dataset (For masking)

Possible issues

Preference and regular training datasets mixed

This mixing would need to worked on at higher than 1 batch size. We assume chunking of pairs so unpaired images won't work that way.

The implementations may not be accurate

If you see something not correct, let me know.

Usage

State: This is currently working and producing favorable results.

Images/caption pairs stored in w and l directories.

dataset/1_myimages/w/image.jpg
dataset/1_myimages/w/image.txt

dataset/1_myimages/l/image.jpg
dataset/1_myimages/l/image.txt

NOTE Use the same name for images in w and l directories to make them paired

python train_network.py ... --preference --dataset_dir dataset --mapo_weight=0.1

or in your dataset config

  [[datasets.subsets]]
  image_dir = "dataset/1_name"
  preference = true

Related tickets: #1040

@feffy380
Copy link
Contributor

Do you have any training samples from this?

@rockerBOO
Copy link
Contributor Author

36ee5259 = pickapic dataset sample (500 preferences)
a9a03acb = my own preference dataset from prompts generated on SD 1.5

10 epochs, LoRA 16/16 with Prodigy d_coef at 1.25

MaPO with contribution weight of mapo_weight = 0.1

xyz_grid-0010-1229252821
Dreamshaper 8

xyz_grid-0005-2443401054
3D Animation 1.0

xyz_grid-0002-3672213076
SD 1.5 base model

Papers generally suggesting around 1000 preferences. I have been making a preference creation tool so one could make their own preferences on their own dataset.

@feffy380
Copy link
Contributor

I'm finding I need the learning rate as low as 1e-6 (for an SDXL lora), possibly lower if you have a bigger dataset. I also had text encoder training disabled.

One thing I want to try is using real chosen plus AI generated rejected images, inspired by what LLM folks have been doing to bypass the need for collecting real preference pairs

@rockerBOO
Copy link
Contributor Author

I'm finding I need the learning rate as low as 1e-6 (for an SDXL lora), possibly lower if you have a bigger dataset. I also had text encoder training disabled.

Which did you try Diffusion-DPO or MaPO? I found it to train slowly at 1e-4 on SD 1.5 with MaPO weight of 0.1. Haven't done a full hyperparameter test yet though.

@feffy380
Copy link
Contributor

MaPO with LR 1e-6 and beta 0.1 on sdxl. My dataset consists of real images in the target style and each has a matching AI image without style prompts as the rejected image. The differences are extreme, so maybe that's why I need lower learning rates?

@rockerBOO
Copy link
Contributor Author

the mapo_weight is described as the contribution which is basically the difference between the preference and non-preference. So you could adjust the weight to be like 0.05 and keep your original LR. I'm not sure what is the most efficient though.

@feffy380
Copy link
Contributor

feffy380 commented Aug 5, 2024

A few more observations with adamw and my real chosen, synthetic rejected dataset:

  • Stippling and gridlike artifacts appear when training for longer periods
  • Dropping the margin loss 75% of the time delayed the appearance of artifacts without slowing down style learning too much. This is not equivalent to reducing beta_mapo by 75%
  • Artifacts were less severe with min_snr_gamma=1 compared to without

I haven't managed to completely eliminate the artifacts, so my only option is early stopping. I've seen this with other preference optimization papers (if you look, they all train for very short periods of around 2000 steps) and it's annoying that it's never addressed.

@feffy380
Copy link
Contributor

feffy380 commented Aug 10, 2024

I finally sat down and built a real preference dataset and found pairs of images generated by the same model don't cause as many artifacts (probably because they come from the same distribution and any encoding artifacts cancel out)

@rockerBOO rockerBOO changed the base branch from dev to sd3 April 28, 2025 20:02
@SharkWipf
Copy link

Bit of a side-track, but...

The differences are extreme, so maybe that's why I need lower learning rates?

@feffy380 Have you considered generating the "negative" images with an img2img of the original images (or even controlnet, or both)? That should significantly reduce the difference while still largely representing the model's output.
No idea if this would work, but I don't see why it shouldn't (though it might reduce the effectiveness of training on specific poses).

@rockerBOO
Copy link
Contributor Author

@SharkWipf There is this paper https://arxiv.org/abs/2405.20216 which they talk about doing the img2img like technique for intermediate and "hard" stages. You may use a inverse sampling to get the noise and then denoise it from there to create small perbutations to train on. The paper goes more into detail how they pick different parts but it's a good idea.

To implement that for this PR we'd probably want to do 3 different training runs to have it be done simply but maybe later it could be possible to have "staged" training for different parameters to come into effect at different parts of the training.

For DDO (https://arxiv.org/abs/2503.01103) they also suggest this iterative process of more and more refinement to get better results. I will write a bit more about DDO in the next week or so but that is what I have been working on adding as it doesn't require a preference dataset to train it.

rockerBOO added 6 commits May 4, 2025 21:19
Refactor Preference Optimization
Refactor preference dataset
Add iterator support for ImageInfo and ImageSetInfo
- Supporting iterating through either ImageInfo or ImageSetInfo to
  clean up preference dataset implementation and support 2 or more
  images more cleanly without needing to duplicate code
Add tests for all PO functions
Add metrics for process_batch
Add losses for gradient manipulation of loss parts
Add normalizing gradient for stabilizing gradients

Args added:

mapo_beta = 0.05
cpo_beta = 0.1
bpo_beta = 0.1
bpo_lambda = 0.2
sdpo_beta = 0.02
simpo_gamma_beta_ratio = 0.25
simpo_beta = 2.0
simpo_smoothing = 0.0
simpo_loss_type = "sigmoid"
ddo_alpha = 4.0
ddo_beta = 0.05
@rockerBOO rockerBOO changed the title Add preference optimization (Diffusion-DPO, MaPO) Add preference optimization (Diffusion-DPO, MaPO, DDO, BPO, CPO, SDPO, SimPO) Jun 3, 2025
@rockerBOO
Copy link
Contributor Author

@kohya-ss This PR adds support for paired (or more) images which might be helpful in Kontext training . This PR is pretty much good to go, especially the dataset parts. Just need to finalize testing all the algorithms.

4444man

This comment was marked as duplicate.

target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)

return model_pred, target, timesteps, weighting
return model_pred, noisy_model_input, target, timesteps, weighting
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return model_pred, noisy_model_input, target, timesteps, weighting
return model_pred, noisy_model_input, target, sigmas, timesteps, weighting

target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)

return model_pred, target, timesteps, weighting
return model_pred, noisy_model_input, target, timesteps, weighting
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return model_pred, noisy_model_input, target, timesteps, weighting
return model_pred, noisy_model_input, target, sigmas, timesteps, weighting

@exdownloader
Copy link

This PR seems like an excellent base upon which to consider ADDifT/Multi-ADDifT training.
This is where both positive and negative image pairs are trained at once and a custom loss is calculated to create a difference or "slider" between them.
I have used this to great effect, but the existing tooling for such training is not maintained and takes some delicate installation to function.
While it's not strictly PO, Is this something that would be an appropriate addition or should this be separate?
Thanks.

@rockerBOO
Copy link
Contributor Author

@fall404 Thank you, I will make these changes soon. Some conflicts since I work with my fork and then upstream them here.

@exdownloader Could be something since PO is generally moving towards "positive/preferred" and away from "negative/non-preferred". So it can be flexible in the applications but not sure how related it would be to a slider in practice.

@exdownloader
Copy link

exdownloader commented Jul 29, 2025

I can confirm the behaviour as doing a fairly good job at activating (or deactivating) a single feature with minimal side effects.
A blogpost was made about this approach here: https://note.com/hakomikan/n/n716397e39d56
It's in Japanese but the visuals are clear.

Having used the project (which now exists in an apparently unmaintained state) to great success, it's the "reinvention of the wheel" that is unfortunate. If the algorithm was lifted into a more updated project like sd-scripts, I imagine that it would see far more use than it currently does.

I'd advocate for its effectiveness but I'm unsure where it could best fit into sd-scripts as I cannot translate between that codebase and this one. Having seen this PO related PR, and having a surface level understanding of "pair of images + custom objective/loss", it makes sense to reuse the groundwork you've laid where you've split out various algorithms. That being said I'm not best suited to make such a statement and would prefer to defer to those far more experienced with the sd-scripts codebase.

That being said there is a conversation already happening here and I don't want to detour too far from it. Let me know if you also think that this algorithm could be a good fit or if it might be better suited to a separate implementation.

@rockerBOO
Copy link
Contributor Author

I can confirm the behaviour as doing a fairly good job at activating (or deactivating) a single feature with minimal side effects. A blogpost was made about this approach here: https://note.com/hakomikan/n/n716397e39d56 It's in Japanese but the visuals are clear.

I see that in more detail now. I think the key difference is this is paired training in 1 learning step, and their proposal is alternating training (more like what Dreambooth is suppose to do) which is suppose to be separate learning steps.

In this PO it's more of a last step process that allows a general movement away in between the 2 images, so somewhat unrelated in it's effects.

For an example, https://arxiv.org/abs/2405.20216 this paper has a good example showcase of what it can do. Curriculum learning isn't in this PR but you could do it in multiple training runs on the same weights in theory, currently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants