Skip to content
Discussion options

You must be logged in to vote

Q1: I think there isn't a pytorch-like API that can interactively check device memory usage. There is a non-interactive but comprehensive method for device memory profiling https://jax.readthedocs.io/en/latest/device_memory_profiling.html
Q2: I think default behavior of jit is lazy evaluation, only using abstract array(without data) during compilation. see https://jax.readthedocs.io/en/latest/_autosummary/jax.ensure_compile_time_eval.html

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by mariogeiger
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants