-
Hi all, I'm toying around with a data structure that represents a centered, variance-scaled dataset using a sparse representation and centering scaling operations as a linear operator. While initially I had rolled my own, I've had fun using the lineax framework to simplify some of internal mechanics. I have run into some issues with numerical stability in that the output between dense matrix/vector produces differs from my data structure's matrix/vector products. I've tried digging into when/how the errors cumulate, but haven't been successful atm. Individual operations seem to be numerical close, but their sum results in too much atol/rtol for I'm including a minimal working example below to highlight where the issue arises using an 'un-rolled' representation of the linear operator. #! /usr/bin/env python
import argparse as ap
import os
import sys
import jax.experimental.sparse as sparse
import jax.numpy as jnp
import jax.random as rdm
from jax.config import config
config.update("jax_enable_x64", True)
config.update("jax_default_matmul_precision", "highest")
def _binomial(key, N, p, shape):
B = jnp.sum(rdm.bernoulli(key, p, shape=(N,) + shape).astype(int), axis=0)
return B
def main(args):
argp = ap.ArgumentParser(description="")
argp.add_argument("-s", "--seed", type=int, default=0)
argp.add_argument("-n", type=int, default=50)
argp.add_argument("-p", type=int, default=100)
argp.add_argument("-o", "--output", type=ap.FileType("w"), default=sys.stdout)
args = argp.parse_args(args)
key = rdm.PRNGKey(args.seed)
maf = 0.1
N, P = args.n, args.p
key, g_key = rdm.split(key)
# simulate matrix
G = _binomial(g_key, 2, maf, (N, P)).astype(jnp.int8)
# grab center/scaling values
M = jnp.mean(G, axis=0)
S = 1.0 / jnp.std(G, axis=0)
# centtered G, centered-scaled G
C = G - M
Z = C * S
# sparse representation
Gsp = sparse.BCOO.fromdense(G)
# (Gsp - T @ M) @ Sd should be mathematically equiv to Z
# does it behave the same numerically?
T = jnp.ones((N, 1))
Mj = M.reshape((1, P))
Sd = jnp.diag(S)
key, r_key = rdm.split(key)
R = rdm.normal(r_key, shape=(P,))
SdR = Sd @ R
# standard dot product between G @ R
args.output.write(
f"Stable[G @ R] = {jnp.allclose(G @ R, Gsp @ R)}" + os.linesep
) # dot product is same
# centered-G @ R
args.output.write(
f"Stable[C @ R] = {jnp.allclose(C @ R, Gsp @ R - T @ Mj @ R)}" + os.linesep
) # centered dot product is same; centering is okay!
# scaled-G @ R
args.output.write(
f"Stable[G/S @ R] = {jnp.allclose(G @ SdR, Gsp @ SdR)}" + os.linesep
) # scaled dot product is same; scaling is okay!
# save for atol/rtol
expected = Z @ R
observed = Gsp @ SdR - T @ Mj @ SdR
atol = jnp.abs(expected - observed)
rtol = atol / jnp.minimum(jnp.abs(expected), jnp.abs(observed))
# centered-scaled-G @ R
args.output.write(
f"Stable[Z @ R] = {jnp.allclose(expected, observed)}" + os.linesep
) # ope! scaling + centering is not okay!
args.output.write(f"Max atol[Z @ R] = {jnp.max(atol)}" + os.linesep)
args.output.write(f"Max rtol[Z @ R] = {jnp.max(rtol)}" + os.linesep)
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:])) Running
I should note that this is on CPU/Apple M1, jax version '0.4.11' jaxlib version '0.4.10'. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
For what it's worth all the |
Beta Was this translation helpful? Give feedback.
-
It looks like the issue is that you're comparing float32 arithmetic to float64 arithmetic. If I change this line: M = jnp.mean(G, axis=0) to this: M = jnp.mean(G, axis=0).astype('float64') Then I get this output:
|
Beta Was this translation helpful? Give feedback.
It looks like the issue is that you're comparing float32 arithmetic to float64 arithmetic. If I change this line:
to this:
Then I get this output: