diff --git a/medcat/config_meta_cat.py b/medcat/config_meta_cat.py index 42a9dab7..7f92ca78 100644 --- a/medcat/config_meta_cat.py +++ b/medcat/config_meta_cat.py @@ -133,6 +133,12 @@ class Model(MixingConfig, BaseModel): NB! For these changes to take effect, the pipe would need to be recreated. """ + load_bert_pretrained_weights: bool = False + """ + Applicable only when using BERT: + Determines if the pretrained weights for BERT are loaded + This should be True if you don't plan on using the model pack weights""" + num_layers: int = 2 """Number of layers in the model (both LSTM and BERT) @@ -164,7 +170,9 @@ class Model(MixingConfig, BaseModel): Paper reference - https://ieeexplore.ieee.org/document/7533053""" category_undersample: str = '' - """When using 2 phase learning, this category is used to undersample the data""" + """When using 2 phase learning, this category is used to undersample the data + The number of samples in the category sets the upper limit for all categories""" + model_architecture_config: Dict = {'fc2': True, 'fc3': False,'lr_scheduler': True} """Specifies the architecture for BERT model. diff --git a/medcat/meta_cat.py b/medcat/meta_cat.py index e4e647d7..f9b54d9f 100644 --- a/medcat/meta_cat.py +++ b/medcat/meta_cat.py @@ -51,11 +51,13 @@ class MetaCAT(PipeRunner): def __init__(self, tokenizer: Optional[TokenizerWrapperBase] = None, embeddings: Optional[Union[Tensor, numpy.ndarray]] = None, - config: Optional[ConfigMetaCAT] = None) -> None: + config: Optional[ConfigMetaCAT] = None, + save_dir_path: Optional[str] = None) -> None: if config is None: config = ConfigMetaCAT() self.config = config set_all_seeds(config.general['seed']) + self.save_dir_path = save_dir_path if tokenizer is not None: # Set it in the config @@ -90,7 +92,7 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module: elif config.model['model_name'] == 'bert': from medcat.utils.meta_cat.models import BertForMetaAnnotation - model = BertForMetaAnnotation(config) + model = BertForMetaAnnotation(config,self.save_dir_path) if not config.model.model_freeze_layers: peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=16, @@ -380,6 +382,9 @@ def save(self, save_dir_path: str) -> None: model_save_path = os.path.join(save_dir_path, 'model.dat') torch.save(self.model.state_dict(), model_save_path) + if self.config.model.model_name == 'bert': + model_config_save_path = os.path.join(save_dir_path, 'bert_config.json') + self.model.bert_config.to_json_file(model_config_save_path) # type: ignore # This is everything we need to save from the class, we do not # save the class itself. @@ -416,7 +421,7 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA tokenizer = TokenizerWrapperBERT.load(save_dir_path, config.model.model_variant) # Create meta_cat - meta_cat = cls(tokenizer=tokenizer, embeddings=None, config=config) + meta_cat = cls(tokenizer=tokenizer, embeddings=None, config=config,save_dir_path=save_dir_path) # Load the model model_save_path = os.path.join(save_dir_path, 'model.dat') diff --git a/medcat/utils/meta_cat/models.py b/medcat/utils/meta_cat/models.py index 543e0ca6..6b59c2a7 100644 --- a/medcat/utils/meta_cat/models.py +++ b/medcat/utils/meta_cat/models.py @@ -87,16 +87,35 @@ def forward(self, class BertForMetaAnnotation(nn.Module): _keys_to_ignore_on_load_unexpected: List[str] = [r"pooler"] # type: ignore - def __init__(self, config): + def __init__(self, config, save_dir_path=None): super(BertForMetaAnnotation, self).__init__() - _bertconfig = AutoConfig.from_pretrained(config.model.model_variant,num_hidden_layers=config.model['num_layers']) + if save_dir_path: + try: + _bertconfig = AutoConfig.from_pretrained(save_dir_path + "/bert_config.json", + num_hidden_layers=config.model['num_layers']) + except Exception: + _bertconfig = AutoConfig.from_pretrained(config.model.model_variant, + num_hidden_layers=config.model['num_layers']) + logger.info("BERT config not found locally — downloaded successfully from Hugging Face.") + + else: + _bertconfig = AutoConfig.from_pretrained(config.model.model_variant,num_hidden_layers=config.model['num_layers']) + if config.model['input_size'] != _bertconfig.hidden_size: logger.warning("Input size for %s model should be %d, provided input size is %d. Input size changed to %d",config.model.model_variant,_bertconfig.hidden_size,config.model['input_size'],_bertconfig.hidden_size) - bert = BertModel.from_pretrained(config.model.model_variant, config=_bertconfig) + if config.model['load_bert_pretrained_weights']: + try: + bert = BertModel.from_pretrained(config.model.model_variant, config=_bertconfig) + except Exception: + raise Exception("Could not load BERT pretrained weights from Hugging Face. \nIf you're seeing a connection error, set `config.model.load_bert_pretrained_weights=False` and make sure to load the model pack from disk instead.") + else: + bert = BertModel(_bertconfig) + self.config = config self.config.use_return_dict = False self.bert = bert + self.bert_config = _bertconfig self.num_labels = config.model["nclasses"] for param in self.bert.parameters(): param.requires_grad = not config.model.model_freeze_layers