Skip to content

Commit eec8238

Browse files
support w4 for qw_vl and disk offload for qw_2511 (#681)
Co-authored-by: gushiqiao <975033167> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent b5752db commit eec8238

File tree

13 files changed

+807
-322
lines changed

13 files changed

+807
-322
lines changed

lightx2v/common/offload/manager.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None):
5050
def init_first_buffer(self, blocks, adapter_block_idx=None):
5151
with torch_device_module.stream(self.init_stream):
5252
if hasattr(self, "cpu_buffers"):
53-
self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0][0].state_dict(), 0, adapter_block_idx)
53+
if self.offload_granularity == "block":
54+
self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0].state_dict(), 0, adapter_block_idx)
55+
else:
56+
self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0][0].state_dict(), 0, adapter_block_idx)
5457
else:
5558
if self.offload_granularity == "block":
5659
self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx)
@@ -62,8 +65,7 @@ def init_first_buffer(self, blocks, adapter_block_idx=None):
6265
def prefetch_weights(self, block_idx, blocks, adapter_block_idx=None):
6366
with torch_device_module.stream(self.cuda_load_stream):
6467
if hasattr(self, "cpu_buffers"):
65-
self.cpu_buffers[1].load_state_dict_from_disk(block_idx, adapter_block_idx)
66-
self.cuda_buffers[1].load_state_dict(self.cpu_buffers[1].state_dict(), block_idx, adapter_block_idx)
68+
self.cuda_buffers[1].load_state_dict(self.cpu_buffers[0].state_dict(), block_idx, adapter_block_idx)
6769
else:
6870
self.cuda_buffers[1].load_state_dict(blocks[block_idx].state_dict(), block_idx, adapter_block_idx)
6971

@@ -110,12 +112,17 @@ def init_lazy_load(self, num_workers=6):
110112
def start_prefetch_block(self, block_idx, adapter_block_idx=None):
111113
self.prefetch_block_idx = block_idx
112114
self.prefetch_futures = []
113-
for phase in self.cpu_buffers[1]:
114-
future = self.executor.submit(phase.load_state_dict_from_disk, block_idx, adapter_block_idx)
115+
if self.offload_granularity == "block":
116+
future = self.executor.submit(self.cpu_buffers[1].load_state_dict_from_disk, block_idx, adapter_block_idx)
115117
self.prefetch_futures.append(future)
118+
else:
119+
for phase in self.cpu_buffers[1]:
120+
future = self.executor.submit(phase.load_state_dict_from_disk, block_idx, adapter_block_idx)
121+
self.prefetch_futures.append(future)
116122

117123
def swap_cpu_buffers(self):
118-
# wait_start = time.time()
124+
# import time
125+
# wait_start = time.time()
119126
# already_done = all(f.done() for f in self.prefetch_futures)
120127
for f in self.prefetch_futures:
121128
f.result()

