@@ -367,6 +367,24 @@ <h1>array_api_extra.at<a class="headerlink" href="#array-api-extra-at" title="Li
367367< p > < a class ="reference external " href ="https://sparse.pydata.org/ "> sparse</ a > , as well as read-only arrays from libraries
368368not explicitly covered by < code class ="docutils literal notranslate "> < span class ="pre "> array-api-compat</ span > </ code > , are not supported by update
369369methods.</ p >
370+ < p > Boolean masks are supported on Dask and jitted JAX arrays exclusively
371+ when < cite > idx</ cite > has the same shape as < cite > x</ cite > and < cite > y</ cite > is 0-dimensional.
372+ Note that this support is not available in JAX’s native
373+ < code class ="docutils literal notranslate "> < span class ="pre "> x.at[mask].set(y)</ span > </ code > .</ p >
374+ < p > This pattern:</ p >
375+ < div class ="highlight-default notranslate "> < div class ="highlight "> < pre > < span > </ span > < span class ="gp "> >>> </ span > < span class ="n "> mask</ span > < span class ="o "> =</ span > < span class ="n "> m</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> )</ span >
376+ < span class ="gp "> >>> </ span > < span class ="n "> x</ span > < span class ="p "> [</ span > < span class ="n "> mask</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="n "> f</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> [</ span > < span class ="n "> mask</ span > < span class ="p "> ])</ span >
377+ </ pre > </ div >
378+ </ div >
379+ < p > Can’t be replaced by < cite > at</ cite > , as it won’t work on Dask and JAX inside jax.jit:</ p >
380+ < div class ="highlight-default notranslate "> < div class ="highlight "> < pre > < span > </ span > < span class ="gp "> >>> </ span > < span class ="n "> mask</ span > < span class ="o "> =</ span > < span class ="n "> m</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> )</ span >
381+ < span class ="gp "> >>> </ span > < span class ="n "> x</ span > < span class ="o "> =</ span > < span class ="n "> xpx</ span > < span class ="o "> .</ span > < span class ="n "> at</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> mask</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="n "> set</ span > < span class ="p "> (</ span > < span class ="n "> f</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> [</ span > < span class ="n "> mask</ span > < span class ="p "> ])</ span > < span class ="c1 "> # Crash on Dask and jax.jit</ span >
382+ </ pre > </ div >
383+ </ div >
384+ < p > You should instead use:</ p >
385+ < div class ="highlight-default notranslate "> < div class ="highlight "> < pre > < span > </ span > < span class ="gp "> >>> </ span > < span class ="n "> x</ span > < span class ="o "> =</ span > < span class ="n "> xp</ span > < span class ="o "> .</ span > < span class ="n "> where</ span > < span class ="p "> (</ span > < span class ="n "> m</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> ),</ span > < span class ="n "> f</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> ),</ span > < span class ="n "> x</ span > < span class ="p "> )</ span >
386+ </ pre > </ div >
387+ </ div >
370388< p class ="rubric "> Examples</ p >
371389< p > Given either of these equivalent expressions:</ p >
372390< div class ="highlight-default notranslate "> < div class ="highlight "> < pre > < span > </ span > < span class ="gp "> >>> </ span > < span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> array_api_extra</ span > < span class ="w "> </ span > < span class ="k "> as</ span > < span class ="w "> </ span > < span class ="nn "> xpx</ span >
0 commit comments