Skip to content

Commit bdfbfeb

Browse files
authored
Update validate.py (#221)
* Update validate.py * minor fix
1 parent 6053442 commit bdfbfeb

File tree

3 files changed

+38
-16
lines changed

3 files changed

+38
-16
lines changed

graph_net/paddle/test_compiler.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,14 @@ def test_single_model(args):
114114
compiled_duration_box = DurationBox(-1)
115115
with naive_timer(compiled_duration_box, synchronizer_func):
116116
compiled_out = compiled_model(**input_dict)
117-
expected_out = expected_out.numpy()
118-
compiled_out = compiled_out.numpy()
117+
if isinstance(expected_out, paddle.Tensor):
118+
expected_out = expected_out.numpy()
119+
compiled_out = compiled_out.numpy()
120+
elif isinstance(expected_out, list) or isinstance(expected_out, tuple):
121+
expected_out = expected_out[0].numpy()
122+
compiled_out = compiled_out[0].numpy()
123+
else:
124+
raise ValueError("Illegal return value.")
119125

120126
def print_cmp(key, func, **kwargs):
121127
cmp_ret = func(expected_out, compiled_out, **kwargs)

graph_net/paddle/utils.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,21 +127,24 @@ def convert_meta_classes_to_tensors(file_path):
127127
if not k.startswith("__") and not callable(v)
128128
}
129129
data_value = None
130-
data_type = getattr(paddle, attrs.get("dtype", "paddle.float").split(".")[-1])
130+
data_type = getattr(paddle, attrs.get("dtype", "float32"))
131131
if attrs.get("data") is not None:
132132
if isinstance(attrs.get("data"), str):
133133
raise ValueError("Unimplemented")
134134
else:
135-
data_value = paddle.to_tensor(
136-
attrs.get("data"), dtype=data_type
137-
).reshape(attrs.get("shape"), [])
135+
data_value = paddle.reshape(
136+
paddle.to_tensor(attrs.get("data"), dtype=data_type),
137+
attrs.get("shape", []),
138+
)
138139
yield {
139140
"info": {
140141
"shape": attrs.get("shape", []),
141142
"dtype": data_type,
142143
"device": attrs.get("device", "gpu"),
143144
"mean": attrs.get("mean", 0.0),
144145
"std": attrs.get("std", 1.0),
146+
"low": attrs.get("low", 0),
147+
"high": attrs.get("high", 2),
145148
},
146149
"data": data_value,
147150
"name": attrs.get("name"),
@@ -163,11 +166,27 @@ def replay_tensor(info):
163166
device = info["info"]["device"]
164167
dtype = info["info"]["dtype"]
165168
shape = info["info"]["shape"]
169+
min_value = info["info"]["low"] if "low" in info["info"] else 0
170+
max_value = info["info"]["high"] if "high" in info["info"] else 0.5
166171
if None in shape:
167172
shape = list(map(lambda i: i if i is not None else 1, shape))
168-
mean = info["info"]["mean"]
169-
std = info["info"]["std"]
170173
if "data" in info and info["data"] is not None:
171-
return info["data"].to(device)
172-
173-
return (paddle.randn(shape).cast(dtype).to(device) * std * 1e-3 + 1e-2).cast(dtype)
174+
return paddle.reshape(info["data"], shape).to(dtype).to(device)
175+
elif dtype == paddle.int32 or dtype == paddle.int64:
176+
return paddle.cast(
177+
paddle.randint(low=min_value, high=max_value, shape=shape, dtype="int64"),
178+
dtype,
179+
).to(device)
180+
elif dtype == paddle.bool:
181+
return paddle.cast(
182+
paddle.randint(low=0, high=2, shape=shape, dtype="int32"),
183+
paddle.bool,
184+
).to(device)
185+
else:
186+
std = info["info"]["std"]
187+
# return paddle.randn(shape).to(dtype).to(device) * std * 1e-3 + 1e-2
188+
return (
189+
paddle.uniform(shape, dtype="float32", min=min_value, max=max_value)
190+
.to(dtype)
191+
.to(device)
192+
)

graph_net/paddle/validate.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,8 @@ def main(args):
7171
print(np.argmin(y), np.argmax(y))
7272
if isinstance(y, paddle.Tensor):
7373
print(y.shape)
74-
elif (isinstance(y, list) or isinstance(y, tuple)) and all(
75-
isinstance(obj, paddle.Tensor) for obj in y
76-
):
77-
# list of paddle.Tensor
78-
print(y[0].shape)
74+
elif isinstance(y, list) or isinstance(y, tuple):
75+
print(y[0].shape if isinstance(y[0], paddle.tensor) else y[0])
7976
else:
8077
raise ValueError("Illegal return value.")
8178

0 commit comments

Comments
 (0)