@@ -400,31 +400,6 @@ Tensor slice_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes,
400
400
return grad_physical.getPhysicalToLogicalMap ().apply (grad_input);
401
401
}
402
402
403
- Tensor diagonal_batching_rule (const Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) {
404
- if (!participatesInCurrentLevel (self)) {
405
- c10::impl::ExcludeDispatchKeyGuard guard (kBatchedKey );
406
- return at::diagonal (self, offset, dim1, dim2);
407
- }
408
- auto self_physical = MultiBatchVmapTransform::logicalToPhysical (self);
409
- auto dim1_physical = self_physical.getPhysicalDim (dim1);
410
- auto dim2_physical = self_physical.getPhysicalDim (dim2);
411
- auto result = at::diagonal (self_physical.tensor (), offset, dim1_physical, dim2_physical);
412
- return self_physical.getPhysicalToLogicalMap ().apply (result);
413
- }
414
-
415
- Tensor diagonal_backward_batching_rule (const Tensor& grad, IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
416
- if (!participatesInCurrentLevel (grad)) {
417
- c10::impl::ExcludeDispatchKeyGuard guard (kBatchedKey );
418
- return at::diagonal_backward (grad, input_sizes, offset, dim1, dim2);
419
- }
420
- auto grad_physical = MultiBatchVmapTransform::logicalToPhysical (grad);
421
- auto grad_input = at::zeros (grad_physical.getPhysicalShape (input_sizes), grad.options ());
422
- auto dim1_physical = getGradInputPhysicalDim (dim1, input_sizes, grad_physical.numBatchDims ());
423
- auto dim2_physical = getGradInputPhysicalDim (dim2, input_sizes, grad_physical.numBatchDims ());
424
- grad_input.diagonal (offset, dim1_physical, dim2_physical).copy_ (grad_physical.tensor ());
425
- return grad_physical.getPhysicalToLogicalMap ().apply (grad_input);
426
- }
427
-
428
403
Tensor movedim_batching_rule (const Tensor& self, IntArrayRef source, IntArrayRef destination) {
429
404
if (!participatesInCurrentLevel (self)) {
430
405
c10::impl::ExcludeDispatchKeyGuard guard (kBatchedKey );
@@ -1026,7 +1001,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
1026
1001
// m.impl("chunk", chunk_batching_rule);
1027
1002
m.impl (" tensor_split.sections" , tensor_split_sections_batching_rule);
1028
1003
m.impl (" tensor_split.indices" , tensor_split_indices_batching_rule);
1029
- m.impl (" diagonal" , diagonal_batching_rule);
1030
1004
m.impl (" expand" , expand_batching_rule);
1031
1005
m.impl (" movedim.intlist" , movedim_batching_rule);
1032
1006
m.impl (" movedim.int" , static_cast <Tensor (*)(const Tensor&,int64_t ,int64_t )>(native::movedim)); // composite wrt autograd
@@ -1103,10 +1077,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
1103
1077
m.impl (" is_same_size" , native::is_same_size);
1104
1078
// //
1105
1079
// // // backward operators
1106
- // // m.impl("select_backward", select_backward_batching_rule);
1107
- // // m.impl("slice_backward", slice_backward_batching_rule);
1108
1080
// // m.impl("trace_backward", trace_backward_batching_rule);
1109
- // // m.impl("diagonal_backward", diagonal_backward_batching_rule);
1110
1081
// //
1111
1082
// // // Tensor.new_* operators
1112
1083
// m.impl("ones_like", ones_like_batching_rule);
0 commit comments