Skip to content

Commit cc60b1b

Browse files
authored
Merge pull request #890 from Trusted-AI/development_issue_889
Define dtype in StandardisationMeanStdPyTorch
2 parents ceb8ba2 + f31b65a commit cc60b1b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

art/preprocessing/standardisation_mean_std/standardisation_mean_std_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def forward(
7272
"""
7373
import torch # lgtm [py/repeated-import]
7474

75-
mean = torch.tensor(self.mean, device=self._device)
76-
std = torch.tensor(self.std, device=self._device)
75+
mean = torch.tensor(self.mean, device=self._device, dtype=torch.float32)
76+
std = torch.tensor(self.std, device=self._device, dtype=torch.float32)
7777

7878
x_norm = x - mean
7979
x_norm = x_norm / std

0 commit comments

Comments
 (0)