File tree Expand file tree Collapse file tree 2 files changed +7
-0
lines changed Expand file tree Collapse file tree 2 files changed +7
-0
lines changed Original file line number Diff line number Diff line change @@ -147,6 +147,10 @@ class CompiledFunction(torch.autograd.Function):
147
147
def forward (ctx , * flat_tensor_args ):
148
148
nonlocal compiled_fw , compiled_bw , num_outs
149
149
if compiled_fw is None :
150
+ # Set input tensors that require grad to leaves
151
+ flat_tensor_args = pytree .tree_map (
152
+ lambda x : x .detach ().requires_grad_ (x .requires_grad ), flat_tensor_args
153
+ )
150
154
with torch .set_grad_enabled (grad_state ):
151
155
out = flat_fn (* flat_tensor_args )
152
156
out = pytree .tree_map (
Original file line number Diff line number Diff line change @@ -369,6 +369,9 @@ def get_node_weight(node):
369
369
node_name = node_in [:- 3 ]
370
370
cut_nodes .add (node_name )
371
371
372
+ # To make this stuff deterministic
373
+ node_idx = {node : idx for idx , node in enumerate (joint_module .graph .nodes )}
374
+ saved_values = sorted ((name_to_node [node ] for node in cut_nodes ), key = lambda x : node_idx [x ])
372
375
saved_values = [name_to_node [node ] for node in cut_nodes ]
373
376
374
377
return _extract_fwd_bwd_modules (joint_module , saved_values )
You can’t perform that action at this time.
0 commit comments