Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from .transformer_mochi import MochiTransformer3DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
from .omnigen_transformer import OmniGenTransformer
380 changes: 380 additions & 0 deletions src/diffusers/models/transformers/omnigen_transformer.py

Large diffs are not rendered by default.

183 changes: 183 additions & 0 deletions src/diffusers/models/transformers/phi3_transformer_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint

from transformers.modeling_outputs import (
BaseModelOutputWithPast,
)
from transformers import Phi3Model
from transformers.cache_utils import Cache, DynamicCache
from transformers.utils import logging

logger = logging.get_logger(__name__)

class Phi3Transformer(Phi3Model):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
We only modified the attention mask
Args:
config: Phi3Config
"""
def prefetch_layer(self, layer_idx: int, device: torch.device):
"Starts prefetching the next layer cache"
with torch.cuda.stream(self.prefetch_stream):
# Prefetch next layer tensors to GPU
for name, param in self.layers[layer_idx].named_parameters():
param.data = param.data.to(device, non_blocking=True)

def evict_previous_layer(self, layer_idx: int):
"Moves the previous layer cache to the CPU"
prev_layer_idx = layer_idx - 1
for name, param in self.layers[prev_layer_idx].named_parameters():
param.data = param.data.to("cpu", non_blocking=True)

def get_offlaod_layer(self, layer_idx: int, device: torch.device):
# init stream
if not hasattr(self, "prefetch_stream"):
self.prefetch_stream = torch.cuda.Stream()

# delete previous layer
torch.cuda.current_stream().synchronize()
self.evict_previous_layer(layer_idx)

# make sure the current layer is ready
torch.cuda.synchronize(self.prefetch_stream)

# load next layer
self.prefetch_layer((layer_idx + 1) % len(self.layers), device)


def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
offload_model: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)

# if inputs_embeds is None:
# inputs_embeds = self.embed_tokens(input_ids)

# if cache_position is None:
# past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
# cache_position = torch.arange(
# past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
# )
# if position_ids is None:
# position_ids = cache_position.unsqueeze(0)

if attention_mask is not None and attention_mask.dim() == 3:
dtype = inputs_embeds.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = (1 - attention_mask) * min_dtype
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
else:
raise
# causal_mask = self._update_causal_mask(
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
# )

hidden_states = inputs_embeds

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None

layer_idx = -1
for decoder_layer in self.layers:
layer_idx += 1

if output_hidden_states:
all_hidden_states += (hidden_states,)

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
if offload_model and not self.training:
self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]

if output_attentions:
all_self_attns += (layer_outputs[1],)

hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
print('************')
all_hidden_states += (hidden_states,)

next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
21 changes: 21 additions & 0 deletions src/diffusers/pipelines/omnigen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import TYPE_CHECKING

from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule


_import_structure = {
"pipeline_omnigen": ["OmniGenPipeline"]
}

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_omnigen import OmniGenPipeline

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
Loading