Skip to content

Commit cb50d6c

Browse files
committed
[Bug Fix] Support bool dtype in replay_tensor
1 parent b3fb4bf commit cb50d6c

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

graph_net/torch/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,4 +265,6 @@ def replay_tensor(info):
265265
std = info["info"]["std"]
266266
if "data" in info and info["data"] is not None:
267267
return info["data"].to(device)
268+
if dtype is torch.bool:
269+
return (torch.randn(size=shape) > 0.5).to(dtype).to(device)
268270
return torch.randn(size=shape).to(dtype).to(device) * std * 0.2 + mean

0 commit comments

Comments
 (0)