@@ -352,18 +352,139 @@ def vmap(
352
352
vmap does not provide general autobatching or handle variable-length
353
353
sequences out of the box.
354
354
"""
355
- if randomness not in ['error' , 'different' , 'same' ]:
356
- raise RuntimeError (f"Only allowed values for randomness are 'error', 'different', or 'same'. Got { randomness } " )
355
+ _check_randomness_arg (randomness )
357
356
358
357
@functools .wraps (func )
359
358
def wrapped (* args , ** kwargs ):
360
359
_check_out_dims_is_int_or_int_pytree (out_dims , func )
361
360
batch_size , flat_in_dims , flat_args , args_spec = _process_batched_inputs (in_dims , args , func )
362
- vmap_level = _vmap_increment_nesting (batch_size , randomness )
363
- try :
364
- batched_inputs = _create_batched_inputs (flat_in_dims , flat_args , vmap_level , args_spec )
365
- batched_outputs = func (* batched_inputs , ** kwargs )
366
- return _unwrap_batched (batched_outputs , out_dims , vmap_level , batch_size , func )
367
- finally :
368
- _vmap_decrement_nesting ()
361
+ return _flat_vmap (
362
+ func , batch_size , flat_in_dims , flat_args , args_spec , out_dims , randomness , ** kwargs
363
+ )
364
+
369
365
return wrapped
366
+
367
+
368
+ def chunk_vmap (
369
+ func : Callable ,
370
+ in_dims : in_dims_t = 0 ,
371
+ out_dims : out_dims_t = 0 ,
372
+ randomness : str = 'error' ,
373
+ chunks = 2 ) -> Callable :
374
+ """
375
+ chunk_vmap is the vectorizing map (vmap) using chunks of input data. It is a mix of vmap (which vectorizes
376
+ everything) and map (which executes things sequentially). ``chunk_vmap`` vectorizes the input with number of
377
+ chunks at a time. For more details about vectorizing map, see :func:`vmap`.
378
+
379
+ Args:
380
+ func (function): A Python function that takes one or more arguments.
381
+ Must return one or more Tensors.
382
+ in_dims (int or nested structure): Specifies which dimension of the
383
+ inputs should be mapped over. :attr:`in_dims` should have a
384
+ structure like the inputs. If the :attr:`in_dim` for a particular
385
+ input is None, then that indicates there is no map dimension.
386
+ Default: 0.
387
+ out_dims (int or Tuple[int]): Specifies where the mapped dimension
388
+ should appear in the outputs. If :attr:`out_dims` is a Tuple, then
389
+ it should have one element per output. Default: 0.
390
+ randomness (str): Specifies whether the randomness in this
391
+ vmap should be the same or different across batches. If 'different',
392
+ the randomness for each batch will be different. If 'same', the
393
+ randomness will be the same across batches. If 'error', any calls to
394
+ random functions will error. Default: 'error'. WARNING: this flag
395
+ only applies to random PyTorch operations and does not apply to
396
+ Python's random module or numpy randomness.
397
+ chunks (int): Number of chunks to use to split the input data. Default is 2.
398
+ If equals to 1 then :func:`vmap` is called.
399
+
400
+ Returns:
401
+ Returns a new "batched" function. It takes the same inputs as
402
+ :attr:`func`, except each input has an extra dimension at the index
403
+ specified by :attr:`in_dims`. It takes returns the same outputs as
404
+ :attr:`func`, except each output has an extra dimension at the index
405
+ specified by :attr:`out_dims`.
406
+ """
407
+ _check_randomness_arg (randomness )
408
+
409
+ if chunks == 1 :
410
+ return vmap (func , in_dims = in_dims , out_dims = out_dims , randomness = randomness )
411
+
412
+ def _get_chunk_flat_args (flat_args_ , flat_in_dims_ , chunks_ ):
413
+ flat_args_chunks = tuple (
414
+ t .chunk (chunks_ , dim = in_dim ) if in_dim is not None else [t , ] * chunks_
415
+ for t , in_dim in zip (flat_args_ , flat_in_dims_ )
416
+ )
417
+ # transpose chunk dim and flatten structure
418
+ # chunks_flat_args is a list of flatten args
419
+ chunks_flat_args = zip (* flat_args_chunks )
420
+ return chunks_flat_args
421
+
422
+ def _flatten_chunks_output (chunks_output_ ):
423
+ # chunks_output is a list of chunked outputs
424
+ # flatten chunked outputs:
425
+ flat_chunks_output = []
426
+ arg_spec_list = []
427
+ for output in chunks_output_ :
428
+ flat_output , arg_specs = tree_flatten (output )
429
+ flat_chunks_output .append (flat_output )
430
+ arg_spec_list .append (arg_specs )
431
+
432
+ arg_spec = arg_spec_list [0 ] # all specs should be the same
433
+ # transpose chunk dim and flatten structure
434
+ # flat_output_chunks is flat list of chunks
435
+ flat_output_chunks = list (zip (* flat_chunks_output ))
436
+ return flat_output_chunks , arg_spec
437
+
438
+ @functools .wraps (func )
439
+ def wrapped_with_chunks (* args , ** kwargs ):
440
+ _check_out_dims_is_int_or_int_pytree (out_dims , func )
441
+ _ , flat_in_dims , flat_args , args_spec = _process_batched_inputs (in_dims , args , func )
442
+ # Chunk flat arguments
443
+ chunks_flat_args = _get_chunk_flat_args (flat_args , flat_in_dims , chunks )
444
+
445
+ # Apply vmap on chunks
446
+ chunks_output = []
447
+ rs = torch .get_rng_state () if randomness == "same" else None
448
+ for flat_args in chunks_flat_args :
449
+ batch_size = _validate_and_get_batch_size (flat_in_dims , flat_args )
450
+ if rs is not None :
451
+ torch .set_rng_state (rs )
452
+ chunks_output .append (
453
+ _flat_vmap (
454
+ func , batch_size , flat_in_dims , flat_args , args_spec , out_dims , randomness , ** kwargs
455
+ )
456
+ )
457
+ flat_output_chunks , arg_spec = _flatten_chunks_output (chunks_output )
458
+ # Removing temporary variables helps to reduce memory usage on device like CUDA
459
+ del chunks_output
460
+
461
+ # concat chunks on out_dim
462
+ flat_out_dims = _broadcast_to_and_flatten (out_dims , arg_spec )
463
+ assert len (flat_out_dims ) == len (flat_output_chunks )
464
+ flat_output = []
465
+ for out_dim in flat_out_dims :
466
+ flat_output .append (torch .cat (flat_output_chunks [0 ], dim = out_dim ))
467
+ # release source data
468
+ del flat_output_chunks [0 ]
469
+ del flat_output_chunks
470
+
471
+ # finally unflatten the output
472
+ return tree_unflatten (flat_output , arg_spec )
473
+
474
+ return wrapped_with_chunks
475
+
476
+
477
+ # Vmap refactored helper funcions:
478
+ def _check_randomness_arg (randomness ):
479
+ if randomness not in ['error' , 'different' , 'same' ]:
480
+ raise RuntimeError (f"Only allowed values for randomness are 'error', 'different', or 'same'. Got { randomness } " )
481
+
482
+
483
+ def _flat_vmap (func , batch_size , flat_in_dims , flat_args , args_spec , out_dims , randomness , ** kwargs ):
484
+ vmap_level = _vmap_increment_nesting (batch_size , randomness )
485
+ try :
486
+ batched_inputs = _create_batched_inputs (flat_in_dims , flat_args , vmap_level , args_spec )
487
+ batched_outputs = func (* batched_inputs , ** kwargs )
488
+ return _unwrap_batched (batched_outputs , out_dims , vmap_level , batch_size , func )
489
+ finally :
490
+ _vmap_decrement_nesting ()
0 commit comments