Skip to content

Commit 4cc4f7e

Browse files
authored
Test GTIL with IsolationForest (#370)
1 parent b65bedc commit 4cc4f7e

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

tests/python/test_gtil.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import scipy
1010
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, \
1111
ExtraTreesClassifier, RandomForestRegressor, GradientBoostingRegressor, \
12-
ExtraTreesRegressor
12+
ExtraTreesRegressor, IsolationForest
1313
from sklearn.datasets import load_iris, load_breast_cancer, load_boston, load_svmlight_file
1414
from sklearn.model_selection import train_test_split
1515

@@ -80,6 +80,18 @@ def test_skl_multiclass_classifier(clazz):
8080
np.testing.assert_almost_equal(out_prob, expected_prob, decimal=5)
8181

8282

83+
def test_skl_converter_iforest():
84+
"""Scikit-learn isolation forest"""
85+
X, _ = load_boston(return_X_y=True)
86+
clf = IsolationForest(max_samples=64, random_state=0, n_estimators=10)
87+
clf.fit(X)
88+
expected_pred = clf._compute_chunked_score_samples(X) # pylint: disable=W0212
89+
90+
tl_model = treelite.sklearn.import_model(clf)
91+
out_pred = treelite.gtil.predict(tl_model, X)
92+
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=2)
93+
94+
8395
@pytest.mark.parametrize('objective', ['reg:linear', 'reg:squarederror', 'reg:squaredlogerror',
8496
'reg:pseudohubererror'])
8597
@pytest.mark.parametrize('model_format', ['binary', 'json'])

0 commit comments

Comments
 (0)