Skip to content

Commit e82e64a

Browse files
authored
Present Random state (#887)
* Present Random state * Add tests
1 parent a3137b6 commit e82e64a

File tree

2 files changed

+53
-19
lines changed

2 files changed

+53
-19
lines changed

functorch/_src/aot_autograd.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from contextlib import contextmanager
12
import torch
23
import torch.nn as nn
34
from torch import Tensor
@@ -52,6 +53,19 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
5253
aten = torch.ops.aten
5354

5455

56+
@contextmanager
57+
def preserve_rng_state():
58+
rng_state = torch.clone(torch.random.get_rng_state())
59+
if torch.cuda.is_available():
60+
cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
61+
try:
62+
yield
63+
finally:
64+
torch.random.set_rng_state(rng_state)
65+
if torch.cuda.is_available():
66+
torch.cuda.set_rng_state(cuda_rng_state)
67+
68+
5569
def create_joint_forward_backward(fn):
5670
def joint_forward_backward(
5771
primals: List[Any], tangents: List[Any]
@@ -147,27 +161,29 @@ class CompiledFunction(torch.autograd.Function):
147161
def forward(ctx, *flat_tensor_args):
148162
nonlocal compiled_fw, compiled_bw, num_outs
149163
if compiled_fw is None:
150-
# Set input tensors that require grad to leaves
151-
flat_tensor_args = pytree.tree_map(
152-
lambda x: x.detach().requires_grad_(x.requires_grad), flat_tensor_args
153-
)
154-
with torch.set_grad_enabled(grad_state):
155-
out = flat_fn(*flat_tensor_args)
156-
out = pytree.tree_map(
157-
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out
158-
)
164+
with preserve_rng_state():
165+
# Set input tensors that require grad to leaves
166+
flat_tensor_args = pytree.tree_map(
167+
lambda x: x.detach().requires_grad_(x.requires_grad), flat_tensor_args
168+
)
169+
with torch.set_grad_enabled(grad_state):
170+
out = flat_fn(*flat_tensor_args)
171+
out = pytree.tree_map(
172+
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out
173+
)
159174

160-
if isinstance(out, (list, tuple)):
161-
num_outs = len(out)
162-
else:
163-
num_outs = 1
175+
if isinstance(out, (list, tuple)):
176+
num_outs = len(out)
177+
else:
178+
num_outs = 1
179+
180+
joint_inputs = (flat_tensor_args, out)
181+
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
182+
with torch.set_grad_enabled(grad_state):
183+
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
184+
*joint_inputs
185+
)
164186

165-
joint_inputs = (flat_tensor_args, out)
166-
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
167-
with torch.set_grad_enabled(grad_state):
168-
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
169-
*joint_inputs
170-
)
171187
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
172188
# print(fw_module.code, bw_module.code)
173189

test/test_pythonkey.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,24 @@ def forward(self, x, y):
546546
assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad)
547547

548548

549+
class TestRandom(TestCase):
550+
def test_preserve_random(self):
551+
def fn(x):
552+
return torch.nn.functional.dropout(x, 0.5) + x
553+
554+
555+
x = torch.randn(4)
556+
557+
torch.manual_seed(0)
558+
ref = fn(x)
559+
560+
torch.manual_seed(0)
561+
aot_fn = aot_function(fn, nop)
562+
res = aot_fn(x)
563+
564+
assert torch.allclose(ref, res)
565+
566+
549567
only_for = ("cpu")
550568
instantiate_device_type_tests(
551569
TestPythonKey,

0 commit comments

Comments
 (0)