Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit aa2cd89

Browse files
author
Samantha Andow
authored
Prelu batching rule (forward + backward) (#609)
* prelu forward rule * prelu backward rule
1 parent 89baedd commit aa2cd89

File tree

3 files changed

+181
-2
lines changed

3 files changed

+181
-2
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include <functorch/csrc/BatchRulesHelper.h>
8+
#include <functorch/csrc/PlumbingHelper.h>
9+
#include <ATen/Operators.h>
10+
11+
// NB: most activation functions fit pointwise unary or binary rules.
12+
// These are only the ones that have special batch rules to help with organization
13+
namespace at { namespace functorch {
14+
std::tuple<Tensor,optional<int64_t>> prelu_batch_rule(
15+
const Tensor& input, optional<int64_t> input_bdim,
16+
const Tensor& weight, optional<int64_t> weight_bdim) {
17+
if (!weight_bdim && weight.dim() == 0) {
18+
return std::make_tuple(at::prelu(input, weight), input_bdim);
19+
}
20+
21+
const auto input_ = moveBatchDimToFront(input, input_bdim);
22+
auto weight_flatten = moveBatchDimToFront(weight, weight_bdim);
23+
24+
if (weight_flatten.dim() > 1) {
25+
// for an input [N, C, ...]
26+
// weight can be a non-vector but the total number of elements must be the same as C
27+
weight_flatten = at::flatten(weight_flatten, weight_bdim.has_value() ? 1 : 0, -1);
28+
}
29+
30+
const int64_t input_logical_rank = rankWithoutBatchDim(input, input_bdim);
31+
VmapDimVector new_shape(weight_flatten.sizes().begin(), weight_flatten.sizes().end());
32+
const int64_t final_size = weight_bdim ? (input_logical_rank + 1) : input_logical_rank;
33+
new_shape.reserve(final_size);
34+
35+
if (weight_flatten.dim() == 2 || !weight_bdim) {
36+
// if weight (without batching) is not a scalar, its size must match the "channel dimension" of input. To do the
37+
// decomposition, we pad the weight to
38+
39+
// copies checks from prelu if the weight (without vmap) is not a scalar
40+
TORCH_CHECK(input_logical_rank > 0, "Not allow zero-dim input tensor.");
41+
42+
int64_t channel_size = 1; // channel_size default to 1
43+
if (input_logical_rank > 1) {
44+
const auto channel_dim = input_bdim ? 2 : 1;
45+
channel_size = input_.size(channel_dim);
46+
}
47+
const auto weight_num = weight_flatten.size(-1);
48+
TORCH_CHECK(channel_size == weight_num,
49+
"Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
50+
" and channel size = ", channel_size, ".");
51+
52+
// pads to the left so that the flattened shape matches up with the channel
53+
if (!weight_bdim) {
54+
new_shape.insert(new_shape.begin(), 1);
55+
} else {
56+
new_shape.insert(new_shape.begin() + 1, 1);
57+
}
58+
}
59+
60+
for (int64_t i = new_shape.size(); i < final_size; i ++) {
61+
new_shape.push_back(1);
62+
}
63+
TORCH_INTERNAL_ASSERT(new_shape.size() == final_size);
64+
const auto weight_padded = weight_flatten.view(new_shape);
65+
auto zero_tensor = at::zeros(1, input.options());
66+
67+
// decomposes function,
68+
auto res = at::maximum(zero_tensor, input_) + weight_padded * at::minimum(zero_tensor, input_);
69+
return std::make_tuple(res, 0);
70+
}
71+
72+
VmapDimVector ensure_shape_with_bdim(const Tensor& input, const bool has_bdim, const int64_t batch_size) {
73+
// helper function that get the size of input, ensuring that there's batch dim, without expanding input
74+
if (has_bdim) {
75+
// sad to have to copy but got garbage if tried to return an IntArrayRef and just do input.sizes()
76+
VmapDimVector new_shape(input.sizes().begin(), input.sizes().end());
77+
return new_shape;
78+
}
79+
VmapDimVector new_shape(1, batch_size);
80+
new_shape.reserve(input.dim() + 1);
81+
new_shape.insert(new_shape.end(), input.sizes().begin(), input.sizes().end());
82+
return new_shape;
83+
}
84+
85+
VmapDimVector shape_maybe_with_bdim(const Tensor& input, const bool need_bdim, const bool has_bdim, const int64_t batch_size) {
86+
// if need_bdim, will return the input with a guaranteed bdim. If not, will return the input logical size (no batch dim)
87+
if (need_bdim) {
88+
return ensure_shape_with_bdim(input, has_bdim, batch_size);
89+
} else if (has_bdim) { // !need_bdim && has_bdim
90+
VmapDimVector new_shape(input.sizes().begin() + 1, input.sizes().end());
91+
return new_shape;
92+
} else { // !need_bdim && !has_bdim
93+
VmapDimVector new_shape(input.sizes().begin(), input.sizes().end());
94+
return new_shape;
95+
}
96+
}
97+
98+
std::tuple<Tensor, Tensor> prelu_backward_batched(
99+
const Tensor& grad_out, const Tensor& self, const Tensor& weight,
100+
const VmapDimVector& self_grad_shape, const VmapDimVector& weight_grad_padded_shape, const VmapDimVector& weight_grad_shape) {
101+
// helper function that produces a batched gradient for prelu using a decomposition inspired by the AOTAutograd ones
102+
const auto input_grad_collector = at::where(self > 0, grad_out, weight * grad_out);
103+
const auto input_grad = native::sum_to_size(input_grad_collector, self_grad_shape);
104+
const auto weight_grad_collector = at::where(self > 0, at::zeros(1, self.options()), self * grad_out);
105+
const auto weight_grad_collector_2 = native::sum_to_size(weight_grad_collector, weight_grad_padded_shape);
106+
const auto weight_grad = weight_grad_collector_2.view(weight_grad_shape);
107+
return std::make_tuple(input_grad, weight_grad);
108+
}
109+
110+
std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>> prelu_backward_batch_rule(
111+
const Tensor& grad_out, optional<int64_t> grad_out_bdim,
112+
const Tensor& self, optional<int64_t> self_bdim,
113+
const Tensor& weight, optional<int64_t> weight_bdim) {
114+
const auto batch_size = get_bdim_size3(grad_out, grad_out_bdim, self, self_bdim, weight, weight_bdim);
115+
const auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
116+
const auto self_ = moveBatchDimToFront(self, self_bdim);
117+
const auto self_size_with_bdim = ensure_shape_with_bdim(self_, self_bdim.has_value(), batch_size);
118+
if (!weight_bdim && weight.dim() == 0) {
119+
VmapDimVector weight_grad_shape(1, batch_size);
120+
VmapDimVector weight_grad_shape_padded(self_bdim.has_value() ? self.dim() : self.dim() + 1, 1);
121+
weight_grad_shape_padded[0] = batch_size;
122+
const auto grads = prelu_backward_batched(grad_out_, self_, weight, self_size_with_bdim, weight_grad_shape_padded, weight_grad_shape);
123+
return std::make_tuple(std::get<0>(grads), 0, std::get<1>(grads), 0);
124+
}
125+
const auto weight_ = moveBatchDimToFront(weight, weight_bdim);
126+
auto weight_flatten = weight_;
127+
if (weight_flatten.dim() > 1) {
128+
// for an input [N, C, ...]
129+
// weight can be a non-vector but the total number of elements must be the same as C
130+
weight_flatten = at::flatten(weight_flatten, weight_bdim.has_value() ? 1 : 0, -1);
131+
}
132+
133+
const int64_t self_logical_rank = rankWithoutBatchDim(self, self_bdim);
134+
VmapDimVector new_shape(weight_flatten.sizes().begin(), weight_flatten.sizes().end());
135+
const int64_t final_size = weight_bdim ? (self_logical_rank + 1) : self_logical_rank;
136+
new_shape.reserve(final_size);
137+
138+
if (weight_flatten.dim() == 2 || !weight_bdim) {
139+
// if weight (without batching) is not a scalar, its size must match the "channel dimension" of input. To do the
140+
// decomposition, we pad the weight to
141+
142+
// copies checks from prelu if the weight (without vmap) is not a scalar
143+
TORCH_CHECK(self_logical_rank > 0, "Not allow zero-dim input tensor.");
144+
145+
int64_t channel_size = 1; // channel_size default to 1
146+
if (self_logical_rank > 1) {
147+
channel_size = self_.size(self_bdim.has_value() ? 2 : 1);
148+
}
149+
150+
const auto weight_num = weight_flatten.size(-1);
151+
TORCH_CHECK(channel_size == weight_num,
152+
"Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
153+
" and channel size = ", channel_size, ".");
154+
155+
// pads to the left so that the flattened shape matches up with the channel
156+
if (!weight_bdim) {
157+
new_shape.insert(new_shape.begin(), 1);
158+
} else {
159+
new_shape.insert(new_shape.begin() + 1, 1);
160+
}
161+
}
162+
163+
for (int64_t i = new_shape.size(); i < final_size; i ++) {
164+
new_shape.push_back(1);
165+
}
166+
// weight grad does not depend on weight values. It is batched iff grad_out or self are batched
167+
const auto weight_grad_is_batched = grad_out_bdim.has_value() || self_bdim.has_value();
168+
169+
const auto weight_padded = weight_flatten.view(new_shape);
170+
const auto weight_grad_shape = shape_maybe_with_bdim(weight_, weight_grad_is_batched, weight_bdim.has_value(), batch_size);
171+
const auto weight_padded_grad_shape = shape_maybe_with_bdim(weight_padded, weight_grad_is_batched, weight_bdim.has_value(), batch_size);
172+
173+
const auto grads = prelu_backward_batched(grad_out_, self_, weight_padded, self_size_with_bdim, weight_padded_grad_shape, weight_grad_shape);
174+
return std::make_tuple(std::get<0>(grads), 0, std::get<1>(grads), (weight_grad_is_batched ? optional<int64_t>(0) : nullopt));
175+
}
176+
177+
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
178+
VMAP_SUPPORT(prelu, prelu_batch_rule)
179+
VMAP_SUPPORT(prelu_backward, prelu_backward_batch_rule)
180+
}
181+
}} // namespace at::functorch

test/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,6 @@ def test():
973973
xfail('nn.functional.huber_loss'),
974974
xfail('nn.functional.poisson_nll_loss'),
975975
xfail('nn.functional.bilinear'),
976-
xfail('nn.functional.prelu'),
977976
xfail('nn.functional.glu'),
978977
xfail('nn.functional.fractional_max_pool3d'),
979978
xfail('as_strided'),

test/test_vmap.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3240,7 +3240,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
32403240
xfail('stft'),
32413241
xfail('linalg.solve_triangular'),
32423242
xfail('nn.functional.glu'),
3243-
xfail('nn.functional.prelu'),
32443243
xfail('isclose'),
32453244
xfail('nn.functional.fractional_max_pool3d'),
32463245
xfail('nn.functional.bilinear'),

0 commit comments

Comments
 (0)