Skip to content

Commit d71f205

Browse files
committed
Add y input to preprocess for all models
1 parent fb1cabe commit d71f205

File tree

9 files changed

+9
-9
lines changed

9 files changed

+9
-9
lines changed

examples/benchmarking/custom_tabarena_model/custom_random_forest_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _fit(
6262
# case for 'binary' and 'multiclass',
6363
model_cls = RandomForestClassifier
6464

65-
X = self.preprocess(X, is_train=True)
65+
X = self.preprocess(X, y=y, is_train=True)
6666
params = self._get_model_params()
6767
self.model = model_cls(**params)
6868
self.model.fit(X, y)

tabarena/tabarena/benchmark/models/ag/ebm/ebm_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _fit(
4545
**kwargs,
4646
):
4747
# Preprocess data.
48-
X = self.preprocess(X)
48+
X = self.preprocess(X, y=y)
4949
if X_val is not None:
5050
X_val = self.preprocess(X_val)
5151

tabarena/tabarena/benchmark/models/ag/knn_new/knn_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _fit(self, X, y, num_cpus=-1, time_limit=None, sample_weight=None, **kwargs)
7373
cat_cols = self._feature_metadata.get_features(valid_raw_types=[R_CATEGORY])
7474
self.knn_preprocessor = KNNPreprocessor(cat_threshold=cat_threshold, categorical_features=cat_cols, numeric_strategy=scaler)
7575

76-
X = self.preprocess(X, is_train=True)
76+
X = self.preprocess(X, y=y, is_train=True)
7777

7878
num_rows_max = len(X)
7979

tabarena/tabarena/benchmark/models/ag/modernnca/modernnca_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def _fit(
346346

347347
hyp = self._get_model_params()
348348
bool_to_cat = hyp.pop("bool_to_cat", True)
349-
X = self.preprocess(X, is_train=True, bool_to_cat=bool_to_cat)
349+
X = self.preprocess(X, y=y, is_train=True, bool_to_cat=bool_to_cat)
350350
if X_val is not None:
351351
X_val = self.preprocess(X_val)
352352

tabarena/tabarena/benchmark/models/ag/sap_rpt_oss/sap_rpt_oss_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def _fit(
5757
# TODO: make code support this like a normal sklearn model
5858
self.model.seed = random_state
5959

60-
X = self.preprocess(X) # does nothing, as no preprocessing is defined
60+
X = self.preprocess(X, y=y) # does nothing, as no preprocessing is defined
6161
self.model = self.model.fit(
6262
X=X,
6363
y=y,

tabarena/tabarena/benchmark/models/ag/tabdpt/tabdpt_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _fit(
5858
random_seed = hps.pop(self.seed_name, self.default_random_seed)
5959
self._predict_hps = {k: v for k, v in hps.items() if k in supported_predict_hps}
6060
self._predict_hps["seed"] = random_seed
61-
X = self.preprocess(X)
61+
X = self.preprocess(X, y=y)
6262
y = y.to_numpy()
6363
self.model = model_cls(
6464
device=device,

tabarena/tabarena/benchmark/models/ag/tabicl/tabicl_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def _fit(
118118
device=device,
119119
n_jobs=num_cpus,
120120
)
121-
X = self.preprocess(X)
121+
X = self.preprocess(X, y=y)
122122
self.model = self.model.fit(
123123
X=X,
124124
y=y,

tabarena/tabarena/benchmark/models/ag/tabm/tabm_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _fit(
9292
hyp = self._get_model_params()
9393
bool_to_cat = hyp.pop("bool_to_cat", True)
9494

95-
X = self.preprocess(X, is_train=True, bool_to_cat=bool_to_cat)
95+
X = self.preprocess(X, y=y, is_train=True, bool_to_cat=bool_to_cat)
9696
if X_val is not None:
9797
X_val = self.preprocess(X_val)
9898

tabarena/tabarena/benchmark/models/ag/xrfm/xrfm_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def _fit(
295295
# todo: do we already need to move stuff to the correct device?
296296

297297
X = self.preprocess(
298-
X, is_train=True, bool_to_cat=bool_to_cat, impute_bool=impute_bool
298+
X, y=y, is_train=True, bool_to_cat=bool_to_cat, impute_bool=impute_bool
299299
)
300300

301301
if X_val is not None:

0 commit comments

Comments
 (0)