Summary
When converting a formula such as $\frac{X}{1}$, this will fail because when converted into JAX it will be treated as an integer. This is a problem because integers can be summed and multiplied, but when being divided they are done like this: $X * 1^{-1}$, and jax cannot compute the gradient for integers that are raised to a power.