Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 754ee26

Browse files
authored
Fix expand.SymInt issues by adding a decomp (#903)
To fully support expand.SymInt, we should write a batching rule for it. Unfortunately, we are unable to write a batching rule because there are not enough SymInt overloads (e.g. we need a view.SymInt, that doesn't exist today). This PR adds a decomp for expand.SymInt as a short-term fix.
1 parent 091d999 commit 754ee26

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

functorch/csrc/BatchRulesViews.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,14 @@ std::tuple<Tensor, optional<int64_t>> diag_embed_batch_rule(const Tensor& self,
502502
return std::make_tuple(at::diag_embed(self_, offset, dim1, dim2), 0);
503503
}
504504

505+
// We need to write a real batching rule to fully support symint.
506+
// This requires symint variants of other operations, like `view`,
507+
// which don't exist yet.
508+
Tensor expand_symint_decomp_hack(const Tensor& self, SymIntArrayRef packed_size, bool implicit) {
509+
auto size = asIntArrayRefSlow(packed_size);
510+
return self.expand(size, implicit);
511+
}
512+
505513
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
506514
VMAP_SUPPORT(diag, diag_batch_rule);
507515
VMAP_SUPPORT(chunk, chunk_batching_rule);
@@ -532,6 +540,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
532540
VMAP_SUPPORT2(slice, Tensor, slice_batch_rule);
533541
VMAP_SUPPORT2(transpose, int, transpose_int_batch_rule);
534542
VMAP_SUPPORT(diag_embed, diag_embed_batch_rule);
543+
m.impl("expand.SymInt", expand_symint_decomp_hack);
535544
}
536545

537546
}}

0 commit comments

Comments
 (0)