Skip to content
Discussion options

You must be logged in to vote

The limitation is that the matrix inverse is not implemented in bfloat16:

import jax.numpy as jnp
x = jnp.ones((3, 3), dtype='bfloat16')
jnp.linalg.inv(x)
# NotImplementedError: Unsupported dtype bfloat16

One way you could work around this by casting to float32:

jnp.linalg.inv(x.astype('float32')).astype('bfloat16')

One aside: in general, you should avoid computing an explicit floating-point inverse, and use a linear solve instead.
So instead of jnp.dot(jnp.linalg.inv(I_w), jac_full.T), you should write:

jnp.linalg.solve(I_w, jac_full.T)

Here solve has the same issue as inv; if the arguments are bfloat16 you'll have to cast to float32.

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@pvasired
Comment options

@jakevdp
Comment options

@pvasired
Comment options

Answer selected by pvasired
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants