Skip to content

Commit d255984

Browse files
authored
Merge pull request CogStack/MedCAT#541 from CogStack/deid_train_eval
DeID improvements
2 parents 7a62ebe + 01b8ef9 commit d255984

File tree

6 files changed

+483
-28
lines changed

6 files changed

+483
-28
lines changed

medcat-v1/medcat/datasets/transformers_ner.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _info(self):
6464

6565
def _split_generators(self, dl_manager): # noqa
6666
"""Returns SplitGenerators.""" # noqa
67-
return [
67+
splits = [
6868
datasets.SplitGenerator(
6969
name=datasets.Split.TRAIN,
7070
gen_kwargs={
@@ -73,6 +73,19 @@ def _split_generators(self, dl_manager): # noqa
7373
),
7474
]
7575

76+
# Only add test split if test data files are provided
77+
if 'test' in self.config.data_files:
78+
splits.append(
79+
datasets.SplitGenerator(
80+
name=datasets.Split.TEST,
81+
gen_kwargs={
82+
"filepaths": self.config.data_files['test'],
83+
},
84+
)
85+
)
86+
87+
return splits
88+
7689
def _generate_examples(self, filepaths): # noqa
7790
cnt = 0
7891
for filepath in filepaths:

medcat-v1/medcat/ner/transformers_ner.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ def train(self,
177177
ignore_extra_labels=False,
178178
dataset=None,
179179
meta_requirements=None,
180+
train_json_path: Union[str, list, None]=None,
181+
test_json_path: Union[str, list, None]=None,
180182
trainer_callbacks: Optional[List[Callable[[Trainer], TrainerCallback]]] = None) -> Tuple:
181183
"""Train or continue training a model give a json_path containing a MedCATtrainer export. It will
182184
continue training if an existing model is loaded or start new training if the model is blank/new.
@@ -187,8 +189,10 @@ def train(self,
187189
ignore_extra_labels:
188190
Makes only sense when an existing deid model was loaded and from the new data we want to ignore
189191
labels that did not exist in the old model.
190-
dataset: Defaults to None.
192+
dataset: Defaults to None. Will be split by self.config.general['test_size'] into train and test datasets.
191193
meta_requirements: Defaults to None
194+
train_json_path (str): Defaults to None. If provided, will be used as the training dataset json_path to load from
195+
test_json_path (str): Defaults to None. If provided, will be used as the test dataset json_path to load from
192196
trainer_callbacks (List[Callable[[Trainer], TrainerCallback]]]):
193197
A list of trainer callbacks for collecting metrics during the training at the client side. The
194198
transformers Trainer object will be passed in when each callback is called.
@@ -200,11 +204,16 @@ def train(self,
200204
Tuple: The dataframe, examples, and the dataset
201205
"""
202206

203-
if dataset is None and json_path is not None:
207+
if dataset is None:
204208
# Load the medcattrainer export
205-
json_path = self._prepare_dataset(json_path, ignore_extra_labels=ignore_extra_labels,
209+
if json_path is not None:
210+
json_path = self._prepare_dataset(json_path, ignore_extra_labels=ignore_extra_labels,
206211
meta_requirements=meta_requirements, file_name='data_eval.json')
207-
# Load dataset
212+
elif test_json_path is not None and train_json_path is not None:
213+
train_json_path = self._prepare_dataset(train_json_path, ignore_extra_labels=ignore_extra_labels,
214+
meta_requirements=meta_requirements, file_name='data_train.json')
215+
test_json_path = self._prepare_dataset(test_json_path, ignore_extra_labels=ignore_extra_labels,
216+
meta_requirements=meta_requirements, file_name='data_test.json')
208217

209218
# NOTE: The following is for backwards comppatibility
210219
# in datasets==2.20.0 `trust_remote_code=True` must be explicitly
@@ -216,13 +225,21 @@ def train(self,
216225
ds_load_dataset = partial(datasets.load_dataset, trust_remote_code=True)
217226
else:
218227
ds_load_dataset = datasets.load_dataset
219-
dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__),
220-
data_files={'train': json_path}, # type: ignore
221-
split='train',
222-
cache_dir='/tmp/')
223-
# We split before encoding so the split is document level, as encoding
224-
#does the document splitting into max_seq_len
225-
dataset = dataset.train_test_split(test_size=self.config.general['test_size']) # type: ignore
228+
229+
if json_path:
230+
dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__),
231+
data_files={'train': json_path}, # type: ignore
232+
split='train',
233+
cache_dir='/tmp/')
234+
# We split before encoding so the split is document level, as encoding
235+
# does the document splitting into max_seq_len
236+
dataset = dataset.train_test_split(test_size=self.config.general['test_size']) # type: ignore
237+
elif train_json_path and test_json_path:
238+
dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__),
239+
data_files={'train': train_json_path, 'test': test_json_path}, # type: ignore
240+
cache_dir='/tmp/')
241+
else:
242+
raise ValueError("Either json_path or train_json_path and test_json_path must be provided when no dataset is provided")
226243

