|
9 | 9 | import scipy |
10 | 10 | from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, \ |
11 | 11 | ExtraTreesClassifier, RandomForestRegressor, GradientBoostingRegressor, \ |
12 | | - ExtraTreesRegressor |
| 12 | + ExtraTreesRegressor, IsolationForest |
13 | 13 | from sklearn.datasets import load_iris, load_breast_cancer, load_boston, load_svmlight_file |
14 | 14 | from sklearn.model_selection import train_test_split |
15 | 15 |
|
@@ -80,6 +80,18 @@ def test_skl_multiclass_classifier(clazz): |
80 | 80 | np.testing.assert_almost_equal(out_prob, expected_prob, decimal=5) |
81 | 81 |
|
82 | 82 |
|
| 83 | +def test_skl_converter_iforest(): |
| 84 | + """Scikit-learn isolation forest""" |
| 85 | + X, _ = load_boston(return_X_y=True) |
| 86 | + clf = IsolationForest(max_samples=64, random_state=0, n_estimators=10) |
| 87 | + clf.fit(X) |
| 88 | + expected_pred = clf._compute_chunked_score_samples(X) # pylint: disable=W0212 |
| 89 | + |
| 90 | + tl_model = treelite.sklearn.import_model(clf) |
| 91 | + out_pred = treelite.gtil.predict(tl_model, X) |
| 92 | + np.testing.assert_almost_equal(out_pred, expected_pred, decimal=2) |
| 93 | + |
| 94 | + |
83 | 95 | @pytest.mark.parametrize('objective', ['reg:linear', 'reg:squarederror', 'reg:squaredlogerror', |
84 | 96 | 'reg:pseudohubererror']) |
85 | 97 | @pytest.mark.parametrize('model_format', ['binary', 'json']) |
|
0 commit comments