Replies: 2 comments 1 reply
-
A simple suggestion based on what you already have: You can try using the |
Beta Was this translation helpful? Give feedback.
-
Some updates After many attempts, and inspired by #4968, I managed to improve the performance by creating a custom batched version of Pseudocode: @jax.jit
def rays_intersect_any_triangle(
ray_origins: Float[Array, "num_rays 3"],
ray_directions: Float[Array, "num_rays 3"],
triangle_vertices: Float[Array, "num_triangles 3 3"],
) -> Bool[Array, "num_rays"]:
batch_size = 1024
num_batches = triangle_vertices.shape[-3] // batch_size
# ... make sure to pad arrays to be a multiple of batch_size
def body_fun(batch_index, intersect):
batch_of_triangle_vertices = jax.lax.dynamic_slice_in_dim(
triangles_vertices, batch_index * batch_size, batch_size, axis=-3)
return intersect | jax.vmap(ray_intersect_triangle)(...).any(axis=-1)
return jax.lax.fori_loop(0, num_batches, body_fun, init_val=jnp.zeros(..., dtype=bool)) The actual code (if you are curious): https://github.com/jeertmans/DiffeRT/pull/300/files. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi! First, thanks for this great package!
I use JAX heavily for my research, for which I implemented ray tracing utilities.
One critical operation among many is the ray-triangle intersection test. We usually have millions of rays and millions of triangles, and we would like to test, for each ray, if it intersects with any triangle. A common approach to check for this is to use the Möller-Trumbore algorithm, which I implemented using JAX.
However, I am struggling to improve the performance of this function, and I am reaching out to see if anyone has any suggestions. I am considering moving to Mitsuba (or using DrJit directly) or implementing a C++ (or Rust) version, but I hope I can find a way to improve the code without leaving JAX.
After reading some related discussions and issues (https://stackoverflow.com/questions/77527847/jax-vmap-limit-memory, #11319, https://stackoverflow.com/questions/77659069/how-to-understand-and-debug-memory-usage-with-jax), I realized that JIT compilation could optimize away the need for a large allocation.
Here is the code to reproduce (
uv run file.py
works):FYI, I have an NVIDIA 3070 with 8192 MiB of memory, which is not enormous but should do the job as it should not need to allocate a matrix that is too large to fit this.
N.B.: I am aware that I could use a hierarchical representation of the scene to accelerate the intersection-check, but I couldn't find any JAX-compatible library for that, so I would prefer avoiding this solution :'-)
Thanks in advance for your help!
EDIT: other (unsuccessful) things I have considered:
vmap
s: inner is on all rays and outer (reduced) is on all triangles;vmap
haveout_axes=-1
(and thus reduction on last axis);inline=True
withjax.jit
.Beta Was this translation helpful? Give feedback.
All reactions