@@ -36,7 +36,10 @@ def __init__(self, config):
3636 transformer_config = json .load (f )
3737 self .in_channels = transformer_config ["in_channels" ]
3838 self .attention_kwargs = {}
39-
39+ self .remove_keys = []
40+ self .lazy_load = self .config .get ("lazy_load" , False )
41+ if self .lazy_load :
42+ self .remove_keys .extend (["blocks." ])
4043 self .dit_quantized = self .config .get ("dit_quantized" , False )
4144
4245 if self .config ["seq_parallel" ]:
@@ -75,10 +78,7 @@ def _init_weights(self, weight_dict=None):
7578 weight_dict = self ._load_ckpt (unified_dtype , sensitive_layer )
7679 else :
7780 # Load quantized weights
78- if not self .config .get ("lazy_load" , False ):
79- weight_dict = self ._load_quant_ckpt (unified_dtype , sensitive_layer )
80- else :
81- weight_dict = self ._load_quant_split_ckpt (unified_dtype , sensitive_layer )
81+ weight_dict = self ._load_quant_ckpt (unified_dtype , sensitive_layer )
8282
8383 if self .config .get ("device_mesh" ) is not None and self .config .get ("load_from_rank0" , False ):
8484 weight_dict = self ._load_weights_from_rank0 (weight_dict , is_weight_loader )
@@ -89,7 +89,10 @@ def _init_weights(self, weight_dict=None):
8989
9090 # Initialize weight containers
9191 self .pre_weight = self .pre_weight_class (self .config )
92- self .transformer_weights = self .transformer_weight_class (self .config )
92+ if self .lazy_load :
93+ self .transformer_weights = self .transformer_weight_class (self .config , self .lazy_load_path )
94+ else :
95+ self .transformer_weights = self .transformer_weight_class (self .config )
9396 self .post_weight = self .post_weight_class (self .config )
9497 if not self ._should_init_empty_model ():
9598 self ._apply_weights ()
@@ -150,8 +153,18 @@ def _load_ckpt(self, unified_dtype, sensitive_layer):
150153 safetensors_path = self .model_path
151154
152155 if os .path .isdir (safetensors_path ):
153- safetensors_files = glob .glob (os .path .join (safetensors_path , "*.safetensors" ))
156+ if self .lazy_load :
157+ self .lazy_load_path = safetensors_path
158+ non_block_file = os .path .join (safetensors_path , "non_block.safetensors" )
159+ if os .path .exists (non_block_file ):
160+ safetensors_files = [non_block_file ]
161+ else :
162+ raise ValueError (f"Non-block file not found in { safetensors_path } . Please check the model path." )
163+ else :
164+ safetensors_files = glob .glob (os .path .join (safetensors_path , "*.safetensors" ))
154165 else :
166+ if self .lazy_load :
167+ self .lazy_load_path = safetensors_path
155168 safetensors_files = [safetensors_path ]
156169
157170 weight_dict = {}
@@ -171,8 +184,18 @@ def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
171184 safetensors_path = self .model_path
172185
173186 if os .path .isdir (safetensors_path ):
174- safetensors_files = glob .glob (os .path .join (safetensors_path , "*.safetensors" ))
187+ if self .lazy_load :
188+ self .lazy_load_path = safetensors_path
189+ non_block_file = os .path .join (safetensors_path , "non_block.safetensors" )
190+ if os .path .exists (non_block_file ):
191+ safetensors_files = [non_block_file ]
192+ else :
193+ raise ValueError (f"Non-block file not found in { safetensors_path } . Please check the model path." )
194+ else :
195+ safetensors_files = glob .glob (os .path .join (safetensors_path , "*.safetensors" ))
175196 else :
197+ if self .lazy_load :
198+ self .lazy_load_path = safetensors_path
176199 safetensors_files = [safetensors_path ]
177200 safetensors_path = os .path .dirname (safetensors_path )
178201
@@ -204,28 +227,6 @@ def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
204227
205228 return weight_dict
206229
207- def _load_quant_split_ckpt (self , unified_dtype , sensitive_layer ): # Need rewrite
208- lazy_load_model_path = self .dit_quantized_ckpt
209- logger .info (f"Loading splited quant model from { lazy_load_model_path } " )
210- pre_post_weight_dict = {}
211-
212- safetensor_path = os .path .join (lazy_load_model_path , "non_block.safetensors" )
213- with safe_open (safetensor_path , framework = "pt" , device = "cpu" ) as f :
214- for k in f .keys ():
215- if f .get_tensor (k ).dtype in [
216- torch .float16 ,
217- torch .bfloat16 ,
218- torch .float ,
219- ]:
220- if unified_dtype or all (s not in k for s in sensitive_layer ):
221- pre_post_weight_dict [k ] = f .get_tensor (k ).to (GET_DTYPE ()).to (self .device )
222- else :
223- pre_post_weight_dict [k ] = f .get_tensor (k ).to (GET_SENSITIVE_DTYPE ()).to (self .device )
224- else :
225- pre_post_weight_dict [k ] = f .get_tensor (k ).to (self .device )
226-
227- return pre_post_weight_dict
228-
229230 def _load_weights_from_rank0 (self , weight_dict , is_weight_loader ):
230231 logger .info ("Loading distributed weights" )
231232 global_src_rank = 0
@@ -291,6 +292,8 @@ def _init_infer(self):
291292 self .post_infer = self .post_infer_class (self .config )
292293 if hasattr (self .transformer_infer , "offload_manager" ):
293294 self .transformer_infer .offload_manager .init_cuda_buffer (self .transformer_weights .offload_block_cuda_buffers , self .transformer_weights .offload_phase_cuda_buffers )
295+ if self .lazy_load :
296+ self .transformer_infer .offload_manager .init_cpu_buffer (self .transformer_weights .offload_block_cpu_buffers , self .transformer_weights .offload_phase_cpu_buffers )
294297
295298 def to_cpu (self ):
296299 self .pre_weight .to_cpu ()
0 commit comments