44
55class AutodiffTest (TestCase ):
66
7- def test_basic (self ):
7+ def test_sanity_check (self ):
88
9+ a = Value (shape = (2 ,), name = 'a' )
10+ a .forward (a = array ([2 , 3 ]))
11+ a .backward ()
12+ self .assertTrue (allclose (a .grad , [1 , 1 ]))
13+
14+ def test_basic (self ):
15+ # test arithmetic, relu
916 a = Value (shape = (2 ,), name = 'a' )
1017 b = Value (shape = (2 ,), name = 'b' )
1118 c = (a + 2 ).relu () * b ** 2
@@ -16,6 +23,7 @@ def test_basic(self):
1623 self .assertTrue (allclose (b .grad , [0 , 18 ]))
1724 self .assertTrue (allclose (c .grad , [1 , 1 ]))
1825
26+ # test log1p
1927 d = a .log1p ()
2028 d .forward (a = array ([2 , 3 ]))
2129 d .backward ()
@@ -27,25 +35,52 @@ def test_basic(self):
2735 self .assertTrue (allclose (d .data , [nan , 1.38629436 ], equal_nan = True ))
2836 self .assertTrue (allclose (a .grad , [nan , 0.25 ], equal_nan = True ))
2937
38+ # test transpose
3039 f = Value (shape = (2 , 1 ), name = 'f' )
3140 g = f .T ** 2
3241 g .forward (f = array ([[2 ], [- 1 ]]))
3342 g .backward ()
3443 self .assertTrue (allclose (f .grad , [[4 ], [- 2 ]]))
3544
45+ # TODO: add test of inner product
46+
47+ # test arctanh
3648 h = Value (shape = (5 ,), name = 'h' )
3749 k = (h * 2 ).arctanh ()
3850 k .forward (h = array ([- 1 , - .5 , 0 , .5 , 1 ]))
3951 k .backward ()
4052 self .assertTrue (allclose (h .grad , [nan , inf , 2 , inf , nan ],
4153 equal_nan = True ))
4254
43- def test_sanity_check (self ):
55+ def test_reduce_ops (self ):
4456
45- a = Value (shape = (2 ,), name = 'a' )
46- a .forward (a = array ([2 , 3 ]))
47- a .backward ()
48- self .assertTrue (allclose (a .grad , [1 , 1 ]))
57+ a = Value (shape = (2 , 2 , 3 ), name = 'a' )
58+ b = (a .sum (axis = (0 , 2 )) - 31 ).relu ()
59+ b .forward (a = array ([[[1 , 2 , 3 ], [4 , 5 , 6 ]],
60+ [[7 , 8 , 9 ], [10 , 11 , 12 ]]]))
61+ b .backward ()
62+ self .assertTrue (allclose (b .data , [0 , 17 ]))
63+ self .assertTrue (allclose (a .grad , [[[0 , 0 , 0 ], [1 , 1 , 1 ]],
64+ [[0 , 0 , 0 ], [1 , 1 , 1 ]]]))
65+
66+ b = (a .sum (axis = 0 ) - 10 ).relu ()
67+ b .forward (a = array ([[[1 , 2 , 3 ], [4 , 5 , 6 ]],
68+ [[7 , 8 , 9 ], [10 , 11 , 12 ]]]))
69+ b .backward ()
70+ self .assertTrue (allclose (a .grad , [[[0 , 0 , 1 ], [1 , 1 , 1 ]],
71+ [[0 , 0 , 1 ], [1 , 1 , 1 ]]]))
72+
73+ b = (a .sum () - 77 ).relu ()
74+ c = (a .sum () - 79 ).relu ()
75+ b .forward (a = array ([[[1 , 2 , 3 ], [4 , 5 , 6 ]],
76+ [[7 , 8 , 9 ], [10 , 11 , 12 ]]]))
77+ b .backward ()
78+ self .assertTrue (allclose (a .grad , 1 ))
79+
80+ c .forward (a = array ([[[1 , 2 , 3 ], [4 , 5 , 6 ]],
81+ [[7 , 8 , 9 ], [10 , 11 , 12 ]]]))
82+ c .backward ()
83+ self .assertTrue (allclose (a .grad , 0 ))
4984
5085 def test_ops (self ):
5186 x = Value (- 4.0 )
0 commit comments