Skip to content

Commit 26418cd

Browse files
committed
Deploying to gh-pages from @ ce7342e 🚀
1 parent f2bd03b commit 26418cd

File tree

3 files changed

+53
-13
lines changed

3 files changed

+53
-13
lines changed

generated/array_api_extra.testing.lazy_xp_function.html

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -273,9 +273,8 @@ <h1>array_api_extra.testing.lazy_xp_function<a class="headerlink" href="#array-a
273273
<dt class="sig sig-object py" id="array_api_extra.testing.lazy_xp_function">
274274
<span class="sig-prename descclassname"><span class="pre">array_api_extra.testing.</span></span><span class="sig-name descname"><span class="pre">lazy_xp_function</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">func</span></span></em>, <em class="sig-param"><span class="keyword-only-separator o"><abbr title="Keyword-only parameters separator (PEP 3102)"><span class="pre">*</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">allow_dask_compute</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">jax_jit</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">static_argnums</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">static_argnames</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#array_api_extra.testing.lazy_xp_function" title="Link to this definition"></a></dt>
275275
<dd><p>Tag a function to be tested on lazy backends.</p>
276-
<p>Tag a function, which must be imported in the test module globals, so that when any
277-
tests defined in the same module are executed with <code class="docutils literal notranslate"><span class="pre">xp=jax.numpy</span></code> the function is
278-
replaced with a jitted version of itself, and when it is executed with
276+
<p>Tag a function so that when any tests are executed with <code class="docutils literal notranslate"><span class="pre">xp=jax.numpy</span></code> the
277+
function is replaced with a jitted version of itself, and when it is executed with
279278
<code class="docutils literal notranslate"><span class="pre">xp=dask.array</span></code> the function will raise if it attempts to materialize the graph.
280279
This will be later expanded to provide test coverage for other lazy backends.</p>
281280
<p>In order for the tag to be effective, the test or a fixture must call
@@ -336,17 +335,56 @@ <h1>array_api_extra.testing.lazy_xp_function<a class="headerlink" href="#array-a
336335
</pre></div>
337336
</div>
338337
<p class="rubric">Notes</p>
339-
<p>A test function can circumvent this monkey-patching system by calling <cite>func</cite> as an
340-
attribute of the original module. You need to sanitize your code to make sure this
341-
does not happen.</p>
342-
<p>Example:</p>
343-
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mymodule</span> <span class="kn">from</span><span class="w"> </span><span class="nn">mymodule</span><span class="w"> </span><span class="kn">import</span> <span class="n">myfunc</span>
338+
<p>In order for this tag to be effective, the test function must be imported into the
339+
test module globals without its namespace; alternatively its namespace must be
340+
declared in a <code class="docutils literal notranslate"><span class="pre">lazy_xp_modules</span></code> list in the test module globals.</p>
341+
<p>Example 1:</p>
342+
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span><span class="w"> </span><span class="nn">mymodule</span><span class="w"> </span><span class="kn">import</span> <span class="n">myfunc</span>
344343

345344
<span class="n">lazy_xp_function</span><span class="p">(</span><span class="n">myfunc</span><span class="p">)</span>
346345

