Skip to content

Commit c00eebc

Browse files
Use .item() before casting single item NP array (#542)
This is needed for compatibility with numpy 2.4.
1 parent 2372e55 commit c00eebc

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

compiler_opt/rl/imitation_learning/generate_bc_trajectories_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def compile_module(
396396
while curr_obs_dict.step_type != env.StepType.LAST:
397397
timestep = self._create_timestep(curr_obs_dict)
398398
action = policy(timestep)
399-
add_int_feature(sequence_example, int(action),
399+
add_int_feature(sequence_example, int(action.item()),
400400
SequenceExampleFeatureNames.action)
401401
curr_obs_dict = self._env.step(action)
402402
curr_obs = curr_obs_dict.obs

compiler_opt/rl/imitation_learning/generate_bc_trajectories_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ class ExplorationWithPolicyTest(tf.test.TestCase):
9191
def _explore_policy(self,
9292
state: time_step.TimeStep) -> policy_step.PolicyStep:
9393
probs = [
94-
0.5 * float(state.observation['feature_3'].numpy()),
95-
1 - 0.5 * float(state.observation['feature_3'].numpy())
94+
0.5 * float(state.observation['feature_3'].numpy().item()),
95+
1 - 0.5 * float(state.observation['feature_3'].numpy().item())
9696
]
9797
logits = [[0.0, tf.math.log(probs[1] / (1.0 - probs[1] + _eps))]]
9898
return policy_step.PolicyStep(

0 commit comments

Comments
 (0)