Allow lower precision calculations in pre-compute transforms#368
Allow lower precision calculations in pre-compute transforms#368willGraham01 wants to merge 23 commits intomainfrom
Conversation
…r(f.dtype) everywhere
This comment was marked as outdated.
This comment was marked as outdated.
|
Running some simple benchmarks using adaptations to the As such, the best proxy we have is the "tmp mem size (B)" (temporary memory allocated during runtime), which is only recorded for JAX computations. However we do (promisingly) see that these figures roughly when all other parameters are held constant, and the We don't gather the temporary memory usage statistic for the torch runs, but given that the torch implementations just wrap the JAX ones, I have reasonable confidence the torch case is also behaving. And in all cases (JAX + torch); the output array size does approximately halve when toggling double to single precision, so the output array is what we expect when it comes out. This is also one of the checks in the test suite now, but it's nice to have some secondary validation 😅 Show results tableAll run at
|
matt-graham
left a comment
There was a problem hiding this comment.
Thanks @willGraham01. Overall this looks good to me. I've added some minor comments / soft suggestions but none of these are critical and I think this could go in as is hence approving now.
| # Allow a -/+1 margin for near-misses during rounding and taking log. | ||
| round_trip_error_long_dtype = abs(true_values - long_precision_result).max() | ||
| long_dtype_error_oom = np.round(np.log10(round_trip_error_long_dtype)) | ||
|
|
||
| round_trip_error_short_dtype = abs(true_values - short_precision_result).max() | ||
| short_dtype_error_oom = np.round(np.log10(round_trip_error_short_dtype)) | ||
|
|
||
| assert ( | ||
| long_dtype_error_oom <= 2 * short_dtype_error_oom | ||
| or long_dtype_error_oom == pytest.approx(2 * short_dtype_error_oom, abs=1) | ||
| ) |
There was a problem hiding this comment.
I think this check is useful but the logic here ends up being quite complicated. I think we essentially want to check that
where
Would something like
| # Allow a -/+1 margin for near-misses during rounding and taking log. | |
| round_trip_error_long_dtype = abs(true_values - long_precision_result).max() | |
| long_dtype_error_oom = np.round(np.log10(round_trip_error_long_dtype)) | |
| round_trip_error_short_dtype = abs(true_values - short_precision_result).max() | |
| short_dtype_error_oom = np.round(np.log10(round_trip_error_short_dtype)) | |
| assert ( | |
| long_dtype_error_oom <= 2 * short_dtype_error_oom | |
| or long_dtype_error_oom == pytest.approx(2 * short_dtype_error_oom, abs=1) | |
| ) | |
| round_trip_error_long_dtype = abs(true_values - long_precision_result).max() | |
| round_trip_error_short_dtype = abs(true_values - short_precision_result).max() | |
| log_error_ratio = np.log(round_trip_error_long_dtype) / np.log(short_precision_result) | |
| tolerance = 0.2 | |
| assert 2 - tolerance < log_error_ratio < 2 + tolerance |
not suffice (and avoid some of the issues around instability due to moving either side of rounding thresholds)? tolerance value might need to be adjusted here!
There was a problem hiding this comment.
(Do you mean to have the ratio the other way round above?) The check should be equivalent to
so if we want to use the ratio,
We're also always happy if we do better than expected, which is equivalent to
(
As such, I think the following code snippet should work
log_round_trip_error_double = np.log10(abs(true_values - double_calc_result).max())
log_round_trip_error_single = np.log10(abs(true_values - single_calc_result).max())
log_error_ratio = log_round_trip_error_single / log_round_trip_error_double
tolerance = 1 / np.abs(log_round_trip_error_single)
ratio_is_approx_2 = -tolerance <= (log_error_ratio - 2.0) <= tolerance
better_than_expected = 0.0 <= log_error_ratio <= 2.0
assert ratio_is_approx_2 or better_than_expectedIn any case you've caught a bug in my original logic too 😅 The long_dtype_error_oom <= 2 * short_dtype_error_oom comparison should have the reverse inequality sign (we've done better than expected if double the OOM of the single precision calculation is less than the double precision calc). 😅 This is corrected in the code snippet above now that we're using your idea of ratios instead... but now I'm seeing a lot of errors that aren't within these regions. HEALPix in particular has a ratio around 3 (error
Addresses #320 |
Allows lower-precision array data-types to be used in the pre-compute transforms, for the JAX and torch paths.
Preliminary benchmarking results below
The output arrays for the JAX and Torch pre-compute forward and inverse transforms will now inherit an appropriate lower-precision data-type based on the input signal / coefficients arrays, and the
realityargument (since in the inverse transform, one cannot determine if we will have a real signal simply by examining the data type of the harmonic coefficients, which are in general complex).It is necessary to have a small "lookup" function (
compatible_cmplx_dtype) to infer the appropriate data-type to use for the intermediate arrays (and the output arrays).dtype). This a static property of arrays that is known at trace time, so we should be OK to conditionally create other arrays sharing thisdtype(or whosedtypeis inferred from another input array'sdtype.Code paths are modified so that intermediate arrays are created using the lower-precision data-type throughout, which should reduce memory overhead during calculations by a factor of ~2. Of course, the price we pay is that we expect the order-of-magnitude of the resulting round trip error to halve (EG an error of$\approx 10^{-12}$ in double precision will likely become $\approx 10^{-6}$ when the same calculation is done in single precision). This expectation has been worked into the additional tests for this functionality (though we may decide it is not necessary).
THESE CHANGES DO NOT TOUCH THE
numpyPRE-COMPUTE TRANSFORMS - these transforms will continue to upcast tocomplex128regardless of the signal data type.