@@ -327,6 +327,18 @@ def num_timesteps(self):
327327 @property
328328 def interrupt (self ):
329329 return self ._interrupt
330+
331+ def enable_transformer_block_cpu_offload (self , device : Union [torch .device , str ] = "cuda" ):
332+ torch_device = torch .device (device )
333+ for name , param in self .transformer .named_parameters ():
334+ if 'layers' in name and 'layers.0' not in name :
335+ param .data = param .data .cpu ()
336+ else :
337+ param .data = param .data .to (torch_device )
338+ for buffer_name , buffer in self .transformer .patch_embedding .named_buffers ():
339+ setattr (self .transformer .patch_embedding , buffer_name , buffer .to (torch_device ))
340+ self .vae .to (torch_device )
341+ self .offload_transformer_block = True
330342
331343 @torch .no_grad ()
332344 @replace_example_docstring (EXAMPLE_DOC_STRING )
@@ -440,6 +452,9 @@ def __call__(
440452 # using Float32 for the VAE doesn't take up much memory but can prevent potential black image outputs.
441453 self .vae .to (torch .float32 )
442454
455+ if offload_transformer_block :
456+ self .enable_transformer_block_cpu_offload ()
457+
443458 # 1. Check inputs. Raise error if not correct
444459 self .check_inputs (
445460 prompt ,
@@ -460,9 +475,10 @@ def __call__(
460475 batch_size = len (prompt )
461476 device = self ._execution_device
462477
478+
463479 # 3. process multi-modal instructions
464480 if max_input_image_size != self .multimodal_processor .max_image_size :
465- self .multimodal_processor = OmniGenMultiModalProcessor (self .text_tokenizer , max_image_size = max_input_image_size )
481+ self .multimodal_processor = OmniGenMultiModalProcessor (self .tokenizer , max_image_size = max_input_image_size )
466482 processed_data = self .multimodal_processor (prompt ,
467483 input_images ,
468484 height = height ,
@@ -521,7 +537,7 @@ def __call__(
521537 position_ids = processed_data ['position_ids' ],
522538 attention_kwargs = attention_kwargs ,
523539 past_key_values = cache ,
524- offload_transformer_block = offload_transformer_block ,
540+ offload_transformer_block = self . offload_transformer_block if hasattr ( self , 'offload_transformer_block' ) else offload_transformer_block ,
525541 return_dict = False ,
526542 )
527543
0 commit comments