Skip to content

Commit f31b65a

Browse files
author
Beat Buesser
committed
Define dtype
Signed-off-by: Beat Buesser <[email protected]>
1 parent ceb8ba2 commit f31b65a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

art/preprocessing/standardisation_mean_std/standardisation_mean_std_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def forward(
7272
"""
7373
import torch # lgtm [py/repeated-import]
7474

75-
mean = torch.tensor(self.mean, device=self._device)
76-
std = torch.tensor(self.std, device=self._device)
75+
mean = torch.tensor(self.mean, device=self._device, dtype=torch.float32)
76+
std = torch.tensor(self.std, device=self._device, dtype=torch.float32)
7777

7878
x_norm = x - mean
7979
x_norm = x_norm / std

0 commit comments

Comments
 (0)