Skip to content

Commit 83c6c27

Browse files
author
gushiqiao
committed
support w4 for qw_vl and disk offload for qw_2511
1 parent 3b18483 commit 83c6c27

File tree

10 files changed

+285
-124
lines changed

10 files changed

+285
-124
lines changed

lightx2v/common/offload/manager.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None):
5050
def init_first_buffer(self, blocks, adapter_block_idx=None):
5151
with torch_device_module.stream(self.init_stream):
5252
if hasattr(self, "cpu_buffers"):
53-
self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0][0].state_dict(), 0, adapter_block_idx)
53+
if self.offload_granularity == "block":
54+
self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0].state_dict(), 0, adapter_block_idx)
55+
else:
56+
self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0][0].state_dict(), 0, adapter_block_idx)
5457
else:
5558
if self.offload_granularity == "block":
5659
self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx)
@@ -62,8 +65,7 @@ def init_first_buffer(self, blocks, adapter_block_idx=None):
6265
def prefetch_weights(self, block_idx, blocks, adapter_block_idx=None):
6366
with torch_device_module.stream(self.cuda_load_stream):
6467
if hasattr(self, "cpu_buffers"):
65-
self.cpu_buffers[1].load_state_dict_from_disk(block_idx, adapter_block_idx)
66-
self.cuda_buffers[1].load_state_dict(self.cpu_buffers[1].state_dict(), block_idx, adapter_block_idx)
68+
self.cuda_buffers[1].load_state_dict(self.cpu_buffers[0].state_dict(), block_idx, adapter_block_idx)
6769
else:
6870
self.cuda_buffers[1].load_state_dict(blocks[block_idx].state_dict(), block_idx, adapter_block_idx)
6971

@@ -110,12 +112,17 @@ def init_lazy_load(self, num_workers=6):
110112
def start_prefetch_block(self, block_idx, adapter_block_idx=None):
111113
self.prefetch_block_idx = block_idx
112114
self.prefetch_futures = []
113-
for phase in self.cpu_buffers[1]:
114-
future = self.executor.submit(phase.load_state_dict_from_disk, block_idx, adapter_block_idx)
115+
if self.offload_granularity == "block":
116+
future = self.executor.submit(self.cpu_buffers[1].load_state_dict_from_disk, block_idx, adapter_block_idx)
115117
self.prefetch_futures.append(future)
118+
else:
119+
for phase in self.cpu_buffers[1]:
120+
future = self.executor.submit(phase.load_state_dict_from_disk, block_idx, adapter_block_idx)
121+
self.prefetch_futures.append(future)
116122

117123
def swap_cpu_buffers(self):
118-
# wait_start = time.time()
124+
# import time
125+
# wait_start = time.time()
119126
# already_done = all(f.done() for f in self.prefetch_futures)
120127
for f in self.prefetch_futures:
121128
f.result()

lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,22 @@ def __init__(self, config):
7171
self.load()
7272

7373
def load(self):
74-
self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(os.path.join(self.config["model_path"], "text_encoder"), torch_dtype=torch.bfloat16)
74+
if self.config.get("qwen25vl_quantized", False):
75+
assert self.config["qwen25vl_quant_scheme"] == "int4"
76+
if self.config["cpu_offload"]:
77+
self.device_map = {
78+
"lm_head": AI_DEVICE,
79+
"model.embed_tokens": AI_DEVICE,
80+
"model.norm": AI_DEVICE,
81+
"model.visual": "cpu",
82+
"model.language_model": "cpu",
83+
}
84+
else:
85+
self.device_map = "auto"
86+
self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(self.config["qwen25vl_quantized_ckpt"], dtype=torch.bfloat16, device_map=self.device_map, low_cpu_mem_usage=True)
87+
else:
88+
self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(os.path.join(self.config["model_path"], "text_encoder"), torch_dtype=torch.bfloat16)
89+
7590
if not self.cpu_offload:
7691
self.text_encoder = self.text_encoder.to(AI_DEVICE)
7792

@@ -99,7 +114,8 @@ def preprocess_image(self, image):
99114
@torch.no_grad()
100115
def infer(self, text, image_list=None):
101116
if self.cpu_offload:
102-
self.text_encoder.to(AI_DEVICE)
117+
if not hasattr(self, "device_map") or self.device_map == "auto":
118+
self.text_encoder.to(AI_DEVICE)
103119

