Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit d1ec060

Browse files
committed
Ensure WithoutTop doesn't get rid of mode metadata
Fixes #129
1 parent d83f371 commit d1ec060

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

functorch/csrc/DynamicLayer.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,25 @@ static DynamicLayer popDynamicLayer() {
103103
return result;
104104
}
105105

106+
static int64_t pushDynamicLayer(DynamicLayer&& dynamic_layer) {
107+
auto& dynamicLayerStack = dynamicLayerStackAccessor();
108+
int64_t layerId = 1 + dynamicLayerStack.size();
109+
TORCH_INTERNAL_ASSERT(layerId == dynamic_layer.layerId());
110+
dynamicLayerStack.emplace_back(dynamic_layer);
111+
112+
if (layerId == 2) {
113+
c10::impl::tls_set_dispatch_key_included(kDynamicLayerFrontModeKey, true);
114+
c10::impl::tls_set_dispatch_key_included(kDynamicLayerBackModeKey, true);
115+
}
116+
117+
return layerId;
118+
}
119+
106120
static int64_t pushDynamicLayer(DispatchKey key, optional<int64_t> batch_size = nullopt) {
107121
auto& dynamicLayerStack = dynamicLayerStackAccessor();
108122
TORCH_INTERNAL_ASSERT(key != DispatchKey::Undefined);
109123
TORCH_INTERNAL_ASSERT(key != DispatchKey::Batched);
124+
110125
auto layerId = 1 + dynamicLayerStack.size();
111126
dynamicLayerStack.emplace_back(key, layerId, batch_size);
112127

@@ -356,7 +371,7 @@ struct WithoutTop {
356371
WithoutTop(): layer_(popDynamicLayer()) {
357372
}
358373
~WithoutTop() {
359-
pushDynamicLayer(layer_.key());
374+
pushDynamicLayer(std::move(layer_));
360375
}
361376

362377
bool prev_grad_enabled_;

0 commit comments

Comments
 (0)