Skip to content

Commit 39f264c

Browse files
committed
bug fix
1 parent f60eae7 commit 39f264c

16 files changed

+45
-16
lines changed

tests/modules/scoring/test_bert.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def test_bert_in_pipeline(dataset):
123123
search_space = [
124124
{
125125
"node_type": "scoring",
126+
"target_metric": "scoring_roc_auc",
126127
"search_space": [
127128
{
128129
"module_name": "bert",
@@ -132,7 +133,7 @@ def test_bert_in_pipeline(dataset):
132133
}
133134
],
134135
},
135-
{"node_type": "decision", "search_space": [{"module_name": "argmax"}]},
136+
{"node_type": "decision", "target_metric": "decision_accuracy", "search_space": [{"module_name": "argmax"}]},
136137
]
137138

138139
pipeline = Pipeline.from_search_space(search_space)

tests/modules/scoring/test_catboost.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def test_catboost_in_pipeline(dataset):
154154
search_space = [
155155
{
156156
"node_type": "scoring",
157+
"target_metric": "scoring_roc_auc",
157158
"search_space": [
158159
{
159160
"module_name": "catboost",
@@ -163,7 +164,7 @@ def test_catboost_in_pipeline(dataset):
163164
}
164165
],
165166
},
166-
{"node_type": "decision", "search_space": [{"module_name": "argmax"}]},
167+
{"node_type": "decision", "target_metric": "decision_accuracy", "search_space": [{"module_name": "argmax"}]},
167168
]
168169

169170
pipeline = Pipeline.from_search_space(search_space)

tests/modules/scoring/test_cnn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def test_cnn_in_pipeline(dataset):
128128
search_space = [
129129
{
130130
"node_type": "scoring",
131+
"target_metric": "scoring_roc_auc",
131132
"search_space": [
132133
{
133134
"module_name": "cnn",
@@ -136,7 +137,7 @@ def test_cnn_in_pipeline(dataset):
136137
}
137138
],
138139
},
139-
{"node_type": "decision", "search_space": [{"module_name": "argmax"}]},
140+
{"node_type": "decision", "target_metric": "decision_accuracy", "search_space": [{"module_name": "argmax"}]},
140141
]
141142

142143
pipeline = Pipeline.from_search_space(search_space)

tests/modules/scoring/test_description_bi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def test_description_bi_in_pipeline(dataset):
6464
search_space = [
6565
{
6666
"node_type": "scoring",
67+
"target_metric": "scoring_roc_auc",
6768
"search_space": [
6869
{
6970
"module_name": "description_bi",
@@ -72,7 +73,7 @@ def test_description_bi_in_pipeline(dataset):
7273
}
7374
],
7475
},
75-
{"node_type": "decision", "search_space": [{"module_name": "argmax"}]},
76+
{"node_type": "decision", "target_metric": "decision_accuracy", "search_space": [{"module_name": "argmax"}]},
7677
]
7778

7879
pipeline = Pipeline.from_search_space(search_space)

tests/modules/scoring/test_description_cross.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def test_description_cross_in_pipeline(dataset):
7272
search_space = [
7373
{
7474
"node_type": "scoring",
75+
"target_metric": "scoring_roc_auc",
7576
"search_space": [
7677
{
7778
"module_name": "description_cross",
@@ -80,7 +81,7 @@ def test_description_cross_in_pipeline(dataset):
8081
}
8182
],
8283
},
83-
{"node_type": "decision", "search_space": [{"module_name": "argmax"}]},
84+
{"node_type": "decision", "target_metric": "decision_accuracy", "search_space": [{"module_name": "argmax"}]},
8485
]
8586

8687
pipeline = Pipeline.from_search_space(search_space)

