Skip to content

Commit 862a53f

Browse files
author
Agathe Guillemot
committed
Update tests + fixes found with the tests
1 parent 21217ba commit 862a53f

File tree

3 files changed

+57
-28
lines changed

3 files changed

+57
-28
lines changed

python-lib/dku_idtb_decision_tree/node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def get_decision_rule(self):
105105
rule += "{} ≤ ".format(self.beginning)
106106
rule += self.feature
107107
if self.end:
108-
rule += "< {}".format(self.end)
108+
rule += " < {}".format(self.end)
109109
return rule
110110

111111
def apply_filter(self, df, mean):

python-lib/dku_idtb_scoring/score.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def get_scored_df_schema(tree, schema, columns, output_probabilities, is_evaluat
3434
if check_prediction:
3535
_add_column('prediction_correct', 'boolean', schema, columns)
3636
_add_column('decision_rule', 'array', schema, columns)
37-
_add_column('node_id', 'int', schema, columns)
37+
_add_column('leaf_id', 'int', schema, columns)
3838
_add_column('label', 'string', schema, columns)
3939
return schema
4040

@@ -75,9 +75,8 @@ def add_scoring_columns(tree, df, output_probabilities, is_evaluation=False, che
7575
df.loc[filtered_df_indices, "prediction"] = leaf.prediction
7676
if check_prediction:
7777
df.loc[filtered_df_indices, "prediction_correct"] = filtered_df[tree.target] == leaf.prediction
78-
df.loc[label_indices, "label"] = leaf.label
7978

8079
filtered_df = tree.get_filtered_df(leaf, df)
8180
df.loc[filtered_df.index, "decision_rule"] = safe_str(tree.get_decision_rule(leaf_id))
82-
df.loc[filtered_df.index, "node_id"] = leaf_id
81+
df.loc[filtered_df.index, "leaf_id"] = leaf_id
8382
df.loc[filtered_df.index, "label"] = leaf.label

python-tests/test_score.py

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -79,47 +79,77 @@ def get_input_df():
7979
def test_score():
8080
df = get_input_df()
8181
add_scoring_columns(tree, df, True)
82-
expected_df = pd.DataFrame([[.2, "u", "A", .8, .2, "A", "hello there"],
83-
[7, pd.np.nan, "B", pd.np.nan, pd.np.nan, pd.np.nan, "general Kenobi"],
84-
[4, "u", "A", .25, .75, "B", None],
85-
[3, "v", "A", .8, .2, "A", "hello there"],
86-
[pd.np.nan, "u", "C", .8, .2, "A", "hello there"]], columns=("num", "cat", "target", "proba_A", "proba_B", "prediction", "label"))
87-
assert df.equals(expected_df)
82+
expected_df = pd.DataFrame([
83+
[.2, "u", "A", .8, .2, "A", str(["num < 4"]), 1.0, "hello there"],
84+
[7, pd.np.nan, "B", pd.np.nan, pd.np.nan, pd.np.nan, str(["4 ≤ num", "cat not in {}".format(["u", "v"])]), 4.0, "general Kenobi"],
85+
[4, "u", "A", .25, .75, "B", str(["4 ≤ num", "cat in {}".format(["u", "v"])]), 3.0, None],
86+
[3, "v", "A", .8, .2, "A", str(["num < 4"]), 1.0, "hello there"],
87+
[pd.np.nan, "u", "C", .8, .2, "A", str(["num < 4"]), 1.0, "hello there"]
88+
], columns=("num", "cat", "target", "proba_A", "proba_B", "prediction", "decision_rule", "leaf_id", "label"))
89+
pd.testing.assert_frame_equal(df, expected_df)
8890

8991
df = get_input_df()
9092
add_scoring_columns(tree, df, False, True, False)
91-
expected_df = pd.DataFrame([[.2, "u", "A", "A", "hello there"],
92-
[7, pd.np.nan, "B", pd.np.nan, "general Kenobi"],
93-
[4, "u", "A", "B", None],
94-
[3, "v", "A", "A", "hello there"],
95-
[pd.np.nan, "u", "C", pd.np.nan, "hello there"]], columns=("num", "cat", "target", "prediction", "label"))
96-
assert df.equals(expected_df)
93+
expected_df = pd.DataFrame([
94+
[.2, "u", "A", "A", str(["num < 4"]), 1.0, "hello there"],
95+
[7, pd.np.nan, "B", pd.np.nan, str(["4 ≤ num", "cat not in {}".format(["u", "v"])]), 4.0, "general Kenobi"],
96+
[4, "u", "A", "B", str(["4 ≤ num", "cat in {}".format(["u", "v"])]), 3.0, None],
97+
[3, "v", "A", "A", str(["num < 4"]), 1.0, "hello there"],
98+
[pd.np.nan, "u", "C", pd.np.nan, str(["num < 4"]), 1.0, "hello there"]
99+
], columns=("num", "cat", "target", "prediction", "decision_rule", "leaf_id", "label"))
100+
pd.testing.assert_frame_equal(df, expected_df)
97101

98102
df = get_input_df()
99103
add_scoring_columns(tree, df, False, True, True)
100-
expected_df = pd.DataFrame([[.2, "u", "A", "A", True, "hello there"],
101-
[7, pd.np.nan, "B", pd.np.nan, pd.np.nan, "general Kenobi"],
102-
[4, "u", "A", "B", False, None],
103-
[3, "v", "A", "A", True, "hello there"],
104-
[pd.np.nan, "u", "C", pd.np.nan, pd.np.nan, "hello there"]], columns=("num", "cat", "target", "prediction", "prediction_correct", "label"))
105-
assert df.equals(expected_df)
104+
expected_df = pd.DataFrame([
105+
[.2, "u", "A", "A", True, str(["num < 4"]), 1.0, "hello there"],
106+
[7, pd.np.nan, "B", pd.np.nan, pd.np.nan, str(["4 ≤ num", "cat not in {}".format(["u", "v"])]), 4.0, "general Kenobi"],
107+
[4, "u", "A", "B", False, str(["4 ≤ num", "cat in {}".format(["u", "v"])]), 3.0, None],
108+
[3, "v", "A", "A", True, str(["num < 4"]), 1.0, "hello there"],
109+
[pd.np.nan, "u", "C", pd.np.nan, pd.np.nan, str(["num < 4"]), 1.0, "hello there"]
110+
], columns=("num", "cat", "target", "prediction", "prediction_correct", "decision_rule", "leaf_id", "label"))
111+
pd.testing.assert_frame_equal(df, expected_df)
106112

107113
def get_input_schema():
108114
return [{"type": "double", "name": "num"}, {"type": "string", "name": "cat"}, {"type": "string", "name": "target"}]
109115

110116
def test_scored_df_schema():
111117
schema = get_scored_df_schema(tree, get_input_schema(), None, True)
112-
assert schema == [{"type": "double", "name": "num"}, {"type": "string", "name": "cat"}, {"type": "string", "name": "target"},
113-
{"type": "double", "name": "proba_A"}, {"type": "double", "name": "proba_B"}, {"type": "string", "name": "prediction"}, {"type": "string", "name": "label"}]
118+
expected_schema = [
119+
{"type": "double", "name": "num"},
120+
{"type": "string", "name": "cat"},
121+
{"type": "string", "name": "target"},
122+
{"type": "double", "name": "proba_A"},
123+
{"type": "double", "name": "proba_B"},
124+
{"type": "string", "name": "prediction"},
125+
{"type": "array", "name": "decision_rule"},
126+
{"type": "int", "name": "leaf_id"},
127+
{"type": "string", "name": "label"}
128+
]
129+
assert schema == expected_schema
114130
columns = []
115131
schema = get_scored_df_schema(tree, get_input_schema(), columns, False, True, False)
116-
assert schema == [{"type": "string", "name": "prediction"}, {"type": "string", "name": "label"}]
117-
assert columns == ["prediction", "label"]
132+
expected_schema = [
133+
{"type": "string", "name": "prediction"},
134+
{"type": "array", "name": "decision_rule"},
135+
{"type": "int", "name": "leaf_id"},
136+
{"type": "string", "name": "label"}
137+
]
138+
assert schema == expected_schema
139+
assert columns == ["prediction", "decision_rule", "leaf_id", "label"]
118140

119141
columns = ["num"]
120142
schema = get_scored_df_schema(tree, get_input_schema(), columns, False, True, True)
121-
assert schema == [{"type": "double", "name": "num"}, {"type": "string", "name": "prediction"}, {"type": "boolean", "name": "prediction_correct"}, {"type": "string", "name": "label"}]
122-
assert columns == ["num", "prediction", "prediction_correct", "label"]
143+
expected_schema = [
144+
{"type": "double", "name": "num"},
145+
{"type": "string", "name": "prediction"},
146+
{"type": "boolean", "name": "prediction_correct"},
147+
{"type": "array", "name": "decision_rule"},
148+
{"type": "int", "name": "leaf_id"},
149+
{"type": "string", "name": "label"}
150+
]
151+
assert schema == expected_schema
152+
assert columns == ["num", "prediction", "prediction_correct", "decision_rule", "leaf_id", "label"]
123153

124154
schema_missing_feature = [{"type": "double", "name": "num"}, {"type": "string", "name": "target"}]
125155
schema_missing_target = [{"type": "double", "name": "num"}, {"type": "string", "name": "cat"}]

0 commit comments

Comments
 (0)