Skip to content

Commit 547903f

Browse files
authored
[BC] weighted_bc_trainer_lib.py bug fixes (#445)
Fixed ```metric.reset_states() -> metric.reset_state()```, removed a debugging line in ```weighted_bc_trainer_lib```, slight naming changes for ```generate_bc_trajectories```.
1 parent 935f5fa commit 547903f

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

compiler_opt/rl/imitation_learning/weighted_bc_trainer_lib.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def train(self, filepaths: list[str]):
392392
for epoch in range(self._epochs):
393393
logging.info('Epoch %s', epoch)
394394
for metric in self._metrics:
395-
metric.reset_states()
395+
metric.reset_state()
396396
for step, (x_batch_train, y_batch_train) in enumerate(dataset):
397397
weight_labels = [y_batch_train[:, 1]]
398398
weights_arr = [self._trainig_weights.get_weights()]
@@ -411,8 +411,6 @@ def train(self, filepaths: list[str]):
411411
(step + 1) * self._batch_size)
412412
for metric in self._metrics:
413413
logging.info('%s: %s', metric.name, metric.result())
414-
if step > 1000: # debugging
415-
break
416414

417415
if self._save_model_dir:
418416
keras.models.save_model(self._model,

compiler_opt/rl/inlining/imitation_learning_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def get_task_type() -> type[env.InliningForSizeTask]:
8181

8282

8383
@gin.register
84-
def greedy_policy(state: time_step.TimeStep):
84+
def default_policy(state: time_step.TimeStep):
8585
"""Greedy policy playing the inlining_default action."""
8686
return np.array(state.observation['inlining_default'])
8787

compiler_opt/rl/inlining/imitation_learning_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def main(_):
3737
logging.info(gin.config_str())
3838

3939
generate_bc_trajectories_lib.gen_trajectories(
40-
callable_policies=[imitation_learning_config.greedy_policy],
40+
# Set callable policies here directly since callables can not be
41+
# gin configured when they are pickled.
42+
callable_policies=[imitation_learning_config.default_policy],
4143
explore_on_features={
4244
'is_callee_avail_external':
4345
imitation_learning_config.explore_on_avail_external

0 commit comments

Comments
 (0)