Skip to content

Commit 3dea00f

Browse files
authored
Refactor contrastive loss (#35)
1 parent 54fc119 commit 3dea00f

File tree

12 files changed

+508
-265
lines changed

12 files changed

+508
-265
lines changed

mmlearn/cli/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def main(cfg: MMLearnConf) -> None: # noqa: PLR0912
4545

4646
if is_torch_tf32_available():
4747
torch.backends.cuda.matmul.allow_tf32 = True
48-
if "16-mixed" in cfg.trainer.precision:
48+
if "16-mixed" in str(cfg.trainer.precision):
4949
cfg.trainer.precision = "bf16-mixed"
5050

5151
# setup trainer first so that we can get some variables for distributed training

mmlearn/conf/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ class MMLearnConf:
168168
job=JobConf(
169169
name=II("experiment_name"),
170170
env_set={
171-
"TORCH_NCCL_ASYNC_ERROR_HANDLING": "3",
171+
"TORCH_NCCL_ASYNC_ERROR_HANDLING": "1",
172172
"HYDRA_FULL_ERROR": "1",
173173
},
174174
),

mmlearn/hf_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def load_huggingface_model(
6767
return_unused_kwargs=True,
6868
**model_config_kwargs,
6969
)
70-
model = model_type._from_config(config, **kwargs)
70+
model = model_type.from_config(config, **kwargs)
7171

7272
if get_model_attr is not None and hasattr(model, get_model_attr):
7373
model = getattr(model, get_model_attr)

mmlearn/modules/encoders/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
HFCLIPTextEncoderWithProjection,
66
HFCLIPVisionEncoder,
77
HFCLIPVisionEncoderWithProjection,
8-
PubMedBERTForCLIPTextEncoding,
98
)
109
from mmlearn.modules.encoders.text import HFTextEncoder
10+
from mmlearn.modules.encoders.vision import TimmViT
1111

1212

1313
__all__ = [
@@ -16,5 +16,5 @@
1616
"HFCLIPTextEncoderWithProjection",
1717
"HFCLIPVisionEncoder",
1818
"HFCLIPVisionEncoderWithProjection",
19-
"PubMedBERTForCLIPTextEncoding",
19+
"TimmViT",
2020
]

mmlearn/modules/encoders/clip.py

Lines changed: 0 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -474,123 +474,6 @@ def forward(self, inputs: Dict[str, Any]) -> Tuple[torch.Tensor]:
474474
return (self.model.visual_projection(pooled_output),)
475475

476476

