JIT-compiled Histogram Function with a data-dependent bin width #6679
-
Hello, I have a question that is probably related to #5186. I wanted to know if there was a way to have a jit-compiled From the error messages, the problem occurs within the Thank you in advanced! Code ExampleContext: I would like to calculate the entropy using the histogram function to be used within an iterative method where the data is always changing. So I wanted to experiment with different bin estimators to check how that affects my entropy estimation. This function will be used internally within other functions so it would be nice to jit-compile everything from the top level. Histogram Entropy Functionimport numpy as np
import jax
import jax.numpy as jnp
from jax.scipy.special import entr
def histogram_entropy(data, bins=None):
"""Estimate univariate entropy with a histogram
Notes
-----
* uses scott's method
* entropy is in nats
"""
# histogram bin width (scotts)
bin_width = 3.5 * jnp.std(data) / (data.shape[0] ** (1/3))
if bins is None:
# histogram bins
nbins = jnp.ceil((data.max() - data.min()) / bin_width)
nbins = nbins.astype(jnp.int32)
# get bins with linspace
bins = jnp.linspace(data.min(), data.max(), nbins)
# # bins with arange (similar to astropy)
# bins = data_min + bin_width * jnp.arange(0, nbins+1, 1)
# histogram
counts, bin_edges = jnp.histogram(data, bins=bins, density=False)
# normalized the bin counts for a density
pk = 1.0 * jnp.array(counts) / jnp.sum(counts)
# calculate entropy
H = entr(pk)
H = jnp.sum(H)
# add correction for continuous case
delta = bin_edges[3] - bin_edges[2]
H += jnp.log(delta)
return H Code: Jitted w/ Fixed Bin Widthdata = np.random.randn(1_000)
data = jnp.array(data, dtype=jnp.float32)
f = jax.jit(jax.partial(histogram_entropy, bins=10))
f(data.ravel()) DeviceArray(1.3789848, dtype=float32) Code: Jitted w/ Data Dependent Bin Widthdata = np.random.randn(1_000)
data = jnp.array(data, dtype=jnp.float32)
f = jax.jit(jax.partial(histogram_entropy, bins=10))
f(data.ravel()) ---------------------------------------------------------------------------
FilteredStackTrace Traceback (most recent call last)
<ipython-input-5-00599da09630> in <module>
2
----> 3 f(data.ravel())
<ipython-input-1-85524f19ef0c> in histogram_entropy(data)
14 # get bins with linspace
---> 15 bins = jnp.linspace(data.min(), data.max(), nbins)
16
~/.conda/envs/jax_py38/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in linspace(start, stop, num, endpoint, retstep, dtype, axis)
3094 lax._check_user_dtype_supported(dtype, "linspace")
-> 3095 num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace")
3096 axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace")
FilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
'num' argument of jnp.linspace
While tracing the function histogram_entropy at <ipython-input-1-85524f19ef0c>:4, this concrete value was not available in Python because it depends on the value of the arguments to histogram_entropy at <ipython-input-1-85524f19ef0c>:4 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
(https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError)
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
ConcretizationTypeError Traceback (most recent call last)
<ipython-input-5-00599da09630> in <module>
1 f = jax.jit(histogram_entropy)
2
----> 3 f(data.ravel())
~/.conda/envs/jax_py38/lib/python3.8/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
137 def reraise_with_filtered_traceback(*args, **kwargs):
138 try:
--> 139 return fun(*args, **kwargs)
140 except Exception as e:
141 if not is_under_reraiser(e):
~/.conda/envs/jax_py38/lib/python3.8/site-packages/jax/api.py in f_jitted(*args, **kwargs)
405 return cache_miss(*args, **kwargs)[0] # probably won't return
406 else:
--> 407 return cpp_jitted_f(*args, **kwargs)
408
409 f_jitted._cpp_jitted_f = cpp_jitted_f
~/.conda/envs/jax_py38/lib/python3.8/site-packages/jax/api.py in cache_miss(*args, **kwargs)
293 _check_arg(arg)
294 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 295 out_flat = xla.xla_call(
296 flat_fun,
297 *args_flat,
~/.conda/envs/jax_py38/lib/python3.8/site-packages/jax/core.py in bind(self, fun, *args, **params)
1400
1401 def bind(self, fun, *args, **params):
-> 1402 return call_bind(self, fun, *args, **params)
1403
1404 def process(self, trace, fun, tracers, params):
~/.conda/envs/jax_py38/lib/python3.8/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1391 tracers = map(top_trace.full_raise, args)
1392 with maybe_new_sublevel(top_trace):
-> 1393 outs = primitive.process(top_trace, fun, tracers, params)
1394 return map(full_lower, apply_todos(env_trace_todo(), outs))
1395
~/.conda/envs/jax_py38/lib/python3.8/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
1403
1404 def process(self, trace, fun, tracers, params):
-> 1405 return trace.process_call(self, fun, tracers, params)
1406
1407 def post_process(self, trace, out_tracers, params):
~/.conda/envs/jax_py38/lib/python3.8/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
598
599 def process_call(self, primitive, f, tracers, params):
--> 600 return primitive.impl(f, *tracers, **params)
601 process_map = process_call
602
~/.conda/envs/jax_py38/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
575
576 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 577 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
578 *unsafe_map(arg_spec, args))
579 try:
~/.conda/envs/jax_py38/lib/python3.8/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
258 fun.populate_stores(stores)
259 else:
--> 260 ans = call(fun, *args)
261 cache[key] = (ans, fun.stores)
262
~/.conda/envs/jax_py38/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
652 abstract_args, arg_devices = unzip2(arg_specs)
653 if config.omnistaging_enabled:
--> 654 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
655 if any(isinstance(c, core.Tracer) for c in consts):
656 raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
~/.conda/envs/jax_py38/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals)
1226 main.source_info = fun_sourceinfo(fun.f) # type: ignore
1227 main.jaxpr_stack = () # type: ignore
-> 1228 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1229 del fun, main
1230 return jaxpr, out_avals, consts
~/.conda/envs/jax_py38/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1206 trace = DynamicJaxprTrace(main, core.cur_sublevel())
1207 in_tracers = map(trace.new_arg, in_avals)
-> 1208 ans = fun.call_wrapped(*in_tracers)
1209 out_tracers = map(trace.full_raise, ans)
1210 jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)
~/.conda/envs/jax_py38/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
164
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
168 # Some transformations yield from inside context managers, so we have to
<ipython-input-1-85524f19ef0c> in histogram_entropy(data)
13
14 # get bins with linspace
---> 15 bins = jnp.linspace(data.min(), data.max(), nbins)
16
17 # # bins with arange (similar to astropy)
~/.conda/envs/jax_py38/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in linspace(start, stop, num, endpoint, retstep, dtype, axis)
3093 """Implementation of linspace differentiable in start and stop args."""
3094 lax._check_user_dtype_supported(dtype, "linspace")
-> 3095 num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace")
3096 axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace")
3097 if num < 0:
~/.conda/envs/jax_py38/lib/python3.8/site-packages/jax/core.py in concrete_or_error(force, val, context)
966 return force(val.aval.val)
967 else:
--> 968 raise ConcretizationTypeError(val, context)
969 else:
970 return force(val)
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
'num' argument of jnp.linspace
While tracing the function histogram_entropy at <ipython-input-1-85524f19ef0c>:4, this concrete value was not available in Python because it depends on the value of the arguments to histogram_entropy at <ipython-input-1-85524f19ef0c>:4 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
(https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
JIT compiled functions cannot produce arrays whose static attributes (shape and dtype) are dependent on traced quantities. So no, there is not any way within a JIT-compiled context to produce a histogram with a number of bins that depends upon data values if the data is traced. One option to get around this is to compute the value-dependent statistic you're using outside of JIT (in this case |
Beta Was this translation helpful? Give feedback.
JIT compiled functions cannot produce arrays whose static attributes (shape and dtype) are dependent on traced quantities. So no, there is not any way within a JIT-compiled context to produce a histogram with a number of bins that depends upon data values if the data is traced.
One option to get around this is to compute the value-dependent statistic you're using outside of JIT (in this case
data.max() - data.min()
) and pass it to the function as a static quantity