@@ -86,12 +86,25 @@ def get_3d_sincos_pos_embed(
8686    temporal_interpolation_scale : float  =  1.0 ,
8787) ->  np .ndarray :
8888    r""" 
89+     Creates 3D sinusoidal positional embeddings. 
90+ 
8991    Args: 
9092        embed_dim (`int`): 
93+             The embedding dimension of inputs. It must be divisible by 16. 
9194        spatial_size (`int` or `Tuple[int, int]`): 
95+             The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both 
96+             spatial dimensions (height and width). 
9297        temporal_size (`int`): 
98+             The temporal dimension of postional embeddings (number of frames). 
9399        spatial_interpolation_scale (`float`, defaults to 1.0): 
100+             Scale factor for spatial grid interpolation. 
94101        temporal_interpolation_scale (`float`, defaults to 1.0): 
102+             Scale factor for temporal grid interpolation. 
103+ 
104+     Returns: 
105+         `np.ndarray`: 
106+             The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1], 
107+             embed_dim]`. 
95108    """ 
96109    if  embed_dim  %  4  !=  0 :
97110        raise  ValueError ("`embed_dim` must be divisible by 4" )
@@ -129,8 +142,24 @@ def get_2d_sincos_pos_embed(
129142    embed_dim , grid_size , cls_token = False , extra_tokens = 0 , interpolation_scale = 1.0 , base_size = 16 
130143):
131144    """ 
132-     grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or 
133-     [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 
145+     Creates 2D sinusoidal positional embeddings. 
146+ 
147+     Args: 
148+         embed_dim (`int`): 
149+             The embedding dimension. 
150+         grid_size (`int`): 
151+             The size of the grid height and width. 
152+         cls_token (`bool`, defaults to `False`): 
153+             Whether or not to add a classification token. 
154+         extra_tokens (`int`, defaults to `0`): 
155+             The number of extra tokens to add. 
156+         interpolation_scale (`float`, defaults to `1.0`): 
157+             The scale of the interpolation. 
158+ 
159+     Returns: 
160+         pos_embed (`np.ndarray`): 
161+             Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size, 
162+             embed_dim]` if using cls_token 
134163    """ 
135164    if  isinstance (grid_size , int ):
136165        grid_size  =  (grid_size , grid_size )
@@ -148,6 +177,16 @@ def get_2d_sincos_pos_embed(
148177
149178
150179def  get_2d_sincos_pos_embed_from_grid (embed_dim , grid ):
180+     r""" 
181+     This function generates 2D sinusoidal positional embeddings from a grid. 
182+ 
183+     Args: 
184+         embed_dim (`int`): The embedding dimension. 
185+         grid (`np.ndarray`): Grid of positions with shape `(H * W,)`. 
186+ 
187+     Returns: 
188+         `np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` 
189+     """ 
151190    if  embed_dim  %  2  !=  0 :
152191        raise  ValueError ("embed_dim must be divisible by 2" )
153192
@@ -161,7 +200,14 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
161200
162201def  get_1d_sincos_pos_embed_from_grid (embed_dim , pos ):
163202    """ 
164-     embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) 
203+     This function generates 1D positional embeddings from a grid. 
204+ 
205+     Args: 
206+         embed_dim (`int`): The embedding dimension `D` 
207+         pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)` 
208+ 
209+     Returns: 
210+         `numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`. 
165211    """ 
166212    if  embed_dim  %  2  !=  0 :
167213        raise  ValueError ("embed_dim must be divisible by 2" )
@@ -181,7 +227,22 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
181227
182228
183229class  PatchEmbed (nn .Module ):
184-     """2D Image to Patch Embedding with support for SD3 cropping.""" 
230+     """ 
231+     2D Image to Patch Embedding with support for SD3 cropping. 
232+ 
233+     Args: 
234+         height (`int`, defaults to `224`): The height of the image. 
235+         width (`int`, defaults to `224`): The width of the image. 
236+         patch_size (`int`, defaults to `16`): The size of the patches. 
237+         in_channels (`int`, defaults to `3`): The number of input channels. 
238+         embed_dim (`int`, defaults to `768`): The output dimension of the embedding. 
239+         layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization. 
240+         flatten (`bool`, defaults to `True`): Whether or not to flatten the output. 
241+         bias (`bool`, defaults to `True`): Whether or not to use bias. 
242+         interpolation_scale (`float`, defaults to `1`): The scale of the interpolation. 
243+         pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding. 
244+         pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding. 
245+     """ 
185246
186247    def  __init__ (
187248        self ,
@@ -289,7 +350,15 @@ def forward(self, latent):
289350
290351
291352class  LuminaPatchEmbed (nn .Module ):
292-     """2D Image to Patch Embedding with support for Lumina-T2X""" 
353+     """ 
354+     2D Image to Patch Embedding with support for Lumina-T2X 
355+ 
356+     Args: 
357+         patch_size (`int`, defaults to `2`): The size of the patches. 
358+         in_channels (`int`, defaults to `4`): The number of input channels. 
359+         embed_dim (`int`, defaults to `768`): The output dimension of the embedding. 
360+         bias (`bool`, defaults to `True`): Whether or not to use bias. 
361+     """ 
293362
294363    def  __init__ (self , patch_size = 2 , in_channels = 4 , embed_dim = 768 , bias = True ):
295364        super ().__init__ ()
@@ -675,6 +744,20 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
675744
676745
677746def  get_2d_rotary_pos_embed_from_grid (embed_dim , grid , use_real = False ):
747+     """ 
748+     Get 2D RoPE from grid. 
749+ 
750+     Args: 
751+     embed_dim: (`int`): 
752+         The embedding dimension size, corresponding to hidden_size_head. 
753+     grid (`np.ndarray`): 
754+         The grid of the positional embedding. 
755+     use_real (`bool`): 
756+         If True, return real part and imaginary part separately. Otherwise, return complex numbers. 
757+ 
758+     Returns: 
759+         `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. 
760+     """ 
678761    assert  embed_dim  %  4  ==  0 
679762
680763    # use half of dimensions to encode grid_h 
@@ -695,6 +778,23 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
695778
696779
697780def  get_2d_rotary_pos_embed_lumina (embed_dim , len_h , len_w , linear_factor = 1.0 , ntk_factor = 1.0 ):
781+     """ 
782+     Get 2D RoPE from grid. 
783+ 
784+     Args: 
785+     embed_dim: (`int`): 
786+         The embedding dimension size, corresponding to hidden_size_head. 
787+     grid (`np.ndarray`): 
788+         The grid of the positional embedding. 
789+     linear_factor (`float`): 
790+         The linear factor of the positional embedding, which is used to scale the positional embedding in the linear 
791+         layer. 
792+     ntk_factor (`float`): 
793+         The ntk factor of the positional embedding, which is used to scale the positional embedding in the ntk layer. 
794+ 
795+     Returns: 
796+         `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. 
797+     """ 
698798    assert  embed_dim  %  4  ==  0 
699799
700800    emb_h  =  get_1d_rotary_pos_embed (
0 commit comments