@@ -249,7 +249,7 @@ class ChromEncoder(nn.Module):
249249 processing everything to be the same dimensionality, concatenate to form a single
250250 latent dimension."""
251251
252- def __init__ (self , num_inputs : List [int ], latent_dim : int = 32 , activation = nn .PReLU ):
252+ def __init__ (self , num_inputs : list [int ], latent_dim : int = 32 , activation = nn .PReLU ):
253253 super ().__init__ ()
254254 self .num_inputs = num_inputs
255255 self .act = activation
@@ -290,7 +290,7 @@ class ChromDecoder(nn.Module):
290290
291291 def __init__ (
292292 self ,
293- num_outputs : List [int ], # Per-chromosome list of output sizes
293+ num_outputs : list [int ], # Per-chromosome list of output sizes
294294 latent_dim : int = 32 ,
295295 activation = nn .PReLU ,
296296 final_activations = [Exp (), ClippedSoftplus ()],
@@ -513,7 +513,7 @@ def forward_single(self, x, size_factors=None, in_domain: int = 1, out_domain: i
513513 decoded = decoder (encoded )
514514 return self ._combine_output_and_encoded (decoded , encoded , num_non_latent_out )
515515
516- def forward (self , x , size_factors = None , mode : Union [None , Tuple [int , int ]] = None ):
516+ def forward (self , x , size_factors = None , mode : Union [None , tuple [int , int ]] = None ):
517517 if self .flat_mode :
518518 x = self .split_catted_input (x )
519519 assert isinstance (x , (tuple , list ))
@@ -552,11 +552,11 @@ def __init__(
552552 input_dim1 : int ,
553553 input_dim2 : int ,
554554 hidden_dim : int = 16 ,
555- final_activations1 : Union [Callable , List [Callable ]] = [
555+ final_activations1 : Union [Callable , list [Callable ]] = [
556556 Exp (),
557557 ClippedSoftplus (),
558558 ],
559- final_activations2 : Union [Callable , List [Callable ]] = nn .Sigmoid (),
559+ final_activations2 : Union [Callable , list [Callable ]] = nn .Sigmoid (),
560560 flat_mode : bool = True , # Controls if we have to re-split inputs
561561 seed : int = 182822 ,
562562 ):
@@ -633,7 +633,7 @@ def forward_single(self, x, size_factors=None, in_domain: int = 1, out_domain: i
633633 assert isinstance (retval [0 ], (torch .TensorType , torch .Tensor ))
634634 return retval
635635
636- def forward (self , x , size_factors = None , mode : Union [None , Tuple [int , int ]] = None ):
636+ def forward (self , x , size_factors = None , mode : Union [None , tuple [int , int ]] = None ):
637637 if self .flat_mode :
638638 x = self .split_catted_input (x )
639639 assert isinstance (x , (tuple , list ))
@@ -662,7 +662,7 @@ class AssymSplicedAutoEncoder(SplicedAutoEncoder):
662662 def __init__ (
663663 self ,
664664 input_dim1 : int ,
665- input_dim2 : List [int ],
665+ input_dim2 : list [int ],
666666 hidden_dim : int = 16 ,
667667 final_activations1 : list = [Exp (), ClippedSoftplus ()],
668668 final_activations2 = nn .Sigmoid (),
0 commit comments