Skip to content

Commit 721375b

Browse files
committed
update
1 parent 98cc6d0 commit 721375b

File tree

3 files changed

+41
-0
lines changed

3 files changed

+41
-0
lines changed

docs/source/en/api/models/hidream_image_transformer.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,29 @@ from diffusers import HiDreamImageTransformer2DModel
2121
transformer = HiDreamImageTransformer2DModel.from_pretrained("HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16)
2222
```
2323

24+
## Loading GGUF quantized checkpoints
25+
26+
GGUF checkpoints for the `HiDreamImageTransformer2DModel` can we be loaded using `~FromOriginalModelMixin.from_single_file`
27+
28+
```python
29+
from diffusers import HiDreamImageTransformer2DModel
30+
31+
ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
32+
transformer = HiDreamImageTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
33+
```
34+
35+
If you are trying to use a GGUF checkpoint for the `HiDream-ai/HiDream-E1-Full` model, you will have to pass in a `config` argument to properly configure the model. This is because the HiDream I1 and E1 models share the same state dict keys, so it is currently not possible to automatically infer the model type from the checkpoint itself.
36+
37+
```python
38+
from diffusers import HiDreamImageTransformer2DModel
39+
40+
ckpt_path = "https://huggingface.co/ND911/HiDream_e1_full_bf16-ggufs/blob/main/hidream_e1_full_bf16-Q2_K.gguf"
41+
42+
transformer = HiDreamImageTransformer2DModel.from_single_file(ckpt_path, config="HiDream-ai/HiDream-E1-Full", subfolder="transformer", torch_dtype=torch.bfloat16)
43+
```
44+
45+
46+
2447
## HiDreamImageTransformer2DModel
2548

2649
[[autodoc]] HiDreamImageTransformer2DModel

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
convert_autoencoder_dc_checkpoint_to_diffusers,
3232
convert_controlnet_checkpoint,
3333
convert_flux_transformer_checkpoint_to_diffusers,
34+
convert_hidream_transformer_to_diffusers,
3435
convert_hunyuan_video_transformer_to_diffusers,
3536
convert_ldm_unet_checkpoint,
3637
convert_ldm_vae_checkpoint,
@@ -133,6 +134,10 @@
133134
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
134135
"default_subfolder": "vae",
135136
},
137+
"HiDreamImageTransformer2DModel": {
138+
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
139+
"default_subfolder": "transformer",
140+
},
136141
}
137142

138143

src/diffusers/loaders/single_file_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@
126126
],
127127
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
128128
"wan_vae": "decoder.middle.0.residual.0.gamma",
129+
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
129130
}
130131

131132
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -190,6 +191,7 @@
190191
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
191192
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
192193
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
194+
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
193195
}
194196

195197
# Use to configure model sample size when original config is provided
@@ -701,6 +703,8 @@ def infer_diffusers_model_type(checkpoint):
701703
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
702704
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
703705
model_type = "wan-t2v-14B"
706+
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
707+
model_type = "hidream"
704708
else:
705709
model_type = "v1"
706710

@@ -3293,3 +3297,12 @@ def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
32933297
converted_state_dict[key] = value
32943298

32953299
return converted_state_dict
3300+
3301+
3302+
def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
3303+
keys = list(checkpoint.keys())
3304+
for k in keys:
3305+
if "model.diffusion_model." in k:
3306+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
3307+
3308+
return checkpoint

0 commit comments

Comments
 (0)