Skip to content

Commit acb46ce

Browse files
authored
Support vmapvjp through F.pad (#357)
1 parent 5dde9b7 commit acb46ce

File tree

3 files changed

+5
-8
lines changed

3 files changed

+5
-8
lines changed

functorch/csrc/BatchingRegistrations.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,10 @@ static void checkBatchDimsAtFrontInLayout(IntArrayRef physical_strides, int64_t
368368
// No example dimensions
369369
return;
370370
}
371+
if (num_batch_dims == 1 && physical_strides.size() > 0 && physical_strides[0] == 0) {
372+
// degenerate batch dim
373+
return;
374+
}
371375
TORCH_CHECK(*smallest_batch_stride >= *largest_example_stride,
372376
"vmap: Calling Tensor.as_strided is not supported unless the batch dims being ",
373377
"vmapped over are at the front of the tensor (in memory layout). When they are ",
@@ -929,7 +933,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
929933
// // // Tensor.new_* operators
930934
// m.impl("ones_like", ones_like_batching_rule);
931935
// // m.impl("new_empty", new_empty_batching_rule);
932-
// m.impl("new_empty_strided", new_empty_strided_batching_rule);
936+
m.impl("new_empty_strided", new_empty_strided_batching_rule);
933937
// // m.impl("new_zeros", new_zeros_batching_rule);
934938
// //
935939
m.impl("contiguous", contiguous_batching_rule);

test/test_ops.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -623,9 +623,6 @@ def test_vmapvjp(self, device, dtype, op):
623623
# xfail "above the line".
624624
xfail('double', 'channels_last'),
625625
626-
# See https://github.com/pytorch/pytorch/issues/66357
627-
xfail('nn.functional.pad', 'circular'),
628-
629626
# RuntimeError: expand: the number of sizes provided (1) must be greater or
630627
# equal to the number of dimensions in the tensor (2)
631628
xfail('nanquantile'),
@@ -690,7 +687,6 @@ def test_vmapjvp(self, device, dtype, op):
690687
xfail('linalg.inv'),
691688
xfail('masked_fill'),
692689
xfail('linalg.tensorinv'),
693-
xfail('nn.functional.pad', 'circular'),
694690
xfail('linalg.matrix_power'),
695691
xfail('maximum'),
696692
xfail('linalg.householder_product'),
@@ -796,7 +792,6 @@ def test_vmapjvpall(self, device, dtype, op):
796792
xfail('nanquantile'),
797793
xfail('nn.functional.conv_transpose2d'),
798794
xfail('nn.functional.gelu'),
799-
xfail('nn.functional.pad', 'circular'),
800795
xfail('norm', 'nuc'),
801796
xfail('pinverse'),
802797
xfail('prod'),

test/test_vmap.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2177,8 +2177,6 @@ def test_new_empty(self):
21772177
result = vmap(vmap(lambda x: op(x, [2, 3])))(torch.randn(B0, B1))
21782178
self.assertEqual(result.shape, [B0, B1, 2, 3])
21792179

2180-
# TODO: new_empty_strided BR
2181-
@unittest.expectedFailure
21822180
def test_new_empty_strided(self):
21832181
# Empty is non-deterministic so we just check that the size and shape
21842182
# of the output are what we expect and that the vmap fallback isn't used

0 commit comments

Comments
 (0)