Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions spras/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,12 +410,16 @@ def precision_recall_curve_node_ensemble(node_ensembles: dict, node_table: pd.Da
'Frequency'
] = 0.0

print(input_nodes_ensemble_df)

y_scores_input_nodes = input_nodes_ensemble_df['Frequency'].tolist()

precision_input_nodes, recall_input_nodes, thresholds_input_nodes = precision_recall_curve(y_true, y_scores_input_nodes)
plt.plot(recall_input_nodes, precision_input_nodes, color='black', marker='o', linestyle='--',
label=f'Input Nodes Baseline')

print(precision_input_nodes)
print(recall_input_nodes)
prc_input_nodes_baseline_data = {
'Threshold': thresholds_input_nodes,
'Precision': precision_input_nodes[:-1],
Expand Down
32 changes: 23 additions & 9 deletions test/evaluate/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def setup_class(cls):
'other_files': []
})

# TODO figure out why the input-nodes file is not being included in the data.pickle file
# it keeps coming up empty
with open(out_dataset, 'wb') as f:
pickle.dump(dataset, f)

Expand Down Expand Up @@ -126,18 +128,18 @@ def test_node_ensemble(self):
out_path_file = Path(OUT_DIR + 'node-ensemble.csv')
out_path_file.unlink(missing_ok=True)
ensemble_network = [INPUT_DIR + 'ensemble-network.tsv']
input_network = OUT_DIR + 'data.pickle'
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_network, input_network)
input_data = OUT_DIR + 'data.pickle'
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_network, input_data)
node_ensemble_dict['ensemble'].to_csv(out_path_file, sep='\t', index=False)
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-node-ensemble.csv', shallow=False)

def test_empty_node_ensemble(self):
out_path_file = Path(OUT_DIR + 'empty-node-ensemble.csv')
out_path_file.unlink(missing_ok=True)
empty_ensemble_network = [INPUT_DIR + 'empty-ensemble-network.tsv']
input_network = OUT_DIR + 'data.pickle'
input_data = OUT_DIR + 'data.pickle'
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, empty_ensemble_network,
input_network)
input_data)
node_ensemble_dict['empty'].to_csv(out_path_file, sep='\t', index=False)
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-empty-node-ensemble.csv', shallow=False)

Expand All @@ -147,8 +149,8 @@ def test_multiple_node_ensemble(self):
out_path_empty_file = Path(OUT_DIR + 'empty-node-ensemble.csv')
out_path_empty_file.unlink(missing_ok=True)
ensemble_networks = [INPUT_DIR + 'ensemble-network.tsv', INPUT_DIR + 'empty-ensemble-network.tsv']
input_network = OUT_DIR + 'data.pickle'
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_networks, input_network)
input_data = OUT_DIR + 'data.pickle'
node_ensemble_dict = Evaluation.edge_frequency_node_ensemble(GS_NODE_TABLE, ensemble_networks, input_data)
node_ensemble_dict['ensemble'].to_csv(out_path_file, sep='\t', index=False)
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-node-ensemble.csv', shallow=False)
node_ensemble_dict['empty'].to_csv(out_path_empty_file, sep='\t', index=False)
Expand All @@ -159,9 +161,19 @@ def test_precision_recall_curve_ensemble_nodes(self):
out_path_png.unlink(missing_ok=True)
out_path_file = Path(OUT_DIR + 'pr-curve-ensemble-nodes.txt')
out_path_file.unlink(missing_ok=True)
input_data = OUT_DIR + 'data.pickle'

pickle = Evaluation.from_file(input_data)
input_nodes_df = pickle.get_node_columns(["prize", "active"])

print(input_nodes_df)

ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble.csv', sep='\t', header=0)

print(ensemble_file)

node_ensembles_dict = {'ensemble': ensemble_file}
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, out_path_png,
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, input_data, out_path_png,
out_path_file)
assert out_path_png.exists()
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-pr-curve-ensemble-nodes.txt', shallow=False)
Expand All @@ -171,9 +183,10 @@ def test_precision_recall_curve_ensemble_nodes_empty(self):
out_path_png.unlink(missing_ok=True)
out_path_file = Path(OUT_DIR + 'pr-curve-ensemble-nodes-empty.txt')
out_path_file.unlink(missing_ok=True)
input_data = OUT_DIR + 'data.pickle'
empty_ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble-empty.csv', sep='\t', header=0)
node_ensembles_dict = {'ensemble': empty_ensemble_file}
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, out_path_png,
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, input_data, out_path_png,
out_path_file)
assert out_path_png.exists()
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-pr-curve-ensemble-nodes-empty.txt', shallow=False)
Expand All @@ -183,10 +196,11 @@ def test_precision_recall_curve_multiple_ensemble_nodes(self):
out_path_png.unlink(missing_ok=True)
out_path_file = Path(OUT_DIR + 'pr-curve-multiple-ensemble-nodes.txt')
out_path_file.unlink(missing_ok=True)
input_data = OUT_DIR + 'data.pickle'
ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble.csv', sep='\t', header=0)
empty_ensemble_file = pd.read_csv(INPUT_DIR + 'node-ensemble-empty.csv', sep='\t', header=0)
node_ensembles_dict = {'ensemble1': ensemble_file, 'ensemble2': ensemble_file, 'ensemble3': empty_ensemble_file}
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, out_path_png,
Evaluation.precision_recall_curve_node_ensemble(node_ensembles_dict, GS_NODE_TABLE, input_data, out_path_png,
out_path_file, True)
assert out_path_png.exists()
assert filecmp.cmp(out_path_file, EXPECT_DIR + 'expected-pr-curve-multiple-ensemble-nodes.txt', shallow=False)
Loading