@@ -248,7 +248,6 @@ def forward(
248248
249249
250250class SanaTransformer2DModel (ModelMixin , ConfigMixin ):
251- # TODO: Change pixart name below
252251 r"""
253252 A 2D Transformer model as introduced in Sana family of models (https://arxiv.org/abs/2310.00426,
254253 https://arxiv.org/abs/2403.04692).
@@ -272,6 +271,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin):
272271 The width of the latent images. This parameter is fixed during training.
273272 patch_size (int, defaults to 1):
274273 Size of the patches the model processes, relevant for architectures working on non-sequential data.
274+ activation_fn (str, optional, defaults to "gelu-approximate"):
275+ Activation function to use in feed-forward networks within Transformer blocks.
275276 num_embeds_ada_norm (int, optional, defaults to 1000):
276277 Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
277278 inference.
@@ -311,6 +312,7 @@ def __init__(
311312 attention_bias : bool = True ,
312313 sample_size : int = 32 ,
313314 patch_size : int = 1 ,
315+ activation_fn : tuple = None ,
314316 num_embeds_ada_norm : Optional [int ] = 1000 ,
315317 upcast_attention : bool = False ,
316318 norm_type : str = "ada_norm_single" ,
0 commit comments