@@ -179,77 +179,80 @@ inline void handle_variadic_bdims(std::vector<std::pair<Tensor, optional<int64_t
179
179
#define VARIADIC_BDIMS_BOXED (op ) \
180
180
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype (&handle_variadic_bdims), &handle_variadic_bdims>>());
181
181
182
+ using UnpackedBatchedTensor = std::tuple<Tensor,optional<int64_t >>;
183
+
184
+ inline void find_and_unpack_tensors (
185
+ const torch::jit::Stack* stack,
186
+ int64_t num_args,
187
+ int64_t cur_level,
188
+ SmallVector<UnpackedBatchedTensor, 5 >* tensors,
189
+ SmallVector<int64_t , 5 >* tensors_pos,
190
+ int64_t * batch_size) {
191
+
192
+ int64_t computed_batch_size = -1 ;
193
+ int64_t args_begin = stack->size () - num_args;
194
+
195
+ for (const auto idx : c10::irange (0 , num_args)) {
196
+ const auto & ivalue = (*stack)[args_begin + idx];
197
+ if (!ivalue.isTensor ()) {
198
+ continue ;
199
+ }
200
+ auto unpacked = unwrapTensorAtLevel (ivalue.toTensor (), cur_level);
201
+ const auto & tensor_value = std::get<0 >(unpacked);
202
+ const auto tensor_bdim = std::get<1 >(unpacked);
203
+ if (tensor_bdim.has_value ()) {
204
+ auto candidate_batch_size = tensor_value.size (*tensor_bdim);
205
+ if (computed_batch_size == -1 ) {
206
+ computed_batch_size = candidate_batch_size;
207
+ }
208
+ TORCH_INTERNAL_ASSERT (candidate_batch_size == computed_batch_size);
209
+ }
210
+
211
+ tensors->push_back (std::move (unpacked));
212
+ tensors_pos->push_back (idx);
213
+ }
214
+ TORCH_INTERNAL_ASSERT (computed_batch_size > -1 );
215
+ *batch_size = computed_batch_size;
216
+ }
217
+
182
218
inline void boxed_existing_bdim_all_batch_rule (
183
219
const c10::OperatorHandle& op, torch::jit::Stack* stack) {
184
220
const auto & schema = op.schema ();
185
221
const auto num_returns = schema.returns ().size ();
186
222
const auto num_arguments = schema.arguments ().size ();
187
- auto arguments = torch::jit::pop (*stack, num_arguments);
188
223
189
224
c10::impl::ExcludeDispatchKeyGuard guard (kBatchedKey );
190
225
auto maybe_layer = maybeCurrentDynamicLayer ();
191
226
TORCH_INTERNAL_ASSERT (maybe_layer.has_value ());
192
227
int64_t cur_level = maybe_layer->layerId ();
193
228
194
- std::vector<std::pair<Tensor, optional<int64_t >>> tensor_inputs;
195
- std::vector<int64_t > tensor_pos;
196
- for (const auto idx : c10::irange (0 , num_arguments)) {
197
- const auto & ivalue = arguments[idx];
198
- if (!ivalue.isTensor ()) {
199
- continue ;
200
- }
201
- Tensor tensor_value;
202
- optional<int64_t > tensor_bdim;
203
- std::tie (tensor_value, tensor_bdim) = unwrapTensorAtLevel (ivalue.toTensor (), cur_level);
204
- tensor_inputs.push_back (std::make_pair (tensor_value, tensor_bdim));
205
- tensor_pos.push_back (idx);
206
- }
229
+ int64_t args_begin = stack->size () - num_arguments;
230
+ SmallVector<UnpackedBatchedTensor, 5 > tensor_inputs;
231
+ SmallVector<int64_t , 5 > tensor_pos;
232
+ int64_t batch_size;
207
233
208
- // compute batch size...
209
- int64_t batch_size = -1 ;
210
- for (const auto & tensor_input : tensor_inputs) {
211
- const auto & value = tensor_input.first ;
212
- const auto & bdim = tensor_input.second ;
213
- if (!bdim) {
214
- continue ;
215
- }
216
- if (batch_size == -1 ) {
217
- batch_size = value.size (*bdim);
218
- }
219
- TORCH_INTERNAL_ASSERT (batch_size == value.size (*bdim));
220
- }
234
+ find_and_unpack_tensors (
235
+ stack, num_arguments, cur_level,
236
+ &tensor_inputs, &tensor_pos, &batch_size);
221
237
222
238
// for each tensor, ensure it has a bdim and reshape it.
223
- for (auto & tensor_input : tensor_inputs) {
224
- auto value = tensor_input. first ;
225
- auto bdim = tensor_input. second ;
226
- value = ensure_has_bdim (value, bdim.has_value (), batch_size);
239
+ for (const auto tensor_idx : c10::irange ( 0 , tensor_inputs. size ()) ) {
240
+ const auto & value = std::get< 0 >(tensor_inputs[tensor_idx]) ;
241
+ auto bdim = std::get< 1 >(tensor_inputs[tensor_idx]) ;
242
+ auto value_ = ensure_has_bdim (value, bdim.has_value (), batch_size);
227
243
if (!bdim.has_value ()) {
228
244
bdim = 0 ;
229
245
}
230
- tensor_input.first = reshape_dim_into (*bdim, 0 , value);
231
- }
232
-
233
- size_t tensor_idx = 0 ;
234
- TORCH_INTERNAL_ASSERT (tensor_pos.size () > 0 );
235
- for (const auto arg_idx : c10::irange (0 , num_arguments)) {
236
- if (tensor_idx >= tensor_pos.size () || (int64_t )arg_idx != tensor_pos[tensor_idx]) {
237
- torch::jit::push (stack, arguments[arg_idx]);
238
- } else {
239
- TORCH_INTERNAL_ASSERT (tensor_idx < tensor_inputs.size ());
240
- torch::jit::push (stack, tensor_inputs[tensor_idx].first );
241
- tensor_idx++;
242
- }
246
+ (*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into (*bdim, 0 , value_);
243
247
}
244
248
245
249
op.callBoxed (stack);
246
- const auto returns = torch::jit::pop (*stack, num_returns);
247
- for (const auto & ret : returns) {
248
- if (ret.isTensor ()) {
249
- torch::jit::push (stack, makeBatched (reshape_dim_outof (0 , batch_size, ret.toTensor ()), 0 , cur_level));
250
- } else {
251
- TORCH_INTERNAL_ASSERT (false , " This boxed batching rule does not currently support ops that return non-tensor values" );
252
- }
250
+
251
+ for (const auto idx : c10::irange (args_begin, args_begin + num_returns)) {
252
+ const auto & ret = (*stack)[idx];
253
+ TORCH_INTERNAL_ASSERT (ret.isTensor (),
254
+ " This boxed batching rule does not currently support ops that return non-tensor values" );
255
+ (*stack)[idx] = makeBatched (reshape_dim_outof (0 , batch_size, ret.toTensor ()), 0 , cur_level);
253
256
}
254
257
}
255
258
0 commit comments