Skip to content

Commit 39f18af

Browse files
committed
Fix cond.
1 parent 3609c86 commit 39f18af

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

src/graph/ops_manual.ml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,15 @@ let cond t ~if_true ~if_false =
221221
let t_false, t_true = Ops_generated.switch t t in
222222
let if_true =
223223
Ops_generated.identity if_true
224-
~control_inputs:[ Node.P t_true ]
224+
(* It is important to keep the [identity] below as control inputs do not handle
225+
ports. *)
226+
~control_inputs:[ Node.P (Ops_generated.identity t_true) ]
225227
in
226228
let if_false =
227229
Ops_generated.identity if_false
228-
~control_inputs:[ Node.P t_false ]
230+
(* It is important to keep the [identity] below as control inputs do not handle
231+
ports. *)
232+
~control_inputs:[ Node.P (Ops_generated.identity t_false) ]
229233
in
230234
Ops_generated.merge [ if_true; if_false ]
231235
|> fst

tests/operator_tests.ml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ open Core_kernel.Std
22
open Tensorflow
33
module O = Ops
44

5+
let assert_equal_int value ~expected_value =
6+
if value <> expected_value
7+
then failwithf "Got %d, expected %d" value expected_value ()
8+
59
let assert_equal value ~expected_value ~tol =
610
if Float.abs (value -. expected_value) > tol
711
then failwithf "Got %f, expected %f" value expected_value ()
@@ -114,7 +118,35 @@ let test_batch_normalization () =
114118
in
115119
Tensor.print (Tensor.P tensor)
116120

121+
let test_cond true_false =
122+
let testing = Ops.placeholder ~type_:Bool [ 1 ] in
123+
let true_false = if true_false then 1 else 0 in
124+
let cond =
125+
Ops.cond (Ops.Placeholder.to_node testing)
126+
~if_true:Ops.one32
127+
~if_false:Ops.zero32
128+
in
129+
let testing_tensor = Tensor.create0 Int8_unsigned in
130+
Tensor.copy_elt_list testing_tensor [ true_false ];
131+
let tensor =
132+
Session.run Session.Output.(int32 cond)
133+
~inputs:
134+
[ Session.Input.bool testing testing_tensor
135+
]
136+
in
137+
let index =
138+
match Tensor.dims tensor with
139+
| [||] -> [||]
140+
| [| 1 |] -> [| 0 |]
141+
| [| n |] -> failwithf "Single dimension tensor with %d elements" n ()
142+
| _ -> failwith "Multi-dimensional tensor."
143+
in
144+
let value = Tensor.get tensor index |> Int32.to_int_exn in
145+
assert_equal_int value ~expected_value:true_false
146+
117147
let () =
118148
test_scalar ();
119149
test_vector ();
150+
test_cond true;
151+
test_cond false;
120152
test_batch_normalization ()

0 commit comments

Comments
 (0)