Skip to content

Commit 11a435a

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/GraphNet into develop
2 parents 2907b1f + 0bf4827 commit 11a435a

File tree

11 files changed

+1496
-37
lines changed

11 files changed

+1496
-37
lines changed

graph_net/paddle/test_compiler.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,24 @@ def get_input_spec(args):
7373
inputs_params_list = utils.load_converted_list_from_text(f"{args.model_path}")
7474
input_spec = [None] * len(inputs_params_list)
7575
for i, v in enumerate(inputs_params_list):
76+
name = v["name"]
7677
dtype = v["info"]["dtype"]
7778
shape = v["info"]["shape"]
79+
# print(f"-- i: {i}, v: name={name}, shape={shape}, dtype={dtype}")
7880
input_spec[i] = paddle.static.InputSpec(shape, dtype)
7981
return input_spec
8082

8183

84+
def regular_item(item):
85+
if isinstance(item, paddle.Tensor) and (item.dtype == paddle.bfloat16):
86+
item = np.array(item.astype("float32"))
87+
else:
88+
item = np.array(item)
89+
if item.dtype == np.bool_:
90+
item = item.astype("float32")
91+
return item
92+
93+
8294
def test_single_model(args):
8395
synchronizer_func = get_synchronizer_func(args)
8496
input_dict = get_input_dict(args)
@@ -88,21 +100,18 @@ def test_single_model(args):
88100
build_strategy.build_cinn_pass = False
89101

90102
# eager
91-
model = paddle.jit.to_static(
92-
model_dy,
93-
full_graph=False,
94-
)
95-
model.eval()
103+
print("-- Run with eager mode")
104+
model_dy.eval()
96105
for _ in range(args.warmup if args.warmup > 0 else 0):
97-
model(**input_dict)
106+
model_dy(**input_dict)
98107
eager_duration_box = DurationBox(-1)
99108
with naive_timer(eager_duration_box, synchronizer_func):
100-
expected_out = model(**input_dict)
109+
expected_out = model_dy(**input_dict)
101110

102111
# compiled
112+
print("-- Run with compiled mode")
103113
build_strategy = paddle.static.BuildStrategy()
104-
build_strategy.build_cinn_pass = True
105-
compilation_start_time = time.time()
114+
# build_strategy.build_cinn_pass = True
106115
compiled_model = paddle.jit.to_static(
107116
model_dy,
108117
input_spec=input_spec,
@@ -118,15 +127,25 @@ def test_single_model(args):
118127
with naive_timer(compiled_duration_box, synchronizer_func):
119128
compiled_out = compiled_model(**input_dict)
120129

121-
expected_out_list = (
122-
expected_out if isinstance(expected_out, (list, tuple)) else [expected_out]
123-
)
124-
compiled_out_list = (
125-
compiled_out if isinstance(compiled_out, (list, tuple)) else [compiled_out]
126-
)
127-
128-
processed_expected_out = [t.numpy() for t in expected_out_list]
129-
processed_compiled_out = [t.numpy() for t in compiled_out_list]
130+
if isinstance(expected_out, paddle.Tensor):
131+
expected_out = [expected_out]
132+
compiled_out = [compiled_out]
133+
if isinstance(expected_out, list) or isinstance(expected_out, tuple):
134+
for a, b in zip(expected_out, compiled_out):
135+
if (a is None and b is not None) or (a is not None and b is None):
136+
raise ValueError("Both expected_out and compiled_out must be not None.")
137+
expected_out = [
138+
regular_item(item)
139+
for item in expected_out
140+
if item is not None and np.array(item).size != 0
141+
]
142+
compiled_out = [
143+
regular_item(item)
144+
for item in compiled_out
145+
if item is not None and np.array(item).size != 0
146+
]
147+
else:
148+
raise ValueError("Illegal return value.")
130149

131150
def print_cmp(key, func, **kwargs):
132151
cmp_ret = func(processed_expected_out, processed_compiled_out, **kwargs)

graph_net/paddle/utils.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import argparse
77
import importlib
88
import inspect
9+
import ast
910
import paddle
1011

1112

@@ -115,8 +116,7 @@ def load_converted_list_from_text(file_path):
115116
weight_info = [
116117
data for data in convert_meta_classes_to_tensors(f"{file_path}/weight_meta.py")
117118
]
118-
119-
return [*input_info, *weight_info]
119+
return [*weight_info, *input_info]
120120

121121

122122
def convert_meta_classes_to_tensors(file_path):
@@ -127,32 +127,42 @@ def convert_meta_classes_to_tensors(file_path):
127127
if not k.startswith("__") and not callable(v)
128128
}
129129
data_value = None
130-
data_type = getattr(paddle, attrs.get("dtype", "paddle.float").split(".")[-1])
130+
data_type = getattr(paddle, attrs.get("dtype", "float32"))
131131
if attrs.get("data") is not None:
132132
if isinstance(attrs.get("data"), str):
133133
raise ValueError("Unimplemented")
134134
else:
135-
data_value = paddle.to_tensor(
136-
attrs.get("data"), dtype=data_type
137-
).reshape(attrs.get("shape"), [])
135+
data_value = paddle.reshape(
136+
paddle.to_tensor(attrs.get("data"), dtype=data_type),
137+
attrs.get("shape", []),
138+
)
138139
yield {
139140
"info": {
140141
"shape": attrs.get("shape", []),
141142
"dtype": data_type,
142143
"device": attrs.get("device", "gpu"),
143144
"mean": attrs.get("mean", 0.0),
144145
"std": attrs.get("std", 1.0),
146+
"low": attrs.get("low", 0),
147+
"high": attrs.get("high", 2),
145148
},
146149
"data": data_value,
147150
"name": attrs.get("name"),
148151
}
149152

150153

151154
def _get_classes(file_path):
155+
with open(file_path, "r", encoding="utf-8") as f:
156+
tree = ast.parse(f.read(), filename=file_path)
157+
158+
class_names = [node.name for node in tree.body if isinstance(node, ast.ClassDef)]
159+
152160
spec = importlib.util.spec_from_file_location("unnamed", file_path)
153161
unnamed = importlib.util.module_from_spec(spec)
154162
spec.loader.exec_module(unnamed)
155-
yield from inspect.getmembers(unnamed, inspect.isclass)
163+
164+
classes = [(name, getattr(unnamed, name)) for name in class_names]
165+
return classes
156166

157167

158168
def extract_dynamic_shapes(example_inputs):
@@ -163,11 +173,28 @@ def replay_tensor(info):
163173
device = info["info"]["device"]
164174
dtype = info["info"]["dtype"]
165175
shape = info["info"]["shape"]
176+
min_value = info["info"]["low"] if "low" in info["info"] else 0
177+
max_value = info["info"]["high"] if "high" in info["info"] else 0.5
166178
if None in shape:
167179
shape = list(map(lambda i: i if i is not None else 1, shape))
168-
mean = info["info"]["mean"]
169-
std = info["info"]["std"]
170180
if "data" in info and info["data"] is not None:
171-
return info["data"].to(device)
172-
173-
return (paddle.randn(shape).cast(dtype).to(device) * std * 1e-3 + 1e-2).cast(dtype)
181+
return paddle.reshape(info["data"], shape).to(dtype).to(device)
182+
elif dtype == paddle.int32 or dtype == paddle.int64:
183+
# for some ops(binary_cross_entropy), label data can only be set 0 or 1.
184+
return paddle.cast(
185+
paddle.randint(low=0, high=2, shape=shape, dtype="int64"),
186+
dtype,
187+
).to(device)
188+
elif dtype == paddle.bool:
189+
return paddle.cast(
190+
paddle.randint(low=0, high=2, shape=shape, dtype="int32"),
191+
paddle.bool,
192+
).to(device)
193+
else:
194+
std = info["info"]["std"]
195+
# return paddle.randn(shape).to(dtype).to(device) * std * 1e-3 + 1e-2
196+
return (
197+
paddle.uniform(shape, dtype="float32", min=min_value, max=max_value)
198+
.to(dtype)
199+
.to(device)
200+
)

