@@ -261,69 +261,6 @@ inline void boxed_existing_bdim_all_batch_rule(
261
261
#define EXISTING_BDIM_ALL_BOXED (op ) \
262
262
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>());
263
263
264
- inline void boxed_existing_bdim_batch_rule (const c10::OperatorHandle& op, torch::jit::Stack* stack) {
265
- const auto & schema = op.schema ();
266
- const auto num_returns = schema.returns ().size ();
267
- const auto num_arguments = schema.arguments ().size ();
268
- auto arguments = torch::jit::pop (*stack, num_arguments);
269
-
270
- c10::impl::ExcludeDispatchKeyGuard guard (kBatchedKey );
271
- auto maybe_layer = maybeCurrentDynamicLayer ();
272
- TORCH_INTERNAL_ASSERT (maybe_layer.has_value ());
273
- int64_t cur_level = maybe_layer->layerId ();
274
-
275
- std::vector<std::pair<Tensor, optional<int64_t >>> tensor_inputs;
276
- std::vector<int64_t > tensor_pos;
277
- for (const auto idx : c10::irange (0 , num_arguments)) {
278
- const auto & ivalue = arguments[idx];
279
- if (!ivalue.isTensor ()) {
280
- continue ;
281
- }
282
- Tensor tensor_value;
283
- optional<int64_t > tensor_bdim;
284
- std::tie (tensor_value, tensor_bdim) = unwrapTensorAtLevel (ivalue.toTensor (), cur_level);
285
- tensor_inputs.push_back (std::make_pair (tensor_value, tensor_bdim));
286
- tensor_pos.push_back (idx);
287
- }
288
-
289
- int64_t batch_size = -1 ;
290
- for (auto & tensor_input : tensor_inputs) {
291
- if (tensor_input.second ) {
292
- if (batch_size == -1 ) {
293
- batch_size = tensor_input.first .size (*tensor_input.second );
294
- }
295
- TORCH_INTERNAL_ASSERT (batch_size == tensor_input.first .size (*tensor_input.second ));
296
- tensor_input.first = reshape_dim_into (*tensor_input.second , 0 , tensor_input.first );
297
- }
298
- }
299
-
300
- size_t tensor_idx = 0 ;
301
- TORCH_INTERNAL_ASSERT (tensor_pos.size () > 0 );
302
- for (const auto arg_idx : c10::irange (0 , num_arguments)) {
303
- if (tensor_idx >= tensor_pos.size () || (int64_t )arg_idx != tensor_pos[tensor_idx]) {
304
- torch::jit::push (stack, arguments[arg_idx]);
305
- } else {
306
- TORCH_INTERNAL_ASSERT (tensor_idx < tensor_inputs.size ());
307
- torch::jit::push (stack, tensor_inputs[tensor_idx].first );
308
- tensor_idx++;
309
- }
310
- }
311
-
312
- op.callBoxed (stack);
313
- const auto returns = torch::jit::pop (*stack, num_returns);
314
- for (const auto & ret : returns) {
315
- if (ret.isTensor ()) {
316
- torch::jit::push (stack, makeBatched (reshape_dim_outof (0 , batch_size, ret.toTensor ()), 0 , cur_level));
317
- } else {
318
- TORCH_INTERNAL_ASSERT (false , " This boxed batching rule does not currently support ops that return non-tensor values" );
319
- }
320
- }
321
- }
322
-
323
- #define EXISTING_BDIM_BOXED (op ) \
324
- m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_batch_rule>());
325
-
326
-
327
264
template <typename A, A a, typename C>
328
265
struct ExistingBdimBatchRuleHelper ;
329
266
0 commit comments