Skip to content

Commit 9335c36

Browse files
committed
matched 1d temb process to 2d
1 parent f1b3cca commit 9335c36

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

src/diffusers/models/unets/unet_1d.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def __init__(
8282
out_channels: int = 2,
8383
extra_in_channels: int = 0,
8484
time_embedding_type: str = "fourier",
85+
time_embedding_dim: int = 0,
8586
flip_sin_to_cos: bool = True,
8687
use_timestep_embedding: bool = False,
8788
freq_shift: float = 0.0,
@@ -100,15 +101,23 @@ def __init__(
100101

101102
# time
102103
if time_embedding_type == "fourier":
104+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
105+
if time_embed_dim % 2 != 0:
106+
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
103107
self.time_proj = GaussianFourierProjection(
104-
embedding_size=block_out_channels[0], set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
108+
embedding_size=time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
105109
)
106-
timestep_input_dim = 2 * block_out_channels[0]
110+
timestep_input_dim = time_embed_dim
107111
elif time_embedding_type == "positional":
112+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
108113
self.time_proj = Timesteps(
109114
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
110115
)
111116
timestep_input_dim = block_out_channels[0]
117+
else:
118+
raise ValueError(
119+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
120+
)
112121

113122
if use_timestep_embedding:
114123
time_embed_dim = block_out_channels[0] * 4

0 commit comments

Comments
 (0)