@@ -485,6 +485,119 @@ def foo(x):
485
485
tensor(3.1400)))""" )
486
486
self .assertEqual (buf , expected )
487
487
488
+ @unittest .expectedFailure
489
+ def test_no_grad_outside (self , device ):
490
+ x = torch .randn ([], device = device )
491
+ with torch .no_grad ():
492
+ y = grad (torch .sin )(x )
493
+ self .assertEqual (y , x .cos ())
494
+
495
+ def test_no_grad_inside (self , device ):
496
+ def f (x ):
497
+ with torch .no_grad ():
498
+ shift = x ** 2
499
+ return x ** 2 - shift
500
+
501
+ x = torch .randn ([], device = device )
502
+ y = grad (f )(x )
503
+ self .assertEqual (y , 2 * x )
504
+ y = grad (grad (f ))(x )
505
+ self .assertEqual (y , 2 )
506
+
507
+ @unittest .expectedFailure
508
+ def test_no_grad_mixed (self , device ):
509
+ def f (x ):
510
+ with torch .no_grad ():
511
+ shift = x ** 2
512
+ return x ** 2 - shift
513
+
514
+ x = torch .randn ([], device = device )
515
+ with torch .no_grad ():
516
+ y = grad (f )(x )
517
+
518
+ self .assertEqual (y , 2 * x )
519
+
520
+ def test_no_grad_nested_simple (self , device ):
521
+ def h (x ):
522
+ with torch .no_grad ():
523
+ shift = grad (lambda x : x ** 3 )(x )
524
+ return x ** 3 - shift
525
+
526
+ x = torch .tensor (2. , device = device )
527
+ y = grad (h )(x )
528
+ self .assertEqual (y , 6 * x )
529
+
530
+ def test_no_grad_nested_complicated (self , device ):
531
+ def f (x ):
532
+ with torch .no_grad ():
533
+ shift = x ** 3
534
+ return x ** 3 - shift
535
+
536
+ def g (x ):
537
+ r1 = grad (f )(x )
538
+ with torch .no_grad ():
539
+ shift = grad (f )(x )
540
+ return r1 - shift
541
+
542
+ x = torch .randn ([])
543
+ y = grad (g )(x )
544
+ # The only differential part of g is x ** 3
545
+ self .assertEqual (y , 6 * x )
546
+
547
+ def test_no_grad_value (self , device ):
548
+ def h (x ):
549
+ with torch .no_grad ():
550
+ gvalue , value = grad_and_value (lambda x : x ** 3 )(x )
551
+ return x ** 3 - value
552
+
553
+ x = torch .tensor (2. , device = device )
554
+ y = grad (h )(x )
555
+ self .assertEqual (y , 6 * x )
556
+
557
+ @unittest .expectedFailure
558
+ def test_no_grad_outside_vjp (self , device ):
559
+ def h (x ):
560
+ return x ** 2
561
+
562
+ x = torch .tensor (2. , requires_grad = True , device = device )
563
+ with torch .no_grad ():
564
+ out , vjp_fn = vjp (h , x )
565
+ y , = vjp_fn (1. )
566
+
567
+ self .assertEqual (y , 2 * x )
568
+ self .assertFalse (y .requires_grad )
569
+ self .assertFalse (out .requires_grad )
570
+
571
+ @unittest .expectedFailure
572
+ def test_no_grad_outside_vjp_fn (self , device ):
573
+ def h (x ):
574
+ return x ** 2
575
+
576
+ x = torch .tensor (2. , requires_grad = True , device = device )
577
+ out , vjp_fn = vjp (h , x )
578
+ with torch .no_grad ():
579
+ y , = vjp_fn (1. )
580
+
581
+ self .assertEqual (y , 2 * x )
582
+ self .assertFalse (y .requires_grad )
583
+ self .assertTrue (out .requires_grad )
584
+
585
+ @unittest .expectedFailure
586
+ def test_no_grad_outside_vjp_only (self , device ):
587
+ def h (x ):
588
+ return x ** 2
589
+
590
+ x = torch .tensor (2. , requires_grad = True , device = device )
591
+ with torch .no_grad ():
592
+ out , vjp_fn = vjp (h , x )
593
+ y , = vjp_fn (1. )
594
+
595
+ self .assertEqual (y , 2 * x )
596
+ self .assertFalse (out .requires_grad )
597
+
598
+ # This one is a little weird. `vjp_fn` didn't save enough info
599
+ # during the forward pass for the output to be differentiable.
600
+ self .assertFalse (y .requires_grad )
488
601
489
602
class TestVmapOfGrad (TestCase ):
490
603
def test_per_sample_grads_inplace_view (self , device ):
0 commit comments