Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit ae89243

Browse files
authored
[port] diagonal to new api (#165)
1 parent e1f6f77 commit ae89243

File tree

2 files changed

+13
-29
lines changed

2 files changed

+13
-29
lines changed

functorch/csrc/BatchRulesViews.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,18 @@ std::tuple<Tensor, optional<int64_t>> roll_batch_rule(const Tensor& self, option
302302
return std::make_tuple(output, 0);
303303
}
304304

305+
std::tuple<Tensor, optional<int64_t>> diagonal_batching_rule(
306+
const Tensor &self, optional<int64_t> self_bdim,
307+
int64_t offset, int64_t dim1, int64_t dim2)
308+
{
309+
auto logical_rank = rankWithoutBatchDim(self, self_bdim);
310+
auto self_ = moveBatchDimToFront(self, self_bdim);
311+
auto dim1_ = maybe_wrap_dim(dim1, logical_rank) + 1;
312+
auto dim2_ = maybe_wrap_dim(dim2, logical_rank) + 1;
313+
auto result = at::diagonal(self_, offset, dim1_, dim2_);
314+
return std::make_tuple(std::move(result), 0);
315+
}
316+
305317
std::tuple<Tensor,optional<int64_t>> diagonal_backward_batch_rule(
306318
const Tensor& grad_input, optional<int64_t> grad_input_bdim,
307319
IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
@@ -359,6 +371,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
359371
VMAP_SUPPORT("squeeze.dim", squeeze_dim_batch_rule);
360372
VMAP_SUPPORT("_reshape_alias", _reshape_alias_batch_rule);
361373
VMAP_SUPPORT("roll", roll_batch_rule);
374+
VMAP_SUPPORT("diagonal", diagonal_batching_rule);
362375
VMAP_SUPPORT("diagonal_backward", diagonal_backward_batch_rule);
363376
VMAP_SUPPORT("select_backward", select_backward_batch_rule);
364377
VMAP_SUPPORT("slice_backward", slice_backward_batch_rule);

functorch/csrc/BatchingRegistrations.cpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -400,31 +400,6 @@ Tensor slice_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes,
400400
return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
401401
}
402402

403-
Tensor diagonal_batching_rule(const Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) {
404-
if (!participatesInCurrentLevel(self)) {
405-
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
406-
return at::diagonal(self, offset, dim1, dim2);
407-
}
408-
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
409-
auto dim1_physical = self_physical.getPhysicalDim(dim1);
410-
auto dim2_physical = self_physical.getPhysicalDim(dim2);
411-
auto result = at::diagonal(self_physical.tensor(), offset, dim1_physical, dim2_physical);
412-
return self_physical.getPhysicalToLogicalMap().apply(result);
413-
}
414-
415-
Tensor diagonal_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
416-
if (!participatesInCurrentLevel(grad)) {
417-
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
418-
return at::diagonal_backward(grad, input_sizes, offset, dim1, dim2);
419-
}
420-
auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
421-
auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
422-
auto dim1_physical = getGradInputPhysicalDim(dim1, input_sizes, grad_physical.numBatchDims());
423-
auto dim2_physical = getGradInputPhysicalDim(dim2, input_sizes, grad_physical.numBatchDims());
424-
grad_input.diagonal(offset, dim1_physical, dim2_physical).copy_(grad_physical.tensor());
425-
return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
426-
}
427-
428403
Tensor movedim_batching_rule(const Tensor& self, IntArrayRef source, IntArrayRef destination) {
429404
if (!participatesInCurrentLevel(self)) {
430405
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
@@ -1026,7 +1001,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
10261001
// m.impl("chunk", chunk_batching_rule);
10271002
m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
10281003
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
1029-
m.impl("diagonal", diagonal_batching_rule);
10301004
m.impl("expand", expand_batching_rule);
10311005
m.impl("movedim.intlist", movedim_batching_rule);
10321006
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
@@ -1103,10 +1077,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
11031077
m.impl("is_same_size", native::is_same_size);
11041078
// //
11051079
// // // backward operators
1106-
// // m.impl("select_backward", select_backward_batching_rule);
1107-
// // m.impl("slice_backward", slice_backward_batching_rule);
11081080
// // m.impl("trace_backward", trace_backward_batching_rule);
1109-
// // m.impl("diagonal_backward", diagonal_backward_batching_rule);
11101081
// //
11111082
// // // Tensor.new_* operators
11121083
// m.impl("ones_like", ones_like_batching_rule);

0 commit comments

Comments
 (0)