diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 8b4bf989e84e..48d669418fd8 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1597,6 +1597,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): tokenizers=[None, None], text_input_ids_list=[tokens_one, tokens_two], max_sequence_length=args.max_sequence_length, + device=accelerator.device, prompt=prompts, ) else: @@ -1606,6 +1607,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): tokenizers=[None, None], text_input_ids_list=[tokens_one, tokens_two], max_sequence_length=args.max_sequence_length, + device=accelerator.device, prompt=args.instance_prompt, ) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index fd0881a14880..e38efe668c6c 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -465,6 +465,7 @@ def forward( "Please remove the batch dimension and pass it as a 2d torch Tensor" ) img_ids = img_ids[0] + ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) diff --git a/tests/pipelines/flux/test_pipeline_flux_img2img.py b/tests/pipelines/flux/test_pipeline_flux_img2img.py index ec89f0538269..a038b1725812 100644 --- a/tests/pipelines/flux/test_pipeline_flux_img2img.py +++ b/tests/pipelines/flux/test_pipeline_flux_img2img.py @@ -18,11 +18,11 @@ enable_full_determinism() -@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): pipeline_class = FluxImg2ImgPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) + test_xformers_attention = False def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/flux/test_pipeline_flux_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_inpaint.py index 7ad77cb6ea1c..ac2eb1fa261b 100644 --- a/tests/pipelines/flux/test_pipeline_flux_inpaint.py +++ b/tests/pipelines/flux/test_pipeline_flux_inpaint.py @@ -18,11 +18,11 @@ enable_full_determinism() -@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin): pipeline_class = FluxInpaintPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) + test_xformers_attention = False def get_dummy_components(self): torch.manual_seed(0)