diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index 8e343be0d3b7..92b5a6c35883 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -29,11 +29,21 @@ def get_sinusoidal_embeddings( """Returns the positional encoding (same as Tensor2Tensor). Args: - timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - embedding_dim: The number of output channels. - min_timescale: The smallest time unit (should probably be 0.0). - max_timescale: The largest time unit. + timesteps (`jnp.ndarray` of shape `(N,)`): + A 1-D array of N indices, one per batch element. These may be fractional. + embedding_dim (`int`): + The number of output channels. + freq_shift (`float`, *optional*, defaults to `1`): + Shift applied to the frequency scaling of the embeddings. + min_timescale (`float`, *optional*, defaults to `1`): + The smallest time unit used in the sinusoidal calculation (should probably be 0.0). + max_timescale (`float`, *optional*, defaults to `1.0e4`): + The largest time unit used in the sinusoidal calculation. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the order of sinusoidal components to cosine first. + scale (`float`, *optional*, defaults to `1.0`): + A scaling factor applied to the positional embeddings. + Returns: a Tensor of timing signals [N, num_channels] """ @@ -61,9 +71,9 @@ class FlaxTimestepEmbedding(nn.Module): Args: time_embed_dim (`int`, *optional*, defaults to `32`): - Time step embedding dimension - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` + Time step embedding dimension. + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + The data type for the embedding parameters. """ time_embed_dim: int = 32 @@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module): Args: dim (`int`, *optional*, defaults to `32`): - Time step embedding dimension + Time step embedding dimension. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sinusoidal function from sine to cosine. + freq_shift (`float`, *optional*, defaults to `1`): + Frequency shift applied to the sinusoidal embeddings. """ dim: int = 32