lightx2v/common/ops/mm/mm_weight.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,15 +1531,15 @@ def __init__(
15311531
if ops is not None:
15321532
self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
15331533
else:
1534-
self.act_quant_func = self.fp8_quantize_triton
1534+
self.act_quant_func = fp8_quantize_triton
15351535

15361536
def apply(self, input_tensor):
15371537
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
15381538
output_tensor = fp8_linear(
15391539
input_tensor_quant,
15401540
self.weight,
15411541
self.bias.float() if self.bias is not None else None,
1542-
input_tensor_scale,
1542+
input_tensor_scale.float(),
15431543
self.weight_scale,
15441544
out_dtype=self.infer_dtype,
15451545
)
@@ -1582,15 +1582,15 @@ def __init__(
15821582
if ops is not None:
15831583
self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
15841584
else:
1585-
self.act_quant_func = self.int8_quantize_triton
1585+
self.act_quant_func = int8_quantize_triton
15861586

15871587
def apply(self, input_tensor):
15881588
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
15891589
output_tensor = q8_linear(
15901590
input_tensor_quant,
15911591
self.weight,
15921592
self.bias.float() if self.bias is not None else None,
1593-
input_tensor_scale,
1593+
input_tensor_scale.float(),
15941594
self.weight_scale,
15951595
fuse_gelu=False,
15961596
out_dtype=self.infer_dtype,

lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,20 @@ def __init__(self, config):
7171
self.load()
7272

7373
def load(self):
74-
self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(os.path.join(self.config["model_path"], "text_encoder"), torch_dtype=torch.bfloat16)
74+
if self.config.get("qwen25vl_quantized", False):
75+
assert self.config["qwen25vl_quant_scheme"] == "int4"
76+
if self.config["cpu_offload"]:
77+
self.device_map = {
78+
"lm_head": AI_DEVICE,
79+
"model.visual": "cpu",
80+
"model.language_model": "cpu",
81+
}
82+
else:
83+
self.device_map = "auto"
84+
self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(self.config["qwen25vl_quantized_ckpt"], dtype=torch.bfloat16, device_map=self.device_map, low_cpu_mem_usage=True)
85+
else:
86+
self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(os.path.join(self.config["model_path"], "text_encoder"), torch_dtype=torch.bfloat16)
87+
7588
if not self.cpu_offload:
7689
self.text_encoder = self.text_encoder.to(AI_DEVICE)
7790

@@ -99,7 +112,8 @@ def preprocess_image(self, image):
99112
@torch.no_grad()
100113
def infer(self, text, image_list=None):
101114
if self.cpu_offload:
102-
self.text_encoder.to(AI_DEVICE)
115+
if not hasattr(self, "device_map") or self.device_map == "auto":
116+
self.text_encoder.to(AI_DEVICE)
103117

104118
if image_list is not None:
105119
condition_image_list = []
@@ -143,7 +157,6 @@ def infer(self, text, image_list=None):
143157
image_grid_thw=model_inputs.image_grid_thw,
144158
output_hidden_states=True,
145159
)
146-
147160
image_info = {
148161
"condition_image_list": condition_image_list,
149162
"vae_image_list": vae_image_list,
@@ -183,7 +196,8 @@ def infer(self, text, image_list=None):
183196
prompt_embeds_mask = prompt_embeds_mask.view(1 * 1, seq_len)
184197

185198
if self.cpu_offload:
186-
self.text_encoder.to(torch.device("cpu"))
199+
if not hasattr(self, "device_map") or self.device_map == "auto":
200+
self.text_encoder.to(torch.device("cpu"))
187201
torch_device_module.empty_cache()
188202
gc.collect()
189203

lightx2v/models/networks/qwen_image/infer/offload/transformer_infer.py

Lines changed: 107 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import torch
22

33
from lightx2v.common.offload.manager import WeightAsyncStreamManager
4-
from lightx2v.models.networks.qwen_image.infer.transformer_infer import QwenImageTransformerInfer
4+
from lightx2v.models.networks.qwen_image.infer.transformer_infer import (
5+
QwenImageTransformerInfer,
6+
)
57
from lightx2v_platform.base.global_var import AI_DEVICE
68

79
torch_device_module = getattr(torch, AI_DEVICE)
@@ -10,35 +12,125 @@
1012
class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer):
1113
def __init__(self, config):
1214
super().__init__(config)
13-
self.phases_num = 3
1415
self.num_blocks = config["num_layers"]
16+
self.phases_num = 4
1517
if self.config.get("cpu_offload", False):
1618
if "offload_ratio" in self.config:
1719
self.offload_ratio = self.config["offload_ratio"]
1820
else:
1921
self.offload_ratio = 1
2022
offload_granularity = self.config.get("offload_granularity", "block")
2123
if offload_granularity == "block":
22-
if not self.config.get("lazy_load", False):
23-
self.infer_func = self.infer_with_blocks_offload
24-
else:
25-
assert NotImplementedError
26-
27-
if offload_granularity != "model":
24+
self.infer_func = self.infer_with_blocks_offload
2825
self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity)
29-
else:
30-
assert NotImplementedError
26+
elif offload_granularity == "phase":
27+
self.infer_func = self.infer_with_phases_offload
28+
self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity)
29+
30+
self.lazy_load = self.config.get("lazy_load", False)
31+
if self.lazy_load:
32+
self.offload_manager.init_lazy_load(num_workers=self.config.get("num_disk_workers", 4))
33+
34+
def infer_with_phases_offload(
35+
self,
36+
blocks,
37+
hidden_states,
38+
encoder_hidden_states,
39+
temb_img_silu,
40+
temb_txt_silu,
41+
image_rotary_emb,
42+
modulate_index,
43+
):
44+
for block_idx in range(len(blocks)):
45+
self.block_idx = block_idx
46+
if self.lazy_load:
47+
next_prefetch = (block_idx + 1) % len(blocks)
48+
self.offload_manager.start_prefetch_block(next_prefetch)
49+
50+
for phase_idx in range(self.phases_num):
51+
# if self.offload_manager.need_init_first_buffer:
52+
if block_idx == 0 and phase_idx == 0:
53+
self.offload_manager.init_first_buffer(blocks)
54+
55+
next_block_idx = (block_idx + 1) % len(blocks) if phase_idx == self.phases_num - 1 else block_idx
56+
next_phase_idx = (phase_idx + 1) % self.phases_num
57+
if self.lazy_load:
58+
if phase_idx == self.phases_num - 1:
59+
self.offload_manager.swap_cpu_buffers()
3160

