Skip to content
Discussion options

You must be logged in to vote

No, JAX will not in general make copies of unmapped arrays, although its possible for particular primitives to do things this way in their batching rule. If you want to be sure, you can use jax.make_jaxpr and inspect the size of the arrays created by the transformed function. For example:

import jax
import jax.numpy as jnp

mapped_array = jnp.zeros((100000, 3))
array1 = jnp.arange(3)
array2 = jnp.ones(3)

def func(arr, arr1, arr2):
  return arr + arr1 + arr2

jax.make_jaxpr(jax.vmap(func, in_axes=(0, None, None)))(mapped_array, array1, array2)
{ lambda ; a:f32[100000,3] b:i32[3] c:f32[3]. let
    d:f32[3] = convert_element_type[new_dtype=float32 weak_type=False] b
    e:f32[1,3] = broadca…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by ajithmoola
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants