Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 7c4453d

Browse files
authored
More batch rule fixes (#348)
1 parent cd881a7 commit 7c4453d

File tree

5 files changed

+94
-8
lines changed

5 files changed

+94
-8
lines changed

functorch/csrc/BatchRulesDecompositions.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
234234
OP_DECOMPOSE2(conv2d, padding);
235235
OP_DECOMPOSE2(conv3d, padding);
236236
OP_DECOMPOSE(_convolution_mode);
237+
OP_DECOMPOSE(frobenius_norm);
237238
OP_DECOMPOSE(type_as);
238239
DECOMPOSE_FUNCTIONAL(diag_embed);
239240
DECOMPOSE_FUNCTIONAL(block_diag);

functorch/csrc/PyTorchOperatorHacks.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#include <functorch/csrc/Constants.h>
33
#include <torch/library.h>
44
#include <ATen/ATen.h>
5+
#include <functorch/csrc/TensorWrapper.h>
6+
#include <functorch/csrc/BatchedTensorImpl.h>
57

68
namespace at { namespace functorch {
79

@@ -34,9 +36,98 @@ Tensor index_select_backward_hack(const Tensor& grad, IntArrayRef self_sizes, in
3436
return at::zeros(self_sizes, grad.options()).index_add(dim, index, grad);
3537
}
3638

39+
// TODO: https://github.com/pytorch/pytorch/issues/69991
40+
Tensor frobenius_norm_dim_hack(const Tensor& self, IntArrayRef dim, bool keepdim) {
41+
if (dim.size() == 1 || dim.size() == 0) {
42+
return at::norm(self, 2, dim, keepdim);
43+
} else {
44+
auto dim_ = dim.vec();
45+
maybe_wrap_dims(dim_, self.dim());
46+
TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead");
47+
if (self.is_complex()){
48+
return at::sqrt(at::sum(at::real(self.conj() * self), dim_, keepdim));
49+
} else {
50+
return at::sqrt(at::sum((self * self), dim_, keepdim));
51+
}
52+
}
53+
}
54+
55+
static optional<std::tuple<Tensor,int64_t>> unwrap(const Tensor& tensor) {
56+
auto* wrapped = maybeGetTensorWrapper(tensor);
57+
if (wrapped) {
58+
if (wrapped->level().has_value()) {
59+
return std::make_tuple(wrapped->value(), *wrapped->level());
60+
}
61+
return unwrap(wrapped->value());
62+
}
63+
auto* batched = maybeGetBatchedImpl(tensor);
64+
if (batched) {
65+
return std::make_tuple(batched->value(), batched->level());
66+
}
67+
return nullopt;
68+
}
69+
70+
static bool can_perform_inplace(const Tensor& a, const Tensor& b) {
71+
// TODO: generalize this to more transforms
72+
auto a_ = unwrap(a);
73+
auto b_ = unwrap(b);
74+
if (!a_.has_value() && b_.has_value()) {
75+
return false;
76+
}
77+
if (!a_.has_value() && !b_.has_value()) {
78+
return true;
79+
}
80+
if (a_.has_value() && !b_.has_value()) {
81+
return true;
82+
}
83+
TORCH_INTERNAL_ASSERT(a_.has_value() && b_.has_value());
84+
85+
// If b has any wrapper that a does not, then we cannot do a.inplace_(b)
86+
if (std::get<1>(*a_) < std::get<1>(*b_)) {
87+
return false;
88+
}
89+
if (std::get<1>(*a_) > std::get<1>(*b_)) {
90+
return can_perform_inplace(std::get<0>(*a_), b);
91+
}
92+
return can_perform_inplace(std::get<0>(*a_), std::get<0>(*b_));
93+
}
94+
95+
// TODO: linear is pretty important for performance, but I'm not sure how to work
96+
// around the in-place.
97+
Tensor linear_hack(const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt) {
98+
// See [Note: hacky wrapper removal for optional tensor]
99+
auto bias = bias_opt.has_value()
100+
? c10::MaybeOwned<Tensor>::borrowed(*bias_opt)
101+
: c10::MaybeOwned<Tensor>::owned(c10::in_place);
102+
103+
if (input.is_mkldnn()) {
104+
return at::mkldnn_linear(input, weight, *bias);
105+
}
106+
#if defined(C10_MOBILE)
107+
if (xnnpack::use_linear(input, weight, *bias)) {
108+
return xnnpack::linear(input, weight, *bias);
109+
}
110+
#endif
111+
if (input.dim() == 2 && bias->defined()) {
112+
// Fused op is marginally faster.
113+
return at::addmm(*bias, input, weight.t());
114+
}
115+
auto output = at::matmul(input, weight.t());
116+
if (bias->defined()) {
117+
// TODO(rzou): I'm a little uncomfortable with this
118+
if (can_perform_inplace(output, *bias)) {
119+
return output.add_(*bias);
120+
}
121+
return output.add(*bias);
122+
}
123+
return output;
124+
}
125+
37126
TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
38127
m.impl("value_selecting_reduction_backward", value_selecting_reduction_backward_hack);
39128
m.impl("index_select_backward", index_select_backward_hack);
129+
m.impl("frobenius_norm.dim", frobenius_norm_dim_hack);
130+
m.impl("linear", linear_hack);
40131
}
41132

42133
}}

