Skip to content

Commit f26384d

Browse files
committed
Deploying to gh-pages from @ 28a364d 🚀
1 parent 5a94439 commit f26384d

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

generated/array_api_extra.testing.lazy_xp_function.html

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@
273273
<h1>array_api_extra.testing.lazy_xp_function<a class="headerlink" href="#array-api-extra-testing-lazy-xp-function" title="Link to this heading"></a></h1>
274274
<dl class="py function">
275275
<dt class="sig sig-object py" id="array_api_extra.testing.lazy_xp_function">
276-
<span class="sig-prename descclassname"><span class="pre">array_api_extra.testing.</span></span><span class="sig-name descname"><span class="pre">lazy_xp_function</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">func</span></span></em>, <em class="sig-param"><span class="keyword-only-separator o"><abbr title="Keyword-only parameters separator (PEP 3102)"><span class="pre">*</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">allow_dask_compute</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">jax_jit</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">static_argnums</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">static_argnames</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#array_api_extra.testing.lazy_xp_function" title="Link to this definition"></a></dt>
276+
<span class="sig-prename descclassname"><span class="pre">array_api_extra.testing.</span></span><span class="sig-name descname"><span class="pre">lazy_xp_function</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">func</span></span></em>, <em class="sig-param"><span class="keyword-only-separator o"><abbr title="Keyword-only parameters separator (PEP 3102)"><span class="pre">*</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">allow_dask_compute</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">jax_jit</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">static_argnums</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">Deprecated.DEPRECATED</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">static_argnames</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">Deprecated.DEPRECATED</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#array_api_extra.testing.lazy_xp_function" title="Link to this definition"></a></dt>
277277
<dd><p>Tag a function to be tested on lazy backends.</p>
278278
<p>Tag a function so that when any tests are executed with <code class="docutils literal notranslate"><span class="pre">xp=jax.numpy</span></code> the
279279
function is replaced with a jitted version of itself, and when it is executed with
@@ -303,14 +303,27 @@ <h1>array_api_extra.testing.lazy_xp_function<a class="headerlink" href="#array-a
303303
<p>Default: False, meaning that <cite>func</cite> must be fully lazy and never materialize the
304304
graph.</p>
305305
</p></li>
306-
<li><p><strong>jax_jit</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><em>bool</em></a><em>, </em><em>optional</em>) – Set to True to replace <cite>func</cite> with <code class="docutils literal notranslate"><span class="pre">jax.jit(func)</span></code> after calling the
307-
<a class="reference internal" href="array_api_extra.testing.patch_lazy_xp_functions.html#array_api_extra.testing.patch_lazy_xp_functions" title="array_api_extra.testing.patch_lazy_xp_functions"><code class="xref py py-func docutils literal notranslate"><span class="pre">patch_lazy_xp_functions()</span></code></a> test helper with <code class="docutils literal notranslate"><span class="pre">xp=jax.numpy</span></code>. Set to False
308-
if <cite>func</cite> is only compatible with eager (non-jitted) JAX. Default: True.</p></li>
309-
<li><p><strong>static_argnums</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.13)"><em>int</em></a><em> | </em><em>Sequence</em><em>[</em><a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.13)"><em>int</em></a><em>]</em><em>, </em><em>optional</em>) – Passed to jax.jit. Positional arguments to treat as static (compile-time
310-
constant). Default: infer from <cite>static_argnames</cite> using
311-
<cite>inspect.signature(func)</cite>.</p></li>
312-
<li><p><strong>static_argnames</strong> (<a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><em>str</em></a><em> | </em><em>Iterable</em><em>[</em><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><em>str</em></a><em>]</em><em>, </em><em>optional</em>) – Passed to jax.jit. Named arguments to treat as static (compile-time constant).
313-
Default: infer from <cite>static_argnums</cite> using <cite>inspect.signature(func)</cite>.</p></li>
306+
<li><p><strong>jax_jit</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><em>bool</em></a><em>, </em><em>optional</em>) – <p>Set to True to replace <cite>func</cite> with a smart variant of <code class="docutils literal notranslate"><span class="pre">jax.jit(func)</span></code> after
307+
calling the <a class="reference internal" href="array_api_extra.testing.patch_lazy_xp_functions.html#array_api_extra.testing.patch_lazy_xp_functions" title="array_api_extra.testing.patch_lazy_xp_functions"><code class="xref py py-func docutils literal notranslate"><span class="pre">patch_lazy_xp_functions()</span></code></a> test helper with <code class="docutils literal notranslate"><span class="pre">xp=jax.numpy</span></code>.
308+
Set to False if <cite>func</cite> is only compatible with eager (non-jitted) JAX.</p>
309+
<p>Unlike with vanilla <code class="docutils literal notranslate"><span class="pre">jax.jit</span></code>, all arguments and return types that are not JAX
310+
arrays are treated as static; the function can accept and return arbitrary
311+
wrappers around JAX arrays. This difference is because, in real life, most users
312+
won’t wrap the function directly with <code class="docutils literal notranslate"><span class="pre">jax.jit</span></code> but rather they will use it
313+
within their own code, which is itself then wrapped by <code class="docutils literal notranslate"><span class="pre">jax.jit</span></code>, and
314+
internally consume the function’s outputs.</p>
315+
<p>In other words, the pattern that is being tested is:</p>
316+
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="nd">@jax</span><span class="o">.</span><span class="n">jit</span>
317+
<span class="gp">... </span><span class="k">def</span><span class="w"> </span><span class="nf">user_func</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
318+
<span class="gp">... </span> <span class="n">y</span> <span class="o">=</span> <span class="n">user_prepares_inputs</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
319+
<span class="gp">... </span> <span class="n">z</span> <span class="o">=</span> <span class="n">func</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">some_static_arg</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
320+
<span class="gp">... </span> <span class="k">return</span> <span class="n">user_consumes</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
321+
</pre></div>
322+
</div>
323+
<p>Default: True.</p>
324+
</p></li>
325+
<li><p><strong>static_argnums</strong> (<span class="sphinx_autodoc_typehints-type"><code class="xref py py-class docutils literal notranslate"><span class="pre">Deprecated</span></code></span>) – Deprecated; ignored</p></li>
326+
<li><p><strong>static_argnames</strong> (<span class="sphinx_autodoc_typehints-type"><code class="xref py py-class docutils literal notranslate"><span class="pre">Deprecated</span></code></span>) – Deprecated; ignored</p></li>
314327
</ul>
315328
</dd>
316329
<dt class="field-even">Return type<span class="colon">:</span></dt>
@@ -334,7 +347,7 @@ <h1>array_api_extra.testing.lazy_xp_function<a class="headerlink" href="#array-a
334347

335348
<span class="k">def</span><span class="w"> </span><span class="nf">test_myfunc</span><span class="p">(</span><span class="n">xp</span><span class="p">):</span>
336349
<span class="n">a</span> <span class="o">=</span> <span class="n">xp</span><span class="o">.</span><span class="n">asarray</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
337-
<span class="c1"># When xp=jax.numpy, this is the same as `b = jax.jit(myfunc)(a)`</span>
350+
<span class="c1"># When xp=jax.numpy, this is similar to `b = jax.jit(myfunc)(a)`</span>
338351
<span class="c1"># When xp=dask.array, crash on compute() or persist()</span>
339352
<span class="n">b</span> <span class="o">=</span> <span class="n">myfunc</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
340353
</pre></div>

0 commit comments

Comments
 (0)