Skip to content

Commit 78a984b

Browse files
committed
minor fix
1 parent 981ceaf commit 78a984b

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

graph_net/paddle/test_compiler.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,18 @@ def get_input_spec(args):
8181
return input_spec
8282

8383

84+
def regular_item(item):
85+
if isinstance(item, paddle.Tensor) and (
86+
item.dtype == paddle.bfloat16 or item.dtype == paddle.bfloat32
87+
):
88+
item = np.array(item.astype("float32"))
89+
else:
90+
item = np.array(item)
91+
if item.dtype == np.bool_:
92+
item = item.astype("float32")
93+
return item
94+
95+
8496
def test_single_model(args):
8597
synchronizer_func = get_synchronizer_func(args)
8698
input_dict = get_input_dict(args)
@@ -115,24 +127,15 @@ def test_single_model(args):
115127
with naive_timer(compiled_duration_box, synchronizer_func):
116128
compiled_out = compiled_model(**input_dict)
117129
if isinstance(expected_out, paddle.Tensor):
118-
expected_out = [expected_out.numpy().astype("float32")]
119-
compiled_out = [compiled_out.numpy().astype("float32")]
120-
elif isinstance(expected_out, list) or isinstance(expected_out, tuple):
121-
if isinstance(expected_out, tuple):
122-
expected_out = list(expected_out)
123-
compiled_out = list(compiled_out)
124-
new_expected = [
125-
np.array(item).astype("float32")
126-
for item in expected_out
127-
if np.array(item).size != 0
130+
expected_out = [expected_out]
131+
compiled_out = [compiled_out]
132+
if isinstance(expected_out, list) or isinstance(expected_out, tuple):
133+
expected_out = [
134+
regular_item(item) for item in expected_out if np.array(item).size != 0
128135
]
129-
new_compiled = [
130-
np.array(item).astype("float32")
131-
for item in compiled_out
132-
if np.array(item).size != 0
136+
compiled_out = [
137+
regular_item(item) for item in compiled_out if np.array(item).size != 0
133138
]
134-
expected_out = new_expected
135-
compiled_out = new_compiled
136139
else:
137140
raise ValueError("Illegal return value.")
138141

graph_net/paddle/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def replay_tensor(info):
184184
return paddle.reshape(info["data"], shape).to(dtype).to(device)
185185
elif dtype == paddle.int32 or dtype == paddle.int64:
186186
return paddle.cast(
187-
paddle.randint(low=0, high=1, shape=shape, dtype="int64"),
187+
paddle.randint(low=0, high=2, shape=shape, dtype="int64"),
188188
dtype,
189189
).to(device)
190190
elif dtype == paddle.bool:

0 commit comments

Comments
 (0)