Skip to content

Commit a8559c1

Browse files
Remove greedy flag from es_trainer_lib
This is already configurable through gin in policy_utils, is not consistently used within es_trainer_lib, and overall just adds unnecessary complexity. Reviewers: mtrofin Reviewed By: mtrofin Pull Request: #427
1 parent 29d562c commit a8559c1

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

compiler_opt/es/es_trainer_lib.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,6 @@
4646
"gradient_ascent_optimizer_type", None,
4747
"Gradient ascent optimization algorithm: 'momentum' or 'adam'")
4848
flags.mark_flag_as_required("gradient_ascent_optimizer_type")
49-
_GREEDY = flags.DEFINE_bool(
50-
"greedy",
51-
None,
52-
"Whether to construct a greedy policy (argmax). \
53-
If False, a sampling-based policy will be used.",
54-
required=True)
5549
_MOMENTUM = flags.DEFINE_float(
5650
"momentum", 0.0, "Momentum for momentum gradient ascent optimizer.")
5751
_OUTPUT_PATH = flags.DEFINE_string("output_path", "",
@@ -82,7 +76,7 @@ def train(additional_compilation_flags=(),
8276
tf.io.gfile.makedirs(_OUTPUT_PATH.value)
8377

8478
# Construct the policy and upload it
85-
policy = policy_utils.create_actor_policy(greedy=_GREEDY.value)
79+
policy = policy_utils.create_actor_policy()
8680
saver = policy_saver.PolicySaver({POLICY_NAME: policy})
8781

8882
# Save the policy
@@ -121,7 +115,7 @@ def train(additional_compilation_flags=(),
121115
replace_flags=replace_compilation_flags)
122116

123117
# Construct policy saver
124-
saved_policy = policy_utils.create_actor_policy(greedy=True)
118+
saved_policy = policy_utils.create_actor_policy()
125119
policy_saver_function = functools.partial(
126120
policy_utils.save_policy,
127121
policy=saved_policy,

0 commit comments

Comments
 (0)