Skip to content

Commit ea03b78

Browse files
committed
Better testing for batch normalization.
1 parent 39f18af commit ea03b78

File tree

4 files changed

+79
-45
lines changed

4 files changed

+79
-45
lines changed

src/graph/layer.ml

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,28 @@ let batch_normalization ?(decay = 0.9) t ~update_moments ~dims ~feature_count =
88
let beta = Var.create [ feature_count ] ~type_ ~init:zero in
99
let gamma = Var.create [ feature_count ] ~type_ ~init:one in
1010
let batch_moments = Ops.moments t ~dims:(List.init dims ~f:Fn.id) in
11-
let beta_with_update =
12-
let update_beta =
13-
(* EWMA update. *)
14-
Ops.assignSub
15-
beta
16-
Ops.(one_minus_decay * (beta - batch_moments.mean))
17-
in
18-
Ops.identity ~control_inputs:[ Node.P update_beta ] beta
11+
let beta_with_update ~control_inputs =
12+
(* EWMA update. *)
13+
Ops.assignSub
14+
beta
15+
Ops.(one_minus_decay * (beta - batch_moments.mean))
16+
~control_inputs
1917
in
20-
let gamma_with_update =
21-
(* EWMA update. *)
22-
let update_gamma =
23-
Ops.assignSub
24-
gamma
25-
Ops.(one_minus_decay * (gamma - batch_moments.variance))
26-
in
27-
Ops.identity ~control_inputs:[ Node.P update_gamma ] gamma
18+
let gamma_with_update ~control_inputs =
19+
(* EWMA update. *)
20+
Ops.assignSub
21+
gamma
22+
Ops.(one_minus_decay * (gamma - batch_moments.variance))
23+
~control_inputs
2824
in
2925
let beta, gamma =
3026
match update_moments with
31-
| `always -> beta_with_update, gamma_with_update
27+
| `always ->
28+
beta_with_update ~control_inputs:[], gamma_with_update ~control_inputs:[]
3229
| `not_in_testing testing ->
33-
Ops.cond testing ~if_true:beta ~if_false:beta_with_update,
34-
Ops.cond testing ~if_true:gamma ~if_false:gamma_with_update
30+
let beta ~control_inputs:_ = beta in
31+
let gamma ~control_inputs:_ = gamma in
32+
Ops.cond_with_control_inputs testing ~if_true:beta ~if_false:beta_with_update,
33+
Ops.cond_with_control_inputs testing ~if_true:gamma ~if_false:gamma_with_update
3534
in
3635
Ops.normalize t { mean = beta; variance = gamma }

src/graph/ops_manual.ml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ let get_shape ?shape values =
1818

1919
let const_float
2020
?(name = "Const")
21+
?(control_inputs = [])
2122
?shape
2223
~type_
2324
values
@@ -28,7 +29,7 @@ let const_float
2829
~op_name:(Op_name.of_string "Const")
2930
~output_type:type_
3031
~inputs:[]
31-
~control_inputs:[]
32+
~control_inputs
3233
~attributes:[
3334
"dtype", Type (P type_);
3435
"value", Tensor_float { type_ = P type_; shape; values };
@@ -37,6 +38,7 @@ let const_float
3738

3839
let const_int
3940
?(name = "Const")
41+
?(control_inputs = [])
4042
?shape
4143
~type_
4244
values
@@ -47,7 +49,7 @@ let const_int
4749
~op_name:(Op_name.of_string "Const")
4850
~output_type:type_
4951
~inputs:[]
50-
~control_inputs:[]
52+
~control_inputs
5153
~attributes:[
5254
"dtype", Type (P type_);
5355
"value", Tensor_int { type_ = P type_; shape; values };
@@ -217,19 +219,24 @@ let normalize ?(epsilon = 1e-12) t { mean; variance } =
217219
let epsilon = scalar ~type_:(Node.output_type t) epsilon in
218220
Ops_generated.rsqrt (variance + epsilon) * (t - mean)
219221

220-
let cond t ~if_true ~if_false =
222+
let cond_with_control_inputs t ~if_true ~if_false =
221223
let t_false, t_true = Ops_generated.switch t t in
222224
let if_true =
223-
Ops_generated.identity if_true
225+
if_true
224226
(* It is important to keep the [identity] below as control inputs do not handle
225227
ports. *)
226228
~control_inputs:[ Node.P (Ops_generated.identity t_true) ]
227229
in
228230
let if_false =
229-
Ops_generated.identity if_false
231+
if_false
230232
(* It is important to keep the [identity] below as control inputs do not handle
231233
ports. *)
232234
~control_inputs:[ Node.P (Ops_generated.identity t_false) ]
233235
in
234236
Ops_generated.merge [ if_true; if_false ]
235237
|> fst
238+
239+
let cond t ~if_true ~if_false =
240+
cond_with_control_inputs t
241+
~if_true:(fun ~control_inputs -> Ops_generated.identity ~control_inputs if_true)
242+
~if_false:(fun ~control_inputs -> Ops_generated.identity ~control_inputs if_false)

src/graph/ops_manual.mli

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,15 @@ val cd : ?shape:int list -> float list -> [ `double ] Node.t
3535
(* Some more refined constant creation functions. *)
3636
val const_float
3737
: ?name:string
38+
-> ?control_inputs:Node.p list
3839
-> ?shape:int list
3940
-> type_:([< `float | `double ] as 'dtype) Node.Type.t
4041
-> float list
4142
-> 'dtype Node.t
4243

4344
val const_int
4445
: ?name:string
46+
-> ?control_inputs:Node.p list
4547
-> ?shape:int list
4648
-> type_:([< `int32 | `int64 ] as 'dtype) Node.Type.t
4749
-> int list
@@ -155,6 +157,16 @@ val normalize
155157
-> 'a moments
156158
-> 'a Node.t
157159

160+
(* If [if_true] and [if_false] use their [control_input] argument to build a node
161+
this node will only be evaluated if necessary. *)
162+
val cond_with_control_inputs
163+
: [ `bool ] Node.t
164+
-> if_true:(control_inputs:Node.p list -> 'a Node.t)
165+
-> if_false:(control_inputs:Node.p list -> 'a Node.t)
166+
-> 'a Node.t
167+
168+
(* [if_true] and [if_false] will always be evaluated because of the 'not-so lazy'
169+
behavior of TensorFlow switch. *)
158170
val cond
159171
: [ `bool ] Node.t
160172
-> if_true:'a Node.t

tests/operator_tests.ml

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -92,39 +92,55 @@ let test_vector () =
9292

9393
let test_batch_normalization () =
9494
let batch = Ops.placeholder ~type_:Double [ 3; 4 ] in
95-
let testing = Ops.placeholder ~type_:Bool [] in
95+
let testing = Ops.placeholder ~type_:Bool [ 1 ] in
9696
let ops =
9797
Layer.batch_normalization
9898
(Ops.Placeholder.to_node batch)
99-
~decay:0.
100-
~update_moments:`always
99+
~decay:0.5
100+
~update_moments:(`not_in_testing (Ops.Placeholder.to_node testing))
101101
~dims:1
102102
~feature_count:4
103+
|> Ops.reduce_sum ~dims:[ 0 ]
103104
in
104105
let batch_tensor = Tensor.create2 Float64 3 4 in
105106
let testing_tensor = Tensor.create0 Int8_unsigned in
106107
Tensor.copy_elt_list testing_tensor [ 0 ];
107-
let tensor =
108-
Tensor.copy_elt_list batch_tensor
109-
[ 0.; 4.; 0.; 8.
110-
; 0.; 4.; 9.; 8.
111-
; 0.; 4.; 3.; 3.
112-
];
113-
Session.run Session.Output.(double ops)
114-
~inputs:
115-
[ Session.Input.double batch batch_tensor
116-
; Session.Input.bool testing testing_tensor
117-
]
118-
in
119-
Tensor.print (Tensor.P tensor)
108+
for i = 0 to 4 do
109+
if i >= 3
110+
then Tensor.copy_elt_list testing_tensor [ 1 ];
111+
let tensor =
112+
Tensor.copy_elt_list batch_tensor
113+
[ 0.; 4.; 0.; 8.
114+
; 0.; 4.; 9.; 8.
115+
; 0.; 4.; 3.; 3.
116+
];
117+
Session.run Session.Output.(double ops)
118+
~inputs:
119+
[ Session.Input.double batch batch_tensor
120+
; Session.Input.bool testing testing_tensor
121+
]
122+
in
123+
let blessed_values =
124+
if i = 0 then [ 0.; 12.; 12.; 19. ]
125+
else if i = 1 then [ 0.; 8.485281; 2.190890; 5.247275 ]
126+
else if i = 2 then [ 0.; 6.000000; 0.914991; 2.260197 ]
127+
else if i = 3 then [ 0.; 4.242641; 0.426401; 1.063611 ]
128+
else if i = 4 then [ 0.; 4.242641; 0.426401; 1.063611 ]
129+
else assert false
130+
in
131+
assert_vector tensor ~expected_value:blessed_values ~tol:1e-6
132+
done
120133

121134
let test_cond true_false =
122135
let testing = Ops.placeholder ~type_:Bool [ 1 ] in
123136
let true_false = if true_false then 1 else 0 in
137+
let int32_with_control_inputs ~control_inputs v =
138+
Ops.const_int ~shape:[] ~type_:Int32 ~control_inputs [ v ]
139+
in
124140
let cond =
125-
Ops.cond (Ops.Placeholder.to_node testing)
126-
~if_true:Ops.one32
127-
~if_false:Ops.zero32
141+
Ops.cond_with_control_inputs (Ops.Placeholder.to_node testing)
142+
~if_true:(int32_with_control_inputs 1)
143+
~if_false:(int32_with_control_inputs 0)
128144
in
129145
let testing_tensor = Tensor.create0 Int8_unsigned in
130146
Tensor.copy_elt_list testing_tensor [ true_false ];
@@ -147,6 +163,6 @@ let test_cond true_false =
147163
let () =
148164
test_scalar ();
149165
test_vector ();
166+
test_batch_normalization ();
150167
test_cond true;
151-
test_cond false;
152-
test_batch_normalization ()
168+
test_cond false

0 commit comments

Comments
 (0)