diff --git a/sklift/models/models.py b/sklift/models/models.py index 52d3e05..ed48829 100644 --- a/sklift/models/models.py +++ b/sklift/models/models.py @@ -637,8 +637,8 @@ def predict(self, X): else: if self._type_of_target == 'binary': - self.ctrl_preds_ = self.estimator_ctrl.predict_proba(X)[:, 1] - self.trmnt_preds_ = self.estimator_trmnt.predict_proba(X)[:, 1] + self.ctrl_preds_ = self.estimator_ctrl.predict_proba(X) + self.trmnt_preds_ = self.estimator_trmnt.predict_proba(X) else: self.ctrl_preds_ = self.estimator_ctrl.predict(X) self.trmnt_preds_ = self.estimator_trmnt.predict(X) diff --git a/sklift/tests/conftest.py b/sklift/tests/conftest.py index ed23faf..01b5e50 100644 --- a/sklift/tests/conftest.py +++ b/sklift/tests/conftest.py @@ -36,7 +36,7 @@ def random_xy_dataset_regr(request): treat = (np.random.normal(0, 2, (n,)) > 0.0).astype(int) if dataset_type == 'numpy': return X, y, treat - return pd.DataFrame(X), pd.Series(y), pd.Series(treat) + return pd.DataFrame(X, columns=[f"feat_{i}" for i in range(X.shape[1])]), pd.Series(y), pd.Series(treat) @pytest.fixture( @@ -65,5 +65,5 @@ def random_xyt_dataset_clf(request): if dataset_type == 'numpy': return X, y, treat - return pd.DataFrame(X), pd.Series(y), pd.Series(treat) + return pd.DataFrame(X, columns=[f"feat_{i}" for i in range(X.shape[1])]), pd.Series(y), pd.Series(treat) diff --git a/sklift/tests/test_models.py b/sklift/tests/test_models.py index 2e58281..e8c3d1b 100644 --- a/sklift/tests/test_models.py +++ b/sklift/tests/test_models.py @@ -27,7 +27,9 @@ ) def test_shape_classification(model, random_xyt_dataset_clf): X, y, treat = random_xyt_dataset_clf - assert model.fit(X, y, treat).predict(X).shape[0] == y.shape[0] + preds = model.fit(X, y, treat).predict(X) + assert preds.shape[0] == y.shape[0], 'different 0 dim' + assert pd.DataFrame(preds).shape[1] == pd.DataFrame(y).shape[1], 'different 1 dim' pipe = Pipeline(steps=[("scaler", StandardScaler()), ("clf", model)]) assert pipe.fit(X, y, clf__treatment=treat).predict(X).shape[0] == y.shape[0]