Skip to content

Commit 2230456

Browse files
authored
Merge pull request #52 from CSDLLab/2021_6_5_rewrite_detail_save
rewrite detail save
2 parents af48664 + c770eb2 commit 2230456

File tree

10 files changed

+61
-62
lines changed

10 files changed

+61
-62
lines changed

config/diadem_metric.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
"remove_spur": false,
44
"count_excess_nodes": true,
55
"align_tree_by_root": false,
6-
"list_miss": false,
7-
"list_distant_matches": false,
8-
"list_continuations": false,
6+
"list_miss": true,
7+
"list_distant_matches": true,
8+
"list_continuations": true,
99
"find_proper_root": true,
1010
"z_scale": 1,
1111
"TRAJECTORY_NONE": -1.0,

config/ssd_metric.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"threshold_mode": 1,
3-
"ssd_threshold": 2,
3+
"ssd_threshold": 1,
44
"up_sample_threshold": 1.0,
55
"z_scale": 1,
66
"debug": false

pyneval/cli/pyneval.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -185,16 +185,6 @@ def set_configs(abs_dir, args):
185185
# info: how many trees read
186186
print("There are {} test image(s)".format(len(test_swc_trees)))
187187

188-
# argument: output
189-
output_dir = None
190-
if args.output:
191-
output_dir = os.path.join(abs_dir, args.output)
192-
193-
# argument: detail
194-
detail_dir = None
195-
if args.detail:
196-
detail_dir = os.path.join(abs_dir, args.detail)
197-
198188
# argument: config
199189
config_path = args.config
200190
if config_path is None:
@@ -208,36 +198,46 @@ def set_configs(abs_dir, args):
208198
except Exception:
209199
raise Exception("[Error: ]Error in analyzing config json file")
210200

201+
# argument: output
202+
output_dir = None
203+
if args.output:
204+
output_dir = os.path.join(abs_dir, args.output)
205+
206+
# argument: detail
207+
detail_dir = None
208+
if args.detail:
209+
detail_dir = os.path.join(abs_dir, args.detail)
210+
config["detail"] = True
211+
211212
# argument: debug
212213
is_debug = args.debug
213214

214215
return gold_swc_tree, test_swc_trees, metric, output_dir, detail_dir, config, is_debug
215216

216217

217-
def excute_metric(metric, gold_swc_tree, test_swc_tree, config, detail_dir, output_dir, file_name_extra=""):
218+
def excute_metric(metric, gold_swc_tree, test_swc_tree, config, detail_dir, output_dir):
218219
metric_method = get_metric_method(metric)
219220
test_swc_name = test_swc_tree.get_name()
220-
gold_swc_name = gold_swc_tree.get_name()
221221

222-
result = metric_method(gold_swc_tree=gold_swc_tree, test_swc_tree=test_swc_tree, config=config)
222+
result, res_gold_swc_tree, res_test_swc_tree = metric_method(gold_swc_tree=gold_swc_tree,
223+
test_swc_tree=test_swc_tree, config=config)
223224

224225
print("---------------Result---------------")
225226
for key in result:
226227
print("{} = {}".format(key.ljust(15, ' '), result[key]))
227228
print("----------------End-----------------\n")
228229

229-
if file_name_extra == "reverse":
230-
file_name = gold_swc_name[:-4] + "_" + metric + "_" + file_name_extra + ".swc"
231-
else:
232-
file_name = test_swc_name[:-4] + "_" + metric + "_" + file_name_extra + ".swc"
230+
file_name = test_swc_name[:-4] + "_" + metric + "_"
233231

234232
if detail_dir:
235-
swc_save(swc_tree=gold_swc_tree,
236-
out_path=os.path.join(detail_dir, file_name))
233+
swc_save(swc_tree=res_gold_swc_tree,
234+
out_path=os.path.join(detail_dir, file_name + "recall.swc"))
235+
swc_save(swc_tree=res_test_swc_tree,
236+
out_path=os.path.join(detail_dir, file_name + "precision.swc"))
237237

238238
if output_dir:
239239
read_json.save_json(data=result,
240-
json_file_path=os.path.join(output_dir, file_name))
240+
json_file_path=os.path.join(output_dir, file_name + ".json"))
241241

242242

243243
# command program
@@ -254,9 +254,6 @@ def run():
254254
for test_swc_tree in test_swc_trees:
255255
excute_metric(metric=metric, gold_swc_tree=gold_swc_tree, test_swc_tree=test_swc_tree,
256256
config=config, detail_dir=detail_dir, output_dir=output_dir)
257-
if metric in ["length_metric", "diadem_metric"]:
258-
excute_metric(metric=metric, gold_swc_tree=test_swc_tree, test_swc_tree=gold_swc_tree,
259-
config=config, detail_dir=detail_dir, output_dir=output_dir, file_name_extra="reverse")
260257

261258

262259
if __name__ == "__main__":
@@ -278,4 +275,4 @@ def run():
278275

279276
# pyneval --gold .\\data\test_data\geo_metric_data\gold_34_23_10.swc --test .\data\test_data\geo_metric_data\test_34_23_10.swc --metric branch_metric
280277

