|
14 | 14 |
|
15 | 15 | import thunder |
16 | 16 |
|
| 17 | +from thunder.tests.framework import requiresCUDA, IS_WINDOWS |
| 18 | +from thunder.core.options import CACHE_OPTIONS |
| 19 | +import thunder.core.prims as prims |
| 20 | +from thunder import pytorch_executor, nvfuser_executor |
| 21 | +from thunder.executors.sdpaex import sdpa_ex |
| 22 | +from thunder.core.transforms import Transform |
17 | 23 |
|
18 | 24 | # Detect once at module load time whether PyTorch uses args_tensor_mask. |
19 | | -# This must be done outside the JIT-traced function to avoid interpreter issues |
20 | | -# with inspect.getsource() and tokenize internals. |
| 25 | +# This must be done outside the JIT-traced function to avoid interpreter issues. |
21 | 26 | def _detect_has_args_tensor_mask(): |
22 | 27 | """Check if autograd_function_apply uses args_tensor_mask. |
23 | 28 |
|
@@ -46,14 +51,6 @@ def _autograd_function_apply_kwargs(args_tensor_mask, non_differentiable_idx=Non |
46 | 51 | return kwargs |
47 | 52 |
|
48 | 53 |
|
49 | | -from thunder.tests.framework import requiresCUDA, IS_WINDOWS |
50 | | -from thunder.core.options import CACHE_OPTIONS |
51 | | -import thunder.core.prims as prims |
52 | | -from thunder import pytorch_executor, nvfuser_executor |
53 | | -from thunder.executors.sdpaex import sdpa_ex |
54 | | -from thunder.core.transforms import Transform |
55 | | - |
56 | | - |
57 | 54 | thunder_jit = partial(thunder.jit, debug_options=thunder.DebugOptions(check_traces=2)) |
58 | 55 |
|
59 | 56 | # |
@@ -1292,8 +1289,9 @@ def test_autograd_function_apply(): |
1292 | 1289 | # since https://github.com/pytorch/pytorch/pull/169528 `torch.ops.higher_order.autograd_function_apply` |
1293 | 1290 | # no longer accepts simple callables, but rather `torch.fx.GraphModule`s. |
1294 | 1291 |
|
| 1292 | + # TODO: Remove this once this autograd API becomes stable. |
1295 | 1293 | # On stable PyTorch (with args_tensor_mask), forward/backward expect ctx as first arg. |
1296 | | - # On nightly PyTorch (without args_tensor_mask), ctx handling is internalized. |
| 1294 | + # On nightly PyTorch (without args_tensor_mask), ctx is not an argument. |
1297 | 1295 | if _HAS_ARGS_TENSOR_MASK: |
1298 | 1296 |
|
1299 | 1297 | class FwdModule(torch.nn.Module): |
@@ -1341,6 +1339,9 @@ def my_sin(x): |
1341 | 1339 | expect_grad = torch.autograd.grad(y_ref, x_ref, grad) |
1342 | 1340 | torch.testing.assert_close(actual_grad, expect_grad) |
1343 | 1341 |
|
| 1342 | + # TODO: Remove this once this autograd API becomes stable. |
| 1343 | + # On stable PyTorch (with args_tensor_mask), forward/backward expect ctx as first arg. |
| 1344 | + # On nightly PyTorch (without args_tensor_mask), ctx is not an argument. |
1344 | 1345 | if _HAS_ARGS_TENSOR_MASK: |
1345 | 1346 |
|
1346 | 1347 | class WrongBwdModule(torch.nn.Module): |
@@ -1383,8 +1384,9 @@ def my_sin_with_wrong_backward(x): |
1383 | 1384 |
|
1384 | 1385 | def test_autograd_function_apply_with_no_grad(): |
1385 | 1386 | # This case is using `torch` operations |
| 1387 | + # TODO: Remove this once this autograd API becomes stable. |
1386 | 1388 | # On stable PyTorch (with args_tensor_mask), forward/backward expect ctx as first arg. |
1387 | | - # On nightly PyTorch (without args_tensor_mask), ctx handling is internalized. |
| 1389 | + # On nightly PyTorch (without args_tensor_mask), ctx is not an argument. |
1388 | 1390 | if _HAS_ARGS_TENSOR_MASK: |
1389 | 1391 |
|
1390 | 1392 | def forward(_, x): |
@@ -1429,6 +1431,9 @@ def my_sin(x): |
1429 | 1431 |
|
1430 | 1432 | # This is using `thunder` operations |
1431 | 1433 | # NOTE - This takes a different codepath compared to above. |
| 1434 | + # TODO: Remove this once this autograd API becomes stable. |
| 1435 | + # On stable PyTorch (with args_tensor_mask), forward/backward expect ctx as first arg. |
| 1436 | + # On nightly PyTorch (without args_tensor_mask), ctx is not an argument. |
1432 | 1437 | if _HAS_ARGS_TENSOR_MASK: |
1433 | 1438 |
|
1434 | 1439 | def forward(_, x): |
|
0 commit comments