Skip to content

Commit 236f14b

Browse files
committed
enable_transformer_block_cpu_offload
1 parent c78d1f4 commit 236f14b

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

src/diffusers/models/transformers/transformer_omnigen.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,21 +74,18 @@ def evict_previous_layer(self, layer_idx: int):
7474
prev_layer_idx = layer_idx - 1
7575
for name, param in self.layers[prev_layer_idx].named_parameters():
7676
param.data = param.data.to("cpu", non_blocking=True)
77-
77+
7878
def get_offload_layer(self, layer_idx: int, device: torch.device):
7979
# init stream
8080
if not hasattr(self, "prefetch_stream"):
8181
self.prefetch_stream = torch.cuda.Stream()
8282

8383
# delete previous layer
84-
# main stream sync shouldn't be necessary since all computation on iter i-1 is finished by iter i
85-
# torch.cuda.current_stream().synchronize()
86-
# avoid extra eviction of last layer
87-
if layer_idx > 0:
88-
self.evict_previous_layer(layer_idx)
89-
84+
torch.cuda.current_stream().synchronize()
85+
self.evict_previous_layer(layer_idx)
86+
9087
# make sure the current layer is ready
91-
self.prefetch_stream.synchronize()
88+
torch.cuda.synchronize(self.prefetch_stream)
9289

9390
# load next layer
9491
self.prefetch_layer((layer_idx + 1) % len(self.layers), device)

src/diffusers/pipelines/omnigen/pipeline_omnigen.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,18 @@ def num_timesteps(self):
327327
@property
328328
def interrupt(self):
329329
return self._interrupt
330+
331+
def enable_transformer_block_cpu_offload(self, device: Union[torch.device, str] = "cuda"):
332+
torch_device = torch.device(device)
333+
for name, param in self.transformer.named_parameters():
334+
if 'layers' in name and 'layers.0' not in name:
335+
param.data = param.data.cpu()
336+
else:
337+
param.data = param.data.to(torch_device)
338+
for buffer_name, buffer in self.transformer.patch_embedding.named_buffers():
339+
setattr(self.transformer.patch_embedding, buffer_name, buffer.to(torch_device))
340+
self.vae.to(torch_device)
341+
self.offload_transformer_block = True
330342

331343
@torch.no_grad()
332344
@replace_example_docstring(EXAMPLE_DOC_STRING)
@@ -440,6 +452,9 @@ def __call__(
440452
# using Float32 for the VAE doesn't take up much memory but can prevent potential black image outputs.
441453
self.vae.to(torch.float32)
442454

455+
if offload_transformer_block:
456+
self.enable_transformer_block_cpu_offload()
457+
443458
# 1. Check inputs. Raise error if not correct
444459
self.check_inputs(
445460
prompt,
@@ -460,9 +475,10 @@ def __call__(
460475
batch_size = len(prompt)
461476
device = self._execution_device
462477

478+
463479
# 3. process multi-modal instructions
464480
if max_input_image_size != self.multimodal_processor.max_image_size:
465-
self.multimodal_processor = OmniGenMultiModalProcessor(self.text_tokenizer, max_image_size=max_input_image_size)
481+
self.multimodal_processor = OmniGenMultiModalProcessor(self.tokenizer, max_image_size=max_input_image_size)
466482
processed_data = self.multimodal_processor(prompt,
467483
input_images,
468484
height=height,
@@ -521,7 +537,7 @@ def __call__(
521537
position_ids=processed_data['position_ids'],
522538
attention_kwargs=attention_kwargs,
523539
past_key_values=cache,
524-
offload_transformer_block=offload_transformer_block,
540+
offload_transformer_block=self.offload_transformer_block if hasattr(self, 'offload_transformer_block') else offload_transformer_block,
525541
return_dict=False,
526542
)
527543

0 commit comments

Comments
 (0)