Skip to content

Commit 21217ba

Browse files
author
Agathe Guillemot
committed
Add decision rule & node id in scored dataset
1 parent d721fb9 commit 21217ba

File tree

3 files changed

+40
-15
lines changed

3 files changed

+40
-15
lines changed

python-lib/dku_idtb_decision_tree/node.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ def set_node_info(self, samples, total_samples, probabilities, prediction):
4545
def get_type(self):
4646
raise NotImplementedError
4747

48+
def get_decision_rule(self):
49+
raise NotImplementedError
50+
4851
def rebuild(self, prediction, samples, probabilities):
4952
self.prediction = prediction
5053
self.samples = samples
@@ -67,6 +70,11 @@ def __init__(self, node_id, parent_id, treated_as_numerical, feature, values, ot
6770
def get_type(self):
6871
return Node.TYPES.CAT
6972

73+
def get_decision_rule(self):
74+
return "{feature} {negation}in {values}".format(
75+
feature=self.feature, negation="not " if self.others else "", values=self.values
76+
)
77+
7078
def apply_filter(self, df):
7179
if self.others:
7280
return df[~df[self.feature].isin(self.values)]
@@ -91,6 +99,15 @@ def __init__(self, node_id, parent_id, treated_as_numerical, feature, beginning=
9199
def get_type(self):
92100
return Node.TYPES.NUM
93101

102+
def get_decision_rule(self):
103+
rule = ""
104+
if self.beginning:
105+
rule += "{} ≤ ".format(self.beginning)
106+
rule += self.feature
107+
if self.end:
108+
rule += "< {}".format(self.end)
109+
return rule
110+
94111
def apply_filter(self, df, mean):
95112
if self.beginning is not None:
96113
df = df[df[self.feature].ge(self.beginning, fill_value=mean)]

python-lib/dku_idtb_decision_tree/tree.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ def add_node(self, node):
8383
parent_node.children_ids.append(node.id)
8484
super(ScoringTree, self).add_node(node)
8585

86+
def get_decision_rule(self, node_id):
87+
rule = deque()
88+
while node_id > 0:
89+
node = self.get_node(node_id)
90+
rule.appendleft(node.get_decision_rule())
91+
node_id = node.parent_id
92+
return list(rule)
93+
8694
#Used by the webapp
8795
class InteractiveTree(Tree):
8896
"""

python-lib/dku_idtb_scoring/score.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,24 @@ def update_input_schema(input_schema, columns):
1818
new_input_schema.append(column)
1919
return new_input_schema
2020

21+
def _add_column(name, type, schema, columns=None):
22+
schema.append({'type': type, 'name': name})
23+
if columns is not None:
24+
columns.append(name)
25+
2126
def get_scored_df_schema(tree, schema, columns, output_probabilities, is_evaluation=False, check_prediction=False):
2227
check_input_schema(tree, set(column["name"] for column in schema), is_evaluation)
2328
if columns is not None:
2429
schema = update_input_schema(schema, columns)
2530
if output_probabilities:
2631
for value in tree.target_values:
27-
schema.append({'type': 'double', 'name': "proba_" + safe_str(value)})
28-
if columns is not None:
29-
columns.append("proba_"+safe_str(value))
30-
schema.append({'type': 'string', 'name': 'prediction'})
31-
if columns is not None:
32-
columns.append("prediction")
32+
_add_column('proba_' + safe_str(value), 'double', schema, columns)
33+
_add_column('prediction', 'string', schema, columns)
3334
if check_prediction:
34-
schema.append({'type': 'boolean', 'name': 'prediction_correct'})
35-
if columns is not None:
36-
columns.append("prediction_correct")
37-
schema.append({'type': 'string', 'name': 'label'})
38-
if columns is not None:
39-
columns.append("label")
35+
_add_column('prediction_correct', 'boolean', schema, columns)
36+
_add_column('decision_rule', 'array', schema, columns)
37+
_add_column('node_id', 'int', schema, columns)
38+
_add_column('label', 'string', schema, columns)
4039
return schema
4140

4241
def get_metric_df_schema(metrics_dict, metrics, recipe_config):
@@ -78,6 +77,7 @@ def add_scoring_columns(tree, df, output_probabilities, is_evaluation=False, che
7877
df.loc[filtered_df_indices, "prediction_correct"] = filtered_df[tree.target] == leaf.prediction
7978
df.loc[label_indices, "label"] = leaf.label
8079

81-
elif leaf.label is not None:
82-
filtered_df = tree.get_filtered_df(leaf, df)
83-
df.loc[filtered_df.index, "label"] = leaf.label
80+
filtered_df = tree.get_filtered_df(leaf, df)
81+
df.loc[filtered_df.index, "decision_rule"] = safe_str(tree.get_decision_rule(leaf_id))
82+
df.loc[filtered_df.index, "node_id"] = leaf_id
83+
df.loc[filtered_df.index, "label"] = leaf.label

0 commit comments

Comments
 (0)