@@ -264,6 +264,71 @@ inline void boxed_existing_bdim_all_batch_rule(
264
264
#define EXISTING_BDIM_ALL_BOXED (op ) \
265
265
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>());
266
266
267
+ template <int64_t feature_rank>
268
+ inline void boxed_all_tensors_have_optional_bdim (
269
+ const c10::OperatorHandle& op, torch::jit::Stack* stack) {
270
+ const auto & schema = op.schema ();
271
+ const auto num_returns = schema.returns ().size ();
272
+ const auto num_arguments = schema.arguments ().size ();
273
+
274
+ c10::impl::ExcludeDispatchKeyGuard guard (kBatchedKey );
275
+ auto maybe_layer = maybeCurrentDynamicLayer ();
276
+ TORCH_INTERNAL_ASSERT (maybe_layer.has_value ());
277
+ int64_t cur_level = maybe_layer->layerId ();
278
+
279
+ int64_t args_begin = stack->size () - num_arguments;
280
+ SmallVector<UnpackedBatchedTensor, 5 > tensor_inputs;
281
+ SmallVector<int64_t , 5 > tensor_pos;
282
+ int64_t batch_size;
283
+
284
+ find_and_unpack_tensors (
285
+ stack, num_arguments, cur_level,
286
+ &tensor_inputs, &tensor_pos, &batch_size);
287
+
288
+ optional<bool > is_no_batch_dim_case;
289
+
290
+ for (const auto tensor_idx : c10::irange (0 , tensor_inputs.size ())) {
291
+ const auto & value = std::get<0 >(tensor_inputs[tensor_idx]);
292
+ auto bdim = std::get<1 >(tensor_inputs[tensor_idx]);
293
+ const auto logical_rank = rankWithoutBatchDim (value, bdim);
294
+
295
+ if (!is_no_batch_dim_case.has_value ()) {
296
+ is_no_batch_dim_case = (logical_rank == feature_rank);
297
+ }
298
+ auto value_ = ensure_has_bdim (value, bdim.has_value (), batch_size);
299
+ if (!bdim.has_value ()) {
300
+ bdim = 0 ;
301
+ }
302
+ if (*is_no_batch_dim_case) {
303
+ TORCH_INTERNAL_ASSERT (logical_rank == feature_rank);
304
+ (*stack)[args_begin + tensor_pos[tensor_idx]] = moveBatchDimToFront (value_, bdim);
305
+ continue ;
306
+ }
307
+ TORCH_INTERNAL_ASSERT (logical_rank == feature_rank + 1 );
308
+ (*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into (*bdim, 0 , value_);
309
+ }
310
+
311
+ op.callBoxed (stack);
312
+
313
+ for (const auto idx : c10::irange (args_begin, args_begin + num_returns)) {
314
+ const auto & ret = (*stack)[idx];
315
+ TORCH_INTERNAL_ASSERT (ret.isTensor (),
316
+ " This boxed batching rule does not currently support ops that return non-tensor values" );
317
+ if (*is_no_batch_dim_case) {
318
+ (*stack)[idx] = makeBatched (ret.toTensor (), 0 , cur_level);
319
+ } else {
320
+ (*stack)[idx] = makeBatched (reshape_dim_outof (0 , batch_size, ret.toTensor ()), 0 , cur_level);
321
+ }
322
+ }
323
+ }
324
+
325
+ // Useful for many NN operators.
326
+ // The operator must satisfy the following:
327
+ // - All arguments must accept an optional batch dim.
328
+ // - All arguments must be the same rank
329
+ #define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED (feature_rank, op ) \
330
+ m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_all_tensors_have_optional_bdim<feature_rank>>());
331
+
267
332
template <typename A, A a, typename C>
268
333
struct ExistingBdimBatchRuleHelper ;
269
334
@@ -304,5 +369,29 @@ Tensor& unary_inplace_batch_rule(Tensor& self, optional<int64_t>, ExtraArgs... e
304
369
return self;
305
370
}
306
371
372
+ inline int64_t get_bdim_size3 (
373
+ const Tensor& a_value, optional<int64_t > a_bdim,
374
+ const Tensor& b_value, optional<int64_t > b_bdim,
375
+ const Tensor& c_value, optional<int64_t > c_bdim) {
376
+ if (a_bdim)
377
+ return a_value.size (*a_bdim);
378
+ if (b_bdim)
379
+ return b_value.size (*b_bdim);
380
+ if (c_bdim)
381
+ return c_value.size (*c_bdim);
382
+ TORCH_INTERNAL_ASSERT (false );
383
+ }
384
+
385
+ inline int64_t get_bdim_size2 (
386
+ const Tensor& a_value, optional<int64_t > a_bdim,
387
+ const Tensor& b_value, optional<int64_t > b_bdim) {
388
+ if (a_bdim)
389
+ return a_value.size (*a_bdim);
390
+ if (b_bdim)
391
+ return b_value.size (*b_bdim);
392
+ TORCH_INTERNAL_ASSERT (false );
393
+ }
394
+
395
+
307
396
}}
308
397
0 commit comments