Skip to content

Commit 82e6feb

Browse files
committed
Fix test_compiler of paddle.
1 parent bdfbfeb commit 82e6feb

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

graph_net/paddle/test_compiler.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ def get_input_spec(args):
7373
inputs_params_list = utils.load_converted_list_from_text(f"{args.model_path}")
7474
input_spec = [None] * len(inputs_params_list)
7575
for i, v in enumerate(inputs_params_list):
76+
name = v["name"]
7677
dtype = v["info"]["dtype"]
7778
shape = v["info"]["shape"]
79+
# print(f"-- i: {i}, v: name={name}, shape={shape}, dtype={dtype}")
7880
input_spec[i] = paddle.static.InputSpec(shape, dtype)
7981
return input_spec
8082

@@ -88,20 +90,18 @@ def test_single_model(args):
8890
build_strategy.build_cinn_pass = False
8991

9092
# eager
91-
model = paddle.jit.to_static(
92-
model_dy,
93-
full_graph=False,
94-
)
95-
model.eval()
93+
print("-- Run with eager mode")
94+
model_dy.eval()
9695
for _ in range(args.warmup if args.warmup > 0 else 0):
97-
model(**input_dict)
96+
model_dy(**input_dict)
9897
eager_duration_box = DurationBox(-1)
9998
with naive_timer(eager_duration_box, synchronizer_func):
100-
expected_out = model(**input_dict)
99+
expected_out = model_dy(**input_dict)
101100

102101
# compiled
102+
print("-- Run with compiled mode")
103103
build_strategy = paddle.static.BuildStrategy()
104-
build_strategy.build_cinn_pass = True
104+
# build_strategy.build_cinn_pass = True
105105
compiled_model = paddle.jit.to_static(
106106
model_dy,
107107
input_spec=input_spec,

graph_net/paddle/utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import argparse
77
import importlib
88
import inspect
9+
import ast
910
import paddle
1011

1112

@@ -112,15 +113,16 @@ def load_converted_list_from_text(file_path):
112113
input_info = [
113114
data for data in convert_meta_classes_to_tensors(f"{file_path}/input_meta.py")
114115
]
116+
# print(f"-- input_info: {input_info}")
115117
weight_info = [
116118
data for data in convert_meta_classes_to_tensors(f"{file_path}/weight_meta.py")
117119
]
118-
119-
return [*input_info, *weight_info]
120+
return [*weight_info, *input_info]
120121

121122

122123
def convert_meta_classes_to_tensors(file_path):
123124
for name, cls in _get_classes(file_path):
125+
# print(f"-- name: {name}")
124126
attrs = {
125127
k: v
126128
for k, v in cls.__dict__.items()
@@ -152,10 +154,18 @@ def convert_meta_classes_to_tensors(file_path):
152154

153155

154156
def _get_classes(file_path):
157+
with open(file_path, "r", encoding="utf-8") as f:
158+
tree = ast.parse(f.read(), filename=file_path)
159+
160+
class_names = [node.name for node in tree.body if isinstance(node, ast.ClassDef)]
161+
155162
spec = importlib.util.spec_from_file_location("unnamed", file_path)
156163
unnamed = importlib.util.module_from_spec(spec)
157164
spec.loader.exec_module(unnamed)
158-
yield from inspect.getmembers(unnamed, inspect.isclass)
165+
# yield from inspect.getmembers(unnamed, inspect.isclass)
166+
167+
classes = [(name, getattr(unnamed, name)) for name in class_names]
168+
return classes
159169

160170

161171
def extract_dynamic_shapes(example_inputs):

0 commit comments

Comments
 (0)