@@ -38,7 +38,90 @@ def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values):
38
38
return loop_out
39
39
40
40
41
- def get_exhaustive_batched_inputs (arg_values , kwarg_values , batch_size = 3 , bdims = (0 , - 1 ), for_batch_norm = False ):
41
+ # This is kind of dangerous, please think carefully before using it.
42
+ # Known risks:
43
+ # - the return better not be mutated so it's best to return immutable types
44
+ # (e.g. prefer tuples to list)
45
+ # - Don't hash tensors in a global context, that'll keep them around forever
46
+ def memoize (fn ):
47
+ memo = {}
48
+ def wrapped (* args ):
49
+ if args not in memo :
50
+ memo [args ] = fn (* args )
51
+ return memo [args ]
52
+ return wrapped
53
+
54
+
55
+ # NB: This is O(2 ** num_tensors).
56
+ # num_tensors ranges from 1 to 10, with 2-4 being most common.
57
+ # Try not to extravagate it if you're modifying it.
58
+ @memoize
59
+ def get_bdim_choices (num_tensors ):
60
+ choices = []
61
+
62
+ # full of zeros
63
+ choices .append ((0 ,) * num_tensors )
64
+
65
+ # All permutations of (-1, None)
66
+ options = (- 1 , None )
67
+ for choice in itertools .product (options , repeat = num_tensors ):
68
+ choices .append (choice )
69
+
70
+ assert choices [- 1 ] == (None ,) * num_tensors
71
+ return tuple (choices [:- 1 ])
72
+
73
+
74
+ def add_batch_dim (arg , bdim , batch_size = 3 ):
75
+ assert bdim == 0 or bdim == - 1
76
+ assert isinstance (arg , torch .Tensor )
77
+ if bdim == 0 :
78
+ shape = [1 ] * len (arg .shape )
79
+ shape .insert (bdim , batch_size )
80
+ return (arg .repeat (shape ), bdim )
81
+ if bdim == - 1 :
82
+ arg = arg .unsqueeze (- 1 ).expand (* arg .shape , batch_size ).contiguous ()
83
+ return (arg , bdim )
84
+
85
+
86
+ def construct_in_dims (bdim_choice_for_tensors , is_tensors ):
87
+ result = []
88
+ bdim = iter (bdim_choice_for_tensors )
89
+ for is_tensor in is_tensors :
90
+ if not is_tensor :
91
+ result .append (None )
92
+ continue
93
+ result .append (next (bdim ))
94
+ return tuple (result )
95
+
96
+
97
+ def get_exhaustive_batched_inputs (arg_values , kwarg_values , batch_size = 2 , * , for_batch_norm = False ):
98
+ if for_batch_norm :
99
+ # TODO: delete this path
100
+ return get_exhaustive_batched_inputs_batch_norm (arg_values , kwarg_values , batch_size )
101
+
102
+ flat_args , arg_spec = pytree .tree_flatten (tuple (arg_values ))
103
+ is_tensors = [isinstance (a , torch .Tensor ) for a in flat_args ]
104
+ bdim_choices = get_bdim_choices (sum (is_tensors ))
105
+
106
+ @memoize
107
+ def get_batched_arg (arg , bdim ):
108
+ assert isinstance (arg , torch .Tensor )
109
+ assert bdim is not None
110
+ result , _ = add_batch_dim (arg , bdim , batch_size )
111
+ return result
112
+
113
+ for bdim_choice in bdim_choices :
114
+ flat_in_dims = construct_in_dims (bdim_choice , is_tensors )
115
+
116
+ flat_batched_args = tuple (arg if in_dim is None else get_batched_arg (arg , in_dim )
117
+ for arg , in_dim in zip (flat_args , flat_in_dims ))
118
+ batched_args = pytree .tree_unflatten (flat_batched_args , arg_spec )
119
+ in_dims = pytree .tree_unflatten (flat_in_dims , arg_spec )
120
+ yield batched_args , in_dims , kwarg_values
121
+
122
+
123
+ def get_exhaustive_batched_inputs_batch_norm (arg_values , kwarg_values , batch_size = 3 , bdims = (0 , - 1 )):
124
+ for_batch_norm = True
42
125
assert bdims == (0 ,) or bdims == (0 , - 1 )
43
126
44
127
def add_batch_dim (arg , bdim , batch_size = 3 ):
@@ -112,18 +195,12 @@ def add_batch_choices(a):
112
195
yield batched_args_tuple , in_dims_tuple , kwarg_values
113
196
114
197
115
- def get_exhaustive_batched_inputs_for_batch_norm (arg_values , kwarg_values , batch_size = 3 , bdims = (0 , - 1 )):
116
- return get_exhaustive_batched_inputs (arg_values , kwarg_values ,
117
- batch_size = batch_size , bdims = bdims , for_batch_norm = True )
118
-
119
-
120
- def get_fallback_and_vmap_exhaustive (op , arg_values , kwarg_values , opinfo = None , compute_loop_out = True , bdims = (0 , - 1 )):
198
+ def get_fallback_and_vmap_exhaustive (op , arg_values , kwarg_values , opinfo = None , compute_loop_out = True ):
121
199
out_dim = 0
122
200
batch_size = 2
123
- generator = get_exhaustive_batched_inputs (arg_values , kwarg_values , batch_size , bdims = bdims )
124
201
batch_norm_fns = ("nn.functional.batch_norm" , "nn.functional.instance_norm" ) # instance norm calls batch norm
125
- if opinfo is not None and opinfo .name in batch_norm_fns :
126
- generator = get_exhaustive_batched_inputs_for_batch_norm (arg_values , kwarg_values , batch_size , bdims = bdims )
202
+ for_batch_norm = opinfo is not None and opinfo .name in batch_norm_fns
203
+ generator = get_exhaustive_batched_inputs (arg_values , kwarg_values , batch_size , for_batch_norm = for_batch_norm )
127
204
for batched_args , in_dims , kwarg_values in generator :
128
205
if compute_loop_out :
129
206
loop_out = loop (op , in_dims , out_dim , batch_size , * batched_args , ** kwarg_values )
0 commit comments