2
2
import pandas as pd
3
3
from .models .structure .experiment import Experiment , label_map , Labels
4
4
from .models .structure .type_predictor import TableType
5
+ from copy import deepcopy
6
+ import pickle
5
7
6
8
7
9
class BaseLogger :
@@ -21,30 +23,85 @@ def __call__(self, step, **kwargs):
21
23
print (f"[STEP] { step } : { kwargs } " , file = self .file )
22
24
23
25
26
+ class SessionRecorder :
27
+ def __init__ (self , pipeline_logger ):
28
+ self .pipeline_logger = pipeline_logger
29
+ self .session = []
30
+ self ._recording = False
31
+
32
+ def __call__ (self , step , ** kwargs ):
33
+ self .session .append ((step , deepcopy (kwargs )))
34
+
35
+ def reset (self ):
36
+ self .session = []
37
+
38
+ def record (self ):
39
+ if not self ._recording :
40
+ self .pipeline_logger .register (".*" , self )
41
+ self ._recording = True
42
+
43
+ def stop (self ):
44
+ if self ._recording :
45
+ self .pipeline_logger .unregister (".*" , self )
46
+ self ._recording = False
47
+
48
+ def replay (self ):
49
+ self .stop ()
50
+ for step , kwargs in self .session :
51
+ self .pipeline_logger (step , ** kwargs )
52
+
53
+ def save_session (self , path ):
54
+ with open (path , "wb" ) as f :
55
+ pickle .dump (self .session , f )
56
+
57
+ def load_session (self , path ):
58
+ with open (path , "rb" ) as f :
59
+ self .session = pickle .load (f )
60
+
61
+
24
62
class StructurePredictionEvaluator :
25
63
def __init__ (self , pipeline_logger , pc ):
26
- pipeline_logger .register ("structure_prediction::tables_labelled" , self .on_tables_labelled )
64
+ pipeline_logger .register ("structure_prediction::evidences_split" , self .on_evidences_split )
65
+ pipeline_logger .register ("structure_prediction::tables_labeled" , self .on_tables_labeled )
27
66
pipeline_logger .register ("type_prediction::predicted" , self .on_type_predicted )
67
+ pipeline_logger .register ("type_prediction::multiclass_predicted" , self .on_type_multiclass_predicted )
28
68
self .pc = pc
29
69
self .results = {}
30
70
self .type_predictions = {}
71
+ self .type_multiclass_predictions = {}
72
+ self .evidences = pd .DataFrame ()
73
+
74
+ def on_type_multiclass_predicted (self , step , paper , tables , threshold , predictions ):
75
+ for table , prediction in zip (tables , predictions ):
76
+ self .type_multiclass_predictions [paper .paper_id , table .name ] = {
77
+ TableType .SOTA : prediction [0 ],
78
+ TableType .ABLATION : prediction [1 ],
79
+ TableType .IRRELEVANT : threshold
80
+ }
31
81
32
82
def on_type_predicted (self , step , paper , tables , predictions ):
33
- self .type_predictions [paper .paper_id ] = predictions
83
+ for table , prediction in zip (tables , predictions ):
84
+ self .type_predictions [paper .paper_id , table .name ] = prediction
85
+
86
+ def on_evidences_split (self , step , evidences , evidences_num ):
87
+ self .evidences = pd .concat ([self .evidences , evidences ])
34
88
35
- def on_tables_labelled (self , step , paper , tables ):
89
+ def on_tables_labeled (self , step , paper , labeled_tables ):
36
90
golds = [p for p in self .pc if p .text .title == paper .text .title ]
37
91
paper_id = paper .paper_id
38
92
type_results = []
39
93
cells_results = []
94
+ labeled_tables = {table .name : table for table in labeled_tables }
40
95
if len (golds ) == 1 :
41
96
gold = golds [0 ]
42
- for gold_table , table , table_type in zip (gold .tables , paper .tables , self .type_predictions .get (paper .paper_id , [])):
97
+ for gold_table , table , in zip (gold .tables , paper .tables ):
98
+ table_type = self .type_predictions [paper .paper_id , table .name ]
43
99
is_important = table_type == TableType .SOTA or table_type == TableType .ABLATION
44
100
gold_is_important = "sota" in gold_table .gold_tags or "ablation" in gold_table .gold_tags
45
101
type_results .append ({"predicted" : is_important , "gold" : gold_is_important , "name" : table .name })
46
102
if not is_important :
47
103
continue
104
+ table = labeled_tables [table .name ]
48
105
rows , cols = table .df .shape
49
106
for r in range (rows ):
50
107
for c in range (cols ):
@@ -76,6 +133,14 @@ def metrics(self, paper_id):
76
133
e ._set_results (paper_id , self .map_tags (results ['cells' ].predicted ), self .map_tags (results ['cells' ].gold ))
77
134
e .show_results (paper_id , normalize = True )
78
135
136
+ def get_table_type_predictions (self , paper_id , table_name ):
137
+ prediction = self .type_predictions .get ((paper_id , table_name ))
138
+ multi_predictions = self .type_multiclass_predictions .get ((paper_id , table_name ))
139
+ if prediction is not None :
140
+ multi_predictions = sorted (multi_predictions .items (), key = lambda x : x [1 ], reverse = True )
141
+ return prediction , [(k .name , v ) for k , v in multi_predictions
142
+ ]
143
+
79
144
80
145
class LinkerEvaluator :
81
146
def __init__ (self , pipeline_logger , pc ):
@@ -102,3 +167,18 @@ def on_taxonomy_topk(self, step, ext_id, topk):
102
167
103
168
def top_matches (self , paper_id , table_name , row , col ):
104
169
return self .topk [(paper_id , table_name , row , col )]
170
+
171
+
172
+ class FilteringEvaluator :
173
+ def __init__ (self , pipeline_logger ):
174
+ pipeline_logger .register ("filtering::.*::filtered" , self .on_filtered )
175
+ self .proposals = {}
176
+ self .which = {}
177
+ self .reason = pd .Series (dtype = str )
178
+
179
+ def on_filtered (self , step , proposals , which , reason , ** kwargs ):
180
+ _ , filter_step , _ = step .split ('::' )
181
+ if filter_step != "compound_filtering" :
182
+ self .proposals [filter_step ] = pd .concat (self .proposals .get (filter_step , []) + [proposals ])
183
+ self .which [filter_step ] = pd .concat (self .which .get (filter_step , []) + [which ])
184
+ self .reason = self .reason .append (reason )
0 commit comments