Skip to content

Commit e9f8c3a

Browse files
alhendricksonTom Searlemart-rtomolopolis
authored
feat(medcat-trainer): CU-869a4br6j Support MedCAT v2 for trainer (#68)
* interim changes for medcat-v2 * updated TODOs * Update dependency to medcat v2 * Update CDB/Vocab load to use the load classmethod again * Move away from pkg_resources (deprecated) * Use v2 based API for loading addons (MetaCATs) * Update MetaCAT loading * Update metrics to v2 format * Do config parsing locally * Update to correct attribute name * Update solr utils to v2 * Fix config access for v2 * Remove addons from CDB config upon load * Fix syntax error * Update Meta Annotation getting so as to avoid error if none set * Fix entity CUI / start/end char access * Fix some more entity detail access * Remove unigram table error (irrelevant / redundant) * Log more info regarding failure upon document preparation * Centralising clearnig CDB addons afer explicit load * More specific import * Clear CDB config addons everywhere if/when applicable * Avoid circular imports by importing dynamically * Correctly set CDB path within v2 model packs * Update (very old) notebook to v2 * Update (very old) notebook for v2 installation * CU-869aknppd: medcattrainer: upgrade dep --------- Co-authored-by: Tom Searle <[email protected]> Co-authored-by: mart-r <[email protected]> Co-authored-by: tomolopolis <[email protected]>
1 parent 4c0495e commit e9f8c3a

File tree

12 files changed

+291
-150
lines changed

12 files changed

+291
-150
lines changed

medcat-trainer/notebook_docs/API_Examples.ipynb

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,11 @@
200200
"for name, d_s in datasets:\n",
201201
" payload = {\n",
202202
" 'dataset_name': name, # Name that appears in each\n",
203-
" 'dataset': d_s.loc[:, ['name', 'text']].to_dict(), # Dictionary representation of only \n",
203+
" 'dataset': d_s.loc[:, ['name', 'text']].to_dict(), # Dictionary representation of only\n",
204204
" 'description': f'{name} first 20 notes from each category' # Description that appears in the trainer\n",
205205
" }\n",
206206
" resp = requests.post(f'{URL}/api/create-dataset/', json=payload, headers=headers)\n",
207-
" dataset_ids.append(json.loads(resp.text)['dataset_id']) \n",
207+
" dataset_ids.append(json.loads(resp.text)['dataset_id'])\n",
208208
"# New datasets created in the trainer have the following IDs\n",
209209
"dataset_ids"
210210
]
@@ -262,7 +262,7 @@
262262
},
263263
{
264264
"cell_type": "code",
265-
"execution_count": 12,
265+
"execution_count": null,
266266
"metadata": {
267267
"tags": []
268268
},
@@ -273,7 +273,7 @@
273273
},
274274
{
275275
"cell_type": "code",
276-
"execution_count": 14,
276+
"execution_count": null,
277277
"metadata": {
278278
"tags": []
279279
},
@@ -290,7 +290,7 @@
290290
}
291291
],
292292
"source": [
293-
"CDB.load('../../medcat-models/deid_medcat_n2c2_modelpack/cdb.dat')"
293+
"cdb = CDB.load('../../medcat-models/deid_medcat_n2c2_modelpack/cdb.dat')"
294294
]
295295
},
296296
{
@@ -301,8 +301,8 @@
301301
},
302302
"outputs": [],
303303
"source": [
304-
"txt = json.loads(requests.post(f'{URL}/api/concept-dbs/', headers=headers, \n",
305-
" data={'name': 'example_cdb', 'use_for_training': True}, \n",
304+
"txt = json.loads(requests.post(f'{URL}/api/concept-dbs/', headers=headers,\n",
305+
" data={'name': 'example_cdb', 'use_for_training': True},\n",
306306
" files={'cdb_file': open('../../medcat-models/deid_medcat_n2c2_modelpack/cdb.dat', 'rb')}).text)"
307307
]
308308
},
@@ -342,8 +342,8 @@
342342
},
343343
"outputs": [],
344344
"source": [
345-
"txt = json.loads(requests.put(f'{URL}/api/concept-dbs/21/', headers=headers, \n",
346-
" data={'name': 'example_cdb-EDITED', 'use_for_training': True}, \n",
345+
"txt = json.loads(requests.put(f'{URL}/api/concept-dbs/21/', headers=headers,\n",
346+
" data={'name': 'example_cdb-EDITED', 'use_for_training': True},\n",
347347
" files={'cdb_file': open('../../medcat-models/deid_medcat_n2c2_modelpack/cdb.dat', 'rb')}).text)"
348348
]
349349
},
@@ -379,8 +379,8 @@
379379
}
380380
],
381381
"source": [
382-
"requests.post(f'{URL}/api/concept-dbs/', headers=headers, \n",
383-
" data={'name': 'example_cdb', 'use_for_training': True}, \n",
382+
"requests.post(f'{URL}/api/concept-dbs/', headers=headers,\n",
383+
" data={'name': 'example_cdb', 'use_for_training': True},\n",
384384
" files={'cdb_file': open('../../medcat-models/deid_medcat_n2c2_modelpack/cdb.dat', 'rb')}).text)"
385385
]
386386
},
@@ -404,7 +404,7 @@
404404
"metadata": {},
405405
"outputs": [],
406406
"source": [
407-
"txt = json.loads(requests.post(f'{URL}/api/vocab/', headers=headers, \n",
407+
"txt = json.loads(requests.post(f'{URL}/api/vocab/', headers=headers,\n",
408408
" files={'cdb_file': open('<<LOCATION OF vocab>>', 'rb')}).text)"
409409
]
410410
},
@@ -465,7 +465,7 @@
465465
"all_cdbs = json.loads(requests.get(f'{URL}/api/concept-dbs/', headers=headers).text)['results']\n",
466466
"# the CDB ID we'll use for this example\n",
467467
"cdb_to_use = all_cdbs[0]['id']\n",
468-
"# you might have many CDBs here. First 2 cdbs: \n",
468+
"# you might have many CDBs here. First 2 cdbs:\n",
469469
"all_cdbs[0:2]"
470470
]
471471
},
@@ -521,12 +521,12 @@
521521
"for d_id, p_name in zip(dataset_ids, project_names):\n",
522522
" payload = {\n",
523523
" 'name': f'{p_name} Annotation Project',\n",
524-
" 'description': 'Example projects', \n",
525-
" 'cuis': '', \n",
524+
" 'description': 'Example projects',\n",
525+
" 'cuis': '',\n",
526526
" 'tuis': '',\n",
527527
" 'dataset': d_id,\n",
528-
" 'concept_db': cdb_to_use, \n",
529-
" 'vocab': vocab_to_use, \n",
528+
" 'concept_db': cdb_to_use,\n",
529+
" 'vocab': vocab_to_use,\n",
530530
" 'members': users_ids\n",
531531
" }\n",
532532
" project_ids.append(json.loads(requests.post(f'{URL}/api/project-annotate-entities/', json=payload, headers=headers).text))"