104120
if image_list is not None:
105121
condition_image_list = []
@@ -143,7 +159,6 @@ def infer(self, text, image_list=None):
143159
image_grid_thw=model_inputs.image_grid_thw,
144160
output_hidden_states=True,
145161
)
146-
147162
image_info = {
148163
"condition_image_list": condition_image_list,
149164
"vae_image_list": vae_image_list,
@@ -183,7 +198,8 @@ def infer(self, text, image_list=None):
183198
prompt_embeds_mask = prompt_embeds_mask.view(1 * 1, seq_len)
184199

185200
if self.cpu_offload:
186-
self.text_encoder.to(torch.device("cpu"))
201+
if not hasattr(self, "device_map") or self.device_map == "auto":
202+
self.text_encoder.to(torch.device("cpu"))
187203
torch_device_module.empty_cache()
188204
gc.collect()
189205

lightx2v/models/networks/qwen_image/infer/offload/transformer_infer.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer):
1111
def __init__(self, config):
1212
super().__init__(config)
13-
self.phases_num = 3
1413
self.num_blocks = config["num_layers"]
1514
if self.config.get("cpu_offload", False):
1615
if "offload_ratio" in self.config:
@@ -19,23 +18,28 @@ def __init__(self, config):
1918
self.offload_ratio = 1
2019
offload_granularity = self.config.get("offload_granularity", "block")
2120
if offload_granularity == "block":
22-
if not self.config.get("lazy_load", False):
23-
self.infer_func = self.infer_with_blocks_offload
24-
else:
25-
assert NotImplementedError
26-
27-
if offload_granularity != "model":
21+
self.infer_func = self.infer_with_blocks_offload
2822
self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity)
29-
else:
30-
assert NotImplementedError
23+
24+
self.lazy_load = self.config.get("lazy_load", False)
25+
if self.lazy_load:
26+
self.offload_manager.init_lazy_load(num_workers=self.config.get("num_disk_workers", 4))
3127

3228
def infer_with_blocks_offload(self, block_weights, hidden_states, encoder_hidden_states, temb_img_silu, temb_txt_silu, image_rotary_emb, modulate_index):
3329
for block_idx in range(self.num_blocks):
3430
self.block_idx = block_idx
31+
32+
if self.lazy_load:
33+
next_prefetch = (block_idx + 1) % self.num_blocks
34+
self.offload_manager.start_prefetch_block(next_prefetch)
35+
3536
if block_idx == 0:
3637
self.offload_manager.init_first_buffer(block_weights.blocks)
37-
if block_idx + 1 < self.num_blocks:
38-
self.offload_manager.prefetch_weights(block_idx + 1, block_weights.blocks)
38+
39+
if self.lazy_load:
40+
self.offload_manager.swap_cpu_buffers()
41+
self.offload_manager.prefetch_weights((block_idx + 1) % self.num_blocks, block_weights.blocks)
42+
3943
with torch_device_module.stream(self.offload_manager.compute_stream):
4044
encoder_hidden_states, hidden_states = self.infer_block(
4145
block_weight=self.offload_manager.cuda_buffers[0],

lightx2v/models/networks/qwen_image/model.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ def __init__(self, config):
3636
transformer_config = json.load(f)
3737
self.in_channels = transformer_config["in_channels"]
3838
self.attention_kwargs = {}
39-
39+
self.remove_keys = []
40+
self.lazy_load = self.config.get("lazy_load", False)
41+
if self.lazy_load:
42+
self.remove_keys.extend(["blocks."])
4043
self.dit_quantized = self.config.get("dit_quantized", False)
4144

4245
if self.config["seq_parallel"]:
@@ -75,10 +78,7 @@ def _init_weights(self, weight_dict=None):
7578
weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
7679
else:
7780
# Load quantized weights
78-
if not self.config.get("lazy_load", False):
79-
weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
80-
else:
81-
weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
81+
weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
8282

8383
if self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False):
8484
weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader)
@@ -89,7 +89,10 @@ def _init_weights(self, weight_dict=None):
8989

9090
# Initialize weight containers
9191
self.pre_weight = self.pre_weight_class(self.config)
92-
self.transformer_weights = self.transformer_weight_class(self.config)
92+
if self.lazy_load:
93+
self.transformer_weights = self.transformer_weight_class(self.config, self.lazy_load_path)
94+
else:
95+
self.transformer_weights = self.transformer_weight_class(self.config)
9396
self.post_weight = self.post_weight_class(self.config)
9497
if not self._should_init_empty_model():
9598
self._apply_weights()
@@ -150,8 +153,18 @@ def _load_ckpt(self, unified_dtype, sensitive_layer):
150153
safetensors_path = self.model_path
151154

