Skip to content

Question to community: How important is being able to compile JAXNS into your JAX program? #239

@Joshuaalbert

Description

@Joshuaalbert

How important is being able to compile JAXNS into your JAX program?

Pros

  1. You can embed it into a single compiled executable, which you can ship to CPU or GPU.
  2. You get better performance (sometimes).

Cons

  1. Worse memory management.
  2. Distribution over compute resources is not great, as JAX relies on synchronous homogeneous SIMD (one single instruction multiple data kernel at a time over the same type of devices). And sampling best leverages concurrency, heterogeneous computing, and non-SIMD paradigm.
  3. Performance is often worse than it should be because the algorithm doesn't nicely fit into pmap or shard_map. You always end up wasting wall-time.
  4. Lose ability to make dynamic choices.
  5. Much hard to distribute over massive compute clusters.

Alternative: What if the the likelihood is the only thing that needs to be JAX?

What you get:

  1. you can still leverage auto-diff
  2. you can still ship to CPU or GPU or TPU
  3. you can distribute over massive compute clusters
  4. you can easily stop and resume
  5. almost no noticeable drop in performance
  6. improved performance in many cases
  7. but the algorithm itself is not pure JAX, so you cannot compile it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions