Skip to content

Commit ba11302

Browse files
committed
Merge branch 'develop' into add_albert_t5
2 parents bb21125 + f8d61c2 commit ba11302

File tree

2,330 files changed

+1089423
-143604
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

2,330 files changed

+1089423
-143604
lines changed

graph_net/paddle/validate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _extract_forward_source(model_path, class_name):
4949
def check_graph_hash(args):
5050
model_path = args.model_path
5151
file_path = f"{model_path}/graph_hash.txt"
52-
if args.dump_graph_hash_key:
52+
if not args.no_dump_graph_hash_key:
5353
model_str = _extract_forward_source(model_path, class_name="GraphModule")
5454
assert model_str is not None, f"model_str of {args.model_path} is None."
5555
new_hash_text = _get_sha_hash(model_str)
@@ -128,9 +128,9 @@ def main(args):
128128
help="whether check model graph redundancy",
129129
)
130130
parser.add_argument(
131-
"--dump-graph-hash-key",
131+
"--no-dump-graph-hash-key",
132132
action="store_true",
133-
default=True,
133+
default=False,
134134
help="Dump graph hash key",
135135
)
136136
parser.add_argument(

graph_net/test_compiler_util.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
3+
4+
def tolerance_generator(t):
5+
# for float16
6+
yield 10 ** (t * 3 / 5), 10**t
7+
# for bfloat16
8+
yield 10 ** (t * 1.796 / 5), 10**t
9+
# yield float32
10+
yield 10 ** (t * 5.886 / 5), 10**t
11+
# yield float64
12+
yield 10 ** (t * 7 / 5), 10 ** (t * 7 / 5)
13+
14+
15+
def calculate_tolerance_pair(begin, end):
16+
tolerance_pair_list = []
17+
for t in range(begin, end + 1):
18+
for rtol, atol in tolerance_generator(t):
19+
effective_atol = float(f"{atol:.3g}")
20+
effective_rtol = float(f"{rtol:.3g}")
21+
tolerance_pair_list.append(
22+
{
23+
"atol": effective_atol,
24+
"rtol": effective_rtol,
25+
}
26+
)
27+
return tolerance_pair_list
28+
29+
30+
def generate_allclose_configs(cmp_all_close_func):
31+
tolerance_pair_list = calculate_tolerance_pair(-10, 5)
32+
33+
cmp_configs = []
34+
for pair in tolerance_pair_list:
35+
atol, rtol = pair["atol"], pair["rtol"]
36+
cmp_configs.append(
37+
(f"[all_close_atol_{atol:.2E}_rtol_{rtol:.2E}]", cmp_all_close_func, pair)
38+
)
39+
return cmp_configs
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
4695a657867f0039c4ff42f6aecd6345cf777a6a14311274b8358a577c69c09d
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"framework": "paddle",
3+
"model_name": "blenderbot-1B-distill",
4+
"num_devices_required": 1,
5+
"num_nodes_required": 1
6+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
class Program_weight_tensor_data_0:
2+
name = "data_0"
3+
shape = [1, 23]
4+
dtype = "int64"
5+
data = [
6+
6950,
7+
19,
8+
395,
9+
1356,
10+
315,
11+
7140,
12+
21,
13+
281,
14+
632,
15+
3547,
16+
458,
17+
1966,
18+
3244,
19+
5837,
20+
298,
21+
549,
22+
7278,
23+
277,
24+
523,
25+
1499,
26+
21,
27+
228,
28+
2,
29+
]

0 commit comments

Comments
 (0)