@@ -218,15 +218,13 @@ def observation_extractor_fn(
218218 target_repertoire_size = 1024
219219
220220 previous_error = jnp .sum (repertoire .fitnesses != - jnp .inf ) - target_repertoire_size
221- container_size_control_jitted = jax .jit (aurora .container_size_control )
222- update_jitted = jax .jit (aurora .update )
223221
224222 iteration = 0
225223 while iteration < max_iterations :
226224 # standard MAP-Elites-like loop
227225 for _ in range (log_freq ):
228226 key , subkey = jax .random .split (key )
229- repertoire , emitter_state , _ = update_jitted (
227+ repertoire , emitter_state , _ = jax . jit ( aurora . update ) (
230228 repertoire ,
231229 emitter_state ,
232230 subkey ,
@@ -246,7 +244,7 @@ def observation_extractor_fn(
246244
247245 elif iteration % 2 == 0 :
248246 # only CSC
249- repertoire , previous_error = container_size_control_jitted (
247+ repertoire , previous_error = jax . jit ( aurora . container_size_control ) (
250248 repertoire ,
251249 target_size = target_repertoire_size ,
252250 previous_error = previous_error ,
@@ -431,21 +429,20 @@ def observation_extractor_fn(
431429 previous_error = jnp .sum (repertoire .fitnesses != - jnp .inf ) - target_repertoire_size
432430
433431 iteration = 0
434- container_size_control_jitted = jax .jit (aurora .container_size_control )
435- ask_jitted = jax .jit (aurora .ask )
436- tell_jitted = jax .jit (aurora .tell )
437432
438433 while iteration < max_iterations :
439434 # standard MAP-Elites-like loop
440435 for _ in range (log_freq ):
441436 key , subkey = jax .random .split (key )
442- genotypes , extra_info = ask_jitted (repertoire , emitter_state , subkey )
437+ genotypes , extra_info = jax .jit (aurora .ask )(
438+ repertoire , emitter_state , subkey
439+ )
443440
444441 # scores the offsprings
445442 key , subkey = jax .random .split (key )
446443 fitnesses , descriptors , extra_scores = aurora_scoring_fn (genotypes , subkey )
447444
448- repertoire , emitter_state , _ = tell_jitted (
445+ repertoire , emitter_state , _ = jax . jit ( aurora . tell ) (
449446 genotypes = genotypes ,
450447 fitnesses = fitnesses ,
451448 descriptors = descriptors ,
@@ -469,7 +466,7 @@ def observation_extractor_fn(
469466
470467 elif iteration % 2 == 0 :
471468 # only CSC
472- repertoire , previous_error = container_size_control_jitted (
469+ repertoire , previous_error = jax . jit ( aurora . container_size_control ) (
473470 repertoire ,
474471 target_size = target_repertoire_size ,
475472 previous_error = previous_error ,
0 commit comments