281-
# pyneval --gold ./data/test_data/geo_metric_data/gold_fake_data1.swc --test ./data/test_data/geo_test/ --metric branch_metric --detail ./output
278+
# pyneval --gold ./data/test_data/geo_metric_data/gold_fake_data1.swc --test ./data/test_data/geo_test/test_fake_data1.swc --metric branch_metric --detail ./output/detail --output ./output/output

pyneval/io/read_json.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def save_json(json_file_path, data, DEBUG=False):
2626
raise Exception("[Error: ] \" {} \" is not a json file. Wrong format".format(json_file_path))
2727
try:
2828
with open(json_file_path, 'w') as f:
29-
json.dump(data, f)
29+
json.dump(data, f, indent=4)
3030
if DEBUG:
3131
print(type(data))
3232
except:

pyneval/metric/branch_leaf_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def branch_leaf_metric(gold_swc_tree, test_swc_tree, config):
188188
"pt_cost": branch_result_tuple[7],
189189
"iso_node_num": branch_result_tuple[8]
190190
}
191-
return branch_result
191+
return branch_result, gold_swc_tree, test_swc_tree
192192

193193

194194
if __name__ == "__main__":

pyneval/metric/diadem_metric.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -780,24 +780,24 @@ def color_tree_only():
780780
if g_list_miss:
781781
if len(g_miss) > 0:
782782
for node in g_miss:
783-
# 3 means this node is missed
784-
node.data._type = 2
783+
# 9 means this node is missed
784+
node.data._type = 9
785785
if len(g_excess_nodes) > 0:
786786
for node in g_excess_nodes.keys():
787-
# 4 means this node is excessive
788-
node.data._type = 3
787+
# 10 means this node is excessive
788+
node.data._type = 10
789789

790790
if g_list_continuations:
791791
if len(g_continuation) > 0:
792792
for node in g_continuation:
793-
# 5 means this node is a continuation
794-
node.data._type = 4
793+
# 11 means this node is a continuation
794+
node.data._type = 11
795795

796796
if g_list_distant_matches:
797797
if len(g_distance_match) > 0:
798798
for node in g_distance_match:
799-
# 6 means this node is a distant match
800-
node.data._type = 5
799+
# 12 means this node is a distant match
800+
node.data._type = 12
801801

802802

803803
def print_result():
@@ -821,8 +821,8 @@ def print_result():
821821
print("node_ID = {} poi = {} weight = {}".format(
822822
node.data.get_id(), node.data._pos, g_weight_dict[node]
823823
))
824-
# 3 means this node is missed
825-
node.data._type = 2
824+
# 9 means this node is missed
825+
node.data._type = 9
826826
print("--END--")
827827
else:
828828
print("---Nodes that are missed:None---")
@@ -835,8 +835,8 @@ def print_result():
835835
print("node_ID = {} poi = {} weight = {}".format(
836836
node.data.get_id(), node.data._pos, g_excess_nodes[node]
837837
))
838-
# 4 means this node is excessive
839-
node.data._type = 3
838+
# 10 means this node is excessive
839+
node.data._type = 10
840840
else:
841841
print("---extra Nodes in test reconstruction: None---")
842842

@@ -848,8 +848,8 @@ def print_result():
848848
print("node_ID = {} poi = {} weight = {}".format(
849849
node.data.get_id(), node.data._pos, g_weight_dict[node]
850850
))
851-
# 5 means this node is a continuation
852-
node.data._type = 4
851+
# 11 means this node is a continuation
852+
node.data._type = 11
853853
else:
854854
print("---continuation Nodes None---")
855855

@@ -861,8 +861,8 @@ def print_result():
861861
print("node_ID = {} poi = {} weight = {}".format(
862862
node.data.get_id(), node.data._pos, g_weight_dict[node]
863863
))
864-
# 6 means this node is a distant match
865-
node.data._type = 5
864+
# 12 means this node is a distant match
865+
node.data._type = 12
866866
else:
867867
print("Distant Matches: none")
868868

