Skip to content

Commit 5384062

Browse files
committed
CU-8699np02n Improve legacy conversion (#29)
* CU-8699np02n: Update CDB legacy conversion so that it works with CDBs with no name_isupper attribute * CU-8699np02n: Add method to legacy converter to convert any config * CU-8699np02n: Fix config legacy conversion * CU-8699np02n: Add a few simple tests for Config legacy conversion * CU-8699np02n: Add a little more sophistication to general config conversion tests
1 parent 65008f0 commit 5384062

File tree

6 files changed

+98
-3
lines changed

6 files changed

+98
-3
lines changed

medcat-v2/medcat/config/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import (Optional, Iterator, Iterable, TypeVar, cast, Type, Any,
23
Literal)
34
from typing import Protocol, runtime_checkable
@@ -12,6 +13,9 @@
1213
from medcat.utils.defaults import workers
1314
from medcat.utils.envsnapshot import Environment, get_environment_info
1415
from medcat.utils.iterutils import callback_iterator
16+
from medcat.utils.defaults import (
17+
avoid_legacy_conversion, doing_legacy_conversion_message,
18+
LegacyConversionDisabledError)
1519
from medcat.storage.serialisables import SerialisingStrategy
1620
from medcat.storage.serialisers import deserialise
1721

@@ -80,6 +84,13 @@ def merge_config(self, other: dict):
8084

8185
@classmethod
8286
def load(cls, path: str) -> Self:
87+
if os.path.isfile(path) and path.endswith(".dat"):
88+
if avoid_legacy_conversion():
89+
raise LegacyConversionDisabledError(cls.__name__)
90+
doing_legacy_conversion_message(logger, cls.__name__, path)
91+
from medcat.utils.legacy.convert_config import (
92+
get_config_from_old_per_cls)
93+
return cast(Self, get_config_from_old_per_cls(path, cls))
8394
obj = deserialise(path)
8495
if not isinstance(obj, cls):
8596
raise ValueError(f"The path '{path}' is not a {cls.__name__}!")

medcat-v2/medcat/utils/legacy/convert_config.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import json
2-
from typing import Any, cast, Optional
2+
from typing import Any, cast, Optional, Type
33
import logging
44

55
from pydantic import BaseModel
66

77
from medcat.config import Config
88

99
from medcat.utils.legacy.helpers import fix_old_style_cnf
10+
from medcat.config.config import SerialisableBaseModel
1011

1112

1213
logger = logging.getLogger(__name__)
@@ -185,3 +186,34 @@ def get_config_from_old(path: str) -> Config:
185186
with open(path) as f:
186187
old_cnf_data = json.load(f)
187188
return get_config_from_nested_dict(old_cnf_data)
189+
190+
191+
def get_config_from_old_per_cls(
192+
path: str, cls: Type[SerialisableBaseModel]) -> SerialisableBaseModel:
193+
"""Convert the saved v1 config into a v2 Config for a specific class.
194+
195+
Args:
196+
path (str): The v1 config path.
197+
cls (Type[SerialisableBaseModel]): The class to convert to.
198+
199+
Returns:
200+
SerialisableBaseModel: The converted config.
201+
"""
202+
from medcat.config.config_meta_cat import ConfigMetaCAT
203+
from medcat.config.config_transformers_ner import ConfigTransformersNER
204+
from medcat.config.config_rel_cat import ConfigRelCAT
205+
if cls is Config:
206+
return get_config_from_old(path)
207+
elif cls is ConfigMetaCAT:
208+
from medcat.utils.legacy.convert_meta_cat import (
209+
load_cnf as load_meta_cat_cnf)
210+
return load_meta_cat_cnf(path)
211+
elif cls is ConfigTransformersNER:
212+
from medcat.utils.legacy.convert_deid import (
213+
get_cnf as load_deid_cnf)
214+
return load_deid_cnf(path)
215+
elif cls is ConfigRelCAT:
216+
from medcat.utils.legacy.convert_rel_cat import (
217+
load_cnf as load_rel_cat_cnf)
218+
return load_rel_cat_cnf(path)
219+
raise ValueError(f"The config at '{path}' is not a {cls.__name__}!")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"general": {"name": "NOT-DEID", "model_name": "roberta-base", "seed": 13, "description": "No description", "pipe_batch_size_in_chars": 20000000, "ner_aggregation_strategy": "simple", "chunking_overlap_window": 5, "test_size": 0.2, "last_train_on": null, "verbose_metrics": false}}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"general": {"device": "cpu", "disable_component_lock": false, "seed": 13, "description": "No description", "category_name": "TEST CATEGORY", "alternative_category_names": [], "category_value2id": {}, "alternative_class_names": [[]], "vocab_size": 3, "lowercase": true, "cntx_left": 15, "cntx_right": 10, "replace_center": null, "batch_size_eval": 5000, "annotate_overlapping": false, "tokenizer_name": "bbpe", "save_and_reuse_tokens": false, "pipe_batch_size_in_chars": 20000000, "span_group": null}, "model": {"model_name": "lstm", "model_variant": "bert-base-uncased", "model_freeze_layers": true, "num_layers": 2, "input_size": 300, "hidden_size": 300, "dropout": 0.5, "phase_number": 0, "category_undersample": "", "model_architecture_config": {"fc2": true, "fc3": false, "lr_scheduler": true}, "num_directions": 2, "nclasses": 2, "padding_idx": -1, "emb_grad": true, "ignore_cpos": false}, "train": {"batch_size": 100, "nepochs": 50, "lr": 0.001, "test_size": 0.1, "shuffle_data": true, "class_weights": null, "compute_class_weights": false, "score_average": "weighted", "prerequisites": {}, "cui_filter": null, "auto_save_model": true, "last_train_on": null, "metric": {"base": "weighted avg", "score": "f1-score"}, "loss_funct": "cross_entropy", "gamma": 2}}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"general": {"device": "cpu", "relation_type_filter_pairs": [], "vocab_size": null, "lowercase": true, "cntx_left": 15, "cntx_right": 15, "window_size": 300, "limit_samples_per_class": -1, "addl_rels_max_sample_size": 200, "create_addl_rels": false, "create_addl_rels_by_type": false, "tokenizer_name": "bert", "model_name": "bert-unknown", "log_level": 20, "max_seq_length": 512, "tokenizer_special_tokens": false, "annotation_schema_tag_ids": [30522, 30523, 30524, 30525], "tokenizer_relation_annotation_special_tokens_tags": ["[s1]", "[e1]", "[s2]", "[e2]"], "tokenizer_other_special_tokens": {"pad_token": "[PAD]"}, "labels2idx": {}, "idx2labels": {}, "pin_memory": true, "seed": 13, "task": "train", "language": "en"}, "model": {"input_size": 300, "hidden_size": 768, "hidden_layers": 3, "model_size": 5120, "dropout": 0.2, "num_directions": 2, "freeze_layers": true, "padding_idx": -1, "emb_grad": true, "ignore_cpos": false, "llama_use_pooled_output": false}, "train": {"nclasses": 2, "batch_size": 25, "nepochs": 1, "lr": 0.0001, "stratified_batching": false, "batching_samples_per_class": [], "batching_minority_limit": 0, "adam_betas": [0.9, 0.999], "adam_weight_decay": 0, "adam_epsilon": 1e-08, "test_size": 0.2, "gradient_acc_steps": 1, "multistep_milestones": [2, 4, 6, 8, 12, 15, 18, 20, 22, 24, 26, 30], "multistep_lr_gamma": 0.8, "max_grad_norm": 1.0, "shuffle_data": true, "class_weights": null, "enable_class_weights": false, "score_average": "weighted", "auto_save_model": true}}

