Skip to content

Commit a2e4aff

Browse files
Merge pull request jax-ml#24425 from dfm:rename-vmap-methods
PiperOrigin-RevId: 688547393
2 parents 4f93563 + 61701af commit a2e4aff

File tree

7 files changed

+38
-38
lines changed

7 files changed

+38
-38
lines changed

docs/ffi.ipynb

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@
299299
"\n",
300300
" # The `vmap_method` parameter controls this function's behavior under `vmap`\n",
301301
" # as discussed below.\n",
302-
" vmap_method=\"broadcast_fullrank\",\n",
302+
" vmap_method=\"broadcast_all\",\n",
303303
" )\n",
304304
"\n",
305305
" # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for\n",
@@ -342,9 +342,9 @@
342342
"The simplest `vmap_method` is `\"sequential\"`.\n",
343343
"In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n",
344344
"This implementation is general purpose, but it doesn't parallelize very well.\n",
345-
"Many FFI calls provide more efficient batching behavior and, in some simple cases, the `\"broadcast\"` or `\"broadcast_fullrank\"` methods can be used to expose a better implementation.\n",
345+
"Many FFI calls provide more efficient batching behavior and, in some simple cases, the `\"expand_dims\"` or `\"broadcast_all\"` methods can be used to expose a better implementation.\n",
346346
"\n",
347-
"In this case, since we only have one input argument, `\"broadcast\"` and `\"broadcast_fullrank\"` actually have the same behavior.\n",
347+
"In this case, since we only have one input argument, `\"expand_dims\"` and `\"broadcast_all\"` actually have the same behavior.\n",
348348
"The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.\n",
349349
"Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:\n",
350350
"\n",
@@ -354,11 +354,11 @@
354354
"\n",
355355
"```{tip}\n",
356356
"Note that things get a bit more complicated when we have multiple input arguments.\n",
357-
"For simplicity, we will use the `\"broadcast_fullrank\"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `\"broadcast\"` method.\n",
357+
"For simplicity, we will use the `\"broadcast_all\"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `\"expand_dims\"` method.\n",
358358
"The documentation for {func}`~jax.pure_callback` includes some examples of this\n",
359359
"```\n",
360360
"\n",
361-
"Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method=\"broadcast_fullrank\"` out of the box:"
361+
"Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method=\"broadcast_all\"` out of the box:"
362362
]
363363
},
364364
{
@@ -460,7 +460,7 @@
460460
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
461461
" jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),\n",
462462
" ),\n",
463-
" vmap_method=\"broadcast_fullrank\",\n",
463+
" vmap_method=\"broadcast_all\",\n",
464464
" )(x, eps=np.float32(eps))\n",
465465
" return y, (res, x)\n",
466466
"\n",
@@ -474,7 +474,7 @@
474474
" jex.ffi.ffi_call(\n",
475475
" \"rms_norm_bwd\",\n",
476476
" jax.ShapeDtypeStruct(ct.shape, ct.dtype),\n",
477-
" vmap_method=\"broadcast_fullrank\",\n",
477+
" vmap_method=\"broadcast_all\",\n",
478478
" )(res, x, ct),\n",
479479
" )\n",
480480
"\n",
@@ -562,7 +562,7 @@
562562
" return lambda x: jex.ffi.ffi_call(\n",
563563
" target_name,\n",
564564
" out_type,\n",
565-
" vmap_method=\"broadcast_fullrank\",\n",
565+
" vmap_method=\"broadcast_all\",\n",
566566
" )(x, eps=np.float32(eps))\n",
567567
"\n",
568568
" return jax.lax.platform_dependent(x, cpu=impl(\"rms_norm\"), cuda=impl(\"rms_norm_cuda\"))\n",

docs/ffi.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def rms_norm(x, eps=1e-5):
260260
261261
# The `vmap_method` parameter controls this function's behavior under `vmap`
262262
# as discussed below.
263-
vmap_method="broadcast_fullrank",
263+
vmap_method="broadcast_all",
264264
)
265265
266266
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
@@ -299,9 +299,9 @@ The docs for {func}`~jax.pure_callback` provide more details about the `vmap_met
299299
The simplest `vmap_method` is `"sequential"`.
300300
In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.
301301
This implementation is general purpose, but it doesn't parallelize very well.
302-
Many FFI calls provide more efficient batching behavior and, in some simple cases, the `"broadcast"` or `"broadcast_fullrank"` methods can be used to expose a better implementation.
302+
Many FFI calls provide more efficient batching behavior and, in some simple cases, the `"expand_dims"` or `"broadcast_all"` methods can be used to expose a better implementation.
303303

304-
In this case, since we only have one input argument, `"broadcast"` and `"broadcast_fullrank"` actually have the same behavior.
304+
In this case, since we only have one input argument, `"expand_dims"` and `"broadcast_all"` actually have the same behavior.
305305
The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.
306306
Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:
307307

@@ -311,11 +311,11 @@ ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])
311311

312312
```{tip}
313313
Note that things get a bit more complicated when we have multiple input arguments.
314-
For simplicity, we will use the `"broadcast_fullrank"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `"broadcast"` method.
314+
For simplicity, we will use the `"broadcast_all"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `"expand_dims"` method.
315315
The documentation for {func}`~jax.pure_callback` includes some examples of this
316316
```
317317

318-
Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method="broadcast_fullrank"` out of the box:
318+
Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method="broadcast_all"` out of the box:
319319

320320
```{code-cell} ipython3
321321
np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)
@@ -378,7 +378,7 @@ def rms_norm_fwd(x, eps=1e-5):
378378
jax.ShapeDtypeStruct(x.shape, x.dtype),
379379
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
380380
),
381-
vmap_method="broadcast_fullrank",
381+
vmap_method="broadcast_all",
382382
)(x, eps=np.float32(eps))
383383
return y, (res, x)
384384
@@ -392,7 +392,7 @@ def rms_norm_bwd(eps, res, ct):
392392
jex.ffi.ffi_call(
393393
"rms_norm_bwd",
394394
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
395-
vmap_method="broadcast_fullrank",
395+
vmap_method="broadcast_all",
396396
)(res, x, ct),
397397
)
398398
@@ -470,7 +470,7 @@ def rms_norm_cross_platform(x, eps=1e-5):
470470
return lambda x: jex.ffi.ffi_call(
471471
target_name,
472472
out_type,
473-
vmap_method="broadcast_fullrank",
473+
vmap_method="broadcast_all",
474474
)(x, eps=np.float32(eps))
475475
476476
return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda"))

