@@ -29,11 +29,21 @@ def get_sinusoidal_embeddings(
2929    """Returns the positional encoding (same as Tensor2Tensor). 
3030
3131    Args: 
32-         timesteps: a 1-D Tensor of N indices, one per batch element. 
33-         These may be fractional. 
34-         embedding_dim: The number of output channels. 
35-         min_timescale: The smallest time unit (should probably be 0.0). 
36-         max_timescale: The largest time unit. 
32+         timesteps (`jnp.ndarray` of shape `(N,)`): 
33+             A 1-D array of N indices, one per batch element. These may be fractional. 
34+         embedding_dim (`int`): 
35+             The number of output channels. 
36+         freq_shift (`float`, *optional*, defaults to `1`): 
37+             Shift applied to the frequency scaling of the embeddings. 
38+         min_timescale (`float`, *optional*, defaults to `1`): 
39+             The smallest time unit used in the sinusoidal calculation (should probably be 0.0). 
40+         max_timescale (`float`, *optional*, defaults to `1.0e4`): 
41+             The largest time unit used in the sinusoidal calculation. 
42+         flip_sin_to_cos (`bool`, *optional*, defaults to `False`): 
43+             Whether to flip the order of sinusoidal components to cosine first. 
44+         scale (`float`, *optional*, defaults to `1.0`): 
45+             A scaling factor applied to the positional embeddings. 
46+ 
3747    Returns: 
3848        a Tensor of timing signals [N, num_channels] 
3949    """ 
@@ -61,9 +71,9 @@ class FlaxTimestepEmbedding(nn.Module):
6171
6272    Args: 
6373        time_embed_dim (`int`, *optional*, defaults to `32`): 
64-                  Time step embedding dimension 
65-         dtype (:obj: `jnp.dtype`, *optional*, defaults to jnp.float32): 
66-                 Parameters `dtype`  
74+             Time step embedding dimension.  
75+         dtype (`jnp.dtype`, *optional*, defaults to ` jnp.float32` ): 
76+             The data type for the embedding parameters.  
6777    """ 
6878
6979    time_embed_dim : int  =  32 
@@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module):
8393
8494    Args: 
8595        dim (`int`, *optional*, defaults to `32`): 
86-                 Time step embedding dimension 
96+             Time step embedding dimension. 
97+         flip_sin_to_cos (`bool`, *optional*, defaults to `False`): 
98+             Whether to flip the sinusoidal function from sine to cosine. 
99+         freq_shift (`float`, *optional*, defaults to `1`): 
100+             Frequency shift applied to the sinusoidal embeddings. 
87101    """ 
88102
89103    dim : int  =  32 
0 commit comments