2424from ..attention_processor import SanaMultiscaleLinearAttention
2525from ..modeling_utils import ModelMixin
2626from ..normalization import RMSNorm , get_normalization
27+ from .vae import DecoderOutput
2728
2829
2930class GLUMBConv (nn .Module ):
@@ -90,8 +91,8 @@ class EfficientViTBlock(nn.Module):
9091 def __init__ (
9192 self ,
9293 in_channels : int ,
93- heads_ratio : float = 1.0 ,
94- dim : int = 32 ,
94+ mult : float = 1.0 ,
95+ attention_head_dim : int = 32 ,
9596 qkv_multiscales : Tuple [int , ...] = (5 ,),
9697 norm_type : str = "batch_norm" ,
9798 ) -> None :
@@ -100,8 +101,8 @@ def __init__(
100101 self .attn = SanaMultiscaleLinearAttention (
101102 in_channels = in_channels ,
102103 out_channels = in_channels ,
103- heads_ratio = heads_ratio ,
104- attention_head_dim = dim ,
104+ mult = mult ,
105+ attention_head_dim = attention_head_dim ,
105106 norm_type = norm_type ,
106107 kernel_sizes = qkv_multiscales ,
107108 residual_connection = True ,
@@ -122,6 +123,7 @@ def get_block(
122123 block_type : str ,
123124 in_channels : int ,
124125 out_channels : int ,
126+ attention_head_dim : int ,
125127 norm_type : str ,
126128 act_fn : str ,
127129 qkv_mutliscales : Tuple [int ] = (),
@@ -130,7 +132,9 @@ def get_block(
130132 block = ResBlock (in_channels , out_channels , norm_type , act_fn )
131133
132134 elif block_type == "EfficientViTBlock" :
133- block = EfficientViTBlock (in_channels , norm_type = norm_type , qkv_multiscales = qkv_mutliscales )
135+ block = EfficientViTBlock (
136+ in_channels , attention_head_dim = attention_head_dim , norm_type = norm_type , qkv_multiscales = qkv_mutliscales
137+ )
134138
135139 else :
136140 raise ValueError (f"Block with { block_type = } is not supported." )
@@ -224,6 +228,7 @@ def __init__(
224228 self ,
225229 in_channels : int ,
226230 latent_channels : int ,
231+ attention_head_dim : int = 32 ,
227232 block_type : Union [str , Tuple [str ]] = "ResBlock" ,
228233 block_out_channels : Tuple [int ] = (128 , 256 , 512 , 512 , 1024 , 1024 ),
229234 layers_per_block : Tuple [int ] = (2 , 2 , 2 , 2 , 2 , 2 ),
@@ -262,6 +267,7 @@ def __init__(
262267 block_type [i ],
263268 out_channel ,
264269 out_channel ,
270+ attention_head_dim = attention_head_dim ,
265271 norm_type = "rms_norm" ,
266272 act_fn = "silu" ,
267273 qkv_mutliscales = qkv_multiscales [i ],
@@ -305,6 +311,7 @@ def __init__(
305311 self ,
306312 in_channels : int ,
307313 latent_channels : int ,
314+ attention_head_dim : int = 32 ,
308315 block_type : Union [str , Tuple [str ]] = "ResBlock" ,
309316 block_out_channels : Tuple [int ] = (128 , 256 , 512 , 512 , 1024 , 1024 ),
310317 layers_per_block : Tuple [int ] = (2 , 2 , 2 , 2 , 2 , 2 ),
@@ -348,6 +355,7 @@ def __init__(
348355 block_type [i ],
349356 out_channel ,
350357 out_channel ,
358+ attention_head_dim = attention_head_dim ,
351359 norm_type = norm_type [i ],
352360 act_fn = act_fn [i ],
353361 qkv_mutliscales = qkv_multiscales [i ],
@@ -425,13 +433,14 @@ class AutoencoderDC(ModelMixin, ConfigMixin):
425433 A scaling factor applied during model operations.
426434 """
427435
428- _supports_gradient_checkpointing = True
436+ _supports_gradient_checkpointing = False
429437
430438 @register_to_config
431439 def __init__ (
432440 self ,
433441 in_channels : int = 3 ,
434442 latent_channels : int = 32 ,
443+ attention_head_dim : int = 32 ,
435444 encoder_block_types : Union [str , Tuple [str ]] = "ResBlock" ,
436445 decoder_block_types : Union [str , Tuple [str ]] = "ResBlock" ,
437446 encoder_block_out_channels : Tuple [int , ...] = (128 , 256 , 512 , 512 , 1024 , 1024 ),
@@ -451,6 +460,7 @@ def __init__(
451460 self .encoder = Encoder (
452461 in_channels = in_channels ,
453462 latent_channels = latent_channels ,
463+ attention_head_dim = attention_head_dim ,
454464 block_type = encoder_block_types ,
455465 block_out_channels = encoder_block_out_channels ,
456466 layers_per_block = encoder_layers_per_block ,
@@ -460,6 +470,7 @@ def __init__(
460470 self .decoder = Decoder (
461471 in_channels = in_channels ,
462472 latent_channels = latent_channels ,
473+ attention_head_dim = attention_head_dim ,
463474 block_type = decoder_block_types ,
464475 block_out_channels = decoder_block_out_channels ,
465476 layers_per_block = decoder_layers_per_block ,
@@ -480,7 +491,9 @@ def decode(self, x: torch.Tensor) -> torch.Tensor:
480491 x = self .decoder (x )
481492 return x
482493
483- def forward (self , x : torch .Tensor ) -> torch .Tensor :
484- x = self .encoder (x )
485- x = self .decoder (x )
486- return x
494+ def forward (self , sample : torch .Tensor , return_dict : bool = True ) -> torch .Tensor :
495+ z = self .encode (sample )
496+ dec = self .decode (z )
497+ if not return_dict :
498+ return (dec ,)
499+ return DecoderOutput (sample = dec )
0 commit comments