Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions graph_net/paddle/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,26 @@ def get_input_spec(args):
inputs_params_list = utils.load_converted_list_from_text(f"{args.model_path}")
input_spec = [None] * len(inputs_params_list)
for i, v in enumerate(inputs_params_list):
name = v["name"]
dtype = v["info"]["dtype"]
shape = v["info"]["shape"]
# print(f"-- i: {i}, v: name={name}, shape={shape}, dtype={dtype}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug信息帮忙删除一下吧

input_spec[i] = paddle.static.InputSpec(shape, dtype)
return input_spec


def regular_item(item):
if isinstance(item, paddle.Tensor) and (
item.dtype == paddle.bfloat16 or item.dtype == paddle.bfloat32
):
item = np.array(item.astype("float32"))
else:
item = np.array(item)
if item.dtype == np.bool_:
item = item.astype("float32")
return item


def test_single_model(args):
synchronizer_func = get_synchronizer_func(args)
input_dict = get_input_dict(args)
Expand All @@ -88,20 +102,18 @@ def test_single_model(args):
build_strategy.build_cinn_pass = False

# eager
model = paddle.jit.to_static(
model_dy,
full_graph=False,
)
model.eval()
print("-- Run with eager mode")
model_dy.eval()
for _ in range(args.warmup if args.warmup > 0 else 0):
model(**input_dict)
model_dy(**input_dict)
eager_duration_box = DurationBox(-1)
with naive_timer(eager_duration_box, synchronizer_func):
expected_out = model(**input_dict)
expected_out = model_dy(**input_dict)

# compiled
print("-- Run with compiled mode")
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = True
# build_strategy.build_cinn_pass = True
compiled_model = paddle.jit.to_static(
model_dy,
input_spec=input_spec,
Expand All @@ -115,11 +127,15 @@ def test_single_model(args):
with naive_timer(compiled_duration_box, synchronizer_func):
compiled_out = compiled_model(**input_dict)
if isinstance(expected_out, paddle.Tensor):
expected_out = expected_out.numpy()
compiled_out = compiled_out.numpy()
elif isinstance(expected_out, list) or isinstance(expected_out, tuple):
expected_out = expected_out[0].numpy()
compiled_out = compiled_out[0].numpy()
expected_out = [expected_out]
compiled_out = [compiled_out]
if isinstance(expected_out, list) or isinstance(expected_out, tuple):
expected_out = [
regular_item(item) for item in expected_out if np.array(item).size != 0
]
compiled_out = [
regular_item(item) for item in compiled_out if np.array(item).size != 0
]
else:
raise ValueError("Illegal return value.")

Expand Down
16 changes: 12 additions & 4 deletions graph_net/paddle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import argparse
import importlib
import inspect
import ast
import paddle


Expand Down Expand Up @@ -115,8 +116,7 @@ def load_converted_list_from_text(file_path):
weight_info = [
data for data in convert_meta_classes_to_tensors(f"{file_path}/weight_meta.py")
]

return [*input_info, *weight_info]
return [*weight_info, *input_info]


def convert_meta_classes_to_tensors(file_path):
Expand Down Expand Up @@ -152,10 +152,17 @@ def convert_meta_classes_to_tensors(file_path):


def _get_classes(file_path):
with open(file_path, "r", encoding="utf-8") as f:
tree = ast.parse(f.read(), filename=file_path)

class_names = [node.name for node in tree.body if isinstance(node, ast.ClassDef)]

spec = importlib.util.spec_from_file_location("unnamed", file_path)
unnamed = importlib.util.module_from_spec(spec)
spec.loader.exec_module(unnamed)
yield from inspect.getmembers(unnamed, inspect.isclass)

classes = [(name, getattr(unnamed, name)) for name in class_names]
return classes


def extract_dynamic_shapes(example_inputs):
Expand All @@ -173,8 +180,9 @@ def replay_tensor(info):
if "data" in info and info["data"] is not None:
return paddle.reshape(info["data"], shape).to(dtype).to(device)
elif dtype == paddle.int32 or dtype == paddle.int64:
# for some ops(binary_cross_entropy), label data can only be set 0 or 1.
return paddle.cast(
paddle.randint(low=min_value, high=max_value, shape=shape, dtype="int64"),
paddle.randint(low=0, high=2, shape=shape, dtype="int64"),
dtype,
).to(device)
elif dtype == paddle.bool:
Expand Down