Skip to content

Commit 66f99b9

Browse files
authored
Merge branch 'main' into introduce_autopipeline_for_text2video
2 parents 41ec606 + 54fa074 commit 66f99b9

File tree

11 files changed

+1230
-59
lines changed

11 files changed

+1230
-59
lines changed

docs/source/en/_toctree.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@
651651
- local: api/pipelines/wuerstchen
652652
title: Wuerstchen
653653
- local: api/pipelines/z_image
654-
title: Z-Image
654+
title: Z-Image
655655
title: Image
656656
- sections:
657657
- local: api/pipelines/allegro

docs/source/en/api/pipelines/z_image.md

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,41 @@ specific language governing permissions and limitations under the License.
2626

2727
Z-Image-Turbo is a distilled version of Z-Image that matches or exceeds leading competitors with only 8 NFEs (Number of Function Evaluations). It offers sub-second inference latency on enterprise-grade H800 GPUs and fits comfortably within 16G VRAM consumer devices. It excels in photorealistic image generation, bilingual text rendering (English & Chinese), and robust instruction adherence.
2828

29+
## Image-to-image
30+
31+
Use [`ZImageImg2ImgPipeline`] to transform an existing image based on a text prompt.
32+
33+
```python
34+
import torch
35+
from diffusers import ZImageImg2ImgPipeline
36+
from diffusers.utils import load_image
37+
38+
pipe = ZImageImg2ImgPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16)
39+
pipe.to("cuda")
40+
41+
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
42+
init_image = load_image(url).resize((1024, 1024))
43+
44+
prompt = "A fantasy landscape with mountains and a river, detailed, vibrant colors"
45+
image = pipe(
46+
prompt,
47+
image=init_image,
48+
strength=0.6,
49+
num_inference_steps=9,
50+
guidance_scale=0.0,
51+
generator=torch.Generator("cuda").manual_seed(42),
52+
).images[0]
53+
image.save("zimage_img2img.png")
54+
```
55+
2956
## ZImagePipeline
3057

3158
[[autodoc]] ZImagePipeline
3259
- all
33-
- __call__
60+
- __call__
61+
62+
## ZImageImg2ImgPipeline
63+
64+
[[autodoc]] ZImageImg2ImgPipeline
65+
- all
66+
- __call__

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,7 @@
663663
"WuerstchenCombinedPipeline",
664664
"WuerstchenDecoderPipeline",
665665
"WuerstchenPriorPipeline",
666+
"ZImageImg2ImgPipeline",
666667
"ZImagePipeline",
667668
]
668669
)
@@ -1361,6 +1362,7 @@
13611362
WuerstchenCombinedPipeline,
13621363
WuerstchenDecoderPipeline,
13631364
WuerstchenPriorPipeline,
1365+
ZImageImg2ImgPipeline,
13641366
ZImagePipeline,
13651367
)
13661368

src/diffusers/models/transformers/transformer_prx.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import torch
1818
from torch import nn
19-
from torch.nn.functional import fold, unfold
2019

2120
from ...configuration_utils import ConfigMixin, register_to_config
2221
from ...utils import logging
@@ -532,7 +531,19 @@ def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
532531
Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
533532
// patch_size)` is the number of patches.
534533
"""
535-
return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2)
534+
b, c, h, w = img.shape
535+
p = patch_size
536+
537+
# Reshape to (B, C, H//p, p, W//p, p) separating grid and patch dimensions
538+
img = img.reshape(b, c, h // p, p, w // p, p)
539+
540+
# Permute to (B, H//p, W//p, C, p, p) using einsum
541+
# n=batch, c=channels, h=grid_height, p=patch_height, w=grid_width, q=patch_width
542+
img = torch.einsum("nchpwq->nhwcpq", img)
543+
544+
# Flatten to (B, L, C * p * p)
545+
img = img.reshape(b, -1, c * p * p)
546+
return img
536547

537548

538549
def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor:
@@ -554,12 +565,26 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te
554565
Reconstructed image tensor of shape `(B, C, H, W)`.
555566
"""
556567
if isinstance(shape, tuple):
557-
shape = shape[-2:]
568+
h, w = shape[-2:]
558569
elif isinstance(shape, torch.Tensor):
559-
shape = (int(shape[0]), int(shape[1]))
570+
h, w = (int(shape[0]), int(shape[1]))
560571
else:
561572
raise NotImplementedError(f"shape type {type(shape)} not supported")
562-
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
573+
574+
b, l, d = seq.shape
575+
p = patch_size
576+
c = d // (p * p)
577+
578+
# Reshape back to grid structure: (B, H//p, W//p, C, p, p)
579+
seq = seq.reshape(b, h // p, w // p, c, p, p)
580+
581+
# Permute back to image layout: (B, C, H//p, p, W//p, p)
582+
# n=batch, h=grid_height, w=grid_width, c=channels, p=patch_height, q=patch_width
583+
seq = torch.einsum("nhwcpq->nchpwq", seq)
584+
585+
# Final reshape to (B, C, H, W)
586+
seq = seq.reshape(b, c, h, w)
587+
return seq
563588

564589

565590
class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@
405405
"Kandinsky5T2IPipeline",
406406
"Kandinsky5I2IPipeline",
407407
]
408-
_import_structure["z_image"] = ["ZImagePipeline"]
408+
_import_structure["z_image"] = ["ZImageImg2ImgPipeline", "ZImagePipeline"]
409409
_import_structure["skyreels_v2"] = [
410410
"SkyReelsV2DiffusionForcingPipeline",
411411
"SkyReelsV2DiffusionForcingImageToVideoPipeline",
@@ -842,7 +842,7 @@
842842
WuerstchenDecoderPipeline,
843843
WuerstchenPriorPipeline,
844844
)
845-
from .z_image import ZImagePipeline
845+
from .z_image import ZImageImg2ImgPipeline, ZImagePipeline
846846

847847
try:
848848
if not is_onnx_available():

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@
119119
)
120120
from .wan import WanAnimatePipeline, WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline
121121
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
122+
from .z_image import ZImageImg2ImgPipeline, ZImagePipeline
122123

123124

124125
AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
@@ -162,6 +163,7 @@
162163
("cogview4-control", CogView4ControlPipeline),
163164
("qwenimage", QwenImagePipeline),
164165
("qwenimage-controlnet", QwenImageControlNetPipeline),
166+
("z-image", ZImagePipeline),
165167
]
166168
)
167169

@@ -189,6 +191,7 @@
189191
("qwenimage", QwenImageImg2ImgPipeline),
190192
("qwenimage-edit", QwenImageEditPipeline),
191193
("qwenimage-edit-plus", QwenImageEditPlusPipeline),
194+
("z-image", ZImageImg2ImgPipeline),
192195
]
193196
)
194197

src/diffusers/pipelines/z_image/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
else:
2424
_import_structure["pipeline_output"] = ["ZImagePipelineOutput"]
2525
_import_structure["pipeline_z_image"] = ["ZImagePipeline"]
26+
_import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"]
2627

2728

2829
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -35,6 +36,7 @@
3536
else:
3637
from .pipeline_output import ZImagePipelineOutput
3738
from .pipeline_z_image import ZImagePipeline
39+
from .pipeline_z_image_img2img import ZImageImg2ImgPipeline
3840

3941
else:
4042
import sys

0 commit comments

Comments
 (0)