Skip to content

Commit d95335c

Browse files
authored
Use the smallest batch size we can in vmap testing (#936)
This speeds up test_vmap by 16% real time on my machine
1 parent 8a5465a commit d95335c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

test/common_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def get_exhaustive_batched_inputs_for_batch_norm(arg_values, kwarg_values, batch
119119

120120
def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, opinfo=None, compute_loop_out=True, bdims=(0, -1)):
121121
out_dim = 0
122-
batch_size = 4
122+
batch_size = 2
123123
generator = get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size, bdims=bdims)
124124
batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm") # instance norm calls batch norm
125125
if opinfo is not None and opinfo.name in batch_norm_fns:

0 commit comments

Comments
 (0)