Skip to content

Commit bdde5ef

Browse files
authored
support Z-Image-Turbo model (#687)
1 parent c44ab44 commit bdde5ef

File tree

23 files changed

+4063
-3
lines changed

23 files changed

+4063
-3
lines changed
174 KB
Loading
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
{
2+
"aspect_ratios": {
3+
"1:1": [
4+
1328,
5+
1328
6+
],
7+
"16:9": [
8+
1664,
9+
928
10+
],
11+
"9:16": [
12+
928,
13+
1664
14+
],
15+
"4:3": [
16+
1472,
17+
1140
18+
],
19+
"3:4": [
20+
768,
21+
1024
22+
]
23+
},
24+
"aspect_ratio": "16:9",
25+
"num_channels_latents": 16,
26+
"batchsize": 1,
27+
"vae_scale_factor": 8,
28+
"infer_steps": 9,
29+
"num_layers": 30,
30+
"attention_out_dim": 3840,
31+
"attention_dim_head": 128,
32+
"attn_type": "flash_attn3",
33+
"enable_cfg": true,
34+
"sample_guide_scale": 0.0,
35+
"n_refiner_layers": 2,
36+
"patch_size": 2,
37+
"strength": 0.6,
38+
"transformer_in_channels": 64,
39+
"_auto_resize": true
40+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
{
2+
"aspect_ratios": {
3+
"1:1": [
4+
1328,
5+
1328
6+
],
7+
"16:9": [
8+
1664,
9+
928
10+
],
11+
"9:16": [
12+
928,
13+
1664
14+
],
15+
"4:3": [
16+
1472,
17+
1140
18+
],
19+
"3:4": [
20+
768,
21+
1024
22+
]
23+
},
24+
"aspect_ratio": "16:9",
25+
"num_channels_latents": 16,
26+
"batchsize": 1,
27+
"vae_scale_factor": 8,
28+
"infer_steps": 9,
29+
"num_layers": 30,
30+
"attention_out_dim": 3840,
31+
"attention_dim_head": 128,
32+
"attn_type": "flash_attn3",
33+
"enable_cfg": false,
34+
"sample_guide_scale": 0.0,
35+
"n_refiner_layers": 2,
36+
"patch_size": 2
37+
}

lightx2v/infer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401
1616
from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401
1717
from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner # noqa: F401
18+
from lightx2v.models.runners.z_image.z_image_runner import ZImageRunner # noqa: F401
1819
from lightx2v.utils.envs import *
1920
from lightx2v.utils.input_info import set_input_info
2021
from lightx2v.utils.profiler import *
@@ -56,6 +57,7 @@ def main():
5657
"wan2.2_animate",
5758
"hunyuan_video_1.5",
5859
"hunyuan_video_1.5_distill",
60+
"z_image",
5961
],
6062
default="wan2.1",
6163
)
@@ -118,6 +120,22 @@ def main():
118120
)
119121
parser.add_argument("--save_result_path", type=str, default=None, help="The path to save video path/file")
120122
parser.add_argument("--return_result_tensor", action="store_true", help="Whether to return result tensor. (Useful for comfyui)")
123+
124+
# Aspect ratio and custom shape for image tasks (t2i, i2i)
125+
parser.add_argument(
126+
"--aspect_ratio",
127+
type=str,
128+
default="16:9",
129+
choices=["16:9", "9:16", "1:1", "4:3", "3:4"],
130+
help="Aspect ratio for image generation. Only used for t2i and i2i tasks.",
131+
)
132+
parser.add_argument(
133+
"--custom_shape",
134+
type=str,
135+
default=None,
136+
help="Custom shape for image generation in format 'height,width' (e.g., '928,1664'). Only used for t2i and i2i tasks. Takes precedence over aspect_ratio.",
137+
)
138+
parser.add_argument("--strength", type=float, default=0.6, help="The strength for image-to-image generation")
121139
args = parser.parse_args()
122140
validate_task_arguments(args)
123141

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import gc
2+
import os
3+
4+
import torch
5+
from PIL import Image
6+
7+
try:
8+
from transformers import Qwen2Tokenizer, Qwen3Model
9+
except ImportError:
10+
Qwen2Tokenizer = None
11+
Qwen3Model = None
12+
13+
from lightx2v_platform.base.global_var import AI_DEVICE
14+
15+
torch_device_module = getattr(torch, AI_DEVICE)
16+
17+
try:
18+
from diffusers.image_processor import VaeImageProcessor
19+
except ImportError:
20+
VaeImageProcessor = None
21+
22+
23+
class Qwen3Model_TextEncoder:
24+
def __init__(self, config):
25+
self.config = config
26+
self.tokenizer_max_length = 512
27+
self.cpu_offload = config.get("qwen3_cpu_offload", config.get("cpu_offload", False))
28+
self.dtype = torch.bfloat16
29+
self.load()
30+
31+
def load(self):
32+
self.text_encoder = Qwen3Model.from_pretrained(os.path.join(self.config["model_path"], "text_encoder"), torch_dtype=torch.bfloat16)
33+
if not self.cpu_offload:
34+
self.text_encoder = self.text_encoder.to(AI_DEVICE)
35+
36+
self.tokenizer = Qwen2Tokenizer.from_pretrained(os.path.join(self.config["model_path"], "tokenizer"))
37+
38+
if self.config["task"] == "i2i":
39+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.config.get("vae_scale_factor", 8) * 2)
40+
41+
def preprocess_image(self, image):
42+
if isinstance(image, Image.Image):
43+
preprocessed_image = self.image_processor.preprocess(image)
44+
elif isinstance(image, torch.Tensor):
45+
if image.dim() == 3:
46+
image = image.unsqueeze(0)
47+
preprocessed_image = image
48+
else:
49+
raise ValueError(f"Unsupported image type: {type(image)}")
50+
51+
return preprocessed_image
52+
53+
@torch.no_grad()
54+
def infer(self, prompt, image_list=None):
55+
if self.cpu_offload:
56+
self.text_encoder.to(AI_DEVICE)
57+
58+
if isinstance(prompt, str):
59+
prompt = [prompt]
60+
61+
for i, prompt_item in enumerate(prompt):
62+
messages = [{"role": "user", "content": prompt_item}]
63+
prompt_tokens = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True)
64+
prompt[i] = prompt_tokens
65+
66+
text_inputs = self.tokenizer(prompt, max_length=self.tokenizer_max_length, padding="max_length", truncation=True, return_tensors="pt").to(AI_DEVICE)
67+
prompt_masks = text_inputs.attention_mask.bool().to(AI_DEVICE)
68+
69+
prompt_embeds = self.text_encoder(
70+
input_ids=text_inputs.input_ids,
71+
attention_mask=prompt_masks,
72+
output_hidden_states=True,
73+
).hidden_states[-2]
74+
embedding_list = []
75+
for i in range(len(prompt_embeds)):
76+
extracted = prompt_embeds[i][prompt_masks[i]]
77+
embedding_list.append(extracted)
78+
image_info = {}
79+
if self.config["task"] == "i2i" and image_list is not None:
80+
vae_image_list = []
81+
for image in image_list:
82+
preprocessed_image = self.preprocess_image(image)
83+
vae_image_list.append(preprocessed_image)
84+
85+
image_info = {
86+
"vae_image_list": vae_image_list,
87+
}
88+
89+
if self.cpu_offload:
90+
self.text_encoder.to(torch.device("cpu"))
91+
torch_device_module.empty_cache()
92+
gc.collect()
93+
94+
return embedding_list, image_info
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from dataclasses import dataclass
2+
3+
import torch
4+
5+
6+
@dataclass
7+
class ZPreInferModuleOutput:
8+
hidden_states: torch.Tensor
9+
encoder_hidden_states: torch.Tensor
10+
temb_img_silu: torch.Tensor
11+
temb_txt_silu: torch.Tensor
12+
x_freqs_cis: torch.Tensor
13+
cap_freqs_cis: torch.Tensor
14+
image_tokens_len: int
15+
x_item_seqlens: list
16+
cap_item_seqlens: list
17+
18+
@property
19+
def adaln_input(self) -> torch.Tensor:
20+
return self.temb_img_silu
21+
22+
@property
23+
def image_rotary_emb(self) -> torch.Tensor:
24+
return self.x_freqs_cis
25+
26+
@property
27+
def freqs_cis(self) -> torch.Tensor:
28+
return self.x_freqs_cis
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
3+
from lightx2v.common.offload.manager import WeightAsyncStreamManager
4+
from lightx2v.models.networks.z_image.infer.transformer_infer import ZImageTransformerInfer
5+
from lightx2v_platform.base.global_var import AI_DEVICE
6+
7+
torch_device_module = getattr(torch, AI_DEVICE)
8+
9+
10+
class ZImageOffloadTransformerInfer(ZImageTransformerInfer):
11+
def __init__(self, config):
12+
super().__init__(config)
13+
self.phases_num = 3
14+
self.num_blocks = config["num_layers"]
15+
if self.config.get("cpu_offload", False):
16+
if "offload_ratio" in self.config:
17+
self.offload_ratio = self.config["offload_ratio"]
18+
else:
19+
self.offload_ratio = 1
20+
offload_granularity = self.config.get("offload_granularity", "block")
21+
if offload_granularity == "block":
22+
if not self.config.get("lazy_load", False):
23+
self.infer_func = self.infer_with_blocks_offload
24+
else:
25+
assert NotImplementedError
26+
27+
if offload_granularity != "model":
28+
self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity)
29+
else:
30+
assert NotImplementedError
31+
32+
def infer_with_blocks_offload(self, block_weights, hidden_states, encoder_hidden_states, temb, image_rotary_emb, modulate_index):
33+
for block_idx in range(self.num_blocks):
34+
self.block_idx = block_idx
35+
if self.offload_manager.need_init_first_buffer:
36+
self.offload_manager.init_first_buffer(block_weights.blocks)
37+
38+
self.offload_manager.prefetch_weights((block_idx + 1) % self.num_blocks, block_weights.blocks)
39+
with torch_device_module.stream(self.offload_manager.compute_stream):
40+
encoder_hidden_states, hidden_states = self.infer_block(
41+
block_weight=self.offload_manager.cuda_buffers[0],
42+
hidden_states=hidden_states,
43+
encoder_hidden_states=encoder_hidden_states,
44+
temb=temb,
45+
image_rotary_emb=image_rotary_emb,
46+
modulate_index=modulate_index,
47+
)
48+
49+
self.offload_manager.swap_blocks()
50+
51+
return encoder_hidden_states, hidden_states
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch.nn.functional as F
2+
3+
4+
class ZImagePostInfer:
5+
def __init__(self, config):
6+
self.config = config
7+
self.cpu_offload = config.get("cpu_offload", False)
8+
self.zero_cond_t = config.get("zero_cond_t", False)
9+
10+
def set_scheduler(self, scheduler):
11+
self.scheduler = scheduler
12+
13+
def infer(self, weights, hidden_states, temb_img_silu, image_tokens_len=None):
14+
temb_silu = F.silu(temb_img_silu)
15+
temb1 = weights.norm_out_linear.apply(temb_silu)
16+
17+
scale = 1.0 + temb1
18+
normed = weights.norm_out.apply(hidden_states)
19+
scaled_norm = normed * scale.unsqueeze(1)
20+
B, T, D = scaled_norm.shape
21+
hidden_states_2d = scaled_norm.reshape(B * T, D)
22+
23+
output_2d = weights.proj_out_linear.apply(hidden_states_2d)
24+
out_dim = output_2d.shape[-1]
25+
output = output_2d.reshape(B, T, out_dim)
26+
27+
if image_tokens_len is not None:
28+
output = output[:, :image_tokens_len, :]
29+
30+
patch_size = self.config.get("patch_size", 2)
31+
f_patch_size = 1
32+
transformer_out_channels = out_dim // (patch_size * patch_size * f_patch_size)
33+
expected_out_dim = patch_size * patch_size * f_patch_size * transformer_out_channels
34+
35+
if out_dim != expected_out_dim:
36+
raise ValueError(f"out_dim mismatch: {out_dim} != {expected_out_dim} (transformer_out_channels={transformer_out_channels})")
37+
38+
out_channels = transformer_out_channels
39+
target_shape = self.scheduler.input_info.target_shape
40+
41+
_, _, height, width = target_shape
42+
num_frames = 1
43+
pH = pW = patch_size
44+
pF = f_patch_size
45+
F_tokens = num_frames // pF
46+
H_tokens = height // pH
47+
W_tokens = width // pW
48+
49+
expected_T = F_tokens * H_tokens * W_tokens
50+
if output.shape[1] != expected_T:
51+
raise ValueError(f"Token count mismatch: output.shape[1]={output.shape[1]} != expected_T={expected_T} (from target_shape={target_shape})")
52+
53+
output_reshaped = output.view(B, F_tokens, H_tokens, W_tokens, pF, pH, pW, out_channels)
54+
output_permuted = output_reshaped.permute(0, 7, 1, 4, 2, 5, 3, 6)
55+
output_4d = output_permuted.reshape(B, out_channels, num_frames, height, width)
56+
output_4d = output_4d.squeeze(2)
57+
58+
return output_4d

0 commit comments

Comments
 (0)