You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/ffi.ipynb
+8-8Lines changed: 8 additions & 8 deletions
Original file line number
Diff line number
Diff line change
@@ -299,7 +299,7 @@
299
299
"\n",
300
300
" # The `vmap_method` parameter controls this function's behavior under `vmap`\n",
301
301
" # as discussed below.\n",
302
-
" vmap_method=\"broadcast_fullrank\",\n",
302
+
" vmap_method=\"broadcast_all\",\n",
303
303
" )\n",
304
304
"\n",
305
305
" # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for\n",
@@ -342,9 +342,9 @@
342
342
"The simplest `vmap_method` is `\"sequential\"`.\n",
343
343
"In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n",
344
344
"This implementation is general purpose, but it doesn't parallelize very well.\n",
345
-
"Many FFI calls provide more efficient batching behavior and, in some simple cases, the `\"broadcast\"` or `\"broadcast_fullrank\"` methods can be used to expose a better implementation.\n",
345
+
"Many FFI calls provide more efficient batching behavior and, in some simple cases, the `\"expand_dims\"` or `\"broadcast_all\"` methods can be used to expose a better implementation.\n",
346
346
"\n",
347
-
"In this case, since we only have one input argument, `\"broadcast\"` and `\"broadcast_fullrank\"` actually have the same behavior.\n",
347
+
"In this case, since we only have one input argument, `\"expand_dims\"` and `\"broadcast_all\"` actually have the same behavior.\n",
348
348
"The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.\n",
349
349
"Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:\n",
350
350
"\n",
@@ -354,11 +354,11 @@
354
354
"\n",
355
355
"```{tip}\n",
356
356
"Note that things get a bit more complicated when we have multiple input arguments.\n",
357
-
"For simplicity, we will use the `\"broadcast_fullrank\"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `\"broadcast\"` method.\n",
357
+
"For simplicity, we will use the `\"broadcast_all\"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `\"expand_dims\"` method.\n",
358
358
"The documentation for {func}`~jax.pure_callback` includes some examples of this\n",
359
359
"```\n",
360
360
"\n",
361
-
"Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method=\"broadcast_fullrank\"` out of the box:"
361
+
"Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method=\"broadcast_all\"` out of the box:"
Copy file name to clipboardExpand all lines: docs/ffi.md
+8-8Lines changed: 8 additions & 8 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -260,7 +260,7 @@ def rms_norm(x, eps=1e-5):
260
260
261
261
# The `vmap_method` parameter controls this function's behavior under `vmap`
262
262
# as discussed below.
263
-
vmap_method="broadcast_fullrank",
263
+
vmap_method="broadcast_all",
264
264
)
265
265
266
266
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
@@ -299,9 +299,9 @@ The docs for {func}`~jax.pure_callback` provide more details about the `vmap_met
299
299
The simplest `vmap_method` is `"sequential"`.
300
300
In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.
301
301
This implementation is general purpose, but it doesn't parallelize very well.
302
-
Many FFI calls provide more efficient batching behavior and, in some simple cases, the `"broadcast"` or `"broadcast_fullrank"` methods can be used to expose a better implementation.
302
+
Many FFI calls provide more efficient batching behavior and, in some simple cases, the `"expand_dims"` or `"broadcast_all"` methods can be used to expose a better implementation.
303
303
304
-
In this case, since we only have one input argument, `"broadcast"` and `"broadcast_fullrank"` actually have the same behavior.
304
+
In this case, since we only have one input argument, `"expand_dims"` and `"broadcast_all"` actually have the same behavior.
305
305
The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.
306
306
Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:
307
307
@@ -311,11 +311,11 @@ ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])
311
311
312
312
```{tip}
313
313
Note that things get a bit more complicated when we have multiple input arguments.
314
-
For simplicity, we will use the `"broadcast_fullrank"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `"broadcast"` method.
314
+
For simplicity, we will use the `"broadcast_all"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `"expand_dims"` method.
315
315
The documentation for {func}`~jax.pure_callback` includes some examples of this
316
316
```
317
317
318
-
Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method="broadcast_fullrank"` out of the box:
318
+
Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method="broadcast_all"` out of the box:
0 commit comments