Skip to content

Commit 05daeb6

Browse files
trivialfisimeyer2
andauthored
[backport] Add enable_categorical to the apply method (dmlc#11550) (dmlc#11581)
--------- Co-authored-by: imeyer2 <[email protected]>
1 parent 45b8fa3 commit 05daeb6

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

python-package/xgboost/sklearn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,7 @@ def apply(
13871387
missing=self.missing,
13881388
feature_types=self.feature_types,
13891389
nthread=self.n_jobs,
1390+
enable_categorical=self.enable_categorical,
13901391
)
13911392
return self.get_booster().predict(
13921393
test_dmatrix, pred_leaf=True, iteration_range=iteration_range

tests/python/test_with_sklearn.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,3 +1619,23 @@ def test_doc_link() -> None:
16191619
name = est.__class__.__name__
16201620
link = est._get_doc_link()
16211621
assert f"xgboost.{name}" in link
1622+
1623+
1624+
def test_apply_method():
1625+
import pandas as pd
1626+
1627+
X_num = np.random.rand(5, 5)
1628+
df = pd.DataFrame(X_num, columns=[f"f{i}" for i in range(X_num.shape[1])])
1629+
df["test"] = pd.Series(
1630+
["one", "two", "three", "four", "five"], dtype="category"
1631+
) # <- categorical column
1632+
y = np.arange(len(df))
1633+
1634+
model = xgb.XGBClassifier(enable_categorical=True)
1635+
model.fit(df, y)
1636+
1637+
model.apply(df) # this must not raise
1638+
1639+
model.set_params(enable_categorical=False)
1640+
with pytest.raises(ValueError, match="`enable_categorical`"):
1641+
model.apply(df)

0 commit comments

Comments
 (0)