@@ -935,8 +935,8 @@ def diadem_metric(gold_swc_tree, test_swc_tree, config):
935935
"""
936936
global g_spur_set
937937
global g_weight_dict
938-
gold_swc_tree.type_clear(0)
939-
test_swc_tree.type_clear(1)
938+
gold_swc_tree.set_node_type_by_topo(root_id=1)
939+
test_swc_tree.set_node_type_by_topo(root_id=5)
940940
diadem_init()
941941
config_init(config)
942942
diadam_match_utils.diadem_utils_init(config)
@@ -979,6 +979,7 @@ def diadem_metric(gold_swc_tree, test_swc_tree, config):
979979
print('match1 = {}, match2 = {}'.format(
980980
key.data.get_id(), g_matches[key].data.get_id()
981981
))
982+
color_tree_only()
982983
if debug:
983984
for k in g_weight_dict:
984985
print("id = {} wt = {}".format(k.data.get_id(), g_weight_dict[k]))
@@ -991,7 +992,7 @@ def diadem_metric(gold_swc_tree, test_swc_tree, config):
991992
"score_sum": g_score_sum,
992993
"final_score": g_final_score
993994
}
994-
return res
995+
return res, gold_swc_tree, test_swc_tree
995996

996997

997998
def pyneval_diadem_metric(gold_swc, test_swc, config):
@@ -1010,7 +1011,7 @@ def pyneval_diadem_metric(gold_swc, test_swc, config):
10101011
gold_tree.load_list(read_swc.adjust_swcfile(gold_swc))
10111012
test_tree.load_list(read_swc.adjust_swcfile(test_swc))
10121013

1013-
diadem_res= diadem_metric(gold_swc_tree=gold_tree,
1014+
diadem_res = diadem_metric(gold_swc_tree=gold_tree,
10141015
test_swc_tree=test_tree,
10151016
config=config)
10161017

@@ -1030,8 +1031,8 @@ def pyneval_diadem_metric(gold_swc, test_swc, config):
10301031
testTree = swc_node.SwcTree()
10311032
goldTree = swc_node.SwcTree()
10321033

1033-
goldTree.load("../../data/test_data/topo_metric_data/ExampleGoldStandard.swc")
1034-
testTree.load("../../data/test_data/topo_metric_data/ExampleTest.swc")
1034+
goldTree.load("../../data/test_data/topo_metric_data/gold_fake_data3.swc")
1035+
testTree.load("../../data/test_data/topo_metric_data/test_fake_data3.swc")
10351036
config_utils.get_default_threshold(goldTree)
10361037
config = read_json.read_json("../../config/diadem_metric.json")
10371038
config_schema = read_json.read_json("../../config/schemas/diadem_metric_schema.json")
@@ -1041,9 +1042,9 @@ def pyneval_diadem_metric(gold_swc, test_swc, config):
10411042
except Exception as e:
10421043
raise Exception("[Error: ]Error in analyzing config json file")
10431044

1044-
diadem_result = diadem_metric(test_swc_tree=testTree,
1045-
gold_swc_tree=goldTree,
1046-
config=config)
1045+
diadem_result, tmp1, tmp2 = diadem_metric(test_swc_tree=testTree,
1046+
gold_swc_tree=goldTree,
1047+
config=config)
10471048
print("matched weight = {}\n"
10481049
"total weight = {}\n"
10491050
"diadem score = {}\n".

pyneval/metric/length_metric.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def length_metric(gold_swc_tree, test_swc_tree, config):
9393

9494
gold_swc_tree.z_rescale(z_scale)
9595
test_swc_tree.z_rescale(z_scale)
96+
gold_swc_tree.set_node_type_by_topo(root_id=1)
97+
test_swc_tree.set_node_type_by_topo(root_id=5)
9698

9799
if rad_mode == 1:
98100
rad_threshold *= -1
@@ -113,7 +115,7 @@ def length_metric(gold_swc_tree, test_swc_tree, config):
113115
"recall": recall,
114116
"precision": precision
115117
}
116-
return res
118+
return res, gold_swc_tree, test_swc_tree
117119

118120

119121
def web_length_metric(gold_swc, test_swc, mode, rad_threshold, len_threshold):

pyneval/metric/link_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def link_metric(gold_swc_tree, test_swc_tree, config):
8888
"edge_loss": edge_loss,
8989
"tree_dis_loss": tree_dis_loss
9090
}
91-
return res
91+
return res, None, None
9292

9393

9494
if __name__ == "__main__":

pyneval/metric/ssd_metric.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,6 @@ def ssd_metric(gold_swc_tree: swc_node.SwcTree, test_swc_tree: swc_node.SwcTree,
119119
t2g_score, t2g_num = get_mse(src_tree=u_test_swc_tree, tar_tree=u_gold_swc_tree,
120120
ssd_threshold=ssd_threshold, mode=threshold_mode)
121121

122-
if "detail_path" in config:
123-
swc_writer.swc_save(u_gold_swc_tree, config["detail_path"][:-4] + "_gold_upsampled.swc")
124-
swc_writer.swc_save(u_test_swc_tree, config["detail_path"][:-4] + "_test_upsampled.swc")
125-
126122
if debug:
127123
print("recall_num = {}, pre_num = {}, gold_tot_num = {}, test_tot_num = {} {} {}".format(
128124
g2t_num, t2g_num, u_gold_swc_tree.size(), u_test_swc_tree.size(), gold_swc_tree.length(), test_swc_tree.length()
@@ -133,6 +129,9 @@ def ssd_metric(gold_swc_tree: swc_node.SwcTree, test_swc_tree: swc_node.SwcTree,
133129
"recall": 1 - g2t_num/u_gold_swc_tree.size(),
134130
"precision": 1 - t2g_num/u_test_swc_tree.size()
135131
}
132+
133+
if "detail" in config:
134+
return res, u_gold_swc_tree, u_test_swc_tree
136135
return res
137136

138137

pyneval/metric/volume_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def volume_metric(gold_swc_tree, test_swc_tree, config=None):
104104
res = {
105105
"recall": recall
106106
}
107-
return res
107+
return res, None, None
108108

109109

110110
if __name__ == "__main__":

0 commit comments

Comments
 (0)