32-
def infer_with_blocks_offload(self, block_weights, hidden_states, encoder_hidden_states, temb_img_silu, temb_txt_silu, image_rotary_emb, modulate_index):
61+
self.offload_manager.prefetch_phase(next_block_idx, next_phase_idx, blocks)
62+
with torch_device_module.stream(self.offload_manager.compute_stream):
63+
if phase_idx == 0:
64+
img_query, img_key, img_value, img_gate1, img_mod2 = self.infer_img_qkv(
65+
img_attn_phase=self.offload_manager.cuda_buffers[phase_idx],
66+
hidden_states=hidden_states,
67+
temb_img_silu=temb_img_silu,
68+
img_freqs=image_rotary_emb[0],
69+
modulate_index=modulate_index,
70+
)
71+
elif phase_idx == 1:
72+
txt_query, txt_key, txt_value, seq_txt, txt_gate1, txt_mod2 = self.infer_txt_qkv(
73+
txt_attn_phase=self.offload_manager.cuda_buffers[phase_idx],
74+
encoder_hidden_states=encoder_hidden_states,
75+
temb_txt_silu=temb_txt_silu,
76+
txt_freqs=image_rotary_emb[1],
77+
)
78+
elif phase_idx == 2:
79+
hidden_states, encoder_hidden_states = self.infer_cross_attn(
80+
cross_attn_phase=self.offload_manager.cuda_buffers[phase_idx],
81+
seq_txt=seq_txt,
82+
img_query=img_query,
83+
img_key=img_key,
84+
img_value=img_value,
85+
txt_query=txt_query,
86+
txt_key=txt_key,
87+
txt_value=txt_value,
88+
img_gate1=img_gate1,
89+
txt_gate1=txt_gate1,
90+
hidden_states=hidden_states,
91+
encoder_hidden_states=encoder_hidden_states,
92+
)
93+
94+
elif phase_idx == 3:
95+
encoder_hidden_states, hidden_states = self.infer_ffn(
96+
ffn_phase=self.offload_manager.cuda_buffers[phase_idx],
97+
hidden_states=hidden_states,
98+
encoder_hidden_states=encoder_hidden_states,
99+
img_mod2=img_mod2,
100+
txt_mod2=txt_mod2,
101+
modulate_index=modulate_index,
102+
)
103+
self.offload_manager.swap_phases()
104+
105+
return hidden_states
106+
107+
def infer_with_blocks_offload(
108+
self,
109+
blocks,
110+
hidden_states,
111+
encoder_hidden_states,
112+
temb_img_silu,
113+
temb_txt_silu,
114+
image_rotary_emb,
115+
modulate_index,
116+
):
33117
for block_idx in range(self.num_blocks):
34118
self.block_idx = block_idx
119+
120+
if self.lazy_load:
121+
next_prefetch = (block_idx + 1) % self.num_blocks
122+
self.offload_manager.start_prefetch_block(next_prefetch)
123+
35124
if block_idx == 0:
36-
self.offload_manager.init_first_buffer(block_weights.blocks)
37-
if block_idx + 1 < self.num_blocks:
38-
self.offload_manager.prefetch_weights(block_idx + 1, block_weights.blocks)
125+
self.offload_manager.init_first_buffer(blocks)
126+
127+
if self.lazy_load:
128+
self.offload_manager.swap_cpu_buffers()
129+
self.offload_manager.prefetch_weights((block_idx + 1) % self.num_blocks, blocks)
130+
39131
with torch_device_module.stream(self.offload_manager.compute_stream):
40132
encoder_hidden_states, hidden_states = self.infer_block(
41-
block_weight=self.offload_manager.cuda_buffers[0],
133+
block=self.offload_manager.cuda_buffers[0],
42134
hidden_states=hidden_states,
43135
encoder_hidden_states=encoder_hidden_states,
44136
temb_img_silu=temb_img_silu,

0 commit comments

Comments
 (0)