-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
There appears to be an issue in the _get_clip_prompt_embeds function, possibly due to the dtype of self.text_encoder being float16 instead of the expected bfloat16.
Reproduction
I am using the code from diffuser documentation
Diffusers version: 0.34.0.dev0
import torch
from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
text_encoder_4 = LlamaForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
output_hidden_states=True,
output_attentions=True,
torch_dtype=torch.bfloat16,
)
pipe = HiDreamImagePipeline.from_pretrained(
"HiDream-ai/HiDream-I1-Full",
tokenizer_4=tokenizer_4,
text_encoder_4=text_encoder_4,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
image = pipe(
'A cat holding a sign that says "Hi-Dreams.ai".',
height=1024,
width=1024,
guidance_scale=5.0,
num_inference_steps=50,
generator=torch.Generator("cuda").manual_seed(0),
).images[0]
image.save("output.png")
Logs
(ai-toolkit) (base) om@C2P1R3H5:/mnt/data/om/ai-toolkit$ CUDA_VISIBLE_DEVICES=1 python hidream-inference0.py
/home/om/miniconda3/envs/ai-toolkit/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:818: UserWarning: `return_dict_in_generate` is NOT set to `True`, but `output_attentions` is. When `return_dict_in_generate` is not `True`, `output_attentions` is ignored.
warnings.warn(
/home/om/miniconda3/envs/ai-toolkit/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:818: UserWarning: `return_dict_in_generate` is NOT set to `True`, but `output_hidden_states` is. When `return_dict_in_generate` is not `True`, `output_hidden_states` is ignored.
warnings.warn(
Loading checkpoint shards: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 4/4 [00:00<00:00, 10.04it/s]
Loading checkpoint shards: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 2/2 [00:00<00:00, 10.92it/s]
Loading checkpoint shards: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 7/7 [00:02<00:00, 2.58it/s]
Loading pipeline components...: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 11/11 [00:04<00:00, 2.70it/s]
Type of text encoder: torch.float16
Traceback (most recent call last):
File "/mnt/data/om/ai-toolkit/hidream-inference0.py", line 22, in <module>
image = pipe(
^^^^^
File "/home/om/miniconda3/envs/ai-toolkit/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/mnt/data/om/diffusers/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py", line 889, in __call__
) = self.encode_prompt(
^^^^^^^^^^^^^^^^^^^
File "/mnt/data/om/diffusers/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py", line 341, in encode_prompt
pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/data/om/diffusers/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py", line 256, in _get_clip_prompt_embeds
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/om/miniconda3/envs/ai-toolkit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/om/miniconda3/envs/ai-toolkit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/om/miniconda3/envs/ai-toolkit/lib/python3.11/site-packages/accelerate/hooks.py", line 176, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/om/miniconda3/envs/ai-toolkit/lib/python3.11/site-packages/transformers/models/clip/modeling_clip.py", line 1490, in forward
text_embeds = self.text_projection(pooled_output)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/om/miniconda3/envs/ai-toolkit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/om/miniconda3/envs/ai-toolkit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/om/miniconda3/envs/ai-toolkit/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Half != c10::BFloat16System Info
- π€ Diffusers version: 0.33.1
- Platform: Linux-5.15.0-134-generic-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.10.12
- PyTorch version (GPU?): 2.2.0+cu118 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.30.2
- Transformers version: 4.51.3
- Accelerate version: 1.6.0
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA H100 80GB HBM3, 81559 MiB
NVIDIA H100 80GB HBM3, 81559 MiB
NVIDIA H100 80GB HBM3, 81559 MiB
NVIDIA H100 80GB HBM3, 81559 MiB
NVIDIA H100 80GB HBM3, 81559 MiB
NVIDIA H100 80GB HBM3, 81559 MiB
NVIDIA H100 80GB HBM3, 81559 MiB
NVIDIA H100 80GB HBM3, 81559 MiB - Using GPU in script?: 0
- Using distributed or parallel set-up in script?: No
Who can help?
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working