Skip to content

Commit dca2b5c

Browse files
author
gushiqiao
committed
support phase offload for qwen-image
1 parent 25345ab commit dca2b5c

File tree

5 files changed

+488
-135
lines changed

5 files changed

+488
-135
lines changed

lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,6 @@ def load(self):
7676
if self.config["cpu_offload"]:
7777
self.device_map = {
7878
"lm_head": AI_DEVICE,
79-
"model.embed_tokens": AI_DEVICE,
80-
"model.norm": AI_DEVICE,
8179
"model.visual": "cpu",
8280
"model.language_model": "cpu",
8381
}

lightx2v/models/networks/qwen_image/infer/offload/transformer_infer.py

Lines changed: 90 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import torch
22

33
from lightx2v.common.offload.manager import WeightAsyncStreamManager
4-
from lightx2v.models.networks.qwen_image.infer.transformer_infer import QwenImageTransformerInfer
4+
from lightx2v.models.networks.qwen_image.infer.transformer_infer import (
5+
QwenImageTransformerInfer,
6+
)
57
from lightx2v_platform.base.global_var import AI_DEVICE
68

79
torch_device_module = getattr(torch, AI_DEVICE)
@@ -11,6 +13,7 @@ class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer):
1113
def __init__(self, config):
1214
super().__init__(config)
1315
self.num_blocks = config["num_layers"]
16+
self.phases_num = 3
1417
if self.config.get("cpu_offload", False):
1518
if "offload_ratio" in self.config:
1619
self.offload_ratio = self.config["offload_ratio"]
@@ -20,12 +23,94 @@ def __init__(self, config):
2023
if offload_granularity == "block":
2124
self.infer_func = self.infer_with_blocks_offload
2225
self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity)
26+
elif offload_granularity == "phase":
27+
self.infer_func = self.infer_with_phases_offload
28+
self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity)
2329

2430
self.lazy_load = self.config.get("lazy_load", False)
2531
if self.lazy_load:
2632
self.offload_manager.init_lazy_load(num_workers=self.config.get("num_disk_workers", 4))
2733

28-
def infer_with_blocks_offload(self, block_weights, hidden_states, encoder_hidden_states, temb_img_silu, temb_txt_silu, image_rotary_emb, modulate_index):
34+
def infer_with_phases_offload(
35+
self,
36+
blocks,
37+
hidden_states,
38+
encoder_hidden_states,
39+
temb_img_silu,
40+
temb_txt_silu,
41+
image_rotary_emb,
42+
modulate_index,
43+
):
44+
for block_idx in range(len(blocks)):
45+
self.block_idx = block_idx
46+
if self.lazy_load:
47+
next_prefetch = (block_idx + 1) % len(blocks)
48+
self.offload_manager.start_prefetch_block(next_prefetch)
49+
50+
for phase_idx in range(self.phases_num):
51+
# if self.offload_manager.need_init_first_buffer:
52+
if block_idx == 0 and phase_idx == 0:
53+
self.offload_manager.init_first_buffer(blocks)
54+
55+
next_block_idx = (block_idx + 1) % len(blocks) if phase_idx == self.phases_num - 1 else block_idx
56+
next_phase_idx = (phase_idx + 1) % self.phases_num
57+
if self.lazy_load:
58+
if phase_idx == self.phases_num - 1:
59+
self.offload_manager.swap_cpu_buffers()
60+
61+
self.offload_manager.prefetch_phase(next_block_idx, next_phase_idx, blocks)
62+
with torch_device_module.stream(self.offload_manager.compute_stream):
63+
if phase_idx == 0:
64+
(
65+
img_modulated,
66+
txt_modulated,
67+
img_gate1,
68+
txt_gate1,
69+
img_mod2,
70+
txt_mod2,
71+
) = self.infer_modulate(
72+
mod_phase=self.offload_manager.cuda_buffers[phase_idx],
73+
hidden_states=hidden_states,
74+
encoder_hidden_states=encoder_hidden_states,
75+
temb_img_silu=temb_img_silu,
76+
temb_txt_silu=temb_txt_silu,
77+
modulate_index=modulate_index,
78+
)
79+
elif phase_idx == 1:
80+
hidden_states, encoder_hidden_states = self.infer_attn(
81+
attn_phase=self.offload_manager.cuda_buffers[phase_idx],
82+
img_modulated=img_modulated,
83+
txt_modulated=txt_modulated,
84+
img_gate1=img_gate1,
85+
txt_gate1=txt_gate1,
86+
hidden_states=hidden_states,
87+
encoder_hidden_states=encoder_hidden_states,
88+
image_rotary_emb=image_rotary_emb,
89+
)
90+
elif phase_idx == 2:
91+
encoder_hidden_states, hidden_states = self.infer_ffn(
92+
ffn_phase=self.offload_manager.cuda_buffers[phase_idx],
93+
hidden_states=hidden_states,
94+
encoder_hidden_states=encoder_hidden_states,
95+
img_mod2=img_mod2,
96+
txt_mod2=txt_mod2,
97+
modulate_index=modulate_index,
98+
)
99+
100+
self.offload_manager.swap_phases()
101+
102+
return hidden_states
103+
104+
def infer_with_blocks_offload(
105+
self,
106+
blocks,
107+
hidden_states,
108+
encoder_hidden_states,
109+
temb_img_silu,
110+
temb_txt_silu,
111+
image_rotary_emb,
112+
modulate_index,
113+
):
29114
for block_idx in range(self.num_blocks):
30115
self.block_idx = block_idx
31116

@@ -34,15 +119,15 @@ def infer_with_blocks_offload(self, block_weights, hidden_states, encoder_hidden
34119
self.offload_manager.start_prefetch_block(next_prefetch)
35120

36121
if block_idx == 0:
37-
self.offload_manager.init_first_buffer(block_weights.blocks)
122+
self.offload_manager.init_first_buffer(blocks)
38123

39124
if self.lazy_load:
40125
self.offload_manager.swap_cpu_buffers()
41-
self.offload_manager.prefetch_weights((block_idx + 1) % self.num_blocks, block_weights.blocks)
126+
self.offload_manager.prefetch_weights((block_idx + 1) % self.num_blocks, blocks)
42127

43128
with torch_device_module.stream(self.offload_manager.compute_stream):
44129
encoder_hidden_states, hidden_states = self.infer_block(
45-
block_weight=self.offload_manager.cuda_buffers[0],
130+
block=self.offload_manager.cuda_buffers[0],
46131
hidden_states=hidden_states,
47132
encoder_hidden_states=encoder_hidden_states,
48133
temb_img_silu=temb_img_silu,

lightx2v/models/networks/qwen_image/infer/post_infer.py

100644100755
File mode changed.

0 commit comments

Comments
 (0)