Skip to content

Commit 01a7212

Browse files
authored
Merge pull request #244 from basf/develop
Include new model: Tangos Rename self.base_model to self.estimator in sklearn interface
2 parents 180f126 + 9fe0cea commit 01a7212

File tree

15 files changed

+382
-45
lines changed

15 files changed

+382
-45
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ Mambular is a Python library for tabular deep learning. It includes models that
2323

2424
<h3>⚡ What's New ⚡</h3>
2525
<ul>
26+
<li>New Models: `Tangos`, `AutoInt`, `Trompt`</li>
27+
<li>Pretraining optionality for suitable models.</li>
2628
<li>Individual preprocessing: preprocess each feature differently, use pre-trained models for categorical encoding</li>
2729
<li>Extract latent representations of tables</li>
2830
<li>Use embeddings as inputs</li>
@@ -78,7 +80,8 @@ Mambular is a Python package that brings the power of advanced deep learning arc
7880
| `NDTF` | A neural decision forest using soft decision trees. See [Kontschieder et al.](https://openaccess.thecvf.com/content_iccv_2015/html/Kontschieder_Deep_Neural_Decision_ICCV_2015_paper.html) for inspiration. |
7981
| `SAINT` | Improve neural networs via Row Attention and Contrastive Pre-Training, introduced [here](https://arxiv.org/pdf/2106.01342). |
8082
| `AutoInt` | Automatic Feature Interaction Learning via Self-Attentive Neural Networks introduced [here](https://arxiv.org/abs/1810.11921). |
81-
| `Trompt ` | Trompt: Towards a Better Deep Neural Network for Tabular Data introduced [here](https://arxiv.org/abs/2305.18446). |
83+
| `Trompt` | Trompt: Towards a Better Deep Neural Network for Tabular Data introduced [here](https://arxiv.org/abs/2305.18446). |
84+
| `Tangos` | Tangos: Regularizing Tabular Neural Networks through Gradient Orthogonalization and Specialization introduced [here](https://openreview.net/pdf?id=n6H86gW8u0d). |
8285

8386

8487

mambular/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@
1717

1818
# The following line *must* be the last in the module, exactly as formatted:
1919

20-
__version__ = "1.3.0"
20+
__version__ = "1.3.1"
2121

mambular/base_models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
from .autoint import AutoInt
1414
from .trompt import Trompt
1515
from .enode import ENODE
16+
from .tangos import Tangos
1617

1718
__all__ = [
19+
"Tangos",
1820
"ENODE",
1921
"Trompt",
2022
"AutoInt",

mambular/base_models/tangos.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
import torch
2+
import torch.nn as nn
3+
import numpy as np
4+
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
5+
from ..configs.tangos_config import DefaultTangosConfig
6+
from ..utils.get_feature_dimensions import get_feature_dimensions
7+
from .utils.basemodel import BaseModel
8+
9+
10+
class Tangos(BaseModel):
11+
"""
12+
A Multi-Layer Perceptron (MLP) model with optional GLU activation, batch normalization, layer normalization, and dropout.
13+
It includes a penalty term for specialization and orthogonality.
14+
15+
Parameters
16+
----------
17+
feature_information : tuple
18+
A tuple containing feature information for numerical and categorical features.
19+
num_classes : int, optional (default=1)
20+
The number of output classes.
21+
config : DefaultTangosConfig, optional (default=DefaultTangosConfig())
22+
Configuration object defining model hyperparameters.
23+
**kwargs : dict
24+
Additional arguments for the base model.
25+
26+
Attributes
27+
----------
28+
returns_ensemble : bool
29+
Whether the model returns an ensemble of predictions.
30+
lamda1 : float
31+
Regularization weight for the specialization loss.
32+
lamda2 : float
33+
Regularization weight for the orthogonality loss.
34+
subsample : float
35+
Proportion of neuron pairs to use for orthogonality loss calculation.
36+
embedding_layer : EmbeddingLayer or None
37+
Optional embedding layer for categorical features.
38+
layers : nn.ModuleList
39+
The main MLP layers including linear, normalization, and activation layers.
40+
head : nn.Linear
41+
The final output layer.
42+
"""
43+
def __init__(
44+
self,
45+
feature_information: tuple,
46+
num_classes=1,
47+
config: DefaultTangosConfig = DefaultTangosConfig(),
48+
**kwargs
49+
):
50+
super().__init__(config=config, **kwargs)
51+
self.save_hyperparameters(ignore=["feature_information"])
52+
self.returns_ensemble = False
53+
54+
self.lamda1 = config.lamda1
55+
self.lamda2 = config.lamda2
56+
self.subsample = config.subsample
57+
58+
input_dim = get_feature_dimensions(*feature_information)
59+
60+
# Initialize layers
61+
self.layers = nn.ModuleList()
62+
63+
# Input layer
64+
self.layers.append(nn.Linear(input_dim, self.hparams.layer_sizes[0]))
65+
if self.hparams.batch_norm:
66+
self.layers.append(nn.BatchNorm1d(self.hparams.layer_sizes[0]))
67+
68+
if self.hparams.use_glu:
69+
self.layers.append(nn.GLU())
70+
else:
71+
self.layers.append(self.hparams.activation)
72+
if self.hparams.dropout > 0.0:
73+
self.layers.append(nn.Dropout(self.hparams.dropout))
74+
75+
# Hidden layers
76+
for i in range(1, len(self.hparams.layer_sizes)):
77+
self.layers.append(
78+
nn.Linear(self.hparams.layer_sizes[i - 1], self.hparams.layer_sizes[i])
79+
)
80+
if self.hparams.batch_norm:
81+
self.layers.append(nn.BatchNorm1d(self.hparams.layer_sizes[i]))
82+
if self.hparams.layer_norm:
83+
self.layers.append(nn.LayerNorm(self.hparams.layer_sizes[i]))
84+
if self.hparams.use_glu:
85+
self.layers.append(nn.GLU())
86+
else:
87+
self.layers.append(self.hparams.activation)
88+
if self.hparams.dropout > 0.0:
89+
self.layers.append(nn.Dropout(self.hparams.dropout))
90+
91+
# Output layer
92+
self.head = nn.Linear(self.hparams.layer_sizes[-1], num_classes)
93+
94+
def repr_forward(self, x) -> torch.Tensor:
95+
"""
96+
Computes the forward pass for feature representations.
97+
98+
This method processes the input through the MLP layers, optionally using
99+
skip connections.
100+
101+
Parameters
102+
----------
103+
x : torch.Tensor
104+
Input tensor of shape (batch_size, feature_dim).
105+
106+
Returns
107+
-------
108+
torch.Tensor
109+
Output tensor after passing through the representation layers.
110+
"""
111+
112+
x = x.unsqueeze(0)
113+
114+
for i in range(len(self.layers)):
115+
if isinstance(self.layers[i], nn.Linear):
116+
out = self.layers[i](x)
117+
if self.hparams.skip_connections and x.shape == out.shape:
118+
x = x + out
119+
else:
120+
x = out
121+
else:
122+
x = self.layers[i](x)
123+
124+
return x
125+
126+
def forward(self, *data) -> torch.Tensor:
127+
"""
128+
Performs a forward pass of the MLP model.
129+
130+
This method concatenates all input tensors before applying MLP layers.
131+
132+
Parameters
133+
----------
134+
data : tuple
135+
A tuple containing lists of numerical, categorical, and embedded feature tensors.
136+
137+
Returns
138+
-------
139+
torch.Tensor
140+
The output tensor of shape (batch_size, num_classes).
141+
"""
142+
143+
x = torch.cat([t for tensors in data for t in tensors], dim=1)
144+
145+
for i in range(len(self.layers)):
146+
if isinstance(self.layers[i], nn.Linear):
147+
out = self.layers[i](x)
148+
if self.hparams.skip_connections and x.shape == out.shape:
149+
x = x + out
150+
else:
151+
x = out
152+
else:
153+
x = self.layers[i](x)
154+
x = self.head(x)
155+
return x
156+
157+
def penalty_forward(self, *data):
158+
"""
159+
Computes both the model predictions and a penalty term.
160+
161+
The penalty term includes:
162+
- **Specialization loss**: Measures feature importance concentration.
163+
- **Orthogonality loss**: Encourages diversity among learned features.
164+
165+
The method uses `jacrev` to compute the Jacobian of the representation function.
166+
167+
Parameters
168+
----------
169+
data : tuple
170+
A tuple containing lists of numerical, categorical, and embedded feature tensors.
171+
172+
Returns
173+
-------
174+
tuple
175+
- predictions : torch.Tensor
176+
Model predictions of shape (batch_size, num_classes).
177+
- penalty : torch.Tensor
178+
The computed penalty term for regularization.
179+
"""
180+
181+
x = torch.cat([t for tensors in data for t in tensors], dim=1)
182+
batch_size = x.shape[0]
183+
subsample = np.int32(self.subsample*batch_size)
184+
185+
# Flatten before passing to jacrev
186+
flat_data = torch.cat([t for tensors in data for t in tensors], dim=1)
187+
188+
# Compute Jacobian
189+
jacobian = torch.func.vmap(torch.func.jacrev(self.repr_forward), randomness="different")(flat_data)
190+
jacobian = jacobian.squeeze()
191+
192+
neuron_attr = jacobian.swapaxes(0, 1)
193+
h_dim = neuron_attr.shape[0]
194+
if len(neuron_attr.shape) > 3:
195+
# h_dim x batch_size x features
196+
neuron_attr = neuron_attr.flatten(start_dim=2)
197+
198+
# calculate specialization loss component
199+
spec_loss = torch.norm(neuron_attr, p=1) / (batch_size * h_dim * neuron_attr.shape[2])
200+
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
201+
orth_loss = torch.tensor(0.0, requires_grad=True).to(x.device)
202+
# apply subsampling routine for orthogonalization loss
203+
if self.subsample > 0 and self.subsample < h_dim * (h_dim - 1) / 2:
204+
tensor_pairs = [
205+
list(np.random.choice(h_dim, size=(2), replace=False))
206+
for i in range(subsample)
207+
]
208+
for tensor_pair in tensor_pairs:
209+
pairwise_corr = cos(
210+
neuron_attr[tensor_pair[0], :, :], neuron_attr[tensor_pair[1], :, :]
211+
).norm(p=1)
212+
orth_loss = orth_loss + pairwise_corr
213+
214+
orth_loss = orth_loss / (batch_size * self.subsample)
215+
else:
216+
for neuron_i in range(1, h_dim):
217+
for neuron_j in range(0, neuron_i):
218+
pairwise_corr = cos(
219+
neuron_attr[neuron_i, :, :], neuron_attr[neuron_j, :, :]
220+
).norm(p=1)
221+
orth_loss = orth_loss + pairwise_corr
222+
num_pairs = h_dim * (h_dim - 1) / 2
223+
orth_loss = orth_loss / (batch_size * num_pairs)
224+
225+
penalty = self.lamda1 * spec_loss + self.lamda2 * orth_loss
226+
predictions = self.forward(*data)
227+
228+
return predictions, penalty

mambular/base_models/utils/lightning_wrapper.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
else:
9090
output_dim = num_classes
9191

92-
self.base_model = model_class(
92+
self.estimator = model_class(
9393
config=config,
9494
feature_information=feature_information,
9595
num_classes=output_dim,
@@ -112,7 +112,7 @@ def forward(self, num_features, cat_features, embeddings):
112112
Model output.
113113
"""
114114

115-
return self.base_model.forward(num_features, cat_features, embeddings)
115+
return self.estimator.forward(num_features, cat_features, embeddings)
116116

117117
def compute_loss(self, predictions, y_true):
118118
"""Compute the loss for the given predictions and true labels.
@@ -130,7 +130,7 @@ def compute_loss(self, predictions, y_true):
130130
Computed loss.
131131
"""
132132
if self.lss:
133-
if getattr(self.base_model, "returns_ensemble", False):
133+
if getattr(self.estimator, "returns_ensemble", False):
134134
loss = 0.0
135135
for ensemble_member in range(predictions.shape[1]):
136136
loss += self.family.compute_loss( # type: ignore
@@ -143,7 +143,7 @@ def compute_loss(self, predictions, y_true):
143143
y_true.squeeze(-1),
144144
)
145145

146-
if getattr(self.base_model, "returns_ensemble", False): # Ensemble case
146+
if getattr(self.estimator, "returns_ensemble", False): # Ensemble case
147147
if (
148148
self.loss_fct.__class__.__name__ == "CrossEntropyLoss"
149149
and predictions.dim() == 3
@@ -191,8 +191,8 @@ def training_step(self, batch, batch_idx): # type: ignore
191191
data, labels = batch
192192

193193
# Check if the model has a `penalty_forward` method
194-
if hasattr(self.base_model, "penalty_forward"):
195-
preds, penalty = self.base_model.penalty_forward(*data)
194+
if hasattr(self.estimator, "penalty_forward"):
195+
preds, penalty = self.estimator.penalty_forward(*data)
196196
loss = self.compute_loss(preds, labels) + penalty
197197
else:
198198
preds = self(*data)
@@ -396,7 +396,7 @@ def configure_optimizers(self): # type: ignore
396396

397397
# Initialize the optimizer with the chosen class and parameters
398398
optimizer = optimizer_class(
399-
self.base_model.parameters(),
399+
self.estimator.parameters(),
400400
lr=self.lr,
401401
weight_decay=self.weight_decay,
402402
**self.optimizer_params, # Pass any additional optimizer-specific parameters
@@ -443,9 +443,9 @@ def pretrain_embeddings(
443443
Path to save the pretrained embeddings.
444444
"""
445445
print("🚀 Pretraining embeddings...")
446-
self.base_model.train()
446+
self.estimator.train()
447447

448-
optimizer = torch.optim.Adam(self.base_model.embedding_parameters(), lr=lr)
448+
optimizer = torch.optim.Adam(self.estimator.embedding_parameters(), lr=lr)
449449

450450
# 🔥 Single tqdm progress bar across all epochs and batches
451451
total_batches = pretrain_epochs * len(train_dataloader)
@@ -459,7 +459,7 @@ def pretrain_embeddings(
459459
optimizer.zero_grad()
460460

461461
# Forward pass through embeddings only
462-
embeddings = self.base_model.encode(data, grad=True)
462+
embeddings = self.estimator.encode(data, grad=True)
463463

464464
# Compute nearest neighbors based on task type
465465
knn_indices = self.get_knn(labels, k_neighbors, regression)
@@ -481,7 +481,7 @@ def pretrain_embeddings(
481481
progress_bar.close()
482482

483483
# Save pretrained embeddings
484-
torch.save(self.base_model.get_embedding_state_dict(), save_path)
484+
torch.save(self.estimator.get_embedding_state_dict(), save_path)
485485
print(f"✅ Embeddings saved to {save_path}")
486486

487487
def get_knn(self, labels, k_neighbors=5, regression=True, device=""):

mambular/base_models/utils/pretraining.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def __init__(
2020
pool_sequence=True,
2121
):
2222
super().__init__()
23-
self.base_model = base_model
24-
self.base_model.eval()
23+
self.estimator = base_model
24+
self.estimator.eval()
2525
self.k_neighbors = k_neighbors
2626
self.temperature = temperature
2727
self.lr = lr
@@ -33,9 +33,9 @@ def __init__(
3333
self.loss_fn = nn.CosineEmbeddingLoss(margin=margin, reduction="mean")
3434

3535
def forward(self, x):
36-
x = self.base_model.encode(x, grad=True)
36+
x = self.estimator.encode(x, grad=True)
3737
if self.pool_sequence:
38-
return self.base_model.pool_sequence(x)
38+
return self.estimator.pool_sequence(x)
3939
return x # Return unpooled sequence embeddings (N, S, D)
4040

4141
def get_knn(self, labels):
@@ -140,7 +140,7 @@ def contrastive_loss(self, embeddings, knn_indices, neg_indices):
140140

141141
def training_step(self, batch, batch_idx):
142142

143-
self.base_model.embedding_layer.train()
143+
self.estimator.embedding_layer.train()
144144

145145
data, labels = batch
146146
embeddings = self(data)
@@ -173,7 +173,7 @@ def validation_step(self, batch, batch_idx):
173173
return loss
174174

175175
def configure_optimizers(self):
176-
params = chain(self.base_model.parameters())
176+
params = chain(self.estimator.parameters())
177177
return torch.optim.Adam(params, lr=self.lr)
178178

179179

0 commit comments

Comments
 (0)