Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 1fb0a0f

Browse files
committed
Batch rule for avg_pool2d_backward
Boxed kernels take a lot of lines to write... Maybe we can make these easier to write.
1 parent ae54dad commit 1fb0a0f

File tree

4 files changed

+104
-20
lines changed

4 files changed

+104
-20
lines changed

functorch/csrc/BatchRulesHelper.h

Lines changed: 103 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,19 @@ void vmapIncompatibleInplaceError(const char* schema_name);
3232

3333
Tensor maybePadToLogicalRank(const Tensor& tensor, optional<int64_t> has_bdim, int64_t logical_rank);
3434

35+
inline Tensor ensure_has_bdim(const Tensor& tensor, bool has_bdim, int64_t batch_size) {
36+
if (has_bdim) {
37+
return tensor;
38+
}
39+
const auto sizes = tensor.sizes();
40+
DimVector expanded_shape;
41+
expanded_shape.reserve(sizes.size());
42+
expanded_shape.emplace_back(batch_size);
43+
expanded_shape.insert(expanded_shape.end(), sizes.begin(), sizes.end());
44+
return tensor.expand(expanded_shape);
45+
}
46+
47+
3548
#define VMAP_SUPPORT(op, batch_rule) \
3649
m.impl(op, PrimBatchRule7< \
3750
decltype(&batch_rule), &batch_rule, to_operator_t<decltype(batch_rule)> \
@@ -166,7 +179,8 @@ inline void handle_variadic_bdims(std::vector<std::pair<Tensor, optional<int64_t
166179
#define VARIADIC_BDIMS_BOXED(op) \
167180
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_variadic_bdims), &handle_variadic_bdims>>());
168181

169-
inline void boxed_existing_bdim_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
182+
inline void boxed_existing_bdim_all_batch_rule(
183+
const c10::OperatorHandle& op, torch::jit::Stack* stack) {
170184
const auto& schema = op.schema();
171185
const auto num_returns = schema.returns().size();
172186
const auto num_arguments = schema.arguments().size();
@@ -177,19 +191,101 @@ inline void boxed_existing_bdim_batch_rule(const c10::OperatorHandle& op, torch:
177191
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
178192
int64_t cur_level = maybe_layer->layerId();
179193

194+
std::vector<std::pair<Tensor, optional<int64_t>>> tensor_inputs;
195+
std::vector<int64_t> tensor_pos;
196+
for (const auto idx : c10::irange(0, num_arguments)) {
197+
const auto& ivalue = arguments[idx];
198+
if (!ivalue.isTensor()) {
199+
continue;
200+
}
201+
Tensor tensor_value;
202+
optional<int64_t> tensor_bdim;
203+
std::tie(tensor_value, tensor_bdim) = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
204+
tensor_inputs.push_back(std::make_pair(tensor_value, tensor_bdim));
205+
tensor_pos.push_back(idx);
206+
}
207+
208+
// compute batch size...
209+
int64_t batch_size = -1;
210+
for (const auto& tensor_input : tensor_inputs) {
211+
const auto& value = tensor_input.first;
212+
const auto& bdim = tensor_input.second;
213+
if (!bdim) {
214+
continue;
215+
}
216+
if (batch_size == -1) {
217+
batch_size = value.size(*bdim);
218+
}
219+
TORCH_INTERNAL_ASSERT(batch_size == value.size(*bdim));
220+
}
221+
222+
// for each tensor, ensure it has a bdim and reshape it.
223+
for (auto& tensor_input : tensor_inputs) {
224+
auto value = tensor_input.first;
225+
auto bdim = tensor_input.second;
226+
value = ensure_has_bdim(value, bdim.has_value(), batch_size);
227+
if (!bdim.has_value()) {
228+
bdim = 0;
229+
}
230+
tensor_input.first = reshape_dim_into(*bdim, 0, value);
231+
}
232+
233+
size_t tensor_idx = 0;
234+
TORCH_INTERNAL_ASSERT(tensor_pos.size() > 0);
235+
for (const auto arg_idx : c10::irange(0, num_arguments)) {
236+
if (tensor_idx >= tensor_pos.size() || (int64_t)arg_idx != tensor_pos[tensor_idx]) {
237+
torch::jit::push(stack, arguments[arg_idx]);
238+
} else {
239+
TORCH_INTERNAL_ASSERT(tensor_idx < tensor_inputs.size());
240+
torch::jit::push(stack, tensor_inputs[tensor_idx].first);
241+
tensor_idx++;
242+
}
243+
}
244+
245+
op.callBoxed(stack);
246+
const auto returns = torch::jit::pop(*stack, num_returns);
247+
for (const auto& ret : returns) {
248+
if (ret.isTensor()) {
249+
torch::jit::push(stack, makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level));
250+
} else {
251+
TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values");
252+
}
253+
}
254+
}
255+
256+
// Use when all tensors arguments accept one (normal) batch dim.
257+
// This batching rule expands the batch dim on all Tensors, reshapes it into
258+
// dim 0, calls the op, and then reshapes the batch dim out of dim 0.
259+
// This is not the most efficient thing; if there are alternatives, plese try
260+
// to use them. Use this only as a last resort.
261+
#define EXISTING_BDIM_ALL_BOXED(op) \
262+
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>());
263+
264+
inline void boxed_existing_bdim_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
265+
const auto& schema = op.schema();
266+
const auto num_returns = schema.returns().size();
267+
const auto num_arguments = schema.arguments().size();
268+
auto arguments = torch::jit::pop(*stack, num_arguments);
269+
270+
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
271+
auto maybe_layer = maybeCurrentDynamicLayer();
272+
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
273+
int64_t cur_level = maybe_layer->layerId();
180274

