3
3
4
4
from sota_extractor2 .models .linking .acronym_extractor import AcronymExtractor
5
5
from sota_extractor2 .models .linking .probs import get_probs , reverse_probs
6
- from sota_extractor2 .models .linking .utils import normalize_dataset , normalize_cell , normalize_cell_ws
6
+ from sota_extractor2 .models .linking .utils import normalize_dataset , normalize_dataset_ws , normalize_cell , normalize_cell_ws
7
7
from scipy .special import softmax
8
8
import re
9
9
import pandas as pd
13
13
14
14
from sota_extractor2 .pipeline_logger import pipeline_logger
15
15
16
- from sota_extractor2 .models .linking .manual_dicts import metrics , datasets , tasks
17
-
18
- datasets = {k :(v + ['test' ]) for k ,v in datasets .items ()}
19
- datasets .update ({
20
- 'LibriSpeech dev-clean' : ['libri speech dev clean' , 'libri speech' , 'dev' , 'clean' , 'dev clean' , 'development' ],
21
- 'LibriSpeech dev-other' : ['libri speech dev other' , 'libri speech' , 'dev' , 'other' , 'dev other' , 'development' , 'noisy' ],
22
- })
23
-
24
- # escaped_ws_re = re.compile(r'\\\s+')
25
- # def name_to_re(name):
26
- # return re.compile(r'(?:^|\s+)' + escaped_ws_re.sub(r'\\s*', re.escape(name.strip())) + r'(?:$|\s+)', re.I)
27
-
28
- #all_datasets = set(k for k,v in merged_p.items() if k != '' and not re.match("^\d+$", k) and v.get('NOMATCH', 0.0) < 0.9)
29
- all_datasets = set (normalize_cell_ws (normalize_dataset (y )) for x in datasets .values () for y in x )
30
- all_metrics = set (normalize_cell_ws (y ) for x in metrics .values () for y in x )
31
- all_tasks = set (normalize_cell_ws (normalize_dataset (y )) for x in tasks .values () for y in x )
32
-
33
- #all_metrics = set(metrics_p.keys())
34
-
35
- # all_datasets_re = {x:name_to_re(x) for x in all_datasets}
36
- # all_metrics_re = {x:name_to_re(x) for x in all_metrics}
37
- #all_datasets = set(x for v in merged_p.values() for x in v)
38
-
39
- # def find_names(text, names_re):
40
- # return set(name for name, name_re in names_re.items() if name_re.search(text))
41
-
42
-
43
- def make_trie (names ):
44
- trie = ahocorasick .Automaton ()
45
- for name in names :
46
- norm = name .replace (" " , "" )
47
- trie .add_word (norm , (len (norm ), name ))
48
- trie .make_automaton ()
49
- return trie
50
-
51
-
52
- single_letter_re = re .compile (r"\b\w\b" )
53
- init_letter_re = re .compile (r"\b\w" )
54
- end_letter_re = re .compile (r"\w\b" )
55
- letter_re = re .compile (r"\w" )
56
-
57
-
58
- def find_names (text , names_trie ):
59
- text = text .lower ()
60
- profile = letter_re .sub ("i" , text )
61
- profile = init_letter_re .sub ("b" , profile )
62
- profile = end_letter_re .sub ("e" , profile )
63
- profile = single_letter_re .sub ("x" , profile )
64
- text = text .replace (" " , "" )
65
- profile = profile .replace (" " , "" )
66
- s = set ()
67
- for (end , (l , word )) in names_trie .iter (text ):
68
- if profile [end ] in ['e' , 'x' ] and profile [end - l + 1 ] in ['b' , 'x' ]:
69
- s .add (word )
70
- return s
71
-
72
-
73
- all_datasets_trie = make_trie (all_datasets )
74
- all_metrics_trie = make_trie (all_metrics )
75
- all_tasks_trie = make_trie (all_tasks )
76
-
77
-
78
- def find_datasets (text ):
79
- return find_names (text , all_datasets_trie )
80
-
81
- def find_metrics (text ):
82
- return find_names (text , all_metrics_trie )
83
-
84
- def find_tasks (text ):
85
- return find_names (text , all_tasks_trie )
16
+ from sota_extractor2 .models .linking import manual_dicts
86
17
87
18
def dummy_item (reason ):
88
19
return pd .DataFrame (dict (dataset = [reason ], task = [reason ], metric = [reason ], evidence = ["" ], confidence = [0.0 ]))
89
20
90
21
22
+ class EvidenceFinder :
23
+ single_letter_re = re .compile (r"\b\w\b" )
24
+ init_letter_re = re .compile (r"\b\w" )
25
+ end_letter_re = re .compile (r"\w\b" )
26
+ letter_re = re .compile (r"\w" )
27
+
28
+ def __init__ (self , taxonomy ):
29
+ self ._init_structs (taxonomy )
30
+
31
+ @staticmethod
32
+ def evidences_from_name (key ):
33
+ x = normalize_dataset_ws (key )
34
+ y = x .split ()
35
+ return [x ] + y if len (y ) > 1 else [x ]
36
+
37
+ @staticmethod
38
+ def get_basic_dicts (taxonomy ):
39
+ tasks = {ts : [normalize_dataset_ws (ts )] for ts in taxonomy .tasks }
40
+ datasets = {ds : EvidenceFinder .evidences_from_name (ds ) for ds in taxonomy .datasets }
41
+ metrics = {ms : EvidenceFinder .evidences_from_name (ms ) for ms in taxonomy .metrics }
42
+ return tasks , datasets , metrics
43
+
44
+ @staticmethod
45
+ def merge_evidences (target , source ):
46
+ for name , evs in source .items ():
47
+ target .setdefault (name , []).extend (evs )
48
+
49
+ @staticmethod
50
+ def make_trie (names ):
51
+ trie = ahocorasick .Automaton ()
52
+ for name in names :
53
+ norm = name .replace (" " , "" )
54
+ trie .add_word (norm , (len (norm ), name ))
55
+ trie .make_automaton ()
56
+ return trie
57
+
58
+ @staticmethod
59
+ def find_names (text , names_trie ):
60
+ text = text .lower ()
61
+ profile = EvidenceFinder .letter_re .sub ("i" , text )
62
+ profile = EvidenceFinder .init_letter_re .sub ("b" , profile )
63
+ profile = EvidenceFinder .end_letter_re .sub ("e" , profile )
64
+ profile = EvidenceFinder .single_letter_re .sub ("x" , profile )
65
+ text = text .replace (" " , "" )
66
+ profile = profile .replace (" " , "" )
67
+ s = set ()
68
+ for (end , (l , word )) in names_trie .iter (text ):
69
+ if profile [end ] in ['e' , 'x' ] and profile [end - l + 1 ] in ['b' , 'x' ]:
70
+ s .add (word )
71
+ return s
72
+
73
+ def find_datasets (self , text ):
74
+ return EvidenceFinder .find_names (text , self .all_datasets_trie )
75
+
76
+ def find_metrics (self , text ):
77
+ return EvidenceFinder .find_names (text , self .all_metrics_trie )
78
+
79
+ def find_tasks (self , text ):
80
+ return EvidenceFinder .find_names (text , self .all_tasks_trie )
81
+
82
+ def _init_structs (self , taxonomy ):
83
+ self .tasks , self .datasets , self .metrics = EvidenceFinder .get_basic_dicts (taxonomy )
84
+ EvidenceFinder .merge_evidences (self .tasks , manual_dicts .tasks )
85
+ EvidenceFinder .merge_evidences (self .datasets , manual_dicts .datasets )
86
+ EvidenceFinder .merge_evidences (self .metrics , manual_dicts .metrics )
87
+ self .datasets = {k : (v + ['test' ] if 'val' not in k else v + ['validation' , 'dev' , 'development' ]) for k , v in
88
+ self .datasets .items ()}
89
+ self .datasets .update ({
90
+ 'LibriSpeech dev-clean' : ['libri speech dev clean' , 'libri speech' , 'dev' , 'clean' , 'dev clean' , 'development' ],
91
+ 'LibriSpeech dev-other' : ['libri speech dev other' , 'libri speech' , 'dev' , 'other' , 'dev other' , 'development' , 'noisy' ],
92
+ })
93
+
94
+ self .datasets = {k : set (v ) for k , v in self .datasets .items ()}
95
+ self .metrics = {k : set (v ) for k , v in self .metrics .items ()}
96
+ self .tasks = {k : set (v ) for k , v in self .tasks .items ()}
97
+
98
+ self .all_datasets = set (normalize_cell_ws (normalize_dataset (y )) for x in self .datasets .values () for y in x )
99
+ self .all_metrics = set (normalize_cell_ws (y ) for x in self .metrics .values () for y in x )
100
+ self .all_tasks = set (normalize_cell_ws (normalize_dataset (y )) for x in self .tasks .values () for y in x )
101
+
102
+ self .all_datasets_trie = EvidenceFinder .make_trie (self .all_datasets )
103
+ self .all_metrics_trie = EvidenceFinder .make_trie (self .all_metrics )
104
+ self .all_tasks_trie = EvidenceFinder .make_trie (self .all_tasks )
91
105
92
106
@njit
93
107
def compute_logprobs (taxonomy , reverse_merged_p , reverse_metrics_p , reverse_task_p ,
@@ -114,17 +128,18 @@ def compute_logprobs(taxonomy, reverse_merged_p, reverse_metrics_p, reverse_task
114
128
115
129
116
130
class ContextSearch :
117
- def __init__ (self , taxonomy , context_noise = (0.5 , 0.2 , 0.1 ), metrics_noise = None , task_noise = None ,
131
+ def __init__ (self , taxonomy , evidence_finder , context_noise = (0.5 , 0.2 , 0.1 ), metrics_noise = None , task_noise = None ,
118
132
ds_pb = 0.001 , ms_pb = 0.01 , ts_pb = 0.01 , debug_gold_df = None ):
119
133
merged_p = \
120
- get_probs ({k : Counter ([normalize_cell (normalize_dataset (x )) for x in v ]) for k , v in datasets .items ()})[1 ]
134
+ get_probs ({k : Counter ([normalize_cell (normalize_dataset (x )) for x in v ]) for k , v in evidence_finder . datasets .items ()})[1 ]
121
135
metrics_p = \
122
- get_probs ({k : Counter ([normalize_cell (normalize_dataset (x )) for x in v ]) for k , v in metrics .items ()})[1 ]
136
+ get_probs ({k : Counter ([normalize_cell (normalize_dataset (x )) for x in v ]) for k , v in evidence_finder . metrics .items ()})[1 ]
123
137
tasks_p = \
124
- get_probs ({k : Counter ([normalize_cell (normalize_dataset (x )) for x in v ]) for k , v in tasks .items ()})[1 ]
138
+ get_probs ({k : Counter ([normalize_cell (normalize_dataset (x )) for x in v ]) for k , v in evidence_finder . tasks .items ()})[1 ]
125
139
126
140
self .queries = {}
127
141
self .taxonomy = taxonomy
142
+ self .evidence_finder = evidence_finder
128
143
self ._taxonomy = typed .List ()
129
144
for t in self .taxonomy .taxonomy :
130
145
self ._taxonomy .append (t )
@@ -158,9 +173,9 @@ def compute_context_logprobs(self, context, noise, ms_noise, ts_noise, logprobs)
158
173
context = context or ""
159
174
abbrvs = self .extract_acronyms (context )
160
175
context = normalize_cell_ws (normalize_dataset (context ))
161
- dss = set (find_datasets (context )) | set (abbrvs .keys ())
162
- mss = set (find_metrics (context ))
163
- tss = set (find_tasks (context ))
176
+ dss = set (self . evidence_finder . find_datasets (context )) | set (abbrvs .keys ())
177
+ mss = set (self . evidence_finder . find_metrics (context ))
178
+ tss = set (self . evidence_finder . find_tasks (context ))
164
179
dss -= mss
165
180
dss -= tss
166
181
dss = [normalize_cell (ds ) for ds in dss ]
@@ -248,7 +263,8 @@ def __call__(self, query, datasets, caption, topk=1, debug_info=None):
248
263
249
264
# todo: compare regex approach (old) with find_datasets(.) (current)
250
265
class DatasetExtractor :
251
- def __init__ (self ):
266
+ def __init__ (self , evidence_finder ):
267
+ self .evidence_finder = evidence_finder
252
268
self .dataset_prefix_re = re .compile (r"[A-Z]|[a-z]+[A-Z]+|[0-9]" )
253
269
self .dataset_name_re = re .compile (r"\b(the)\b\s*(?P<name>((?!(the)\b)\w+\W+){1,10}?)(test|val(\.|idation)?|dev(\.|elopment)?|train(\.|ing)?\s+)?\bdata\s*set\b" , re .IGNORECASE )
254
270
@@ -260,4 +276,4 @@ def from_paper(self, paper):
260
276
261
277
def __call__ (self , text ):
262
278
text = normalize_cell_ws (normalize_dataset (text ))
263
- return find_datasets (text ) | find_tasks (text )
279
+ return self . evidence_finder . find_datasets (text ) | self . evidence_finder . find_tasks (text )
0 commit comments