Skip to content

Commit e3cff98

Browse files
authored
Use lower case spec name for future compatibility (#160)
1 parent 00129b5 commit e3cff98

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

compiler_opt/rl/policy_saver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def _write_output_signature(self, saver, path):
191191

192192
# Map spec name to index in flattened outputs.
193193
sm_action_indices = dict(
194-
(k.name, i) for i, k in enumerate(sm_action_signature))
194+
(k.name.lower(), i) for i, k in enumerate(sm_action_signature))
195195

196196
# List mapping flattened structured outputs to tensors.
197197
sm_action_tensors = saved_model.signatures['action'].outputs
@@ -204,7 +204,7 @@ def _write_output_signature(self, saver, path):
204204

205205
# Find the decision's tensor in the flattened output tensor list.
206206
sm_action_decision = (
207-
sm_action_tensors[sm_action_indices[decision_spec[0].name]])
207+
sm_action_tensors[sm_action_indices[decision_spec[0].name.lower()]])
208208

209209
sm_action_decision = _get_non_identity_op(sm_action_decision)
210210

@@ -220,7 +220,7 @@ def _write_output_signature(self, saver, path):
220220
}
221221
}]
222222
for info_spec in tf.nest.flatten(action_signature.info):
223-
sm_action_info = sm_action_tensors[sm_action_indices[info_spec.name]]
223+
sm_action_info = sm_action_tensors[sm_action_indices[info_spec.name.lower()]]
224224
sm_action_info = _get_non_identity_op(sm_action_info)
225225
(tensor_op, tensor_port) = _split_tensor_name(sm_action_info.name)
226226
output_list.append({

0 commit comments

Comments
 (0)