diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 702e5b586d59..b423c17c1246 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -84,6 +84,78 @@ def get_3d_sincos_pos_embed( temporal_size: int, spatial_interpolation_scale: float = 1.0, temporal_interpolation_scale: float = 1.0, + device: Optional[torch.device] = None, + output_type: str = "np", +) -> torch.Tensor: + r""" + Creates 3D sinusoidal positional embeddings. + + Args: + embed_dim (`int`): + The embedding dimension of inputs. It must be divisible by 16. + spatial_size (`int` or `Tuple[int, int]`): + The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both + spatial dimensions (height and width). + temporal_size (`int`): + The temporal dimension of postional embeddings (number of frames). + spatial_interpolation_scale (`float`, defaults to 1.0): + Scale factor for spatial grid interpolation. + temporal_interpolation_scale (`float`, defaults to 1.0): + Scale factor for temporal grid interpolation. + + Returns: + `torch.Tensor`: + The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1], + embed_dim]`. + """ + if output_type == "np": + return _get_3d_sincos_pos_embed_np( + embed_dim=embed_dim, + spatial_size=spatial_size, + temporal_size=temporal_size, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + ) + if embed_dim % 4 != 0: + raise ValueError("`embed_dim` must be divisible by 4") + if isinstance(spatial_size, int): + spatial_size = (spatial_size, spatial_size) + + embed_dim_spatial = 3 * embed_dim // 4 + embed_dim_temporal = embed_dim // 4 + + # 1. Spatial + grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale + grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first + grid = torch.stack(grid, dim=0) + + grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) + pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid, output_type="pt") + + # 2. Temporal + grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale + pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t, output_type="pt") + + # 3. Concat + pos_embed_spatial = pos_embed_spatial[None, :, :] + pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3] + + pos_embed_temporal = pos_embed_temporal[:, None, :] + pos_embed_temporal = pos_embed_temporal.repeat_interleave( + spatial_size[0] * spatial_size[1], dim=1 + ) # [T, H*W, D // 4] + + pos_embed = torch.concat([pos_embed_temporal, pos_embed_spatial], dim=-1) # [T, H*W, D] + return pos_embed + + +def _get_3d_sincos_pos_embed_np( + embed_dim: int, + spatial_size: Union[int, Tuple[int, int]], + temporal_size: int, + spatial_interpolation_scale: float = 1.0, + temporal_interpolation_scale: float = 1.0, ) -> np.ndarray: r""" Creates 3D sinusoidal positional embeddings. @@ -106,6 +178,12 @@ def get_3d_sincos_pos_embed( The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1], embed_dim]`. """ + deprecation_message = ( + "`get_3d_sincos_pos_embed` uses `torch` and supports `device`." + " `from_numpy` is no longer required." + " Pass `output_type='pt' to use the new version now." + ) + deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) if embed_dim % 4 != 0: raise ValueError("`embed_dim` must be divisible by 4") if isinstance(spatial_size, int): @@ -139,6 +217,143 @@ def get_3d_sincos_pos_embed( def get_2d_sincos_pos_embed( + embed_dim, + grid_size, + cls_token=False, + extra_tokens=0, + interpolation_scale=1.0, + base_size=16, + device: Optional[torch.device] = None, + output_type: str = "np", +): + """ + Creates 2D sinusoidal positional embeddings. + + Args: + embed_dim (`int`): + The embedding dimension. + grid_size (`int`): + The size of the grid height and width. + cls_token (`bool`, defaults to `False`): + Whether or not to add a classification token. + extra_tokens (`int`, defaults to `0`): + The number of extra tokens to add. + interpolation_scale (`float`, defaults to `1.0`): + The scale of the interpolation. + + Returns: + pos_embed (`torch.Tensor`): + Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size, + embed_dim]` if using cls_token + """ + if output_type == "np": + deprecation_message = ( + "`get_2d_sincos_pos_embed` uses `torch` and supports `device`." + " `from_numpy` is no longer required." + " Pass `output_type='pt' to use the new version now." + ) + deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) + return get_2d_sincos_pos_embed_np( + embed_dim=embed_dim, + grid_size=grid_size, + cls_token=cls_token, + extra_tokens=extra_tokens, + interpolation_scale=interpolation_scale, + base_size=base_size, + ) + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + grid_h = ( + torch.arange(grid_size[0], device=device, dtype=torch.float32) + / (grid_size[0] / base_size) + / interpolation_scale + ) + grid_w = ( + torch.arange(grid_size[1], device=device, dtype=torch.float32) + / (grid_size[1] / base_size) + / interpolation_scale + ) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first + grid = torch.stack(grid, dim=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type=output_type) + if cls_token and extra_tokens > 0: + pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"): + r""" + This function generates 2D sinusoidal positional embeddings from a grid. + + Args: + embed_dim (`int`): The embedding dimension. + grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`. + + Returns: + `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` + """ + if output_type == "np": + deprecation_message = ( + "`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`." + " `from_numpy` is no longer required." + " Pass `output_type='pt' to use the new version now." + ) + deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) + return get_2d_sincos_pos_embed_from_grid_np( + embed_dim=embed_dim, + grid=grid, + ) + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], output_type=output_type) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], output_type=output_type) # (H*W, D/2) + + emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"): + """ + This function generates 1D positional embeddings from a grid. + + Args: + embed_dim (`int`): The embedding dimension `D` + pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)` + + Returns: + `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`. + """ + if output_type == "np": + deprecation_message = ( + "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`." + " `from_numpy` is no longer required." + " Pass `output_type='pt' to use the new version now." + ) + deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) + return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos) + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.outer(pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D) + return emb + + +def get_2d_sincos_pos_embed_np( embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 ): """ @@ -170,13 +385,13 @@ def get_2d_sincos_pos_embed( grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + pos_embed = get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): +def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid): r""" This function generates 2D sinusoidal positional embeddings from a grid. @@ -191,14 +406,14 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): raise ValueError("embed_dim must be divisible by 2") # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + emb_h = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): +def get_1d_sincos_pos_embed_from_grid_np(embed_dim, pos): """ This function generates 1D positional embeddings from a grid. @@ -288,10 +503,14 @@ def __init__( self.pos_embed = None elif pos_embed_type == "sincos": pos_embed = get_2d_sincos_pos_embed( - embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale + embed_dim, + grid_size, + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + output_type="pt", ) persistent = True if pos_embed_max_size else False - self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) + self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent) else: raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") @@ -341,8 +560,10 @@ def forward(self, latent): grid_size=(height, width), base_size=self.base_size, interpolation_scale=self.interpolation_scale, + device=latent.device, + output_type="pt", ) - pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) + pos_embed = pos_embed.float().unsqueeze(0) else: pos_embed = self.pos_embed @@ -453,7 +674,9 @@ def __init__( pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) - def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: + def _get_positional_embeddings( + self, sample_height: int, sample_width: int, sample_frames: int, device: Optional[torch.device] = None + ) -> torch.Tensor: post_patch_height = sample_height // self.patch_size post_patch_width = sample_width // self.patch_size post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 @@ -465,8 +688,10 @@ def _get_positional_embeddings(self, sample_height: int, sample_width: int, samp post_time_compression_frames, self.spatial_interpolation_scale, self.temporal_interpolation_scale, + device=device, + output_type="pt", ) - pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1) + pos_embedding = pos_embedding.flatten(0, 1) joint_pos_embedding = torch.zeros( 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False ) @@ -521,8 +746,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): or self.sample_width != width or self.sample_frames != pre_time_compression_frames ): - pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames) - pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype) + pos_embedding = self._get_positional_embeddings( + height, width, pre_time_compression_frames, device=embeds.device + ) + pos_embedding = pos_embedding.to(dtype=embeds.dtype) else: pos_embedding = self.pos_embedding @@ -552,9 +779,11 @@ def __init__( # Linear projection for text embeddings self.text_proj = nn.Linear(text_hidden_size, hidden_size) - pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size) + pos_embed = get_2d_sincos_pos_embed( + hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt" + ) pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size) - self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False) + self.register_buffer("pos_embed", pos_embed.float(), persistent=False) def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: batch_size, channel, height, width = hidden_states.shape diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 7e2b1273687d..d34ccfd20108 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -156,9 +156,9 @@ def __init__( # define temporal positional embedding temp_pos_embed = get_1d_sincos_pos_embed_from_grid( - inner_dim, torch.arange(0, video_length).unsqueeze(1) + inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt" ) # 1152 hidden size - self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False) self.gradient_checkpointing = False diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py index cb1514b153ce..1e285a9670e2 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -104,8 +104,8 @@ def __init__( self.use_pos_embed = use_pos_embed if self.use_pos_embed: - pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) - self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5), output_type="pt") + self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=False) def forward(self, latent): latent = self.proj(latent)