Skip to content

Commit adf1f91

Browse files
authored
[Tests] fix some fast gpu tests. (#9379)
fix some fast gpu tests.
1 parent f28a8c2 commit adf1f91

File tree

4 files changed

+5
-2
lines changed

4 files changed

+5
-2
lines changed

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,6 +1597,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15971597
tokenizers=[None, None],
15981598
text_input_ids_list=[tokens_one, tokens_two],
15991599
max_sequence_length=args.max_sequence_length,
1600+
device=accelerator.device,
16001601
prompt=prompts,
16011602
)
16021603
else:
@@ -1606,6 +1607,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16061607
tokenizers=[None, None],
16071608
text_input_ids_list=[tokens_one, tokens_two],
16081609
max_sequence_length=args.max_sequence_length,
1610+
device=accelerator.device,
16091611
prompt=args.instance_prompt,
16101612
)
16111613

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ def forward(
465465
"Please remove the batch dimension and pass it as a 2d torch Tensor"
466466
)
467467
img_ids = img_ids[0]
468+
468469
ids = torch.cat((txt_ids, img_ids), dim=0)
469470
image_rotary_emb = self.pos_embed(ids)
470471

tests/pipelines/flux/test_pipeline_flux_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
enable_full_determinism()
1919

2020

21-
@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.")
2221
class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
2322
pipeline_class = FluxImg2ImgPipeline
2423
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
2524
batch_params = frozenset(["prompt"])
25+
test_xformers_attention = False
2626

2727
def get_dummy_components(self):
2828
torch.manual_seed(0)

tests/pipelines/flux/test_pipeline_flux_inpaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
enable_full_determinism()
1919

2020

21-
@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.")
2221
class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
2322
pipeline_class = FluxInpaintPipeline
2423
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
2524
batch_params = frozenset(["prompt"])
25+
test_xformers_attention = False
2626

2727
def get_dummy_components(self):
2828
torch.manual_seed(0)

0 commit comments

Comments
 (0)