477-
@store(group="modules/encoders", provider="mmlearn", hydra_convert="object")
478-
class PubMedBERTForCLIPTextEncoding(nn.Module):
479-
"""BiomedNLP's PubMedBERT model for CLIP text encoding.
480-
481-
This module is wrapper around the PubMedBERT model from huggingface.
482-
483-
Parameters
484-
----------
485-
pretrained : bool, default=False
486-
Whether to load the pretrained weights or not.
487-
pooling_layer : nn.Module, optional, default=None
488-
Pooling layer to apply to the last hidden state of the model.
489-
freeze_layers : int | float | List[int] | bool, default=False
490-
Whether to freeze layers of the model and which layers to freeze. If `True`,
491-
all model layers are frozen. If it is an integer, the first `N` layers of
492-
the model are frozen. If it is a float, the first `N` percent of the layers
493-
are frozen. If it is a list of integers, the layers at the indices in the
494-
list are frozen.
495-
freeze_layer_norm : bool, default=True
496-
Whether to freeze the layer normalization layers of the model.
497-
peft_config : PeftConfig, optional, default=None
498-
The configuration from the `peft` library to use to wrap the model
499-
for parameter-efficient finetuning.
500-
model_config_kwargs : Dict[str, Any], optional, default=None
501-
Additional keyword arguments to pass to the model configuration.
502-
503-
Warns
504-
-----
505-
UserWarning
506-
If both `peft_config` and `freeze_layers` are set. The `peft_config` will
507-
override the `freeze_layers` setting.
508-
509-
"""
510-
511-
def __init__(
512-
self,
513-
pretrained: bool = True,
514-
pooling_layer: Optional[nn.Module] = None,
515-
freeze_layers: Union[int, float, List[int], bool] = False,
516-
freeze_layer_norm: bool = True,
517-
peft_config: Optional["PeftConfig"] = None,
518-
model_config_kwargs: Optional[Dict[str, Any]] = None,
519-
) -> None:
520-
"""Initialize the model."""
521-
super().__init__()
522-
_warn_freeze_with_peft(peft_config, freeze_layers)
523-
524-
model = hf_utils.load_huggingface_model(
525-
transformers.AutoModelForMaskedLM,
526-
"microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
527-
load_pretrained_weights=pretrained,
528-
get_model_attr="bert",
529-
model_config_kwargs=model_config_kwargs,
530-
)
531-
532-
if isinstance(freeze_layers, bool) and freeze_layers:
533-
for name, param in model.named_parameters():
534-
param.requires_grad = (
535-
(not freeze_layer_norm) if "LayerNorm" in name else False
536-
)
537-
538-
layers = [model.embeddings, *model.encoder.layer]
539-
if isinstance(freeze_layers, float):
540-
freeze_layers = int(freeze_layers * len(layers))
541-
if isinstance(freeze_layers, int):
542-
freeze_layers = list(range(freeze_layers))
543-
544-
if isinstance(freeze_layers, list):
545-
for idx, layer in enumerate(layers):
546-
if idx in freeze_layers:
547-
for name, param in layer.named_parameters():
548-
param.requires_grad = (
549-
(not freeze_layer_norm) if "LayerNorm" in name else False
550-
)
551-
552-
if peft_config is not None:
553-
model = hf_utils._wrap_peft_model(model, peft_config)
554-
555-
self.model = model
556-
self.pooling_layer = pooling_layer
557-
558-
def forward(self, inputs: Dict[str, Any]) -> BaseModelOutput:
559-
"""Run the forward pass.
560-
561-
Parameters
562-
----------
563-
inputs : Dict[str, Any]
564-
The input data. The `input_ids` will be expected under the `Modalities.TEXT`
565-
key.
566-
567-
Returns
568-
-------
569-
BaseModelOutput
570-
The output of the model, including the last hidden state, all hidden states,
571-
and the attention weights, if `output_attentions` is set to `True`.
572-
"""
573-
output = self.model(
574-
input_ids=inputs[Modalities.TEXT.name],
575-
attention_mask=inputs.get(
576-
"attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
577-
),
578-
inputs_embeds=inputs.get("inputs_embeds"),
579-
output_attentions=inputs.get("output_attentions"),
580-
output_hidden_states=True,
581-
return_dict=True,
582-
)
583-
last_hidden_state = output.last_hidden_state
584-
if self.pooling_layer is not None:
585-
last_hidden_state = self.pooling_layer(last_hidden_state)
586-
587-
return BaseModelOutput(
588-
last_hidden_state=last_hidden_state,
589-
hidden_states=output.hidden_states,
590-
attentions=output.attentions,
591-
)
592-
593-
594477
#### Utility methods ####
595478

596479

mmlearn/modules/encoders/vision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
@store(
2727
group="modules/encoders",
2828
provider="mmlearn",
29-
model_name_or_path="vit_base_patch16_224",
29+
model_name="vit_base_patch16_224",
3030
hydra_convert="object",
3131
)
3232
class TimmViT(nn.Module):

mmlearn/modules/losses/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Loss functions."""
22

3-
from mmlearn.modules.losses.contrastive import CLIPLoss
3+
from mmlearn.modules.losses.contrastive import ContrastiveLoss
44
from mmlearn.modules.losses.data2vec import Data2VecLoss
55

66

7-
__all__ = ["CLIPLoss", "Data2VecLoss"]
7+
__all__ = ["ContrastiveLoss", "Data2VecLoss"]

0 commit comments

Comments
 (0)