From ba51279e34508f0a6b29affcfaa7612eb65dcbee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AB=A5=E7=A8=8B?= Date: Tue, 2 Dec 2025 19:04:59 +0800 Subject: [PATCH] Fixed an issue with abnormal training loss in Dreambooth Flux2. --- examples/dreambooth/train_dreambooth_lora_flux2_img2img.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 32bce9531b71..bc13a1e4cc17 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1652,6 +1652,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # concatenate the model inputs with the cond inputs packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1) + orig_inp_shape = packed_noisy_model_input.shape + orig_inp_ids_shape = model_input_ids.shape model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1) # handle guidance @@ -1668,8 +1670,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): img_ids=model_input_ids, # B, image_seq_len, 4 return_dict=False, )[0] - model_pred = model_pred[:, : packed_noisy_model_input.size(1) :] - + model_pred = model_pred[:, : orig_inp_shape[1] :] + model_input_ids = model_input_ids[:, :orig_inp_ids_shape[1] :] model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids) # these weighting schemes use a uniform timestep sampling