2424from ...utils import is_torch_version
2525from ..attention import FeedForward
2626from ..attention_processor import Attention , AttentionProcessor
27- from ..embeddings import get_1d_rotary_pos_embed
27+ from ..embeddings import get_1d_rotary_pos_embed , get_timestep_embedding
2828from ..modeling_outputs import Transformer2DModelOutput
2929from ..modeling_utils import ModelMixin
3030from ..normalization import AdaLayerNormContinuous , AdaLayerNormZero , AdaLayerNormZeroSingle
@@ -219,7 +219,8 @@ def __init__(
219219 )
220220
221221 def forward (self , t ):
222- t_freq = timestep_embedding (t , self .frequency_embedding_size , self .max_period ).type (self .mlp [0 ].weight .dtype )
222+ # t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
223+ t_freq = get_timestep_embedding (t , self .frequency_embedding_size , flip_sin_to_cos = True , max_period = self .max_period , downscale_freq_shift = 0 ).type (self .mlp [0 ].weight .dtype )
223224 t_emb = self .mlp (t_freq )
224225 return t_emb
225226
@@ -231,24 +232,22 @@ def __init__(
231232 attention_head_dim : int ,
232233 mlp_width_ratio : str = 4.0 ,
233234 mlp_drop_rate : float = 0.0 ,
234- qkv_bias : bool = True ,
235+ attention_bias : bool = True ,
235236 ) -> None :
236237 super ().__init__ ()
237238
238239 hidden_size = num_attention_heads * attention_head_dim
239240
240241 self .norm1 = nn .LayerNorm (hidden_size , elementwise_affine = True , eps = 1e-6 )
241-
242242 self .attn = Attention (
243243 query_dim = hidden_size ,
244244 cross_attention_dim = None ,
245245 heads = num_attention_heads ,
246246 dim_head = attention_head_dim ,
247- bias = True ,
247+ bias = attention_bias ,
248248 )
249249
250250 self .norm2 = nn .LayerNorm (hidden_size , elementwise_affine = True , eps = 1e-6 )
251-
252251 self .mlp = FeedForward (hidden_size , mult = mlp_width_ratio , activation_fn = "silu" , dropout = mlp_drop_rate )
253252
254253 self .adaLN_modulation = nn .Sequential (
@@ -286,8 +285,8 @@ def __init__(
286285 num_layers : int ,
287286 mlp_width_ratio : float = 4.0 ,
288287 mlp_drop_rate : float = 0.0 ,
289- qkv_bias : bool = True ,
290- ):
288+ attention_bias : bool = True ,
289+ ) -> None :
291290 super ().__init__ ()
292291
293292 self .refiner_blocks = nn .ModuleList (
@@ -297,7 +296,7 @@ def __init__(
297296 attention_head_dim = attention_head_dim ,
298297 mlp_width_ratio = mlp_width_ratio ,
299298 mlp_drop_rate = mlp_drop_rate ,
300- qkv_bias = qkv_bias ,
299+ attention_bias = attention_bias ,
301300 )
302301 for _ in range (num_layers )
303302 ]
@@ -308,7 +307,7 @@ def forward(
308307 hidden_states : torch .Tensor ,
309308 temb : torch .Tensor ,
310309 attention_mask : Optional [torch .Tensor ] = None ,
311- ):
310+ ) -> None :
312311 self_attn_mask = None
313312 if attention_mask is not None :
314313 batch_size = attention_mask .shape [0 ]
@@ -334,13 +333,15 @@ def __init__(
334333 num_layers : int ,
335334 mlp_ratio : float = 4.0 ,
336335 mlp_drop_rate : float = 0.0 ,
337- qkv_bias : bool = True ,
338- ):
336+ attention_bias : bool = True ,
337+ ) -> None :
339338 super ().__init__ ()
340339
341340 hidden_size = num_attention_heads * attention_head_dim
342341
343342 self .input_embedder = nn .Linear (in_channels , hidden_size , bias = True )
343+ # self.time_embed = TimestepEmbedder(hidden_size, nn.SiLU)
344+ # self.context_embed = TextProjection(in_channels, hidden_size, nn.SiLU)
344345 self .t_embedder = TimestepEmbedder (hidden_size , nn .SiLU )
345346 self .c_embedder = TextProjection (in_channels , hidden_size , nn .SiLU )
346347
@@ -350,7 +351,7 @@ def __init__(
350351 num_layers = num_layers ,
351352 mlp_width_ratio = mlp_ratio ,
352353 mlp_drop_rate = mlp_drop_rate ,
353- qkv_bias = qkv_bias ,
354+ attention_bias = attention_bias ,
354355 )
355356
356357 def forward (
@@ -360,6 +361,7 @@ def forward(
360361 attention_mask : Optional [torch .LongTensor ] = None ,
361362 ) -> torch .Tensor :
362363 original_dtype = hidden_states .dtype
364+ # temb = self.time_embed(timestep)
363365 temb = self .t_embedder (timestep )
364366
365367 if attention_mask is None :
@@ -369,6 +371,7 @@ def forward(
369371 pooled_projections = (hidden_states * mask_float ).sum (dim = 1 ) / mask_float .sum (dim = 1 )
370372 pooled_projections = pooled_projections .to (original_dtype )
371373
374+ # pooled_projections = self.context_embed(pooled_projections)
372375 pooled_projections = self .c_embedder (pooled_projections )
373376 emb = temb + pooled_projections
374377
0 commit comments