@@ -164,17 +164,17 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
164164 return x
165165
166166
167- class CombinedTimestepTextProjChromaEmbeddings (nn .Module ):
168- def __init__ (self , factor : int , hidden_dim : int , out_dim : int , n_layers : int , embedding_dim : int ):
167+ class ChromaCombinedTimestepTextProjEmbeddings (nn .Module ):
168+ def __init__ (self , num_channels : int , out_dim : int ):
169169 super ().__init__ ()
170170
171- self .time_proj = Timesteps (num_channels = factor , flip_sin_to_cos = True , downscale_freq_shift = 0 )
172- self .guidance_proj = Timesteps (num_channels = factor , flip_sin_to_cos = True , downscale_freq_shift = 0 )
171+ self .time_proj = Timesteps (num_channels = num_channels , flip_sin_to_cos = True , downscale_freq_shift = 0 )
172+ self .guidance_proj = Timesteps (num_channels = num_channels , flip_sin_to_cos = True , downscale_freq_shift = 0 )
173173
174174 self .register_buffer (
175175 "mod_proj" ,
176176 get_timestep_embedding (
177- torch .arange (out_dim ) * 1000 , 2 * factor , flip_sin_to_cos = True , downscale_freq_shift = 0
177+ torch .arange (out_dim ) * 1000 , 2 * num_channels , flip_sin_to_cos = True , downscale_freq_shift = 0
178178 ),
179179 persistent = False ,
180180 )
@@ -426,14 +426,16 @@ def __init__(
426426
427427 self .pos_embed = FluxPosEmbed (theta = 10000 , axes_dim = axes_dims_rope )
428428
429- self .time_text_embed = CombinedTimestepTextProjChromaEmbeddings (
430- factor = approximator_in_factor ,
431- hidden_dim = approximator_hidden_dim ,
429+ self .time_text_embed = ChromaCombinedTimestepTextProjEmbeddings (
430+ num_channels = approximator_in_factor ,
432431 out_dim = 3 * num_single_layers + 2 * 6 * num_layers + 2 ,
433- embedding_dim = self .inner_dim ,
432+ )
433+ self .distilled_guidance_layer = ChromaApproximator (
434+ in_dim = 64 ,
435+ out_dim = self .inner_dim ,
436+ hidden_dim = approximator_hidden_dim ,
434437 n_layers = approximator_layers ,
435438 )
436- self .distilled_guidance_layer = ChromaApproximator (in_dim = 64 , out_dim = 3072 , hidden_dim = 5120 , n_layers = 5 )
437439
438440 self .context_embedder = nn .Linear (joint_attention_dim , self .inner_dim )
439441 self .x_embedder = nn .Linear (in_channels , self .inner_dim )
0 commit comments