Skip to content
Discussion options

You must be logged in to vote

I looked into this. The problem is doing ensure_compile_time_eval under shard_map creates Manual HloShardings that don't really work in the execution part i.e. the impl rule of primitives.

Your workaround is fine for now but you can use jax.sharding.use_abstract_mesh to override the mesh (but just be careful about doing that under a shard_map).

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@yashk2810
Comment options

Answer selected by inailuig
@PhilipVinc
Comment options

@inailuig
Comment options

@yashk2810
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants