Skip to content

Commit e2c1320

Browse files
Refactored train_eval to use num_modules from gin configs (#53)
1 parent 4f17124 commit e2c1320

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

compiler_opt/rl/inlining/gin_configs/ppo_nn_agent.gin

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ include 'compiler_opt/rl/inlining/gin_configs/common.gin'
1212
train_eval.agent_name=%constant.AgentName.PPO
1313
train_eval.warmstart_policy_dir=''
1414
train_eval.num_policy_iterations=3000
15+
train_eval.num_modules=100
1516
train_eval.num_iterations=300
1617
train_eval.batch_size=256
1718
train_eval.train_sequence_length=16

compiler_opt/rl/regalloc/gin_configs/ppo_nn_agent.gin

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ include 'compiler_opt/rl/regalloc/gin_configs/network.gin'
1313
train_eval.agent_name=%constant.AgentName.PPO
1414
train_eval.warmstart_policy_dir=''
1515
train_eval.num_policy_iterations=3000
16+
train_eval.num_modules=512
1617
train_eval.num_iterations=200
1718
train_eval.batch_size=256
1819
train_eval.train_sequence_length=16

compiler_opt/rl/train_locally.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@
4848
flags.DEFINE_integer(
4949
'num_workers', None,
5050
'Number of parallel data collection workers. `None` for max available')
51-
flags.DEFINE_integer('num_modules', 100,
52-
'Number of modules to collect data for each iteration.')
5351
flags.DEFINE_multi_string('gin_files', [],
5452
'List of paths to gin configuration files.')
5553
flags.DEFINE_multi_string(
@@ -63,6 +61,7 @@
6361
def train_eval(agent_name=constant.AgentName.PPO,
6462
warmstart_policy_dir=None,
6563
num_policy_iterations=0,
64+
num_modules=100,
6665
num_iterations=100,
6766
batch_size=64,
6867
train_sequence_length=1,
@@ -149,7 +148,7 @@ def sequence_example_iterator_fn(seq_ex: List[str]):
149148
delete_flags=delete_compilation_flags) as worker_pool:
150149
data_collector = local_data_collector.LocalDataCollector(
151150
file_paths=tuple(file_paths),
152-
num_modules=FLAGS.num_modules,
151+
num_modules=num_modules,
153152
worker_pool=worker_pool,
154153
parser=sequence_example_iterator_fn,
155154
reward_stat_map=reward_stat_map)

0 commit comments

Comments
 (0)