347346
<span class="k">def</span><span class="w"> </span><span class="nf">test_myfunc</span><span class="p">(</span><span class="n">xp</span><span class="p">):</span>
348-
<span class="n">a</span> <span class="o">=</span> <span class="n">xp</span><span class="o">.</span><span class="n">asarray</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span> <span class="n">b</span> <span class="o">=</span> <span class="n">myfunc</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="c1"># This is jitted when xp=jax.numpy c =</span>
349-
<span class="n">mymodule</span><span class="o">.</span><span class="n">myfunc</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="c1"># This is not</span>
347+
<span class="n">x</span> <span class="o">=</span> <span class="n">myfunc</span><span class="p">(</span><span class="n">xp</span><span class="o">.</span><span class="n">asarray</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">]))</span>
348+
</pre></div>
349+
</div>
350+
<p>Example 2:</p>
351+
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mymodule</span>
352+
353+
<span class="n">lazy_xp_modules</span> <span class="o">=</span> <span class="p">[</span><span class="n">mymodule</span><span class="p">]</span>
354+
<span class="n">lazy_xp_function</span><span class="p">(</span><span class="n">mymodule</span><span class="o">.</span><span class="n">myfunc</span><span class="p">)</span>
355+
356+
<span class="k">def</span><span class="w"> </span><span class="nf">test_myfunc</span><span class="p">(</span><span class="n">xp</span><span class="p">):</span>
357+
<span class="n">x</span> <span class="o">=</span> <span class="n">mymodule</span><span class="o">.</span><span class="n">myfunc</span><span class="p">(</span><span class="n">xp</span><span class="o">.</span><span class="n">asarray</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">]))</span>
358+
</pre></div>
359+
</div>
360+
<p>A test function can circumvent this monkey-patching system by using a namespace
361+
outside of the two above patterns. You need to sanitize your code to make sure this
362+
only happens intentionally.</p>
363+
<p>Example 1:</p>
364+
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mymodule</span>
365+
<span class="kn">from</span><span class="w"> </span><span class="nn">mymodule</span><span class="w"> </span><span class="kn">import</span> <span class="n">myfunc</span>
366+
367+
<span class="n">lazy_xp_function</span><span class="p">(</span><span class="n">myfunc</span><span class="p">)</span>
368+
369+
<span class="k">def</span><span class="w"> </span><span class="nf">test_myfunc</span><span class="p">(</span><span class="n">xp</span><span class="p">):</span>
370+
<span class="n">a</span> <span class="o">=</span> <span class="n">xp</span><span class="o">.</span><span class="n">asarray</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
371+
<span class="n">b</span> <span class="o">=</span> <span class="n">myfunc</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="c1"># This is wrapped when xp=jax.numpy or xp=dask.array</span>
372+
<span class="n">c</span> <span class="o">=</span> <span class="n">mymodule</span><span class="o">.</span><span class="n">myfunc</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="c1"># This is not</span>
373+
</pre></div>
374+
</div>
375+
<p>Example 2:</p>
376+
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">mymodule</span>
377+
378+
<span class="k">class</span><span class="w"> </span><span class="nc">naked</span><span class="p">:</span>
379+
<span class="n">myfunc</span> <span class="o">=</span> <span class="n">mymodule</span><span class="o">.</span><span class="n">myfunc</span>
380+
381+
<span class="n">lazy_xp_modules</span> <span class="o">=</span> <span class="p">[</span><span class="n">mymodule</span><span class="p">]</span>
382+
<span class="n">lazy_xp_function</span><span class="p">(</span><span class="n">mymodule</span><span class="o">.</span><span class="n">myfunc</span><span class="p">)</span>
383+
384+
<span class="k">def</span><span class="w"> </span><span class="nf">test_myfunc</span><span class="p">(</span><span class="n">xp</span><span class="p">):</span>
385+
<span class="n">a</span> <span class="o">=</span> <span class="n">xp</span><span class="o">.</span><span class="n">asarray</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
386+
<span class="n">b</span> <span class="o">=</span> <span class="n">mymodule</span><span class="o">.</span><span class="n">myfunc</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="c1"># This is wrapped when xp=jax.numpy or xp=dask.array</span>
387+
<span class="n">c</span> <span class="o">=</span> <span class="n">naked</span><span class="o">.</span><span class="n">myfunc</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="c1"># This is not</span>
350388
</pre></div>
351389
</div>
352390
</dd></dl>

generated/array_api_extra.testing.patch_lazy_xp_functions.html

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,12 @@ <h1>array_api_extra.testing.patch_lazy_xp_functions<a class="headerlink" href="#
274274
<span class="sig-prename descclassname"><span class="pre">array_api_extra.testing.</span></span><span class="sig-name descname"><span class="pre">patch_lazy_xp_functions</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">request</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">monkeypatch</span></span></em>, <em class="sig-param"><span class="keyword-only-separator o"><abbr title="Keyword-only parameters separator (PEP 3102)"><span class="pre">*</span></abbr></span></em>, <em class="sig-param"><span class="n"><span class="pre">xp</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#array_api_extra.testing.patch_lazy_xp_functions" title="Link to this definition"></a></dt>
275275
<dd><p>Test lazy execution of functions tagged with <a class="reference internal" href="array_api_extra.testing.lazy_xp_function.html#array_api_extra.testing.lazy_xp_function" title="array_api_extra.testing.lazy_xp_function"><code class="xref py py-func docutils literal notranslate"><span class="pre">lazy_xp_function()</span></code></a>.</p>
276276
<p>If <code class="docutils literal notranslate"><span class="pre">xp==jax.numpy</span></code>, search for all functions which have been tagged with
277-
<a class="reference internal" href="array_api_extra.testing.lazy_xp_function.html#array_api_extra.testing.lazy_xp_function" title="array_api_extra.testing.lazy_xp_function"><code class="xref py py-func docutils literal notranslate"><span class="pre">lazy_xp_function()</span></code></a> in the globals of the module that defines the current test
277+
<a class="reference internal" href="array_api_extra.testing.lazy_xp_function.html#array_api_extra.testing.lazy_xp_function" title="array_api_extra.testing.lazy_xp_function"><code class="xref py py-func docutils literal notranslate"><span class="pre">lazy_xp_function()</span></code></a> in the globals of the module that defines the current test,
278+
as well as in the <code class="docutils literal notranslate"><span class="pre">lazy_xp_modules</span></code> list in the globals of the same module,
278279
and wrap them with <a class="reference external" href="https://docs.jax.dev/en/latest/_autosummary/jax.jit.html#jax.jit" title="(in JAX)"><code class="xref py py-func docutils literal notranslate"><span class="pre">jax.jit()</span></code></a>. Unwrap them at the end of the test.</p>
279280
<p>If <code class="docutils literal notranslate"><span class="pre">xp==dask.array</span></code>, wrap the functions with a decorator that disables
280-
<code class="docutils literal notranslate"><span class="pre">compute()</span></code> and <code class="docutils literal notranslate"><span class="pre">persist()</span></code>.</p>
281+
<code class="docutils literal notranslate"><span class="pre">compute()</span></code> and <code class="docutils literal notranslate"><span class="pre">persist()</span></code> and ensures that exceptions and warnings are raised
282+
eagerly.</p>
281283
<p>This function should be typically called by your library’s <cite>xp</cite> fixture that runs
282284
tests on multiple backends:</p>
283285
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@pytest</span><span class="o">.</span><span class="n">fixture</span><span class="p">(</span><span class="n">params</span><span class="o">=</span><span class="p">[</span><span class="n">numpy</span><span class="p">,</span> <span class="n">array_api_strict</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="p">,</span> <span class="n">dask</span><span class="o">.</span><span class="n">array</span><span class="p">])</span>

0 commit comments

Comments
 (0)