Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 39f0906

Browse files
committed
Revert "Fix one composability problem"
This reverts commit 671d28c.
1 parent 671d28c commit 39f0906

File tree

2 files changed

+19
-34
lines changed

2 files changed

+19
-34
lines changed

functorch/csrc/BatchRulesFactory.cpp

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -47,40 +47,12 @@ std::tuple<Tensor,optional<int64_t>> _new_zeros_with_same_feature_meta_batch_rul
4747
return std::make_tuple(result, 0);
4848
}
4949

50-
std::tuple<Tensor,optional<int64_t>> randn_like_batch_rule(
51-
const Tensor& self, optional<int64_t> self_bdim,
52-
c10::optional<ScalarType> dtype,
53-
c10::optional<Layout> layout,
54-
c10::optional<Device> device,
55-
c10::optional<bool> pin_memory,
56-
c10::optional<c10::MemoryFormat> optional_memory_format) {
57-
// Disable the random key
58-
c10::impl::ExcludeDispatchKeyGuard guard(kVmapModeKey);
59-
return std::make_tuple(
60-
at::randn_like(self, dtype, layout, device, pin_memory, optional_memory_format),
61-
self_bdim);
62-
}
63-
64-
std::tuple<Tensor,optional<int64_t>> rand_like_batch_rule(
65-
const Tensor& self, optional<int64_t> self_bdim,
66-
c10::optional<ScalarType> dtype,
67-
c10::optional<Layout> layout,
68-
c10::optional<Device> device,
69-
c10::optional<bool> pin_memory,
70-
c10::optional<c10::MemoryFormat> optional_memory_format) {
71-
// Disable the random key
72-
c10::impl::ExcludeDispatchKeyGuard guard(kVmapModeKey);
73-
return std::make_tuple(
74-
at::rand_like(self, dtype, layout, device, pin_memory, optional_memory_format),
75-
self_bdim);
76-
}
77-
7850
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
7951
VMAP_SUPPORT("ones_like", BASIC_UNARY_BATCH_RULE(ATEN_FN(ones_like)));
8052
VMAP_SUPPORT("zeros_like", BASIC_UNARY_BATCH_RULE(ATEN_FN(zeros_like)));
8153
VMAP_SUPPORT("empty_like", BASIC_UNARY_BATCH_RULE(ATEN_FN(empty_like)));
82-
VMAP_SUPPORT("randn_like", randn_like_batch_rule);
83-
VMAP_SUPPORT("rand_like", rand_like_batch_rule);
54+
VMAP_SUPPORT("randn_like", BASIC_UNARY_BATCH_RULE(ATEN_FN(randn_like)));
55+
VMAP_SUPPORT("rand_like", BASIC_UNARY_BATCH_RULE(ATEN_FN(rand_like)));
8456
VMAP_SUPPORT("full_like", BASIC_UNARY_BATCH_RULE(ATEN_FN(full_like)));
8557
VMAP_SUPPORT("new_empty", NEW_BLAH_BATCH_RULE(ATEN_FN(new_empty)));
8658
VMAP_SUPPORT("new_zeros", NEW_BLAH_BATCH_RULE(ATEN_FN(new_zeros)));

functorch/csrc/DynamicLayer.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,18 @@ struct WithoutTop {
525525
DynamicLayer layer_;
526526
};
527527

528+
struct SaveLocalDispatchKeySet {
529+
public:
530+
SaveLocalDispatchKeySet() :
531+
saved_keyset_(c10::impl::tls_local_dispatch_key_set()) {}
532+
~SaveLocalDispatchKeySet() {
533+
c10::impl::_force_tls_local_dispatch_key_set(saved_keyset_);
534+
}
535+
536+
private:
537+
c10::impl::LocalDispatchKeySet saved_keyset_;
538+
};
539+
528540
void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
529541
auto cur_level = getDynamicLayerStack().back().layerId();
530542
auto cur_key = getDynamicLayerStack().back().key();
@@ -588,10 +600,11 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
588600
WithoutTop guard;
589601

590602
// "reset exclude set"
591-
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
592-
local_keyset.included_ = local_keyset.included_ - (local_keyset.included_ & all_dynlayer_keyset.add(kVmapModeKey));
593-
local_keyset.excluded_ = local_keyset.excluded_ - (local_keyset.excluded_ & all_dynlayer_keyset.add(kVmapModeKey));
594-
ForceLocalDispatchKeySet save_guard(local_keyset);
603+
// TODO: Still a problem with composabiilty and AutoNonVariableTypeGuard.
604+
// Users cannot do torch.no_grad otherwise there will be problems.
605+
SaveLocalDispatchKeySet save_guard;
606+
auto keyset = c10::impl::PODLocalDispatchKeySet();
607+
c10::impl::_force_tls_local_dispatch_key_set(keyset);
595608
setDynamicLayerFrontBackKeysIncluded(true);
596609

597610
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE

0 commit comments

Comments
 (0)