Efficient ways to only sum over tril part of a point-wise scalar-value function(R^d -> R) evaluates on a n×n×d tensor. #9692
Unanswered
YouJiacheng
asked this question in
Q&A
Replies: 1 comment 3 replies
-
😭 |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Update:
n×n×d
tensor is produced by (self) pair-wise difference of an×d
tensor(actually coordinates ofn
points, andd
typically ≤3). That is why I only need the tril/triu part.I am trying to implement fast multipole method in JAX, any help will be appreciated. 🥰
In detail, there are 2 tasks:
Task A:
Task B:
Here is my use case and some implementations.
use
jnp.tril
after point-wise function maybe faster with trivial point-wise function sincelax.select
is faster thanlax.gather
orlax.scatter
, but the redundant computation can be significant if point-wise function become more complex. (Actually sum over last dim is complex enough to makepost_tril
slower)The efficiency of gradient evaluation should be considered as well.
Especially,
pre_mask
andpre_idx
jit compilation time is painfully long if n >= 5000.Beta Was this translation helpful? Give feedback.
All reactions