Skip to content

Commit 615d68a

Browse files
committed
Change the init of int tensor.
1 parent 1816e8d commit 615d68a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

graph_net/paddle/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,10 @@ def replay_tensor(info):
180180
if "data" in info and info["data"] is not None:
181181
return paddle.reshape(info["data"], shape).to(dtype).to(device)
182182
elif dtype == paddle.int32 or dtype == paddle.int64:
183-
# for some ops(binary_cross_entropy), label data can only be set 0 or 1.
184183
return paddle.cast(
185-
paddle.randint(low=0, high=2, shape=shape, dtype="int64"),
184+
paddle.randint(
185+
low=min_value, high=max_value + 1, shape=shape, dtype="int64"
186+
),
186187
dtype,
187188
).to(device)
188189
elif dtype == paddle.bool:
@@ -192,7 +193,6 @@ def replay_tensor(info):
192193
).to(device)
193194
else:
194195
std = info["info"]["std"]
195-
# return paddle.randn(shape).to(dtype).to(device) * std * 1e-3 + 1e-2
196196
return (
197197
paddle.uniform(shape, dtype="float32", min=min_value, max=max_value)
198198
.to(dtype)

0 commit comments

Comments
 (0)