Profile Jax Jitted Code - Find n slowest functions #20766
Unanswered
KaleabTessera
asked this question in
Q&A
Replies: 1 comment 1 reply
-
You can directly measure the execution time of individual jitted functions in JAX code using the time module, along with Example [gist]: import jax
import jax.numpy as jnp
from jax import random
import time
# Define your jitted functions
@jax.jit
def slow_function(x, key):
# Increased matrix size to demonstrate slow execution
random_eig = random.normal(key, (5000, 5000)) # Consider adjusting matrix size for fair comparison
jnp.linalg.eig(random_eig)
return x
@jax.jit
def fast_function(x, key):
# Your computation here
random_eig = random.normal(key, (10, 10))
jnp.linalg.eig(random_eig)
return x
# Function to be profiled (Consider separating setup and measurement)
def run_code(key):
# Pre-allocate memory (optional for some workloads)
slow_result = jnp.zeros((5000, 5000))
fast_result = jnp.zeros((10, 10))
for _ in range(100):
key, subkey = random.split(key)
slow_result = slow_function(slow_result, subkey) # Avoid redundant allocation
slow_result.block_until_ready() # Ensure computation is complete
key, subkey = random.split(key)
fast_result = fast_function(fast_result, subkey)
fast_result.block_until_ready() # Ensure computation is complete
# Generate a random key
key = random.PRNGKey(0)
# Measure execution time (Consider multiple runs for better accuracy)
start_time_slow = time.time()
run_code(key)
end_time_slow = time.time()
execution_time_slow_function = end_time_slow - start_time_slow
# Measure execution time of fast_function
start_time_fast = time.time()
run_code(key)
end_time_fast = time.time()
execution_time_fast_function = end_time_fast - start_time_fast
# Print the execution time
print("Execution time of slow_function:", execution_time_slow_function)
print("Execution time of fast_function:", execution_time_fast_function) Output:
For more detailed profiling, you can leverage JAX's Additionally, if you're using GPU or TPU with JAX, you can leverage JAX functionality more efficiently, optimizing the performance further. |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have a large file with many jitted functions and I am running my code on CPU. I want to find the
n
slowest jitted functions, so that I can focus on making these functions faster.I have tried all the instructions here, and honestly the profile plots are really hard to figure out and don't seem very useful.
Is there a simple way to just see the slowest jitted functions in my code or how long each jitted function takes to run on average or any relevant metrics?
Beta Was this translation helpful? Give feedback.
All reactions