@@ -32,6 +32,19 @@ void vmapIncompatibleInplaceError(const char* schema_name);
32
32
33
33
Tensor maybePadToLogicalRank (const Tensor& tensor, optional<int64_t > has_bdim, int64_t logical_rank);
34
34
35
+ inline Tensor ensure_has_bdim (const Tensor& tensor, bool has_bdim, int64_t batch_size) {
36
+ if (has_bdim) {
37
+ return tensor;
38
+ }
39
+ const auto sizes = tensor.sizes ();
40
+ DimVector expanded_shape;
41
+ expanded_shape.reserve (sizes.size ());
42
+ expanded_shape.emplace_back (batch_size);
43
+ expanded_shape.insert (expanded_shape.end (), sizes.begin (), sizes.end ());
44
+ return tensor.expand (expanded_shape);
45
+ }
46
+
47
+
35
48
#define VMAP_SUPPORT (op, batch_rule ) \
36
49
m.impl(op, PrimBatchRule7< \
37
50
decltype (&batch_rule), &batch_rule, to_operator_t <decltype (batch_rule)> \
@@ -166,7 +179,8 @@ inline void handle_variadic_bdims(std::vector<std::pair<Tensor, optional<int64_t
166
179
#define VARIADIC_BDIMS_BOXED (op ) \
167
180
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype (&handle_variadic_bdims), &handle_variadic_bdims>>());
168
181
169
- inline void boxed_existing_bdim_batch_rule (const c10::OperatorHandle& op, torch::jit::Stack* stack) {
182
+ inline void boxed_existing_bdim_all_batch_rule (
183
+ const c10::OperatorHandle& op, torch::jit::Stack* stack) {
170
184
const auto & schema = op.schema ();
171
185
const auto num_returns = schema.returns ().size ();
172
186
const auto num_arguments = schema.arguments ().size ();
@@ -177,19 +191,101 @@ inline void boxed_existing_bdim_batch_rule(const c10::OperatorHandle& op, torch:
177
191
TORCH_INTERNAL_ASSERT (maybe_layer.has_value ());
178
192
int64_t cur_level = maybe_layer->layerId ();
179
193
194
+ std::vector<std::pair<Tensor, optional<int64_t >>> tensor_inputs;
195
+ std::vector<int64_t > tensor_pos;
196
+ for (const auto idx : c10::irange (0 , num_arguments)) {
197
+ const auto & ivalue = arguments[idx];
198
+ if (!ivalue.isTensor ()) {
199
+ continue ;
200
+ }
201
+ Tensor tensor_value;
202
+ optional<int64_t > tensor_bdim;
203
+ std::tie (tensor_value, tensor_bdim) = unwrapTensorAtLevel (ivalue.toTensor (), cur_level);
204
+ tensor_inputs.push_back (std::make_pair (tensor_value, tensor_bdim));
205
+ tensor_pos.push_back (idx);
206
+ }
207
+
208
+ // compute batch size...
209
+ int64_t batch_size = -1 ;
210
+ for (const auto & tensor_input : tensor_inputs) {
211
+ const auto & value = tensor_input.first ;
212
+ const auto & bdim = tensor_input.second ;
213
+ if (!bdim) {
214
+ continue ;
215
+ }
216
+ if (batch_size == -1 ) {
217
+ batch_size = value.size (*bdim);
218
+ }
219
+ TORCH_INTERNAL_ASSERT (batch_size == value.size (*bdim));
220
+ }
221
+
222
+ // for each tensor, ensure it has a bdim and reshape it.
223
+ for (auto & tensor_input : tensor_inputs) {
224
+ auto value = tensor_input.first ;
225
+ auto bdim = tensor_input.second ;
226
+ value = ensure_has_bdim (value, bdim.has_value (), batch_size);
227
+ if (!bdim.has_value ()) {
228
+ bdim = 0 ;
229
+ }
230
+ tensor_input.first = reshape_dim_into (*bdim, 0 , value);
231
+ }
232
+
233
+ size_t tensor_idx = 0 ;
234
+ TORCH_INTERNAL_ASSERT (tensor_pos.size () > 0 );
235
+ for (const auto arg_idx : c10::irange (0 , num_arguments)) {
236
+ if (tensor_idx >= tensor_pos.size () || (int64_t )arg_idx != tensor_pos[tensor_idx]) {
237
+ torch::jit::push (stack, arguments[arg_idx]);
238
+ } else {
239
+ TORCH_INTERNAL_ASSERT (tensor_idx < tensor_inputs.size ());
240
+ torch::jit::push (stack, tensor_inputs[tensor_idx].first );
241
+ tensor_idx++;
242
+ }
243
+ }
244
+
245
+ op.callBoxed (stack);
246
+ const auto returns = torch::jit::pop (*stack, num_returns);
247
+ for (const auto & ret : returns) {
248
+ if (ret.isTensor ()) {
249
+ torch::jit::push (stack, makeBatched (reshape_dim_outof (0 , batch_size, ret.toTensor ()), 0 , cur_level));
250
+ } else {
251
+ TORCH_INTERNAL_ASSERT (false , " This boxed batching rule does not currently support ops that return non-tensor values" );
252
+ }
253
+ }
254
+ }
255
+
256
+ // Use when all tensors arguments accept one (normal) batch dim.
257
+ // This batching rule expands the batch dim on all Tensors, reshapes it into
258
+ // dim 0, calls the op, and then reshapes the batch dim out of dim 0.
259
+ // This is not the most efficient thing; if there are alternatives, plese try
260
+ // to use them. Use this only as a last resort.
261
+ #define EXISTING_BDIM_ALL_BOXED (op ) \
262
+ m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>());
263
+
264
+ inline void boxed_existing_bdim_batch_rule (const c10::OperatorHandle& op, torch::jit::Stack* stack) {
265
+ const auto & schema = op.schema ();
266
+ const auto num_returns = schema.returns ().size ();
267
+ const auto num_arguments = schema.arguments ().size ();
268
+ auto arguments = torch::jit::pop (*stack, num_arguments);
269
+
270
+ c10::impl::ExcludeDispatchKeyGuard guard (kBatchedKey );
271
+ auto maybe_layer = maybeCurrentDynamicLayer ();
272
+ TORCH_INTERNAL_ASSERT (maybe_layer.has_value ());
273
+ int64_t cur_level = maybe_layer->layerId ();
180
274
181
275
std::vector<std::pair<Tensor, optional<int64_t >>> tensor_inputs;
182
276
std::vector<int64_t > tensor_pos;
183
277
for (const auto idx : c10::irange (0 , num_arguments)) {
184
278
const auto & ivalue = arguments[idx];
185
- if (ivalue.isTensor ()) {
186
- Tensor tensor_value;
187
- optional<int64_t > tensor_bdim;
188
- std::tie (tensor_value, tensor_bdim) = unwrapTensorAtLevel (ivalue.toTensor (), cur_level);
189
- tensor_inputs.push_back (std::make_pair (tensor_value, tensor_bdim));
190
- tensor_pos.push_back (idx);
279
+ if (!ivalue.isTensor ()) {
280
+ continue ;
191
281
}
282
+ Tensor tensor_value;
283
+ optional<int64_t > tensor_bdim;
284
+ std::tie (tensor_value, tensor_bdim) = unwrapTensorAtLevel (ivalue.toTensor (), cur_level);
285
+ tensor_inputs.push_back (std::make_pair (tensor_value, tensor_bdim));
286
+ tensor_pos.push_back (idx);
192
287
}
288
+
193
289
int64_t batch_size = -1 ;
194
290
for (auto & tensor_input : tensor_inputs) {
195
291
if (tensor_input.second ) {
0 commit comments