Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 14cda62

Browse files
committed
fixed compilecache problem
1 parent c6cc13c commit 14cda62

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

functorch/_src/aot_autograd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ def forward(ctx, *flat_tensor_args):
167167
with preserve_rng_state():
168168
# Set input tensors that require grad to leaves
169169
flat_tensor_args = pytree.tree_map(
170-
lambda x: x.detach().requires_grad_(x.requires_grad), flat_tensor_args
170+
lambda x: x.detach().requires_grad_(x.requires_grad)
171+
if isinstance(x, Tensor) else x, flat_tensor_args
171172
)
172173
with torch.set_grad_enabled(grad_state):
173174
out = flat_fn(*flat_tensor_args)

0 commit comments

Comments
 (0)