Python overheads of jitted functions? #18784
Unanswered
Qiustander
asked this question in
Q&A
Replies: 1 comment
-
Hi - thanks for the question. I think in order to answer this we'll need more information: there's too much guesswork in trying to understand what's going on given the code you've shown. Can you edit your question to add a minimal reproducible example? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all I implemented the Kalman filter which is based on TFP JAX substrates. Then I return all results wrapped with a
namedtuple
(for example, predicted_mean, predicted_covariance). Then I wrap the whole function like thisThe return values of kalman_filter are optional:
Then I find a very confusing phenomenon about the runtime. I tested with some cases
1D Kalman filter with 200 time steps (11 matrix multiplication): 0.002s if I only return the last likelihood which is a scalar with
log_likelihood=True
1D Kalman filter with 200 time steps (11 matrix multiplication): 0.02s if
log_likelihood=False
3D Kalman filter with 200 time steps (33 matrix multiplication): 0.045s if
log_likelihood=True
3D Kalman filter with 200 time steps (33 matrix multiplication): 0.069s if
log_likelihood=False
So it seems that there exists overhead that in caching? could we reduce this time? Thanks!
Beta Was this translation helpful? Give feedback.
All reactions