Skip to content

Commit 17db639

Browse files
committed
Fix mypy warnings for timm 1.0.16
Signed-off-by: Beat Buesser <[email protected]>
1 parent e76ec6c commit 17db639

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

art/estimators/certification/derandomized_smoothing/pytorch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,12 +406,15 @@ def get_models(cls, generate_from_null: bool = False) -> list[str]:
406406
return supported
407407

408408
@staticmethod
409-
def create_vision_transformer(variant: str, pretrained: bool = False, **kwargs) -> "PyTorchVisionTransformer":
409+
def create_vision_transformer(
410+
variant: str, pretrained: bool = False, use_naflex: bool | None = None, **kwargs
411+
) -> "PyTorchVisionTransformer":
410412
"""
411413
Creates a vision transformer using PyTorchViT which controls the forward pass of the model
412414
413415
:param variant: The name of the vision transformer to load
414416
:param pretrained: If to load pre-trained weights
417+
:param use_naflex: If using NaFlexVit.
415418
:return: A ViT with the required methods needed for ART
416419
"""
417420

art/estimators/certification/derandomized_smoothing/vision_transformers/pytorch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,12 @@ def drop_tokens(x: torch.Tensor, indexes: torch.Tensor) -> torch.Tensor:
167167
x_no_cl = torch.reshape(x_no_cl, shape=(shape[0], -1, shape[-1]))
168168
return torch.cat((cls_token, x_no_cl), dim=1)
169169

170-
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
170+
def forward_features(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor:
171171
"""
172172
The forward pass of the ViT.
173173
174174
:param x: Input data.
175+
:param attn_mask: Attention mask.
175176
:return: The input processed by the ViT backbone
176177
"""
177178

0 commit comments

Comments
 (0)