Skip to content

Commit 36f8a4a

Browse files
committed
Updated comments
1 parent 5e58f4a commit 36f8a4a

File tree

3 files changed

+30
-25
lines changed

3 files changed

+30
-25
lines changed

thunder/core/jit_ext.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -941,8 +941,6 @@ def _generate_random_str_id() -> str:
941941
return "".join(secrets.choice(string.ascii_lowercase) for _ in range(length))
942942

943943
# Support both stable PyTorch (with args_tensor_mask) and nightly (without it)
944-
# See changelog.md for details on the args_tensor_mask removal in nightly PyTorch
945-
# Note: Use "in" check rather than .get() to handle wrapped values correctly
946944
if "args_tensor_mask" in fwd_kwargs:
947945
args_tensor_mask = unwrap(fwd_kwargs["args_tensor_mask"])
948946
else:
@@ -972,8 +970,8 @@ def _generate_random_str_id() -> str:
972970
# With args_tensor_mask, the fwd_body expects ctx as first argument
973971
new_fwd_args = (wrap_const(None),) + tuple(new_fwd_args)
974972
else:
975-
# For nightly PyTorch without args_tensor_mask, the ctx handling is internalized
976-
# by dynamo. The fwd_body GraphModule does NOT expect a ctx argument.
973+
# For nightly PyTorch without args_tensor_mask, the fwd_body
974+
# GraphModule does NOT expect a ctx argument.
977975
# We pass all args as-is without prepending None.
978976
new_fwd_args = tuple(fwd_args)
979977
unwrapped_fwd_args = tree_map(lambda t: unwrap(t), new_fwd_args)
@@ -1014,8 +1012,6 @@ def forward(*args, **kwargs):
10141012
grads = sequencify(tree_map(lambda t: TensorProxy(like=t), sequencify(output)))
10151013
bwd_tensor_args = grads + tuple(saved_values)
10161014
# Support both stable PyTorch (with args_tensor_mask) and nightly (without it)
1017-
# With args_tensor_mask, bwd_body expects ctx as first argument
1018-
# Without args_tensor_mask, ctx handling is internalized - no ctx argument needed
10191015
if args_tensor_mask is not None:
10201016
bwd_args = (None,) + bwd_tensor_args
10211017
else:
@@ -1050,11 +1046,11 @@ def grad_transform(*args, **kwargs):
10501046
# Support both stable PyTorch (with args_tensor_mask) and nightly (without it)
10511047
if args_tensor_mask is not None:
10521048
bwd_args = (None,) + tuple(grads) + tuple(sequencify(residuals))
1053-
# Old API: first arg is ctx, skip it for put_grads
1049+
# Stable PT: first arg is ctx, skip it for put_grads
10541050
grad_inputs = args[1:]
10551051
else:
10561052
bwd_args = tuple(grads) + tuple(sequencify(residuals))
1057-
# New API: no ctx, use all args
1053+
# Nightly PT: no ctx, use all args
10581054
grad_inputs = args
10591055
result = interpret_trace(aliased_bwd_trace, *bwd_args)
10601056
put_grads(grad_inputs, result)

thunder/tests/test_jit_general.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,15 @@
1414

1515
import thunder
1616

17+
from thunder.tests.framework import requiresCUDA, IS_WINDOWS
18+
from thunder.core.options import CACHE_OPTIONS
19+
import thunder.core.prims as prims
20+
from thunder import pytorch_executor, nvfuser_executor
21+
from thunder.executors.sdpaex import sdpa_ex
22+
from thunder.core.transforms import Transform
1723

