Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 731e05f

Browse files
committed
Make one of the boxed fallbacks faster
We should do this for the others
1 parent ea180d2 commit 731e05f

File tree

1 file changed

+54
-51
lines changed

1 file changed

+54
-51
lines changed

functorch/csrc/BatchRulesHelper.h

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -179,77 +179,80 @@ inline void handle_variadic_bdims(std::vector<std::pair<Tensor, optional<int64_t
179179
#define VARIADIC_BDIMS_BOXED(op) \
180180
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_variadic_bdims), &handle_variadic_bdims>>());
181181

182+
using UnpackedBatchedTensor = std::tuple<Tensor,optional<int64_t>>;
183+
184+
inline void find_and_unpack_tensors(
185+
const torch::jit::Stack* stack,
186+
int64_t num_args,
187+
int64_t cur_level,
188+
SmallVector<UnpackedBatchedTensor, 5>* tensors,
189+
SmallVector<int64_t, 5>* tensors_pos,
190+
int64_t* batch_size) {
191+
192+
int64_t computed_batch_size = -1;
193+
int64_t args_begin = stack->size() - num_args;
194+
195+
for (const auto idx : c10::irange(0, num_args)) {
196+
const auto& ivalue = (*stack)[args_begin + idx];
197+
if (!ivalue.isTensor()) {
198+
continue;
199+
}
200+
auto unpacked = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
201+
const auto& tensor_value = std::get<0>(unpacked);
202+
const auto tensor_bdim = std::get<1>(unpacked);
203+
if (tensor_bdim.has_value()) {
204+
auto candidate_batch_size = tensor_value.size(*tensor_bdim);
205+
if (computed_batch_size == -1) {
206+
computed_batch_size = candidate_batch_size;
207+
}
208+
TORCH_INTERNAL_ASSERT(candidate_batch_size == computed_batch_size);
209+
}
210+
211+
tensors->push_back(std::move(unpacked));
212+
tensors_pos->push_back(idx);
213+
}
214+
TORCH_INTERNAL_ASSERT(computed_batch_size > -1);
215+
*batch_size = computed_batch_size;
216+
}
217+
182218
inline void boxed_existing_bdim_all_batch_rule(
183219
const c10::OperatorHandle& op, torch::jit::Stack* stack) {
184220
const auto& schema = op.schema();
185221
const auto num_returns = schema.returns().size();
186222
const auto num_arguments = schema.arguments().size();
187-
auto arguments = torch::jit::pop(*stack, num_arguments);
188223

189224
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
190225
auto maybe_layer = maybeCurrentDynamicLayer();
191226
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
192227
int64_t cur_level = maybe_layer->layerId();
193228

194-
std::vector<std::pair<Tensor, optional<int64_t>>> tensor_inputs;
195-
std::vector<int64_t> tensor_pos;
196-
for (const auto idx : c10::irange(0, num_arguments)) {
197-
const auto& ivalue = arguments[idx];
198-
if (!ivalue.isTensor()) {
199-
continue;
200-
}
201-
Tensor tensor_value;
202-
optional<int64_t> tensor_bdim;
203-
std::tie(tensor_value, tensor_bdim) = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
204-
tensor_inputs.push_back(std::make_pair(tensor_value, tensor_bdim));
205-
tensor_pos.push_back(idx);
206-
}
229+
int64_t args_begin = stack->size() - num_arguments;
230+
SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
231+
SmallVector<int64_t, 5> tensor_pos;
232+
int64_t batch_size;
207233

208-
// compute batch size...
209-
int64_t batch_size = -1;
210-
for (const auto& tensor_input : tensor_inputs) {
211-
const auto& value = tensor_input.first;
212-
const auto& bdim = tensor_input.second;
213-
if (!bdim) {
214-
continue;
215-
}
216-
if (batch_size == -1) {
217-
batch_size = value.size(*bdim);
218-
}
219-
TORCH_INTERNAL_ASSERT(batch_size == value.size(*bdim));
220-
}
234+
find_and_unpack_tensors(
235+
stack, num_arguments, cur_level,
236+
&tensor_inputs, &tensor_pos, &batch_size);
221237

222238
// for each tensor, ensure it has a bdim and reshape it.
223-
for (auto& tensor_input : tensor_inputs) {
224-
auto value = tensor_input.first;
225-
auto bdim = tensor_input.second;
226-
value = ensure_has_bdim(value, bdim.has_value(), batch_size);
239+
for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
240+
const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
241+
auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
242+
auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
227243
if (!bdim.has_value()) {
228244
bdim = 0;
229245
}
230-
tensor_input.first = reshape_dim_into(*bdim, 0, value);
231-
}
232-
233-
size_t tensor_idx = 0;
234-
TORCH_INTERNAL_ASSERT(tensor_pos.size() > 0);
235-
for (const auto arg_idx : c10::irange(0, num_arguments)) {
236-
if (tensor_idx >= tensor_pos.size() || (int64_t)arg_idx != tensor_pos[tensor_idx]) {
237-
torch::jit::push(stack, arguments[arg_idx]);
238-
} else {
239-
TORCH_INTERNAL_ASSERT(tensor_idx < tensor_inputs.size());
240-
torch::jit::push(stack, tensor_inputs[tensor_idx].first);
241-
tensor_idx++;
242-
}
246+
(*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(*bdim, 0, value_);
243247
}
244248

245249
op.callBoxed(stack);
246-
const auto returns = torch::jit::pop(*stack, num_returns);
247-
for (const auto& ret : returns) {
248-
if (ret.isTensor()) {
249-
torch::jit::push(stack, makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level));
250-
} else {
251-
TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values");
252-
}
250+
251+
for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
252+
const auto& ret = (*stack)[idx];
253+
TORCH_INTERNAL_ASSERT(ret.isTensor(),
254+
"This boxed batching rule does not currently support ops that return non-tensor values");
255+
(*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
253256
}
254257
}
255258

0 commit comments

Comments
 (0)