Fixed (Static) Shape and Type for JIT Compilation #10518
-
Hello JAX Community; As far as I understand from this context, although just-in-time compilation in JAX provides great contribution to both computational efficiency and reduced memory consumption, it comes up with some kinds restrictions and constrains that we are supposed to obey like static array shapes. When we jit-compile a function, we actually want to creat its a version that we can cache and reuse for many different argument values. If we jit-compile a function on the array like In this case, should not our compiled function work on for example, 2D matrices because they don't adhere to the abstraction of Edited: When I put a key = jax.random.PRNGKey(13)
key1, key2 = jax.random.split(key)
def matrix_product(matrix1, matrix2):
result = matrix1 @ matrix2
return result
jit_matrix_product = jax.jit(matrix_product)
# first series of input arguments
vector1 = jnp.array([1, 2, 3], dtype=jnp.float32)
vector2 = jnp.array([4, 5, 6], dtype=jnp.float32)
# second series of input arguments
matrix1 = jax.random.uniform(key1, (200, 200))
matrix2 = jax.random.uniform(key2, (200, 300))
print("Result of matrix multiplications 1: ", jit_matrix_product(vector1, vector2))
print("Result of matrix multiplications 2: ", jit_matrix_product(matrix1, matrix2)) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
JAX's JIT does just-in-time compilation of functions based on their inputs. Compiling the function for one set of inputs does not preclude later compiling the function again for a second set. The first time you call a function with inputs of a given shape & dtype, the function is compiled, and then executed. The second time you call the function with inputs of matching shapes/dtypes, the cached compilation result is used (i.e. no new compilation is necessary). If you call the function with a different set of shapes & dtypes, the function is compiled again and the result of the new compilation is cached. This section of the docs might be good background on the mechanics of JIT: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables Does that answer your question? |
Beta Was this translation helpful? Give feedback.
JAX's JIT does just-in-time compilation of functions based on their inputs. Compiling the function for one set of inputs does not preclude later compiling the function again for a second set.
The first time you call a function with inputs of a given shape & dtype, the function is compiled, and then executed. The second time you call the function with inputs of matching shapes/dtypes, the cached compilation result is used (i.e. no new compilation is necessary). If you call the function with a different set of shapes & dtypes, the function is compiled again and the result of the new compilation is cached.
This section of the docs might be good background on the mechanics of JIT: https://jax…