Skip to content

Commit 34d1950

Browse files
committed
Update tests for timm 1.0.16
Signed-off-by: Beat Buesser <[email protected]>
1 parent 510b420 commit 34d1950

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tests/estimators/certification/test_vision_transformers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def embedder(cls, x, pos_embed, cls_token):
549549
x = torch.cat((cls_token.expand(x.shape[0], -1, -1), x), dim=1)
550550
return x + pos_embed
551551

552-
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
552+
def forward_features(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor:
553553
"""
554554
This is a copy of the function in ArtViT.forward_features
555555
except we also perform an equivalence assertion compared to the implementation
@@ -558,6 +558,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
558558
The forward pass of the ViT.
559559
560560
:param x: Input data.
561+
:param attn_mask: Attention mask.
561562
:return: The input processed by the ViT backbone
562563
"""
563564
import copy

0 commit comments

Comments
 (0)