11import torch
22
33from lightx2v .common .offload .manager import WeightAsyncStreamManager
4- from lightx2v .models .networks .qwen_image .infer .transformer_infer import QwenImageTransformerInfer
4+ from lightx2v .models .networks .qwen_image .infer .transformer_infer import (
5+ QwenImageTransformerInfer ,
6+ )
57from lightx2v_platform .base .global_var import AI_DEVICE
68
79torch_device_module = getattr (torch , AI_DEVICE )
@@ -11,6 +13,7 @@ class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer):
1113 def __init__ (self , config ):
1214 super ().__init__ (config )
1315 self .num_blocks = config ["num_layers" ]
16+ self .phases_num = 3
1417 if self .config .get ("cpu_offload" , False ):
1518 if "offload_ratio" in self .config :
1619 self .offload_ratio = self .config ["offload_ratio" ]
@@ -20,12 +23,94 @@ def __init__(self, config):
2023 if offload_granularity == "block" :
2124 self .infer_func = self .infer_with_blocks_offload
2225 self .offload_manager = WeightAsyncStreamManager (offload_granularity = offload_granularity )
26+ elif offload_granularity == "phase" :
27+ self .infer_func = self .infer_with_phases_offload
28+ self .offload_manager = WeightAsyncStreamManager (offload_granularity = offload_granularity )
2329
2430 self .lazy_load = self .config .get ("lazy_load" , False )
2531 if self .lazy_load :
2632 self .offload_manager .init_lazy_load (num_workers = self .config .get ("num_disk_workers" , 4 ))
2733
28- def infer_with_blocks_offload (self , block_weights , hidden_states , encoder_hidden_states , temb_img_silu , temb_txt_silu , image_rotary_emb , modulate_index ):
34+ def infer_with_phases_offload (
35+ self ,
36+ blocks ,
37+ hidden_states ,
38+ encoder_hidden_states ,
39+ temb_img_silu ,
40+ temb_txt_silu ,
41+ image_rotary_emb ,
42+ modulate_index ,
43+ ):
44+ for block_idx in range (len (blocks )):
45+ self .block_idx = block_idx
46+ if self .lazy_load :
47+ next_prefetch = (block_idx + 1 ) % len (blocks )
48+ self .offload_manager .start_prefetch_block (next_prefetch )
49+
50+ for phase_idx in range (self .phases_num ):
51+ # if self.offload_manager.need_init_first_buffer:
52+ if block_idx == 0 and phase_idx == 0 :
53+ self .offload_manager .init_first_buffer (blocks )
54+
55+ next_block_idx = (block_idx + 1 ) % len (blocks ) if phase_idx == self .phases_num - 1 else block_idx
56+ next_phase_idx = (phase_idx + 1 ) % self .phases_num
57+ if self .lazy_load :
58+ if phase_idx == self .phases_num - 1 :
59+ self .offload_manager .swap_cpu_buffers ()
60+
61+ self .offload_manager .prefetch_phase (next_block_idx , next_phase_idx , blocks )
62+ with torch_device_module .stream (self .offload_manager .compute_stream ):
63+ if phase_idx == 0 :
64+ (
65+ img_modulated ,
66+ txt_modulated ,
67+ img_gate1 ,
68+ txt_gate1 ,
69+ img_mod2 ,
70+ txt_mod2 ,
71+ ) = self .infer_modulate (
72+ mod_phase = self .offload_manager .cuda_buffers [phase_idx ],
73+ hidden_states = hidden_states ,
74+ encoder_hidden_states = encoder_hidden_states ,
75+ temb_img_silu = temb_img_silu ,
76+ temb_txt_silu = temb_txt_silu ,
77+ modulate_index = modulate_index ,
78+ )
79+ elif phase_idx == 1 :
80+ hidden_states , encoder_hidden_states = self .infer_attn (
81+ attn_phase = self .offload_manager .cuda_buffers [phase_idx ],
82+ img_modulated = img_modulated ,
83+ txt_modulated = txt_modulated ,
84+ img_gate1 = img_gate1 ,
85+ txt_gate1 = txt_gate1 ,
86+ hidden_states = hidden_states ,
87+ encoder_hidden_states = encoder_hidden_states ,
88+ image_rotary_emb = image_rotary_emb ,
89+ )
90+ elif phase_idx == 2 :
91+ encoder_hidden_states , hidden_states = self .infer_ffn (
92+ ffn_phase = self .offload_manager .cuda_buffers [phase_idx ],
93+ hidden_states = hidden_states ,
94+ encoder_hidden_states = encoder_hidden_states ,
95+ img_mod2 = img_mod2 ,
96+ txt_mod2 = txt_mod2 ,
97+ modulate_index = modulate_index ,
98+ )
99+
100+ self .offload_manager .swap_phases ()
101+
102+ return hidden_states
103+
104+ def infer_with_blocks_offload (
105+ self ,
106+ blocks ,
107+ hidden_states ,
108+ encoder_hidden_states ,
109+ temb_img_silu ,
110+ temb_txt_silu ,
111+ image_rotary_emb ,
112+ modulate_index ,
113+ ):
29114 for block_idx in range (self .num_blocks ):
30115 self .block_idx = block_idx
31116
@@ -34,15 +119,15 @@ def infer_with_blocks_offload(self, block_weights, hidden_states, encoder_hidden
34119 self .offload_manager .start_prefetch_block (next_prefetch )
35120
36121 if block_idx == 0 :
37- self .offload_manager .init_first_buffer (block_weights . blocks )
122+ self .offload_manager .init_first_buffer (blocks )
38123
39124 if self .lazy_load :
40125 self .offload_manager .swap_cpu_buffers ()
41- self .offload_manager .prefetch_weights ((block_idx + 1 ) % self .num_blocks , block_weights . blocks )
126+ self .offload_manager .prefetch_weights ((block_idx + 1 ) % self .num_blocks , blocks )
42127
43128 with torch_device_module .stream (self .offload_manager .compute_stream ):
44129 encoder_hidden_states , hidden_states = self .infer_block (
45- block_weight = self .offload_manager .cuda_buffers [0 ],
130+ block = self .offload_manager .cuda_buffers [0 ],
46131 hidden_states = hidden_states ,
47132 encoder_hidden_states = encoder_hidden_states ,
48133 temb_img_silu = temb_img_silu ,
0 commit comments