diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0053074bad8e..089a18e16ae7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -193,6 +193,7 @@ "ControlNetXSAdapter", "CosmosTransformer3DModel", "DiTTransformer2DModel", + "DreamTransformer1DModel", "EasyAnimateTransformer3DModel", "FluxControlNetModel", "FluxMultiControlNetModel", @@ -301,6 +302,7 @@ "DPMSolverMultistepInverseScheduler", "DPMSolverMultistepScheduler", "DPMSolverSinglestepScheduler", + "DreamMaskedDiffusionScheduler", "EDMDPMSolverMultistepScheduler", "EDMEulerScheduler", "EulerAncestralDiscreteScheduler", @@ -413,6 +415,7 @@ "CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline", "CycleDiffusionPipeline", + "DreamTextPipeline", "EasyAnimateControlPipeline", "EasyAnimateInpaintPipeline", "EasyAnimatePipeline", @@ -857,6 +860,7 @@ ControlNetXSAdapter, CosmosTransformer3DModel, DiTTransformer2DModel, + DreamTransformer1DModel, EasyAnimateTransformer3DModel, FluxControlNetModel, FluxMultiControlNetModel, @@ -956,6 +960,7 @@ DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, + DreamMaskedDiffusionScheduler, EDMDPMSolverMultistepScheduler, EDMEulerScheduler, EulerAncestralDiscreteScheduler, @@ -1047,6 +1052,7 @@ CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline, CycleDiffusionPipeline, + DreamTextPipeline, EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, EasyAnimatePipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 972233bd987d..6ecf44d4ba38 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -80,6 +80,7 @@ _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"] _import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"] + _import_structure["transformers.transformer_dream"] = ["DreamTransformer1DModel"] _import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] @@ -165,6 +166,7 @@ ConsisIDTransformer3DModel, CosmosTransformer3DModel, DiTTransformer2DModel, + DreamTransformer1DModel, DualTransformer2DModel, EasyAnimateTransformer3DModel, FluxTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 5550fed92d28..cd0e238a79ce 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -21,6 +21,7 @@ from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_cogview4 import CogView4Transformer2DModel from .transformer_cosmos import CosmosTransformer3DModel + from .transformer_dream import DreamTransformer1DModel from .transformer_easyanimate import EasyAnimateTransformer3DModel from .transformer_flux import FluxTransformer2DModel from .transformer_hidream_image import HiDreamImageTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_dream.py b/src/diffusers/models/transformers/transformer_dream.py new file mode 100644 index 000000000000..1bf6140058e5 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_dream.py @@ -0,0 +1,532 @@ +import inspect +import math +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..activations import get_activation +from ..attention import AttentionModuleMixin +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Based on transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Based on transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1 +) -> torch.Tensor: + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Ultimately from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class DreamAttnProcessor: + _attention_backend = None + + def __call__( + self, + attn: "DreamAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + ) -> torch.Tensor: + # TODO: can caching be implemented in diffusers like it is in the original code? + # hidden_states shape: (batch_size, seq_len, hidden_dim) = [B, L, D] + batch_size, query_len, _ = hidden_states.size() + query = attn.to_q(hidden_states) # [B, L, D] + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + key = attn.to_k(encoder_hidden_states) # [B, L, D] --> [B, L, D_KV] + value = attn.to_v(encoder_hidden_states) # [B, L, D] --> [B, L, D_KV] + + # TODO: call attn.head_to_batch_dim here instead??? + # original code sends [batch_size, seq_len, hidden_dim] to [batch_size, num_heads, seq_len, head_dim] + # batch_to_head_dim instead sends it to [batch_size // num_heads, seq_len, dim * heads] + query = query.view(batch_size, query_len, attn.heads, attn.head_dim).transpose(1, 2) # [B, N, L, H] + key = key.view(batch_size, query_len, attn.kv_heads, attn.head_dim).transpose(1, 2) # [B, N_KV, L, H] + value = value.view(batch_size, query_len, attn.kv_heads, attn.head_dim).transpose(1, 2) # [B, N_KV, L, H] + + if rotary_emb is not None: + # TODO: rewrite in terms of embeddings.apply_rotary_emb??? + query, key = apply_rotary_pos_emb(query, key, rotary_emb[0], rotary_emb[1]) + + # Repeat KV heads if attn.kv_heads < attn.heads + key = repeat_kv(key, attn.kv_groups) # [B, N_KV, L, H] --> [B, N, L, H] + value = repeat_kv(query, attn.kv_groups) # [B, N_KV, L, H] --> [B, N, L, H] + + # TODO: call dispatch_attention_fn here to dispatch the implementation to a backend? e.g. FlashAttn + # hidden_states = dispatch_attention_fn( + # query, key, value, attn_mask=attention_mask, backend=self._attention_backend + # ) + # TODO: call attn.get_attention_scores here instead??? + # For example, this would handle upcasting the attention operation for us + attn_scores = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(attn.head_dim) # [B, N, L, L] + if attention_mask is not None: + # Not matter the length, we just slice the attention mask + # TODO: check shapes here, is attention_mask expected to be a causal (upper-triangular) mask of shape + # [B, 1, L, L]???? + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_scores = attn_scores + causal_mask + + # TODO: could use something like torch.autocast from torch AMP here + if attn.upcast_softmax: + original_dtype = attn_scores.dtype + attn_scores = attn_scores.to(dtype=torch.float32) + attn_scores = F.softmax(attn_scores, dim=-1) + if attn.upcast_softmax: + attn_scores = attn_scores.to(dtype=original_dtype) + attn_scores = F.dropout(attn_scores, p=attn.dropout, training=attn.training) + hidden_states = torch.matmul(attn_scores, value) # [B, N, L, H] + + # TODO: call attn.batch_to_head_dim here instead???? + hidden_states = hidden_states.transpose(1, 2).contiguous() # [B, L, N, H] + hidden_states = hidden_states.reshape(batch_size, query_len, attn.inner_dim) # [B, L, D] + + hidden_states = attn.to_out(hidden_states) + + return hidden_states + + +class DreamSdpaAttnProcessor: + _attention_backend = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "DreamAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + ) -> torch.Tensor: + # TODO: can caching be implemented in diffusers like it is in the original code? + # hidden_states shape: (batch_size, seq_len, hidden_dim) = [B, L, D] + batch_size, query_len, _ = hidden_states.size() + query = attn.to_q(hidden_states) # [B, L, D] + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + key = attn.to_k(encoder_hidden_states) # [B, L, D] --> [B, L, D_KV] + value = attn.to_v(encoder_hidden_states) # [B, L, D] --> [B, L, D_KV] + + # TODO: call attn.head_to_batch_dim here instead??? + # original code sends [batch_size, seq_len, hidden_dim] to [batch_size, num_heads, seq_len, head_dim] + # batch_to_head_dim instead sends it to [batch_size // num_heads, seq_len, dim * heads] + query = query.view(batch_size, query_len, attn.heads, attn.head_dim).transpose(1, 2) # [B, N, L, H] + key = key.view(batch_size, query_len, attn.kv_heads, attn.head_dim).transpose(1, 2) # [B, N_KV, L, H] + value = value.view(batch_size, query_len, attn.kv_heads, attn.head_dim).transpose(1, 2) # [B, N_KV, L, H] + + if rotary_emb is not None: + # TODO: rewrite in terms of embeddings.apply_rotary_emb??? + query, key = apply_rotary_pos_emb(query, key, rotary_emb[0], rotary_emb[1]) + + # Repeat KV heads if attn.kv_heads < attn.heads + key = repeat_kv(key, attn.kv_groups) # [B, N_KV, L, H] --> [B, N, L, H] + value = repeat_kv(query, attn.kv_groups) # [B, N_KV, L, H] --> [B, N, L, H] + + # TODO: call dispatch_attention_fn here to dispatch the implementation to a backend? e.g. FlashAttn + # hidden_states = dispatch_attention_fn( + # query, key, value, attn_mask=attention_mask, backend=self._attention_backend + # ) + # TODO: check SDPA call here + hidden_states = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=attn.dropout if attn.training else 0.0, + is_causal=False, # hard-coded like in original code + ) + + # TODO: call attn.batch_to_head_dim here instead???? + hidden_states = hidden_states.transpose(1, 2).contiguous() # [B, L, N, H] + hidden_states = hidden_states.reshape(batch_size, query_len, attn.inner_dim) # [B, L, D] + + hidden_states = attn.to_out(hidden_states) + + return hidden_states + + +class DreamAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = DreamAttnProcessor + _available_processors = [ + DreamAttnProcessor, + DreamSdpaAttnProcessor, + ] + + def __init__( + self, + query_dim: int, # 3584 in Dream-7B??? + heads: int = 28, + kv_heads: Optional[int] = 4, + dim_head: int = 128, # 3584 // 28 = 128 + dropout: float = 0.0, + bias: bool = True, + out_bias: bool = False, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + upcast_softmax: bool = True, + processor=None, + ): + super().__init__() + + self.query_dim = query_dim + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads # num_heads in original code + self.kv_heads = kv_heads if kv_heads is not None else heads + self.kv_inner_dim = dim_head * self.kv_heads + self.kv_groups = self.heads // self.kv_heads # num_key_value_groups + + self.dropout = dropout + self.use_bias = bias + self.upcast_softmax = upcast_softmax + + # q_proj, k_proj, v_proj in original code + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.kv_inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.kv_inner_dim, bias=bias) + + # o_proj in original code + self.to_out = torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs) + + +# Based on transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Dream +class DreamRotaryEmbedding(nn.Module): + def __init__( + self, + dim: int, + theta: float = 1000000.0, # Not 10000.0 as is standard + ): + super().__init__() + self.theta = theta + + # Default RoPE initialization + inv_freq = 1.0 / (self.theta ** (torch.arange(0, dim, 2) / dim)) # [D // 2] + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # position_ids shape: [B, S] + # TODO: rewrite in terms of get_1d_rotary_pos_embed? + inv_freq_expanded = self.inv_freq[None, :, None].expand(position_ids.shape[0], -1, 1) # [B, D // 2, 1]? + position_ids_expanded = position_ids[:, None, :] # [B, 1, S]? + + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) # [B, S, D // 2]? + emb = torch.cat((freqs, freqs), dim=-1) # [B, S, D] + cos = emb.cos() + sin = emb.sin() + + return cos, sin + + +# Based on transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Dream +class DreamRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DreamRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# Based on transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Dream +class DreamMLP(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: Optional[int] = 4, # mult is not an integer for Dream-7B - it's 18944 / 3584 = 37 / 7 + dropout: float = 0.0, # dropout is actually not used in the Dream MLP + activation_fn: str = "silu", + inner_dim = 18944, + bias: bool = False, + ): + super().__init__() + self.hidden_size = dim + if inner_dim is None: + inner_dim = int(dim * mult) + self.intermediate_size = inner_dim + self.dim_out = dim_out if dim_out is not None else dim + + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.act_fn = get_activation(activation_fn) + + self.down_proj = nn.Linear(self.intermediate_size, self.dim_out, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.up_proj(hidden_states) + + gated_hidden_states = self.gate_proj(hidden_states) + gated_hidden_states = self.act_fn(gated_hidden_states) + + hidden_states = gated_hidden_states * hidden_states + hidden_states = self.down_proj(hidden_states) + return hidden_states + + +@maybe_allow_in_graph +class DreamTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + num_attention_kv_heads: Optional[int], + attention_head_dim: int, + ff_intermediate_dim: int = 18944, + eps: float = 1e-6, + ): + super().__init__() + + # Input LayerNorm + self.norm1 = DreamRMSNorm(dim, eps=eps) + + self.attn = DreamAttention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_kv_heads, + dim_head=attention_head_dim, + processor=DreamSdpaAttnProcessor(), + ) + + # Post-attention LayerNorm + self.norm2 = DreamRMSNorm(dim, eps=eps) + self.ff = DreamMLP(dim=dim, dim_out=dim, inner_dim=ff_intermediate_dim) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, # temb is not used in Dream (time-invariant model) + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + # hidden_states shape: [batch_size, seq_len, hidden_dim] = [B, L, D] + residual = hidden_states + + # Input LayerNorm + hidden_states = self.norm1(hidden_states) + + # Attention + shortcut connection + hidden_states = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + rotary_emb=rotary_emb, + ) + hidden_states = residual + hidden_states + + # Fully-connected + shortcut connection + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.ff(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class DreamTransformer1DModel( + ModelMixin, + ConfigMixin, # TODO: add other mixins as necessary +): + """ + The diffusion transformer model used in the Dream-7B diffusion LLM. + + See https://hkunlp.github.io/blog/2025/dream/. The original transformers-style implementation is at + https://huggingface.co/Dream-org/Dream-v0-Base-7B/blob/main/modeling_dream.py. + + Args: + TODO + """ + + _supports_gradient_checkpointing = False + _no_split_modules = ["DreamTransformerBlock"] + _skip_layerwise_casting_patterns = ["embedding", "norm"] + _repeated_blocks = ["DreamTransformerBlock"] + + @register_to_config + def __init__( + self, + num_layers: int = 28, + attention_head_dim: int = 128, + num_attention_heads: int = 28, + num_attention_kv_heads: Optional[int] = 4, + ff_intermediate_dim: int = 18944, + rms_norm_eps: float = 1e-6, + rope_theta: float = 1000000.0, + vocab_size: int = 152064, + pad_token_id: int = 151643, + ): + super().__init__() + self.inner_dim = num_attention_heads * attention_head_dim # hidden_size = 3584 in original code + self.pad_token_id = pad_token_id + + # TODO: can we replace this with a diffusers embedding module? + self.token_embedding = nn.Embedding(vocab_size, self.inner_dim, self.pad_token_id) + self.rotary_embedding = DreamRotaryEmbedding(dim=attention_head_dim, theta=rope_theta) + + self.transformer_blocks = nn.ModuleList( + [ + DreamTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + num_attention_kv_heads=num_attention_kv_heads, + attention_head_dim=attention_head_dim, + ff_intermediate_dim=ff_intermediate_dim, + eps=rms_norm_eps, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = DreamRMSNorm(self.inner_dim, eps=rms_norm_eps) + self.lm_head = nn.Linear(self.inner_dim, vocab_size, bias=False) + + def embed_tokens(self, text_ids: torch.Tensor) -> torch.Tensor: + return self.token_embedding(text_ids) + + def forward( + self, + text_ids: torch.Tensor = None, + hidden_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + timestep: Optional[torch.LongTensor] = None, # not used by Dream (time-invariant model) + attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`DreamTransformer1DModel`] forward method. + + Args: + text_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The indices of the input text tokens. + hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + The already embedded hidden states for the transformer. This is analogous to `inputs_embeds` for a + transformers model. + position_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + The indices of the positions of each token within the input. Will be created if not supplied. + timestep (`torch.LongTensor`): + Used to indicate denoising step. Not used currently as Dream is a time-invariant model. + attention_mask (`torch.Tensor`, *optional*): + An optional attention mask. This is mainly useful for training, as Dream is trained with an attention + mask annealing strategy. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # text_ids shape: [B, L] + if hidden_states is None: + # Embed text tokens + hidden_states = self.token_embedding(text_ids) # [B, L] --> [B, L, D] + + # Create position_ids if not supplied + if position_ids is None: + position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device) + position_ids = position_ids.unsqueeze(0) # [L] --> [1, L] + # Get RoPE embeddings (shared across all layers) + rotary_emb = self.rotary_embedding(position_ids) + + # Transformer decoder layers + for block in self.transformer_blocks: + hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb) + + hidden_states = self.norm_out(hidden_states) + logits = self.lm_head(hidden_states) + + if not return_dict: + return (logits,) + + # TODO: arguably the input is not 2D here since it is of shape (batch_size, seq_len, vocab_size) + return Transformer2DModelOutput(sample=logits) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 535b23dbb4ee..fe5ca1680430 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -226,6 +226,7 @@ "IFPipeline", "IFSuperResolutionPipeline", ] + _import_structure["dream"] = ["DreamTextPipeline"] _import_structure["easyanimate"] = [ "EasyAnimatePipeline", "EasyAnimateInpaintPipeline", @@ -608,6 +609,7 @@ VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, ) + from .dream import DreamTextPipeline from .easyanimate import ( EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, diff --git a/src/diffusers/pipelines/dream/__init__.py b/src/diffusers/pipelines/dream/__init__.py new file mode 100644 index 000000000000..aa4a8f220d9b --- /dev/null +++ b/src/diffusers/pipelines/dream/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {"pipeline_output": ["DreamTextPipelineOutput"]} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_dream"] = ["DreamTextPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_dream import DreamTextPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/dream/pipeline_dream.py b/src/diffusers/pipelines/dream/pipeline_dream.py new file mode 100644 index 000000000000..6368b5c4478c --- /dev/null +++ b/src/diffusers/pipelines/dream/pipeline_dream.py @@ -0,0 +1,376 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F + +from ...models import DreamTransformer1DModel +from ...schedulers import DreamMaskedDiffusionScheduler +from ...utils import is_torch_xla_available, logging +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import DreamTextPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class DreamTextPipeline(DiffusionPipeline): + r""" + Dream 7B diffusion based LLM. + + Introduced in https://hkunlp.github.io/blog/2025/dream/. + """ + + model_cpu_offload_seq = "transformer" + _optional_components = [] # TODO: list any optional components here + _callback_tensor_inputs = [] # TODO: what needs to be here? + + def __init__( + self, + tokenizer, + transformer: DreamTransformer1DModel, + scheduler: DreamMaskedDiffusionScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + # 131072 in original code + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 512 + ) + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def tokenize_prompt( + self, + prompt: Union[str, List[str]], + num_texts_per_prompt: int = 1, + max_sequence_length: int = 512, + apply_chat_template: bool = False, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if apply_chat_template: + prompt_is_chat_template = isinstance(prompt[0], dict) or (isinstance(prompt[0], list) and isinstance(prompt[0][0], dict)) + if not prompt_is_chat_template: + # Apply simple chat template for each supplied prompt + prompt = [{"role": "user", "content": prompt_instance} for prompt_instance in prompt] + # Call the PreTrainedTokenier's apply_chat_template method for chat generation + text_inputs = self.tokenizer.apply_chat_template( + prompt, + return_tensors="pt", + return_dict=False, # List[int] output rather than Dict output + add_generation_prompt=True, + ) + else: + # Call the tokenizer's normal __call__ method + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device=device) + attention_mask = text_inputs.attention_mask.to(device=device) + + # duplicate text tokens and attention mask for each generation per prompt, using mps friendly method + # TODO: this follows e.g. the Flux pipeline's encode_prompts, why do we repeat in the sequence length dim + # rather than the batch length dim...? + text_input_ids = text_input_ids.repeat(1, num_texts_per_prompt) + text_input_ids = text_input_ids.view(batch_size * num_texts_per_prompt, -1) + + attention_mask = attention_mask.repeat(1, num_texts_per_prompt) + attention_mask = attention_mask.view(batch_size * num_texts_per_prompt, -1) + + return text_input_ids, attention_mask + + def prepare_latents( + self, + batch_size: int, + max_sequence_length: int, + text_ids: Optional[torch.Tensor] = None, + latents: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + ): + latents_shape = (batch_size, max_sequence_length) + if latents is None and text_ids is None: + # Create all-masks latents of length max_sequence_length + latents = torch.full(latents_shape, self.scheduler.config.mask_token_id, dtype=torch.long, device=device) + elif latents is None and text_ids is not None: + # Pad text_ids to max_sequence_length with mask tokens + # NOTE: text_ids is assumed to have the correct batch dimension already + latents = F.pad( + text_ids, (0, max_sequence_length - text_ids.shape[1]), value=self.scheduler.config.mask_token_id + ) + else: + if latents.ndim == 1: + # Unsqueeze a batch dim + latents = latents.unsqueeze(0) + # bring latents to the correct batch size + current_batch_size = latents.shape[0] + if batch_size % current_batch_size == 0: + repeat_factor = batch_size // current_batch_size + latents = latents.repeat(repeat_factor, 1) + else: + raise ValueError( + f"The `latents` batch size {current_batch_size} must evenly divide the total batch size" + f" {batch_size}." + ) + + # If latents is not max_sequence_length, pad to max_sequence_length with mask tokens + latents = F.pad( + latents, (0, max_sequence_length - latents.shape[1]), value=self.scheduler.config.mask_token_id + ) + + latents = latents.to(device) + + return latents + + def check_inputs(self, prompt, prompt_embeds, latents): + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is not None and latents is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `latents`: {latents}. Please make sure to" + " only forward one of the two." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[List[str], str]] = None, + num_inference_steps: int = 512, + num_texts_per_prompt: Optional[int] = 1, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.IntTensor] = None, # TODO: does supporting both latents and prompt_embeds make sense? + prompt_embeds: Optional[torch.Tensor] = None, + temperature: Union[float, Tuple[float, float], List[float]] = 0.2, + top_p: Union[float, Tuple[float, float], List[float]] = 0.95, + max_sequence_length: int = 512, + apply_chat_template: bool = False, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + output_type: str = "pil", # TODO: replace with options appropriate for text + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ) -> Union[DreamTextPipelineOutput, Tuple[Any]]: + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide text generation. A chat template for + `transformers.PreTrainedTokenizer.apply_chat_template` can be used if `apply_chat_template` is set to + `True`. + num_inference_steps (`int`, *optional*, defaults to 512): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + num_texts_per_prompt (`int`, *optional*, defaults to 1): + The number of text outputs to generate per prompt. If neither `prompts` nor `prompt_embeds` is + supplied, this will be interpreted as the batch size for generation. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.IntTensor`, *optional*): + Pre-generated text tokens from which to start generation. If supplied, this should include any + conditioning text tokens (analogous to a tokenized version of `prompt`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. A single vector from the + pooled and projected final hidden states. + temperature (`Union[float, Tuple[float, float], List[float]]`, *optional*, defaults to 0.2): + Configures the temperature scheduler on `self.scheduler`; see `DreamMaskedDiffusionScheduler#set_timesteps`. + top_p (`Union[float, Tuple[float, float], List[float]]`, *optional*, defaults to 0.95): + Configures the top-p probability scheduler on `self.scheduler`; see `DreamMaskedDiffusionScheduler#set_timesteps`. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Returns: + [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images. + """ + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, prompt_embeds, latents) + + # 2. Define call parameters + # NOTE: it is possible for both prompt and prompt_embeds to be None (which corresponds to "unconditional" text + # generation) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + elif prompt_embeds is not None: + batch_size = prompt_embeds.shape[0] + else: + batch_size = 1 + + device = self._execution_device + + self._current_timestep = None + self._interrupt = False + + # 3. Tokenize input text prompts, if any + if prompt is None: + prompt = [prompt] if isinstance(prompt, str) else prompt + text_ids, attention_mask = self.tokenize_prompt( + prompt=prompt, + num_texts_per_prompt=num_texts_per_prompt, + apply_chat_template=apply_chat_template, + device=device, + ) + else: + text_ids, attention_mask = None + + # 4. Prepare latent variables (e.g. the initial sample) for generation + total_batch_size = batch_size * num_texts_per_prompt + latents = self.prepare_latents( + total_batch_size, + max_sequence_length, + text_ids=text_ids, + latents=latents, + device=device, + ) + + if prompt_embeds is not None: + prompt_embeds = self.transformer.embed_tokens(latents) + else: + # If prompt_embeds's seq len is not max_sequence_length, concat with embedding of mask tokens for the + # remaining length + padding_length = max_sequence_length - prompt_embeds.shape[1] + if padding_length > 0: + padding_mask_tokens = torch.full( + (total_batch_size, padding_length), self.scheduler.config.mask_token_id, device=device + ) + padding_mask_embedding = self.transformer.embed_tokens(padding_mask_tokens) + prompt_embeds = torch.cat([prompt_embeds, padding_mask_embedding], dim=1) + else: + # Truncate to max_sequence_length, if necessary + prompt_embeds = prompt_embeds[:, :max_sequence_length, :] + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, temperature=temperature, top_p=top_p, device=device) + timesteps = self.scheduler.timesteps + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + if i > 0: + model_output = self.transformer( + text_ids=latents, + attention_mask=attention_mask, + return_dict=False, + )[0] + else: + # Use prompt_embeds only at the first step, to support supplying an initial prompt embedding + model_output = self.transformer( + hidden_states=prompt_embeds, + attention_mask=attention_mask, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step( + model_output=model_output, + timestep=t, + sample=latents, + generator=generator, + ).prev_sample + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 7. Post-processing and output handling + self._current_timestep = None + + if output_type == "latent": + texts = latents + else: + # TODO: should there be a text_processor class analogous to e.g. VaeImageProcessor??? + texts = self.tokenizer.batch_decode(latents) + # TODO: if prompt or other conditioning is supplied, remove prompts from generated texts??? + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (texts,) + + return DreamTextPipelineOutput(texts=texts) diff --git a/src/diffusers/pipelines/dream/pipeline_output.py b/src/diffusers/pipelines/dream/pipeline_output.py new file mode 100644 index 000000000000..8d2f39c0ce2c --- /dev/null +++ b/src/diffusers/pipelines/dream/pipeline_output.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass +from typing import List + +from ...utils import BaseOutput + + +@dataclass +class DreamTextPipelineOutput(BaseOutput): + """ + Output class for the Dream-7B diffusion LLM. + + Args: + text + """ + + # For example, should we also accept token ids? Or only output token ids? + texts: List[str] diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 29052c1ba0cb..c843806bd61a 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -54,6 +54,7 @@ _import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"] _import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"] _import_structure["scheduling_dpmsolver_singlestep"] = ["DPMSolverSinglestepScheduler"] + _import_structure["scheduling_dream"] = ["DreamMaskedDiffusionScheduler"] _import_structure["scheduling_edm_dpmsolver_multistep"] = ["EDMDPMSolverMultistepScheduler"] _import_structure["scheduling_edm_euler"] = ["EDMEulerScheduler"] _import_structure["scheduling_euler_ancestral_discrete"] = ["EulerAncestralDiscreteScheduler"] @@ -156,6 +157,7 @@ from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler + from .scheduling_dream import DreamMaskedDiffusionScheduler from .scheduling_edm_dpmsolver_multistep import EDMDPMSolverMultistepScheduler from .scheduling_edm_euler import EDMEulerScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler diff --git a/src/diffusers/schedulers/scheduling_dream.py b/src/diffusers/schedulers/scheduling_dream.py new file mode 100644 index 000000000000..e49c27a043f6 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_dream.py @@ -0,0 +1,442 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from .scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def create_schedule( + schedule_params: Optional[Union[float, Tuple[float, float], List[float], torch.Tensor]], + num_inference_steps: int, + device: Optional[Union[str, torch.device]] = None, +) -> torch.Tensor: + if schedule_params is None: + schedule = None + elif isinstance(schedule_params, float): + # Interpret as a constant schedule for all timesteps + schedule = torch.full(num_inference_steps, schedule_params) + elif isinstance(schedule_params, (tuple, list)): + # Interpret first and second elems as start and end points of a linear schedule + schedule = torch.linspace(schedule_params[0], schedule_params[1], num_inference_steps) + elif isinstance(schedule_params, torch.Tensor): + # Interpret this as the fully specified schedule + if schedule_params.ndim != 1: + raise ValueError(f"Expected torch tensor schedule to have 1 dim but has {schedule_params.ndim} dims") + if schedule_params.shape[0] != num_inference_steps: + raise ValueError( + f"Receive torch tensor schedule but length ({schedule_params}) does not match num_inference_steps " + f"({num_inference_steps})" + ) + schedule = schedule_params + else: + raise ValueError( + f"`schedule_params` is of unrecognized type {type(schedule_params)}; should be either a float, tuple, " + f"list, or `torch.Tensor`." + ) + + if schedule is not None: + schedule = schedule.to(device=device) + return schedule + + +def top_p_logits(logits: torch.Tensor, top_p: Optional[float] = None) -> torch.Tensor: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) + mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) + logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) + return logits + + +def top_k_logits(logits: torch.Tensor, top_k: Optional[int] = None) -> torch.Tensor: + top_k = min(top_k, logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) + return logits + + +def sample_tokens( + logits: torch.Tensor, + temperature: float = 0.0, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + margin_confidence: bool = False, + neg_entropy: bool = False, + generator: Optional[torch.Generator] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Samples from a sequence of logits of shape [..., vocab_size] and returns both the sampled sequence (as the second + return elem) and the model probabilities for the chosen tokens (as the first return elem). + """ + # logits shape: [B, L, V] + if temperature > 0: + logits = logits / temperature + if top_p is not None and top_p < 1: + logits = top_p_logits(logits, top_p) + if top_k is not None: + logits = top_k_logits(logits, top_k) + + probs = torch.softmax(logits, dim=-1) + device = probs.device + probs_ = probs.to(generator.device) if generator is not None else probs # handles when generator is on CPU + if probs_.device.type == "cpu" and probs_.dtype != torch.float32: + probs_ = probs_.float() # multinomial is not implemented for cpu half precision + if probs.ndim > 2: + probs_ = probs_.reshape(-1, probs.size(-1)) # [B, L, V] --> [B * L, V] + + if temperature > 0: + try: + # Sample x0 ~ Cat(probs) + x0 = torch.multinomial(probs_, 1, generator=generator).to(device=device) + if probs.ndim > 2: + x0 = x0[:, 0].view(*probs.shape[:-1]) # [B * L, 1] --> [B, L] + confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) # [B, L] + except: + confidence, x0 = probs.max(dim=-1) + else: + confidence, x0 = probs.max(dim=-1) + + if margin_confidence: + sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) + # Extract top1 and top2 probabilities + top1_probs = sorted_probs[:, 0] + top2_probs = sorted_probs[:, 1] + # Calculate confidence as top1 - top2 + confidence = top1_probs - top2_probs + + if neg_entropy: + epsilon = 1e-10 + log_probs = torch.log(probs + epsilon) + confidence = torch.sum(probs * log_probs, dim=-1) + + return confidence, x0 + + +@dataclass +class DreamMaskedDiffusionSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, seq_len)` for text): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, seg_len)` for text): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +class DreamMaskedDiffusionScheduler(SchedulerMixin, ConfigMixin): + """ + Scheduler for the Dream 7B masked diffusion model. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + masking_schedule (`str`, defaults to `"linear"`): + The noise schedule for discrete diffusion, often represented as alpha_t. This determines the probability + that tokens are masked in the forward process. Available choices are `"linear"`, `"cosine"`, and + `"polynomial"`. + timestep_discretization (`str`, defaults to `"linear"`): + The function which specifies how we discretize (continuous) time [0, 1]. Available strategies are + `"linear"` (evenly spaced timesteps) and `"cosine"`. + logit_sampling_alg (`str`, defaults to `"entropy"`): + The algorithm used to sample from the predicted logits. This incorporates sampling techniques such as + temperature, top-p, top-k, etc. Available algorithms are `"origin"`, `"maskgit_plus"`, `"topk_margin"`, + and `"entropy"` (names match those of original code). + shift (`bool`, defaults to `True`): + Whether to shift the logits before sampling. Dream models shift the logits such that the (n - 1)-th token + predicts the n-th token, mirroring the behavior of AR models. + polynomial_exp (`int`, defaults to `1`): + When `masking_schedule` is set to `"polynomial"`, this specifies the exponent of the polynomial. The + default value of `1` is equivalent to a `"linear"` masking schedule. + final_timestep (`float`, defaults to `1e-3`): + The value of the final timestep in the schedule, which should be a small positive number close to 0 for + numerical stability reasons. + temperature: (`float` or `tuple` or `list`, defaults to `0.2`): + The temperature used when taking the softmax of the predicted logits. If this is a float, we will use that + value at each timestep; if a tuple or list, we will interpret the first and second elements as the start + and end points of a linear schedule. If `None`, a temperature of 1.0 will be used. + top_p: (`float` or `tuple` or `list`, *optional*, defaults to `0.96`): + The probability for top-p sampling. If this is a float, we will use that value at each timestep; if a + tuple or list, we will interpret the first and second elements as the start and end points of a linear + schedule. If `None`, top-p sampling will not be performed. + top_k: (`int`, *optional*, defaults to `None`): + The k value for top-p sampling. If not set, top-k sampling will not be performed. + alg_temperature: (`float`, *optional*, defaults to `0.0`): + Used for certain logit sampling strategies, such as `"maskgit_plus"`, `"topk_margin"`, and `"entropy"`. If + > 0, we will use this as a temperature when taking the softmax over the model confidences to decide which + tokens to unmask. Otherwise, we will deterministically select the tokens with the highest confidences. + mask_token_id (`int`, defaults to `151666`): + The token id of the mask token in the tokenizer. The default value corresponds to the mask token id in the + official Dream 7B tokenizer. + start_token_id (`int`, defaults to `151643`): + The token id of the start/BOS token in the tokenizer. The default value corresponds to the BOS token id + in the official Dream 7B tokenizer. + """ + + order = 1 + + @register_to_config + def __init__( + self, + masking_schedule: str = "linear", + timestep_discretization: str = "linear", + logit_sampling_alg: str = "entropy", + shift: bool = True, + polynomial_exp: int = 1, + final_timestep: float = 1e-3, # small positive value for final timestep (eps in original code) + temperature: Optional[Union[float, Tuple[float], List[float]]] = 0.2, + top_p: Optional[Union[float, Tuple[float], List[float]]] = 0.95, + top_k: Optional[int] = None, + alg_temperature: Optional[float] = 0.0, + mask_token_id: int = 151666, + start_token_id: int = 151643, + ): + # Setable values + self.num_inference_steps = None + self.temperatures = None + self.top_p_schedule = None + self.top_k_schedule = None + + self.timesteps = None + self.alphas = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def set_timesteps( + self, + num_inference_steps: int, + temperature: Optional[Union[float, Tuple[float, float], List[int], torch.Tensor]] = None, + top_p: Optional[Union[float, Tuple[float, float], List[int], torch.Tensor]] = None, + device: Optional[Union[str, torch.device]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + temperature (`float` or `tuple` or `list` or `torch.Tensor`, *optional*, defaults to `None`): + A custom temperature schedule to override the configured temperature schedule. If this is a float, we + will use that float at each timestep; if a tuple or list, we will interpret the first and second + elements as the start and end points of a linear schedule; if a `torch.Tensor`, we will interpret this + as the full temperature schedule (must have length `num_inference_steps`). + top-p (`float` or `tuple` or `list` or `torch.Tensor`, *optional*, defaults to `None`): + A custom top-p schedule to override the configured top-p schedule. If this is a float, we will use + that value at each timestep; if a tuple or list, we will interpret the first and second elements as + the start and end points of a linear schedule; if a `torch.Tensor`, we will interpret this as the full + top-p schedule (must have length `num_inference_steps`). + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.timestep_discretization == "linear": + timesteps = torch.linspace(1.0, self.config.final_timestep, num_inference_steps + 1, device=device) + elif self.config.timestep_discretization == "cosine": + timesteps = torch.linspace(self.config.final_timestep, 1.0, num_inference_steps + 1) + timesteps = torch.cos((torch.pi / 2) * (1.0 - timesteps)).to(device) + else: + raise ValueError( + f"{self.config.timestep_discretization} is not a supported timestep discretization strategy. Current " + f"supported strategies are `linear` and `cosine`." + ) + self.timesteps = timesteps + + # Now calculate the masking or noise schedule (alpha) values at the chosen timestep discretization + if self.config.masking_schedule == "linear": + alphas = 1.0 - self.timesteps + elif self.config.masking_schedule == "cosine": + alphas = 1.0 - torch.cos((torch.pi / 2) * (1.0 - self.timesteps)) + elif self.config.masking_schedule == "polynomial": + alphas = 1.0 - torch.pow(self.timesteps, self.config.polynomial_exp) + else: + raise ValueError( + f"{self.config.masking_schedule} is not a supported masking schedule. Currently supported schedules " + f"are `linear`, `cosine`, and `polynomial`." + ) + self.alphas = alphas.to(device=device) + + # Allow overriding of specific sampling parameters (temperature, top_p, etc.) + if temperature is None: + temperature = self.config.temperature + self.temperatures = create_schedule(temperature, num_inference_steps) + + if top_p is None: + top_p = self.config.top_p + self.top_p_schedule = create_schedule(top_p, num_inference_steps) + + def step( + self, + model_output: torch.Tensor, + timestep: Union[float, torch.Tensor], + sample: torch.Tensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[DreamMaskedDiffusionSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reverse process. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_dream.DreamMaskedDiffusionSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_dream.DreamMaskedDiffusionSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + # model_output shape: [B, L, V] + # sample shape: [B, L] (sequence of discrete tokens) + step_idx = self.index_for_timestep(timestep) + t = self.timesteps[step_idx] # Current timestep + s = self.timesteps[step_idx + 1] # Previous timestep (next-largest timestep not yet processed) + temperature = self.temperatures[step_idx] if self.temperatures is not None else 1.0 + top_p = self.top_p_schedule[step_idx] if self.top_p_schedule is not None else None + top_k = self.top_k_schedule[step_idx] if self.top_k_schedule is not None else None + + mask_map = sample == self.config.mask_token_id + + if self.config.shift: + # Right shift the logits from the model + # Dream models are trained to predict at right-shifted positions, analogous to an autoregressive model, + # so we also need to shift the inputs at inference time + model_output = torch.cat(model_output[:, :1], model_output[:, :-1], dim=1) + + # Probability of unmasking each token at time t + unmask_prob = (self.alphas[step_idx + 1] - self.alphas[step_idx]) / (1 - self.alphas[step_idx]) + # Unmask all remaining masked tokens at last inference step + unmask_prob = unmask_prob if step_idx < self.num_inference_steps - 1 else 1.0 + + # TODO: mask logits (model_output) beforehand? might make it more efficient? + if self.config.logit_sampling_alg == "origin": + to_unmask_mask = torch.rand(*sample.shape, generator=generator, device=sample.device) < unmask_prob + confidence, pred_original_sample = sample_tokens( + model_output, temperature=temperature, top_p=top_p, top_k=top_k, generator=generator + ) + prev_sample = torch.where(to_unmask_mask, pred_original_sample, sample) + else: + if self.config.logit_sampling_alg == "maskgit_plus": + confidence, pred_original_sample = sample_tokens( + model_output, temperature=temperature, top_p=top_p, top_k=top_k, generator=generator + ) + elif self.config.logit_sampling_alg == "topk_margin": + confidence, pred_original_sample = sample_tokens( + model_output, + temperature=temperature, + top_p=top_p, + top_k=top_k, + margin_confidence=True, + generator=generator, + ) + elif self.config.logit_sampling_alg == "entropy": + confidence, pred_original_sample = sample_tokens( + model_output, + temperature=temperature, + top_p=top_p, + top_k=top_k, + neg_entropy=True, + generator=generator, + ) + + # Unmask a fixed number of tokens at each timestep depending on unmask_prob + num_masked_tokens = mask_map.sum() / mask_map.shape[0] + num_tokens_to_unmask = int(num_masked_tokens * unmask_prob) + full_confidence = torch.full_like(sample, -torch.inf, dtype=model_output.dtype, device=sample.device) + full_confidence = torch.where(mask_map, confidence, full_confidence) + + if num_tokens_to_unmask > 0: + if self.config.alg_temperature is None or self.config.alg_temperature == 0: + _, unmask_index = torch.topk(full_confidence, num_tokens_to_unmask) + else: + full_confidence = full_confidence / self.config.alg_temperature + full_confidence = F.softmax(full_confidence, dim=-1) + unmask_index = torch.multinomial(full_confidence, num_samples=num_tokens_to_unmask) + + prev_sample = torch.zeros_like(sample, device=sample.device) + prev_sample = torch.where(unmask_index, pred_original_sample, sample) + + # TODO: do we need to shift the tokens again at the end??? + if not return_dict: + return (prev_sample, pred_original_sample) + + return DreamMaskedDiffusionSchedulerOutput(prev_sample, pred_original_sample) + + def add_noise( + self, + original_samples: torch.Tensor, + timesteps: torch.Tensor, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + # For each batch instance i with timestep t_i, mask each position independently with prob 1 - alphas[t_i] + # original_samples shape: [B, L] + # Make sure alphas and timesteps have the same device and dtype as original_samples + alphas = self.alphas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + + mask_probs = 1.0 - alphas[step_indices].flatten() + while len(mask_probs).shape < len(original_samples.shape): + mask_probs.unsqueeze(-1) + + mask_indices = ( + torch.rand( + original_samples.shape, + device=generator.device if generator is not None else original_samples.device, + generator=generator, + ).to(original_samples.device) + < mask_probs + ) + + masked_samples = original_samples.clone() + masked_samples[mask_indices] = self.config.mask_token_id + + return masked_samples