Skip to content

Commit 8a03aeb

Browse files
committed
try to fix tabm python3.9 compatibility
1 parent eba8a36 commit 8a03aeb

File tree

1 file changed

+18
-22
lines changed

1 file changed

+18
-22
lines changed

pytabkit/models/nn_models/tabm.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from __future__ import annotations
77

88
import itertools
9-
from typing import Any, Literal
9+
from typing import Any, Literal, Optional, Union, List, Dict
1010

1111
from pytabkit.models.nn_models import rtdl_num_embeddings
1212

@@ -80,7 +80,7 @@ class OneHotEncoding0d(nn.Module):
8080
# Input: (*, n_cat_features=len(cardinalities))
8181
# Output: (*, sum(cardinalities))
8282

83-
def __init__(self, cardinalities: list[int]) -> None:
83+
def __init__(self, cardinalities: List[int]) -> None:
8484
super().__init__()
8585
self._cardinalities = cardinalities
8686

@@ -161,9 +161,9 @@ class LinearEfficientEnsemble(nn.Module):
161161
avoids the term "adapter".
162162
"""
163163

164-
r: None | Tensor
165-
s: None | Tensor
166-
bias: None | Tensor
164+
r: Optional[Tensor]
165+
s: Optional[Tensor]
166+
bias: Optional[Tensor]
167167

168168
def __init__(
169169
self,
@@ -261,8 +261,8 @@ class MLP(nn.Module):
261261
def __init__(
262262
self,
263263
*,
264-
d_in: None | int = None,
265-
d_out: None | int = None,
264+
d_in: Optional[int] = None,
265+
d_out: Optional[int] = None,
266266
n_blocks: int,
267267
d_block: int,
268268
dropout: float,
@@ -326,7 +326,7 @@ def _get_first_ensemble_layer(backbone: MLP) -> LinearEfficientEnsemble:
326326
def _init_first_adapter(
327327
weight: Tensor,
328328
distribution: Literal['normal', 'random-signs'],
329-
init_sections: list[int],
329+
init_sections: List[int],
330330
) -> None:
331331
"""Initialize the first adapter.
332332
@@ -389,20 +389,16 @@ def default_zero_weight_decay_condition(
389389
del module_name, parameter
390390
return parameter_name.endswith('bias') or isinstance(
391391
module,
392-
nn.BatchNorm1d
393-
| nn.LayerNorm
394-
| nn.InstanceNorm1d
395-
| rtdl_num_embeddings.LinearEmbeddings
396-
| rtdl_num_embeddings.LinearReLUEmbeddings
397-
| rtdl_num_embeddings._Periodic,
392+
(nn.BatchNorm1d, nn.LayerNorm, nn.InstanceNorm1d, rtdl_num_embeddings.LinearEmbeddings,
393+
rtdl_num_embeddings.LinearReLUEmbeddings, rtdl_num_embeddings._Periodic),
398394
)
399395

400396

401397
def make_parameter_groups(
402398
module: nn.Module,
403399
zero_weight_decay_condition=default_zero_weight_decay_condition,
404-
custom_groups: None | list[dict[str, Any]] = None,
405-
) -> list[dict[str, Any]]:
400+
custom_groups: Optional[List[Dict[str, Any]]] = None,
401+
) -> List[Dict[str, Any]]:
406402
if custom_groups is None:
407403
custom_groups = []
408404
custom_params = frozenset(
@@ -441,11 +437,11 @@ def __init__(
441437
self,
442438
*,
443439
n_num_features: int,
444-
cat_cardinalities: list[int],
445-
n_classes: None | int,
440+
cat_cardinalities: List[int],
441+
n_classes: Optional[int],
446442
backbone: dict,
447-
bins: None | list[Tensor], # For piecewise-linear encoding/embeddings.
448-
num_embeddings: None | dict = None,
443+
bins: Optional[List[Tensor]], # For piecewise-linear encoding/embeddings.
444+
num_embeddings: Optional[Dict] = None,
449445
arch_type: Literal[
450446
# Plain feed-forward network without any kind of ensembling.
451447
'plain',
@@ -467,7 +463,7 @@ def __init__(
467463
# This variant was not used in the paper.
468464
'tabm-mini-normal',
469465
],
470-
k: None | int = None,
466+
k: Optional[int] = None,
471467
share_training_batches: bool = True,
472468
) -> None:
473469
# >>> Validate arguments.
@@ -596,7 +592,7 @@ def __init__(
596592
self.share_training_batches = share_training_batches
597593

598594
def forward(
599-
self, x_num: None | Tensor = None, x_cat: None | Tensor = None
595+
self, x_num: Optional[Tensor] = None, x_cat: Optional[Tensor] = None
600596
) -> Tensor:
601597
x = []
602598
if x_num is not None:

0 commit comments

Comments
 (0)