test/discover_coverage.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ def print_coverage_info(th=100, nn=25):
374374
'torch.nonzero', # dynamic
375375
'torch.masked_select', # dynamic
376376
'torch.prod', # dynamic (backward)
377+
'torch.norm', # norm with nuc is not commonly used.
377378
}
378379
remove_from_set(statuses['test_vmap_exhaustive'], vmap_exemptions)
379380
remove_from_set(statuses['test_vmapvjp'], vmap_exemptions)

test/test_ops.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,6 @@ def vjp_of_vjp(*args_and_cotangents):
459459
xfail('masked_scatter'),
460460
xfail('matrix_exp'),
461461
xfail('nanquantile'),
462-
xfail('norm', 'fro'),
463462
xfail('norm', 'nuc'),
464463
xfail('prod'),
465464
xfail('put'),
@@ -481,17 +480,13 @@ def vjp_of_vjp(*args_and_cotangents):
481480
xfail('_masked.prod'), # calls aten::item
482481
xfail('stft'),
483482
xfail('nn.functional.glu'),
484-
485483
xfail('nn.functional.fractional_max_pool3d'),
486484
xfail('as_strided'),
487485
xfail('nn.functional.fractional_max_pool2d'),
488486
})
489487
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
490488
@skipOps('TestOperators', 'test_vmapvjp', vmapvjp_fail)
491489
def test_vmapvjp(self, device, dtype, op):
492-
# These are too annoying to put into the list above
493-
if op.name in {'nn.functional.linear'}:
494-
self.skipTest("Skipped! ExpectedF failures")
495490
if not op.supports_autograd:
496491
self.skipTest("Skipped! Autograd not supported.")
497492
return
@@ -741,7 +736,6 @@ def test_vmapjvpall(self, device, dtype, op):
741736
xfail('nn.functional.conv_transpose2d'),
742737
xfail('nn.functional.gelu'),
743738
xfail('nn.functional.pad', 'circular'),
744-
xfail('norm', 'fro'),
745739
xfail('norm', 'nuc'),
746740
xfail('pinverse'),
747741
xfail('prod'),
@@ -794,7 +788,7 @@ def test_vmapjvpall(self, device, dtype, op):
794788
}))
795789
def test_vmapvjp_has_batch_rule(self, device, dtype, op):
796790
# These are too annoying to put into the list above
797-
if op.name in {'nn.functional.linear', 'nn.functional.conv2d'}:
791+
if op.name in {'nn.functional.conv2d'}:
798792
self.skipTest("Skipped! ExpectedF failures")
799793
if not op.supports_autograd:
800794
self.skipTest("Skipped! Autograd not supported.")

test/test_vmap.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3197,7 +3197,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
31973197
xfail('masked_scatter'),
31983198
xfail('masked_select'),
31993199
xfail('nanquantile'),
3200-
xfail('norm', 'fro'),
32013200
xfail('norm', 'nuc'),
32023201
xfail('ormqr'),
32033202
xfail('put'),

0 commit comments

Comments
 (0)