examples/ffi/src/jax_ffi_example/rms_norm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def rms_norm(x, eps=1e-5):
5858
# above in `register_ffi_target`
5959
"rms_norm",
6060
out_type,
61-
vmap_method="broadcast_fullrank",
61+
vmap_method="broadcast_all",
6262
)(x, eps=np.float32(eps))
6363

6464

@@ -69,7 +69,7 @@ def rms_norm_fwd(x, eps=1e-5):
6969
jax.ShapeDtypeStruct(x.shape, x.dtype),
7070
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
7171
),
72-
vmap_method="broadcast_fullrank",
72+
vmap_method="broadcast_all",
7373
)(x, eps=np.float32(eps))
7474
return y, (res, x)
7575

@@ -83,7 +83,7 @@ def rms_norm_bwd(eps, res, ct):
8383
jex.ffi.ffi_call(
8484
"rms_norm_bwd",
8585
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
86-
vmap_method="broadcast_fullrank",
86+
vmap_method="broadcast_all",
8787
)(res, x, ct),
8888
)
8989

jax/_src/callback.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ def callback_batching_rule(
170170
result_avals=batched_result_avals,
171171
**kwargs,
172172
)
173-
elif vmap_method == "broadcast" or vmap_method == "broadcast_fullrank":
174-
size = axis_size if vmap_method == "broadcast_fullrank" else 1
173+
elif vmap_method == "expand_dims" or vmap_method == "broadcast_all":
174+
size = axis_size if vmap_method == "broadcast_all" else 1
175175
bcast_args = [
176176
lax.broadcast(x, (size,)) if d is batching.not_mapped else x
177177
for x, d in zip(new_args, dims)]
@@ -198,7 +198,7 @@ def _batch_fun(batched_args):
198198
else:
199199
raise NotImplementedError(
200200
f"vmap is only supported for the {prim.name} primitive when vmap_method "
201-
"is one of 'sequential', 'broadcast', 'broadcast_fullrank', or "
201+
"is one of 'sequential', 'expand_dims', 'broadcast_all', or "
202202
"'legacy_vectorized'.")
203203
return tuple(outvals), (0,) * len(outvals)
204204

