@@ -82,6 +82,7 @@ def __init__(
8282 out_channels : int = 2 ,
8383 extra_in_channels : int = 0 ,
8484 time_embedding_type : str = "fourier" ,
85+ time_embedding_dim : Optional [int ] = None ,
8586 flip_sin_to_cos : bool = True ,
8687 use_timestep_embedding : bool = False ,
8788 freq_shift : float = 0.0 ,
@@ -100,15 +101,23 @@ def __init__(
100101
101102 # time
102103 if time_embedding_type == "fourier" :
104+ time_embed_dim = time_embedding_dim or block_out_channels [0 ] * 2
105+ if time_embed_dim % 2 != 0 :
106+ raise ValueError (f"`time_embed_dim` should be divisible by 2, but is { time_embed_dim } ." )
103107 self .time_proj = GaussianFourierProjection (
104- embedding_size = 8 , set_W_to_weight = False , log = False , flip_sin_to_cos = flip_sin_to_cos
108+ embedding_size = time_embed_dim // 2 , set_W_to_weight = False , log = False , flip_sin_to_cos = flip_sin_to_cos
105109 )
106- timestep_input_dim = 2 * block_out_channels [ 0 ]
110+ timestep_input_dim = time_embed_dim
107111 elif time_embedding_type == "positional" :
112+ time_embed_dim = time_embedding_dim or block_out_channels [0 ] * 4
108113 self .time_proj = Timesteps (
109114 block_out_channels [0 ], flip_sin_to_cos = flip_sin_to_cos , downscale_freq_shift = freq_shift
110115 )
111116 timestep_input_dim = block_out_channels [0 ]
117+ else :
118+ raise ValueError (
119+ f"{ time_embedding_type } does not exist. Please make sure to use one of `fourier` or `positional`."
120+ )
112121
113122 if use_timestep_embedding :
114123 time_embed_dim = block_out_channels [0 ] * 4
0 commit comments