1
+ from sota_extractor2 .models .linking .metrics import Metrics
1
2
from ..models .structure import TableType
2
3
from ..loggers import StructurePredictionEvaluator , LinkerEvaluator , FilteringEvaluator
3
4
import pandas as pd
5
+ import numpy as np
6
+ from ..helpers .jupyter import table_to_html
7
+ from sota_extractor2 .models .linking .format import extract_value
4
8
5
9
6
- class TableTypeExplainer :
10
+ class Reason :
11
+ pass
12
+
13
+
14
+ class IrrelevantTable (Reason ):
7
15
def __init__ (self , paper , table , table_type , probs ):
8
16
self .paper = paper
9
17
self .table = table
10
18
self .table_type = table_type
11
19
self .probs = pd .DataFrame (probs , columns = ["type" , "probability" ])
12
20
13
21
def __str__ (self ):
14
- return f"Table { self .table .name } was labelled as { self .table_type } ."
22
+ return f"Table { self .table .name } was labelled as { self .table_type .name } ."
23
+
24
+ def _repr_html_ (self ):
25
+ prediction = f'<div>{ self } </div>'
26
+ caption = f'<div>Caption: { self .table .caption } </div>'
27
+ probs = self .probs .style .format ({"probability" : "{:.2f}" })._repr_html_ ()
28
+ return prediction + caption + probs
29
+
30
+
31
+ class MislabeledCell (Reason ):
32
+ def __init__ (self , paper , table , row , col , probs ):
33
+ self .paper = paper
34
+ self .table = table
35
+
36
+
37
+ class TableExplanation :
38
+ def __init__ (self , paper , table , table_type , proposals , reasons , topk ):
39
+ self .paper = paper
40
+ self .table = table
41
+ self .table_type = table_type
42
+ self .proposals = proposals
43
+ self .reasons = reasons
44
+ self .topk = topk
45
+
46
+ def _format_tooltip (self , proposal ):
47
+ return f"dataset: { proposal .dataset } \n " \
48
+ f"metric: { proposal .metric } \n " \
49
+ f"task: { proposal .task } \n " \
50
+ f"score: { proposal .parsed } \n " \
51
+ f"confidence: { proposal .confidence :0.2f} "
52
+
53
+ def _format_topk (self , topk ):
54
+ return ""
15
55
16
- def display (self ):
17
- print (self )
18
- self .probs .display ()
56
+ def _repr_html_ (self ):
57
+ matrix = self .table .matrix_html .values
58
+ predictions = np .zeros_like (matrix , dtype = object )
59
+ tooltips = np .zeros_like (matrix , dtype = object )
60
+ for cell_ext_id , proposal in self .proposals .iterrows ():
61
+ paper_id , table_name , rc = cell_ext_id .split ("/" )
62
+ row , col = [int (x ) for x in rc .split ('.' )]
63
+ if cell_ext_id in self .reasons :
64
+ reason = self .reasons [cell_ext_id ]
65
+ tooltips [row , col ] = reason
66
+ if reason .startswith ("replaced by " ):
67
+ tooltips [row , col ] += "\n \n " + self ._format_tooltip (proposal )
68
+ elif reason .startswith ("confidence " ):
69
+ tooltips [row , col ] += "\n \n " + self ._format_topk (self .topk [row , col ])
70
+ else :
71
+ predictions [row , col ] = 'final-proposal'
72
+ tooltips [row , col ] = self ._format_tooltip (proposal )
73
+
74
+ table_type_html = f'<div>Table { self .table .name } was labelled as { self .table_type .name } .</div>'
75
+ caption_html = f'<div>Caption: { self .table .caption } </div>'
76
+ table_html = table_to_html (matrix ,
77
+ self .table .matrix_tags .values ,
78
+ self .table .matrix_layout .values ,
79
+ predictions ,
80
+ tooltips )
81
+ html = table_type_html + caption_html + table_html
82
+ proposals = self .proposals [~ self .proposals .index .isin (self .reasons .index )]
83
+ if len (proposals ):
84
+ proposals = proposals [["dataset" , "metric" , "task" , "model" , "parsed" ]]\
85
+ .reset_index (drop = True ).rename (columns = {"parsed" : "score" })
86
+ html2 = proposals ._repr_html_ ()
87
+ return f"<div><div>{ html } </div><div>Proposals</div><div>{ html2 } </div></div>"
88
+ return html
19
89
20
90
21
91
class Explainer :
92
+ _sota_record_columns = ['task' , 'dataset' , 'metric' , 'format' , 'model' , 'model_type' , 'raw_value' , 'parsed' ]
93
+
22
94
def __init__ (self , pipeline_logger , paper_collection ):
95
+ self .paper_collection = paper_collection
23
96
self .spe = StructurePredictionEvaluator (pipeline_logger , paper_collection )
24
97
self .le = LinkerEvaluator (pipeline_logger , paper_collection )
25
98
self .fe = FilteringEvaluator (pipeline_logger )
@@ -29,15 +102,94 @@ def explain(self, paper, cell_ext_id):
29
102
if paper .paper_id != paper_id :
30
103
return "No such cell"
31
104
32
- row , col = [int (x ) for x in rc .split ('.' )]
33
-
34
105
table_type , probs = self .spe .get_table_type_predictions (paper_id , table_name )
35
106
36
107
if table_type == TableType .IRRELEVANT :
37
- return TableTypeExplainer (paper , paper .table_by_name (table_name ), table_type , probs )
108
+ return IrrelevantTable (paper , paper .table_by_name (table_name ), table_type , probs )
109
+
110
+ all_proposals = self .le .proposals [paper_id ]
111
+ reasons = self .fe .reason
112
+ table_ext_id = f"{ paper_id } /{ table_name } "
113
+ table_proposals = all_proposals [all_proposals .index .str .startswith (table_ext_id + "/" )]
114
+ topk = {(row , col ): topk for (pid , tn , row , col ), topk in self .le .topk .items ()
115
+ if (pid , tn ) == (paper_id , table_name )}
116
+
117
+ return TableExplanation (paper , paper .table_by_name (table_name ), table_type , table_proposals , reasons , topk )
118
+
119
+ row , col = [int (x ) for x in rc .split ('.' )]
38
120
39
121
reason = self .fe .reason .get (cell_ext_id )
40
122
if reason is None :
41
123
pass
42
124
else :
43
125
return reason
126
+
127
+ def _get_table_sota_records (self , table ):
128
+
129
+ first_model = lambda x : ([a for a in x if a .startswith ('model' )] + ['' ])[0 ]
130
+ if len (table .sota_records ):
131
+ matrix = table .matrix .values
132
+ tags = table .matrix_tags
133
+ model_type_col = tags .apply (first_model )
134
+ model_type_row = tags .T .apply (first_model )
135
+ sota_records = table .sota_records .copy ()
136
+ sota_records ['model_type' ] = ''
137
+ sota_records ['raw_value' ] = ''
138
+ for cell_ext_id , record in sota_records .iterrows ():
139
+ name , rc = cell_ext_id .split ('/' )
140
+ row , col = [int (x ) for x in rc .split ('.' )]
141
+ record .model_type = model_type_col [col ] or model_type_row [row ]
142
+ record .raw_value = matrix [row , col ]
143
+
144
+ sota_records ["parsed" ] = sota_records [["raw_value" , "format" ]].apply (
145
+ lambda row : float (extract_value (row .raw_value , row .format )), axis = 1 )
146
+
147
+ sota_records = sota_records [sota_records ["parsed" ] == sota_records ["parsed" ]]
148
+
149
+ strip_cols = ["task" , "dataset" , "format" , "metric" , "raw_value" , "model" , "model_type" ]
150
+ sota_records = sota_records .transform (
151
+ lambda x : x .str .strip () if x .name in strip_cols else x )
152
+ return sota_records [self ._sota_record_columns ]
153
+ else :
154
+ empty = pd .DataFrame (columns = self ._sota_record_columns )
155
+ empty .index .rename ("cell_ext_id" , inplace = True )
156
+ return empty
157
+
158
+ def _get_sota_records (self , paper ):
159
+ if not len (paper .tables ):
160
+ empty = pd .DataFrame (columns = self ._sota_record_columns )
161
+ empty .index .rename ("cell_ext_id" , inplace = True )
162
+ return empty
163
+ records = [self ._get_table_sota_records (table ) for table in paper .tables ]
164
+ records = pd .concat (records )
165
+ records .index = paper .paper_id + "/" + records .index
166
+ records .index .rename ("cell_ext_id" , inplace = True )
167
+ return records
168
+
169
+ def linking_metrics (self , experiment_name = "unk" ):
170
+ paper_ids = list (self .le .proposals .keys ())
171
+
172
+ proposals = pd .concat (self .le .proposals .values ())
173
+ proposals = proposals [~ proposals .index .isin (self .fe .reason .index )]
174
+
175
+ papers = {paper_id : self .paper_collection .get_by_id (paper_id ) for paper_id in paper_ids }
176
+ missing = [paper_id for paper_id , paper in papers .items () if paper is None ]
177
+ if missing :
178
+ print ("Missing papers in paper collection:" )
179
+ print (", " .join (missing ))
180
+ papers = [paper for paper in papers .values () if paper is not None ]
181
+
182
+ if not len (papers ):
183
+ gold_sota_records = pd .DataFrame (columns = self ._sota_record_columns )
184
+ gold_sota_records .index .rename ("cell_ext_id" , inplace = True )
185
+ else :
186
+ gold_sota_records = pd .concat ([self ._get_sota_records (paper ) for paper in papers ])
187
+
188
+ df = gold_sota_records .merge (proposals , 'outer' , left_index = True , right_index = True , suffixes = ['_gold' , '_pred' ])
189
+ df = df .reindex (sorted (df .columns ), axis = 1 )
190
+ df = df .fillna ('not-present' )
191
+ if "experiment_name" in df .columns :
192
+ del df ["experiment_name" ]
193
+
194
+ metrics = Metrics (df , experiment_name = experiment_name )
195
+ return metrics
0 commit comments