Skip to content

Allow lower precision calculations in pre-compute transforms#368

Open
willGraham01 wants to merge 23 commits intomainfrom
wgraham/320-precompute-precision
Open

Allow lower precision calculations in pre-compute transforms#368
willGraham01 wants to merge 23 commits intomainfrom
wgraham/320-precompute-precision

Conversation

@willGraham01
Copy link
Collaborator

@willGraham01 willGraham01 commented Mar 5, 2026

Addresses #320 |

Allows lower-precision array data-types to be used in the pre-compute transforms, for the JAX and torch paths.

Preliminary benchmarking results below

The output arrays for the JAX and Torch pre-compute forward and inverse transforms will now inherit an appropriate lower-precision data-type based on the input signal / coefficients arrays, and the reality argument (since in the inverse transform, one cannot determine if we will have a real signal simply by examining the data type of the harmonic coefficients, which are in general complex).

It is necessary to have a small "lookup" function (compatible_cmplx_dtype) to infer the appropriate data-type to use for the intermediate arrays (and the output arrays).

  • This has been implemented as a lookup function since we don't want a global variable as this would violate JAX's pure-function need. We could implement a more general converter method (EG by inferring the dtype size from the object itself / str specifier) if we wanted to, but not sure how much merit there is in this.
  • This function does require examination of the input array datatype (dtype). This a static property of arrays that is known at trace time, so we should be OK to conditionally create other arrays sharing this dtype (or whose dtype is inferred from another input array's dtype.

Code paths are modified so that intermediate arrays are created using the lower-precision data-type throughout, which should reduce memory overhead during calculations by a factor of ~2. Of course, the price we pay is that we expect the order-of-magnitude of the resulting round trip error to halve (EG an error of $\approx 10^{-12}$ in double precision will likely become $\approx 10^{-6}$ when the same calculation is done in single precision). This expectation has been worked into the additional tests for this functionality (though we may decide it is not necessary).

THESE CHANGES DO NOT TOUCH THE numpy PRE-COMPUTE TRANSFORMS - these transforms will continue to upcast to complex128 regardless of the signal data type.

@codecov

This comment was marked as outdated.

@willGraham01 willGraham01 marked this pull request as ready for review March 9, 2026 14:13
@willGraham01
Copy link
Collaborator Author

willGraham01 commented Mar 12, 2026

Running some simple benchmarks using adaptations to the benchmarking/ scripts gives promising results (assuming I'm interpreting the metrics correctly!). Though the "peak memory usage" statistic seems to be misleading as this isn't recording peak memory usage at runtime, only trace time? Either way it's too low to be a reflection of the memory usage when performing the calculations.

As such, the best proxy we have is the "tmp mem size (B)" (temporary memory allocated during runtime), which is only recorded for JAX computations. However we do (promisingly) see that these figures roughly when all other parameters are held constant, and the long_precision is switched from double precision to single precision.
This is consistent with what we were expecting, and though it is not reported in the table below, the round-trip errors also all go from $\approx 10^{-12}$ to $\approx 10^{-6}$ which is again consistent with the use of a lower-precision dtype.

We don't gather the temporary memory usage statistic for the torch runs, but given that the torch implementations just wrap the JAX ones, I have reasonable confidence the torch case is also behaving.

And in all cases (JAX + torch); the output array size does approximately halve when toggling double to single precision, so the output array is what we expect when it comes out. This is also one of the checks in the test suite now, but it's nice to have some secondary validation 😅

Show results table

All run at $L = 256$, no spin.

sampling method direction reality long_precision peak mem (B) tmp mem size (B) output array size (B)
dh jax forward False False 7074 5.41057e+08 1.04653e+06
dh jax forward False True 7222 1.08211e+09 2.09306e+06
dh jax forward True False 7186 2.70533e+08 1.04653e+06
dh jax forward True True 7534 5.41065e+08 2.09306e+06
dh jax inverse False False 7150 5.36869e+08 2.09306e+06
dh jax inverse False True 7132 1.07374e+09 4.18611e+06
dh jax inverse True False 7150 2.69484e+08 1.04653e+06
dh jax inverse True True 7186 5.38968e+08 2.09306e+06
dh torch forward False False 214507 nan nan
dh torch forward False True 216532 nan nan
dh torch forward True False 323927 nan nan
dh torch forward True True 329069 nan nan
dh torch inverse False False 217351 nan nan
dh torch inverse False True 218701 nan nan
dh torch inverse True False 241639 nan nan
dh torch inverse True True 241801 nan nan
gl jax forward False False 7150 2.70004e+08 1.04653e+06
gl jax forward False True 7222 5.40008e+08 2.09306e+06
gl jax forward True False 7150 1.35266e+08 1.04653e+06
gl jax forward True True 7222 2.70533e+08 2.09306e+06
gl jax inverse False False 7150 2.68958e+08 1.04653e+06
gl jax inverse False True 7186 5.37915e+08 2.09306e+06
gl jax inverse True False 7150 1.35266e+08 523264
gl jax inverse True True 7186 2.70533e+08 1.04653e+06
gl torch forward False False 226717 nan nan
gl torch forward False True 235696 nan nan
gl torch forward True False 324550 nan nan
gl torch forward True True 325092 nan nan
gl torch inverse False False 155194 nan nan
gl torch inverse False True 156726 nan nan
gl torch inverse True False 171728 nan nan
gl torch inverse True True 171930 nan nan
healpix jax forward False False 7186 5.37917e+08 1.04653e+06
healpix jax forward False True 7222 1.07583e+09 2.09306e+06
healpix jax forward True False 7186 2.70512e+08 1.04653e+06
healpix jax forward True True 7222 5.41024e+08 2.09306e+06
healpix jax inverse False False 7286 5.36933e+08 1.57286e+06
healpix jax inverse False True 7186 1.0738e+09 3.14573e+06
healpix jax inverse True False 7150 2.69026e+08 786432
healpix jax inverse True True 7186 5.37984e+08 1.57286e+06
healpix torch forward False False 16048550 nan nan
healpix torch forward False True 15710477 nan nan
healpix torch forward True False 17767798 nan nan
healpix torch forward True True 19317608 nan nan
healpix torch inverse False False 19431048 nan nan
healpix torch inverse False True 19448210 nan nan
healpix torch inverse True False 22533988 nan nan
healpix torch inverse True True 23431283 nan nan
mw jax forward False False 7186 5.47359e+08 1.04653e+06
mw jax forward False True 7222 1.09472e+09 2.09306e+06
mw jax forward True False 7186 2.77348e+08 1.04653e+06
mw jax forward True True 7222 5.54697e+08 2.09306e+06
mw jax inverse False False 7150 2.68958e+08 1.04653e+06
mw jax inverse False True 7186 5.37915e+08 2.09306e+06
mw jax inverse True False 7096 1.35266e+08 523264
mw jax inverse True True 7186 2.70533e+08 1.04653e+06
mw torch forward False False 872316 nan nan
mw torch forward False True 889309 nan nan
mw torch forward True False 2051718 nan nan
mw torch forward True True 2244823 nan nan
mw torch inverse False False 222575 nan nan
mw torch inverse False True 223871 nan nan
mw torch inverse True False 247996 nan nan
mw torch inverse True True 258274 nan nan
mwss jax forward False False 7186 5.47359e+08 1.04653e+06
mwss jax forward False True 7186 1.09472e+09 2.09306e+06
mwss jax forward True False 7186 2.77348e+08 1.04653e+06
mwss jax forward True True 7498 5.54697e+08 2.09306e+06
mwss jax inverse False False 7096 2.70004e+08 1.05267e+06
mwss jax inverse False True 7222 5.40008e+08 2.10534e+06
mwss jax inverse True False 7092 1.35266e+08 526336
mwss jax inverse True True 7222 2.70533e+08 1.05267e+06
mwss torch forward False False 266798 nan nan
mwss torch forward False True 262938 nan nan
mwss torch forward True False 621577 nan nan
mwss torch forward True True 624458 nan nan
mwss torch inverse False False 217652 nan nan
mwss torch inverse False True 219896 nan nan
mwss torch inverse True False 242237 nan nan
mwss torch inverse True True 243575 nan nan

Copy link
Collaborator

@matt-graham matt-graham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @willGraham01. Overall this looks good to me. I've added some minor comments / soft suggestions but none of these are critical and I think this could go in as is hence approving now.

Comment on lines +134 to +144
# Allow a -/+1 margin for near-misses during rounding and taking log.
round_trip_error_long_dtype = abs(true_values - long_precision_result).max()
long_dtype_error_oom = np.round(np.log10(round_trip_error_long_dtype))

round_trip_error_short_dtype = abs(true_values - short_precision_result).max()
short_dtype_error_oom = np.round(np.log10(round_trip_error_short_dtype))

assert (
long_dtype_error_oom <= 2 * short_dtype_error_oom
or long_dtype_error_oom == pytest.approx(2 * short_dtype_error_oom, abs=1)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check is useful but the logic here ends up being quite complicated. I think we essentially want to check that

$$\frac{\log(e_\text{single})}{\log(e_\text{double})} \approx 2$$

where $e_\text{x}$ is the round-trip error using x precision floating-point types?

Would something like

Suggested change
# Allow a -/+1 margin for near-misses during rounding and taking log.
round_trip_error_long_dtype = abs(true_values - long_precision_result).max()
long_dtype_error_oom = np.round(np.log10(round_trip_error_long_dtype))
round_trip_error_short_dtype = abs(true_values - short_precision_result).max()
short_dtype_error_oom = np.round(np.log10(round_trip_error_short_dtype))
assert (
long_dtype_error_oom <= 2 * short_dtype_error_oom
or long_dtype_error_oom == pytest.approx(2 * short_dtype_error_oom, abs=1)
)
round_trip_error_long_dtype = abs(true_values - long_precision_result).max()
round_trip_error_short_dtype = abs(true_values - short_precision_result).max()
log_error_ratio = np.log(round_trip_error_long_dtype) / np.log(short_precision_result)
tolerance = 0.2
assert 2 - tolerance < log_error_ratio < 2 + tolerance

not suffice (and avoid some of the issues around instability due to moving either side of rounding thresholds)? tolerance value might need to be adjusted here!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Do you mean to have the ratio the other way round above?) The check should be equivalent to

$$\log(e_\text{double}) = 2\log(e_\text{single}) \pm 1$$

so if we want to use the ratio,

$$ \vert \frac{\log(e_\text{double})}{\log(e_\text{single})} \vert \leq 2 \pm \frac{1}{\vert \log(e_\text{single}) \vert}.$$

We're also always happy if we do better than expected, which is equivalent to

$$ 0 < \frac{\log(e_\text{double})}{\log(e_\text{single})} < 2,$$

($0 &lt;$ since both logs need to have the same sign, $&lt; 2$ since if the ratio is smaller than 2 then the single precision log-error is less than half the log-error of the double).

As such, I think the following code snippet should work

log_round_trip_error_double = np.log10(abs(true_values - double_calc_result).max())
log_round_trip_error_single = np.log10(abs(true_values - single_calc_result).max())
log_error_ratio = log_round_trip_error_single / log_round_trip_error_double
tolerance = 1 / np.abs(log_round_trip_error_single)

ratio_is_approx_2 = -tolerance <= (log_error_ratio - 2.0) <= tolerance
better_than_expected = 0.0 <= log_error_ratio <= 2.0
assert ratio_is_approx_2 or better_than_expected

In any case you've caught a bug in my original logic too 😅 The long_dtype_error_oom <= 2 * short_dtype_error_oom comparison should have the reverse inequality sign (we've done better than expected if double the OOM of the single precision calculation is less than the double precision calc). 😅 This is corrected in the code snippet above now that we're using your idea of ratios instead... but now I'm seeing a lot of errors that aren't within these regions. HEALPix in particular has a ratio around 3 (error $\sim 10^{-12}$ in double, $\sim 10^{-4}$ in single). So perhaps some actual error analysis is warranted to see why this is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants