Skip to content

Commit 4f3df77

Browse files
committed
Add test for PreprocessorPyTorch initialization
Signed-off-by: Nilaksh Das <[email protected]>
1 parent 53d6cec commit 4f3df77

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytest
2+
3+
from art.defences.preprocessor.preprocessor import PreprocessorPyTorch
4+
from tests.utils import ARTTestException
5+
6+
7+
class DummyPreprocessorPyTorch(PreprocessorPyTorch):
8+
def forward(self, x, y):
9+
return x, y
10+
11+
12+
@pytest.mark.parametrize("is_fitted", [True, False])
13+
@pytest.mark.parametrize("apply_fit", [True, False])
14+
@pytest.mark.parametrize("apply_predict", [True, False])
15+
@pytest.mark.only_with_platform("pytorch")
16+
def test_preprocessor_pytorch_init(art_warning, is_fitted, apply_fit, apply_predict):
17+
try:
18+
import torch
19+
20+
preprocessor = DummyPreprocessorPyTorch(
21+
device_type="cpu",
22+
is_fitted=is_fitted,
23+
apply_fit=apply_fit,
24+
apply_predict=apply_predict,
25+
)
26+
27+
assert preprocessor.device == torch.device("cpu")
28+
assert preprocessor.is_fitted == is_fitted
29+
assert preprocessor.apply_fit == apply_fit
30+
assert preprocessor.apply_predict == apply_predict
31+
32+
except ARTTestException as e:
33+
art_warning(e)

0 commit comments

Comments
 (0)