Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 5710831

Browse files
authored
Added chunks arg to vmap (#774)
* Added chunks arg to vmap Description: - Added chunks arg to vmap - Added a test * Create chunk_vmap into experimental * COde formatting * Updated tests Refactored common code and fixed random state with randomness = same * Updated docstring and split tests by randomness
1 parent a60ef90 commit 5710831

File tree

6 files changed

+174
-24
lines changed

6 files changed

+174
-24
lines changed

functorch/_src/vmap.py

Lines changed: 130 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -352,18 +352,139 @@ def vmap(
352352
vmap does not provide general autobatching or handle variable-length
353353
sequences out of the box.
354354
"""
355-
if randomness not in ['error', 'different', 'same']:
356-
raise RuntimeError(f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}")
355+
_check_randomness_arg(randomness)
357356

358357
@functools.wraps(func)
359358
def wrapped(*args, **kwargs):
360359
_check_out_dims_is_int_or_int_pytree(out_dims, func)
361360
batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs(in_dims, args, func)
362-
vmap_level = _vmap_increment_nesting(batch_size, randomness)
363-
try:
364-
batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
365-
batched_outputs = func(*batched_inputs, **kwargs)
366-
return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
367-
finally:
368-
_vmap_decrement_nesting()
361+
return _flat_vmap(
362+
func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
363+
)
364+
369365
return wrapped
366+
367+
368+
def chunk_vmap(
369+
func: Callable,
370+
in_dims: in_dims_t = 0,
371+
out_dims: out_dims_t = 0,
372+
randomness: str = 'error',
373+
chunks=2) -> Callable:
374+
"""
375+
chunk_vmap is the vectorizing map (vmap) using chunks of input data. It is a mix of vmap (which vectorizes
376+
everything) and map (which executes things sequentially). ``chunk_vmap`` vectorizes the input with number of
377+
chunks at a time. For more details about vectorizing map, see :func:`vmap`.
378+
379+
Args:
380+
func (function): A Python function that takes one or more arguments.
381+
Must return one or more Tensors.
382+
in_dims (int or nested structure): Specifies which dimension of the
383+
inputs should be mapped over. :attr:`in_dims` should have a
384+
structure like the inputs. If the :attr:`in_dim` for a particular
385+
input is None, then that indicates there is no map dimension.
386+
Default: 0.
387+
out_dims (int or Tuple[int]): Specifies where the mapped dimension
388+
should appear in the outputs. If :attr:`out_dims` is a Tuple, then
389+
it should have one element per output. Default: 0.
390+
randomness (str): Specifies whether the randomness in this
391+
vmap should be the same or different across batches. If 'different',
392+
the randomness for each batch will be different. If 'same', the
393+
randomness will be the same across batches. If 'error', any calls to
394+
random functions will error. Default: 'error'. WARNING: this flag
395+
only applies to random PyTorch operations and does not apply to
396+
Python's random module or numpy randomness.
397+
chunks (int): Number of chunks to use to split the input data. Default is 2.
398+
If equals to 1 then :func:`vmap` is called.
399+
400+
Returns:
401+
Returns a new "batched" function. It takes the same inputs as
402+
:attr:`func`, except each input has an extra dimension at the index
403+
specified by :attr:`in_dims`. It takes returns the same outputs as
404+
:attr:`func`, except each output has an extra dimension at the index
405+
specified by :attr:`out_dims`.
406+
"""
407+
_check_randomness_arg(randomness)
408+
409+
if chunks == 1:
410+
return vmap(func, in_dims=in_dims, out_dims=out_dims, randomness=randomness)
411+
412+
def _get_chunk_flat_args(flat_args_, flat_in_dims_, chunks_):
413+
flat_args_chunks = tuple(
414+
t.chunk(chunks_, dim=in_dim) if in_dim is not None else [t, ] * chunks_
415+
for t, in_dim in zip(flat_args_, flat_in_dims_)
416+
)
417+
# transpose chunk dim and flatten structure
418+
# chunks_flat_args is a list of flatten args
419+
chunks_flat_args = zip(*flat_args_chunks)
420+
return chunks_flat_args
421+
422+
def _flatten_chunks_output(chunks_output_):
423+
# chunks_output is a list of chunked outputs
424+
# flatten chunked outputs:
425+
flat_chunks_output = []
426+
arg_spec_list = []
427+
for output in chunks_output_:
428+
flat_output, arg_specs = tree_flatten(output)
429+
flat_chunks_output.append(flat_output)
430+
arg_spec_list.append(arg_specs)
431+
432+
arg_spec = arg_spec_list[0] # all specs should be the same
433+
# transpose chunk dim and flatten structure
434+
# flat_output_chunks is flat list of chunks
435+
flat_output_chunks = list(zip(*flat_chunks_output))
436+
return flat_output_chunks, arg_spec
437+
438+
@functools.wraps(func)
439+
def wrapped_with_chunks(*args, **kwargs):
440+
_check_out_dims_is_int_or_int_pytree(out_dims, func)
441+
_, flat_in_dims, flat_args, args_spec = _process_batched_inputs(in_dims, args, func)
442+
# Chunk flat arguments
443+
chunks_flat_args = _get_chunk_flat_args(flat_args, flat_in_dims, chunks)
444+
445+
# Apply vmap on chunks
446+
chunks_output = []
447+
rs = torch.get_rng_state() if randomness == "same" else None
448+
for flat_args in chunks_flat_args:
449+
batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
450+
if rs is not None:
451+
torch.set_rng_state(rs)
452+
chunks_output.append(
453+
_flat_vmap(
454+
func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
455+
)
456+
)
457+
flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output)
458+
# Removing temporary variables helps to reduce memory usage on device like CUDA
459+
del chunks_output
460+
461+
# concat chunks on out_dim
462+
flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec)
463+
assert len(flat_out_dims) == len(flat_output_chunks)
464+
flat_output = []
465+
for out_dim in flat_out_dims:
466+
flat_output.append(torch.cat(flat_output_chunks[0], dim=out_dim))
467+
# release source data
468+
del flat_output_chunks[0]
469+
del flat_output_chunks
470+
471+
# finally unflatten the output
472+
return tree_unflatten(flat_output, arg_spec)
473+
474+
return wrapped_with_chunks
475+
476+
477+
# Vmap refactored helper funcions:
478+
def _check_randomness_arg(randomness):
479+
if randomness not in ['error', 'different', 'same']:
480+
raise RuntimeError(f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}")
481+
482+
483+
def _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs):
484+
vmap_level = _vmap_increment_nesting(batch_size, randomness)
485+
try:
486+
batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
487+
batched_outputs = func(*batched_inputs, **kwargs)
488+
return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
489+
finally:
490+
_vmap_decrement_nesting()

functorch/experimental/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .batch_norm_replacement import replace_all_batch_norm_modules_
22
# PyTorch forward-mode is not mature yet
33
from .._src.eager_transforms import jvp, jacfwd, hessian, functionalize
4+
from .._src.vmap import chunk_vmap

test/test_eager_transforms.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2946,7 +2946,6 @@ def f(x: torch.Tensor) -> torch.Tensor:
29462946
return x
29472947
self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device))
29482948

2949-
29502949
def test_inplace_view(self, device):
29512950

29522951
def f(x: torch.Tensor) -> torch.Tensor:

test/test_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,8 @@ def _test(_op):
484484
@skipOps('TestOperators', 'test_vjpvjp', vjp_fail.union({
485485
skip('nn.functional.max_unpool1d'), # Flaky
486486
skip('nn.functional.max_unpool2d'), # Flaky
487-
skip('nn.functional.fractional_max_pool2d'), # randomness
488-
skip('nn.functional.fractional_max_pool3d'), # randomness
487+
skip('nn.functional.fractional_max_pool2d'), # randomness
488+
skip('nn.functional.fractional_max_pool3d'), # randomness
489489
}))
490490
@opsToleranceOverride('TestOperators', 'test_vjpvjp', (
491491
tol1('nn.functional.conv_transpose3d',

test/test_pythonkey.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def f(x):
113113
def test_make_fx_no_decompose(self, device):
114114
# FIXME
115115
return self.skipTest("error: maximum recursion reached")
116+
116117
def f(x):
117118
return torch.tanh(x).sum()
118119

test/test_vmap.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
import functorch
4141
from functorch import vmap, grad, grad_and_value, jvp, vjp
42+
from functorch.experimental import chunk_vmap
4243
from functorch._C import reshape_dim_into, reshape_dim_outof
4344
from functorch._src.make_functional import functional_init_with_buffers
4445

@@ -879,7 +880,6 @@ def test_backward_unsupported_interaction(self):
879880
def backward_on_vmapped_tensor(x):
880881
x.sum().backward()
881882

882-
883883
# FIXME
884884
return self.skipTest("error: element 0 of tensors does not require grad and does not have a grad_fn")
885885
with self.assertRaisesRegex(RuntimeError, err_msg):
@@ -2719,14 +2719,28 @@ def naive_f(x, shape):
27192719

27202720
self.assertTrue(torch.randn(()).dim() == 0)
27212721

2722-
@parametrize('op', [torch.cos, torch.sinh], name_fn=lambda f: f.__name__)
2723-
def test_foobar_parametrize(self, op):
2724-
pass
2722+
@parametrize('in_dim', [0, 1, 2])
2723+
@parametrize('out_dim', [0, 1, 2])
2724+
@parametrize('randomness', ['error', 'same'])
2725+
def test_chunk_vmap(self, in_dim, out_dim, randomness):
2726+
2727+
x = torch.randn(4, 5, 6)
2728+
2729+
def f(x):
2730+
y = x.sin()
2731+
if randomness != "error":
2732+
y = y + torch.rand_like(x)
2733+
return y
2734+
2735+
rs = torch.get_rng_state()
2736+
expected = vmap(f, in_dims=in_dim, out_dims=out_dim, randomness=randomness)(x)
27252737

2726-
@parametrize('op2', [torch.cos, torch.sinh], name_fn=lambda f: f.__name__)
2727-
@parametrize('op1', [torch.abs, torch.acos], name_fn=lambda f: f.__name__)
2728-
def test_parametrize_multiple(self, op1, op2):
2729-
pass
2738+
for chunks in [1, 2, 3, 4, 7, 10, 16]:
2739+
torch.set_rng_state(rs)
2740+
output = chunk_vmap(
2741+
f, in_dims=in_dim, out_dims=out_dim, randomness=randomness, chunks=chunks
2742+
)(x)
2743+
self.assertEqual(output, expected)
27302744

27312745

27322746
instantiate_parametrized_tests(TestVmapOperators)
@@ -2906,10 +2920,6 @@ def test_log1p(self, device):
29062920
self._batched_grad_test(torch.log1p, (x,))
29072921
self._batched_grad_grad_test(torch.log1p, (x,))
29082922

2909-
@parametrize('param', ['foo', 'bar'])
2910-
def test_param_device(self, device, param):
2911-
pass
2912-
29132923
@allowVmapFallbackUsage
29142924
def test_max(self, device):
29152925
x = torch.randn(2, 3, requires_grad=True, device=device)
@@ -4160,6 +4170,24 @@ def f(z):
41604170
return torch.rrelu(x)
41614171
vmap(f, randomness='same')(z)
41624172

4173+
@parametrize('in_dim', [0, 1, 2])
4174+
@parametrize('out_dim', [0, 1, 2])
4175+
def test_chunk_vmap(self, in_dim, out_dim):
4176+
4177+
randomness = "different"
4178+
4179+
x = torch.randn(4, 5, 6)
4180+
4181+
def f(x):
4182+
y = x.sin() + torch.rand_like(x)
4183+
return y
4184+
4185+
for chunks in [1, 2, 3, 4, 7, 10, 16]:
4186+
output = chunk_vmap(
4187+
f, in_dims=in_dim, out_dims=out_dim, randomness=randomness, chunks=chunks
4188+
)(x)
4189+
self._assert_all_slices_unique(output)
4190+
41634191

41644192
class TestTransformFailure(TestCase):
41654193
@parametrize('transform', ['vmap', 'grad', 'grad_and_value', 'vjp', 'jvp', 'jacrev', 'jacfwd'])

0 commit comments

Comments
 (0)