152155
if os.path.isdir(safetensors_path):
153-
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
156+
if self.lazy_load:
157+
self.lazy_load_path = safetensors_path
158+
non_block_file = os.path.join(safetensors_path, "non_block.safetensors")
159+
if os.path.exists(non_block_file):
160+
safetensors_files = [non_block_file]
161+
else:
162+
raise ValueError(f"Non-block file not found in {safetensors_path}. Please check the model path.")
163+
else:
164+
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
154165
else:
166+
if self.lazy_load:
167+
self.lazy_load_path = safetensors_path
155168
safetensors_files = [safetensors_path]
156169

157170
weight_dict = {}
@@ -171,8 +184,18 @@ def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
171184
safetensors_path = self.model_path
172185

173186
if os.path.isdir(safetensors_path):
174-
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
187+
if self.lazy_load:
188+
self.lazy_load_path = safetensors_path
189+
non_block_file = os.path.join(safetensors_path, "non_block.safetensors")
190+
if os.path.exists(non_block_file):
191+
safetensors_files = [non_block_file]
192+
else:
193+
raise ValueError(f"Non-block file not found in {safetensors_path}. Please check the model path.")
194+
else:
195+
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
175196
else:
197+
if self.lazy_load:
198+
self.lazy_load_path = safetensors_path
176199
safetensors_files = [safetensors_path]
177200
safetensors_path = os.path.dirname(safetensors_path)
178201

@@ -204,28 +227,6 @@ def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
204227

205228
return weight_dict
206229

207-
def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer): # Need rewrite
208-
lazy_load_model_path = self.dit_quantized_ckpt
209-
logger.info(f"Loading splited quant model from {lazy_load_model_path}")
210-
pre_post_weight_dict = {}
211-
212-
safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
213-
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
214-
for k in f.keys():
215-
if f.get_tensor(k).dtype in [
216-
torch.float16,
217-
torch.bfloat16,
218-
torch.float,
219-
]:
220-
if unified_dtype or all(s not in k for s in sensitive_layer):
221-
pre_post_weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device)
222-
else:
223-
pre_post_weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device)
224-
else:
225-
pre_post_weight_dict[k] = f.get_tensor(k).to(self.device)
226-
227-
return pre_post_weight_dict
228-
229230
def _load_weights_from_rank0(self, weight_dict, is_weight_loader):
230231
logger.info("Loading distributed weights")
231232
global_src_rank = 0
@@ -291,6 +292,8 @@ def _init_infer(self):
291292
self.post_infer = self.post_infer_class(self.config)
292293
if hasattr(self.transformer_infer, "offload_manager"):
293294
self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_cuda_buffers, self.transformer_weights.offload_phase_cuda_buffers)
295+
if self.lazy_load:
296+
self.transformer_infer.offload_manager.init_cpu_buffer(self.transformer_weights.offload_block_cpu_buffers, self.transformer_weights.offload_phase_cpu_buffers)
294297

295298
def to_cpu(self):
296299
self.pre_weight.to_cpu()

lightx2v/models/networks/qwen_image/weights/post_weights.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,12 @@ def __init__(self, config):
1010
super().__init__()
1111
self.task = config["task"]
1212
self.config = config
13-
self.lazy_load = self.config.get("lazy_load", False)
14-
if self.lazy_load:
15-
assert NotImplementedError
16-
self.lazy_load_file = False
17-
1813
# norm_out
1914
self.add_module(
2015
"norm_out_linear",
2116
MM_WEIGHT_REGISTER["Default"](
2217
"norm_out.linear.weight",
2318
"norm_out.linear.bias",
24-
self.lazy_load,
25-
self.lazy_load_file,
2619
),
2720
)
2821
self.add_module("norm_out", LN_WEIGHT_REGISTER["Default"](eps=1e-6))
@@ -33,8 +26,6 @@ def __init__(self, config):
3326
MM_WEIGHT_REGISTER["Default"](
3427
"proj_out.weight",
3528
"proj_out.bias",
36-
self.lazy_load,
37-
self.lazy_load_file,
3829
),
3930
)
4031

0 commit comments

Comments
 (0)