-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
flux controlnet train script was contributed by @PromeAIpro and merged around 0.31.0, see https://github.com/huggingface/diffusers/pull/9324/files
but around 0.32.0, with modification of code prettify and diffusers api introducing, see https://github.com/huggingface/diffusers/blob/8170dc368d278ec40d27bf04f58bff140cebd99e/examples/flux-control/train_control_flux.py
the training logic seems written wrong! The code now is a lora-or-finetune training , not a controlnet training
diffusers/examples/flux-control/train_control_flux.py
Lines 1067 to 1097 in 8170dc3
| text_encoding_pipeline = text_encoding_pipeline.to("cpu") | |
| # Predict. | |
| model_pred = flux_transformer( | |
| hidden_states=packed_noisy_model_input, | |
| timestep=timesteps / 1000, | |
| guidance=guidance_vec, | |
| pooled_projections=pooled_prompt_embeds, | |
| encoder_hidden_states=prompt_embeds, | |
| txt_ids=text_ids, | |
| img_ids=latent_image_ids, | |
| return_dict=False, | |
| )[0] | |
| model_pred = FluxControlPipeline._unpack_latents( | |
| model_pred, | |
| height=noisy_model_input.shape[2] * vae_scale_factor, | |
| width=noisy_model_input.shape[3] * vae_scale_factor, | |
| vae_scale_factor=vae_scale_factor, | |
| ) | |
| # these weighting schemes use a uniform timestep sampling | |
| # and instead post-weight the loss | |
| weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) | |
| # flow-matching loss | |
| target = noise - pixel_latents | |
| loss = torch.mean( | |
| (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), | |
| 1, | |
| ) | |
| loss = loss.mean() | |
| accelerator.backward(loss) |
we need a PR to fix that.
@a-r-r-o-w @sayakpaul @yiyixuxu
Reproduction
empty
Logs
emptySystem Info
empty
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working