From ec0673ad82dfa776c9228d8e4e3bebbb52c091d2 Mon Sep 17 00:00:00 2001 From: wony617 <49024958+Jwaminju@users.noreply.github.com> Date: Sun, 6 Oct 2024 03:18:21 +0900 Subject: [PATCH 1/3] [docs] refactoring docstrings in `models/embeddings_flax.py` --- src/diffusers/models/embeddings_flax.py | 34 +++++++++++++++++-------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index 8e343be0d3b7..fbd9d2fd522f 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -29,11 +29,21 @@ def get_sinusoidal_embeddings( """Returns the positional encoding (same as Tensor2Tensor). Args: - timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - embedding_dim: The number of output channels. - min_timescale: The smallest time unit (should probably be 0.0). - max_timescale: The largest time unit. + timesteps (`jnp.ndarray` of shape `(N,)`): + A 1-D array of N indices, one per batch element. These may be fractional. + embedding_dim (`int`): + The number of output channels. + freq_shift (`float`, *optional*, defaults to `1`): + Shift applied to the frequency scaling of the embeddings. + min_timescale (`float`, *optional*, defaults to `1`): + The smallest time unit used in the sinusoidal calculation (should probably be 0.0). + max_timescale (`float`, *optional*, defaults to `1.0e4`): + The largest time unit used in the sinusoidal calculation. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the order of sinusoidal components to cosine first. + scale (`float`, *optional*, defaults to `1.0`): + A scaling factor applied to the positional embeddings. + Returns: a Tensor of timing signals [N, num_channels] """ @@ -58,12 +68,12 @@ def get_sinusoidal_embeddings( class FlaxTimestepEmbedding(nn.Module): r""" Time step Embedding Module. Learns embeddings for input time steps. - + Args: time_embed_dim (`int`, *optional*, defaults to `32`): - Time step embedding dimension - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` + Time step embedding dimension. + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + The data type for the embedding parameters. """ time_embed_dim: int = 32 @@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module): Args: dim (`int`, *optional*, defaults to `32`): - Time step embedding dimension + Time step embedding dimension. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sinusoidal function from sine to cosine. + freq_shift (`float`, *optional*, defaults to `1`): + Frequency shift applied to the sinusoidal embeddings. """ dim: int = 32 From 0b5688396d2dbb529c04e59a6e8e692d9aa721f7 Mon Sep 17 00:00:00 2001 From: wony617 <49024958+Jwaminju@users.noreply.github.com> Date: Tue, 8 Oct 2024 22:07:44 +0900 Subject: [PATCH 2/3] Update src/diffusers/models/embeddings_flax.py --- src/diffusers/models/embeddings_flax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index fbd9d2fd522f..6ccde81f210a 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -68,7 +68,6 @@ def get_sinusoidal_embeddings( class FlaxTimestepEmbedding(nn.Module): r""" Time step Embedding Module. Learns embeddings for input time steps. - Args: time_embed_dim (`int`, *optional*, defaults to `32`): Time step embedding dimension. From b46c71421251702323264f8642df2f520572cee8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Oct 2024 15:10:25 +0200 Subject: [PATCH 3/3] make style --- src/diffusers/models/embeddings_flax.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index 6ccde81f210a..92b5a6c35883 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -68,6 +68,7 @@ def get_sinusoidal_embeddings( class FlaxTimestepEmbedding(nn.Module): r""" Time step Embedding Module. Learns embeddings for input time steps. + Args: time_embed_dim (`int`, *optional*, defaults to `32`): Time step embedding dimension.