Skip to content

Commit 0121ea1

Browse files
committed
CU-8699vkmu4: Allow load with merging config(s) (#53)
* CU-8699vkmu4: Add option to load MetaCATs with config * CU-8699vkmu4: Add small test to make sure MetaCAT config is merged correctly * CU-8699vkmu4: Allow model config when loading model pack * CU-8699vkmu4: Add small tests for CAT model config addition * CU-8699vkmu4: Allow merging in arbitrary addon configs upon model load if/when specified * CU-8699vkmu4: Add small test for addon config merge upon CAT load * CU-8699vkmu4: Make a method to a class method * CU-8699vkmu4: Limit indentation * CU-8699vkmu4: Generalis load_addons method to for all addons
1 parent 827fb3d commit 0121ea1

File tree

4 files changed

+113
-17
lines changed

4 files changed

+113
-17
lines changed

medcat-v2/medcat/cat.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def __init__(self,
4949
vocab: Union[Vocab, None] = None,
5050
config: Optional[Config] = None,
5151
model_load_path: Optional[str] = None,
52+
config_dict: Optional[dict] = None,
53+
addon_config_dict: Optional[dict[str, dict]] = None,
5254
) -> None:
5355
self.cdb = cdb
5456
self.vocab = vocab
@@ -60,20 +62,24 @@ def __init__(self,
6062
elif config is not None:
6163
self.cdb.config = config
6264
self.config = config
65+
if config_dict:
66+
self.config.merge_config(config_dict)
6367

6468
self._trainer: Optional[Trainer] = None
65-
self._pipeline = self._recreate_pipe(model_load_path)
69+
self._pipeline = self._recreate_pipe(model_load_path, addon_config_dict)
6670
self.usage_monitor = UsageMonitor(
6771
self._get_hash, self.config.general.usage_monitor)
6872

69-
def _recreate_pipe(self, model_load_path: Optional[str] = None
73+
def _recreate_pipe(self, model_load_path: Optional[str] = None,
74+
addon_config_dict: Optional[dict[str, dict]] = None,
7075
) -> Pipeline:
7176
if hasattr(self, "_pipeline"):
7277
old_pipe = self._pipeline
7378
else:
7479
old_pipe = None
7580
self._pipeline = Pipeline(self.cdb, self.vocab, model_load_path,
76-
old_pipe=old_pipe)
81+
old_pipe=old_pipe,
82+
addon_config_dict=addon_config_dict)
7783
return self._pipeline
7884

7985
@classmethod
@@ -668,11 +674,21 @@ def attempt_unpack(cls, zip_path: str) -> str:
668674
return model_pack_path
669675

670676
@classmethod
671-
def load_model_pack(cls, model_pack_path: str) -> 'CAT':
677+
def load_model_pack(cls, model_pack_path: str,
678+
config_dict: Optional[dict] = None,
679+
addon_config_dict: Optional[dict[str, dict]] = None
680+
) -> 'CAT':
672681
"""Load the model pack from file.
673682
674683
Args:
675684
model_pack_path (str): The model pack path.
685+
config_dict (Optional[dict]): The model config to
686+
merge in before initialising the pipe. Defaults to None.
687+
addon_config_dict (Optional[dict]): The Addon-specific
688+
config dict to merge in before pipe initialisation.
689+
If specified, it needs to have an addon dict per name.
690+
For instance, `{"meta_cat.Subject": {}}` would apply
691+
to the specific MetaCAT.
676692
677693
Raises:
678694
ValueError: If the saved data does not represent a model pack.
@@ -703,7 +719,9 @@ def load_model_pack(cls, model_pack_path: str) -> 'CAT':
703719
TOKENIZER_PREFIX,
704720
# components will be loaded semi-manually
705721
# within the creation of pipe
706-
COMPONENTS_FOLDER})
722+
COMPONENTS_FOLDER},
723+
config_dict=config_dict,
724+
addon_config_dict=addon_config_dict)
707725
# NOTE: deserialising of components that need serialised
708726
# will be dealt with upon pipeline creation automatically
709727
if not isinstance(cat, CAT):
@@ -730,33 +748,40 @@ def load_cdb(cls, model_pack_path: str) -> CDB:
730748

731749
@classmethod
732750
def load_addons(
733-
cls, model_pack_path: str, meta_cat_config_dict: Optional[dict] = None
751+
cls, model_pack_path: str,
752+
addon_config_dict: Optional[dict[str, dict]] = None
734753
) -> list[tuple[str, AddonComponent]]:
735754
"""Load addons based on a model pack path.
736755
737756
Args:
738757
model_pack_path (str): path to model pack, zip or dir.
739-
meta_cat_config_dict (Optional[dict]):
740-
A config dict that will overwrite existing configs in meta_cat.
741-
e.g. meta_cat_config_dict = {'general': {'device': 'cpu'}}.
742-
Defaults to None.
758+
addon_config_dict (Optional[dict]): The Addon-specific
759+
config dict to merge in before pipe initialisation.
760+
If specified, it needs to have an addon dict per name.
761+
For instance,
762+
`{"meta_cat.Subject": {'general': {'device': 'cpu'}}}`
763+
would apply to the specific MetaCAT.
743764
744765
Returns:
745766
List[tuple(str, AddonComponent)]: list of pairs of adddon names the addons.
746767
"""
747768
components_folder = os.path.join(model_pack_path, COMPONENTS_FOLDER)
748769
if not os.path.exists(components_folder):
749770
return []
750-
addon_paths = [
751-
folder_path
771+
addon_paths_and_names = [
772+
(folder_path, folder_name.removeprefix(AddonComponent.NAME_PREFIX))
752773
for folder_name in os.listdir(components_folder)
753774
if os.path.isdir(folder_path := os.path.join(
754775
components_folder, folder_name))
755776
and folder_name.startswith(AddonComponent.NAME_PREFIX)
756777
]
757778
loaded_addons = [
758-
addon for addon_path in addon_paths
759-
if isinstance(addon := deserialise(addon_path), AddonComponent)
779+
addon for addon_path, addon_name in addon_paths_and_names
780+
if isinstance(addon := (
781+
deserialise(addon_path, model_config=addon_config_dict.get(addon_name))
782+
if addon_config_dict else
783+
deserialise(addon_path)
784+
), AddonComponent)
760785
]
761786
return [(addon.full_name, addon) for addon in loaded_addons]
762787

medcat-v2/medcat/components/addons/meta_cat/meta_cat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ def deserialise_from(cls, folder_path: str, **init_kwargs
221221
"Inferring config from file at '%s'", folder_path,
222222
config_path)
223223
cnf = ConfigMetaCAT.load(config_path)
224+
if 'model_config' in init_kwargs:
225+
cnf.merge_config(init_kwargs['model_config'])
224226
if 'tokenizer' in init_kwargs:
225227
tokenizer = init_kwargs['tokenizer']
226228
else:

medcat-v2/medcat/pipeline/pipeline.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ class Pipeline:
7272
def __init__(self, cdb: CDB, vocab: Optional[Vocab],
7373
model_load_path: Optional[str],
7474
# NOTE: upon reload, old pipe can be useful
75-
old_pipe: Optional['Pipeline'] = None):
75+
old_pipe: Optional['Pipeline'] = None,
76+
addon_config_dict: Optional[dict[str, dict]] = None):
7677
self.cdb = cdb
7778
# NOTE: Vocab is None in case of DeID models and thats fine then,
7879
# but it should be non-None otherwise
@@ -81,7 +82,7 @@ def __init__(self, cdb: CDB, vocab: Optional[Vocab],
8182
self._tokenizer = self._init_tokenizer()
8283
self._components: list[CoreComponent] = []
8384
self._addons: list[AddonComponent] = []
84-
self._init_components(model_load_path, old_pipe)
85+
self._init_components(model_load_path, old_pipe, addon_config_dict)
8586

8687
@property
8788
def tokenizer(self) -> BaseTokenizer:
@@ -172,8 +173,45 @@ def _load_saved_core_component(self, cct_name: str, comp_folder_path: str
172173
f"'{comp.get_type().name}' instead.")
173174
return comp
174175

176+
@classmethod
177+
def _attempt_merge(
178+
cls, addon_cnf: ComponentConfig,
179+
addon_config_dict: dict[str, dict]) -> None:
180+
for name, config_dict in addon_config_dict.items():
181+
if not name.startswith(addon_cnf.comp_name):
182+
continue
183+
# TODO: is there an option to do this in a more general way?
184+
# right now it's an implementation-specific code smell
185+
if isinstance(addon_cnf, ConfigMetaCAT):
186+
full_name = f"{addon_cnf.comp_name}.{addon_cnf.general.category_name}"
187+
if name == full_name:
188+
addon_cnf.merge_config(config_dict)
189+
return
190+
continue
191+
logger.warning(
192+
"No implementation specified for defining if/when %s"
193+
"should apply to addon config (e.g %s)",
194+
type(addon_cnf).__name__, name)
195+
# if only 1 of the type, then just merge
196+
similars = [cd for oname, cd in addon_config_dict.items()
197+
if oname.startswith(addon_cnf.comp_name)]
198+
if len(similars) == 1:
199+
logger.warning(
200+
"Since there is only 1 config for this type (%s) specified "
201+
"we will just merge the configs (@%s).",
202+
addon_cnf.comp_name, name)
203+
addon_cnf.merge_config(config_dict)
204+
return
205+
else:
206+
logger.warning(
207+
"There are %d similar configs (%s) specified, so unable to "
208+
"merge the config since it's ambiguous (@%s)",
209+
len(similars), addon_cnf.comp_name, name)
210+
175211
def _init_components(self, model_load_path: Optional[str],
176-
old_pipe: Optional['Pipeline']) -> None:
212+
old_pipe: Optional['Pipeline'],
213+
addon_config_dict: Optional[dict[str, dict]],
214+
) -> None:
177215
(loaded_core_component_paths,
178216
loaded_addon_component_paths) = self._get_loaded_components_paths(
179217
model_load_path)
@@ -186,6 +224,8 @@ def _init_components(self, model_load_path: Optional[str],
186224
CoreComponentType[cct_name], model_load_path)
187225
self._components.append(comp)
188226
for addon_cnf in self.config.components.addons:
227+
if addon_config_dict:
228+
self._attempt_merge(addon_cnf, addon_config_dict)
189229
addon = self._init_addon(
190230
addon_cnf, loaded_addon_component_paths, old_pipe)
191231
# mark as not dirty at loat / init time

medcat-v2/tests/test_cat.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,19 @@ def setUpClass(cls):
7171
cls.model.config.components.linking.train = False
7272

7373

74+
class ConfigMergeTests(unittest.TestCase):
75+
spacy_model_name = 'en_core_web_lg'
76+
model_dict = {
77+
"general": {'nlp': {"modelname": spacy_model_name}}
78+
}
79+
80+
def test_can_merge_config(self):
81+
model = cat.CAT.load_model_pack(
82+
EXAMPLE_MODEL_PACK_ZIP, config_dict=self.model_dict)
83+
self.assertEqual(
84+
model.config.general.nlp.modelname, self.spacy_model_name)
85+
86+
7487
class InferenceFromLoadedTests(TrainedModelTests):
7588

7689
def test_can_load_model(self):
@@ -298,6 +311,22 @@ def test_can_load_meta_cat(self):
298311
_, addon = addons[0]
299312
self.assertIsInstance(addon, MetaCATAddon)
300313

314+
def test_can_load_meta_cat_with_addon_cnf(self, seed: int = -41):
315+
mc: MetaCATAddon = cat.CAT.load_addons(
316+
self.mpp, addon_config_dict={
317+
"meta_cat.Status": {
318+
"general": {"seed": seed}}})[0][1]
319+
self.assertEqual(mc.config.general.seed, seed)
320+
321+
def test_can_merge_cnf_upon_load(self, use_seed: int = -4):
322+
loaded = cat.CAT.load_model_pack(
323+
self.mpp,
324+
addon_config_dict={
325+
"meta_cat.Status": {"general": {"seed": use_seed}}
326+
})
327+
addon: MetaCATAddon = list(loaded._pipeline.iter_addons())[0]
328+
self.assertEqual(addon.config.general.seed, use_seed)
329+
301330

302331
class CatWithChangesMetaCATTests(CatWithMetaCATTests):
303332
EXPECTED_HASH = "0b22401059a08380"

0 commit comments

Comments
 (0)