@@ -474,123 +474,6 @@ def forward(self, inputs: Dict[str, Any]) -> Tuple[torch.Tensor]:
474474 return (self .model .visual_projection (pooled_output ),)
475475
476476
477- @store (group = "modules/encoders" , provider = "mmlearn" , hydra_convert = "object" )
478- class PubMedBERTForCLIPTextEncoding (nn .Module ):
479- """BiomedNLP's PubMedBERT model for CLIP text encoding.
480-
481- This module is wrapper around the PubMedBERT model from huggingface.
482-
483- Parameters
484- ----------
485- pretrained : bool, default=False
486- Whether to load the pretrained weights or not.
487- pooling_layer : nn.Module, optional, default=None
488- Pooling layer to apply to the last hidden state of the model.
489- freeze_layers : int | float | List[int] | bool, default=False
490- Whether to freeze layers of the model and which layers to freeze. If `True`,
491- all model layers are frozen. If it is an integer, the first `N` layers of
492- the model are frozen. If it is a float, the first `N` percent of the layers
493- are frozen. If it is a list of integers, the layers at the indices in the
494- list are frozen.
495- freeze_layer_norm : bool, default=True
496- Whether to freeze the layer normalization layers of the model.
497- peft_config : PeftConfig, optional, default=None
498- The configuration from the `peft` library to use to wrap the model
499- for parameter-efficient finetuning.
500- model_config_kwargs : Dict[str, Any], optional, default=None
501- Additional keyword arguments to pass to the model configuration.
502-
503- Warns
504- -----
505- UserWarning
506- If both `peft_config` and `freeze_layers` are set. The `peft_config` will
507- override the `freeze_layers` setting.
508-
509- """
510-
511- def __init__ (
512- self ,
513- pretrained : bool = True ,
514- pooling_layer : Optional [nn .Module ] = None ,
515- freeze_layers : Union [int , float , List [int ], bool ] = False ,
516- freeze_layer_norm : bool = True ,
517- peft_config : Optional ["PeftConfig" ] = None ,
518- model_config_kwargs : Optional [Dict [str , Any ]] = None ,
519- ) -> None :
520- """Initialize the model."""
521- super ().__init__ ()
522- _warn_freeze_with_peft (peft_config , freeze_layers )
523-
524- model = hf_utils .load_huggingface_model (
525- transformers .AutoModelForMaskedLM ,
526- "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" ,
527- load_pretrained_weights = pretrained ,
528- get_model_attr = "bert" ,
529- model_config_kwargs = model_config_kwargs ,
530- )
531-
532- if isinstance (freeze_layers , bool ) and freeze_layers :
533- for name , param in model .named_parameters ():
534- param .requires_grad = (
535- (not freeze_layer_norm ) if "LayerNorm" in name else False
536- )
537-
538- layers = [model .embeddings , * model .encoder .layer ]
539- if isinstance (freeze_layers , float ):
540- freeze_layers = int (freeze_layers * len (layers ))
541- if isinstance (freeze_layers , int ):
542- freeze_layers = list (range (freeze_layers ))
543-
544- if isinstance (freeze_layers , list ):
545- for idx , layer in enumerate (layers ):
546- if idx in freeze_layers :
547- for name , param in layer .named_parameters ():
548- param .requires_grad = (
549- (not freeze_layer_norm ) if "LayerNorm" in name else False
550- )
551-
552- if peft_config is not None :
553- model = hf_utils ._wrap_peft_model (model , peft_config )
554-
555- self .model = model
556- self .pooling_layer = pooling_layer
557-
558- def forward (self , inputs : Dict [str , Any ]) -> BaseModelOutput :
559- """Run the forward pass.
560-
561- Parameters
562- ----------
563- inputs : Dict[str, Any]
564- The input data. The `input_ids` will be expected under the `Modalities.TEXT`
565- key.
566-
567- Returns
568- -------
569- BaseModelOutput
570- The output of the model, including the last hidden state, all hidden states,
571- and the attention weights, if `output_attentions` is set to `True`.
572- """
573- output = self .model (
574- input_ids = inputs [Modalities .TEXT .name ],
575- attention_mask = inputs .get (
576- "attention_mask" , inputs .get (Modalities .TEXT .attention_mask , None )
577- ),
578- inputs_embeds = inputs .get ("inputs_embeds" ),
579- output_attentions = inputs .get ("output_attentions" ),
580- output_hidden_states = True ,
581- return_dict = True ,
582- )
583- last_hidden_state = output .last_hidden_state
584- if self .pooling_layer is not None :
585- last_hidden_state = self .pooling_layer (last_hidden_state )
586-
587- return BaseModelOutput (
588- last_hidden_state = last_hidden_state ,
589- hidden_states = output .hidden_states ,
590- attentions = output .attentions ,
591- )
592-
593-
594477#### Utility methods ####
595478
596479
0 commit comments