Skip to content

Commit 50faa11

Browse files
authored
Fixes for downstream integration: (#458)
- np.Inf becomes np.inf - explicit `module` for a test gin-configurable
1 parent 6354047 commit 50faa11

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

compiler_opt/rl/imitation_learning/generate_bc_trajectories_lib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -542,8 +542,8 @@ def explore_at_state_generator(
542542

543543
distr_logits = explore_policy(explore_state).action.logits.numpy()[0]
544544
for _ in range(num_samples):
545-
distr_logits[replay_prefix[explore_step]] = -np.Inf
546-
if all(-np.Inf == logit for logit in distr_logits):
545+
distr_logits[replay_prefix[explore_step]] = -np.inf
546+
if all(-np.inf == logit for logit in distr_logits):
547547
break
548548
replay_prefix[explore_step] = np.random.choice(
549549
range(distr_logits.shape[0]), p=scipy.special.softmax(distr_logits))

compiler_opt/rl/imitation_learning/generate_bc_trajectories_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -453,9 +453,9 @@ def _explore_policy(state: time_step.TimeStep):
453453
# will explore every 4-th step
454454
logits = [[
455455
4.0 + 1e-3 * float(env_test._NUM_STEPS - times_called),
456-
-np.Inf,
457-
-np.Inf,
458-
-np.Inf,
456+
-np.inf,
457+
-np.inf,
458+
-np.inf,
459459
float(np.mod(times_called, 5)),
460460
]]
461461
return policy_step.PolicyStep(
@@ -632,7 +632,7 @@ def test_process_succeeded(self):
632632
self.assertEqual(seq_example, succeeded_comp[0][0][2])
633633

634634

635-
@gin.configurable
635+
@gin.configurable(module='generate_bc_trajectories_test')
636636
class MockModuleWorker(generate_bc_trajectories_lib.ModuleWorker):
637637

638638
@mock.patch('subprocess.Popen')

0 commit comments

Comments
 (0)