Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions Orange/widgets/classify/owrandomforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class OWRandomForest(OWBaseLearner):
use_random_state = settings.Setting(False)
max_depth = settings.Setting(3)
use_max_depth = settings.Setting(False)
max_leaf_nodes = settings.Setting(5)
use_max_leaf_nodes = settings.Setting(True)
min_samples_split = settings.Setting(5)
use_min_samples_split = settings.Setting(True)
index_output = settings.Setting(0)

def add_main_layout(self):
Expand All @@ -35,7 +35,7 @@ def add_main_layout(self):

form.addWidget(QLabel(self.tr("Number of trees: ")),
0, 0, Qt.AlignLeft)
spin = gui.spin(basic_box, self, "n_estimators", minv=1, maxv=1e4,
spin = gui.spin(basic_box, self, "n_estimators", minv=1, maxv=10000,
callback=self.settings_changed, addToLayout=False,
controlWidth=50)
form.addWidget(spin, 0, 1, Qt.AlignRight)
Expand Down Expand Up @@ -76,25 +76,25 @@ def add_main_layout(self):
addToLayout=False)

max_depth_spin = gui.spin(
growth_box, self, "max_depth", 2, 50, addToLayout=False,
growth_box, self, "max_depth", 1, 50, addToLayout=False,
callback=self.settings_changed)

form.addWidget(max_depth_cb, 3, 0, Qt.AlignLeft)
form.addWidget(max_depth_spin, 3, 1, Qt.AlignRight)

max_leaf_nodes_cb = gui.checkBox(
growth_box, self, "use_max_leaf_nodes",
min_samples_split_cb = gui.checkBox(
growth_box, self, "use_min_samples_split",
label="Do not split subsets smaller than: ",
callback=self.settings_changed, addToLayout=False)

max_leaf_nodes_spin = gui.spin(
growth_box, self, "max_leaf_nodes", 0, 100, addToLayout=False,
min_samples_split_spin = gui.spin(
growth_box, self, "min_samples_split", 1, 1000, addToLayout=False,
callback=self.settings_changed)

form.addWidget(max_leaf_nodes_cb, 4, 0, Qt.AlignLeft)
form.addWidget(max_leaf_nodes_spin, 4, 1, Qt.AlignRight)
form.addWidget(min_samples_split_cb, 4, 0, Qt.AlignLeft)
form.addWidget(min_samples_split_spin, 4, 1, Qt.AlignRight)
self._max_depth_spin = max_depth_spin
self._max_leaf_nodes_spin = max_leaf_nodes_spin
self._min_samples_split_spin = min_samples_split_spin

# Index on the output
# gui.doubleSpin(self.controlArea, self, "index_output", 0, 10000, 1,
Expand All @@ -108,8 +108,8 @@ def create_learner(self):
common_args["random_state"] = self.random_state
if self.use_max_depth:
common_args["max_depth"] = self.max_depth
if self.use_max_leaf_nodes:
common_args["max_leaf_nodes"] = self.max_leaf_nodes
if self.use_min_samples_split:
common_args["min_samples_split"] = self.min_samples_split

return self.LEARNER(preprocessors=self.preprocessors, **common_args)

Expand All @@ -118,17 +118,20 @@ def settings_changed(self):
self._max_features_spin.setEnabled(self.use_max_features)
self._random_state_spin.setEnabled(self.use_random_state)
self._max_depth_spin.setEnabled(self.use_max_depth)
self._max_leaf_nodes_spin.setEnabled(self.use_max_leaf_nodes)
self._min_samples_split_spin.setEnabled(self.use_min_samples_split)

def get_learner_parameters(self):
return (("Number of trees", self.n_estimators),
("Maximal number of considered features",
self.max_features if self.use_max_features else "unlimited"),
("Fixed random seed", self.use_random_state and self.random_state),
("Maximal tree depth",
self.max_depth if self.use_max_depth else "unlimited"),
("Stop splitting nodes with maximum instances",
self.max_leaf_nodes if self.use_max_leaf_nodes else "unlimited"))
"""Called by send report to list the parameters of the learner."""
return (
("Number of trees", self.n_estimators),
("Maximal number of considered features",
self.max_features if self.use_max_features else "unlimited"),
("Fixed random seed", self.use_random_state and self.random_state),
("Maximal tree depth",
self.max_depth if self.use_max_depth else "unlimited"),
("Stop splitting nodes with maximum instances",
self.min_samples_split if self.use_min_samples_split else "unlimited")
)


if __name__ == "__main__":
Expand Down