@@ -1637,35 +1637,6 @@ def forward(self, timestep, guidance, pooled_projection):
16371637        return  conditioning 
16381638
16391639
1640- class  CombinedTimestepTextProjChromaEmbeddings (nn .Module ):
1641-     def  __init__ (self , factor : int , hidden_dim : int , out_dim : int , n_layers : int , embedding_dim : int ):
1642-         super ().__init__ ()
1643- 
1644-         self .time_proj  =  Timesteps (num_channels = factor , flip_sin_to_cos = True , downscale_freq_shift = 0 )
1645-         self .guidance_proj  =  Timesteps (num_channels = factor , flip_sin_to_cos = True , downscale_freq_shift = 0 )
1646- 
1647-         self .register_buffer (
1648-             "mod_proj" ,
1649-             get_timestep_embedding (torch .arange (out_dim )* 1000 , 2  *  factor , flip_sin_to_cos = True , downscale_freq_shift = 0 , ),
1650-             persistent = False ,
1651-         )
1652- 
1653-     def  forward (
1654-         self , timestep : torch .Tensor , guidance : Optional [torch .Tensor ], pooled_projections : torch .Tensor 
1655-     ) ->  torch .Tensor :
1656-         mod_index_length  =  self .mod_proj .shape [0 ]
1657-         timesteps_proj  =  self .time_proj (timestep ).to (dtype = timestep .dtype )
1658-         guidance_proj  =  self .guidance_proj (torch .tensor ([0 ])).to (dtype = timestep .dtype , device = timestep .device )
1659- 
1660-         mod_proj  =  self .mod_proj .to (dtype = timesteps_proj .dtype , device = timesteps_proj .device )
1661-         timestep_guidance  =  (
1662-             torch .cat ([timesteps_proj , guidance_proj ], dim = 1 ).unsqueeze (1 ).repeat (1 , mod_index_length , 1 )
1663-         )
1664-         input_vec  =  torch .cat ([timestep_guidance , mod_proj .unsqueeze (0 )], dim = - 1 )
1665- 
1666-         return  input_vec 
1667- 
1668- 
16691640class  CogView3CombinedTimestepSizeEmbeddings (nn .Module ):
16701641    def  __init__ (self , embedding_dim : int , condition_dim : int , pooled_projection_dim : int , timesteps_dim : int  =  256 ):
16711642        super ().__init__ ()
@@ -2259,25 +2230,6 @@ def forward(self, caption):
22592230        return  hidden_states 
22602231
22612232
2262- class  ChromaApproximator (nn .Module ):
2263-     def  __init__ (self , in_dim : int , out_dim : int , hidden_dim : int , n_layers : int  =  5 ):
2264-         super ().__init__ ()
2265-         self .in_proj  =  nn .Linear (in_dim , hidden_dim , bias = True )
2266-         self .layers  =  nn .ModuleList (
2267-             [PixArtAlphaTextProjection (hidden_dim , hidden_dim , act_fn = "silu" ) for  _  in  range (n_layers )]
2268-         )
2269-         self .norms  =  nn .ModuleList ([nn .RMSNorm (hidden_dim ) for  _  in  range (n_layers )])
2270-         self .out_proj  =  nn .Linear (hidden_dim , out_dim )
2271- 
2272-     def  forward (self , x ):
2273-         x  =  self .in_proj (x )
2274- 
2275-         for  layer , norms  in  zip (self .layers , self .norms ):
2276-             x  =  x  +  layer (norms (x ))
2277- 
2278-         return  self .out_proj (x )
2279- 
2280- 
22812233class  IPAdapterPlusImageProjectionBlock (nn .Module ):
22822234    def  __init__ (
22832235        self ,
0 commit comments