3434- config
3535- cdb
3636"""
37+ import re
3738from typing import Union , Tuple , Any , List , Iterable , Optional , Dict
3839import logging
3940
@@ -62,9 +63,11 @@ class DeIdModel(NerModel):
6263 def __init__ (self , cat : CAT ) -> None :
6364 self .cat = cat
6465
65- def train (self , json_path : Union [str , list , None ],
66+ def train (self , json_path : Union [str , list , None ] = None ,
6667 * args , ** kwargs ) -> Tuple [Any , Any , Any ]:
67- return super ().train (json_path , * args , train_nr = 0 , ** kwargs ) # type: ignore
68+ assert not all ([json_path , kwargs .get ('train_json_path' ), kwargs .get ('test_json_path' )]), \
69+ "Either json_path or train_json_path and test_json_path must be provided when no dataset is provided"
70+ return super ().train (json_path = json_path , * args , ** kwargs ) # type: ignore
6871
6972 def eval (self , json_path : Union [str , list , None ],
7073 * args , ** kwargs ) -> Tuple [Any , Any , Any ]:
@@ -146,7 +149,8 @@ def deid_multi_texts(self,
146149 return out
147150
148151 @classmethod
149- def load_model_pack (cls , model_pack_path : str , config : Optional [Dict ] = None ) -> 'DeIdModel' :
152+ def load_model_pack (cls , model_pack_path : str ,
153+ config : Optional [Dict ] = None ) -> 'DeIdModel' :
150154 """Load DeId model from model pack.
151155
152156 The method first loads the CAT instance.
@@ -164,7 +168,7 @@ def load_model_pack(cls, model_pack_path: str, config: Optional[Dict] = None) ->
164168 Returns:
165169 DeIdModel: The resulting DeI model.
166170 """
167- ner_model = NerModel .load_model_pack (model_pack_path ,config = config )
171+ ner_model = NerModel .load_model_pack (model_pack_path , config = config )
168172 cat = ner_model .cat
169173 if not cls ._is_deid_model (cat ):
170174 raise ValueError (
@@ -180,7 +184,135 @@ def _is_deid_model(cls, cat: CAT) -> bool:
180184 @classmethod
181185 def _get_reason_not_deid (cls , cat : CAT ) -> str :
182186 if cat .vocab is not None :
183- return "Has vocab "
187+ return "Has voc§ab "
184188 if len (cat ._addl_ner ) != 1 :
185189 return f"Incorrect number of addl_ner: { len (cat ._addl_ner )} "
186190 return ""
191+
192+
193+ def match_rules (rules : List [Tuple [str , str ]], texts : List [str ], cui2preferred_name : Dict [str , str ]) -> List [List [Dict ]]:
194+ """Match a set of rules - pat / cui combos as post processing labels.
195+
196+ Uses a cat DeID model for pretty name mapping.
197+
198+ Args:
199+ rules (List[Tuple[str, str]]): List of tuples of pattern and cui
200+ texts (List[str]): List of texts to match rules on
201+ cui2preferred_name (Dict[str, str]): Dictionary of CUI to preferred name, likely to be cat.cdb.cui2preferred_name.
202+
203+ Examples:
204+ >>> cat = CAT.load_model_pack(model_pack_path)
205+ ...
206+ >>> rules = [
207+ ('(123) 456-7890', '134'),
208+ ('1234567890', '134'),
209+ ('123.456.7890', '134'),
210+ ('1234567890', '134'),
211+ ('1234567890', '134'),
212+ ]
213+ >>> texts = [
214+ 'My phone number is (123) 456-7890',
215+ 'My phone number is 1234567890',
216+ 'My phone number is 123.456.7890',
217+ 'My phone number is 1234567890',
218+ ]
219+ >>> matches = match_rules(rules, texts, cat.cdb.cui2preferred_name)
220+
221+ Returns:
222+ List[List[Dict]]: List of lists of predictions from `match_rules`
223+ """
224+ # Iterate through each text and pattern combination
225+ rule_matches_per_text = []
226+ for i , text in enumerate (texts ):
227+ matches_in_text = []
228+ for pattern , concept in rules :
229+ # Find all matches of current pattern in current text
230+ text_matches = re .finditer (pattern , text , flags = re .M )
231+ # Add each match with its pattern and text info
232+ for match in text_matches :
233+ matches_in_text .append ({
234+ 'source_value' : match .group (),
235+ 'pretty_name' : cui2preferred_name [concept ],
236+ 'start' : match .start (),
237+ 'end' : match .end (),
238+ 'cui' : concept ,
239+ 'acc' : 1.0
240+ })
241+ rule_matches_per_text .append (matches_in_text )
242+ return rule_matches_per_text
243+
244+
245+ def merge_all_preds (model_preds_by_text : List [List [Dict ]],
246+ rule_matches_per_text : List [List [Dict ]],
247+ accept_preds : bool = True ) -> List [List [Dict ]]:
248+ """Conveniance method to merge predictions from rule based and deID model predictions.
249+
250+ Args:
251+ model_preds_by_text (List[Dict]): list of predictions from
252+ `cat.get_entities()`, then `[list(m['entities'].values()) for m in model_preds]`
253+ rule_matches_per_text (List[Dict]): list of predictions from output of
254+ running `match_rules`
255+ accept_preds (bool): uses the predicted label from the model,
256+ model_preds_by_text, over the rule matches if they overlap.
257+ Defaults to using model preds over rules.
258+
259+ Returns:
260+ List[List[Dict]]: List of lists of predictions from `merge_all_preds`
261+ """
262+ assert len (model_preds_by_text ) == len (rule_matches_per_text ), \
263+ "model_preds_by_text and rule_matches_per_text must have the same length as they should be CAT.get_entities and match_rules outputs of the same text"
264+ return [merge_preds (model_preds_by_text [i ], rule_matches_per_text [i ], accept_preds ) for i in range (len (model_preds_by_text ))]
265+
266+
267+ def merge_preds (model_preds : List [Dict ],
268+ rule_matches : List [Dict ],
269+ accept_preds : bool = True ) -> List [Dict ]:
270+ """Merge predictions from rule based and deID model predictions.
271+
272+ Args:
273+ model_preds (List[Dict]): predictions from `cat.get_entities()`
274+ rule_matches (List[Dict]): predictions from output of running `match_rules` on a text
275+ accept_preds (bool): uses the predicted label from the model,
276+ model_preds, over the rule matches if they overlap.
277+ Defaults to using model preds over rules.
278+
279+ Examples:
280+ >>> # a list of predictions from `cat.get_entities()`
281+ >>> model_preds = [
282+ [
283+ {'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0,
284+ 'pretty_name': 'Phone Number'},
285+ {'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0,
286+ 'pretty_name': 'Phone Number'}
287+ ]
288+ ]
289+ >>> # a list of predictions from `match_rules`
290+ >>> rule_matches = [
291+ [
292+ {'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0,
293+ 'pretty_name': 'Phone Number'},
294+ {'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0,
295+ 'pretty_name': 'Phone Number'}
296+ ]
297+ ]
298+ >>> merged_preds = merge_preds(model_preds, rule_matches)
299+
300+ Returns:
301+ List[Dict]: List of predictions from `merge_preds`
302+ """
303+ if accept_preds :
304+ labels1 = model_preds
305+ labels2 = rule_matches
306+ else :
307+ labels1 = rule_matches
308+ labels2 = model_preds
309+
310+ # Keep only non-overlapping model predictions
311+ labels2 = [span2 for span2 in labels2
312+ if not any (not (span2 ['end' ] <= span1 ['start' ] or span1 ['end' ] <= span2 ['start' ])
313+ for span1 in labels1 )]
314+ # merge preds and sort on start
315+ merged_preds = labels1 + labels2
316+ merged_preds .sort (key = lambda x : x ['start' ])
317+ merged_preds
318+ return merged_preds
0 commit comments