-
I want to use a mask on an array of inputs One idea I had was to split the process into 3 operations:
Conceptually, the first step would follow something like this:
and the final step could follow the approach described in #16962. However, I would like to make the whole process both differentiable and jittable. Would be grateful for any help on this! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
There is no one-size-fits-all answer here, because it's all about tradeoffs. Broadly speaking, you have an array of inputs that you would like to map to an array of outputs. You also have a device (GPU or TPU) that's purpose-built for array-oriented computing. There are two options: (1) embrace that array-oriented computing, and compute your function for every entry in the array (taking advantage of the implicit array-oriented parallelism in the device architecture) and then mask out the results you don't want. There is wasted computation here because you are computing an expensive result that will be thrown away in some cases, but the benefit is you are using the hardware in precisely the way it was designed. (2) turn away from array-oriented computing, with the goal of avoiding this wasted computation. You could do this via some sort of sequential operation (e.g. stepping through the entries with In most cases, approach (1) will win out, because the benefit of fully utilizing the accelerator architecture typically outweighs the disadvantage of extra computation, and this is partly why this approach is easiest to express in JAX and XLA. In some special cases, (2) will be better. There's no magic bullet, though, and the overhead involved with moving data, recompiling kernels, etc. will be very expensive. But if it's less expensive than the wasted computation in (1), it may be worth it. Still, you'd lose some of the advantages of JAX (e.g. with a dynamic mask size, you'll not be able to use Does that help answer your question? |
Beta Was this translation helpful? Give feedback.
There is no one-size-fits-all answer here, because it's all about tradeoffs.
Broadly speaking, you have an array of inputs that you would like to map to an array of outputs. You also have a device (GPU or TPU) that's purpose-built for array-oriented computing.
There are two options:
(1) embrace that array-oriented computing, and compute your function for every entry in the array (taking advantage of the implicit array-oriented parallelism in the device architecture) and then mask out the results you don't want. There is wasted computation here because you are computing an expensive result that will be thrown away in some cases, but the benefit is you are using the hardware in precisely th…