Skip to content

Commit 448aff2

Browse files
author
Lisa
committed
fix tests
1 parent 70939fd commit 448aff2

File tree

3 files changed

+12
-17
lines changed

3 files changed

+12
-17
lines changed

examples/pga_aurora.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,6 @@
514514
"n_target = 1024\n",
515515
"\n",
516516
"previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - n_target\n",
517-
"container_size_control_jitted = jax.jit(aurora.container_size_control)\n",
518517
"\n",
519518
"iteration = 0\n",
520519
"while iteration < max_iterations:\n",
@@ -543,7 +542,7 @@
543542
" )\n",
544543
"\n",
545544
" elif iteration % 2 == 0:\n",
546-
" repertoire, previous_error = container_size_control_jitted(\n",
545+
" repertoire, previous_error = jax.jit(aurora.container_size_control)(\n",
547546
" repertoire,\n",
548547
" target_size=n_target,\n",
549548
" previous_error=previous_error,\n",

tests/core_test/aurora_test.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -218,15 +218,13 @@ def observation_extractor_fn(
218218
target_repertoire_size = 1024
219219

220220
previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - target_repertoire_size
221-
container_size_control_jitted = jax.jit(aurora.container_size_control)
222-
update_jitted = jax.jit(aurora.update)
223221

224222
iteration = 0
225223
while iteration < max_iterations:
226224
# standard MAP-Elites-like loop
227225
for _ in range(log_freq):
228226
key, subkey = jax.random.split(key)
229-
repertoire, emitter_state, _ = update_jitted(
227+
repertoire, emitter_state, _ = jax.jit(aurora.update)(
230228
repertoire,
231229
emitter_state,
232230
subkey,
@@ -246,7 +244,7 @@ def observation_extractor_fn(
246244

247245
elif iteration % 2 == 0:
248246
# only CSC
249-
repertoire, previous_error = container_size_control_jitted(
247+
repertoire, previous_error = jax.jit(aurora.container_size_control)(
250248
repertoire,
251249
target_size=target_repertoire_size,
252250
previous_error=previous_error,
@@ -431,21 +429,20 @@ def observation_extractor_fn(
431429
previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - target_repertoire_size
432430

433431
iteration = 0
434-
container_size_control_jitted = jax.jit(aurora.container_size_control)
435-
ask_jitted = jax.jit(aurora.ask)
436-
tell_jitted = jax.jit(aurora.tell)
437432

438433
while iteration < max_iterations:
439434
# standard MAP-Elites-like loop
440435
for _ in range(log_freq):
441436
key, subkey = jax.random.split(key)
442-
genotypes, extra_info = ask_jitted(repertoire, emitter_state, subkey)
437+
genotypes, extra_info = jax.jit(aurora.ask)(
438+
repertoire, emitter_state, subkey
439+
)
443440

444441
# scores the offsprings
445442
key, subkey = jax.random.split(key)
446443
fitnesses, descriptors, extra_scores = aurora_scoring_fn(genotypes, subkey)
447444

448-
repertoire, emitter_state, _ = tell_jitted(
445+
repertoire, emitter_state, _ = jax.jit(aurora.tell)(
449446
genotypes=genotypes,
450447
fitnesses=fitnesses,
451448
descriptors=descriptors,
@@ -469,7 +466,7 @@ def observation_extractor_fn(
469466

470467
elif iteration % 2 == 0:
471468
# only CSC
472-
repertoire, previous_error = container_size_control_jitted(
469+
repertoire, previous_error = jax.jit(aurora.container_size_control)(
473470
repertoire,
474471
target_size=target_repertoire_size,
475472
previous_error=previous_error,

tests/core_test/map_elites_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,22 +270,21 @@ def play_step_fn(
270270
extra_scores=extra_scores,
271271
)
272272

273-
ask_jitted = jax.jit(map_elites.ask)
274-
tell_jitted = jax.jit(map_elites.tell)
275-
276273
# Run the algorithm
277274
for _ in range(num_iterations):
278275
key, subkey = jax.random.split(key)
279276
# Generate solutions
280-
genotypes, extra_info = ask_jitted(repertoire, emitter_state, subkey)
277+
genotypes, extra_info = jax.jit(map_elites.ask)(
278+
repertoire, emitter_state, subkey
279+
)
281280

282281
# Evaluate solutions: get fitness, descriptor and extra scores.
283282
# This is where custom evaluations on CPU or GPU can be added.
284283
key, subkey = jax.random.split(key)
285284
fitnesses, descriptors, extra_scores = scoring_fn(genotypes, subkey)
286285

287286
# Update MAP-Elites
288-
repertoire, emitter_state, current_metrics = tell_jitted(
287+
repertoire, emitter_state, current_metrics = jax.jit(map_elites.tell)(
289288
genotypes=genotypes,
290289
fitnesses=fitnesses,
291290
descriptors=descriptors,

0 commit comments

Comments
 (0)