|
| 1 | +from contextlib import contextmanager |
1 | 2 | import torch
|
2 | 3 | import torch.nn as nn
|
3 | 4 | from torch import Tensor
|
@@ -52,6 +53,19 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
|
52 | 53 | aten = torch.ops.aten
|
53 | 54 |
|
54 | 55 |
|
| 56 | +@contextmanager |
| 57 | +def preserve_rng_state(): |
| 58 | + rng_state = torch.clone(torch.random.get_rng_state()) |
| 59 | + if torch.cuda.is_available(): |
| 60 | + cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) |
| 61 | + try: |
| 62 | + yield |
| 63 | + finally: |
| 64 | + torch.random.set_rng_state(rng_state) |
| 65 | + if torch.cuda.is_available(): |
| 66 | + torch.cuda.set_rng_state(cuda_rng_state) |
| 67 | + |
| 68 | + |
55 | 69 | def create_joint_forward_backward(fn):
|
56 | 70 | def joint_forward_backward(
|
57 | 71 | primals: List[Any], tangents: List[Any]
|
@@ -147,27 +161,29 @@ class CompiledFunction(torch.autograd.Function):
|
147 | 161 | def forward(ctx, *flat_tensor_args):
|
148 | 162 | nonlocal compiled_fw, compiled_bw, num_outs
|
149 | 163 | if compiled_fw is None:
|
150 |
| - # Set input tensors that require grad to leaves |
151 |
| - flat_tensor_args = pytree.tree_map( |
152 |
| - lambda x: x.detach().requires_grad_(x.requires_grad), flat_tensor_args |
153 |
| - ) |
154 |
| - with torch.set_grad_enabled(grad_state): |
155 |
| - out = flat_fn(*flat_tensor_args) |
156 |
| - out = pytree.tree_map( |
157 |
| - lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out |
158 |
| - ) |
| 164 | + with preserve_rng_state(): |
| 165 | + # Set input tensors that require grad to leaves |
| 166 | + flat_tensor_args = pytree.tree_map( |
| 167 | + lambda x: x.detach().requires_grad_(x.requires_grad), flat_tensor_args |
| 168 | + ) |
| 169 | + with torch.set_grad_enabled(grad_state): |
| 170 | + out = flat_fn(*flat_tensor_args) |
| 171 | + out = pytree.tree_map( |
| 172 | + lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out |
| 173 | + ) |
159 | 174 |
|
160 |
| - if isinstance(out, (list, tuple)): |
161 |
| - num_outs = len(out) |
162 |
| - else: |
163 |
| - num_outs = 1 |
| 175 | + if isinstance(out, (list, tuple)): |
| 176 | + num_outs = len(out) |
| 177 | + else: |
| 178 | + num_outs = 1 |
| 179 | + |
| 180 | + joint_inputs = (flat_tensor_args, out) |
| 181 | + aot_decompositions = {**aot_autograd_decompositions, **decompositions} |
| 182 | + with torch.set_grad_enabled(grad_state): |
| 183 | + fx_g = make_fx(joint_forward_backward, aot_decompositions)( |
| 184 | + *joint_inputs |
| 185 | + ) |
164 | 186 |
|
165 |
| - joint_inputs = (flat_tensor_args, out) |
166 |
| - aot_decompositions = {**aot_autograd_decompositions, **decompositions} |
167 |
| - with torch.set_grad_enabled(grad_state): |
168 |
| - fx_g = make_fx(joint_forward_backward, aot_decompositions)( |
169 |
| - *joint_inputs |
170 |
| - ) |
171 | 187 | fw_module, bw_module = partition_fn(fx_g, joint_inputs)
|
172 | 188 | # print(fw_module.code, bw_module.code)
|
173 | 189 |
|
|
0 commit comments