diff --git a/medcat-trainer/notebook_docs/API_Examples.ipynb b/medcat-trainer/notebook_docs/API_Examples.ipynb index 22dcd3beb..97b722327 100644 --- a/medcat-trainer/notebook_docs/API_Examples.ipynb +++ b/medcat-trainer/notebook_docs/API_Examples.ipynb @@ -200,11 +200,11 @@ "for name, d_s in datasets:\n", " payload = {\n", " 'dataset_name': name, # Name that appears in each\n", - " 'dataset': d_s.loc[:, ['name', 'text']].to_dict(), # Dictionary representation of only \n", + " 'dataset': d_s.loc[:, ['name', 'text']].to_dict(), # Dictionary representation of only\n", " 'description': f'{name} first 20 notes from each category' # Description that appears in the trainer\n", " }\n", " resp = requests.post(f'{URL}/api/create-dataset/', json=payload, headers=headers)\n", - " dataset_ids.append(json.loads(resp.text)['dataset_id']) \n", + " dataset_ids.append(json.loads(resp.text)['dataset_id'])\n", "# New datasets created in the trainer have the following IDs\n", "dataset_ids" ] @@ -262,7 +262,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": { "tags": [] }, @@ -273,7 +273,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": { "tags": [] }, @@ -290,7 +290,7 @@ } ], "source": [ - "CDB.load('../../medcat-models/deid_medcat_n2c2_modelpack/cdb.dat')" + "cdb = CDB.load('../../medcat-models/deid_medcat_n2c2_modelpack/cdb.dat')" ] }, { @@ -301,8 +301,8 @@ }, "outputs": [], "source": [ - "txt = json.loads(requests.post(f'{URL}/api/concept-dbs/', headers=headers, \n", - " data={'name': 'example_cdb', 'use_for_training': True}, \n", + "txt = json.loads(requests.post(f'{URL}/api/concept-dbs/', headers=headers,\n", + " data={'name': 'example_cdb', 'use_for_training': True},\n", " files={'cdb_file': open('../../medcat-models/deid_medcat_n2c2_modelpack/cdb.dat', 'rb')}).text)" ] }, @@ -342,8 +342,8 @@ }, "outputs": [], "source": [ - "txt = json.loads(requests.put(f'{URL}/api/concept-dbs/21/', headers=headers, \n", - " data={'name': 'example_cdb-EDITED', 'use_for_training': True}, \n", + "txt = json.loads(requests.put(f'{URL}/api/concept-dbs/21/', headers=headers,\n", + " data={'name': 'example_cdb-EDITED', 'use_for_training': True},\n", " files={'cdb_file': open('../../medcat-models/deid_medcat_n2c2_modelpack/cdb.dat', 'rb')}).text)" ] }, @@ -379,8 +379,8 @@ } ], "source": [ - "requests.post(f'{URL}/api/concept-dbs/', headers=headers, \n", - " data={'name': 'example_cdb', 'use_for_training': True}, \n", + "requests.post(f'{URL}/api/concept-dbs/', headers=headers,\n", + " data={'name': 'example_cdb', 'use_for_training': True},\n", " files={'cdb_file': open('../../medcat-models/deid_medcat_n2c2_modelpack/cdb.dat', 'rb')}).text)" ] }, @@ -404,7 +404,7 @@ "metadata": {}, "outputs": [], "source": [ - "txt = json.loads(requests.post(f'{URL}/api/vocab/', headers=headers, \n", + "txt = json.loads(requests.post(f'{URL}/api/vocab/', headers=headers,\n", " files={'cdb_file': open('<>', 'rb')}).text)" ] }, @@ -465,7 +465,7 @@ "all_cdbs = json.loads(requests.get(f'{URL}/api/concept-dbs/', headers=headers).text)['results']\n", "# the CDB ID we'll use for this example\n", "cdb_to_use = all_cdbs[0]['id']\n", - "# you might have many CDBs here. First 2 cdbs: \n", + "# you might have many CDBs here. First 2 cdbs:\n", "all_cdbs[0:2]" ] }, @@ -521,12 +521,12 @@ "for d_id, p_name in zip(dataset_ids, project_names):\n", " payload = {\n", " 'name': f'{p_name} Annotation Project',\n", - " 'description': 'Example projects', \n", - " 'cuis': '', \n", + " 'description': 'Example projects',\n", + " 'cuis': '',\n", " 'tuis': '',\n", " 'dataset': d_id,\n", - " 'concept_db': cdb_to_use, \n", - " 'vocab': vocab_to_use, \n", + " 'concept_db': cdb_to_use,\n", + " 'vocab': vocab_to_use,\n", " 'members': users_ids\n", " }\n", " project_ids.append(json.loads(requests.post(f'{URL}/api/project-annotate-entities/', json=payload, headers=headers).text))" diff --git a/medcat-trainer/notebook_docs/Train_MedCAT_Models.ipynb b/medcat-trainer/notebook_docs/Train_MedCAT_Models.ipynb index c89e352a5..3b3e74cbd 100644 --- a/medcat-trainer/notebook_docs/Train_MedCAT_Models.ipynb +++ b/medcat-trainer/notebook_docs/Train_MedCAT_Models.ipynb @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -177,7 +177,7 @@ ], "source": [ "# install medcat\n", - "!pip install medcat\n", + "!pip install \"medcat[spacy,meta-cat,rel-cat,deid]>=2.0.0\"\n", "# scispacy medium models\n", "!pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.2.5/en_core_sci_md-0.2.5.tar.gz\n", "# ipywidgets\n", @@ -186,7 +186,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2020-09-08T11:27:34.270631Z", @@ -195,9 +195,11 @@ }, "outputs": [], "source": [ + "import json\n", + "\n", "from medcat.cat import CAT\n", "from medcat.cdb import CDB\n", - "from medcat.utils.vocab import Vocab" + "from medcat.vocab import Vocab" ] }, { @@ -310,7 +312,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2020-09-08T11:27:59.782731Z", @@ -319,16 +321,14 @@ }, "outputs": [], "source": [ - "cdb = CDB()\n", - "cdb.load_dict(cdb_path)\n", - "vocab = Vocab()\n", - "vocab.load_dict(vocab_path)\n", + "cdb = CDB.load(cdb_path)\n", + "vocab = Vocab.load(vocab_path)\n", "cat = CAT(cdb, vocab)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2020-09-08T11:37:38.546552Z", @@ -382,7 +382,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", + "\r\n", "Epoch: 0, Prec: 0.36538461538461536, Rec: 0.8444444444444444, F1: 0.6049145299145299\n", "\n", "Docs with false positives: Psych Text 1; Psych Text 2\n", @@ -1383,12 +1383,13 @@ } ], "source": [ - "cat.train_supervised(data_path=\"example_data/MedCAT_Export_With_Text_2020-05-22_10_34_09.json\", \n", - " nepochs=1,\n", - " lr=0.1,\n", - " anneal=False, # Unless we are reseting the CDB or cui_count this is False\n", - " print_stats=True, \n", - " use_filters=True)" + "with open(\"example_data/MedCAT_Export_With_Text_2020-05-22_10_34_09.json\") as f:\n", + " data = json.load(f)\n", + "cat.trainer.train_supervised_raw(\n", + " data=data,\n", + " nepochs=1,\n", + " print_stats=True,\n", + " use_filters=True)" ] }, { @@ -1402,7 +1403,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2020-09-08T15:04:02.394607Z", @@ -1411,14 +1412,14 @@ }, "outputs": [], "source": [ - "from medcat.meta_cat import MetaCAT\n", + "from medcat.components.addons.meta_cat import MetaCAT\n", "from tokenizers import ByteLevelBPETokenizer\n", "from itertools import chain" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2020-09-08T14:46:39.070589Z", @@ -1427,6 +1428,7 @@ }, "outputs": [], "source": [ + "import numpy as np\n", "# Tokenizer instantiation\n", "tokenizer = ByteLevelBPETokenizer(vocab_file='data/medmen-vocab.json', merges_file='data/medmen-merges.txt')\n", "embeddings = np.load(open('data/embeddings.npy', 'rb'))" @@ -1443,7 +1445,7 @@ }, "outputs": [], "source": [ - "metacat = MetaCAT(tokenizer=tokenizer, embeddings=embeddings, \n", + "metacat = MetaCAT(tokenizer=tokenizer, embeddings=embeddings,\n", " pad_id=len(embeddings) -1, save_dir='mc_status', device='cpu')" ] }, diff --git a/medcat-trainer/webapp/api/api/admin/actions.py b/medcat-trainer/webapp/api/api/admin/actions.py index 57b5a16b6..1ec12d16f 100644 --- a/medcat-trainer/webapp/api/api/admin/actions.py +++ b/medcat-trainer/webapp/api/api/admin/actions.py @@ -12,6 +12,9 @@ from api.models import AnnotatedEntity, MetaAnnotation, EntityRelation, Document, ConceptDB from api.solr_utils import drop_collection, import_all_concepts +from api.utils import clear_cdb_cnf_addons + +from medcat.cdb import CDB logger = logging.getLogger(__name__) @@ -356,20 +359,18 @@ def dataset_document_counts(dataset): @background(schedule=5) def _reset_cdb_filters(id): - from medcat.cdb import CDB concept_db = ConceptDB.objects.get(id=id) cdb = CDB.load(concept_db.cdb_file.path) - cdb.config.linking['filters'] = {'cuis': set()} + clear_cdb_cnf_addons(cdb, id) + cdb.config.components.linking.filters = {'cuis': set()} cdb.save(concept_db.cdb_file.path) @background(schedule=5) def import_concepts_from_cdb(cdb_model_id: int): - from medcat.cdb import CDB - cdb_model = ConceptDB.objects.get(id=cdb_model_id) cdb = CDB.load(cdb_model.cdb_file.path) - + clear_cdb_cnf_addons(cdb, cdb_model_id) import_all_concepts(cdb, cdb_model) diff --git a/medcat-trainer/webapp/api/api/metrics.py b/medcat-trainer/webapp/api/api/metrics.py index faaad4960..79d534724 100644 --- a/medcat-trainer/webapp/api/api/metrics.py +++ b/medcat-trainer/webapp/api/api/metrics.py @@ -13,18 +13,20 @@ from background_task.models import Task from django.contrib.auth.models import User from django.db.models import QuerySet +from medcat.stats.stats import get_stats from medcat.cat import CAT from medcat.cdb import CDB -from medcat.config_meta_cat import ConfigMetaCAT -from medcat.meta_cat import MetaCAT -from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase -from medcat.utils.meta_cat.data_utils import prepare_from_json, encode_category_values -from medcat.utils.meta_cat.ml_utils import create_batch_piped_data +from medcat.config.config_meta_cat import ConfigMetaCAT +from medcat.components.addons.meta_cat.meta_cat import MetaCATAddon +from medcat.components.addons.meta_cat.mctokenizers.tokenizers import TokenizerWrapperBase +from medcat.components.addons.meta_cat.data_utils import prepare_from_json, encode_category_values +from medcat.components.addons.meta_cat.ml_utils import create_batch_piped_data from medcat.vocab import Vocab from torch import nn from api.admin import retrieve_project_data from api.models import AnnotatedEntity, ProjectAnnotateEntities, ProjectMetrics as AppProjectMetrics +from api.utils import clear_cdb_cnf_addons from core.settings import MEDIA_ROOT _dt_fmt = '%Y-%m-%d %H:%M:%S.%f' @@ -51,6 +53,7 @@ def calculate_metrics(project_ids: List[int], report_name: str): else: # assume the cdb / vocab is set in these projects cdb = CDB.load(projects[0].concept_db.cdb_file.path) + clear_cdb_cnf_addons(cdb, projects[0].concept_db.name) vocab = Vocab.load(projects[0].vocab.vocab_file.path) cat = CAT(cdb, vocab, config=cdb.config) project_data = retrieve_project_data(projects) @@ -116,7 +119,7 @@ def annotation_df(self): """ annotation_df = pd.DataFrame(self.annotations) if self.cat: - annotation_df.insert(5, 'concept_name', annotation_df['cui'].map(self.cat.cdb.cui2preferred_name)) + annotation_df.insert(5, 'concept_name', annotation_df['cui'].map(self.cat.cdb.get_name)) annotation_df['last_modified'] = pd.to_datetime(annotation_df['last_modified']).dt.tz_localize(None) return annotation_df @@ -138,9 +141,10 @@ def concept_summary(self, extra_cui_filter=None): concept_count_df['count_variations_ratio'] = round(concept_count_df['concept_count'] / concept_count_df['variations'], 3) if self.cat: - fps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples = self.cat._print_stats(data=self.mct_export, - use_project_filters=True, - extra_cui_filter=extra_cui_filter) + fps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples = get_stats(self.cat, + data=self.mct_export, + use_project_filters=True, + extra_cui_filter=extra_cui_filter) # remap tps, fns, fps to specific user annotations examples = self.enrich_medcat_metrics(examples) concept_count_df['fps'] = concept_count_df['cui'].map(fps) @@ -242,11 +246,11 @@ def rename_meta_anns(self, meta_anns2rename=dict(), meta_ann_values2rename=dict( return def _eval_model(self, model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: TokenizerWrapperBase) -> Dict: - device = torch.device(config.general['device']) # Create a torch device - batch_size_eval = config.general['batch_size_eval'] - pad_id = config.model['padding_idx'] - ignore_cpos = config.model['ignore_cpos'] - class_weights = config.train['class_weights'] + device = torch.device(config.general.device) # Create a torch device + batch_size_eval = config.general.batch_size_eval + pad_id = config.model.padding_idx + ignore_cpos = config.model.ignore_cpos + class_weights = config.train.class_weights if class_weights is not None: class_weights = torch.FloatTensor(class_weights).to(device) @@ -323,9 +327,17 @@ def full_annotation_df(self) -> pd.DataFrame: ~anns_df['killed'] & ~anns_df['irrelevant']] meta_df = meta_df.reset_index(drop=True) - for meta_model in self.cat._meta_cats: - logger.info(f'Checking metacat model: {meta_model}') - meta_model_task = meta_model.name + all_meta_cats = self.cat.get_addons_of_type(MetaCATAddon) + + for meta_model_card in self.cat.get_model_card(as_dict=True)['MetaCAT models']: + meta_model_task = meta_model_card['Category Name'] + logger.info(f'Checking metacat model: {meta_model_task}') + _meta_models = [mc for mc in all_meta_cats + if mc.config.general.category_name == meta_model_task] + if not _meta_models: + logger.warning(f'MetaCAT model {meta_model_task} not found in the CAT instance.') + continue + meta_model = _meta_models[0] meta_results = self._eval(meta_model, self.mct_export) meta_values = {v: k for k, v in meta_results['meta_values'].items()} pred_meta_values = [] diff --git a/medcat-trainer/webapp/api/api/migrations/0074_auto_20231211_1526.py b/medcat-trainer/webapp/api/api/migrations/0074_auto_20231211_1526.py new file mode 100644 index 000000000..e910b17bd --- /dev/null +++ b/medcat-trainer/webapp/api/api/migrations/0074_auto_20231211_1526.py @@ -0,0 +1,18 @@ +# Generated by Django 2.2.28 on 2023-12-11 15:26 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('api', '0073_auto_20231022_0028'), + ] + + operations = [ + migrations.AlterField( + model_name='projectmetrics', + name='projects', + field=models.ManyToManyField(blank=True, to='api.ProjectAnnotateEntities'), + ), + ] diff --git a/medcat-trainer/webapp/api/api/migrations/0090_merge_20250623_1330.py b/medcat-trainer/webapp/api/api/migrations/0090_merge_20250623_1330.py new file mode 100644 index 000000000..8c502e4a7 --- /dev/null +++ b/medcat-trainer/webapp/api/api/migrations/0090_merge_20250623_1330.py @@ -0,0 +1,14 @@ +# Generated by Django 5.1.11 on 2025-06-23 13:30 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('api', '0074_auto_20231211_1526'), + ('api', '0089_projectannotateentities_deid_model_annotation_and_more'), + ] + + operations = [ + ] diff --git a/medcat-trainer/webapp/api/api/model_cache.py b/medcat-trainer/webapp/api/api/model_cache.py index b79a4ce0e..e21048517 100644 --- a/medcat-trainer/webapp/api/api/model_cache.py +++ b/medcat-trainer/webapp/api/api/model_cache.py @@ -1,11 +1,15 @@ import logging import os -from typing import Dict +from typing import Dict, Optional, Any -import pkg_resources +from pydantic import ValidationError + +from medcat import __version__ as mct_version from medcat.cat import CAT +from medcat.config.config import Config, SerialisableBaseModel from medcat.cdb import CDB from medcat.vocab import Vocab +from medcat.utils.legacy.convert_cdb import get_cdb_from_old from api.models import ConceptDB @@ -53,18 +57,30 @@ def get_medcat_from_cdb_vocab(project, cdb_path = project.concept_db.cdb_file.path try: cdb = CDB.load(cdb_path) + except NotADirectoryError as e: + logger.warning("Legacy CDB found, converting to new format") + # TODO: deserialise and write back to the model path? + cdb = get_cdb_from_old(cdb_path) + cdb.save(cdb_path) + cdb_map[cdb_id] = cdb + cdb_path = project.concept_db.cdb_file.path + cdb_map[cdb_id] = cdb + except KeyError as ke: - mc_v = pkg_resources.get_distribution('medcat').version + mc_v = mct_version if int(mc_v.split('.')[0]) > 0: logger.error('Attempted to load MedCAT v0.x model with MCTrainer v1.x') raise Exception('Attempted to load MedCAT v0.x model with MCTrainer v1.x', 'Please re-configure this project to use a MedCAT v1.x CDB or consult the ' 'MedCATTrainer Dev team if you believe this should work') from ke raise + # NOTE: dynamic import to avoid circular imports + from api.utils import clear_cdb_cnf_addons + clear_cdb_cnf_addons(cdb, cdb_id) custom_config = os.getenv("MEDCAT_CONFIG_FILE") if custom_config is not None and os.path.exists(custom_config): - cdb.config.parse_config_file(path=custom_config) + _parse_config_file(cdb.config, custom_config) else: logger.info("No MEDCAT_CONFIG_FILE env var set to valid path, using default config available on CDB") cdb_map[cdb_id] = cdb @@ -81,6 +97,62 @@ def get_medcat_from_cdb_vocab(project, return cat +def _parse_config_file(config: Config, + custom_config_path: str): + # NOTE: the v2 mappings are a little different + mappings = { + "linking": "components.linking", + "ner": "components.ner", + } + mappings_key = { + "spacy_model": "nlp.modelname" + } + with open(custom_config_path) as f: + for line in f: + if not line.strip().startswith("cat"): + continue + line = line[4:] + left, right = line.split("=") + variable, key = left.split(".") + variable = variable.strip() + # map to v2 + variable = mappings.get(variable, variable) + key = key.strip() + # key can also differ + key = mappings_key.get(key, key) + value = eval(right) + alt_value = set() if right.strip() in ({}, "{}") else None + + # get (potentially nested in case of v2 mapping) attribute + cnf = config + while "." in variable: + current, variable = variable.split(".", 1) + cnf = getattr(cnf, current) + attr = getattr(cnf, variable) + while "." in key: + cur_key, key = key.split(".", 1) + attr = getattr(attr, cur_key) + if isinstance(attr, SerialisableBaseModel): + _set_value_or_alt(attr, key, value, alt_value) + elif isinstance(attr, dict): + attr[key] = value + else: + raise ValueError(f'Unknown attribute {attr} for "{line}"') + + +def _set_value_or_alt(conf: SerialisableBaseModel, key: str, value: Any, + alt_value: Any, err: Optional[ValidationError] = None) -> None: + try: + setattr(conf, key, value) # hoping for correct type + except ValidationError as ve: + if alt_value is not None: + _set_value_or_alt(conf, key, alt_value, None, err=ve) + elif err is not None: + raise err + else: + raise ve + + def get_medcat_from_model_pack(project, cat_map: Dict[str, CAT]=CAT_MAP) -> CAT: model_pack_obj = project.model_pack cat_id = 'mp' + str(model_pack_obj.id) @@ -101,8 +173,8 @@ def get_medcat(project, else: cat = get_medcat_from_model_pack(project, cat_map) return cat - except AttributeError: - raise Exception('Failure loading Project ConceptDB, Vocab or Model Pack. Are these set correctly?') + except AttributeError as err: + raise Exception('Failure loading Project ConceptDB, Vocab or Model Pack. Are these set correctly?') from err def get_cached_medcat(project, cat_map: Dict[str, CAT]=CAT_MAP): @@ -133,6 +205,7 @@ def get_cached_cdb(cdb_id: str, cdb_map: Dict[str, CDB]=CDB_MAP) -> CDB: if cdb_id not in cdb_map: cdb_obj = ConceptDB.objects.get(id=cdb_id) cdb = CDB.load(cdb_obj.cdb_file.path) + clear_cdb_cnf_addons(cdb, cdb_id) cdb_map[cdb_id] = cdb return cdb_map[cdb_id] diff --git a/medcat-trainer/webapp/api/api/models.py b/medcat-trainer/webapp/api/api/models.py index f29bc8b76..82705af11 100644 --- a/medcat-trainer/webapp/api/api/models.py +++ b/medcat-trainer/webapp/api/api/models.py @@ -13,7 +13,7 @@ from medcat.cat import CAT from medcat.cdb import CDB from medcat.vocab import Vocab -from medcat.meta_cat import MetaCAT +from medcat.components.addons.meta_cat.meta_cat import MetaCAT, MetaCATAddon from polymorphic.models import PolymorphicModel from core.settings import MEDIA_ROOT @@ -42,14 +42,14 @@ class ModelPack(models.Model): meta_cats = models.ManyToManyField('MetaCATModel', blank=True, default=None) create_time = models.DateTimeField(auto_now_add=True) last_modified = models.DateTimeField(auto_now=True) - last_modified_by = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE, default=None, null=True) + last_modified_by = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE, default=None, null=True) @transaction.atomic def save(self, *args, **kwargs): is_new = self._state.adding if is_new: super().save(*args, **kwargs) - + # Process the model pack logger.info('Loading model pack: %s', self.model_pack) model_pack_name = str(self.model_pack).replace(".zip", "") @@ -64,7 +64,12 @@ def save(self, *args, **kwargs): CAT.load_cdb(unpacked_model_pack_path) concept_db = ConceptDB() unpacked_file_name = self.model_pack.file.name.replace('.zip', '') - concept_db.cdb_file.name = os.path.join(unpacked_file_name, 'cdb.dat') + # cdb path for v2 + cdb_path = os.path.join(unpacked_file_name, 'cdb') + if not os.path.exists(cdb_path): + # cdb path for v1 + cdb_path = os.path.join(unpacked_file_name, 'cdb.dat') + concept_db.cdb_file.name = cdb_path concept_db.name = f'{self.name}_CDB' concept_db.save(skip_load=True) self.concept_db = concept_db @@ -72,6 +77,7 @@ def save(self, *args, **kwargs): raise FileNotFoundError(f'Error loading the CDB from this model pack: {self.model_pack.path}') from exc # Load Vocab + vocab_path = os.path.join(unpacked_model_pack_path, "vocab.dat") if os.path.exists(vocab_path): Vocab.load(vocab_path) @@ -88,7 +94,12 @@ def save(self, *args, **kwargs): try: metaCATmodels = [] # should raise an error if there already is a MetaCAT model with this definition - for meta_cat_dir, meta_cat in CAT.load_meta_cats(unpacked_model_pack_path): + addons = CAT.load_addons(unpacked_model_pack_path) + meta_cat_addons = [ + (addon_path, addon) for addon_path, addon in addons + if isinstance(addon, MetaCATAddon)] + for meta_cat_dir, meta_cat_addon in meta_cat_addons: + meta_cat = meta_cat_addon.mc mc_model = MetaCATModel() mc_model.meta_cat_dir = meta_cat_dir.replace(f'{MEDIA_ROOT}/', '') mc_model.name = f'{meta_cat.config.general.category_name} - {meta_cat.config.model.model_name}' @@ -98,7 +109,7 @@ def save(self, *args, **kwargs): self.meta_cats.set(metaCATmodels) # Use set() instead of add() for atomic operation except Exception as exc: raise MedCATLoadException(f'Failure loading MetaCAT models - {unpacked_model_pack_path}') from exc - + # Only save if this is an update (not a new instance) if not is_new: super().save(*args, **kwargs) diff --git a/medcat-trainer/webapp/api/api/solr_utils.py b/medcat-trainer/webapp/api/api/solr_utils.py index a5b85a831..58e8c04a6 100644 --- a/medcat-trainer/webapp/api/api/solr_utils.py +++ b/medcat-trainer/webapp/api/api/solr_utils.py @@ -6,6 +6,7 @@ import requests from django.http import HttpResponseServerError from medcat.cdb import CDB +from medcat.cdb.concepts import CUIInfo from rest_framework.response import Response from api.models import ConceptDB @@ -129,14 +130,14 @@ def import_all_concepts(cdb: CDB, cdb_model: ConceptDB): if resp.status_code != 200: _solr_error_response(resp, 'Failure creating collection') - cui2name_iter = iter(cdb.cui2names.items()) + cui2info_iter = iter(cdb.cui2info.items()) payload = [] try: while True: for i in range(5000): - cui, name = next(cui2name_iter) - concept_dct = _concept_dct(cui, cdb) + cui, info = next(cui2info_iter) + concept_dct = _concept_dct(cui, cdb, info) payload.append(concept_dct) _upload_payload(f'{base_url}/{collection_name}/update', payload, collection_name) payload = [] @@ -176,7 +177,7 @@ def ensure_concept_searchable(cui, cdb: CDB, cdb_model: ConceptDB): resp = requests.get(url) if resp.status_code == 200: collections = json.loads(resp.text)['collections'] - data = [_concept_dct(cui, cdb)] + data = [_concept_dct(cui, cdb, cdb.cui2info[cui])] if collection in collections: _upload_payload(f'{base_url}/{collection}/update', data, collection, commit=True) @@ -191,14 +192,14 @@ def _upload_payload(update_url, data, collection, commit=False): _solr_error_response(resp, f'error updating {collection}') -def _concept_dct(cui: str, cdb: CDB): - synonyms = list(cdb.addl_info.get('cui2original_names', {}).get(cui, set())) +def _concept_dct(cui: str, cdb: CDB, info: CUIInfo): + synonyms = list(info['original_names'] or []) concept_dct = { 'cui': str(cui), 'pretty_name': cdb.get_name(cui), 'name': re.sub(r'\([\w+\s]+\)', '', cdb.get_name(cui)).strip(), - 'type_ids': list(cdb.cui2type_ids[cui]), - 'desc': cdb.addl_info.get('cui2description', {}).get(cui, ''), + 'type_ids': list(info['type_ids']), + 'desc': info['description'], 'synonyms': synonyms if len(synonyms) > 0 else [cdb.get_name(cui)] } return concept_dct diff --git a/medcat-trainer/webapp/api/api/utils.py b/medcat-trainer/webapp/api/api/utils.py index 8d4d64ad0..483370d48 100644 --- a/medcat-trainer/webapp/api/api/utils.py +++ b/medcat-trainer/webapp/api/api/utils.py @@ -9,9 +9,9 @@ from django.db.models.signals import post_save from django.dispatch import receiver from medcat.cat import CAT -from medcat.utils.filters import check_filters -from medcat.utils.helpers import tkns_from_doc -from medcat.utils.ner.deid import DeIdModel +from medcat.cdb import CDB +from medcat.components.ner.trf.deid import DeIdModel +from medcat.tokenizing.tokens import UnregisteredDataPathException from .model_cache import get_medcat from .models import Entity, AnnotatedEntity, ProjectAnnotateEntities, \ @@ -37,7 +37,7 @@ def remove_annotations(document, project, partial=False): def add_annotations(spacy_doc, user, project, document, existing_annotations, cat): - spacy_doc._.ents.sort(key=lambda x: len(x.text), reverse=True) + spacy_doc.linked_ents.sort(key=lambda x: len(x.text), reverse=True) tkns_in = [] ents = [] @@ -46,10 +46,10 @@ def add_annotations(spacy_doc, user, project, document, existing_annotations, ca # that can be produced are expected to have available models try: metatask2obj = {task_name: MetaTask.objects.get(name=task_name) - for task_name in spacy_doc._.ents[0]._.meta_anns.keys()} + for task_name in spacy_doc.linked_ents[0].get_addon_data('meta_cat_meta_anns').keys()} metataskvals2obj = {task_name: {v.name: v for v in MetaTask.objects.get(name=task_name).values.all()} - for task_name in spacy_doc._.ents[0]._.meta_anns.keys()} - except (AttributeError, IndexError): + for task_name in spacy_doc.linked_ents[0].get_addon_data('meta_cat_meta_anns').keys()} + except (AttributeError, IndexError, UnregisteredDataPathException): # IndexError: ignore if there are no annotations in this doc # AttributeError: ignore meta_anns that are not present - i.e. non model pack preds # or model pack preds with no meta_anns @@ -58,11 +58,17 @@ def add_annotations(spacy_doc, user, project, document, existing_annotations, ca pass def check_ents(ent): - return any((ea[0] < ent.start_char < ea[1]) or - (ea[0] < ent.end_char < ea[1]) for ea in existing_annos_intervals) + return any((ea[0] < ent.start_char_index < ea[1]) or + (ea[0] < ent.end_char_index < ea[1]) for ea in existing_annos_intervals) - for ent in spacy_doc._.ents: - if not check_ents(ent) and check_filters(ent._.cui, cat.config.linking['filters']): + def check_filters(cui, filters): + if cui in filters.cuis or not filters.cuis: + return cui not in filters.cuis_exclude + else: + return False + + for ent in spacy_doc.linked_ents: + if not check_ents(ent) and check_filters(ent.cui, cat.config.components.linking.filters): to_add = True for tkn in ent: if tkn in tkns_in: @@ -75,7 +81,7 @@ def check_ents(ent): logger.debug('Found %s annotations to store', len(ents)) for ent in ents: logger.debug('Processing annotation ent %s of %s', ents.index(ent), len(ents)) - label = ent._.cui + label = ent.cui if not Entity.objects.filter(label=label).exists(): # Create the entity @@ -87,8 +93,8 @@ def check_ents(ent): ann_ent = AnnotatedEntity.objects.filter(project=project, document=document, - start_ind=ent.start_char, - end_ind=ent.end_char).first() + start_ind=ent.start_char_index, + end_ind=ent.end_char_index).first() if ann_ent is None: # If this entity doesn't exist already ann_ent = AnnotatedEntity() @@ -97,30 +103,39 @@ def check_ents(ent): ann_ent.document = document ann_ent.entity = entity ann_ent.value = ent.text - ann_ent.start_ind = ent.start_char - ann_ent.end_ind = ent.end_char - ann_ent.acc = ent._.context_similarity + ann_ent.start_ind = ent.start_char_index + ann_ent.end_ind = ent.end_char_index + ann_ent.acc = ent.context_similarity - MIN_ACC = cat.config.linking.get('similarity_threshold_trainer', 0.2) - if ent._.context_similarity < MIN_ACC: + MIN_ACC = cat.config.components.linking.similarity_threshold + if ent.context_similarity < MIN_ACC: ann_ent.deleted = True ann_ent.validated = True ann_ent.save() - # check the ent._.meta_anns if it exists - if hasattr(ent._, 'meta_anns') and len(metatask2obj) > 0 and len(metataskvals2obj) > 0: - logger.debug('Found %s meta annos on ent', len(ent._.meta_anns.items())) - for meta_ann_task, pred in ent._.meta_anns.items(): - meta_anno_obj = MetaAnnotation() - meta_anno_obj.predicted_meta_task_value = metataskvals2obj[meta_ann_task][pred['value']] - meta_anno_obj.meta_task = metatask2obj[meta_ann_task] - meta_anno_obj.annotated_entity = ann_ent - meta_anno_obj.meta_task_value = metataskvals2obj[meta_ann_task][pred['value']] - meta_anno_obj.acc = pred['confidence'] - meta_anno_obj.save() - logger.debug('Successfully saved %s', meta_anno_obj) - + # TODO: Fix before v2 release. + # check the ent.get_addon_data('meta_cat_meta_anns') if it exists + # if hasattr(ent, 'get_addon_data') and \ + # len(metatask2obj) > 0 and + # len(metataskvals2obj) > 0: + # logger.debug('Found %s meta annos on ent', len(ent._.meta_anns.items())) + # for meta_ann_task, pred in ent._.meta_anns.items(): + # meta_anno_obj = MetaAnnotation() + # meta_anno_obj.predicted_meta_task_value = metataskvals2obj[meta_ann_task][pred['value']] + # meta_anno_obj.meta_task = metatask2obj[meta_ann_task] + # meta_anno_obj.annotated_entity = ann_ent + # meta_anno_obj.meta_task_value = metataskvals2obj[meta_ann_task][pred['value']] + # meta_anno_obj.acc = pred['confidence'] + # meta_anno_obj.save() + # logger.debug('Successfully saved %s', meta_anno_obj) + + +def clear_cdb_cnf_addons(cdb: CDB, cdb_id: str | int): + # NOTE: when loading a CDB separately, we don't necessarily want to + # load / create addons like MetaCAT as well + logger.info('Clearing addons for CDB upon load: %s', cdb_id) + cdb.config.components.addons.clear() def get_create_cdb_infos(cdb, concept, cui, cui_info_prop, code_prop, desc_prop, model_clazz): @@ -206,35 +221,37 @@ def train_medcat(cat, project, document): for ann in anns: cui = ann.entity.label # Indices for this annotation - spacy_entity = tkns_from_doc(spacy_doc=spacy_doc, start=ann.start_ind, end=ann.end_ind) + spacy_entity = [tkn for tkn in spacy_doc if tkn.char_index == ann.start_ind] # This will add the concept if it doesn't exist and if it - #does just link the new name to the concept, if the namee is - #already linked then it will just train. + # does just link the new name to the concept, if the namee is + # already linked then it will just train. manually_created = False if ann.manually_created or ann.alternative: manually_created = True - cat.add_and_train_concept(cui=cui, - name=ann.value, - spacy_doc=spacy_doc, - spacy_entity=spacy_entity, - negative=ann.deleted, - devalue_others=manually_created) + cat.trainer.add_and_train_concept( + cui=cui, + name=ann.value, + mut_doc=spacy_doc, + mut_entity=spacy_entity, + negative=ann.deleted, + devalue_others=manually_created + ) # Completely remove concept names that the user killed killed_anns = AnnotatedEntity.objects.filter(project=project, document=document, killed=True) for ann in killed_anns: cui = ann.entity.label name = ann.value - cat.unlink_concept_name(cui=cui, name=name) + cat.trainer.unlink_concept_name(cui=cui, name=name) # Add irrelevant cuis to cui_exclude irrelevant_anns = AnnotatedEntity.objects.filter(project=project, document=document, irrelevant=True) for ann in irrelevant_anns: cui = ann.entity.label - if 'cuis_exclude' not in cat.config.linking['filters']: - cat.config.linking['filters']['cuis_exclude'] = set() - cat.config.linking['filters'].get('cuis_exclude').update([cui]) + if 'cuis_exclude' not in cat.config.components.linking.filters: + cat.config.components.linking.filters['cuis_exclude'] = set() + cat.config.components.linking.filters.get('cuis_exclude').update([cui]) @background(schedule=1, queue='doc_prep') @@ -247,7 +264,7 @@ def prep_docs(project_id: List[int], doc_ids: List[int], user_id: int): cat = get_medcat(project=project) # Set CAT filters - cat.config.linking['filters']['cuis'] = project.cuis + cat.config.components.linking.filters.cuis = project.cuis for doc in docs: logger.info(f'Running MedCAT model for project {project.id}:{project.name} over doc: {doc.id}') diff --git a/medcat-trainer/webapp/api/api/views.py b/medcat-trainer/webapp/api/api/views.py index b53fb1c0f..813d12cbd 100644 --- a/medcat-trainer/webapp/api/api/views.py +++ b/medcat-trainer/webapp/api/api/views.py @@ -10,11 +10,10 @@ from django.shortcuts import render from django.utils import timezone from django_filters import rest_framework as drf -from medcat.utils.helpers import tkns_from_doc from rest_framework import viewsets from rest_framework.decorators import api_view from rest_framework.response import Response -from medcat.utils.ner.deid import DeIdModel +from medcat.components.ner.trf.deid import DeIdModel from .admin import download_projects_with_text, download_projects_without_text, \ import_concepts_from_cdb @@ -282,7 +281,7 @@ def prepare_documents(request): logger.info('loaded medcat model for project: %s', project.id) # Set CAT filters - cat.config.linking['filters']['cuis'] = cuis + cat.config.components.linking.filters.cuis = cuis if not project.deid_model_annotation: spacy_doc = cat(document.text) @@ -304,6 +303,7 @@ def prepare_documents(request): project.save() except Exception as e: + logger.warning('Error preparing documents for project %s', p_id, exc_info=e) stack = traceback.format_exc() return Response({'message': e.args[0] if len(e.args) > 0 else 'Internal Server Error', 'description': e.args[1] if len(e.args) > 1 else '', @@ -424,9 +424,9 @@ def add_concept(request): if source_val in spacy_doc.text: start = spacy_doc.text.index(source_val) end = start + len(source_val) - spacy_entity = tkns_from_doc(spacy_doc=spacy_doc, start=start, end=end) + spacy_entity = [tkn for tkn in spacy_doc if tkn.idx >= start and tkn.idx <= end] - cat.add_and_train_concept(cui=cui, name=name, name_status='P', spacy_doc=spacy_doc, spacy_entity=spacy_entity) + cat.trainer.add_and_train_concept(cui=cui, name=name, name_status='P', mut_doc=spacy_doc, mut_entity=spacy_entity) id = create_annotation(source_val=source_val, selection_occurrence_index=sel_occur_idx, @@ -461,16 +461,8 @@ def import_cdb_concepts(request): def _submit_document(project: ProjectAnnotateEntities, document: Document): if project.train_model_on_submit: - try: - cat = get_medcat(project=project) - train_medcat(cat, project, document) - except Exception as e: - if project.vocab.id: - if len(VOCAB_MAP[project.vocab.id].unigram_table) == 0: - return Exception('Vocab is missing the unigram table. On the vocab instance ' - 'use vocab.make_unigram_table() to build') - else: - raise e + cat = get_medcat(project=project) + train_medcat(cat, project, document) # Add cuis to filter if they did not exist cuis = [] @@ -614,23 +606,23 @@ def annotate_text(request): project = ProjectAnnotateEntities.objects.get(id=p_id) cat = get_medcat(project=project) - cat.config.linking['filters']['cuis'] = set(cuis) + cat.config.components.linking.filters.cuis = set(cuis) spacy_doc = cat(message) ents = [] anno_tkns = [] - for ent in spacy_doc._.ents: - cnt = Entity.objects.filter(label=ent._.cui).count() + for ent in spacy_doc.linked_ents: + cnt = Entity.objects.filter(label=ent.cui).count() inc_ent = all(tkn not in anno_tkns for tkn in ent) if inc_ent and cnt != 0: anno_tkns.extend([tkn for tkn in ent]) - entity = Entity.objects.get(label=ent._.cui) + entity = Entity.objects.get(label=ent.cui) ents.append({ 'entity': entity.id, - 'value': ent.text, - 'start_ind': ent.start_char, - 'end_ind': ent.end_char, - 'acc': ent._.context_similarity + 'value': ent.base.text, + 'start_ind': ent.base.start_char_index, + 'end_ind': ent.base.end_char_index, + 'acc': ent.context_similarity }) ents.sort(key=lambda e: e['start_ind']) diff --git a/medcat-trainer/webapp/requirements.txt b/medcat-trainer/webapp/requirements.txt index 095bb55ba..55d6611d0 100644 --- a/medcat-trainer/webapp/requirements.txt +++ b/medcat-trainer/webapp/requirements.txt @@ -6,6 +6,6 @@ django-polymorphic==3.0.* djangorestframework==3.15.* django-background-tasks-updated==1.2.* openpyxl==3.1.2 -medcat==1.16.* +medcat[meta-cat,spacy,rel-cat,deid]==2.1.* psycopg[binary,pool]==3.2.9 django-health-check==3.20.0 \ No newline at end of file