Skip to content

Commit 406a656

Browse files
committed
update
1 parent 9717944 commit 406a656

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@
3636
Examples:
3737
```py
3838
>>> import torch
39-
>>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
40-
>>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline
39+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
40+
>>> from diffusers import HiDreamImagePipeline
4141
4242
43-
>>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
43+
>>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
4444
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
4545
... "meta-llama/Meta-Llama-3.1-8B-Instruct",
4646
... output_hidden_states=True,
@@ -901,6 +901,7 @@ def __call__(
901901
pooled_prompt_embeds=pooled_prompt_embeds,
902902
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
903903
device=device,
904+
dtype=self.dtype,
904905
num_images_per_prompt=num_images_per_prompt,
905906
max_sequence_length=max_sequence_length,
906907
lora_scale=lora_scale,

tests/quantization/gguf/test_gguf.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,3 +557,23 @@ class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
557557
torch_dtype = torch.bfloat16
558558
model_cls = HiDreamImageTransformer2DModel
559559
expected_memory_use_in_gb = 8
560+
561+
def get_dummy_inputs(self):
562+
return {
563+
"hidden_states": torch.randn((1, 16, 128, 128), generator=torch.Generator("cpu").manual_seed(0)).to(
564+
torch_device, self.torch_dtype
565+
),
566+
"encoder_hidden_states_t5": torch.randn(
567+
(1, 128, 4096),
568+
generator=torch.Generator("cpu").manual_seed(0),
569+
).to(torch_device, self.torch_dtype),
570+
"encoder_hidden_states_llama3": torch.randn(
571+
(32, 1, 128, 4096),
572+
generator=torch.Generator("cpu").manual_seed(0),
573+
).to(torch_device, self.torch_dtype),
574+
"pooled_embeds": torch.randn(
575+
(1, 2048),
576+
generator=torch.Generator("cpu").manual_seed(0),
577+
).to(torch_device, self.torch_dtype),
578+
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
579+
}

0 commit comments

Comments
 (0)