Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit d83f371

Browse files
committed
Some test cases for no_grad
1 parent 509424e commit d83f371

File tree

1 file changed

+113
-0
lines changed

1 file changed

+113
-0
lines changed

test/test_eager_transforms.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,119 @@ def foo(x):
485485
tensor(3.1400)))""")
486486
self.assertEqual(buf, expected)
487487

488+
@unittest.expectedFailure
489+
def test_no_grad_outside(self, device):
490+
x = torch.randn([], device=device)
491+
with torch.no_grad():
492+
y = grad(torch.sin)(x)
493+
self.assertEqual(y, x.cos())
494+
495+
def test_no_grad_inside(self, device):
496+
def f(x):
497+
with torch.no_grad():
498+
shift = x ** 2
499+
return x ** 2 - shift
500+
501+
x = torch.randn([], device=device)
502+
y = grad(f)(x)
503+
self.assertEqual(y, 2 * x)
504+
y = grad(grad(f))(x)
505+
self.assertEqual(y, 2)
506+
507+
@unittest.expectedFailure
508+
def test_no_grad_mixed(self, device):
509+
def f(x):
510+
with torch.no_grad():
511+
shift = x ** 2
512+
return x ** 2 - shift
513+
514+
x = torch.randn([], device=device)
515+
with torch.no_grad():
516+
y = grad(f)(x)
517+
518+
self.assertEqual(y, 2 * x)
519+
520+
def test_no_grad_nested_simple(self, device):
521+
def h(x):
522+
with torch.no_grad():
523+
shift = grad(lambda x: x ** 3)(x)
524+
return x ** 3 - shift
525+
526+
x = torch.tensor(2., device=device)
527+
y = grad(h)(x)
528+
self.assertEqual(y, 6 * x)
529+
530+
def test_no_grad_nested_complicated(self, device):
531+
def f(x):
532+
with torch.no_grad():
533+
shift = x ** 3
534+
return x ** 3 - shift
535+
536+
def g(x):
537+
r1 = grad(f)(x)
538+
with torch.no_grad():
539+
shift = grad(f)(x)
540+
return r1 - shift
541+
542+
x = torch.randn([])
543+
y = grad(g)(x)
544+
# The only differential part of g is x ** 3
545+
self.assertEqual(y, 6 * x)
546+
547+
def test_no_grad_value(self, device):
548+
def h(x):
549+
with torch.no_grad():
550+
gvalue, value = grad_and_value(lambda x: x ** 3)(x)
551+
return x ** 3 - value
552+
553+
x = torch.tensor(2., device=device)
554+
y = grad(h)(x)
555+
self.assertEqual(y, 6 * x)
556+
557+
@unittest.expectedFailure
558+
def test_no_grad_outside_vjp(self, device):
559+
def h(x):
560+
return x ** 2
561+
562+
x = torch.tensor(2., requires_grad=True, device=device)
563+
with torch.no_grad():
564+
out, vjp_fn = vjp(h, x)
565+
y, = vjp_fn(1.)
566+
567+
self.assertEqual(y, 2 * x)
568+
self.assertFalse(y.requires_grad)
569+
self.assertFalse(out.requires_grad)
570+
571+
@unittest.expectedFailure
572+
def test_no_grad_outside_vjp_fn(self, device):
573+
def h(x):
574+
return x ** 2
575+
576+
x = torch.tensor(2., requires_grad=True, device=device)
577+
out, vjp_fn = vjp(h, x)
578+
with torch.no_grad():
579+
y, = vjp_fn(1.)
580+
581+
self.assertEqual(y, 2 * x)
582+
self.assertFalse(y.requires_grad)
583+
self.assertTrue(out.requires_grad)
584+
585+
@unittest.expectedFailure
586+
def test_no_grad_outside_vjp_only(self, device):
587+
def h(x):
588+
return x ** 2
589+
590+
x = torch.tensor(2., requires_grad=True, device=device)
591+
with torch.no_grad():
592+
out, vjp_fn = vjp(h, x)
593+
y, = vjp_fn(1.)
594+
595+
self.assertEqual(y, 2 * x)
596+
self.assertFalse(out.requires_grad)
597+
598+
# This one is a little weird. `vjp_fn` didn't save enough info
599+
# during the forward pass for the output to be differentiable.
600+
self.assertFalse(y.requires_grad)
488601

489602
class TestVmapOfGrad(TestCase):
490603
def test_per_sample_grads_inplace_view(self, device):

0 commit comments

Comments
 (0)