medcat-v2/tests/utils/legacy/test_convert_config.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1+
from typing import Type, Any
12
import os
23

34
from medcat.utils.legacy import convert_config
45

56
from medcat.config import Config
7+
from medcat.config.config import SerialisableBaseModel
8+
from medcat.config.config_meta_cat import ConfigMetaCAT
9+
from medcat.config.config_rel_cat import ConfigRelCAT
10+
from medcat.config.config_transformers_ner import ConfigTransformersNER
611

712
import unittest
813

914

10-
TESTS_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__),
11-
"..", ".."))
15+
from ... import RESOURCES_PATH
16+
TESTS_PATH = os.path.dirname(RESOURCES_PATH)
1217

1318

1419
class ValAndModelGetterTests(unittest.TestCase):
@@ -78,3 +83,47 @@ def test_migrates_partial(self):
7883
def test_preprocesses_sets(self):
7984
self.assertEqual(self.cnf.preprocessing.words_to_skip,
8085
self.EXP_WORDS_TO_SKIP)
86+
87+
88+
class PerClsConfigConversionTests(unittest.TestCase):
89+
# paths, classes, expected path, expected value
90+
# NOTE: These are hard-coded values I know I changed in the confgis
91+
# before saving
92+
PATHS_AND_CLASSES: list[str, Type[SerialisableBaseModel], str, Any] = [
93+
(os.path.join(RESOURCES_PATH,
94+
"mct_v1_cnf.json"), Config,
95+
'meta.description', "FAKE MODEL"),
96+
(os.path.join(RESOURCES_PATH,
97+
"mct_v1_meta_cat_cnf.json"), ConfigMetaCAT,
98+
"general.category_name", 'TEST CATEGORY'),
99+
(os.path.join(RESOURCES_PATH,
100+
"mct_v1_rel_cat_cnf.json"), ConfigRelCAT,
101+
"general.model_name", 'bert-unknown'),
102+
(os.path.join(RESOURCES_PATH,
103+
"mct_v1_deid_cnf.json"), ConfigTransformersNER,
104+
"general.name", 'NOT-DEID'),
105+
]
106+
107+
@classmethod
108+
def setUpClass(cls):
109+
return super().setUpClass()
110+
111+
def _get_attr_nested(self, obj: SerialisableBaseModel, path: str) -> Any:
112+
"""Get an attribute from a nested object using a dot-separated path."""
113+
parts = path.split('.')
114+
for part in parts:
115+
obj = getattr(obj, part)
116+
return obj
117+
118+
def assert_can_convert(
119+
self, path, cls: Type[SerialisableBaseModel],
120+
exp_path: str, exp_value: Any):
121+
cnf = convert_config.get_config_from_old_per_cls(path, cls)
122+
self.assertIsInstance(cnf, cls, f"Failed for {cls.__name__}")
123+
self.assertEqual(self._get_attr_nested(cnf, exp_path), exp_value,
124+
f"Failed for {cls.__name__} at {exp_path}")
125+
126+
def test_can_convert(self):
127+
for path, cls, exp_path, exp_value in self.PATHS_AND_CLASSES:
128+
with self.subTest(f"Testing {cls.__name__} at {path}"):
129+
self.assert_can_convert(path, cls, exp_path, exp_value)

0 commit comments

Comments
 (0)