File tree Expand file tree Collapse file tree 1 file changed +5
-0
lines changed
src/diffusers/models/autoencoders Expand file tree Collapse file tree 1 file changed +5
-0
lines changed Original file line number Diff line number Diff line change @@ -992,7 +992,9 @@ def __init__(
992992 # timestep embedding
993993 self .time_embedder = None
994994 self .scale_shift_table = None
995+ self .timestep_scale_multiplier = None
995996 if timestep_conditioning :
997+ self .timestep_scale_multiplier = nn .Parameter (torch .tensor (1000.0 , dtype = torch .float32 ))
996998 self .time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings (output_channel * 2 , 0 )
997999 self .scale_shift_table = nn .Parameter (torch .randn (2 , output_channel ) / output_channel ** 0.5 )
9981000
@@ -1001,6 +1003,9 @@ def __init__(
10011003 def forward (self , hidden_states : torch .Tensor , temb : Optional [torch .Tensor ] = None ) -> torch .Tensor :
10021004 hidden_states = self .conv_in (hidden_states )
10031005
1006+ if self .timestep_scale_multiplier is not None :
1007+ temb = temb * self .timestep_scale_multiplier
1008+
10041009 if torch .is_grad_enabled () and self .gradient_checkpointing :
10051010 hidden_states = self ._gradient_checkpointing_func (self .mid_block , hidden_states , temb )
10061011
You can’t perform that action at this time.
0 commit comments