Skip to content

Commit 056ff1f

Browse files
authored
"set_inplace_requires_grad_allowed" should be a context manager (#870)
Test Plan: - run existing tests; code reading
1 parent 5d9e50b commit 056ff1f

File tree

4 files changed

+19
-5
lines changed

4 files changed

+19
-5
lines changed

functorch/_src/eager_transforms.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,27 @@
3333
_assert_wrapped_functional,
3434
_propagate_functional_input_mutation,
3535
set_inplace_requires_grad_allowed,
36+
get_inplace_requires_grad_allowed,
3637
)
3738

3839
argnums_t = Union[int, Tuple[int, ...]]
3940

4041

42+
@contextlib.contextmanager
43+
def enable_inplace_requires_grad(enabled=True):
44+
prev_state = get_inplace_requires_grad_allowed()
45+
set_inplace_requires_grad_allowed(enabled)
46+
try:
47+
yield
48+
finally:
49+
set_inplace_requires_grad_allowed(prev_state)
50+
51+
4152
def _create_differentiable(inps, level=None):
4253
def create_differentiable(x):
4354
if isinstance(x, torch.Tensor):
44-
try:
45-
set_inplace_requires_grad_allowed(True)
55+
with enable_inplace_requires_grad():
4656
return x.requires_grad_()
47-
finally:
48-
set_inplace_requires_grad_allowed(False)
49-
5057
raise ValueError(f'Thing passed to transform API must be Tensor, '
5158
f'got {type(x)}')
5259
return tree_map(create_differentiable, inps)

functorch/csrc/DynamicLayer.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ void setInplaceRequiresGradAllowed(bool allowed) {
140140
functorch_tls->allow_inplace_requires_grad_ = allowed;
141141
}
142142

143+
bool getInplaceRequiresGradAllowed() {
144+
auto* functorch_tls = getRawFunctorchTLS();
145+
return functorch_tls->allow_inplace_requires_grad_;
146+
}
147+
143148

144149
static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
145150
return getRawFunctorchTLS()->dynamicLayerStack;

functorch/csrc/DynamicLayer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
8686
std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
8787

8888
void setInplaceRequiresGradAllowed(bool allowed);
89+
bool getInplaceRequiresGradAllowed();
8990

9091

9192
}

functorch/csrc/init.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
380380
m.def("_set_vmap_fallback_enabled", &at::functorch::setVmapFallbackEnabled);
381381
m.def("_is_vmap_fallback_enabled", &at::functorch::isVmapFallbackEnabled);
382382
m.def("set_inplace_requires_grad_allowed", &at::functorch::setInplaceRequiresGradAllowed);
383+
m.def("get_inplace_requires_grad_allowed", &at::functorch::getInplaceRequiresGradAllowed);
383384
m.def("dlevel", &at::functorch::dlevel, "dlevel");
384385
m.def("dump_tensor", &at::functorch::dump_tensor, "dump_tensor");
385386
m.def("reshape_dim_into", &at::functorch::reshape_dim_into);

0 commit comments

Comments
 (0)