-
Notifications
You must be signed in to change notification settings - Fork 16
Open
Description
How important is being able to compile JAXNS into your JAX program?
Pros
- You can embed it into a single compiled executable, which you can ship to CPU or GPU.
- You get better performance (sometimes).
Cons
- Worse memory management.
- Distribution over compute resources is not great, as JAX relies on synchronous homogeneous SIMD (one single instruction multiple data kernel at a time over the same type of devices). And sampling best leverages concurrency, heterogeneous computing, and non-SIMD paradigm.
- Performance is often worse than it should be because the algorithm doesn't nicely fit into
pmaporshard_map. You always end up wasting wall-time. - Lose ability to make dynamic choices.
- Much hard to distribute over massive compute clusters.
Alternative: What if the the likelihood is the only thing that needs to be JAX?
What you get:
- you can still leverage auto-diff
- you can still ship to CPU or GPU or TPU
- you can distribute over massive compute clusters
- you can easily stop and resume
- almost no noticeable drop in performance
- improved performance in many cases
- but the algorithm itself is not pure JAX, so you cannot compile it.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels