diff --git a/losses.py b/losses.py index ff24b25..c8f5b0a 100644 --- a/losses.py +++ b/losses.py @@ -46,7 +46,7 @@ def forward(self, outputs): image_ssl_embed = outputs["image_ssl_embed"] inputs = {} inputs["aug1_embed"] = image_ssl_embed[:bs] - inputs["aug2_embed"] = image_ssl_embed[:bs] + inputs["aug2_embed"] = image_ssl_embed[bs:] simclr_loss_dict = self.simclr_loss(inputs) def loss_fn(x, y): diff --git a/models.py b/models.py index 0aa50c1..10fd1ac 100644 --- a/models.py +++ b/models.py @@ -763,7 +763,7 @@ def ACLIP_VITS16(mask_ratio=0, **kwargs): "mask_vit_small_patch16_224", num_classes=0, mask_ratio=mask_ratio ) vision_model_ema = timm.create_model( - "mask_vit_small_patch16_224", num_classes=0, mask_ratio=mask_ratio + "mask_vit_small_patch16_224", num_classes=0, mask_ratio=0 ) model = ACLIP( embed_dim=512,