Skip to content

Commit 6aba944

Browse files
authored
Merge pull request #2677 from Trusted-AI/dependabot/pip/timm-1.0.16
Bump timm from 1.0.15 to 1.0.16
2 parents 88e0c50 + 4f50f61 commit 6aba944

File tree

4 files changed

+9
-4
lines changed

4 files changed

+9
-4
lines changed

art/estimators/certification/derandomized_smoothing/pytorch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,12 +406,15 @@ def get_models(cls, generate_from_null: bool = False) -> list[str]:
406406
return supported
407407

408408
@staticmethod
409-
def create_vision_transformer(variant: str, pretrained: bool = False, **kwargs) -> "PyTorchVisionTransformer":
409+
def create_vision_transformer(
410+
variant: str, pretrained: bool = False, use_naflex: bool | None = None, **kwargs
411+
) -> "PyTorchVisionTransformer":
410412
"""
411413
Creates a vision transformer using PyTorchViT which controls the forward pass of the model
412414
413415
:param variant: The name of the vision transformer to load
414416
:param pretrained: If to load pre-trained weights
417+
:param use_naflex: If using NaFlexVit.
415418
:return: A ViT with the required methods needed for ART
416419
"""
417420

art/estimators/certification/derandomized_smoothing/vision_transformers/pytorch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,12 @@ def drop_tokens(x: torch.Tensor, indexes: torch.Tensor) -> torch.Tensor:
167167
x_no_cl = torch.reshape(x_no_cl, shape=(shape[0], -1, shape[-1]))
168168
return torch.cat((cls_token, x_no_cl), dim=1)
169169

170-
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
170+
def forward_features(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor:
171171
"""
172172
The forward pass of the ViT.
173173
174174
:param x: Input data.
175+
:param attn_mask: Attention mask.
175176
:return: The input processed by the ViT backbone
176177
"""
177178

requirements_test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ torchaudio==2.7.1
3333
torchvision==0.22.1
3434

3535
# PyTorch image transformers
36-
timm==1.0.15
36+
timm==1.0.16
3737

3838
# YOLO dependencies
3939
ultralytics==8.3.162

tests/estimators/certification/test_vision_transformers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def embedder(cls, x, pos_embed, cls_token):
549549
x = torch.cat((cls_token.expand(x.shape[0], -1, -1), x), dim=1)
550550
return x + pos_embed
551551

552-
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
552+
def forward_features(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor:
553553
"""
554554
This is a copy of the function in ArtViT.forward_features
555555
except we also perform an equivalence assertion compared to the implementation
@@ -558,6 +558,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
558558
The forward pass of the ViT.
559559
560560
:param x: Input data.
561+
:param attn_mask: Attention mask.
561562
:return: The input processed by the ViT backbone
562563
"""
563564
import copy

0 commit comments

Comments
 (0)