Skip to content

Commit d3872b9

Browse files
authored
Merge pull request #252 from basf/data_check
Data check
2 parents fdbb12b + 4211d29 commit d3872b9

File tree

7 files changed

+143
-2
lines changed

7 files changed

+143
-2
lines changed

mambular/__version__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717

1818
# The following line *must* be the last in the module, exactly as formatted:
1919

20-
__version__ = "1.3.2"
20+
__version__ = "1.4.0"
21+

mambular/models/utils/sklearn_base_classifier.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from sklearn.metrics import accuracy_score, log_loss
66
from .sklearn_parent import SklearnBase
7+
import numpy as np
78

89

910
class SklearnBaseClassifier(SklearnBase):
@@ -85,6 +86,8 @@ def build_model(
8586
The built classifier.
8687
"""
8788

89+
num_classes = len(np.unique(y))
90+
8891
return super()._build_model(
8992
X,
9093
y,
@@ -94,6 +97,7 @@ def build_model(
9497
y_val=y_val,
9598
embeddings=embeddings,
9699
embeddings_val=embeddings_val,
100+
num_classes=num_classes,
97101
random_state=random_state,
98102
batch_size=batch_size,
99103
shuffle=shuffle,
@@ -190,6 +194,7 @@ def fit(
190194
The fitted classifier.
191195
"""
192196

197+
num_classes = len(np.unique(y))
193198
return super().fit(
194199
X=X,
195200
y=y,
@@ -215,6 +220,7 @@ def fit(
215220
train_metrics=train_metrics,
216221
val_metrics=val_metrics,
217222
rebuild=rebuild,
223+
num_classes=num_classes,
218224
**trainer_kwargs,
219225
)
220226

mambular/models/utils/sklearn_base_regressor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def build_model(
9393
y_val=y_val,
9494
embeddings=embeddings,
9595
embeddings_val=embeddings_val,
96+
num_classes=1,
9697
random_state=random_state,
9798
batch_size=batch_size,
9899
shuffle=shuffle,
@@ -198,6 +199,7 @@ def fit(
198199
y_val=y_val,
199200
embeddings=embeddings,
200201
embeddings_val=embeddings_val,
202+
num_classes=1,
201203
max_epochs=max_epochs,
202204
random_state=random_state,
203205
batch_size=batch_size,

mambular/models/utils/sklearn_parent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def _build_model(
120120
y_val=None,
121121
embeddings=None,
122122
embeddings_val=None,
123+
num_classes: int = None,
123124
random_state: int = 101,
124125
batch_size: int = 128,
125126
shuffle: bool = True,
@@ -223,6 +224,7 @@ def _build_model(
223224
weight_decay=(
224225
weight_decay if weight_decay is not None else self.config.weight_decay
225226
),
227+
num_classes=num_classes,
226228
train_metrics=train_metrics,
227229
val_metrics=val_metrics,
228230
optimizer_type=self.optimizer_type,
@@ -273,6 +275,7 @@ def fit(
273275
y_val=None,
274276
embeddings=None,
275277
embeddings_val=None,
278+
num_classes: int = None,
276279
max_epochs: int = 100,
277280
random_state: int = 101,
278281
batch_size: int = 128,
@@ -357,6 +360,7 @@ def fit(
357360
y_val=y_val,
358361
embeddings=embeddings,
359362
embeddings_val=embeddings_val,
363+
num_classes=num_classes,
360364
random_state=random_state,
361365
batch_size=batch_size,
362366
shuffle=shuffle,

mambular/preprocessing/preprocessor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
OneHotFromOrdinal,
2828
ToFloatTransformer,
2929
)
30+
from .utils import check_inputs
3031
from sklearn.base import TransformerMixin
3132

3233

@@ -118,6 +119,7 @@ def __init__(
118119
use_decision_tree_knots=True,
119120
knots_strategy="uniform",
120121
spline_implementation="sklearn",
122+
min_unique_vals=5,
121123
):
122124
self.n_bins = n_bins
123125
self.numerical_preprocessing = (
@@ -176,6 +178,7 @@ def __init__(
176178
self.use_decision_tree_knots = use_decision_tree_knots
177179
self.knots_strategy = knots_strategy
178180
self.spline_implementation = spline_implementation
181+
self.min_unique_vals = min_unique_vals
179182

180183
def get_params(self, deep=True):
181184
"""Get parameters for the preprocessor.
@@ -307,6 +310,15 @@ def fit(self, X, y=None, embeddings=None):
307310
self._fit_embeddings(embeddings)
308311

309312
numerical_features, categorical_features = self._detect_column_types(X)
313+
314+
check_inputs(
315+
X,
316+
y,
317+
numerical_features,
318+
categorical_features,
319+
task_type=self.task,
320+
min_samples=self.min_unique_vals,
321+
)
310322
transformers = []
311323

312324
if numerical_features:

mambular/preprocessing/utils.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import pandas as pd
2+
import numpy as np
3+
import warnings
4+
5+
6+
def check_inputs(
7+
X,
8+
y=None,
9+
numerical_columns=None,
10+
categorical_columns=None,
11+
task_type=None,
12+
min_samples=5,
13+
):
14+
"""
15+
Perform thorough validation on input features and target.
16+
17+
Parameters
18+
----------
19+
X : pd.DataFrame or dict
20+
Input features.
21+
y : array-like, optional
22+
Target values.
23+
numerical_columns : list of str
24+
Columns expected to be numerical.
25+
categorical_columns : list of str
26+
Columns expected to be categorical.
27+
task_type : str, optional
28+
One of {"regression", "binary", "multiclass"}. If specified, target checks will apply accordingly.
29+
min_samples : int, optional
30+
Minimum number of distinct values required in any feature or target.
31+
32+
Raises
33+
------
34+
ValueError
35+
If any feature or target fails validation checks.
36+
"""
37+
if isinstance(X, dict):
38+
X = pd.DataFrame(X)
39+
40+
if not isinstance(X, pd.DataFrame):
41+
raise TypeError("X must be a DataFrame or a dict convertible to DataFrame.")
42+
43+
if X.empty:
44+
raise ValueError("X must not be empty.")
45+
46+
if numerical_columns is None:
47+
numerical_columns = []
48+
if categorical_columns is None:
49+
categorical_columns = []
50+
51+
all_cols = set(numerical_columns) | set(categorical_columns)
52+
missing_cols = all_cols - set(X.columns)
53+
if missing_cols:
54+
raise ValueError(
55+
f"The following specified columns are missing in X: {missing_cols}"
56+
)
57+
58+
# Check numerical features
59+
for col in numerical_columns:
60+
series = X[col]
61+
if series.nunique(dropna=False) < min_samples:
62+
raise ValueError(
63+
f"Numerical feature '{col}' has less than {min_samples} unique values."
64+
)
65+
if not np.issubdtype(series.dtype, np.number):
66+
raise TypeError(f"Numerical feature '{col}' must be numeric.")
67+
if not np.all(np.isfinite(series.dropna())):
68+
raise ValueError(
69+
f"Numerical feature '{col}' contains non-finite values (inf or NaN)."
70+
)
71+
72+
# Check categorical features
73+
for col in categorical_columns:
74+
series = X[col]
75+
if series.nunique(dropna=False) < 2:
76+
raise ValueError(
77+
f"Categorical feature '{col}' has less only a single value ."
78+
)
79+
if pd.api.types.is_numeric_dtype(
80+
series
81+
) and not pd.api.types.is_categorical_dtype(series):
82+
# allow numerical dtypes only if user intends to encode them
83+
pass # optionally warn or convert
84+
if series.isnull().all():
85+
raise ValueError(f"Categorical feature '{col}' contains only NaNs.")
86+
87+
# Check y
88+
if y is not None:
89+
y = np.array(y)
90+
91+
if y.ndim != 1:
92+
raise ValueError("y must be a 1D array or Series.")
93+
94+
if len(y) != len(X):
95+
raise ValueError("X and y must have the same number of samples.")
96+
97+
unique_targets = np.unique(y[~pd.isnull(y)])
98+
n_classes = len(unique_targets)
99+
100+
if task_type == "regression":
101+
if not np.issubdtype(y.dtype, np.number):
102+
raise TypeError("For regression, target y must be numeric.")
103+
if not np.all(np.isfinite(y)):
104+
raise ValueError("Target y contains non-finite values.")
105+
106+
if n_classes <= 10:
107+
warnings.warn(
108+
f"Target y has only {n_classes} unique values. "
109+
"Consider if this should be a classification problem instead of regression.",
110+
UserWarning,
111+
)
112+
113+
elif task_type == "classification":
114+
if n_classes < 2:
115+
raise ValueError("Classification tasks requires more than 1 class.")

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
[tool.poetry]
22
name = "mambular"
33

4-
version = "1.3.2"
4+
version = "1.4.0"
5+
56

67
description = "A python package for tabular deep learning with mamba blocks."
78
authors = ["Anton Thielmann", "Manish Kumar", "Christoph Weisser"]

0 commit comments

Comments
 (0)