Skip to content

Commit 54664fc

Browse files
Changed default cuda dtype to torch.float32.
1 parent c683e77 commit 54664fc

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

ml-agents/mlagents/torch_utils/torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def set_torch_config(torch_settings: TorchSettings) -> None:
5353

5454
if _device.type == "cuda":
5555
torch.set_default_device(_device.type)
56-
torch.set_default_dtype(torch.cuda.FloatTensor)
56+
torch.set_default_dtype(torch.float32)
5757
else:
5858
torch.set_default_dtype(torch.float32)
5959
logger.debug(f"default Torch device: {_device}")

0 commit comments

Comments
 (0)