Skip to content

Commit bb30a84

Browse files
authored
Implement a common function to calculate the tolerance configs. (#286)
1 parent 3c3fc04 commit bb30a84

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

graph_net/test_compiler_util.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
3+
4+
def tolerance_generator(t):
5+
# for float16
6+
yield 10 ** (t * 3 / 5), 10**t
7+
# for bfloat16
8+
yield 10 ** (t * 1.796 / 5), 10**t
9+
# yield float32
10+
yield 10 ** (t * 5.886 / 5), 10**t
11+
# yield float64
12+
yield 10 ** (t * 7 / 5), 10 ** (t * 7 / 5)
13+
14+
15+
def calculate_tolerance_pair(begin, end):
16+
tolerance_pair_list = []
17+
for t in range(begin, end + 1):
18+
for rtol, atol in tolerance_generator(t):
19+
effective_atol = float(f"{atol:.3g}")
20+
effective_rtol = float(f"{rtol:.3g}")
21+
tolerance_pair_list.append(
22+
{
23+
"atol": effective_atol,
24+
"rtol": effective_rtol,
25+
}
26+
)
27+
return tolerance_pair_list
28+
29+
30+
def generate_allclose_configs(cmp_all_close_func):
31+
tolerance_pair_list = calculate_tolerance_pair(-10, 5)
32+
33+
cmp_configs = []
34+
for pair in tolerance_pair_list:
35+
atol, rtol = pair["atol"], pair["rtol"]
36+
cmp_configs.append(
37+
(f"[all_close_atol_{atol:.2E}_rtol_{rtol:.2E}]", cmp_all_close_func, pair)
38+
)
39+
return cmp_configs

0 commit comments

Comments
 (0)