Skip to content

Commit 8449450

Browse files
committed
Update with missing torch import
Signed-off-by: MatteoFasulo <[email protected]>
1 parent 5508da1 commit 8449450

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

art/defences/preprocessor/variance_minimization_pytorch.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
from tqdm.auto import tqdm
1717

18+
import torch
19+
1820
import numpy as np
1921

2022
from art.defences.preprocessor.preprocessor import PreprocessorPyTorch
@@ -43,7 +45,7 @@ def __init__(
4345
lamb: float = 0.5,
4446
max_iter: int = 10,
4547
channels_first: bool = True,
46-
clip_values: "CLIP_VALUES_TYPE" | None = None,
48+
clip_values: "CLIP_VALUES_TYPE | None" = None,
4749
apply_fit: bool = False,
4850
apply_predict: bool = True,
4951
verbose: bool = False,
@@ -80,8 +82,8 @@ def __init__(
8082
self._check_params()
8183

8284
def forward(
83-
self, x: "torch.Tensor", y: "torch.Tensor" | None = None
84-
) -> tuple["torch.Tensor", "torch.Tensor" | None]:
85+
self, x: "torch.Tensor", y: "torch.Tensor | None" = None
86+
) -> tuple["torch.Tensor", "torch.Tensor | None"]:
8587
"""
8688
Apply total variance minimization to sample `x`.
8789

0 commit comments

Comments
 (0)