@@ -780,24 +780,24 @@ def color_tree_only():
780780 if g_list_miss :
781781 if len (g_miss ) > 0 :
782782 for node in g_miss :
783- # 3 means this node is missed
784- node .data ._type = 2
783+ # 9 means this node is missed
784+ node .data ._type = 9
785785 if len (g_excess_nodes ) > 0 :
786786 for node in g_excess_nodes .keys ():
787- # 4 means this node is excessive
788- node .data ._type = 3
787+ # 10 means this node is excessive
788+ node .data ._type = 10
789789
790790 if g_list_continuations :
791791 if len (g_continuation ) > 0 :
792792 for node in g_continuation :
793- # 5 means this node is a continuation
794- node .data ._type = 4
793+ # 11 means this node is a continuation
794+ node .data ._type = 11
795795
796796 if g_list_distant_matches :
797797 if len (g_distance_match ) > 0 :
798798 for node in g_distance_match :
799- # 6 means this node is a distant match
800- node .data ._type = 5
799+ # 12 means this node is a distant match
800+ node .data ._type = 12
801801
802802
803803def print_result ():
@@ -821,8 +821,8 @@ def print_result():
821821 print ("node_ID = {} poi = {} weight = {}" .format (
822822 node .data .get_id (), node .data ._pos , g_weight_dict [node ]
823823 ))
824- # 3 means this node is missed
825- node .data ._type = 2
824+ # 9 means this node is missed
825+ node .data ._type = 9
826826 print ("--END--" )
827827 else :
828828 print ("---Nodes that are missed:None---" )
@@ -835,8 +835,8 @@ def print_result():
835835 print ("node_ID = {} poi = {} weight = {}" .format (
836836 node .data .get_id (), node .data ._pos , g_excess_nodes [node ]
837837 ))
838- # 4 means this node is excessive
839- node .data ._type = 3
838+ # 10 means this node is excessive
839+ node .data ._type = 10
840840 else :
841841 print ("---extra Nodes in test reconstruction: None---" )
842842
@@ -848,8 +848,8 @@ def print_result():
848848 print ("node_ID = {} poi = {} weight = {}" .format (
849849 node .data .get_id (), node .data ._pos , g_weight_dict [node ]
850850 ))
851- # 5 means this node is a continuation
852- node .data ._type = 4
851+ # 11 means this node is a continuation
852+ node .data ._type = 11
853853 else :
854854 print ("---continuation Nodes None---" )
855855
@@ -861,8 +861,8 @@ def print_result():
861861 print ("node_ID = {} poi = {} weight = {}" .format (
862862 node .data .get_id (), node .data ._pos , g_weight_dict [node ]
863863 ))
864- # 6 means this node is a distant match
865- node .data ._type = 5
864+ # 12 means this node is a distant match
865+ node .data ._type = 12
866866 else :
867867 print ("Distant Matches: none" )
868868
@@ -935,8 +935,8 @@ def diadem_metric(gold_swc_tree, test_swc_tree, config):
935935 """
936936 global g_spur_set
937937 global g_weight_dict
938- gold_swc_tree .type_clear ( 0 )
939- test_swc_tree .type_clear ( 1 )
938+ gold_swc_tree .set_node_type_by_topo ( root_id = 1 )
939+ test_swc_tree .set_node_type_by_topo ( root_id = 5 )
940940 diadem_init ()
941941 config_init (config )
942942 diadam_match_utils .diadem_utils_init (config )
@@ -979,6 +979,7 @@ def diadem_metric(gold_swc_tree, test_swc_tree, config):
979979 print ('match1 = {}, match2 = {}' .format (
980980 key .data .get_id (), g_matches [key ].data .get_id ()
981981 ))
982+ color_tree_only ()
982983 if debug :
983984 for k in g_weight_dict :
984985 print ("id = {} wt = {}" .format (k .data .get_id (), g_weight_dict [k ]))
@@ -991,7 +992,7 @@ def diadem_metric(gold_swc_tree, test_swc_tree, config):
991992 "score_sum" : g_score_sum ,
992993 "final_score" : g_final_score
993994 }
994- return res
995+ return res , gold_swc_tree , test_swc_tree
995996
996997
997998def pyneval_diadem_metric (gold_swc , test_swc , config ):
@@ -1010,7 +1011,7 @@ def pyneval_diadem_metric(gold_swc, test_swc, config):
10101011 gold_tree .load_list (read_swc .adjust_swcfile (gold_swc ))
10111012 test_tree .load_list (read_swc .adjust_swcfile (test_swc ))
10121013
1013- diadem_res = diadem_metric (gold_swc_tree = gold_tree ,
1014+ diadem_res = diadem_metric (gold_swc_tree = gold_tree ,
10141015 test_swc_tree = test_tree ,
10151016 config = config )
10161017
@@ -1030,8 +1031,8 @@ def pyneval_diadem_metric(gold_swc, test_swc, config):
10301031 testTree = swc_node .SwcTree ()
10311032 goldTree = swc_node .SwcTree ()
10321033
1033- goldTree .load ("../../data/test_data/topo_metric_data/ExampleGoldStandard .swc" )
1034- testTree .load ("../../data/test_data/topo_metric_data/ExampleTest .swc" )
1034+ goldTree .load ("../../data/test_data/topo_metric_data/gold_fake_data3 .swc" )
1035+ testTree .load ("../../data/test_data/topo_metric_data/test_fake_data3 .swc" )
10351036 config_utils .get_default_threshold (goldTree )
10361037 config = read_json .read_json ("../../config/diadem_metric.json" )
10371038 config_schema = read_json .read_json ("../../config/schemas/diadem_metric_schema.json" )
@@ -1041,9 +1042,9 @@ def pyneval_diadem_metric(gold_swc, test_swc, config):
10411042 except Exception as e :
10421043 raise Exception ("[Error: ]Error in analyzing config json file" )
10431044
1044- diadem_result = diadem_metric (test_swc_tree = testTree ,
1045- gold_swc_tree = goldTree ,
1046- config = config )
1045+ diadem_result , tmp1 , tmp2 = diadem_metric (test_swc_tree = testTree ,
1046+ gold_swc_tree = goldTree ,
1047+ config = config )
10471048 print ("matched weight = {}\n "
10481049 "total weight = {}\n "
10491050 "diadem score = {}\n " .
0 commit comments