From 4fc6ef8d0be841fce2ae76207e38df6627774b36 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Sep 2024 18:42:16 +0530 Subject: [PATCH] fix some fast gpu tests. --- examples/dreambooth/train_dreambooth_lora_flux.py | 2 ++ src/diffusers/models/transformers/transformer_flux.py | 1 + tests/pipelines/flux/test_pipeline_flux_img2img.py | 2 +- tests/pipelines/flux/test_pipeline_flux_inpaint.py | 2 +- 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 8b4bf989e84e..48d669418fd8 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1597,6 +1597,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): tokenizers=[None, None], text_input_ids_list=[tokens_one, tokens_two], max_sequence_length=args.max_sequence_length, + device=accelerator.device, prompt=prompts, ) else: @@ -1606,6 +1607,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): tokenizers=[None, None], text_input_ids_list=[tokens_one, tokens_two], max_sequence_length=args.max_sequence_length, + device=accelerator.device, prompt=args.instance_prompt, ) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index fd0881a14880..e38efe668c6c 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -465,6 +465,7 @@ def forward( "Please remove the batch dimension and pass it as a 2d torch Tensor" ) img_ids = img_ids[0] + ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) diff --git a/tests/pipelines/flux/test_pipeline_flux_img2img.py b/tests/pipelines/flux/test_pipeline_flux_img2img.py index ec89f0538269..a038b1725812 100644 --- a/tests/pipelines/flux/test_pipeline_flux_img2img.py +++ b/tests/pipelines/flux/test_pipeline_flux_img2img.py @@ -18,11 +18,11 @@ enable_full_determinism() -@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): pipeline_class = FluxImg2ImgPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) + test_xformers_attention = False def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/flux/test_pipeline_flux_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_inpaint.py index 7ad77cb6ea1c..ac2eb1fa261b 100644 --- a/tests/pipelines/flux/test_pipeline_flux_inpaint.py +++ b/tests/pipelines/flux/test_pipeline_flux_inpaint.py @@ -18,11 +18,11 @@ enable_full_determinism() -@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin): pipeline_class = FluxInpaintPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) + test_xformers_attention = False def get_dummy_components(self): torch.manual_seed(0)