File tree Expand file tree Collapse file tree 1 file changed +5
-0
lines changed 
src/diffusers/models/autoencoders Expand file tree Collapse file tree 1 file changed +5
-0
lines changed Original file line number Diff line number Diff line change @@ -981,7 +981,9 @@ def __init__(
981981        # timestep embedding 
982982        self .time_embedder  =  None 
983983        self .scale_shift_table  =  None 
984+         self .timestep_scale_multiplier  =  None 
984985        if  timestep_conditioning :
986+             self .timestep_scale_multiplier  =  nn .Parameter (torch .tensor (1000.0 , dtype = torch .float32 ))
985987            self .time_embedder  =  PixArtAlphaCombinedTimestepSizeEmbeddings (output_channel  *  2 , 0 )
986988            self .scale_shift_table  =  nn .Parameter (torch .randn (2 , output_channel ) /  output_channel ** 0.5 )
987989
@@ -990,6 +992,9 @@ def __init__(
990992    def  forward (self , hidden_states : torch .Tensor , temb : Optional [torch .Tensor ] =  None ) ->  torch .Tensor :
991993        hidden_states  =  self .conv_in (hidden_states )
992994
995+         if  self .timestep_scale_multiplier  is  not None :
996+             temb  =  temb  *  self .timestep_scale_multiplier 
997+ 
993998        if  torch .is_grad_enabled () and  self .gradient_checkpointing :
994999            hidden_states  =  self ._gradient_checkpointing_func (self .mid_block , hidden_states , temb )
9951000
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments