Skip to content

Commit ce87dc8

Browse files
authored
Merge pull request #1223 from inesvalentim/fix_issue_1221
Fix device_type of default preprocessor of PyTorchEstimator
2 parents 4f8db1f + 145217f commit ce87dc8

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

art/estimators/pytorch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def __init__(self, device_type: str = "gpu", **kwargs) -> None:
6666
if isinstance(preprocessing, tuple):
6767
from art.preprocessing.standardisation_mean_std.pytorch import StandardisationMeanStdPyTorch
6868

69-
kwargs["preprocessing"] = StandardisationMeanStdPyTorch(mean=preprocessing[0], std=preprocessing[1])
69+
kwargs["preprocessing"] = StandardisationMeanStdPyTorch(
70+
mean=preprocessing[0], std=preprocessing[1], device_type=device_type
71+
)
7072

7173
super().__init__(**kwargs)
7274

0 commit comments

Comments
 (0)