Skip to content

Commit 0d7be4a

Browse files
authored
Merge pull request #248 from basf/loss_func_fix
fix num_classes argument for binary classification
2 parents f7fac7b + da9cc8e commit 0d7be4a

File tree

5 files changed

+14
-3
lines changed

5 files changed

+14
-3
lines changed

mambular/__version__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,4 @@
1717

1818
# The following line *must* be the last in the module, exactly as formatted:
1919

20-
__version__ = "1.3.1"
21-
20+
__version__ = "1.3.2"

mambular/models/utils/sklearn_base_classifier.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from sklearn.metrics import accuracy_score, log_loss
66
from .sklearn_parent import SklearnBase
7+
import numpy as np
78

89

910
class SklearnBaseClassifier(SklearnBase):
@@ -85,6 +86,8 @@ def build_model(
8586
The built classifier.
8687
"""
8788

89+
num_classes = len(np.unique(y))
90+
8891
return super()._build_model(
8992
X,
9093
y,
@@ -94,6 +97,7 @@ def build_model(
9497
y_val=y_val,
9598
embeddings=embeddings,
9699
embeddings_val=embeddings_val,
100+
num_classes=num_classes,
97101
random_state=random_state,
98102
batch_size=batch_size,
99103
shuffle=shuffle,
@@ -190,6 +194,7 @@ def fit(
190194
The fitted classifier.
191195
"""
192196

197+
num_classes = len(np.unique(y))
193198
return super().fit(
194199
X=X,
195200
y=y,
@@ -215,6 +220,7 @@ def fit(
215220
train_metrics=train_metrics,
216221
val_metrics=val_metrics,
217222
rebuild=rebuild,
223+
num_classes=num_classes,
218224
**trainer_kwargs,
219225
)
220226

mambular/models/utils/sklearn_base_regressor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def build_model(
9393
y_val=y_val,
9494
embeddings=embeddings,
9595
embeddings_val=embeddings_val,
96+
num_classes=1,
9697
random_state=random_state,
9798
batch_size=batch_size,
9899
shuffle=shuffle,
@@ -198,6 +199,7 @@ def fit(
198199
y_val=y_val,
199200
embeddings=embeddings,
200201
embeddings_val=embeddings_val,
202+
num_classes=1,
201203
max_epochs=max_epochs,
202204
random_state=random_state,
203205
batch_size=batch_size,

mambular/models/utils/sklearn_parent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def _build_model(
120120
y_val=None,
121121
embeddings=None,
122122
embeddings_val=None,
123+
num_classes: int = None,
123124
random_state: int = 101,
124125
batch_size: int = 128,
125126
shuffle: bool = True,
@@ -223,6 +224,7 @@ def _build_model(
223224
weight_decay=(
224225
weight_decay if weight_decay is not None else self.config.weight_decay
225226
),
227+
num_classes=num_classes,
226228
train_metrics=train_metrics,
227229
val_metrics=val_metrics,
228230
optimizer_type=self.optimizer_type,
@@ -273,6 +275,7 @@ def fit(
273275
y_val=None,
274276
embeddings=None,
275277
embeddings_val=None,
278+
num_classes: int = None,
276279
max_epochs: int = 100,
277280
random_state: int = 101,
278281
batch_size: int = 128,
@@ -357,6 +360,7 @@ def fit(
357360
y_val=y_val,
358361
embeddings=embeddings,
359362
embeddings_val=embeddings_val,
363+
num_classes=num_classes,
360364
random_state=random_state,
361365
batch_size=batch_size,
362366
shuffle=shuffle,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[tool.poetry]
22
name = "mambular"
33

4-
version = "1.3.1"
4+
version = "1.3.2"
55

66
description = "A python package for tabular deep learning with mamba blocks."
77
authors = ["Anton Thielmann", "Manish Kumar", "Christoph Weisser"]

0 commit comments

Comments
 (0)