Skip to content

Commit b8848c0

Browse files
committed
Deploying to gh-pages from @ a5cb116 🚀
1 parent 03b855e commit b8848c0

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

generated/array_api_extra.at.html

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,24 @@ <h1>array_api_extra.at<a class="headerlink" href="#array-api-extra-at" title="Li
367367
<p><a class="reference external" href="https://sparse.pydata.org/">sparse</a>, as well as read-only arrays from libraries
368368
not explicitly covered by <code class="docutils literal notranslate"><span class="pre">array-api-compat</span></code>, are not supported by update
369369
methods.</p>
370+
<p>Boolean masks are supported on Dask and jitted JAX arrays exclusively
371+
when <cite>idx</cite> has the same shape as <cite>x</cite> and <cite>y</cite> is 0-dimensional.
372+
Note that this support is not available in JAX’s native
373+
<code class="docutils literal notranslate"><span class="pre">x.at[mask].set(y)</span></code>.</p>
374+
<p>This pattern:</p>
375+
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="n">mask</span> <span class="o">=</span> <span class="n">m</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
376+
<span class="gp">&gt;&gt;&gt; </span><span class="n">x</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">[</span><span class="n">mask</span><span class="p">])</span>
377+
</pre></div>
378+
</div>
379+
<p>Can’t be replaced by <cite>at</cite>, as it won’t work on Dask and JAX inside jax.jit:</p>
380+
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="n">mask</span> <span class="o">=</span> <span class="n">m</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
381+
<span class="gp">&gt;&gt;&gt; </span><span class="n">x</span> <span class="o">=</span> <span class="n">xpx</span><span class="o">.</span><span class="n">at</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">[</span><span class="n">mask</span><span class="p">])</span> <span class="c1"># Crash on Dask and jax.jit</span>
382+
</pre></div>
383+
</div>
384+
<p>You should instead use:</p>
385+
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="n">x</span> <span class="o">=</span> <span class="n">xp</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">m</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span>
386+
</pre></div>
387+
</div>
370388
<p class="rubric">Examples</p>
371389
<p>Given either of these equivalent expressions:</p>
372390
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span><span class="w"> </span><span class="nn">array_api_extra</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">xpx</span>

0 commit comments

Comments
 (0)