Skip to content

Commit 9fe10d6

Browse files
hxzd5568Xreki
andauthored
Fix validate and test_compiler for paddle (#225)
* Update validate.py * minor fix * Fix test_compiler * Fix test_compiler of paddle. * fix inputs generation error and acc calculation error * minor fix * delete unused expr * delete unused expr --------- Co-authored-by: Liu Yiqun <[email protected]>
1 parent bdfbfeb commit 9fe10d6

File tree

2 files changed

+41
-17
lines changed

2 files changed

+41
-17
lines changed

graph_net/paddle/test_compiler.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,26 @@ def get_input_spec(args):
7373
inputs_params_list = utils.load_converted_list_from_text(f"{args.model_path}")
7474
input_spec = [None] * len(inputs_params_list)
7575
for i, v in enumerate(inputs_params_list):
76+
name = v["name"]
7677
dtype = v["info"]["dtype"]
7778
shape = v["info"]["shape"]
79+
# print(f"-- i: {i}, v: name={name}, shape={shape}, dtype={dtype}")
7880
input_spec[i] = paddle.static.InputSpec(shape, dtype)
7981
return input_spec
8082

8183

84+
def regular_item(item):
85+
if isinstance(item, paddle.Tensor) and (
86+
item.dtype == paddle.bfloat16 or item.dtype == paddle.bfloat32
87+
):
88+
item = np.array(item.astype("float32"))
89+
else:
90+
item = np.array(item)
91+
if item.dtype == np.bool_:
92+
item = item.astype("float32")
93+
return item
94+
95+
8296
def test_single_model(args):
8397
synchronizer_func = get_synchronizer_func(args)
8498
input_dict = get_input_dict(args)
@@ -88,20 +102,18 @@ def test_single_model(args):
88102
build_strategy.build_cinn_pass = False
89103

90104
# eager
91-
model = paddle.jit.to_static(
92-
model_dy,
93-
full_graph=False,
94-
)
95-
model.eval()
105+
print("-- Run with eager mode")
106+
model_dy.eval()
96107
for _ in range(args.warmup if args.warmup > 0 else 0):
97-
model(**input_dict)
108+
model_dy(**input_dict)
98109
eager_duration_box = DurationBox(-1)
99110
with naive_timer(eager_duration_box, synchronizer_func):
100-
expected_out = model(**input_dict)
111+
expected_out = model_dy(**input_dict)
101112

102113
# compiled
114+
print("-- Run with compiled mode")
103115
build_strategy = paddle.static.BuildStrategy()
104-
build_strategy.build_cinn_pass = True
116+
# build_strategy.build_cinn_pass = True
105117
compiled_model = paddle.jit.to_static(
106118
model_dy,
107119
input_spec=input_spec,
@@ -115,11 +127,15 @@ def test_single_model(args):
115127
with naive_timer(compiled_duration_box, synchronizer_func):
116128
compiled_out = compiled_model(**input_dict)
117129
if isinstance(expected_out, paddle.Tensor):
118-
expected_out = expected_out.numpy()
119-
compiled_out = compiled_out.numpy()
120-
elif isinstance(expected_out, list) or isinstance(expected_out, tuple):
121-
expected_out = expected_out[0].numpy()
122-
compiled_out = compiled_out[0].numpy()
130+
expected_out = [expected_out]
131+
compiled_out = [compiled_out]
132+
if isinstance(expected_out, list) or isinstance(expected_out, tuple):
133+
expected_out = [
134+
regular_item(item) for item in expected_out if np.array(item).size != 0
135+
]
136+
compiled_out = [
137+
regular_item(item) for item in compiled_out if np.array(item).size != 0
138+
]
123139
else:
124140
raise ValueError("Illegal return value.")
125141

graph_net/paddle/utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import argparse
77
import importlib
88
import inspect
9+
import ast
910
import paddle
1011

1112

@@ -115,8 +116,7 @@ def load_converted_list_from_text(file_path):
115116
weight_info = [
116117
data for data in convert_meta_classes_to_tensors(f"{file_path}/weight_meta.py")
117118
]
118-
119-
return [*input_info, *weight_info]
119+
return [*weight_info, *input_info]
120120

121121

122122
def convert_meta_classes_to_tensors(file_path):
@@ -152,10 +152,17 @@ def convert_meta_classes_to_tensors(file_path):
152152

153153

154154
def _get_classes(file_path):
155+
with open(file_path, "r", encoding="utf-8") as f:
156+
tree = ast.parse(f.read(), filename=file_path)
157+
158+
class_names = [node.name for node in tree.body if isinstance(node, ast.ClassDef)]
159+
155160
spec = importlib.util.spec_from_file_location("unnamed", file_path)
156161
unnamed = importlib.util.module_from_spec(spec)
157162
spec.loader.exec_module(unnamed)
158-
yield from inspect.getmembers(unnamed, inspect.isclass)
163+
164+
classes = [(name, getattr(unnamed, name)) for name in class_names]
165+
return classes
159166

160167

161168
def extract_dynamic_shapes(example_inputs):
@@ -173,8 +180,9 @@ def replay_tensor(info):
173180
if "data" in info and info["data"] is not None:
174181
return paddle.reshape(info["data"], shape).to(dtype).to(device)
175182
elif dtype == paddle.int32 or dtype == paddle.int64:
183+
# for some ops(binary_cross_entropy), label data can only be set 0 or 1.
176184
return paddle.cast(
177-
paddle.randint(low=min_value, high=max_value, shape=shape, dtype="int64"),
185+
paddle.randint(low=0, high=2, shape=shape, dtype="int64"),
178186
dtype,
179187
).to(device)
180188
elif dtype == paddle.bool:

0 commit comments

Comments
 (0)