Skip to content

Commit 665a2fd

Browse files
authored
Fix to batching rule (#685)
The to(self, other) batching rule doesn't actually handle the case where self is not Batched and other is Batched.
1 parent f80a466 commit 665a2fd

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

functorch/csrc/BatchRulesUnaryOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ view_as_complex_batch_rule(const Tensor& self, optional<int64_t> self_bdim) {
7171
return std::make_tuple(result, 0);
7272
}
7373

74+
std::tuple<Tensor,optional<int64_t>>
75+
to_other_batch_rule(const Tensor& self, optional<int64_t> self_bdim,
76+
const Tensor& other, optional<int64_t> other_bdim,
77+
bool non_blocking,
78+
bool copy, c10::optional<at::MemoryFormat> memory_format) {
79+
return std::make_tuple(self.to(other, non_blocking, copy, memory_format), self_bdim);
80+
}
81+
7482
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
7583

7684
#define UNARY_POINTWISE_ALL2(op, overload) \
@@ -89,6 +97,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
8997
VMAP_SUPPORT2(to, device, BASIC_UNARY_BATCH_RULE(ATEN_FN2(to, device)));
9098
VMAP_SUPPORT2(to, dtype, BASIC_UNARY_BATCH_RULE(ATEN_FN2(to, dtype)));
9199
VMAP_SUPPORT2(to, dtype_layout, BASIC_UNARY_BATCH_RULE(ATEN_FN2(to, dtype_layout)));
100+
VMAP_SUPPORT2(to, other, to_other_batch_rule);
92101

93102
UNARY_POINTWISE(_to_copy);
94103
UNARY_POINTWISE(alias);

functorch/csrc/BatchingRegistrations.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -672,15 +672,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
672672
m.impl("squeeze_.dim", squeeze_dim__batching_rule);
673673
m.impl("unsqueeze_", unsqueeze__batching_rule);
674674

675-
// still legacy b/c this op is weird
676-
#define TO_BATCHING_RULE(name, ...) \
677-
{ \
678-
using to_type = Tensor(Tensor::*)(__VA_ARGS__) const; \
679-
m.impl(name, unwrap_and_call_method< \
680-
to_type, &Tensor::to, __VA_ARGS__>);\
681-
}
682-
TO_BATCHING_RULE("to.other", const Tensor&, bool, bool, optional<MemoryFormat>)
683-
684675
// still legacy because these are ridiculously complicated
685676
m.impl("as_strided", as_strided_batching_rule);
686677
m.impl("new_empty_strided", new_empty_strided_batching_rule);

0 commit comments

Comments
 (0)