1824
# Detect once at module load time whether PyTorch uses args_tensor_mask.
19-
# This must be done outside the JIT-traced function to avoid interpreter issues
20-
# with inspect.getsource() and tokenize internals.
25+
# This must be done outside the JIT-traced function to avoid interpreter issues.
2126
def _detect_has_args_tensor_mask():
2227
"""Check if autograd_function_apply uses args_tensor_mask.
2328
@@ -46,14 +51,6 @@ def _autograd_function_apply_kwargs(args_tensor_mask, non_differentiable_idx=Non
4651
return kwargs
4752

4853

49-
from thunder.tests.framework import requiresCUDA, IS_WINDOWS
50-
from thunder.core.options import CACHE_OPTIONS
51-
import thunder.core.prims as prims
52-
from thunder import pytorch_executor, nvfuser_executor
53-
from thunder.executors.sdpaex import sdpa_ex
54-
from thunder.core.transforms import Transform
55-
56-
5754
thunder_jit = partial(thunder.jit, debug_options=thunder.DebugOptions(check_traces=2))
5855

5956
#
@@ -1292,8 +1289,9 @@ def test_autograd_function_apply():
12921289
# since https://github.com/pytorch/pytorch/pull/169528 `torch.ops.higher_order.autograd_function_apply`
12931290
# no longer accepts simple callables, but rather `torch.fx.GraphModule`s.
12941291

1292+
# TODO: Remove this once this autograd API becomes stable.
12951293
# On stable PyTorch (with args_tensor_mask), forward/backward expect ctx as first arg.
1296-
# On nightly PyTorch (without args_tensor_mask), ctx handling is internalized.
1294+
# On nightly PyTorch (without args_tensor_mask), ctx is not an argument.
12971295
if _HAS_ARGS_TENSOR_MASK:
12981296

12991297
class FwdModule(torch.nn.Module):
@@ -1341,6 +1339,9 @@ def my_sin(x):
13411339
expect_grad = torch.autograd.grad(y_ref, x_ref, grad)
13421340
torch.testing.assert_close(actual_grad, expect_grad)
13431341

1342+
# TODO: Remove this once this autograd API becomes stable.
1343+
# On stable PyTorch (with args_tensor_mask), forward/backward expect ctx as first arg.
1344+
# On nightly PyTorch (without args_tensor_mask), ctx is not an argument.
13441345
if _HAS_ARGS_TENSOR_MASK:
13451346

13461347
class WrongBwdModule(torch.nn.Module):
@@ -1383,8 +1384,9 @@ def my_sin_with_wrong_backward(x):
13831384

13841385
def test_autograd_function_apply_with_no_grad():
13851386
# This case is using `torch` operations
1387+
# TODO: Remove this once this autograd API becomes stable.
13861388
# On stable PyTorch (with args_tensor_mask), forward/backward expect ctx as first arg.
1387-
# On nightly PyTorch (without args_tensor_mask), ctx handling is internalized.
1389+
# On nightly PyTorch (without args_tensor_mask), ctx is not an argument.
13881390
if _HAS_ARGS_TENSOR_MASK:
13891391

13901392
def forward(_, x):
@@ -1429,6 +1431,9 @@ def my_sin(x):
14291431

14301432
# This is using `thunder` operations
14311433
# NOTE - This takes a different codepath compared to above.
1434+
# TODO: Remove this once this autograd API becomes stable.
1435+
# On stable PyTorch (with args_tensor_mask), forward/backward expect ctx as first arg.
1436+
# On nightly PyTorch (without args_tensor_mask), ctx is not an argument.
14321437
if _HAS_ARGS_TENSOR_MASK:
14331438

14341439
def forward(_, x):

thunder/torch/__init__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6716,9 +6716,9 @@ def autograd_function_apply(
67166716
args_tensor_mask: Sequence[bool] | None = None,
67176717
non_differentiable_idx: Sequence[int] | None = None,
67186718
) -> TensorProxy | tuple[TensorProxy, ...]:
6719-
# Support both stable PyTorch (with args_tensor_mask) and nightly (without it)
6720-
# With args_tensor_mask, fwd expects ctx as first argument
6721-
# Without args_tensor_mask, ctx handling is internalized - no ctx argument needed
6719+
# TODO: Remove this once this autograd API becomes stable.
6720+
# On stable PyTorch, fwd expects ctx as first argument
6721+
# On nightly PyTorch, ctx is not an argument
67226722
if args_tensor_mask is not None:
67236723
result, saved_for_backward = call_higher_order_function_and_consider_outer_autograd_setting(fwd)(None, *args)
67246724
else:
@@ -6734,7 +6734,9 @@ def augmented_forward_autograd_function_apply(
67346734
args_tensor_mask: Sequence[bool] | None = None,
67356735
non_differentiable_idx: Sequence[int] | None = None,
67366736
) -> tuple[TensorProxy | tuple[TensorProxy, ...], tuple[Any, ...]]:
6737-
# Support both stable PyTorch (with args_tensor_mask) and nightly (without it)
6737+
# TODO: Remove this once this autograd API becomes stable.
6738+
# On stable PyTorch, fwd expects ctx as first argument
6739+
# On nightly PyTorch, ctx is not an argument
67386740
if args_tensor_mask is not None:
67396741
result, saved_for_backward = fwd(None, *args)
67406742
else:
@@ -6750,7 +6752,9 @@ def backward_autograd_function_apply(
67506752
non_differentiable_idx: Sequence[int] | None = None,
67516753
*grad_output: Sequence[TensorProxy],
67526754
) -> tuple[Any, ...]:
6753-
# Support both stable PyTorch (with args_tensor_mask) and nightly (without it)
6755+
# TODO: Remove this once this autograd API becomes stable.
6756+
# On stable PyTorch, fwd expects ctx as first argument
6757+
# On nightly PyTorch, ctx is not an argument
67546758
if args_tensor_mask is not None:
67556759
return bwd(None, *grad_output, *saved_for_backward)
67566760
else:

0 commit comments

Comments
 (0)