Skip to content

Commit fb70a3c

Browse files
authored
fixed some issues with setting leaves and made partitioning deterministic (#880)
1 parent 1e93a45 commit fb70a3c

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

functorch/_src/aot_autograd.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ class CompiledFunction(torch.autograd.Function):
147147
def forward(ctx, *flat_tensor_args):
148148
nonlocal compiled_fw, compiled_bw, num_outs
149149
if compiled_fw is None:
150+
# Set input tensors that require grad to leaves
151+
flat_tensor_args = pytree.tree_map(
152+
lambda x: x.detach().requires_grad_(x.requires_grad), flat_tensor_args
153+
)
150154
with torch.set_grad_enabled(grad_state):
151155
out = flat_fn(*flat_tensor_args)
152156
out = pytree.tree_map(

functorch/_src/partitioners.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,9 @@ def get_node_weight(node):
369369
node_name = node_in[:-3]
370370
cut_nodes.add(node_name)
371371

372+
# To make this stuff deterministic
373+
node_idx = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
374+
saved_values = sorted((name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x])
372375
saved_values = [name_to_node[node] for node in cut_nodes]
373376

374377
return _extract_fwd_bwd_modules(joint_module, saved_values)

0 commit comments

Comments
 (0)