Skip to content

Commit 4f8d8da

Browse files
committed
Correct the tolerance.
1 parent f65c53e commit 4f8d8da

File tree

1 file changed

+41
-324
lines changed

1 file changed

+41
-324
lines changed

graph_net/test_compiler_util.py

Lines changed: 41 additions & 324 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,44 @@ def check_equal(args, expected_out, compiled_out, cmp_equal_func):
159159
)
160160

161161

162+
def tolerance_generator(t):
163+
# for float16
164+
yield 10 ** (t * 3 / 5), 10**t
165+
# for bfloat16
166+
yield 10 ** (t * 1.796 / 5), 10**t
167+
# yield float32
168+
yield 10 ** (t * 5.886 / 5), 10**t
169+
# yield float64
170+
yield 10 ** (t * 7 / 5), 10 ** (t * 7 / 5)
171+
172+
173+
def compute_tolerance_pair(begin, end):
174+
tolerance_pair_list = []
175+
for t in range(begin, end + 1):
176+
for rtol, atol in tolerance_generator(t):
177+
effective_atol = float(f"{atol:.3g}")
178+
effective_rtol = float(f"{rtol:.3g}")
179+
tolerance_pair_list.append(
180+
{
181+
"atol": effective_atol,
182+
"rtol": effective_rtol,
183+
}
184+
)
185+
return tolerance_pair_list
186+
187+
188+
def generate_allclose_configs(cmp_all_close_func):
189+
tolerance_pair_list = compute_tolerance_pair(-10, 5)
190+
191+
cmp_configs = []
192+
for pair in tolerance_pair_list:
193+
atol, rtol = pair["atol"], pair["rtol"]
194+
cmp_configs.append(
195+
(f"[all_close_atol_{atol:.2E}_rtol_{rtol:.2E}]", cmp_all_close_func, pair)
196+
)
197+
return cmp_configs
198+
199+
162200
def check_allclose(
163201
args,
164202
expected_out,
@@ -168,330 +206,9 @@ def check_allclose(
168206
cmp_mean_diff_func,
169207
cmp_diff_count_func,
170208
):
171-
cmp_configs = [
172-
(
173-
"[all_close_atol_1.00E-06_rtol_1.00E-10]",
174-
cmp_all_close_func,
175-
{"atol": 1.00e-06, "rtol": 1.00e-10},
176-
),
177-
(
178-
"[all_close_atol_2.56E-04_rtol_1.00E-10]",
179-
cmp_all_close_func,
180-
{"atol": 2.56e-04, "rtol": 1.00e-10},
181-
),
182-
(
183-
"[all_close_atol_1.69E-12_rtol_1.00E-10]",
184-
cmp_all_close_func,
185-
{"atol": 1.69e-12, "rtol": 1.00e-10},
186-
),
187-
(
188-
"[all_close_atol_1.00E-14_rtol_1.00E-10]",
189-
cmp_all_close_func,
190-
{"atol": 1.00e-14, "rtol": 1.00e-10},
191-
),
192-
(
193-
"[all_close_atol_3.98E-06_rtol_1.00E-09]",
194-
cmp_all_close_func,
195-
{"atol": 3.98e-06, "rtol": 1.00e-09},
196-
),
197-
(
198-
"[all_close_atol_5.85E-04_rtol_1.00E-09]",
199-
cmp_all_close_func,
200-
{"atol": 5.85e-04, "rtol": 1.00e-09},
201-
),
202-
(
203-
"[all_close_atol_2.54E-11_rtol_1.00E-09]",
204-
cmp_all_close_func,
205-
{"atol": 2.54e-11, "rtol": 1.00e-09},
206-
),
207-
(
208-
"[all_close_atol_2.51E-13_rtol_1.00E-09]",
209-
cmp_all_close_func,
210-
{"atol": 2.51e-13, "rtol": 1.00e-09},
211-
),
212-
(
213-
"[all_close_atol_1.58E-05_rtol_1.00E-08]",
214-
cmp_all_close_func,
215-
{"atol": 1.58e-05, "rtol": 1.00e-08},
216-
),
217-
(
218-
"[all_close_atol_1.34E-03_rtol_1.00E-08]",
219-
cmp_all_close_func,
220-
{"atol": 1.34e-03, "rtol": 1.00e-08},
221-
),
222-
(
223-
"[all_close_atol_3.82E-10_rtol_1.00E-08]",
224-
cmp_all_close_func,
225-
{"atol": 3.82e-10, "rtol": 1.00e-08},
226-
),
227-
(
228-
"[all_close_atol_6.31E-12_rtol_1.00E-08]",
229-
cmp_all_close_func,
230-
{"atol": 6.31e-12, "rtol": 1.00e-08},
231-
),
232-
(
233-
"[all_close_atol_6.31E-05_rtol_1.00E-07]",
234-
cmp_all_close_func,
235-
{"atol": 6.31e-05, "rtol": 1.00e-07},
236-
),
237-
(
238-
"[all_close_atol_3.06E-03_rtol_1.00E-07]",
239-
cmp_all_close_func,
240-
{"atol": 3.06e-03, "rtol": 1.00e-07},
241-
),
242-
(
243-
"[all_close_atol_5.75E-09_rtol_1.00E-07]",
244-
cmp_all_close_func,
245-
{"atol": 5.75e-09, "rtol": 1.00e-07},
246-
),
247-
(
248-
"[all_close_atol_1.58E-10_rtol_1.00E-07]",
249-
cmp_all_close_func,
250-
{"atol": 1.58e-10, "rtol": 1.00e-07},
251-
),
252-
(
253-
"[all_close_atol_2.51E-04_rtol_1.00E-06]",
254-
cmp_all_close_func,
255-
{"atol": 2.51e-04, "rtol": 1.00e-06},
256-
),
257-
(
258-
"[all_close_atol_7.00E-03_rtol_1.00E-06]",
259-
cmp_all_close_func,
260-
{"atol": 7.00e-03, "rtol": 1.00e-06},
261-
),
262-
(
263-
"[all_close_atol_8.65E-08_rtol_1.00E-06]",
264-
cmp_all_close_func,
265-
{"atol": 8.65e-08, "rtol": 1.00e-06},
266-
),
267-
(
268-
"[all_close_atol_3.98E-09_rtol_1.00E-06]",
269-
cmp_all_close_func,
270-
{"atol": 3.98e-09, "rtol": 1.00e-06},
271-
),
272-
(
273-
"[all_close_atol_1.00E-03_rtol_1.00E-05]",
274-
cmp_all_close_func,
275-
{"atol": 1.00e-03, "rtol": 1.00e-05},
276-
),
277-
(
278-
"[all_close_atol_1.60E-02_rtol_1.00E-05]",
279-
cmp_all_close_func,
280-
{"atol": 1.60e-02, "rtol": 1.00e-05},
281-
),
282-
(
283-
"[all_close_atol_1.30E-06_rtol_1.00E-05]",
284-
cmp_all_close_func,
285-
{"atol": 1.30e-06, "rtol": 1.00e-05},
286-
),
287-
(
288-
"[all_close_atol_1.00E-07_rtol_1.00E-05]",
289-
cmp_all_close_func,
290-
{"atol": 1.00e-07, "rtol": 1.00e-05},
291-
),
292-
(
293-
"[all_close_atol_3.98E-03_rtol_1.00E-04]",
294-
cmp_all_close_func,
295-
{"atol": 3.98e-03, "rtol": 1.00e-04},
296-
),
297-
(
298-
"[all_close_atol_3.66E-02_rtol_1.00E-04]",
299-
cmp_all_close_func,
300-
{"atol": 3.66e-02, "rtol": 1.00e-04},
301-
),
302-
(
303-
"[all_close_atol_1.96E-05_rtol_1.00E-04]",
304-
cmp_all_close_func,
305-
{"atol": 1.96e-05, "rtol": 1.00e-04},
306-
),
307-
(
308-
"[all_close_atol_2.51E-06_rtol_1.00E-04]",
309-
cmp_all_close_func,
310-
{"atol": 2.51e-06, "rtol": 1.00e-04},
311-
),
312-
(
313-
"[all_close_atol_1.58E-02_rtol_1.00E-03]",
314-
cmp_all_close_func,
315-
{"atol": 1.58e-02, "rtol": 1.00e-03},
316-
),
317-
(
318-
"[all_close_atol_8.36E-02_rtol_1.00E-03]",
319-
cmp_all_close_func,
320-
{"atol": 8.36e-02, "rtol": 1.00e-03},
321-
),
322-
(
323-
"[all_close_atol_2.94E-04_rtol_1.00E-03]",
324-
cmp_all_close_func,
325-
{"atol": 2.94e-04, "rtol": 1.00e-03},
326-
),
327-
(
328-
"[all_close_atol_6.31E-05_rtol_1.00E-03]",
329-
cmp_all_close_func,
330-
{"atol": 6.31e-05, "rtol": 1.00e-03},
331-
),
332-
(
333-
"[all_close_atol_6.31E-02_rtol_1.00E-02]",
334-
cmp_all_close_func,
335-
{"atol": 6.31e-02, "rtol": 1.00e-02},
336-
),
337-
(
338-
"[all_close_atol_1.91E-01_rtol_1.00E-02]",
339-
cmp_all_close_func,
340-
{"atol": 1.91e-01, "rtol": 1.00e-02},
341-
),
342-
(
343-
"[all_close_atol_4.42E-03_rtol_1.00E-02]",
344-
cmp_all_close_func,
345-
{"atol": 4.42e-03, "rtol": 1.00e-02},
346-
),
347-
(
348-
"[all_close_atol_1.58E-03_rtol_1.00E-02]",
349-
cmp_all_close_func,
350-
{"atol": 1.58e-03, "rtol": 1.00e-02},
351-
),
352-
(
353-
"[all_close_atol_2.51E-01_rtol_1.00E-01]",
354-
cmp_all_close_func,
355-
{"atol": 2.51e-01, "rtol": 1.00e-01},
356-
),
357-
(
358-
"[all_close_atol_4.37E-01_rtol_1.00E-01]",
359-
cmp_all_close_func,
360-
{"atol": 4.37e-01, "rtol": 1.00e-01},
361-
),
362-
(
363-
"[all_close_atol_6.65E-02_rtol_1.00E-01]",
364-
cmp_all_close_func,
365-
{"atol": 6.65e-02, "rtol": 1.00e-01},
366-
),
367-
(
368-
"[all_close_atol_3.98E-02_rtol_1.00E-01]",
369-
cmp_all_close_func,
370-
{"atol": 3.98e-02, "rtol": 1.00e-01},
371-
),
372-
(
373-
"[all_close_atol_1.00E+00_rtol_1.00E+00]",
374-
cmp_all_close_func,
375-
{"atol": 1.00e00, "rtol": 1.00e00},
376-
),
377-
(
378-
"[all_close_atol_1.00E+00_rtol_1.00E+00]",
379-
cmp_all_close_func,
380-
{"atol": 1.00e00, "rtol": 1.00e00},
381-
),
382-
(
383-
"[all_close_atol_1.00E+00_rtol_1.00E+00]",
384-
cmp_all_close_func,
385-
{"atol": 1.00e00, "rtol": 1.00e00},
386-
),
387-
(
388-
"[all_close_atol_1.00E+00_rtol_1.00E+00]",
389-
cmp_all_close_func,
390-
{"atol": 1.00e00, "rtol": 1.00e00},
391-
),
392-
(
393-
"[all_close_atol_3.98E+00_rtol_1.00E+01]",
394-
cmp_all_close_func,
395-
{"atol": 3.98e00, "rtol": 1.00e01},
396-
),
397-
(
398-
"[all_close_atol_2.29E+00_rtol_1.00E+01]",
399-
cmp_all_close_func,
400-
{"atol": 2.29e00, "rtol": 1.00e01},
401-
),
402-
(
403-
"[all_close_atol_1.50E+01_rtol_1.00E+01]",
404-
cmp_all_close_func,
405-
{"atol": 1.50e01, "rtol": 1.00e01},
406-
),
407-
(
408-
"[all_close_atol_2.51E+01_rtol_1.00E+01]",
409-
cmp_all_close_func,
410-
{"atol": 2.51e01, "rtol": 1.00e01},
411-
),
412-
(
413-
"[all_close_atol_1.58E+01_rtol_1.00E+02]",
414-
cmp_all_close_func,
415-
{"atol": 1.58e01, "rtol": 1.00e02},
416-
),
417-
(
418-
"[all_close_atol_5.23E+00_rtol_1.00E+02]",
419-
cmp_all_close_func,
420-
{"atol": 5.23e00, "rtol": 1.00e02},
421-
),
422-
(
423-
"[all_close_atol_2.26E+02_rtol_1.00E+02]",
424-
cmp_all_close_func,
425-
{"atol": 2.26e02, "rtol": 1.00e02},
426-
),
427-
(
428-
"[all_close_atol_6.31E+02_rtol_1.00E+02]",
429-
cmp_all_close_func,
430-
{"atol": 6.31e02, "rtol": 1.00e02},
431-
),
432-
(
433-
"[all_close_atol_6.31E+01_rtol_1.00E+03]",
434-
cmp_all_close_func,
435-
{"atol": 6.31e01, "rtol": 1.00e03},
436-
),
437-
(
438-
"[all_close_atol_1.20E+01_rtol_1.00E+03]",
439-
cmp_all_close_func,
440-
{"atol": 1.20e01, "rtol": 1.00e03},
441-
),
442-
(
443-
"[all_close_atol_3.40E+03_rtol_1.00E+03]",
444-
cmp_all_close_func,
445-
{"atol": 3.40e03, "rtol": 1.00e03},
446-
),
447-
(
448-
"[all_close_atol_1.58E+04_rtol_1.00E+03]",
449-
cmp_all_close_func,
450-
{"atol": 1.58e04, "rtol": 1.00e03},
451-
),
452-
(
453-
"[all_close_atol_2.51E+02_rtol_1.00E+04]",
454-
cmp_all_close_func,
455-
{"atol": 2.51e02, "rtol": 1.00e04},
456-
),
457-
(
458-
"[all_close_atol_2.73E+01_rtol_1.00E+04]",
459-
cmp_all_close_func,
460-
{"atol": 2.73e01, "rtol": 1.00e04},
461-
),
462-
(
463-
"[all_close_atol_5.11E+04_rtol_1.00E+04]",
464-
cmp_all_close_func,
465-
{"atol": 5.11e04, "rtol": 1.00e04},
466-
),
467-
(
468-
"[all_close_atol_3.98E+05_rtol_1.00E+04]",
469-
cmp_all_close_func,
470-
{"atol": 3.98e05, "rtol": 1.00e04},
471-
),
472-
(
473-
"[all_close_atol_1.00E+03_rtol_1.00E+05]",
474-
cmp_all_close_func,
475-
{"atol": 1.00e03, "rtol": 1.00e05},
476-
),
477-
(
478-
"[all_close_atol_6.25E+01_rtol_1.00E+05]",
479-
cmp_all_close_func,
480-
{"atol": 6.25e01, "rtol": 1.00e05},
481-
),
482-
(
483-
"[all_close_atol_7.69E+05_rtol_1.00E+05]",
484-
cmp_all_close_func,
485-
{"atol": 7.69e05, "rtol": 1.00e05},
486-
),
487-
(
488-
"[all_close_atol_1.00E+07_rtol_1.00E+05]",
489-
cmp_all_close_func,
490-
{"atol": 1.00e07, "rtol": 1.00e05},
491-
),
492-
("[max_diff]", cmp_max_diff_func, {}),
493-
("[mean_diff]", cmp_mean_diff_func, {}),
494-
]
209+
cmp_configs = generate_allclose_configs(cmp_all_close_func)
210+
cmp_configs.append(("[max_diff]", cmp_max_diff_func, {}))
211+
cmp_configs.append(("[mean_diff]", cmp_mean_diff_func, {}))
495212

496213
for key, func, kwargs in cmp_configs:
497214
print_and_store_cmp(

0 commit comments

Comments
 (0)