Skip to content

Commit 2b535dd

Browse files
author
Samantha Andow
authored
(fix CI) Add batch rule for split.sizes (#952)
* add batch rule for split.sizes * switch to doing split.sizes as a decomposition
1 parent 9b96f14 commit 2b535dd

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

functorch/csrc/BatchRulesDecompositions.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
181181
OP_DECOMPOSE(special_multigammaln);
182182
OP_DECOMPOSE(special_polygamma);
183183
OP_DECOMPOSE(special_softmax);
184+
OP_DECOMPOSE2(split, sizes);
184185
OP_DECOMPOSE(square);
185186
OP_DECOMPOSE(numpy_T);
186187
OP_DECOMPOSE(reshape_as);

functorch/csrc/LegacyBatchingRegistrations.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,11 @@ std::vector<Tensor> split_batching_rule(const Tensor& self, int64_t split_size,
270270
std::vector<Tensor> split_with_sizes_batching_rule(const Tensor& self, IntArrayRef split_sizes, int64_t dim) {
271271
if (!participatesInCurrentLevel(self)) {
272272
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
273-
return at::split_with_sizes(self, split_sizes, dim);
273+
return split_with_sizes(self, split_sizes, dim);
274274
}
275275
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
276276
auto dim_physical = self_physical.getPhysicalDim(dim);
277-
auto result = at::split_with_sizes(self_physical.tensor(), split_sizes, dim_physical);
277+
auto result = split_with_sizes(self_physical.tensor(), split_sizes, dim_physical);
278278
self_physical.getPhysicalToLogicalMap().applyInplace(result);
279279
return result;
280280
}

0 commit comments

Comments
 (0)