Replies: 1 comment 3 replies
-
You can explore the underlying computations used by each of these approaches with with jax.disable_jit(): # for cleaner jaxprs
print("gram1:")
print(jax.make_jaxpr(gram1)(X, Y))
print("\ngram2:")
print(jax.make_jaxpr(gram2)(X, Y))
print("\ngram3:")
print(jax.make_jaxpr(gram3)(X, Y)) Output:
You can see in the jaxprs that Now it's true that in general XLA compilation may re-arrange these kinds of computations to avoid unnecessary allocation, but it appears that for this particular compilation the compiler does not automatically reduce the operation to the more efficient form. I suspect the reason this kind of rewrite is not written into the compiler is that the import numpy as np
x = np.float32(1E8)
y = x + 1
print((y - x) * (y - x))
# 1.0
print(y ** 2 + x ** 2 - 2 * y * x)
# 0.0 |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
while trying to compute Gram matrices for kernel methods with JAX, I realized that the computation time can vary considerably, even for simple l2 distances.
Here are three different ways to compute the same thing:
Then the outputs will be:
So it is obvious that the latest implementation is the fastest.
Does anyone have any idea why there are such large differences in computation time (x40 between the slowest and fastest)?
Beta Was this translation helpful? Give feedback.
All reactions