@@ -211,7 +211,7 @@ def __init__(
211211 scales : tuple [int , ...] = (5 ,),
212212 eps = 1.0e-15 ,
213213 ):
214- super (LiteMLA , self ).__init__ ()
214+ super ().__init__ ()
215215 self .eps = eps
216216 heads = int (in_channels // dim * heads_ratio ) if heads is None else heads
217217
@@ -253,7 +253,6 @@ def __init__(
253253 act_func = act_func [1 ],
254254 )
255255
256- @torch .autocast (device_type = "cuda" , enabled = False )
257256 def relu_linear_att (self , qkv : torch .Tensor ) -> torch .Tensor :
258257 B , _ , H , W = list (qkv .size ())
259258
@@ -292,7 +291,6 @@ def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
292291 out = torch .reshape (out , (B , - 1 , H , W ))
293292 return out
294293
295- @torch .autocast (device_type = "cuda" , enabled = False )
296294 def relu_quadratic_att (self , qkv : torch .Tensor ) -> torch .Tensor :
297295 B , _ , H , W = list (qkv .size ())
298296
@@ -657,11 +655,12 @@ def __init__(
657655 super ().__init__ ()
658656 num_stages = len (width_list )
659657 self .num_stages = num_stages
660- assert len (depth_list ) == num_stages
661- assert len (width_list ) == num_stages
662- assert isinstance (block_type , str ) or (
663- isinstance (block_type , list ) and len (block_type ) == num_stages
664- )
658+
659+ # validate config
660+ if len (depth_list ) != num_stages or len (width_list ) != num_stages :
661+ raise ValueError (f"len(depth_list) { len (depth_list )} and len(width_list) { len (width_list )} should be equal to num_stages { num_stages } " )
662+ if not isinstance (block_type , (str , list )) or (isinstance (block_type , list ) and len (block_type ) != num_stages ):
663+ raise ValueError (f"block_type should be either a str or a list of str with length { num_stages } , but got { block_type } " )
665664
666665 self .project_in = build_encoder_project_in_block (
667666 in_channels = in_channels ,
@@ -725,13 +724,16 @@ def __init__(
725724 super ().__init__ ()
726725 num_stages = len (width_list )
727726 self .num_stages = num_stages
728- assert len (depth_list ) == num_stages
729- assert len (width_list ) == num_stages
730- assert isinstance (block_type , str ) or (
731- isinstance (block_type , list ) and len (block_type ) == num_stages
732- )
733- assert isinstance (norm , str ) or (isinstance (norm , list ) and len (norm ) == num_stages )
734- assert isinstance (act , str ) or (isinstance (act , list ) and len (act ) == num_stages )
727+
728+ # validate config
729+ if len (depth_list ) != num_stages or len (width_list ) != num_stages :
730+ raise ValueError (f"len(depth_list) { len (depth_list )} and len(width_list) { len (width_list )} should be equal to num_stages { num_stages } " )
731+ if not isinstance (block_type , (str , list )) or (isinstance (block_type , list ) and len (block_type ) != num_stages ):
732+ raise ValueError (f"block_type should be either a str or a list of str with length { num_stages } , but got { block_type } " )
733+ if not isinstance (norm , (str , list )) or (isinstance (norm , list ) and len (norm ) != num_stages ):
734+ raise ValueError (f"norm should be either a str or a list of str with length { num_stages } , but got { norm } " )
735+ if not isinstance (act , (str , list )) or (isinstance (act , list ) and len (act ) != num_stages ):
736+ raise ValueError (f"act should be either a str or a list of str with length { num_stages } , but got { act } " )
735737
736738 self .project_in = build_decoder_project_in_block (
737739 in_channels = latent_channels ,
0 commit comments