181275
std::vector<std::pair<Tensor, optional<int64_t>>> tensor_inputs;
182276
std::vector<int64_t> tensor_pos;
183277
for (const auto idx : c10::irange(0, num_arguments)) {
184278
const auto& ivalue = arguments[idx];
185-
if (ivalue.isTensor()) {
186-
Tensor tensor_value;
187-
optional<int64_t> tensor_bdim;
188-
std::tie(tensor_value, tensor_bdim) = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
189-
tensor_inputs.push_back(std::make_pair(tensor_value, tensor_bdim));
190-
tensor_pos.push_back(idx);
279+
if (!ivalue.isTensor()) {
280+
continue;
191281
}
282+
Tensor tensor_value;
283+
optional<int64_t> tensor_bdim;
284+
std::tie(tensor_value, tensor_bdim) = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
285+
tensor_inputs.push_back(std::make_pair(tensor_value, tensor_bdim));
286+
tensor_pos.push_back(idx);
192287
}
288+
193289
int64_t batch_size = -1;
194290
for (auto& tensor_input : tensor_inputs) {
195291
if (tensor_input.second) {

functorch/csrc/BatchRulesPooling.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
6767
EXISTING_BDIM(_adaptive_avg_pool2d);
6868
EXISTING_BDIM(avg_pool2d);
6969
m.impl("max_pool2d_with_indices_backward", max_pool2d_with_indices_backward_plumbing);
70+
EXISTING_BDIM_ALL_BOXED(avg_pool2d_backward);
7071
}
7172

7273
}}

functorch/csrc/BatchRulesScatterOps.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -156,18 +156,6 @@ Tensor& index_put__plumbing(Tensor & self, const List<optional<Tensor>> & indice
156156
return self;
157157
}
158158

159-
Tensor ensure_has_bdim(const Tensor& tensor, bool has_bdim, int64_t batch_size) {
160-
if (has_bdim) {
161-
return tensor;
162-
}
163-
const auto sizes = tensor.sizes();
164-
DimVector expanded_shape;
165-
expanded_shape.reserve(sizes.size());
166-
expanded_shape.emplace_back(batch_size);
167-
expanded_shape.insert(expanded_shape.end(), sizes.begin(), sizes.end());
168-
return tensor.expand(expanded_shape);
169-
}
170-
171159
int64_t bdim_size(
172160
const Tensor& a, optional<int64_t> a_bdim,
173161
const Tensor& b, optional<int64_t> b_bdim,

test/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,6 @@ def test_vmapvjp(self, device, dtype, op):
437437
xfail('nanmedian'),
438438
xfail('nanquantile'),
439439
xfail('nn.functional.adaptive_avg_pool2d'),
440-
xfail('nn.functional.avg_pool2d'),
441440
xfail('nn.functional.conv_transpose2d'),
442441
xfail('nn.functional.cross_entropy', 'mean'),
443442
xfail('nn.functional.cross_entropy', 'none'),

0 commit comments

Comments
 (0)