-
I have the following function which I am trying to speed up taking gradients of using the bfloat16 data type.
However, when I try to call this function directly or during a
I am not sure if this is a limitation of my hardware (not compatible with bfloat16?) or if the values in the array are outside of the range that bfloat16 can represent. Would appreciate any advice here! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
The limitation is that the matrix inverse is not implemented in bfloat16: import jax.numpy as jnp
x = jnp.ones((3, 3), dtype='bfloat16')
jnp.linalg.inv(x)
# NotImplementedError: Unsupported dtype bfloat16 One way you could work around this by casting to float32: jnp.linalg.inv(x.astype('float32')).astype('bfloat16') One aside: in general, you should avoid computing an explicit floating-point inverse, and use a linear solve instead. jnp.linalg.solve(I_w, jac_full.T) Here |
Beta Was this translation helpful? Give feedback.
The limitation is that the matrix inverse is not implemented in bfloat16:
One way you could work around this by casting to float32:
One aside: in general, you should avoid computing an explicit floating-point inverse, and use a linear solve instead.
So instead of
jnp.dot(jnp.linalg.inv(I_w), jac_full.T)
, you should write:Here
solve
has the same issue asinv
; if the arguments are bfloat16 you'll have to cast to float32.