Skip to content
Discussion options

You must be logged in to vote

You may be able to rely on the compiler to fuse operations. Even if the jaxpr indicates an intermediate value of a particular shape, it doesn't necessarily mean that the compiled operation will instantiate that intermediate value.

For example, here's the compiled HLO produced by your broadcast function on a T4 GPU:

print(jax.jit(broadcast).lower(weights, xs, idx).compile().as_text())
HloModule jit_broadcast, is_scheduled=true, entry_computation_layout={(f32[8,128,512]{2,1,0}, f32[256,128]{1,0}, s32[256]{0})->f32[256,512]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="4167…

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@ScottAlexanderCameron
Comment options

@ScottAlexanderCameron
Comment options

@jakevdp
Comment options

@ScottAlexanderCameron
Comment options

Answer selected by ScottAlexanderCameron
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants