@@ -127,7 +127,7 @@ def __init__(
127127        time_freq_dim : int ,
128128        time_proj_dim : int ,
129129        text_embed_dim : int ,
130-         image_embedding_dim : Optional [int ] =  None ,
130+         image_embed_dim : Optional [int ] =  None ,
131131    ):
132132        super ().__init__ ()
133133
@@ -138,8 +138,8 @@ def __init__(
138138        self .text_embedder  =  PixArtAlphaTextProjection (text_embed_dim , dim , act_fn = "gelu_tanh" )
139139
140140        self .image_embedder  =  None 
141-         if  image_embedding_dim  is  not None :
142-             self .image_embedder  =  WanImageEmbedding (image_embedding_dim , dim )
141+         if  image_embed_dim  is  not None :
142+             self .image_embedder  =  WanImageEmbedding (image_embed_dim , dim )
143143
144144    def  forward (
145145        self ,
@@ -348,7 +348,7 @@ def __init__(
348348        cross_attn_norm : bool  =  True ,
349349        qk_norm : Optional [str ] =  "rms_norm_across_heads" ,
350350        eps : float  =  1e-6 ,
351-         image_embedding_dim : Optional [int ] =  None ,
351+         image_dim : Optional [int ] =  None ,
352352        added_kv_proj_dim : Optional [int ] =  None ,
353353        rope_max_seq_len : int  =  1024 ,
354354    ) ->  None :
@@ -368,7 +368,7 @@ def __init__(
368368            time_freq_dim = freq_dim ,
369369            time_proj_dim = inner_dim  *  6 ,
370370            text_embed_dim = text_dim ,
371-             image_embedding_dim = image_embedding_dim ,
371+             image_embed_dim = image_dim ,
372372        )
373373
374374        # 3. Transformer blocks 
0 commit comments