Skip to content

Commit 07d7ab4

Browse files
authored
[es] policy_saver_function fix; logging weights fix (#538)
Make `policy_saver_function` an actual function to clarify parameter passing. Drive-by logging message fix.
1 parent 5285826 commit 07d7ab4

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

compiler_opt/es/es_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def main(_):
3535
final_weights = es_trainer_lib.train()
3636

3737
logging.info("Final Weights:")
38-
logging.info(", ".join(final_weights))
38+
logging.info(str(final_weights))
3939

4040

4141
if __name__ == "__main__":

compiler_opt/es/es_trainer_lib.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from absl import flags, logging
1717
import enum
18-
import functools
1918
import gin
2019
import tensorflow as tf
2120
import os
@@ -119,10 +118,13 @@ def train(additional_compilation_flags=(),
119118

120119
# Construct policy saver
121120
saved_policy = policy_utils.create_actor_policy()
122-
policy_saver_function = functools.partial(
123-
policy_utils.save_policy,
124-
policy=saved_policy,
125-
save_folder=os.path.join(_OUTPUT_PATH.value, "saved_policies"))
121+
122+
def policy_saver_function(parameters, model_name):
123+
policy_utils.save_policy(
124+
parameters=parameters,
125+
policy=saved_policy,
126+
policy_name=model_name,
127+
save_folder=os.path.join(_OUTPUT_PATH.value, "saved_policies"))
126128

127129
# Get learner config
128130
learner_config = blackbox_learner.BlackboxLearnerConfig()

0 commit comments

Comments
 (0)