medcat-trainer/notebook_docs/Train_MedCAT_Models.ipynb

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
},
2626
{
2727
"cell_type": "code",
28-
"execution_count": 2,
28+
"execution_count": null,
2929
"metadata": {},
3030
"outputs": [
3131
{
@@ -177,7 +177,7 @@
177177
],
178178
"source": [
179179
"# install medcat\n",
180-
"!pip install medcat\n",
180+
"!pip install \"medcat[spacy,meta-cat,rel-cat,deid]>=2.0.0\"\n",
181181
"# scispacy medium models\n",
182182
"!pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.2.5/en_core_sci_md-0.2.5.tar.gz\n",
183183
"# ipywidgets\n",
@@ -186,7 +186,7 @@
186186
},
187187
{
188188
"cell_type": "code",
189-
"execution_count": 2,
189+
"execution_count": null,
190190
"metadata": {
191191
"ExecuteTime": {
192192
"end_time": "2020-09-08T11:27:34.270631Z",
@@ -195,9 +195,11 @@
195195
},
196196
"outputs": [],
197197
"source": [
198+
"import json\n",
199+
"\n",
198200
"from medcat.cat import CAT\n",
199201
"from medcat.cdb import CDB\n",
200-
"from medcat.utils.vocab import Vocab"
202+
"from medcat.vocab import Vocab"
201203
]
202204
},
203205
{
@@ -310,7 +312,7 @@
310312
},
311313
{
312314
"cell_type": "code",
313-
"execution_count": 5,
315+
"execution_count": null,
314316
"metadata": {
315317
"ExecuteTime": {
316318
"end_time": "2020-09-08T11:27:59.782731Z",
@@ -319,16 +321,14 @@
319321
},
320322
"outputs": [],
321323
"source": [
322-
"cdb = CDB()\n",
323-
"cdb.load_dict(cdb_path)\n",
324-
"vocab = Vocab()\n",
325-
"vocab.load_dict(vocab_path)\n",
324+
"cdb = CDB.load(cdb_path)\n",
325+
"vocab = Vocab.load(vocab_path)\n",
326326
"cat = CAT(cdb, vocab)"
327327
]
328328
},
329329
{
330330
"cell_type": "code",
331-
"execution_count": 10,
331+
"execution_count": null,
332332
"metadata": {
333333
"ExecuteTime": {
334334
"end_time": "2020-09-08T11:37:38.546552Z",
@@ -382,7 +382,7 @@
382382
"name": "stdout",
383383
"output_type": "stream",
384384
"text": [
385-
"\r",
385+
"\r\n",
386386
"Epoch: 0, Prec: 0.36538461538461536, Rec: 0.8444444444444444, F1: 0.6049145299145299\n",
387387
"\n",
388388
"Docs with false positives: Psych Text 1; Psych Text 2\n",
@@ -1383,12 +1383,13 @@
13831383
}
13841384
],
13851385
"source": [
1386-
"cat.train_supervised(data_path=\"example_data/MedCAT_Export_With_Text_2020-05-22_10_34_09.json\", \n",
1387-
" nepochs=1,\n",
1388-
" lr=0.1,\n",
1389-
" anneal=False, # Unless we are reseting the CDB or cui_count this is False\n",
1390-
" print_stats=True, \n",
1391-
" use_filters=True)"
1386+
"with open(\"example_data/MedCAT_Export_With_Text_2020-05-22_10_34_09.json\") as f:\n",
1387+
" data = json.load(f)\n",
1388+
"cat.trainer.train_supervised_raw(\n",
1389+
" data=data,\n",
1390+
" nepochs=1,\n",
1391+
" print_stats=True,\n",
1392+
" use_filters=True)"
13921393
]
13931394
},
13941395
{
@@ -1402,7 +1403,7 @@
14021403
},
14031404
{
14041405
"cell_type": "code",
1405-
"execution_count": 50,
1406+
"execution_count": null,
14061407
"metadata": {
14071408
"ExecuteTime": {
14081409
"end_time": "2020-09-08T15:04:02.394607Z",
@@ -1411,14 +1412,14 @@
14111412
},
14121413
"outputs": [],
14131414
"source": [
1414-
"from medcat.meta_cat import MetaCAT\n",
1415+
"from medcat.components.addons.meta_cat import MetaCAT\n",
14151416
"from tokenizers import ByteLevelBPETokenizer\n",
14161417
"from itertools import chain"
14171418
]
14181419
},
14191420
{
14201421
"cell_type": "code",
1421-
"execution_count": 18,
1422+
"execution_count": null,
14221423
"metadata": {
14231424
"ExecuteTime": {
14241425
"end_time": "2020-09-08T14:46:39.070589Z",
@@ -1427,6 +1428,7 @@
14271428
},
14281429
"outputs": [],
14291430
"source": [
1431+
"import numpy as np\n",
14301432
"# Tokenizer instantiation\n",
14311433
"tokenizer = ByteLevelBPETokenizer(vocab_file='data/medmen-vocab.json', merges_file='data/medmen-merges.txt')\n",
14321434
"embeddings = np.load(open('data/embeddings.npy', 'rb'))"
@@ -1443,7 +1445,7 @@
14431445
},
14441446
"outputs": [],
14451447
"source": [
1446-
"metacat = MetaCAT(tokenizer=tokenizer, embeddings=embeddings, \n",
1448+
"metacat = MetaCAT(tokenizer=tokenizer, embeddings=embeddings,\n",
14471449
" pad_id=len(embeddings) -1, save_dir='mc_status', device='cpu')"
14481450
]
14491451
},

medcat-trainer/webapp/api/api/admin/actions.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
from api.models import AnnotatedEntity, MetaAnnotation, EntityRelation, Document, ConceptDB
1414
from api.solr_utils import drop_collection, import_all_concepts
15+
from api.utils import clear_cdb_cnf_addons
16+
17+
from medcat.cdb import CDB
1518

1619
logger = logging.getLogger(__name__)
1720

@@ -356,20 +359,18 @@ def dataset_document_counts(dataset):
356359

357360
@background(schedule=5)
358361
def _reset_cdb_filters(id):
359-
from medcat.cdb import CDB
360362
concept_db = ConceptDB.objects.get(id=id)
361363
cdb = CDB.load(concept_db.cdb_file.path)
362-
cdb.config.linking['filters'] = {'cuis': set()}
364+
clear_cdb_cnf_addons(cdb, id)
365+
cdb.config.components.linking.filters = {'cuis': set()}
363366
cdb.save(concept_db.cdb_file.path)
364367

365368

366369
@background(schedule=5)
367370
def import_concepts_from_cdb(cdb_model_id: int):
368-
from medcat.cdb import CDB
369-
370371
cdb_model = ConceptDB.objects.get(id=cdb_model_id)
371372
cdb = CDB.load(cdb_model.cdb_file.path)
372-
373+
clear_cdb_cnf_addons(cdb, cdb_model_id)
373374
import_all_concepts(cdb, cdb_model)
374375

375376

medcat-trainer/webapp/api/api/metrics.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,20 @@
1313
from background_task.models import Task
1414
from django.contrib.auth.models import User
1515
from django.db.models import QuerySet
16+
from medcat.stats.stats import get_stats
1617
from medcat.cat import CAT
1718
from medcat.cdb import CDB
18-
from medcat.config_meta_cat import ConfigMetaCAT
19-
from medcat.meta_cat import MetaCAT
20-
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase
21-
from medcat.utils.meta_cat.data_utils import prepare_from_json, encode_category_values
22-
from medcat.utils.meta_cat.ml_utils import create_batch_piped_data
19+
from medcat.config.config_meta_cat import ConfigMetaCAT
20+
from medcat.components.addons.meta_cat.meta_cat import MetaCATAddon
21+
from medcat.components.addons.meta_cat.mctokenizers.tokenizers import TokenizerWrapperBase
22+
from medcat.components.addons.meta_cat.data_utils import prepare_from_json, encode_category_values
23+
from medcat.components.addons.meta_cat.ml_utils import create_batch_piped_data
2324
from medcat.vocab import Vocab
2425
from torch import nn
2526

2627
from api.admin import retrieve_project_data
2728
from api.models import AnnotatedEntity, ProjectAnnotateEntities, ProjectMetrics as AppProjectMetrics
29+
from api.utils import clear_cdb_cnf_addons
2830
from core.settings import MEDIA_ROOT
2931

3032
_dt_fmt = '%Y-%m-%d %H:%M:%S.%f'
@@ -51,6 +53,7 @@ def calculate_metrics(project_ids: List[int], report_name: str):
5153
else:
5254
# assume the cdb / vocab is set in these projects
5355
cdb = CDB.load(projects[0].concept_db.cdb_file.path)
56+
clear_cdb_cnf_addons(cdb, projects[0].concept_db.name)
5457
vocab = Vocab.load(projects[0].vocab.vocab_file.path)
5558
cat = CAT(cdb, vocab, config=cdb.config)
5659
project_data = retrieve_project_data(projects)
@@ -116,7 +119,7 @@ def annotation_df(self):
116119
"""
117120
annotation_df = pd.DataFrame(self.annotations)
118121
if self.cat:
119-
annotation_df.insert(5, 'concept_name', annotation_df['cui'].map(self.cat.cdb.cui2preferred_name))
122+
annotation_df.insert(5, 'concept_name', annotation_df['cui'].map(self.cat.cdb.get_name))
120123
annotation_df['last_modified'] = pd.to_datetime(annotation_df['last_modified']).dt.tz_localize(None)
121124
return annotation_df
122125

@@ -138,9 +141,10 @@ def concept_summary(self, extra_cui_filter=None):
138141
concept_count_df['count_variations_ratio'] = round(concept_count_df['concept_count'] /
139142
concept_count_df['variations'], 3)
140143
if self.cat:
141-
fps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples = self.cat._print_stats(data=self.mct_export,
142-
use_project_filters=True,
143-
extra_cui_filter=extra_cui_filter)
144+
fps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples = get_stats(self.cat,
145+
data=self.mct_export,
146+
use_project_filters=True,
147+
extra_cui_filter=extra_cui_filter)
144148
# remap tps, fns, fps to specific user annotations
145149
examples = self.enrich_medcat_metrics(examples)
146150
concept_count_df['fps'] = concept_count_df['cui'].map(fps)
@@ -242,11 +246,11 @@ def rename_meta_anns(self, meta_anns2rename=dict(), meta_ann_values2rename=dict(
242246
return
243247

244248
def _eval_model(self, model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: TokenizerWrapperBase) -> Dict:
245-
device = torch.device(config.general['device']) # Create a torch device
246-
batch_size_eval = config.general['batch_size_eval']
247-
pad_id = config.model['padding_idx']
248-
ignore_cpos = config.model['ignore_cpos']
249-
class_weights = config.train['class_weights']
249+
device = torch.device(config.general.device) # Create a torch device
250+
batch_size_eval = config.general.batch_size_eval
251+
pad_id = config.model.padding_idx
252+
ignore_cpos = config.model.ignore_cpos
253+
class_weights = config.train.class_weights
250254

251255
if class_weights is not None:
252256
class_weights = torch.FloatTensor(class_weights).to(device)
@@ -323,9 +327,17 @@ def full_annotation_df(self) -> pd.DataFrame:
323327
~anns_df['killed'] & ~anns_df['irrelevant']]
324328
meta_df = meta_df.reset_index(drop=True)
325329

326-
for meta_model in self.cat._meta_cats:
327-
logger.info(f'Checking metacat model: {meta_model}')
328-
meta_model_task = meta_model.name
330+
all_meta_cats = self.cat.get_addons_of_type(MetaCATAddon)
331+
332+
for meta_model_card in self.cat.get_model_card(as_dict=True)['MetaCAT models']:
333+
meta_model_task = meta_model_card['Category Name']
334+
logger.info(f'Checking metacat model: {meta_model_task}')
335+
_meta_models = [mc for mc in all_meta_cats
336+
if mc.config.general.category_name == meta_model_task]
337+
if not _meta_models:
338+
logger.warning(f'MetaCAT model {meta_model_task} not found in the CAT instance.')
339+
continue
340+
meta_model = _meta_models[0]
329341
meta_results = self._eval(meta_model, self.mct_export)
330342
meta_values = {v: k for k, v in meta_results['meta_values'].items()}
331343
pred_meta_values = []
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Generated by Django 2.2.28 on 2023-12-11 15:26
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
8+
dependencies = [
9+
('api', '0073_auto_20231022_0028'),
10+
]
11+
12+
operations = [
13+
migrations.AlterField(
14+
model_name='projectmetrics',
15+
name='projects',
16+
field=models.ManyToManyField(blank=True, to='api.ProjectAnnotateEntities'),
17+
),
18+
]

0 commit comments

Comments
 (0)