Skip to content

interpol crash #121

@AlexGKim

Description

@AlexGKim

This interpolation call crashes. The equivalent call with numpy.interp works fine. Unfortunately my jax fluency is poor so that I couldn't immediately solve this.

jc.scipy.interpolate.interp(numpy.array([0.5,1.5]),numpy.array([0.,1.,2.]),numpy.array([0.,1.,2.]))

---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
Cell In[40], line 1
----> 1 jc.scipy.interpolate.interp(numpy.array([0.5,1.5]),numpy.array([0.,1.,2.]),numpy.array([0.,1.,2.]))

    [... skipping hidden 3 frame]

File ~/opt/anaconda3/envs/unity3/lib/python3.12/site-packages/jax_cosmo/scipy/interpolate.py:33, in interp(x, xp, fp)
     30 # Perform linear interpolation
     31 ind = np.clip(ind, 1, len(xp) - 2)
---> 33 xi = xp[ind]
     34 # Figure out if we are on the right or the left of nearest
     35 s = np.sign(np.clip(x, xp[1], xp[-2]) - xi).astype(np.int64)

File ~/opt/anaconda3/envs/unity3/lib/python3.12/site-packages/jax/_src/core.py:710, in Tracer.__array__(self, *args, **kw)
    709 def __array__(self, *args, **kw):
--> 710   raise TracerArrayConversionError(self)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int64[].
This BatchTracer with object id 5183025136 was created on line:
  /var/folders/91/bt9dzsj545130th75px54m0h0000gq/T/ipykernel_25064/3155834959.py:1 (<module>)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions