Skip to content

Commit 020c124

Browse files
authored
port and fix unfold batching rule (#206)
1 parent 5497d75 commit 020c124

File tree

5 files changed

+18
-18
lines changed

5 files changed

+18
-18
lines changed

functorch/csrc/BatchRulesScatterOps.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,6 @@ std::tuple<Tensor,optional<int64_t>> scatter_reduce_batch_rule(
293293
const Tensor& index, optional<int64_t> index_bdim,
294294
const Tensor& src, optional<int64_t> src_bdim,
295295
const c10::string_view reduce) {
296-
using scatter_reduce_value_sig = Tensor (*)(const Tensor&, int64_t, const Tensor&, const Tensor&, const c10::string_view reduce);
297296
return scatter_batch_rule(ATEN_FN2(scatter, reduce),
298297
self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce);
299298
}
@@ -304,7 +303,6 @@ std::tuple<Tensor,optional<int64_t>> scatter_value_reduce_batch_rule(
304303
const Tensor& index, optional<int64_t> index_bdim,
305304
const Scalar& src,
306305
const c10::string_view reduce) {
307-
using scatter_reduce_value_sig = Tensor (*)(const Tensor&, int64_t, const Tensor&, const Scalar&, const c10::string_view reduce);
308306
return scatter_batch_rule(ATEN_FN2(scatter, value_reduce),
309307
self, self_bdim, dim, index, index_bdim, src, reduce);
310308
}

functorch/csrc/BatchRulesViews.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,23 @@ std::tuple<Tensor, optional<int64_t>> expand_batch_rule(
429429
return std::make_tuple(self_.view(view_shape).expand(size_, implicit), 0);
430430
}
431431

432+
std::tuple<Tensor, optional<int64_t>> unfold_batch_rule(
433+
const Tensor &self, optional<int64_t> self_bdim, int64_t dim, int64_t size, int64_t step)
434+
{
435+
TORCH_INTERNAL_ASSERT(self_bdim.has_value());
436+
auto self_ = moveBatchDimToFront(self, self_bdim);
437+
auto logical_rank = rankWithoutBatchDim(self, self_bdim);
438+
dim = maybe_wrap_dim(dim, logical_rank) + 1;
439+
if (logical_rank==0) {
440+
self_ = self_.unsqueeze(-1);
441+
}
442+
auto result = self_.unfold(dim, size, step);
443+
if (logical_rank==0) {
444+
result = result.squeeze(-1);
445+
}
446+
return std::make_tuple(result, 0);
447+
}
448+
432449
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
433450
VMAP_SUPPORT("diag", diag_batch_rule);
434451
VMAP_SUPPORT("chunk", chunk_batching_rule);
@@ -453,6 +470,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
453470
VMAP_SUPPORT("slice_backward", slice_backward_batch_rule);
454471
VMAP_SUPPORT("view", view_batching_rule);
455472
VMAP_SUPPORT("expand", expand_batch_rule);
473+
VMAP_SUPPORT("unfold", unfold_batch_rule);
456474
}
457475

458476
}}

functorch/csrc/BatchingRegistrations.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -356,17 +356,6 @@ std::vector<Tensor> unbind_batching_rule(const Tensor& self, int64_t dim) {
356356
return result;
357357
}
358358

359-
Tensor unfold_batching_rule(const Tensor& self, int64_t dim, int64_t size, int64_t step) {
360-
if (!participatesInCurrentLevel(self)) {
361-
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
362-
return self.unfold(dim, size, step);
363-
}
364-
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
365-
auto dim_physical = self_physical.getPhysicalDim(dim);
366-
auto result = self_physical.tensor().unfold(dim_physical, size, step);
367-
return self_physical.getPhysicalToLogicalMap().apply(result);
368-
}
369-
370359
Tensor contiguous_batching_rule(const Tensor& self, MemoryFormat memory_format) {
371360
if (!participatesInCurrentLevel(self)) {
372361
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
@@ -909,7 +898,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
909898
// m.impl("trace", trace_batching_rule);
910899
m.impl("transpose.int", transpose_int_batching_rule);
911900
m.impl("unbind.int", unbind_batching_rule);
912-
m.impl("unfold", unfold_batching_rule);
913901
m.impl("unsqueeze_", unsqueeze__batching_rule);
914902
m.impl("view_as", native::view_as); // composite wrt autograd
915903

test/test_ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,6 @@ def vjp_of_vjp(*args_and_cotangents):
355355
xfail('quantile'),
356356
xfail('symeig'),
357357
xfail('take'),
358-
xfail('unfold'),
359358
xfail('linalg.tensorinv'),
360359
xfail('nn.functional.conv_transpose2d', device_type='cuda'),
361360
xfail('nanmean'),
@@ -527,7 +526,6 @@ def test():
527526
xfail('gradient'),
528527
xfail('hsplit'),
529528
xfail('nn.functional.pad', 'circular'),
530-
xfail('unfold'),
531529
xfail('vsplit'),
532530
xfail('dstack'),
533531
xfail('hstack'),

test/test_vmap.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3005,7 +3005,6 @@ class TestVmapOperatorsOpInfo(TestCase):
30053005
30063006
# entries in here don't work and need to be fixed.
30073007
# Each one of these is a bug
3008-
xfail('unfold'),
30093008
xfail('svd', device_type='cuda'),
30103009
xfail('linalg.svd', device_type='cuda'),
30113010
xfail('index_put'),
@@ -3097,7 +3096,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
30973096
xfail('take_along_dim'),
30983097
xfail('tensor_split'),
30993098
xfail('to_sparse'),
3100-
xfail('unfold'),
31013099
xfail('vdot'),
31023100
xfail('vsplit'),
31033101
xfail('__getitem__'),

0 commit comments

Comments
 (0)