Skip to content

Commit 7f6dbe9

Browse files
authored
Clean up grad state handling (#681)
1 parent 6a52d17 commit 7f6dbe9

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

functorch/csrc/DynamicLayer.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -674,20 +674,20 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
674674
}
675675
#endif
676676

677-
// Re-dispatch
677+
// See NOTE [grad and vjp interaction with no_grad]
678+
optional<c10::AutoGradMode> grad_guard;
678679
if (cur_key == DispatchKey::Autograd && prev_grad_mode.has_value() && *prev_grad_mode == false) {
679-
// See NOTE [grad and vjp interaction with no_grad]
680-
c10::AutoGradMode guard(*prev_grad_mode);
681-
op.callBoxed(stack);
680+
grad_guard.emplace(*prev_grad_mode);
682681
}
683-
else if (cur_key == DispatchKey::Autograd &&
684-
prev_fwd_grad_mode.has_value() && prev_fwd_grad_mode.value() == false) {
685-
c10::AutoFwGradMode guard(*prev_fwd_grad_mode);
686-
op.callBoxed(stack);
687-
} else {
688-
op.callBoxed(stack);
682+
optional<c10::AutoFwGradMode> fw_grad_guard;
683+
if (cur_key == DispatchKey::Autograd &&
684+
prev_fwd_grad_mode.has_value() && prev_fwd_grad_mode.value() == false) {
685+
fw_grad_guard.emplace(*prev_fwd_grad_mode);
689686
}
690687

688+
// Re-dispatch
689+
op.callBoxed(stack);
690+
691691
// Step 4, 5, 6
692692
auto ret_size = op.schema().returns().size();
693693
if (cur_key == DispatchKey::Autograd) {

0 commit comments

Comments
 (0)