@@ -327,9 +327,9 @@ def pure_callback(
327327
is deprecated and it will eventually raise ``NotImplementedError``.
328328
* ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over
329329
the batched arugments, calling ``callback`` once for each batch element.
330-
* ``vmap_method="broadcast"`` calls ``callback`` with new axes of size ``1``
330+
* ``vmap_method="expand_dims"`` calls ``callback`` with new axes of size ``1``
331331
added as the leading dimension unbatched inputs.
332-
* ``vmap_method="broadcast_fullrank"`` behaves like ``broadcast``, but the
332+
* ``vmap_method="broadcast_all"`` behaves like ``expand_dims``, but the
333333
inputs are tiled to the expected batched shape.
334334
335335
If necessary, the legacy behavior provided by the deprecated
@@ -383,20 +383,20 @@ def pure_callback(
383383
... return jax.pure_callback(callback, out_type, x, y,
384384
... vmap_method=vmap_method)
385385
386-
Calling this with ``vmap_method="broadcast"`` adds a new axis of size ``1``
386+
Calling this with ``vmap_method="expand_dims"`` adds a new axis of size ``1``
387387
to ``y``:
388388
389389
>>> from functools import partial
390390
>>> x = jnp.arange(4)
391391
>>> y = 1.0
392-
>>> jax.vmap(partial(fun, vmap_method="broadcast"), in_axes=(0, None))(x, y)
392+
>>> jax.vmap(partial(fun, vmap_method="expand_dims"), in_axes=(0, None))(x, y)
393393
(4,) (1,)
394394
Array([1., 2., 3., 4.], dtype=float32)
395395
396-
Whereas, ``vmap_method="broadcast_fullrank"`` adds an axis of size ``4`` to
396+
Whereas, ``vmap_method="broadcast_all"`` adds an axis of size ``4`` to
397397
``y``:
398398
399-
>>> jax.vmap(partial(fun, vmap_method="broadcast_fullrank"),
399+
>>> jax.vmap(partial(fun, vmap_method="broadcast_all"),
400400
... in_axes=(0, None))(x, y)
401401
(4,) (4,)
402402
Array([1., 2., 3., 4.], dtype=float32)
@@ -415,7 +415,7 @@ def pure_callback(
415415
"the vectorized and vmap_method arguments of jax.pure_callback cannot "
416416
"be used together. Please use the vmap_method argument.")
417417
vmap_method = "legacy_vectorized" if vectorized else "sequential"
418-
allowed_vmap_methods = ["sequential", "broadcast", "broadcast_fullrank",
418+
allowed_vmap_methods = ["sequential", "expand_dims", "broadcast_all",
419419
"legacy_vectorized", None]
420420
if vmap_method not in allowed_vmap_methods:
421421
raise ValueError(

jax/_src/extend/ffi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def ffi_call(
256256
"the vectorized and vmap_method arguments of ffi_call cannot "
257257
"be used together. Please use the vmap_method argument.")
258258
vmap_method = "legacy_vectorized" if vectorized else "sequential"
259-
allowed_vmap_methods = ["sequential", "broadcast", "broadcast_fullrank",
259+
allowed_vmap_methods = ["sequential", "expand_dims", "broadcast_all",
260260
"legacy_vectorized", None]
261261
if vmap_method not in allowed_vmap_methods:
262262
raise ValueError(

tests/extend_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def testFfiCall(self, shape, dtype):
245245
@jtu.sample_product(
246246
shape=[(1,), (4,), (5,)],
247247
dtype=(np.int32,),
248-
vmap_method=("broadcast", "broadcast_fullrank", "sequential",
248+
vmap_method=("expand_dims", "broadcast_all", "sequential",
249249
"legacy_vectorized"),
250250
)
251251
@jtu.run_on_devices("gpu")

tests/python_callback_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -696,15 +696,15 @@ def cb2(x):
696696
@jax.jit
697697
@jax.vmap
698698
def g(x):
699-
return jax.pure_callback(cb2, x, x, vmap_method="broadcast")
699+
return jax.pure_callback(cb2, x, x, vmap_method="expand_dims")
700700

701701
np.testing.assert_allclose(g(jnp.arange(4.)), np.sin(np.arange(4.)))
702702

703703
@jax.jit
704704
@functools.partial(jax.vmap, in_axes=(0, None))
705705
def h(x, y):
706706
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
707-
vmap_method="broadcast")
707+
vmap_method="expand_dims")
708708
out = h(jnp.arange(4.), 4.)
709709
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
710710

@@ -725,7 +725,7 @@ def cb(x):
725725
@jax.jit
726726
@jax.vmap
727727
def f(x):
728-
return jax.pure_callback(cb, x, x, vmap_method="broadcast")
728+
return jax.pure_callback(cb, x, x, vmap_method="expand_dims")
729729

730730
with self.assertRaises(RuntimeError):
731731
f(jnp.arange(4.))
@@ -1007,26 +1007,26 @@ def f(x, **kwargs):
10071007
with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"):
10081008
f(jnp.arange(4.0), vectorized=False)
10091009

1010-
def test_vmap_method_broadcast(self):
1010+
def test_vmap_method_expand_dims(self):
10111011
def callback(x, y):
10121012
self.assertTupleEqual(x.shape, (4,))
10131013
self.assertTupleEqual(y.shape, (1,))
10141014
return x + y
10151015

10161016
def f(x, y):
1017-
return jax.pure_callback(callback, x, x, y, vmap_method="broadcast")
1017+
return jax.pure_callback(callback, x, x, y, vmap_method="expand_dims")
10181018

10191019
jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error
10201020

1021-
def test_vmap_method_broadcast_fullrank(self):
1021+
def test_vmap_method_broadcast_all(self):
10221022
def callback(x, y):
10231023
self.assertTupleEqual(x.shape, (4,))
10241024
self.assertTupleEqual(y.shape, (4,))
10251025
return x + y
10261026

10271027
def f(x, y):
10281028
return jax.pure_callback(callback, x, x, y,
1029-
vmap_method="broadcast_fullrank")
1029+
vmap_method="broadcast_all")
10301030

10311031
jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error
10321032

0 commit comments

Comments
 (0)