Skip to content

Commit cd37442

Browse files
committed
added rng to trepan recipe
1 parent e21244f commit cd37442

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

generalizedtrees/recipes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the BSD 3-Clause License
44
# Copyright (c) 2020, Yuriy Sverchkov
55

6-
from numpy import ma
6+
from numpy.random import default_rng
77
from sklearn.linear_model import LogisticRegression
88

99
from generalizedtrees.node import MTNode, TrepanNode
@@ -64,6 +64,7 @@ def trepan(
6464
min_samples: int = 1000,
6565
dist_test_alpha = 0.05,
6666
max_attempts = 1000,
67+
rng = default_rng()
6768
) -> GreedyTreeLearner:
6869
"""
6970
Recipe for Trepan* (Craven and Shavlik 1995)
@@ -94,7 +95,7 @@ def trepan(
9495
learner.node_builder = ModelTranslationNodeBuilderLC(
9596
leaf_model=ConstantEstimator,
9697
min_samples=min_samples,
97-
data_factory=TrepanDataFactoryLC(alpha=dist_test_alpha, max_attempts=max_attempts),
98+
data_factory=TrepanDataFactoryLC(alpha=dist_test_alpha, max_attempts=max_attempts, rng=rng),
9899
node_type=TrepanNode)
99100

100101
return learner

0 commit comments

Comments
 (0)