Skip to content

Commit 5b62d27

Browse files
authored
Refactor(medcat):CU-869ak0v7n Add type hints to util methods (#146)
* CU-869ak0v7n: Small refactor and refinement in terms of typing for meta cat utils * CU-869ak0v7n: Improve typing for encode category values * CU-869ak0v7n: Improve typing for encode category values (again) * CU-869ak0v7n: Add meta_anns as an optional part to trainer export annotation typed dict * CU-869ak0v7n: Some typing fixes * CU-869ak0v7n: Imporve typing for creating batch piped data * CU-869ak0v7n: Fix MetaCAT typing in annotation * CU-869ak0v7n: Allow Meta Annotations in a MedCATtrainer export to be either a list or a dict
1 parent 228aab9 commit 5b62d27

File tree

4 files changed

+160
-106
lines changed

4 files changed

+160
-106
lines changed

medcat-v2/medcat/components/addons/meta_cat/data_utils.py

Lines changed: 113 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Iterator, cast
22
import copy
33

44
from medcat.components.addons.meta_cat.mctokenizers.tokenizers import (
@@ -15,7 +15,8 @@ def prepare_from_json(data: dict,
1515
cui_filter: Optional[set] = None,
1616
replace_center: Optional[str] = None,
1717
prerequisites: dict = {},
18-
lowercase: bool = True) -> dict:
18+
lowercase: bool = True
19+
) -> dict[str, list[tuple[list, list, str]]]:
1920
"""Convert the data from a json format into a CSV-like format for
2021
training. This function is not very efficient (the one working with
2122
documents as part of the meta_cat.pipe method is much better).
@@ -64,91 +65,110 @@ def prepare_from_json(data: dict,
6465

6566
if len(text) > 0:
6667
doc_text = tokenizer(text)
68+
for name, sample in _prepare_from_json_loop(
69+
document, prerequisites, cui_filter, doc_text,
70+
cntx_left, cntx_right, lowercase, replace_center,
71+
tokenizer):
72+
if name in out_data:
73+
out_data[name].append(sample)
74+
else:
75+
out_data[name] = [sample]
6776

68-
for ann in document.get('annotations', document.get(
69-
# A hack to support entities and annotations
70-
'entities', {}).values()):
71-
cui = ann['cui']
72-
skip = False
73-
if 'meta_anns' in ann and prerequisites:
74-
# It is possible to require certain meta_anns to exist
75-
# and have a specific value
76-
for meta_ann in prerequisites:
77-
if (meta_ann not in ann['meta_anns'] or
78-
ann['meta_anns'][meta_ann][
79-
'value'] != prerequisites[meta_ann]):
80-
# Skip this annotation as the prerequisite
81-
# is not met
82-
skip = True
83-
break
84-
85-
if not skip and (cui_filter is None or
86-
not cui_filter or cui in cui_filter):
87-
if ann.get('validated', True) and (
88-
not ann.get('deleted', False) and
89-
not ann.get('killed', False)
90-
and not ann.get('irrelevant', False)):
91-
start = ann['start']
92-
end = ann['end']
93-
94-
# Updated implementation to extract all the tokens
95-
# for the medical entity (rather than the one)
96-
ctoken_idx = []
97-
for ind, pair in enumerate(
98-
doc_text['offset_mapping']):
99-
if start <= pair[0] or start <= pair[1]:
100-
if end <= pair[1]:
101-
ctoken_idx.append(ind)
102-
break
103-
else:
104-
ctoken_idx.append(ind)
105-
106-
_start = max(0, ctoken_idx[0] - cntx_left)
107-
_end = min(len(doc_text['input_ids']),
108-
ctoken_idx[-1] + 1 + cntx_right)
109-
110-
cpos = cntx_left + min(0, ind - cntx_left)
111-
cpos_new = [x - _start for x in ctoken_idx]
112-
tkns = doc_text['input_ids'][_start:_end]
113-
114-
if replace_center is not None:
115-
if lowercase:
116-
replace_center = replace_center.lower()
117-
for p_ind, pair in enumerate(
118-
doc_text['offset_mapping']):
119-
if start >= pair[0] and start < pair[1]:
120-
s_ind = p_ind
121-
if end > pair[0] and end <= pair[1]:
122-
e_ind = p_ind
123-
124-
ln = e_ind - s_ind
125-
tkns = tkns[:cpos] + tokenizer(
126-
replace_center)['input_ids'] + tkns[
127-
cpos + ln + 1:]
128-
129-
# Backward compatibility if meta_anns is a list vs
130-
# dict in the new approach
131-
meta_anns: list[dict] = []
132-
if 'meta_anns' in ann:
133-
if isinstance(ann['meta_anns'], dict):
134-
meta_anns.extend(ann['meta_anns'].values())
135-
else:
136-
meta_anns.extend(ann['meta_anns'])
137-
138-
# If the annotation is validated
139-
for meta_ann in meta_anns:
140-
name = meta_ann['name']
141-
value = meta_ann['value']
142-
143-
sample = [tkns, cpos_new, value]
144-
145-
if name in out_data:
146-
out_data[name].append(sample)
147-
else:
148-
out_data[name] = [sample]
14977
return out_data
15078

15179

80+
def _prepare_from_json_loop(document: dict,
81+
prerequisites: dict,
82+
cui_filter: Optional[set],
83+
doc_text: dict,
84+
cntx_left: int,
85+
cntx_right: int,
86+
lowercase: bool,
87+
replace_center: Optional[str],
88+
tokenizer: TokenizerWrapperBase,
89+
) -> Iterator[tuple[str, tuple[list, list, str]]]:
90+
for ann in document.get('annotations', document.get(
91+
# A hack to support entities and annotations
92+
'entities', {}).values()):
93+
cui = ann['cui']
94+
skip = False
95+
if 'meta_anns' in ann and prerequisites:
96+
# It is possible to require certain meta_anns to exist
97+
# and have a specific value
98+
for meta_ann in prerequisites:
99+
if (meta_ann not in ann['meta_anns'] or
100+
ann['meta_anns'][meta_ann][
101+
'value'] != prerequisites[meta_ann]):
102+
# Skip this annotation as the prerequisite
103+
# is not met
104+
skip = True
105+
break
106+
107+
if not skip and (cui_filter is None or
108+
not cui_filter or cui in cui_filter):
109+
if ann.get('validated', True) and (
110+
not ann.get('deleted', False) and
111+
not ann.get('killed', False)
112+
and not ann.get('irrelevant', False)):
113+
start = ann['start']
114+
end = ann['end']
115+
116+
# Updated implementation to extract all the tokens
117+
# for the medical entity (rather than the one)
118+
ctoken_idx = []
119+
for ind, pair in enumerate(
120+
doc_text['offset_mapping']):
121+
if start <= pair[0] or start <= pair[1]:
122+
if end <= pair[1]:
123+
ctoken_idx.append(ind)
124+
break
125+
else:
126+
ctoken_idx.append(ind)
127+
128+
_start = max(0, ctoken_idx[0] - cntx_left)
129+
_end = min(len(doc_text['input_ids']),
130+
ctoken_idx[-1] + 1 + cntx_right)
131+
132+
cpos = cntx_left + min(0, ind - cntx_left)
133+
cpos_new = [x - _start for x in ctoken_idx]
134+
tkns = doc_text['input_ids'][_start:_end]
135+
136+
if replace_center is not None:
137+
if lowercase:
138+
replace_center = replace_center.lower()
139+
for p_ind, pair in enumerate(
140+
doc_text['offset_mapping']):
141+
if start >= pair[0] and start < pair[1]:
142+
s_ind = p_ind
143+
if end > pair[0] and end <= pair[1]:
144+
e_ind = p_ind
145+
146+
ln = e_ind - s_ind
147+
tkns = tkns[:cpos] + tokenizer(
148+
replace_center)['input_ids'] + tkns[
149+
cpos + ln + 1:]
150+
151+
# Backward compatibility if meta_anns is a list vs
152+
# dict in the new approach
153+
meta_anns: list[dict] = []
154+
if 'meta_anns' in ann:
155+
if isinstance(ann['meta_anns'], dict):
156+
meta_anns.extend(ann['meta_anns'].values())
157+
else:
158+
meta_anns.extend(ann['meta_anns'])
159+
160+
# If the annotation is validated
161+
for meta_ann in meta_anns:
162+
name = meta_ann['name']
163+
value = meta_ann['value']
164+
165+
# NOTE: representing as tuple so as to have better typing
166+
# but using a list to allow assignment
167+
sample: tuple[list, list, str] = cast(
168+
tuple[list, list, str], [tkns, cpos_new, value])
169+
yield name, sample
170+
171+
152172
def prepare_for_oversampled_data(data: list,
153173
tokenizer: TokenizerWrapperBase) -> list:
154174
"""Convert the data from a json format into a CSV-like format for
@@ -189,20 +209,21 @@ def prepare_for_oversampled_data(data: list,
189209
return data_sampled
190210

191211

192-
def encode_category_values(data: dict,
212+
def encode_category_values(data: list[tuple[list, list, str]],
193213
existing_category_value2id: Optional[dict] = None,
194-
category_undersample=None,
214+
category_undersample: Optional[str] = None,
195215
alternative_class_names: list[list[str]] = []
196-
) -> tuple:
216+
) -> tuple[
217+
list[tuple[list, list, str]], list, dict]:
197218
"""Converts the category values in the data outputted by
198219
`prepare_from_json` into integer values.
199220
200221
Args:
201-
data (dict):
222+
data (list[tuple[list, list, str]]):
202223
Output of `prepare_from_json`.
203224
existing_category_value2id(Optional[dict]):
204225
Map from category_value to id (old/existing).
205-
category_undersample:
226+
category_undersample (Optional[str]):
206227
Name of class that should be used to undersample the data (for 2
207228
phase learning)
208229
alternative_class_names (list[list[str]]):
@@ -211,9 +232,9 @@ def encode_category_values(data: dict,
211232
`config.general.alternative_class_names`.
212233
213234
Returns:
214-
dict:
235+
list[tuple[list, list, str]]:
215236
New data with integers inplace of strings for category values.
216-
dict:
237+
list:
217238
New undersampled data (for 2 phase learning) with integers
218239
inplace of strings for category values
219240
dict:
@@ -288,7 +309,8 @@ def encode_category_values(data: dict,
288309

289310
# Map values to numbers
290311
for i in range(len(data_list)):
291-
data_list[i][2] = category_value2id[data_list[i][2]]
312+
# NOTE: internally, it's a a list so assingment will work
313+
data_list[i][2] = category_value2id[data_list[i][2]] # type: ignore
292314

293315
# Creating dict with labels and its number of samples
294316
label_data_ = {v: 0 for v in category_value2id.values()}

medcat-v2/medcat/components/addons/meta_cat/meta_cat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from medcat.config.config import ComponentConfig
1616
from medcat.config.config_meta_cat import ConfigMetaCAT
1717
from medcat.components.addons.meta_cat.ml_utils import (
18-
predict, train_model, set_all_seeds, eval_model)
18+
predict, train_model, set_all_seeds, eval_model, EvalModelResults)
1919
from medcat.components.addons.meta_cat.data_utils import (
2020
prepare_from_json, encode_category_values, prepare_for_oversampled_data)
2121
from medcat.components.addons.addons import AddonComponent
@@ -632,15 +632,15 @@ def train_raw(self, data_loaded: dict, save_dir_path: Optional[str] = None,
632632
self.config.train.last_train_on = datetime.now().timestamp()
633633
return report
634634

635-
def eval(self, json_path: str) -> dict:
635+
def eval(self, json_path: str) -> EvalModelResults:
636636
"""Evaluate from json.
637637
638638
Args:
639639
json_path (str):
640640
The json file ath
641641
642642
Returns:
643-
dict:
643+
EvalModelResults:
644644
The resulting model dict
645645
646646
Raises:

medcat-v2/medcat/components/addons/meta_cat/ml_utils.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
import pandas as pd
88
import torch.optim as optim
9-
from typing import Optional, Any, Union
9+
from typing import Optional, Any, Union, TypedDict
1010
from torch import nn
1111
from scipy.special import softmax
1212
from medcat.config.config_meta_cat import ConfigMetaCAT
@@ -34,7 +34,9 @@ def set_all_seeds(seed: int) -> None:
3434
def create_batch_piped_data(data: list[tuple[list[int], int, Optional[int]]],
3535
start_ind: int, end_ind: int,
3636
device: Union[torch.device, str],
37-
pad_id: int) -> tuple:
37+
pad_id: int
38+
) -> tuple[torch.Tensor, list[int],
39+
torch.Tensor, Optional[torch.Tensor]]:
3840
"""Creates a batch given data and start/end that denote batch size,
3941
will also add padding and move to the right device.
4042
@@ -52,13 +54,13 @@ def create_batch_piped_data(data: list[tuple[list[int], int, Optional[int]]],
5254
Padding index
5355
5456
Returns:
55-
x ():
57+
x (torch.Tensor):
5658
Same as data, but subsetted and as a tensor
57-
cpos ():
59+
cpos (list[int]):
5860
Center positions for the data
59-
attention_mask:
61+
attention_mask (torch.Tensor):
6062
Indicating padding mask for the data
61-
y:
63+
y (Optional[torch.Tensor]):
6264
class label of the data
6365
"""
6466
max_seq_len = max([len(x[0]) for x in data])
@@ -78,7 +80,7 @@ class label of the data
7880

7981

8082
def predict(model: nn.Module, data: list[tuple[list[int], int, Optional[int]]],
81-
config: ConfigMetaCAT) -> tuple:
83+
config: ConfigMetaCAT) -> tuple[list[int], list[float]]:
8284
"""Predict on data used in the meta_cat.pipe
8385
8486
Args:
@@ -399,8 +401,17 @@ def initialize_model(classifier, data_, batch_size_, lr_, epochs=4):
399401
return winner_report
400402

401403

404+
EvalModelResults = TypedDict('EvalModelResults', {
405+
"precision": float,
406+
"recall": float,
407+
"f1": float,
408+
"examples": dict,
409+
"confusion matrix": pd.DataFrame,
410+
})
411+
412+
402413
def eval_model(model: nn.Module, data: list, config: ConfigMetaCAT,
403-
tokenizer: TokenizerWrapperBase) -> dict:
414+
tokenizer: TokenizerWrapperBase) -> EvalModelResults:
404415
"""Evaluate a trained model on the provided data
405416
406417
Args:
@@ -474,9 +485,22 @@ def eval_model(model: nn.Module, data: list, config: ConfigMetaCAT,
474485
examples: dict = {'FP': {}, 'FN': {}, 'TP': {}}
475486
id2category_value = {v: k for k, v
476487
in config.general.category_value2id.items()}
488+
return _eval_predictions(
489+
tokenizer, data, predictions, confusion, id2category_value,
490+
y_eval, precision, recall, f1, examples)
491+
492+
493+
def _eval_predictions(
494+
tokenizer: TokenizerWrapperBase,
495+
data: list,
496+
predictions: list[int],
497+
confusion: pd.DataFrame,
498+
id2category_value: dict[int, str],
499+
y_eval: list,
500+
precision, recall, f1, examples: dict) -> EvalModelResults:
477501
for i, p in enumerate(predictions):
478502
y = id2category_value[y_eval[i]]
479-
p = id2category_value[p]
503+
pred = id2category_value[p]
480504
c = data[i][1]
481505
if isinstance(c, list):
482506
c = c[-1]
@@ -487,11 +511,11 @@ def eval_model(model: nn.Module, data: list, config: ConfigMetaCAT,
487511
tokenizer.hf_tokenizers.decode(
488512
tkns[c:c + 1]).strip() + ">> " +
489513
tokenizer.hf_tokenizers.decode(tkns[c + 1:]))
490-
info = "Predicted: {}, True: {}".format(p, y)
491-
if p != y:
514+
info = "Predicted: {}, True: {}".format(pred, y)
515+
if pred != y:
492516
# We made a mistake
493517
examples['FN'][y] = examples['FN'].get(y, []) + [(info, text)]
494-
examples['FP'][p] = examples['FP'].get(p, []) + [(info, text)]
518+
examples['FP'][pred] = examples['FP'].get(pred, []) + [(info, text)]
495519
else:
496520
examples['TP'][y] = examples['TP'].get(y, []) + [(info, text)]
497521

0 commit comments

Comments
 (0)