Skip to content

Commit 9845ada

Browse files
author
Lisa
committed
last updates to address PR reviews
1 parent 448aff2 commit 9845ada

File tree

10 files changed

+34
-33
lines changed

10 files changed

+34
-33
lines changed

examples/aurora.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@
452452
"n_target = 1024\n",
453453
"\n",
454454
"previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - n_target\n",
455+
"container_size_control_fn = jax.jit(aurora.container_size_control)\n",
455456
"\n",
456457
"iteration = 0\n",
457458
"while iteration < max_iterations:\n",
@@ -480,7 +481,7 @@
480481
" )\n",
481482
"\n",
482483
" elif iteration % 2 == 0:\n",
483-
" repertoire, previous_error = jax.jit(aurora.container_size_control)(\n",
484+
" repertoire, previous_error = container_size_control_fn(\n",
484485
" repertoire,\n",
485486
" target_size=n_target,\n",
486487
" previous_error=previous_error,\n",

examples/cmaes.ipynb

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,18 +208,21 @@
208208
"covs = [(state.sigma**2) * state.cov_matrix]\n",
209209
"\n",
210210
"iteration_count = 0\n",
211+
"sample_fn = jax.jit(cmaes.sample)\n",
212+
"update_fn = jax.jit(cmaes.update)\n",
213+
"stop_condition_fn = jax.jit(cmaes.stop_condition)\n",
211214
"for _ in range(num_iterations):\n",
212215
" iteration_count += 1\n",
213216
"\n",
214217
" # sample\n",
215218
" key, subkey = jax.random.split(key)\n",
216-
" samples = jax.jit(cmaes.sample)(state, subkey)\n",
219+
" samples = sample_fn(state, subkey)\n",
217220
"\n",
218221
" # update\n",
219-
" state = jax.jit(cmaes.update)(state, samples)\n",
222+
" state = update_fn(state, samples)\n",
220223
"\n",
221224
" # check stop condition\n",
222-
" stop_condition = jax.jit(cmaes.stop_condition)(state)\n",
225+
" stop_condition = stop_condition_fn(state)\n",
223226
"\n",
224227
" if stop_condition:\n",
225228
" break\n",

examples/mapelites_asktell.ipynb

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,20 +327,23 @@
327327
"except ImportError:\n",
328328
" bar = range(num_iterations)\n",
329329
"\n",
330+
"ask_fn = jax.jit(map_elites.ask)\n",
331+
"tell_fn = jax.jit(map_elites.tell)\n",
332+
"\n",
330333
"# Main loop\n",
331334
"for i in bar:\n",
332335
" start_time = time.time()\n",
333336
" key, subkey = jax.random.split(key)\n",
334337
" # Generate solutions\n",
335-
" genotypes, extra_info = map_elites.ask(repertoire, emitter_state, subkey)\n",
338+
" genotypes, extra_info = ask_fn(repertoire, emitter_state, subkey)\n",
336339
"\n",
337340
" # Evaluate solutions: get fitness, descriptor and extra scores.\n",
338341
" # This is where custom evaluations on CPU or GPU can be added.\n",
339342
" key, subkey = jax.random.split(key)\n",
340343
" fitnesses, descriptors, extra_scores = scoring_fn(genotypes, subkey)\n",
341344
"\n",
342345
" # Update MAP-Elites\n",
343-
" repertoire, emitter_state, current_metrics = map_elites.tell(\n",
346+
" repertoire, emitter_state, current_metrics = tell_fn(\n",
344347
" genotypes=genotypes,\n",
345348
" fitnesses=fitnesses,\n",
346349
" descriptors=descriptors,\n",

examples/mees.ipynb

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -448,11 +448,8 @@
448448
"provenance": []
449449
},
450450
"gpuClass": "standard",
451-
"interpreter": {
452-
"hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64"
453-
},
454451
"kernelspec": {
455-
"display_name": "Python 3 (ipykernel)",
452+
"display_name": ".venv",
456453
"language": "python",
457454
"name": "python3"
458455
},
@@ -466,7 +463,7 @@
466463
"name": "python",
467464
"nbconvert_exporter": "python",
468465
"pygments_lexer": "ipython3",
469-
"version": "3.10.12"
466+
"version": "3.11.10"
470467
}
471468
},
472469
"nbformat": 4,

examples/nsga2_spea2.ipynb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,6 @@
296296
"\n",
297297
"# init spea2\n",
298298
"key, subkey = jax.random.split(key)\n",
299-
"init_fn = partial(spea2.init, population_size=population_size)\n",
300299
"repertoire, emitter_state, init_metrics = spea2.init(\n",
301300
" genotypes,\n",
302301
" population_size,\n",

examples/pga_aurora.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@
514514
"n_target = 1024\n",
515515
"\n",
516516
"previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - n_target\n",
517+
"container_size_control_fn = jax.jit(aurora.container_size_control)\n",
517518
"\n",
518519
"iteration = 0\n",
519520
"while iteration < max_iterations:\n",
@@ -542,7 +543,7 @@
542543
" )\n",
543544
"\n",
544545
" elif iteration % 2 == 0:\n",
545-
" repertoire, previous_error = jax.jit(aurora.container_size_control)(\n",
546+
" repertoire, previous_error = container_size_control_fn(\n",
546547
" repertoire,\n",
547548
" target_size=n_target,\n",
548549
" previous_error=previous_error,\n",

qdax/core/containers/archive.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
from functools import partial
65
from typing import Any, Dict, Tuple
76

