- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
This seems to be a regular issue with MPS support, the rope implementation has an hard coded float64 parameter arrange call
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_hidream_image.py", line 98, in rope
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
Reproduction
import torch
from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline, HiDreamImageTransformer2DModel
import gc
torch.mps.set_per_process_memory_fraction(0.0)
def flush():
    gc.collect()
    torch.mps.empty_cache()
    gc.collect()
    torch.mps.empty_cache()
scheduler = UniPCMultistepScheduler(
    flow_shift=3.0, prediction_type="flow_prediction", use_flow_sigmas=True
)
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
text_encoder_4 = LlamaForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    output_hidden_states=True,
    output_attentions=True,
    torch_dtype=torch.bfloat16,
)
prompt = 'A cat holding a sign that says "Hi-Dreams.ai".'
te_pipe = HiDreamImagePipeline.from_pretrained(
    "HiDream-ai/HiDream-I1-Full",
    scheduler=scheduler,
    tokenizer_4=tokenizer_4,
    text_encoder_4=text_encoder_4,
    transformer=None,
    vae=None,
    torch_dtype=torch.bfloat16,
)
te_pipe.to("mps")
with torch.no_grad():
    (prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds) = te_pipe.encode_prompt(
            prompt=prompt, prompt_2=None, prompt_3=None, prompt_4=None
    )
del text_encoder_4
del te_pipe
flush()
transformer = HiDreamImageTransformer2DModel.from_pretrained(
    "HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HiDreamImagePipeline.from_pretrained(
    "HiDream-ai/HiDream-I1-Full",
    scheduler=scheduler,
    tokenizer=None,
    text_encoder=None,
    tokenizer_2=None,
    text_encoder_2=None,
    tokenizer_3=None,
    text_encoder_3=None,
    tokenizer_4=None,
    text_encoder_4=None,
    transformer=transformer,
    torch_dtype=torch.bfloat16,
)
pipe.to('mps')
image = pipe(
    prompt_embeds=prompt_embeds,
    negative_prompt_embeds=negative_prompt_embeds,
    pooled_prompt_embeds=pooled_prompt_embeds,
    negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
    height=1024,
    width=1024,
    guidance_scale=5.0,
    num_inference_steps=50,
    generator=torch.Generator("mps").manual_seed(0),
).images[0]
image.save("hidream_output.png")Logs
$ time python hi.py 
/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:817: UserWarning: `return_dict_in_generate` is NOT set to `True`, but `output_attentions` is. When `return_dict_in_generate` is not `True`, `output_attentions` is ignored.
  warnings.warn(
/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:817: UserWarning: `return_dict_in_generate` is NOT set to `True`, but `output_hidden_states` is. When `return_dict_in_generate` is not `True`, `output_hidden_states` is ignored.
  warnings.warn(
Loading checkpoint shards: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 4/4 [00:19<00:00,  4.99s/it]
Loading checkpoint shards: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 2/2 [00:04<00:00,  2.02s/it]
Loading pipeline components...: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββ| 9/9 [00:12<00:00,  1.36s/it]
`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.
Fetching 7 files: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 7/7 [00:00<00:00, 71435.83it/s]
Loading checkpoint shards: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 7/7 [01:03<00:00,  9.06s/it]
Loading pipeline components...: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββ| 3/3 [00:00<00:00, 14.59it/s]
  0%|                                                                                            | 0/50 [00:03<?, ?it/s]
Traceback (most recent call last):
  File "/Volumes/SSD2TB/AI/Diffusers/hi.py", line 68, in <module>
    image = pipe(
            ^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/pipelines/hidream_image/pipeline_hidream_image.py", line 682, in __call__
    noise_pred = self.transformer(
                 ^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_hidream_image.py", line 801, in forward
    image_rotary_emb = self.pe_embedder(ids)
                       ^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_hidream_image.py", line 120, in forward
    [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_hidream_image.py", line 120, in <listcomp>
    [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_hidream_image.py", line 98, in rope
    scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
real	9m30.801s
user	0m22.554s
sys	1m41.065sSystem Info
- π€ Diffusers version: 0.33.0.dev0
- Platform: macOS-15.3.2-arm64-arm-64bit
- Running on Google Colab?: No
- Python version: 3.11.10
- PyTorch version (GPU?): 2.6.0 (False)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.27.1
- Transformers version: 4.50.3
- Accelerate version: 0.34.2
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.4.5
- xFormers version: not installed
- Accelerator: Apple M3
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No
Who can help?
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working