227244
# Update labelmap in case the current dataset has more labels than what we had before
228245
self.tokenizer.calculate_label_map(dataset['train'])
@@ -231,8 +248,8 @@ def train(self,
231248
if self.model.num_labels != len(self.tokenizer.label_map):
232249
logger.warning("The dataset contains labels we've not seen before, model is being reinitialized")
233250
logger.warning("Model: {} vs Dataset: {}".format(self.model.num_labels, len(self.tokenizer.label_map)))
234-
self.model = AutoModelForTokenClassification.from_pretrained(self.config.general['model_name'],
235-
num_labels=len(self.tokenizer.label_map),
251+
self.model = AutoModelForTokenClassification.from_pretrained(self.config.general['model_name'],
252+
num_labels=len(self.tokenizer.label_map),
236253
ignore_mismatched_sizes=True)
237254
self.tokenizer.cui2name = {k:self.cdb.get_name(k) for k in self.tokenizer.label_map.keys()}
238255

@@ -273,7 +290,6 @@ def train(self,
273290
# NOTE: this shouldn't really happen, but we'll do this for type safety
274291
raise ValueError("Output path should not be None!")
275292
self.save(save_dir_path=os.path.join(output_dir, 'final_model'))
276-
277293
# Run an eval step and return metrics
278294
p = trainer.predict(encoded_dataset['test']) # type: ignore
279295
df, examples = metrics(p, return_df=True, tokenizer=self.tokenizer, dataset=encoded_dataset['test'])

medcat-v1/medcat/utils/ner/deid.py

Lines changed: 137 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
- config
3535
- cdb
3636
"""
37+
import re
3738
from typing import Union, Tuple, Any, List, Iterable, Optional, Dict
3839
import logging
3940

@@ -62,9 +63,11 @@ class DeIdModel(NerModel):
6263
def __init__(self, cat: CAT) -> None:
6364
self.cat = cat
6465

65-
def train(self, json_path: Union[str, list, None],
66+
def train(self, json_path: Union[str, list, None] = None,
6667
*args, **kwargs) -> Tuple[Any, Any, Any]:
67-
return super().train(json_path, *args, train_nr=0, **kwargs) # type: ignore
68+
assert not all([json_path, kwargs.get('train_json_path'), kwargs.get('test_json_path')]), \
69+
"Either json_path or train_json_path and test_json_path must be provided when no dataset is provided"
70+
return super().train(json_path=json_path, *args, **kwargs) # type: ignore
6871

6972
def eval(self, json_path: Union[str, list, None],
7073
*args, **kwargs) -> Tuple[Any, Any, Any]:
@@ -146,7 +149,8 @@ def deid_multi_texts(self,
146149
return out
147150

148151
@classmethod
149-
def load_model_pack(cls, model_pack_path: str, config: Optional[Dict] = None) -> 'DeIdModel':
152+
def load_model_pack(cls, model_pack_path: str,
153+
config: Optional[Dict] = None) -> 'DeIdModel':
150154
"""Load DeId model from model pack.
151155
152156
The method first loads the CAT instance.
@@ -164,7 +168,7 @@ def load_model_pack(cls, model_pack_path: str, config: Optional[Dict] = None) ->
164168
Returns:
165169
DeIdModel: The resulting DeI model.
166170
"""
167-
ner_model = NerModel.load_model_pack(model_pack_path,config=config)
171+
ner_model = NerModel.load_model_pack(model_pack_path, config=config)
168172
cat = ner_model.cat
169173
if not cls._is_deid_model(cat):
170174
raise ValueError(
@@ -180,7 +184,135 @@ def _is_deid_model(cls, cat: CAT) -> bool:
180184
@classmethod
181185
def _get_reason_not_deid(cls, cat: CAT) -> str:
182186
if cat.vocab is not None:
183-
return "Has vocab"
187+
return "Has voc§ab"
184188
if len(cat._addl_ner) != 1:
185189
return f"Incorrect number of addl_ner: {len(cat._addl_ner)}"
186190
return ""
191+
192+
193+
def match_rules(rules: List[Tuple[str, str]], texts: List[str], cui2preferred_name: Dict[str, str]) -> List[List[Dict]]:
194+
"""Match a set of rules - pat / cui combos as post processing labels.
195+
196+
Uses a cat DeID model for pretty name mapping.
197+
198+
Args:
199+
rules (List[Tuple[str, str]]): List of tuples of pattern and cui
200+
texts (List[str]): List of texts to match rules on
201+
cui2preferred_name (Dict[str, str]): Dictionary of CUI to preferred name, likely to be cat.cdb.cui2preferred_name.
202+
203+
Examples:
204+
>>> cat = CAT.load_model_pack(model_pack_path)
205+
...
206+
>>> rules = [
207+
('(123) 456-7890', '134'),
208+
('1234567890', '134'),
209+
('123.456.7890', '134'),
210+
('1234567890', '134'),
211+
('1234567890', '134'),
212+
]
213+
>>> texts = [
214+
'My phone number is (123) 456-7890',
215+
'My phone number is 1234567890',
216+
'My phone number is 123.456.7890',
217+
'My phone number is 1234567890',
218+
]
219+
>>> matches = match_rules(rules, texts, cat.cdb.cui2preferred_name)
220+
221+
Returns:
222+
List[List[Dict]]: List of lists of predictions from `match_rules`
223+
"""
224+
# Iterate through each text and pattern combination
225+
rule_matches_per_text = []
226+
for i, text in enumerate(texts):
227+
matches_in_text = []
228+
for pattern, concept in rules:
229+
# Find all matches of current pattern in current text
230+
text_matches = re.finditer(pattern, text, flags=re.M)
231+
# Add each match with its pattern and text info
232+
for match in text_matches:
233+
matches_in_text.append({
234+
'source_value': match.group(),
235+
'pretty_name': cui2preferred_name[concept],
236+
'start': match.start(),
237+
'end': match.end(),
238+
'cui': concept,
239+
'acc': 1.0
240+
})
241+
rule_matches_per_text.append(matches_in_text)
242+
return rule_matches_per_text
243+
244+
245+
def merge_all_preds(model_preds_by_text: List[List[Dict]],
246+
rule_matches_per_text: List[List[Dict]],
247+
accept_preds: bool = True) -> List[List[Dict]]:
248+
"""Conveniance method to merge predictions from rule based and deID model predictions.
249+
250+
Args:
251+
model_preds_by_text (List[Dict]): list of predictions from
252+
`cat.get_entities()`, then `[list(m['entities'].values()) for m in model_preds]`
253+
rule_matches_per_text (List[Dict]): list of predictions from output of
254+
running `match_rules`
255+
accept_preds (bool): uses the predicted label from the model,
256+
model_preds_by_text, over the rule matches if they overlap.
257+
Defaults to using model preds over rules.
258+
259+
Returns:
260+
List[List[Dict]]: List of lists of predictions from `merge_all_preds`
261+
"""
262+
assert len(model_preds_by_text) == len(rule_matches_per_text), \
263+
"model_preds_by_text and rule_matches_per_text must have the same length as they should be CAT.get_entities and match_rules outputs of the same text"
264+
return [merge_preds(model_preds_by_text[i], rule_matches_per_text[i], accept_preds) for i in range(len(model_preds_by_text))]
265+
266+
267+
def merge_preds(model_preds: List[Dict],
268+
rule_matches: List[Dict],
269+
accept_preds: bool = True) -> List[Dict]:
270+
"""Merge predictions from rule based and deID model predictions.
271+
272+
Args:
273+
model_preds (List[Dict]): predictions from `cat.get_entities()`
274+
rule_matches (List[Dict]): predictions from output of running `match_rules` on a text
275+
accept_preds (bool): uses the predicted label from the model,
276+
model_preds, over the rule matches if they overlap.
277+
Defaults to using model preds over rules.
278+
279+
Examples:
280+
>>> # a list of predictions from `cat.get_entities()`
281+
>>> model_preds = [
282+
[
283+
{'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0,
284+
'pretty_name': 'Phone Number'},
285+
{'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0,
286+
'pretty_name': 'Phone Number'}
287+
]
288+
]
289+
>>> # a list of predictions from `match_rules`
290+
>>> rule_matches = [
291+
[
292+
{'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0,
293+
'pretty_name': 'Phone Number'},
294+
{'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0,
295+
'pretty_name': 'Phone Number'}
296+
]
297+
]
298+
>>> merged_preds = merge_preds(model_preds, rule_matches)
299+
300+
Returns:
301+
List[Dict]: List of predictions from `merge_preds`
302+
"""
303+
if accept_preds:
304+
labels1 = model_preds
305+
labels2 = rule_matches
306+
else:
307+
labels1 = rule_matches
308+
labels2 = model_preds
309+
310+
# Keep only non-overlapping model predictions
311+
labels2 = [span2 for span2 in labels2
312+
if not any(not (span2['end'] <= span1['start'] or span1['end'] <= span2['start'])
313+
for span1 in labels1)]
314+
# merge preds and sort on start
315+
merged_preds = labels1 + labels2
316+
merged_preds.sort(key=lambda x: x['start'])
317+
merged_preds
318+
return merged_preds

0 commit comments

Comments
 (0)