Confusion regarding traced objects and if-statements #9322
-
Hi all - I am trying to optimize some of my code using vmap and jit but am running into some errors when the compiler looks at my jnp.where statement. Please see below. import jax.numpy as jnp
from jax import grad, jit, vmap
import numpy as np
import math
import time
import matplotlib.pyplot as plt
# # # # # # # # # # # # # # # # # # # # #
def initializeGaussian(grid, gridpoint, mu, sigma):
CF_eval = jnp.exp(1j * mu * gridpoint - 0.5 * gridpoint ** 2 * sigma ** 2)
return CF_eval
def intermediateLayer(grid, CFx, gridpoint, w, b):
CFy_eval = jnp.exp(1j * b * gridpoint) * jnp.interp(gridpoint, grid, CFx)
return CFy_eval
def computeHilbertTransform(f, grid, x, h, M):
sum = 0
hilb_grid = jnp.linspace(-M, M, 2 * M + 1)
for m in hilb_grid:
numerator = 1 - jnp.cos(jnp.pi * (x - m * h) / h)
denominator = jnp.pi * (x - m * h) / h
sum += jnp.interp(m * h, grid, f) * numerator / denominator
return sum
def hilbertTransform(f, grid, x, h, M):
hilb_grid_scaled = jnp.linspace(-M * h, M * h, 2 * M + 1)
return jnp.where(x in hilb_grid_scaled, 0, computeHilbertTransform(f, grid, x, h, M))
def maxLayer(grid, CFy, gridpoint):
h, M = 0.1, 100
CFz_eval = 0.5 * (1 + jnp.interp(gridpoint, grid, CFy)) + 0.5 * (hilbertTransform(CFy, grid, gridpoint, h, M) - hilbertTransform(CFy, grid, 0, h, M))
return CFz_eval
# # # # # # # # # # # # # # # # # # # # #
# define cutoff and resolution
d, L = 10, 100
# create grid along each axis
grid = jnp.linspace(-d, d, L)
# define initial Gaussian
mu, sigma = 0, 1
# define weight and bias
w, b = 0.5, 1
# # # vectorize function maps # # #
vfunc_IG = vmap(initializeGaussian, in_axes = (None, 0, None, None), out_axes = 0)
vfunc_IL = vmap(intermediateLayer, in_axes = (None, None, 0, None, None), out_axes = 0)
vfunc_ML = vmap(maxLayer, in_axes = (None, None, 0), out_axes = 0)
# # # jit it up # # #
vfunc_IG_jit = jit(vfunc_IG)
vfunc_IL_jit = jit(vfunc_IL)
vfunc_ML_jit = jit(vfunc_ML)
# warm-up
CFx = vfunc_IG_jit(grid, grid, mu, sigma).block_until_ready()
CFy = vfunc_IL_jit(grid, CFx, grid, w, b).block_until_ready()
CFz = vfunc_ML_jit(grid, CFy, grid).block_until_ready()
tic = time.perf_counter()
CFx = vfunc_IG_jit(grid, grid, mu, sigma).block_until_ready()
CFy = vfunc_IL_jit(grid, CFx, grid, w, b).block_until_ready()
CFz = vfunc_ML_jit(grid, CFy, grid).block_until_ready()
toc = time.perf_counter()
print(toc - tic)
# DEBUGGING - plot CFs
plt.plot(grid, CFx)
plt.plot(grid, CFy)
plt.plot(grid, CFz)
plt.show() It is erroring on the line
If you look at the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Python's |
Beta Was this translation helpful? Give feedback.
Python's
in
operator cannot be overloaded, and so expressions likex in hilb_grid_scaled
are not compatible with traced values. You could use something likejnp.any(x == hilb_grid_scaled)
instead.