Skip to content

Commit 99b6bee

Browse files
shubham-s-agarwalmart-r
authored andcommitted
Pushing bug fix for MetaCAT (#148)
* Pushing bug fix for MetaCAT - Bug fix for MetaCAT through cat.get_entities() - Adding tests to ensure it gets caught earlier * Pushing change for metacat * Pushing change * Update meta_cat.py * Pushing update * Update for v2 * Update meta_cat.py
1 parent 0dc223a commit 99b6bee

File tree

7 files changed

+10750
-4
lines changed

7 files changed

+10750
-4
lines changed

medcat-v2/medcat/components/addons/meta_cat/meta_cat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -720,10 +720,12 @@ def prepare_document(self, doc: MutableDocument, input_ids: list,
720720
# Checking if we've reached at the start of the entity
721721
if start <= pair[0] or start <= pair[1]:
722722
if end <= pair[1]:
723-
ctoken_idx.append(ind) # End reached
723+
# End reached; update for correct index
724+
ctoken_idx.append(last_ind + ind)
724725
break
725726
else:
726-
ctoken_idx.append(ind) # Keep going
727+
# Keep going; update for correct index
728+
ctoken_idx.append(last_ind + ind)
727729

728730
# Start where the last ent was found, cannot be before it as we've
729731
# sorted

v1/medcat/examples/cdb_meta.csv

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
cui,name,ontologies,name_status,type_ids,description
2+
C0000039,"gastroesophageal reflux",,,T234,
3+
C0000239,"heartburn",,,,
4+
C0000339,"hypertension",,,,
5+
C0000439,"stroke",,,,

v1/medcat/examples/cdb_meta.dat

1.03 KB
Binary file not shown.

v1/medcat/examples/vocab_meta.dat

883 Bytes
Binary file not shown.

v1/medcat/medcat/meta_cat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,10 +489,10 @@ def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowe
489489
# Checking if we've reached at the start of the entity
490490
if start <= pair[0] or start <= pair[1]:
491491
if end <= pair[1]:
492-
ctoken_idx.append(ind) # End reached
492+
ctoken_idx.append(last_ind+ind) # End reached; update the index to reflect the correct position since iteration does not start from the beginning
493493
break
494494
else:
495-
ctoken_idx.append(ind) # Keep going
495+
ctoken_idx.append(last_ind+ind) # Keep going; update the index to reflect the correct position since iteration does not start from the beginning
496496

497497
# Start where the last ent was found, cannot be before it as we've sorted
498498
last_ind += ind # If we did not start from 0 in the for loop

v1/medcat/tests/resources/mct_export_for_meta_cat_full_text.json

Lines changed: 10671 additions & 0 deletions
Large diffs are not rendered by default.

v1/medcat/tests/test_meta_cat.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@
44

55
from transformers import AutoTokenizer
66

7+
from medcat.vocab import Vocab
8+
from medcat.cdb import CDB
9+
from medcat.cat import CAT
710
from medcat.meta_cat import MetaCAT
811
from medcat.config_meta_cat import ConfigMetaCAT
912
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT
13+
import tempfile
1014
import spacy
1115
from spacy.tokens import Span
1216

@@ -117,6 +121,70 @@ def test_two_phase(self):
117121

118122
self.meta_cat.config.model['phase_number'] = 0
119123

124+
class CAT_METACATTests(unittest.TestCase):
125+
META_CAT_JSON_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources",
126+
"mct_export_for_meta_cat_full_text.json")
120127

128+
@classmethod
129+
def _get_meta_cat(cls, meta_cat_dir):
130+
config = ConfigMetaCAT()
131+
config.general["category_name"] = "Status"
132+
config.general['category_value2id'] = {'Other': 0, 'Confirmed': 1}
133+
config.model['model_name'] = 'bert'
134+
config.model['model_freeze_layers'] = False
135+
config.model['num_layers'] = 10
136+
config.train['lr'] = 0.001
137+
config.train["nepochs"] = 20
138+
config.train.class_weights = [0.75,0.3]
139+
config.train['metric']['base'] = 'macro avg'
140+
141+
meta_cat = MetaCAT(tokenizer=TokenizerWrapperBERT(AutoTokenizer.from_pretrained("bert-base-uncased")),
142+
embeddings=None,
143+
config=config)
144+
os.makedirs(meta_cat_dir, exist_ok=True)
145+
json_path = cls.META_CAT_JSON_PATH
146+
meta_cat.train_from_json(json_path, save_dir_path=meta_cat_dir)
147+
return meta_cat
148+
149+
@classmethod
150+
def setUpClass(cls) -> None:
151+
cls.cdb = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb_meta.dat"))
152+
cls.vocab = Vocab.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab_meta.dat"))
153+
cls.vocab.make_unigram_table()
154+
cls._temp_logs_folder = tempfile.TemporaryDirectory()
155+
cls.temp_dir = tempfile.TemporaryDirectory()
156+
cls.cdb.config.general.spacy_model = os.path.join(cls.temp_dir.name, "en_core_web_md")
157+
cls.cdb.config.ner.min_name_len = 2
158+
cls.cdb.config.ner.upper_case_limit_len = 3
159+
cls.cdb.config.general.spell_check = True
160+
cls.cdb.config.linking.train_count_threshold = 10
161+
cls.cdb.config.linking.similarity_threshold = 0.3
162+
cls.cdb.config.linking.train = True
163+
cls.cdb.config.linking.disamb_length_limit = 5
164+
cls.cdb.config.general.full_unlink = True
165+
cls.cdb.config.general.usage_monitor.enabled = True
166+
cls.cdb.config.general.usage_monitor.log_folder = cls._temp_logs_folder.name
167+
cls.meta_cat_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp")
168+
cls.meta_cat = cls._get_meta_cat(cls.meta_cat_dir)
169+
cls.cat = CAT(cdb=cls.cdb, config=cls.cdb.config, vocab=cls.vocab, meta_cats=[cls.meta_cat])
170+
171+
@classmethod
172+
def tearDownClass(cls) -> None:
173+
cls.cat.destroy_pipe()
174+
if os.path.exists(cls.meta_cat_dir):
175+
shutil.rmtree(cls.meta_cat_dir)
176+
cls._temp_logs_folder.cleanup()
177+
178+
def test_meta_cat_through_cat(self):
179+
text = "This information is just to add text. The patient denied history of heartburn and/or gastroesophageal reflux disorder. He recently had a stroke in the last week."
180+
entities = self.cat.get_entities(text)
181+
meta_status_values = []
182+
for en in entities['entities']:
183+
meta_status_values.append(entities['entities'][en]['meta_anns']['Status']['value'])
184+
185+
self.assertEqual(meta_status_values,['Other','Other','Confirmed'])
186+
187+
import logging
188+
logging.basicConfig(level=logging.INFO)
121189
if __name__ == '__main__':
122190
unittest.main()

0 commit comments

Comments
 (0)