|
46 | 46 | "gradient_ascent_optimizer_type", None,
|
47 | 47 | "Gradient ascent optimization algorithm: 'momentum' or 'adam'")
|
48 | 48 | flags.mark_flag_as_required("gradient_ascent_optimizer_type")
|
49 |
| -_GREEDY = flags.DEFINE_bool( |
50 |
| - "greedy", |
51 |
| - None, |
52 |
| - "Whether to construct a greedy policy (argmax). \ |
53 |
| - If False, a sampling-based policy will be used.", |
54 |
| - required=True) |
55 | 49 | _MOMENTUM = flags.DEFINE_float(
|
56 | 50 | "momentum", 0.0, "Momentum for momentum gradient ascent optimizer.")
|
57 | 51 | _OUTPUT_PATH = flags.DEFINE_string("output_path", "",
|
@@ -82,7 +76,7 @@ def train(additional_compilation_flags=(),
|
82 | 76 | tf.io.gfile.makedirs(_OUTPUT_PATH.value)
|
83 | 77 |
|
84 | 78 | # Construct the policy and upload it
|
85 |
| - policy = policy_utils.create_actor_policy(greedy=_GREEDY.value) |
| 79 | + policy = policy_utils.create_actor_policy() |
86 | 80 | saver = policy_saver.PolicySaver({POLICY_NAME: policy})
|
87 | 81 |
|
88 | 82 | # Save the policy
|
@@ -121,7 +115,7 @@ def train(additional_compilation_flags=(),
|
121 | 115 | replace_flags=replace_compilation_flags)
|
122 | 116 |
|
123 | 117 | # Construct policy saver
|
124 |
| - saved_policy = policy_utils.create_actor_policy(greedy=True) |
| 118 | + saved_policy = policy_utils.create_actor_policy() |
125 | 119 | policy_saver_function = functools.partial(
|
126 | 120 | policy_utils.save_policy,
|
127 | 121 | policy=saved_policy,
|
|
0 commit comments