Skip to content
Discussion options

You must be logged in to vote

Here are the raw trace_viewer in TensorBoard and an detailed examination:

I also make a few other attempts like jit the version1, or try to avoid the y1 & y2. But none of them give the desired result:

version1_jitted = jax.jit(version1)

def version2(x):
    return fn1(fn0(x))

version2_jitted = jax.jit(version2)

def version3(x):
    y1 = fn0(x)
    y1 = jax.device_put(y1, device=GPUS[1])
    y2 = fn1(y1)
    return y2

version3_jitted = jax.jit(version3)

Replies: 1 comment 8 replies

Comment options

You must be logged in to vote
8 replies
@hawkinsp
Comment options

@MingRuey
Comment options

@hawkinsp
Comment options

@MingRuey
Comment options

@MingRuey
Comment options

Answer selected by MingRuey
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants