@@ -86,12 +86,25 @@ def get_3d_sincos_pos_embed(
8686 temporal_interpolation_scale : float = 1.0 ,
8787) -> np .ndarray :
8888 r"""
89+ Creates 3D sinusoidal positional embeddings.
90+
8991 Args:
9092 embed_dim (`int`):
93+ The embedding dimension of inputs. It must be divisible by 16.
9194 spatial_size (`int` or `Tuple[int, int]`):
95+ The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
96+ spatial dimensions (height and width).
9297 temporal_size (`int`):
98+ The temporal dimension of postional embeddings (number of frames).
9399 spatial_interpolation_scale (`float`, defaults to 1.0):
100+ Scale factor for spatial grid interpolation.
94101 temporal_interpolation_scale (`float`, defaults to 1.0):
102+ Scale factor for temporal grid interpolation.
103+
104+ Returns:
105+ `np.ndarray`:
106+ The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
107+ embed_dim]`.
95108 """
96109 if embed_dim % 4 != 0 :
97110 raise ValueError ("`embed_dim` must be divisible by 4" )
@@ -129,8 +142,24 @@ def get_2d_sincos_pos_embed(
129142 embed_dim , grid_size , cls_token = False , extra_tokens = 0 , interpolation_scale = 1.0 , base_size = 16
130143):
131144 """
132- grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
133- [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
145+ Creates 2D sinusoidal positional embeddings.
146+
147+ Args:
148+ embed_dim (`int`):
149+ The embedding dimension.
150+ grid_size (`int`):
151+ The size of the grid height and width.
152+ cls_token (`bool`, defaults to `False`):
153+ Whether or not to add a classification token.
154+ extra_tokens (`int`, defaults to `0`):
155+ The number of extra tokens to add.
156+ interpolation_scale (`float`, defaults to `1.0`):
157+ The scale of the interpolation.
158+
159+ Returns:
160+ pos_embed (`np.ndarray`):
161+ Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
162+ embed_dim]` if using cls_token
134163 """
135164 if isinstance (grid_size , int ):
136165 grid_size = (grid_size , grid_size )
@@ -148,6 +177,16 @@ def get_2d_sincos_pos_embed(
148177
149178
150179def get_2d_sincos_pos_embed_from_grid (embed_dim , grid ):
180+ r"""
181+ This function generates 2D sinusoidal positional embeddings from a grid.
182+
183+ Args:
184+ embed_dim (`int`): The embedding dimension.
185+ grid (`np.ndarray`): Grid of positions with shape `(H * W,)`.
186+
187+ Returns:
188+ `np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
189+ """
151190 if embed_dim % 2 != 0 :
152191 raise ValueError ("embed_dim must be divisible by 2" )
153192
@@ -161,7 +200,14 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
161200
162201def get_1d_sincos_pos_embed_from_grid (embed_dim , pos ):
163202 """
164- embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
203+ This function generates 1D positional embeddings from a grid.
204+
205+ Args:
206+ embed_dim (`int`): The embedding dimension `D`
207+ pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)`
208+
209+ Returns:
210+ `numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`.
165211 """
166212 if embed_dim % 2 != 0 :
167213 raise ValueError ("embed_dim must be divisible by 2" )
@@ -181,7 +227,22 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
181227
182228
183229class PatchEmbed (nn .Module ):
184- """2D Image to Patch Embedding with support for SD3 cropping."""
230+ """
231+ 2D Image to Patch Embedding with support for SD3 cropping.
232+
233+ Args:
234+ height (`int`, defaults to `224`): The height of the image.
235+ width (`int`, defaults to `224`): The width of the image.
236+ patch_size (`int`, defaults to `16`): The size of the patches.
237+ in_channels (`int`, defaults to `3`): The number of input channels.
238+ embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
239+ layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
240+ flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
241+ bias (`bool`, defaults to `True`): Whether or not to use bias.
242+ interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
243+ pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
244+ pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
245+ """
185246
186247 def __init__ (
187248 self ,
@@ -289,7 +350,15 @@ def forward(self, latent):
289350
290351
291352class LuminaPatchEmbed (nn .Module ):
292- """2D Image to Patch Embedding with support for Lumina-T2X"""
353+ """
354+ 2D Image to Patch Embedding with support for Lumina-T2X
355+
356+ Args:
357+ patch_size (`int`, defaults to `2`): The size of the patches.
358+ in_channels (`int`, defaults to `4`): The number of input channels.
359+ embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
360+ bias (`bool`, defaults to `True`): Whether or not to use bias.
361+ """
293362
294363 def __init__ (self , patch_size = 2 , in_channels = 4 , embed_dim = 768 , bias = True ):
295364 super ().__init__ ()
@@ -675,6 +744,20 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
675744
676745
677746def get_2d_rotary_pos_embed_from_grid (embed_dim , grid , use_real = False ):
747+ """
748+ Get 2D RoPE from grid.
749+
750+ Args:
751+ embed_dim: (`int`):
752+ The embedding dimension size, corresponding to hidden_size_head.
753+ grid (`np.ndarray`):
754+ The grid of the positional embedding.
755+ use_real (`bool`):
756+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
757+
758+ Returns:
759+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
760+ """
678761 assert embed_dim % 4 == 0
679762
680763 # use half of dimensions to encode grid_h
@@ -695,6 +778,23 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
695778
696779
697780def get_2d_rotary_pos_embed_lumina (embed_dim , len_h , len_w , linear_factor = 1.0 , ntk_factor = 1.0 ):
781+ """
782+ Get 2D RoPE from grid.
783+
784+ Args:
785+ embed_dim: (`int`):
786+ The embedding dimension size, corresponding to hidden_size_head.
787+ grid (`np.ndarray`):
788+ The grid of the positional embedding.
789+ linear_factor (`float`):
790+ The linear factor of the positional embedding, which is used to scale the positional embedding in the linear
791+ layer.
792+ ntk_factor (`float`):
793+ The ntk factor of the positional embedding, which is used to scale the positional embedding in the ntk layer.
794+
795+ Returns:
796+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
797+ """
698798 assert embed_dim % 4 == 0
699799
700800 emb_h = get_1d_rotary_pos_embed (
0 commit comments