66from __future__ import annotations
77
88import itertools
9- from typing import Any , Literal
9+ from typing import Any , Literal , Optional , Union , List , Dict
1010
1111from pytabkit .models .nn_models import rtdl_num_embeddings
1212
@@ -80,7 +80,7 @@ class OneHotEncoding0d(nn.Module):
8080 # Input: (*, n_cat_features=len(cardinalities))
8181 # Output: (*, sum(cardinalities))
8282
83- def __init__ (self , cardinalities : list [int ]) -> None :
83+ def __init__ (self , cardinalities : List [int ]) -> None :
8484 super ().__init__ ()
8585 self ._cardinalities = cardinalities
8686
@@ -161,9 +161,9 @@ class LinearEfficientEnsemble(nn.Module):
161161 avoids the term "adapter".
162162 """
163163
164- r : None | Tensor
165- s : None | Tensor
166- bias : None | Tensor
164+ r : Optional [ Tensor ]
165+ s : Optional [ Tensor ]
166+ bias : Optional [ Tensor ]
167167
168168 def __init__ (
169169 self ,
@@ -261,8 +261,8 @@ class MLP(nn.Module):
261261 def __init__ (
262262 self ,
263263 * ,
264- d_in : None | int = None ,
265- d_out : None | int = None ,
264+ d_in : Optional [ int ] = None ,
265+ d_out : Optional [ int ] = None ,
266266 n_blocks : int ,
267267 d_block : int ,
268268 dropout : float ,
@@ -326,7 +326,7 @@ def _get_first_ensemble_layer(backbone: MLP) -> LinearEfficientEnsemble:
326326def _init_first_adapter (
327327 weight : Tensor ,
328328 distribution : Literal ['normal' , 'random-signs' ],
329- init_sections : list [int ],
329+ init_sections : List [int ],
330330) -> None :
331331 """Initialize the first adapter.
332332
@@ -389,20 +389,16 @@ def default_zero_weight_decay_condition(
389389 del module_name , parameter
390390 return parameter_name .endswith ('bias' ) or isinstance (
391391 module ,
392- nn .BatchNorm1d
393- | nn .LayerNorm
394- | nn .InstanceNorm1d
395- | rtdl_num_embeddings .LinearEmbeddings
396- | rtdl_num_embeddings .LinearReLUEmbeddings
397- | rtdl_num_embeddings ._Periodic ,
392+ (nn .BatchNorm1d , nn .LayerNorm , nn .InstanceNorm1d , rtdl_num_embeddings .LinearEmbeddings ,
393+ rtdl_num_embeddings .LinearReLUEmbeddings , rtdl_num_embeddings ._Periodic ),
398394 )
399395
400396
401397def make_parameter_groups (
402398 module : nn .Module ,
403399 zero_weight_decay_condition = default_zero_weight_decay_condition ,
404- custom_groups : None | list [ dict [ str , Any ]] = None ,
405- ) -> list [ dict [str , Any ]]:
400+ custom_groups : Optional [ List [ Dict [ str , Any ] ]] = None ,
401+ ) -> List [ Dict [str , Any ]]:
406402 if custom_groups is None :
407403 custom_groups = []
408404 custom_params = frozenset (
@@ -441,11 +437,11 @@ def __init__(
441437 self ,
442438 * ,
443439 n_num_features : int ,
444- cat_cardinalities : list [int ],
445- n_classes : None | int ,
440+ cat_cardinalities : List [int ],
441+ n_classes : Optional [ int ] ,
446442 backbone : dict ,
447- bins : None | list [ Tensor ], # For piecewise-linear encoding/embeddings.
448- num_embeddings : None | dict = None ,
443+ bins : Optional [ List [ Tensor ] ], # For piecewise-linear encoding/embeddings.
444+ num_embeddings : Optional [ Dict ] = None ,
449445 arch_type : Literal [
450446 # Plain feed-forward network without any kind of ensembling.
451447 'plain' ,
@@ -467,7 +463,7 @@ def __init__(
467463 # This variant was not used in the paper.
468464 'tabm-mini-normal' ,
469465 ],
470- k : None | int = None ,
466+ k : Optional [ int ] = None ,
471467 share_training_batches : bool = True ,
472468 ) -> None :
473469 # >>> Validate arguments.
@@ -596,7 +592,7 @@ def __init__(
596592 self .share_training_batches = share_training_batches
597593
598594 def forward (
599- self , x_num : None | Tensor = None , x_cat : None | Tensor = None
595+ self , x_num : Optional [ Tensor ] = None , x_cat : Optional [ Tensor ] = None
600596 ) -> Tensor :
601597 x = []
602598 if x_num is not None :
0 commit comments