graph_net/paddle/validate.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,13 @@ def main(args):
6666
params.update(inputs)
6767
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
6868

69-
y = model(**state_dict)[0]
69+
y = model(**state_dict)
7070

71-
print(np.argmin(y), np.argmax(y))
71+
# print(np.argmin(y), np.argmax(y))
7272
if isinstance(y, paddle.Tensor):
7373
print(y.shape)
74-
elif (isinstance(y, list) or isinstance(y, tuple)) and all(
75-
isinstance(obj, paddle.Tensor) for obj in y
76-
):
77-
# list of paddle.Tensor
78-
print(y[0].shape)
74+
elif isinstance(y, list) or isinstance(y, tuple):
75+
print(y[0].shape if isinstance(y[0], paddle.Tensor) else y[0])
7976
else:
8077
raise ValueError("Illegal return value.")
8178

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import unittest
2+
from graph_net.torch.rp_expr import Tokenize
3+
from graph_net.torch.rp_expr.rp_expr_passes import (
4+
FlattenTokenListPass,
5+
FoldTokensPass,
6+
RecursiveFoldTokensPass,
7+
FoldIfTokenIdGreatEqualPass,
8+
)
9+
from graph_net.torch.rp_expr.nested_range import Range, Tree
10+
from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser
11+
from graph_net.torch.rp_expr.rp_expr_util import (
12+
MakeNestedIndexRangeFromLetsListTokenRpExpr,
13+
)
14+
15+
16+
class TestTokenize(unittest.TestCase):
17+
"""Tests tokenization of primitive ID lists into symbolic token sequences."""
18+
19+
def test_simple(self):
20+
primitive_id_lists = [list(range(10 + i)) for i in range(5)]
21+
token_list, id_allocator, _ = Tokenize(primitive_id_lists)
22+
self.assertEqual(len(token_list.tensors), len(primitive_id_lists))
23+
24+
25+
class TestFlattenTokenListPass(unittest.TestCase):
26+
"""Tests flattening of nested token structures into linear sequences."""
27+
28+
def test_simple(self):
29+
base = 10
30+
size = 5
31+
primitive_id_lists = [list(range(base + i)) for i in range(size)]
32+
token_list, id_allocator, _ = Tokenize(primitive_id_lists)
33+
rp_expr_pass = FlattenTokenListPass(id_allocator)
34+
success, flattened_rp_expr_pass = rp_expr_pass(token_list)
35+
self.assertTrue(success)
36+
self.assertEqual(id_allocator.NextTokenId(), base + 2 * size - 1)
37+
38+
39+
class TestFoldTokensPass(unittest.TestCase):
40+
"""Tests folding of the most frequent contiguous token pattern into a single symbol."""
41+
42+
def test_simple(self):
43+
base = 3
44+
size = 3
45+
primitive_id_lists = [list(range(base + i)) for i in range(size)]
46+
token_list, id_allocator, _ = Tokenize(primitive_id_lists)
47+
flatten_pass = FlattenTokenListPass(id_allocator)
48+
_, flattened_rp_expr = flatten_pass(token_list)
49+
fold_pass = FoldTokensPass(id_allocator)
50+
success, fold_rp_expr = fold_pass(flattened_rp_expr.flattened_tensor)
51+
self.assertTrue(success)
52+
input = flattened_rp_expr.flattened_tensor.tensor.numpy().tolist()
53+
pattern = fold_rp_expr.symbol_token_tensors[0].numpy().tolist()
54+
replacement = fold_rp_expr.symbol_token_ids[0]
55+
output = fold_rp_expr.body_rp_expr.tensor.numpy().tolist()
56+
self.assertEqual(input, [3, 4, 5, 1, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7])
57+
self.assertEqual(pattern, [3, 4, 5])
58+
self.assertEqual(replacement, 8)
59+
self.assertEqual(output, [8, 1, 8, 6, 2, 8, 6, 7])
60+
61+
62+
class TestRecursiveFoldTokensPass(unittest.TestCase):
63+
"""Tests recursive folding of repeated patterns at multiple levels of nesting."""
64+
65+
def test_simple(self):
66+
base = 3
67+
size = 3
68+
primitive_id_lists = [list(range(base + i)) for i in range(size)]
69+
token_list, id_allocator, _ = Tokenize(primitive_id_lists)
70+
flatten_pass = FlattenTokenListPass(id_allocator)
71+
_, flattened_rp_expr = flatten_pass(token_list)
72+
fold_pass = RecursiveFoldTokensPass(id_allocator)
73+
success, fold_rp_expr = fold_pass(flattened_rp_expr.flattened_tensor)
74+
self.assertTrue(success)
75+
input = flattened_rp_expr.flattened_tensor.tensor.numpy().tolist()
76+
pattern = [x.numpy().tolist() for x in fold_rp_expr.symbol_token_tensors]
77+
replacement = fold_rp_expr.symbol_token_ids
78+
output = fold_rp_expr.body_rp_expr.tensor.numpy().tolist()
79+
self.assertEqual(input, [3, 4, 5, 1, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7])
80+
self.assertEqual(pattern, [[3, 4, 5], [8, 6]])
81+
self.assertEqual(replacement, [8, 9])
82+
self.assertEqual(output, [8, 1, 9, 2, 9, 7])
83+
84+
85+
class TestFoldIfTokenIdGreatEqualPass(unittest.TestCase):
86+
"""Tests conditional folding only for tokens with ID greater than or equal to a threshold."""
87+
88+
def test_simple(self):
89+
base = 3
90+
size = 3
91+
primitive_id_lists = [list(range(base + i)) for i in range(size)]
92+
token_list, id_allocator, _ = Tokenize(primitive_id_lists)
93+
flatten_pass = FlattenTokenListPass(id_allocator)
94+
_, flattened_rp_expr = flatten_pass(token_list)
95+
fold_pass = RecursiveFoldTokensPass(id_allocator)
96+
success, fold_rp_expr = fold_pass(flattened_rp_expr.flattened_tensor)
97+
self.assertTrue(success)
98+
threshold_fold_pass = FoldIfTokenIdGreatEqualPass(
99+
id_allocator=id_allocator,
100+
threshold_start_token_id=len(primitive_id_lists),
101+
)
102+
success, threshold_fold_rp_expr = threshold_fold_pass(fold_rp_expr.body_rp_expr)
103+
self.assertTrue(success)
104+
input = fold_rp_expr.body_rp_expr.tensor.numpy().tolist()
105+
pattern = [
106+
x.numpy().tolist() for x in threshold_fold_rp_expr.symbol_token_tensors
107+
]
108+
replacement = threshold_fold_rp_expr.symbol_token_ids
109+
self.assertEqual(len(threshold_fold_rp_expr.body_rp_expr), 3)
110+
output = [x.numpy().tolist() for x in threshold_fold_rp_expr.body_rp_expr]
111+
self.assertEqual(input, [8, 1, 9, 2, 9, 7])
112+
self.assertEqual(pattern, [[9, 7]])
113+
self.assertEqual(replacement, [10])
114+
self.assertEqual(output, [[8], [9], [10]])
115+
116+
117+
if __name__ == "__main__":
118+
unittest.main()
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from .rp_expr import Tokenize
2+
from .rp_expr_parser import RpExprParser
3+
from .nested_range import Range, Tree
4+
from .rp_expr_util import MakeNestedIndexRangeFromLetsListTokenRpExpr
5+
from .rp_expr_passes import (
6+
FlattenTokenListPass,
7+
FoldTokensPass,
8+
RecursiveFoldTokensPass,
9+
FoldIfTokenIdGreatEqualPass,
10+
)
11+
12+
__all__ = [
13+
"Tokenize",
14+
"RpExprParser",
15+
"Range",
16+
"Tree",
17+
"MakeNestedIndexRangeFromLetsListTokenRpExpr",
18+
"FlattenTokenListPass",
19+
"FoldTokensPass",
20+
"RecursiveFoldTokensPass",
21+
"FoldIfTokenIdGreatEqualPass",
22+
]

0 commit comments

Comments
 (0)