|
4 | 4 |
|
5 | 5 | from transformers import AutoTokenizer |
6 | 6 |
|
| 7 | +from medcat.vocab import Vocab |
| 8 | +from medcat.cdb import CDB |
| 9 | +from medcat.cat import CAT |
7 | 10 | from medcat.meta_cat import MetaCAT |
8 | 11 | from medcat.config_meta_cat import ConfigMetaCAT |
9 | 12 | from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT |
| 13 | +import tempfile |
10 | 14 | import spacy |
11 | 15 | from spacy.tokens import Span |
12 | 16 |
|
@@ -117,6 +121,70 @@ def test_two_phase(self): |
117 | 121 |
|
118 | 122 | self.meta_cat.config.model['phase_number'] = 0 |
119 | 123 |
|
| 124 | +class CAT_METACATTests(unittest.TestCase): |
| 125 | + META_CAT_JSON_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", |
| 126 | + "mct_export_for_meta_cat_full_text.json") |
120 | 127 |
|
| 128 | + @classmethod |
| 129 | + def _get_meta_cat(cls, meta_cat_dir): |
| 130 | + config = ConfigMetaCAT() |
| 131 | + config.general["category_name"] = "Status" |
| 132 | + config.general['category_value2id'] = {'Other': 0, 'Confirmed': 1} |
| 133 | + config.model['model_name'] = 'bert' |
| 134 | + config.model['model_freeze_layers'] = False |
| 135 | + config.model['num_layers'] = 10 |
| 136 | + config.train['lr'] = 0.001 |
| 137 | + config.train["nepochs"] = 20 |
| 138 | + config.train.class_weights = [0.75,0.3] |
| 139 | + config.train['metric']['base'] = 'macro avg' |
| 140 | + |
| 141 | + meta_cat = MetaCAT(tokenizer=TokenizerWrapperBERT(AutoTokenizer.from_pretrained("bert-base-uncased")), |
| 142 | + embeddings=None, |
| 143 | + config=config) |
| 144 | + os.makedirs(meta_cat_dir, exist_ok=True) |
| 145 | + json_path = cls.META_CAT_JSON_PATH |
| 146 | + meta_cat.train_from_json(json_path, save_dir_path=meta_cat_dir) |
| 147 | + return meta_cat |
| 148 | + |
| 149 | + @classmethod |
| 150 | + def setUpClass(cls) -> None: |
| 151 | + cls.cdb = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb_meta.dat")) |
| 152 | + cls.vocab = Vocab.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab_meta.dat")) |
| 153 | + cls.vocab.make_unigram_table() |
| 154 | + cls._temp_logs_folder = tempfile.TemporaryDirectory() |
| 155 | + cls.temp_dir = tempfile.TemporaryDirectory() |
| 156 | + cls.cdb.config.general.spacy_model = os.path.join(cls.temp_dir.name, "en_core_web_md") |
| 157 | + cls.cdb.config.ner.min_name_len = 2 |
| 158 | + cls.cdb.config.ner.upper_case_limit_len = 3 |
| 159 | + cls.cdb.config.general.spell_check = True |
| 160 | + cls.cdb.config.linking.train_count_threshold = 10 |
| 161 | + cls.cdb.config.linking.similarity_threshold = 0.3 |
| 162 | + cls.cdb.config.linking.train = True |
| 163 | + cls.cdb.config.linking.disamb_length_limit = 5 |
| 164 | + cls.cdb.config.general.full_unlink = True |
| 165 | + cls.cdb.config.general.usage_monitor.enabled = True |
| 166 | + cls.cdb.config.general.usage_monitor.log_folder = cls._temp_logs_folder.name |
| 167 | + cls.meta_cat_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp") |
| 168 | + cls.meta_cat = cls._get_meta_cat(cls.meta_cat_dir) |
| 169 | + cls.cat = CAT(cdb=cls.cdb, config=cls.cdb.config, vocab=cls.vocab, meta_cats=[cls.meta_cat]) |
| 170 | + |
| 171 | + @classmethod |
| 172 | + def tearDownClass(cls) -> None: |
| 173 | + cls.cat.destroy_pipe() |
| 174 | + if os.path.exists(cls.meta_cat_dir): |
| 175 | + shutil.rmtree(cls.meta_cat_dir) |
| 176 | + cls._temp_logs_folder.cleanup() |
| 177 | + |
| 178 | + def test_meta_cat_through_cat(self): |
| 179 | + text = "This information is just to add text. The patient denied history of heartburn and/or gastroesophageal reflux disorder. He recently had a stroke in the last week." |
| 180 | + entities = self.cat.get_entities(text) |
| 181 | + meta_status_values = [] |
| 182 | + for en in entities['entities']: |
| 183 | + meta_status_values.append(entities['entities'][en]['meta_anns']['Status']['value']) |
| 184 | + |
| 185 | + self.assertEqual(meta_status_values,['Other','Other','Confirmed']) |
| 186 | + |
| 187 | +import logging |
| 188 | +logging.basicConfig(level=logging.INFO) |
121 | 189 | if __name__ == '__main__': |
122 | 190 | unittest.main() |
0 commit comments