You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: generated/array_api_extra.testing.lazy_xp_function.html
+23-10Lines changed: 23 additions & 10 deletions
Original file line number
Diff line number
Diff line change
@@ -273,7 +273,7 @@
273
273
<h1>array_api_extra.testing.lazy_xp_function<aclass="headerlink" href="#array-api-extra-testing-lazy-xp-function" title="Link to this heading">¶</a></h1>
<p>Default: False, meaning that <cite>func</cite> must be fully lazy and never materialize the
304
304
graph.</p>
305
305
</p></li>
306
-
<li><p><strong>jax_jit</strong> (<aclass="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><em>bool</em></a><em>, </em><em>optional</em>) – Set to True to replace <cite>func</cite> with <codeclass="docutils literal notranslate"><spanclass="pre">jax.jit(func)</span></code> after calling the
307
-
<aclass="reference internal" href="array_api_extra.testing.patch_lazy_xp_functions.html#array_api_extra.testing.patch_lazy_xp_functions" title="array_api_extra.testing.patch_lazy_xp_functions"><codeclass="xref py py-func docutils literal notranslate"><spanclass="pre">patch_lazy_xp_functions()</span></code></a> test helper with <codeclass="docutils literal notranslate"><spanclass="pre">xp=jax.numpy</span></code>. Set to False
308
-
if <cite>func</cite> is only compatible with eager (non-jitted) JAX. Default: True.</p></li>
309
-
<li><p><strong>static_argnums</strong> (<aclass="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.13)"><em>int</em></a><em> | </em><em>Sequence</em><em>[</em><aclass="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.13)"><em>int</em></a><em>]</em><em>, </em><em>optional</em>) – Passed to jax.jit. Positional arguments to treat as static (compile-time
310
-
constant). Default: infer from <cite>static_argnames</cite> using
311
-
<cite>inspect.signature(func)</cite>.</p></li>
312
-
<li><p><strong>static_argnames</strong> (<aclass="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><em>str</em></a><em> | </em><em>Iterable</em><em>[</em><aclass="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><em>str</em></a><em>]</em><em>, </em><em>optional</em>) – Passed to jax.jit. Named arguments to treat as static (compile-time constant).
313
-
Default: infer from <cite>static_argnums</cite> using <cite>inspect.signature(func)</cite>.</p></li>
306
+
<li><p><strong>jax_jit</strong> (<aclass="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><em>bool</em></a><em>, </em><em>optional</em>) – <p>Set to True to replace <cite>func</cite> with a smart variant of <codeclass="docutils literal notranslate"><spanclass="pre">jax.jit(func)</span></code> after
307
+
calling the <aclass="reference internal" href="array_api_extra.testing.patch_lazy_xp_functions.html#array_api_extra.testing.patch_lazy_xp_functions" title="array_api_extra.testing.patch_lazy_xp_functions"><codeclass="xref py py-func docutils literal notranslate"><spanclass="pre">patch_lazy_xp_functions()</span></code></a> test helper with <codeclass="docutils literal notranslate"><spanclass="pre">xp=jax.numpy</span></code>.
308
+
Set to False if <cite>func</cite> is only compatible with eager (non-jitted) JAX.</p>
309
+
<p>Unlike with vanilla <codeclass="docutils literal notranslate"><spanclass="pre">jax.jit</span></code>, all arguments and return types that are not JAX
310
+
arrays are treated as static; the function can accept and return arbitrary
311
+
wrappers around JAX arrays. This difference is because, in real life, most users
312
+
won’t wrap the function directly with <codeclass="docutils literal notranslate"><spanclass="pre">jax.jit</span></code> but rather they will use it
313
+
within their own code, which is itself then wrapped by <codeclass="docutils literal notranslate"><spanclass="pre">jax.jit</span></code>, and
314
+
internally consume the function’s outputs.</p>
315
+
<p>In other words, the pattern that is being tested is:</p>
0 commit comments