87
import jax
@@ -240,8 +239,7 @@ def score_euclidean_novelty(
240239
Returns:
241240
The novelty scores of the given state descriptors.
242241
"""
243-
knn_fn = partial(knn, k=num_nearest_neighb)
244-
values, _indices = knn_fn(archive.data, state_descriptors)
242+
values, _indices = knn(archive.data, state_descriptors, num_nearest_neighb)
245243

246244
summed_distances = jnp.mean(jnp.square(values), axis=1)
247245
return scaling_ratio * summed_distances
@@ -278,8 +276,7 @@ def knn(
278276
dist = jnp.sqrt(jnp.clip(dist, min=0.0))
279277

280278
# return values, indices
281-
qdax_top_k_fn = partial(qdax_top_k, k=k)
282-
values, indices = qdax_top_k_fn(-dist)
279+
values, indices = qdax_top_k(-dist, k=k)
283280

284281
return -values, indices
285282

qdax/utils/plotting.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Any, Dict, Iterable, List, Optional, Tuple
22

3-
import jax
43
import jax.numpy as jnp
54
import matplotlib as mpl
65
import matplotlib.cm as cm
@@ -456,9 +455,7 @@ def plot_mome_pareto_fronts(
456455
axes[1].set_ylim(minval, maxval)
457456

458457
if with_global:
459-
global_pareto_front, pareto_bool = jax.jit(
460-
repertoire.compute_global_pareto_front
461-
)()
458+
global_pareto_front, pareto_bool = repertoire.compute_global_pareto_front()
462459
global_pareto_descriptors = jnp.concatenate(repertoire_descriptors)[pareto_bool]
463460
axes[0].scatter(
464461
global_pareto_front[:, 0],

tests/core_test/aurora_test.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,15 @@ def observation_extractor_fn(
218218
target_repertoire_size = 1024
219219

220220
previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - target_repertoire_size
221+
update_fn = jax.jit(aurora.update)
222+
container_size_control_fn = jax.jit(aurora.container_size_control)
221223

222224
iteration = 0
223225
while iteration < max_iterations:
224226
# standard MAP-Elites-like loop
225227
for _ in range(log_freq):
226228
key, subkey = jax.random.split(key)
227-
repertoire, emitter_state, _ = jax.jit(aurora.update)(
229+
repertoire, emitter_state, _ = update_fn(
228230
repertoire,
229231
emitter_state,
230232
subkey,
@@ -244,7 +246,7 @@ def observation_extractor_fn(
244246

245247
elif iteration % 2 == 0:
246248
# only CSC
247-
repertoire, previous_error = jax.jit(aurora.container_size_control)(
249+
repertoire, previous_error = container_size_control_fn(
248250
repertoire,
249251
target_size=target_repertoire_size,
250252
previous_error=previous_error,
@@ -427,22 +429,23 @@ def observation_extractor_fn(
427429
target_repertoire_size = 1024
428430

429431
previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - target_repertoire_size
432+
ask_fn = jax.jit(aurora.ask)
433+
tell_fn = jax.jit(aurora.tell)
434+
container_size_control_fn = jax.jit(aurora.container_size_control)
430435

431436
iteration = 0
432437

433438
while iteration < max_iterations:
434439
# standard MAP-Elites-like loop
435440
for _ in range(log_freq):
436441
key, subkey = jax.random.split(key)
437-
genotypes, extra_info = jax.jit(aurora.ask)(
438-
repertoire, emitter_state, subkey
439-
)
442+
genotypes, extra_info = ask_fn(repertoire, emitter_state, subkey)
440443

441444
# scores the offsprings
442445
key, subkey = jax.random.split(key)
443446
fitnesses, descriptors, extra_scores = aurora_scoring_fn(genotypes, subkey)
444447

445-
repertoire, emitter_state, _ = jax.jit(aurora.tell)(
448+
repertoire, emitter_state, _ = tell_fn(
446449
genotypes=genotypes,
447450
fitnesses=fitnesses,
448451
descriptors=descriptors,
@@ -466,7 +469,7 @@ def observation_extractor_fn(
466469

467470
elif iteration % 2 == 0:
468471
# only CSC
469-
repertoire, previous_error = jax.jit(aurora.container_size_control)(
472+
repertoire, previous_error = container_size_control_fn(
470473
repertoire,
471474
target_size=target_repertoire_size,
472475
previous_error=previous_error,

tests/core_test/map_elites_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,22 +269,22 @@ def play_step_fn(
269269
key=key,
270270
extra_scores=extra_scores,
271271
)
272+
ask_fn = jax.jit(map_elites.ask)
273+
tell_fn = jax.jit(map_elites.tell)
272274

273275
# Run the algorithm
274276
for _ in range(num_iterations):
275277
key, subkey = jax.random.split(key)
276278
# Generate solutions
277-
genotypes, extra_info = jax.jit(map_elites.ask)(
278-
repertoire, emitter_state, subkey
279-
)
279+
genotypes, extra_info = ask_fn(repertoire, emitter_state, subkey)
280280

281281
# Evaluate solutions: get fitness, descriptor and extra scores.
282282
# This is where custom evaluations on CPU or GPU can be added.
283283
key, subkey = jax.random.split(key)
284284
fitnesses, descriptors, extra_scores = scoring_fn(genotypes, subkey)
285285

286286
# Update MAP-Elites
287-
repertoire, emitter_state, current_metrics = jax.jit(map_elites.tell)(
287+
repertoire, emitter_state, current_metrics = tell_fn(
288288
genotypes=genotypes,
289289
fitnesses=fitnesses,
290290
descriptors=descriptors,

0 commit comments

Comments
 (0)