@@ -2,6 +2,10 @@ open Core_kernel.Std
22open Tensorflow
33module O = Ops
44
5+ let assert_equal_int value ~expected_value =
6+ if value <> expected_value
7+ then failwithf " Got %d, expected %d" value expected_value ()
8+
59let assert_equal value ~expected_value ~tol =
610 if Float. abs (value -. expected_value) > tol
711 then failwithf " Got %f, expected %f" value expected_value ()
@@ -114,7 +118,35 @@ let test_batch_normalization () =
114118 in
115119 Tensor. print (Tensor. P tensor)
116120
121+ let test_cond true_false =
122+ let testing = Ops. placeholder ~type_: Bool [ 1 ] in
123+ let true_false = if true_false then 1 else 0 in
124+ let cond =
125+ Ops. cond (Ops.Placeholder. to_node testing)
126+ ~if_true: Ops. one32
127+ ~if_false: Ops. zero32
128+ in
129+ let testing_tensor = Tensor. create0 Int8_unsigned in
130+ Tensor. copy_elt_list testing_tensor [ true_false ];
131+ let tensor =
132+ Session. run Session.Output. (int32 cond)
133+ ~inputs:
134+ [ Session.Input. bool testing testing_tensor
135+ ]
136+ in
137+ let index =
138+ match Tensor. dims tensor with
139+ | [||] -> [||]
140+ | [| 1 |] -> [| 0 |]
141+ | [| n |] -> failwithf " Single dimension tensor with %d elements" n ()
142+ | _ -> failwith " Multi-dimensional tensor."
143+ in
144+ let value = Tensor. get tensor index |> Int32. to_int_exn in
145+ assert_equal_int value ~expected_value: true_false
146+
117147let () =
118148 test_scalar () ;
119149 test_vector () ;
150+ test_cond true ;
151+ test_cond false ;
120152 test_batch_normalization ()
0 commit comments