Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 671d28c

Browse files
committed
Fix one composability problem
1 parent 8eb6af6 commit 671d28c

File tree

2 files changed

+34
-19
lines changed

2 files changed

+34
-19
lines changed

functorch/csrc/BatchRulesFactory.cpp

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,40 @@ std::tuple<Tensor,optional<int64_t>> _new_zeros_with_same_feature_meta_batch_rul
4747
return std::make_tuple(result, 0);
4848
}
4949

50+
std::tuple<Tensor,optional<int64_t>> randn_like_batch_rule(
51+
const Tensor& self, optional<int64_t> self_bdim,
52+
c10::optional<ScalarType> dtype,
53+
c10::optional<Layout> layout,
54+
c10::optional<Device> device,
55+
c10::optional<bool> pin_memory,
56+
c10::optional<c10::MemoryFormat> optional_memory_format) {
57+
// Disable the random key
58+
c10::impl::ExcludeDispatchKeyGuard guard(kVmapModeKey);
59+
return std::make_tuple(
60+
at::randn_like(self, dtype, layout, device, pin_memory, optional_memory_format),
61+
self_bdim);
62+
}
63+
64+
std::tuple<Tensor,optional<int64_t>> rand_like_batch_rule(
65+
const Tensor& self, optional<int64_t> self_bdim,
66+
c10::optional<ScalarType> dtype,
67+
c10::optional<Layout> layout,
68+
c10::optional<Device> device,
69+
c10::optional<bool> pin_memory,
70+
c10::optional<c10::MemoryFormat> optional_memory_format) {
71+
// Disable the random key
72+
c10::impl::ExcludeDispatchKeyGuard guard(kVmapModeKey);
73+
return std::make_tuple(
74+
at::rand_like(self, dtype, layout, device, pin_memory, optional_memory_format),
75+
self_bdim);
76+
}
77+
5078
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
5179
VMAP_SUPPORT("ones_like", BASIC_UNARY_BATCH_RULE(ATEN_FN(ones_like)));
5280
VMAP_SUPPORT("zeros_like", BASIC_UNARY_BATCH_RULE(ATEN_FN(zeros_like)));
5381
VMAP_SUPPORT("empty_like", BASIC_UNARY_BATCH_RULE(ATEN_FN(empty_like)));
54-
VMAP_SUPPORT("randn_like", BASIC_UNARY_BATCH_RULE(ATEN_FN(randn_like)));
55-
VMAP_SUPPORT("rand_like", BASIC_UNARY_BATCH_RULE(ATEN_FN(rand_like)));
82+
VMAP_SUPPORT("randn_like", randn_like_batch_rule);
83+
VMAP_SUPPORT("rand_like", rand_like_batch_rule);
5684
VMAP_SUPPORT("full_like", BASIC_UNARY_BATCH_RULE(ATEN_FN(full_like)));
5785
VMAP_SUPPORT("new_empty", NEW_BLAH_BATCH_RULE(ATEN_FN(new_empty)));
5886
VMAP_SUPPORT("new_zeros", NEW_BLAH_BATCH_RULE(ATEN_FN(new_zeros)));

functorch/csrc/DynamicLayer.cpp

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -525,18 +525,6 @@ struct WithoutTop {
525525
DynamicLayer layer_;
526526
};
527527

528-
struct SaveLocalDispatchKeySet {
529-
public:
530-
SaveLocalDispatchKeySet() :
531-
saved_keyset_(c10::impl::tls_local_dispatch_key_set()) {}
532-
~SaveLocalDispatchKeySet() {
533-
c10::impl::_force_tls_local_dispatch_key_set(saved_keyset_);
534-
}
535-
536-
private:
537-
c10::impl::LocalDispatchKeySet saved_keyset_;
538-
};
539-
540528
void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
541529
auto cur_level = getDynamicLayerStack().back().layerId();
542530
auto cur_key = getDynamicLayerStack().back().key();
@@ -600,11 +588,10 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
600588
WithoutTop guard;
601589

602590
// "reset exclude set"
603-
// TODO: Still a problem with composabiilty and AutoNonVariableTypeGuard.
604-
// Users cannot do torch.no_grad otherwise there will be problems.
605-
SaveLocalDispatchKeySet save_guard;
606-
auto keyset = c10::impl::PODLocalDispatchKeySet();
607-
c10::impl::_force_tls_local_dispatch_key_set(keyset);
591+
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
592+
local_keyset.included_ = local_keyset.included_ - (local_keyset.included_ & all_dynlayer_keyset.add(kVmapModeKey));
593+
local_keyset.excluded_ = local_keyset.excluded_ - (local_keyset.excluded_ & all_dynlayer_keyset.add(kVmapModeKey));
594+
ForceLocalDispatchKeySet save_guard(local_keyset);
608595
setDynamicLayerFrontBackKeysIncluded(true);
609596

610597
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE

0 commit comments

Comments
 (0)