Skip to content

Commit 767fab9

Browse files
author
[zebinyang]
committed
fix a bug in decision_rule; update version 0.1.3
1 parent ec0d871 commit 767fab9

File tree

3 files changed

+19
-15
lines changed

3 files changed

+19
-15
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from setuptools import setup
22

33
setup(name='simtree',
4-
version='0.1.2',
4+
version='0.1.3',
55
description='Single-index model tree',
66
url='https://github.com/ZebinYang/SIMTree',
77
author='Zebin Yang',

simtree/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88
"SIMTreeRegressor", "SIMTreeClassifier",
99
"CustomMobTreeRegressor", "CustomMobTreeClassifier"]
1010

11-
__version__ = '0.1.2'
11+
__version__ = '0.1.3'
1212
__author__ = 'Zebin Yang'

simtree/mobtree.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -445,25 +445,29 @@ def decision_rule(self, node_id):
445445
key = str(parent_node["feature"])
446446
if key not in rule_dict.keys():
447447
if current_node["is_left"]:
448-
rule_dict.update({key:{"split_feature": parent_node["feature"],
449-
"threshold_left": parent_node["threshold"]}})
448+
rule_dict.update({key:{"left": parent_node["threshold"]}})
450449
else:
451-
rule_dict.update({key:{"split_feature": parent_node["feature"],
452-
"threshold_right": parent_node["threshold"]}})
453-
elif "threshold_left" not in rule_dict[key].keys():
454-
rule_dict[key].update({"threshold_left": parent_node["threshold"]})
455-
elif "threshold_right" not in rule_dict[key].keys():
456-
rule_dict[key].update({"threshold_right": parent_node["threshold"]})
450+
rule_dict.update({key:{"right": parent_node["threshold"]}})
451+
else:
452+
if "left" not in rule_dict[key].keys():
453+
rule_dict[key].update({"left": parent_node["threshold"]})
454+
else:
455+
rule_dict[key].update({"left": min(parent_node["threshold"], rule_dict[key]["left"])})
456+
if "right" not in rule_dict[key].keys():
457+
rule_dict[key].update({"right": parent_node["threshold"]})
458+
else:
459+
rule_dict[key].update({"right": max(parent_node["threshold"], rule_dict[key]["right"])})
457460
current_node = parent_node
461+
print(rule_dict)
458462

459463
rule_list = []
460464
for key, item in rule_dict.items():
461465
rule = ""
462-
if "threshold_right" in item.keys():
463-
rule += str(round(item["threshold_right"], 3)) + "<"
464-
rule += self.feature_names[item["split_feature"]]
465-
if "threshold_left" in item.keys():
466-
rule += "<=" + str(round(item["threshold_left"], 3))
466+
if "right" in item.keys():
467+
rule += str(round(item["right"], 3)) + "<"
468+
rule += self.feature_names[int(key)]
469+
if "left" in item.keys():
470+
rule += "<=" + str(round(item["left"], 3))
467471
rule_list.append(rule)
468472
return rule_list
469473

0 commit comments

Comments
 (0)