You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
<spanclass="sig-prename descclassname"><spanclass="pre">array_api_extra.testing.</span></span><spanclass="sig-name descname"><spanclass="pre">lazy_xp_function</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">func</span></span></em>, <emclass="sig-param"><spanclass="keyword-only-separator o"><abbrtitle="Keyword-only parameters separator (PEP 3102)"><spanclass="pre">*</span></abbr></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">allow_dask_compute</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">0</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">jax_jit</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">True</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">static_argnums</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">None</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">static_argnames</span></span><spanclass="o"><spanclass="pre">=</span></span><spanclass="default_value"><spanclass="pre">None</span></span></em><spanclass="sig-paren">)</span><aclass="headerlink" href="#array_api_extra.testing.lazy_xp_function" title="Link to this definition">¶</a></dt>
275
275
<dd><p>Tag a function to be tested on lazy backends.</p>
276
-
<p>Tag a function, which must be imported in the test module globals, so that when any
277
-
tests defined in the same module are executed with <codeclass="docutils literal notranslate"><spanclass="pre">xp=jax.numpy</span></code> the function is
278
-
replaced with a jitted version of itself, and when it is executed with
276
+
<p>Tag a function so that when any tests are executed with <codeclass="docutils literal notranslate"><spanclass="pre">xp=jax.numpy</span></code> the
277
+
function is replaced with a jitted version of itself, and when it is executed with
279
278
<codeclass="docutils literal notranslate"><spanclass="pre">xp=dask.array</span></code> the function will raise if it attempts to materialize the graph.
280
279
This will be later expanded to provide test coverage for other lazy backends.</p>
281
280
<p>In order for the tag to be effective, the test or a fixture must call
<spanclass="n">a</span><spanclass="o">=</span><spanclass="n">xp</span><spanclass="o">.</span><spanclass="n">asarray</span><spanclass="p">([</span><spanclass="mi">1</span><spanclass="p">,</span><spanclass="mi">2</span><spanclass="p">])</span><spanclass="n">b</span><spanclass="o">=</span><spanclass="n">myfunc</span><spanclass="p">(</span><spanclass="n">a</span><spanclass="p">)</span><spanclass="c1"># This is jitted when xp=jax.numpy c =</span>
349
-
<spanclass="n">mymodule</span><spanclass="o">.</span><spanclass="n">myfunc</span><spanclass="p">(</span><spanclass="n">a</span><spanclass="p">)</span><spanclass="c1"># This is not</span>
<spanclass="n">b</span><spanclass="o">=</span><spanclass="n">myfunc</span><spanclass="p">(</span><spanclass="n">a</span><spanclass="p">)</span><spanclass="c1"># This is wrapped when xp=jax.numpy or xp=dask.array</span>
372
+
<spanclass="n">c</span><spanclass="o">=</span><spanclass="n">mymodule</span><spanclass="o">.</span><spanclass="n">myfunc</span><spanclass="p">(</span><spanclass="n">a</span><spanclass="p">)</span><spanclass="c1"># This is not</span>
<spanclass="n">b</span><spanclass="o">=</span><spanclass="n">mymodule</span><spanclass="o">.</span><spanclass="n">myfunc</span><spanclass="p">(</span><spanclass="n">a</span><spanclass="p">)</span><spanclass="c1"># This is wrapped when xp=jax.numpy or xp=dask.array</span>
387
+
<spanclass="n">c</span><spanclass="o">=</span><spanclass="n">naked</span><spanclass="o">.</span><spanclass="n">myfunc</span><spanclass="p">(</span><spanclass="n">a</span><spanclass="p">)</span><spanclass="c1"># This is not</span>
<spanclass="sig-prename descclassname"><spanclass="pre">array_api_extra.testing.</span></span><spanclass="sig-name descname"><spanclass="pre">patch_lazy_xp_functions</span></span><spanclass="sig-paren">(</span><emclass="sig-param"><spanclass="n"><spanclass="pre">request</span></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">monkeypatch</span></span></em>, <emclass="sig-param"><spanclass="keyword-only-separator o"><abbrtitle="Keyword-only parameters separator (PEP 3102)"><spanclass="pre">*</span></abbr></span></em>, <emclass="sig-param"><spanclass="n"><spanclass="pre">xp</span></span></em><spanclass="sig-paren">)</span><aclass="headerlink" href="#array_api_extra.testing.patch_lazy_xp_functions" title="Link to this definition">¶</a></dt>
275
275
<dd><p>Test lazy execution of functions tagged with <aclass="reference internal" href="array_api_extra.testing.lazy_xp_function.html#array_api_extra.testing.lazy_xp_function" title="array_api_extra.testing.lazy_xp_function"><codeclass="xref py py-func docutils literal notranslate"><spanclass="pre">lazy_xp_function()</span></code></a>.</p>
276
276
<p>If <codeclass="docutils literal notranslate"><spanclass="pre">xp==jax.numpy</span></code>, search for all functions which have been tagged with
277
-
<aclass="reference internal" href="array_api_extra.testing.lazy_xp_function.html#array_api_extra.testing.lazy_xp_function" title="array_api_extra.testing.lazy_xp_function"><codeclass="xref py py-func docutils literal notranslate"><spanclass="pre">lazy_xp_function()</span></code></a> in the globals of the module that defines the current test
277
+
<aclass="reference internal" href="array_api_extra.testing.lazy_xp_function.html#array_api_extra.testing.lazy_xp_function" title="array_api_extra.testing.lazy_xp_function"><codeclass="xref py py-func docutils literal notranslate"><spanclass="pre">lazy_xp_function()</span></code></a> in the globals of the module that defines the current test,
278
+
as well as in the <codeclass="docutils literal notranslate"><spanclass="pre">lazy_xp_modules</span></code> list in the globals of the same module,
278
279
and wrap them with <aclass="reference external" href="https://docs.jax.dev/en/latest/_autosummary/jax.jit.html#jax.jit" title="(in JAX)"><codeclass="xref py py-func docutils literal notranslate"><spanclass="pre">jax.jit()</span></code></a>. Unwrap them at the end of the test.</p>
279
280
<p>If <codeclass="docutils literal notranslate"><spanclass="pre">xp==dask.array</span></code>, wrap the functions with a decorator that disables
280
-
<codeclass="docutils literal notranslate"><spanclass="pre">compute()</span></code> and <codeclass="docutils literal notranslate"><spanclass="pre">persist()</span></code>.</p>
281
+
<codeclass="docutils literal notranslate"><spanclass="pre">compute()</span></code> and <codeclass="docutils literal notranslate"><spanclass="pre">persist()</span></code> and ensures that exceptions and warnings are raised
282
+
eagerly.</p>
281
283
<p>This function should be typically called by your library’s <cite>xp</cite> fixture that runs
0 commit comments