@@ -218,13 +218,15 @@ def observation_extractor_fn(
218218 target_repertoire_size = 1024
219219
220220 previous_error = jnp .sum (repertoire .fitnesses != - jnp .inf ) - target_repertoire_size
221+ update_fn = jax .jit (aurora .update )
222+ container_size_control_fn = jax .jit (aurora .container_size_control )
221223
222224 iteration = 0
223225 while iteration < max_iterations :
224226 # standard MAP-Elites-like loop
225227 for _ in range (log_freq ):
226228 key , subkey = jax .random .split (key )
227- repertoire , emitter_state , _ = jax . jit ( aurora . update ) (
229+ repertoire , emitter_state , _ = update_fn (
228230 repertoire ,
229231 emitter_state ,
230232 subkey ,
@@ -244,7 +246,7 @@ def observation_extractor_fn(
244246
245247 elif iteration % 2 == 0 :
246248 # only CSC
247- repertoire , previous_error = jax . jit ( aurora . container_size_control ) (
249+ repertoire , previous_error = container_size_control_fn (
248250 repertoire ,
249251 target_size = target_repertoire_size ,
250252 previous_error = previous_error ,
@@ -427,22 +429,23 @@ def observation_extractor_fn(
427429 target_repertoire_size = 1024
428430
429431 previous_error = jnp .sum (repertoire .fitnesses != - jnp .inf ) - target_repertoire_size
432+ ask_fn = jax .jit (aurora .ask )
433+ tell_fn = jax .jit (aurora .tell )
434+ container_size_control_fn = jax .jit (aurora .container_size_control )
430435
431436 iteration = 0
432437
433438 while iteration < max_iterations :
434439 # standard MAP-Elites-like loop
435440 for _ in range (log_freq ):
436441 key , subkey = jax .random .split (key )
437- genotypes , extra_info = jax .jit (aurora .ask )(
438- repertoire , emitter_state , subkey
439- )
442+ genotypes , extra_info = ask_fn (repertoire , emitter_state , subkey )
440443
441444 # scores the offsprings
442445 key , subkey = jax .random .split (key )
443446 fitnesses , descriptors , extra_scores = aurora_scoring_fn (genotypes , subkey )
444447
445- repertoire , emitter_state , _ = jax . jit ( aurora . tell ) (
448+ repertoire , emitter_state , _ = tell_fn (
446449 genotypes = genotypes ,
447450 fitnesses = fitnesses ,
448451 descriptors = descriptors ,
@@ -466,7 +469,7 @@ def observation_extractor_fn(
466469
467470 elif iteration % 2 == 0 :
468471 # only CSC
469- repertoire , previous_error = jax . jit ( aurora . container_size_control ) (
472+ repertoire , previous_error = container_size_control_fn (
470473 repertoire ,
471474 target_size = target_repertoire_size ,
472475 previous_error = previous_error ,
0 commit comments