tests/modules/scoring/test_description_llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,15 @@ def test_llm_description_in_pipeline(dataset):
6262
search_space = [
6363
{
6464
"node_type": "scoring",
65+
"target_metric": "scoring_roc_auc",
6566
"search_space": [
6667
{
6768
"module_name": "description_llm",
6869
"temperature": [0.3],
6970
}
7071
],
7172
},
72-
{"node_type": "decision", "search_space": [{"module_name": "argmax"}]},
73+
{"node_type": "decision", "target_metric": "decision_accuracy", "search_space": [{"module_name": "argmax"}]},
7374
]
7475

7576
pipeline = Pipeline.from_search_space(search_space)

tests/modules/scoring/test_dnnc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def test_dnnc_in_pipeline(dataset):
4949
search_space = [
5050
{
5151
"node_type": "scoring",
52+
"target_metric": "scoring_roc_auc",
5253
"search_space": [
5354
{
5455
"module_name": "dnnc",
@@ -57,7 +58,7 @@ def test_dnnc_in_pipeline(dataset):
5758
}
5859
],
5960
},
60-
{"node_type": "decision", "search_space": [{"module_name": "argmax"}]},
61+
{"node_type": "decision", "target_metric": "decision_accuracy", "search_space": [{"module_name": "argmax"}]},
6162
]
6263

6364
pipeline = Pipeline.from_search_space(search_space)

tests/modules/scoring/test_gcn_scorer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def test_gcn_in_pipeline(dataset):
9898
search_space = [
9999
{
100100
"node_type": "scoring",
101+
"target_metric": "scoring_hit_rate",
101102
"search_space": [
102103
{
103104
"module_name": "gcn",
@@ -106,10 +107,15 @@ def test_gcn_in_pipeline(dataset):
106107
}
107108
],
108109
},
109-
{"node_type": "decision", "search_space": [{"module_name": "threshold", "thresh": [0.5]}]},
110+
{
111+
"node_type": "decision",
112+
"target_metric": "decision_accuracy",
113+
"search_space": [{"module_name": "threshold", "thresh": [0.5]}],
114+
},
110115
]
111116

112117
pipeline = Pipeline.from_search_space(search_space)
118+
pipeline.set_config(get_test_embedder_config())
113119
pipeline.fit(dataset.to_multilabel())
114120
predictions = pipeline.predict(["test utterance"])
115121
assert len(predictions) == 1

tests/modules/scoring/test_knn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def test_knn_in_pipeline(dataset):
5353
search_space = [
5454
{
5555
"node_type": "scoring",
56+
"target_metric": "scoring_roc_auc",
5657
"search_space": [
5758
{
5859
"module_name": "knn",
@@ -61,10 +62,11 @@ def test_knn_in_pipeline(dataset):
6162
}
6263
],
6364
},
64-
{"node_type": "decision", "search_space": [{"module_name": "argmax"}]},
65+
{"node_type": "decision", "target_metric": "decision_accuracy", "search_space": [{"module_name": "argmax"}]},
6566
]
6667

6768
pipeline = Pipeline.from_search_space(search_space)
69+
pipeline.set_config(get_test_embedder_config())
6870
pipeline.fit(dataset)
6971
predictions = pipeline.predict(["test utterance"])
7072
assert len(predictions) == 1

tests/modules/scoring/test_linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,18 @@ def test_linear_in_pipeline(dataset):
5353
search_space = [
5454
{
5555
"node_type": "scoring",
56+
"target_metric": "scoring_roc_auc",
5657
"search_space": [
5758
{
5859
"module_name": "linear",
5960
}
6061
],
6162
},
62-
{"node_type": "decision", "search_space": [{"module_name": "argmax"}]},
63+
{"node_type": "decision", "target_metric": "decision_accuracy", "search_space": [{"module_name": "argmax"}]},
6364
]
6465

6566
pipeline = Pipeline.from_search_space(search_space)
67+
pipeline.set_config(get_test_embedder_config())
6668
pipeline.fit(dataset)
6769
predictions = pipeline.predict(["test utterance"])
6870
assert len(predictions) == 1

0 commit comments

Comments
 (0)