Skip to content

Commit c0a7009

Browse files
willy-liuhcho3
andauthored
fix: iForest feature mapping (#620)
* fix: iForest feature mapping Ensure correct handling of feature subsampling in Isolation Forest. When `max_features != 1.0`, the feature index is subsampled, which could affect mapping consistency. * test: add test for IsolationForest with max_features < 1.0 Adds a unit test to validate Treelite's import_model for IsolationForest when max_features is set to a random float between 0.2 and 0.8. * Fix formatting --------- Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu> Co-authored-by: Hyunsu Cho <phcho@nvidia.com>
1 parent 523f64b commit c0a7009

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

python/treelite/sklearn/importer.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def import_model(sklearn_model) -> Model:
180180
n_node_samples = ArrayOfArrays(dtype=np.int64)
181181
weighted_n_node_samples = ArrayOfArrays(dtype=np.float64)
182182
impurity = ArrayOfArrays(dtype=np.float64)
183-
for estimator in sklearn_model.estimators_:
183+
for tree_idx, estimator in enumerate(sklearn_model.estimators_):
184184
if isinstance(sklearn_model, (GradientBoostingR, GradientBoostingC)):
185185
estimator_range = estimator
186186
learning_rate = sklearn_model.learning_rate
@@ -197,20 +197,31 @@ def import_model(sklearn_model) -> Model:
197197
node_count.append(tree.node_count)
198198
children_left.add(tree.children_left, expected_shape=(tree.node_count,))
199199
children_right.add(tree.children_right, expected_shape=(tree.node_count,))
200-
feature.add(tree.feature, expected_shape=(tree.node_count,))
201200
threshold.add(tree.threshold, expected_shape=(tree.node_count,))
202201
if isinstance(sklearn_model, IsolationForest):
203202
value.add(
204203
isolation_depths.reshape((-1, 1, 1)),
205204
expected_shape=leaf_value_expected_shape(tree.node_count),
206205
)
206+
# Note: for isolation forest, if max_features != 1.0
207+
# the feature index will be subsampled
208+
feature_subsample = np.full(tree.feature.shape, -2, dtype=np.int64)
209+
mask = tree.feature != -2
210+
feature_subsample[mask] = np.array(
211+
sklearn_model.estimators_features_[tree_idx]
212+
)[tree.feature[mask]]
213+
feature.add(
214+
feature_subsample.astype(np.int64),
215+
expected_shape=(tree.node_count,),
216+
)
207217
else:
208218
# Note: for gradient boosted trees, we shrink each leaf output by the
209219
# learning rate
210220
value.add(
211221
tree.value * learning_rate,
212222
expected_shape=leaf_value_expected_shape(tree.node_count),
213223
)
224+
feature.add(tree.feature, expected_shape=(tree.node_count,))
214225
n_node_samples.add(tree.n_node_samples, expected_shape=(tree.node_count,))
215226
weighted_n_node_samples.add(
216227
tree.weighted_n_node_samples, expected_shape=(tree.node_count,)

tests/python/test_sklearn_integration.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,30 @@ def test_skl_converter_iforest(dataset):
195195
np.testing.assert_almost_equal(out_pred, expected_pred)
196196

197197

198+
@given(
199+
dataset=standard_regression_datasets(),
200+
max_feat=floats(min_value=0.2, max_value=0.8),
201+
)
202+
@settings(**standard_settings())
203+
def test_skl_converter_iforest_feature_subsampling(dataset, max_feat):
204+
"""Scikit-learn isolation forest with feature subsampling"""
205+
X, _ = dataset
206+
clf = IsolationForest(
207+
max_samples=64,
208+
max_features=max_feat,
209+
n_estimators=10,
210+
n_jobs=-1,
211+
random_state=0,
212+
)
213+
clf.fit(X)
214+
expected_pred = -clf.score_samples(X).reshape((-1, 1, 1))
215+
216+
tl_model = treelite.sklearn.import_model(clf)
217+
out_pred = treelite.gtil.predict(tl_model, X)
218+
219+
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5)
220+
221+
198222
@given(
199223
dataset=standard_classification_datasets(
200224
n_classes=integers(min_value=2, max_value=4),

0 commit comments

Comments
 (0)