@@ -428,10 +428,7 @@ def __init__(
428428        attention_head_dim : int  =  128 ,
429429        num_attention_heads : int  =  24 ,
430430        joint_attention_dim : int  =  4096 ,
431-         pooled_projection_dim : int  =  768 ,
432-         guidance_embeds : bool  =  False ,
433431        axes_dims_rope : Tuple [int , ...] =  (16 , 56 , 56 ),
434-         variant : str  =  "flux" ,
435432        approximator_in_factor : int  =  16 ,
436433        approximator_hidden_dim : int  =  5120 ,
437434        approximator_layers : int  =  5 ,
@@ -446,7 +443,10 @@ def __init__(
446443            num_channels = approximator_in_factor , out_dim = 3  *  num_single_layers  +  2  *  6  *  num_layers  +  2 
447444        )
448445        self .distilled_guidance_layer  =  ChromaApproximator (
449-             in_dim = 64 , out_dim = 3072 , hidden_dim = approximator_hidden_dim , n_layers = approximator_layers 
446+             in_dim = in_channels ,
447+             out_dim = self .inner_dim ,
448+             hidden_dim = approximator_hidden_dim ,
449+             n_layers = approximator_layers ,
450450        )
451451        self .context_embedder  =  nn .Linear (joint_attention_dim , self .inner_dim )
452452        self .x_embedder  =  nn .Linear (in_channels , self .inner_dim )
0 commit comments