|
| 1 | +import os |
| 2 | + |
| 3 | + |
| 4 | +def tolerance_generator(t): |
| 5 | + # for float16 |
| 6 | + yield 10 ** (t * 3 / 5), 10**t |
| 7 | + # for bfloat16 |
| 8 | + yield 10 ** (t * 1.796 / 5), 10**t |
| 9 | + # yield float32 |
| 10 | + yield 10 ** (t * 5.886 / 5), 10**t |
| 11 | + # yield float64 |
| 12 | + yield 10 ** (t * 7 / 5), 10 ** (t * 7 / 5) |
| 13 | + |
| 14 | + |
| 15 | +def calculate_tolerance_pair(begin, end): |
| 16 | + tolerance_pair_list = [] |
| 17 | + for t in range(begin, end + 1): |
| 18 | + for rtol, atol in tolerance_generator(t): |
| 19 | + effective_atol = float(f"{atol:.3g}") |
| 20 | + effective_rtol = float(f"{rtol:.3g}") |
| 21 | + tolerance_pair_list.append( |
| 22 | + { |
| 23 | + "atol": effective_atol, |
| 24 | + "rtol": effective_rtol, |
| 25 | + } |
| 26 | + ) |
| 27 | + return tolerance_pair_list |
| 28 | + |
| 29 | + |
| 30 | +def generate_allclose_configs(cmp_all_close_func): |
| 31 | + tolerance_pair_list = calculate_tolerance_pair(-10, 5) |
| 32 | + |
| 33 | + cmp_configs = [] |
| 34 | + for pair in tolerance_pair_list: |
| 35 | + atol, rtol = pair["atol"], pair["rtol"] |
| 36 | + cmp_configs.append( |
| 37 | + (f"[all_close_atol_{atol:.2E}_rtol_{rtol:.2E}]", cmp_all_close_func, pair) |
| 38 | + ) |
| 39 | + return cmp_configs |
0 commit comments