19
19
from simple_nets import init_data
20
20
21
21
22
- def simple_net1 ():
22
+ def case1_fill_grad_vars ():
23
23
x = fluid .layers .data (name = 'image' , shape = [784 ], dtype = 'float32' )
24
24
label = fluid .layers .data (name = 'label' , shape = [1 ], dtype = 'int64' )
25
25
feature = fluid .layers .fc (input = x , size = 20 , act = None )
@@ -30,7 +30,7 @@ def simple_net1():
30
30
return loss
31
31
32
32
33
- def simple_net2 ():
33
+ def case2_prune_no_grad_branch ():
34
34
x = fluid .layers .data (name = 'image' , shape = [784 ], dtype = 'float32' )
35
35
label = fluid .layers .data (name = 'label' , shape = [1 ], dtype = 'int64' )
36
36
feature = fluid .layers .fc (input = x , size = 10 , act = None )
@@ -42,14 +42,28 @@ def simple_net2():
42
42
return loss
43
43
44
44
45
+ def case3_prune_no_grad_branch2 ():
46
+ label = fluid .layers .data (name = 'label' , shape = [1 ], dtype = 'int64' )
47
+ label = fluid .layers .cast (label , dtype = "float32" )
48
+ label = fluid .layers .cast (label , dtype = 'int64' )
49
+ out = fluid .layers .one_hot (input = label , depth = 100 )
50
+ loss = fluid .layers .mean (out )
51
+ return loss
52
+
53
+
54
+ def case4_with_no_grad_op_maker ():
55
+ out = fluid .layers .gaussian_random (shape = [20 , 30 ])
56
+ loss = fluid .layers .mean (out )
57
+ return loss
58
+
59
+
45
60
class TestBackward (unittest .TestCase ):
46
- def check_backward (self , model ):
61
+ def check_backward (self , model , feed_dict ):
47
62
place = fluid .CPUPlace ()
48
63
exe = fluid .Executor (place )
49
64
50
65
main = fluid .Program ()
51
66
startup = fluid .Program ()
52
- batch_size = 2
53
67
54
68
with fluid .program_guard (main , startup ):
55
69
loss = model ()
@@ -58,12 +72,16 @@ def check_backward(self, model):
58
72
optimizer .minimize (loss )
59
73
60
74
exe .run (fluid .default_startup_program ())
61
- img , label = init_data (batch_size , img_shape = [784 ], label_range = 9 )
62
- exe .run (feed = {'image' : img , 'label' : label })
75
+ exe .run (feed = feed_dict )
63
76
64
77
def test_backward (self ):
65
- self .check_backward (simple_net1 )
66
- self .check_backward (simple_net2 )
78
+ batch_size = 2
79
+ img , label = init_data (batch_size , img_shape = [784 ], label_range = 9 )
80
+ feed_dict = {'image' : img , 'label' : label }
81
+ self .check_backward (case1_fill_grad_vars , feed_dict )
82
+ self .check_backward (case2_prune_no_grad_branch , feed_dict )
83
+ self .check_backward (case3_prune_no_grad_branch2 , {'label' : label })
84
+ self .check_backward (case4_with_no_grad_op_maker , {})
67
85
68
86
69
87
if __name__ == '__main__' :
0 commit comments