Skip to content

Commit b5472d0

Browse files
authored
Merge pull request #232 from LisaCoiffard/refactor/remove-jit-decorators
Remove jax.jit decorators
2 parents fdab34f + 9911e74 commit b5472d0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+132
-370
lines changed

examples/aurora.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@
452452
"n_target = 1024\n",
453453
"\n",
454454
"previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - n_target\n",
455+
"container_size_control_fn = jax.jit(aurora.container_size_control)\n",
455456
"\n",
456457
"iteration = 0\n",
457458
"while iteration < max_iterations:\n",
@@ -480,7 +481,7 @@
480481
" )\n",
481482
"\n",
482483
" elif iteration % 2 == 0:\n",
483-
" repertoire, previous_error = aurora.container_size_control(\n",
484+
" repertoire, previous_error = container_size_control_fn(\n",
484485
" repertoire,\n",
485486
" target_size=n_target,\n",
486487
" previous_error=previous_error,\n",

examples/cmaes.ipynb

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,18 +208,21 @@
208208
"covs = [(state.sigma**2) * state.cov_matrix]\n",
209209
"\n",
210210
"iteration_count = 0\n",
211+
"sample_fn = jax.jit(cmaes.sample)\n",
212+
"update_fn = jax.jit(cmaes.update)\n",
213+
"stop_condition_fn = jax.jit(cmaes.stop_condition)\n",
211214
"for _ in range(num_iterations):\n",
212215
" iteration_count += 1\n",
213216
"\n",
214217
" # sample\n",
215218
" key, subkey = jax.random.split(key)\n",
216-
" samples = cmaes.sample(state, subkey)\n",
219+
" samples = sample_fn(state, subkey)\n",
217220
"\n",
218221
" # update\n",
219-
" state = cmaes.update(state, samples)\n",
222+
" state = update_fn(state, samples)\n",
220223
"\n",
221224
" # check stop condition\n",
222-
" stop_condition = cmaes.stop_condition(state)\n",
225+
" stop_condition = stop_condition_fn(state)\n",
223226
"\n",
224227
" if stop_condition:\n",
225228
" break\n",

examples/mapelites_asktell.ipynb

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,20 +327,23 @@
327327
"except ImportError:\n",
328328
" bar = range(num_iterations)\n",
329329
"\n",
330+
"ask_fn = jax.jit(map_elites.ask)\n",
331+
"tell_fn = jax.jit(map_elites.tell)\n",
332+
"\n",
330333
"# Main loop\n",
331334
"for i in bar:\n",
332335
" start_time = time.time()\n",
333336
" key, subkey = jax.random.split(key)\n",
334337
" # Generate solutions\n",
335-
" genotypes, extra_info = map_elites.ask(repertoire, emitter_state, subkey)\n",
338+
" genotypes, extra_info = ask_fn(repertoire, emitter_state, subkey)\n",
336339
"\n",
337340
" # Evaluate solutions: get fitness, descriptor and extra scores.\n",
338341
" # This is where custom evaluations on CPU or GPU can be added.\n",
339342
" key, subkey = jax.random.split(key)\n",
340343
" fitnesses, descriptors, extra_scores = scoring_fn(genotypes, subkey)\n",
341344
"\n",
342345
" # Update MAP-Elites\n",
343-
" repertoire, emitter_state, current_metrics = map_elites.tell(\n",
346+
" repertoire, emitter_state, current_metrics = tell_fn(\n",
344347
" genotypes=genotypes,\n",
345348
" fitnesses=fitnesses,\n",
346349
" descriptors=descriptors,\n",

examples/mees.ipynb

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,11 @@
248248
"\n",
249249
"# Prepare the scoring functions for the offspring generated following\n",
250250
"# the approximated gradient (each of them is evaluated 30 times)\n",
251-
"sampling_fn = functools.partial(\n",
251+
"sampling_fn = jax.jit(functools.partial(\n",
252252
" sampling,\n",
253253
" scoring_fn=scoring_fn,\n",
254254
" num_samples=30,\n",
255-
")\n",
255+
"))\n",
256256
"\n",
257257
"# Get minimum reward value to make sure qd_score are positive\n",
258258
"reward_offset = environments.reward_offset[env_name]\n",
@@ -448,11 +448,8 @@
448448
"provenance": []
449449
},
450450
"gpuClass": "standard",
451-
"interpreter": {
452-
"hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64"
453-
},
454451
"kernelspec": {
455-
"display_name": "Python 3 (ipykernel)",
452+
"display_name": ".venv",
456453
"language": "python",
457454
"name": "python3"
458455
},
@@ -466,7 +463,7 @@
466463
"name": "python",
467464
"nbconvert_exporter": "python",
468465
"pygments_lexer": "ipython3",
469-
"version": "3.10.12"
466+
"version": "3.11.10"
470467
}
471468
},
472469
"nbformat": 4,

examples/pga_aurora.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@
514514
"n_target = 1024\n",
515515
"\n",
516516
"previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - n_target\n",
517+
"container_size_control_fn = jax.jit(aurora.container_size_control)\n",
517518
"\n",
518519
"iteration = 0\n",
519520
"while iteration < max_iterations:\n",
@@ -542,7 +543,7 @@
542543
" )\n",
543544
"\n",
544545
" elif iteration % 2 == 0:\n",
545-
" repertoire, previous_error = aurora.container_size_control(\n",
546+
" repertoire, previous_error = container_size_control_fn(\n",
546547
" repertoire,\n",
547548
" target_size=n_target,\n",
548549
" previous_error=previous_error,\n",

qdax/baselines/cmaes.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
a CMA optimization script. Link to the paper: https://arxiv.org/abs/1604.00772
44
"""
55

6-
from functools import partial
76
from typing import Callable, Optional, Tuple
87

98
import flax
@@ -165,7 +164,6 @@ def init(self) -> CMAESState:
165164
invsqrt_cov=invsqrt_cov,
166165
)
167166

168-
@partial(jax.jit, static_argnames=("self",))
169167
def sample(self, cmaes_state: CMAESState, key: RNGKey) -> Genotype:
170168
"""
171169
Sample a population.
@@ -186,7 +184,6 @@ def sample(self, cmaes_state: CMAESState, key: RNGKey) -> Genotype:
186184
)
187185
return samples
188186

189-
@partial(jax.jit, static_argnames=("self",))
190187
def update_state(
191188
self,
192189
cmaes_state: CMAESState,
@@ -198,7 +195,6 @@ def update_state(
198195
weights=self._weights,
199196
)
200197

201-
@partial(jax.jit, static_argnames=("self",))
202198
def update_state_with_mask(
203199
self, cmaes_state: CMAESState, sorted_candidates: Genotype, mask: Mask
204200
) -> CMAESState:
@@ -217,7 +213,6 @@ def update_state_with_mask(
217213
weights=weights,
218214
)
219215

220-
@partial(jax.jit, static_argnames=("self",))
221216
def _update_state(
222217
self,
223218
cmaes_state: CMAESState,
@@ -332,7 +327,6 @@ def update_eigen(
332327

333328
return cmaes_state
334329

335-
@partial(jax.jit, static_argnames=("self",))
336330
def update(self, cmaes_state: CMAESState, samples: Genotype) -> CMAESState:
337331
"""Updates the distribution.
338332
@@ -352,7 +346,6 @@ def update(self, cmaes_state: CMAESState, samples: Genotype) -> CMAESState:
352346

353347
return new_state # type: ignore
354348

355-
@partial(jax.jit, static_argnames=("self",))
356349
def stop_condition(self, cmaes_state: CMAESState) -> bool:
357350
"""Determines if the current optimization path must be stopped.
358351

qdax/baselines/dads.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"""
55

66
from dataclasses import dataclass
7-
from functools import partial
87
from typing import Callable, Tuple
98

109
import jax
@@ -191,7 +190,6 @@ def init( # type: ignore
191190
steps=jnp.array(0),
192191
)
193192

194-
@partial(jax.jit, static_argnames=("self",))
195193
def _compute_diversity_reward(
196194
self, transition: QDTransition, training_state: DadsTrainingState
197195
) -> Reward:
@@ -244,8 +242,7 @@ def _compute_diversity_reward(
244242

245243
return reward
246244

247-
@partial(jax.jit, static_argnames=("self", "env", "deterministic", "evaluation"))
248-
def play_step_fn(
245+
def play_step_fn( # type: ignore
249246
self,
250247
env_state: EnvState,
251248
training_state: DadsTrainingState,
@@ -339,14 +336,13 @@ def play_step_fn(
339336

340337
return next_env_state, training_state, transition
341338

342-
@partial(jax.jit, static_argnames=("self", "play_step_fn", "env_batch_size"))
343-
def eval_policy_fn(
339+
def eval_policy_fn( # type: ignore
344340
self,
345341
training_state: DadsTrainingState,
346342
eval_env_first_state: EnvState,
347343
play_step_fn: Callable[
348-
[EnvState, Params, RNGKey],
349-
Tuple[EnvState, Params, RNGKey, QDTransition],
344+
[EnvState, Params],
345+
Tuple[EnvState, Params, QDTransition],
350346
],
351347
env_batch_size: int,
352348
) -> Tuple[Reward, Reward, Reward, StateDescriptor]:
@@ -400,7 +396,6 @@ def eval_policy_fn(
400396

401397
return true_return, true_returns, diversity_returns, transitions.state_desc
402398

403-
@partial(jax.jit, static_argnames=("self",))
404399
def _compute_reward(
405400
self, transition: QDTransition, training_state: DadsTrainingState
406401
) -> Reward:
@@ -417,7 +412,6 @@ def _compute_reward(
417412
transition=transition, training_state=training_state
418413
)
419414

420-
@partial(jax.jit, static_argnames=("self",))
421415
def _update_dynamics(
422416
self, operand: Tuple[DadsTrainingState, QDTransition]
423417
) -> Tuple[Params, float, optax.OptState]:
@@ -448,7 +442,6 @@ def _update_dynamics(
448442
dynamics_optimizer_state,
449443
)
450444

451-
@partial(jax.jit, static_argnames=("self",))
452445
def _not_update_dynamics(
453446
self, operand: Tuple[DadsTrainingState, QDTransition]
454447
) -> Tuple[Params, float, optax.OptState]:
@@ -464,7 +457,6 @@ def _not_update_dynamics(
464457
training_state.dynamics_optimizer_state,
465458
)
466459

467-
@partial(jax.jit, static_argnames=("self",))
468460
def _update_networks(
469461
self,
470462
training_state: DadsTrainingState,
@@ -566,7 +558,6 @@ def _update_networks(
566558

567559
return new_training_state, metrics
568560

569-
@partial(jax.jit, static_argnames=("self",))
570561
def update(
571562
self,
572563
training_state: DadsTrainingState,

qdax/baselines/dads_smerl.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
"""
66

77
from dataclasses import dataclass
8-
from functools import partial
98
from typing import Optional, Tuple
109

1110
import jax
@@ -40,8 +39,7 @@ def __init__(self, config: DadsSmerlConfig, action_size: int, descriptor_size: i
4039
super(DADSSMERL, self).__init__(config, action_size, descriptor_size)
4140
self._config: DadsSmerlConfig = config
4241

43-
@partial(jax.jit, static_argnames=("self",))
44-
def _compute_reward(
42+
def _compute_reward( # type: ignore
4543
self,
4644
transition: QDTransition,
4745
training_state: DadsTrainingState,
@@ -74,7 +72,6 @@ def _compute_reward(
7472

7573
return rewards
7674

77-
@partial(jax.jit, static_argnames=("self",))
7875
def update(
7976
self,
8077
training_state: DadsTrainingState,

qdax/baselines/diayn.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"""
55

66
from dataclasses import dataclass
7-
from functools import partial
87
from typing import Callable, Tuple
98

109
import jax
@@ -177,7 +176,6 @@ def init( # type: ignore
177176
steps=jnp.array(0),
178177
)
179178

180-
@partial(jax.jit, static_argnames=("self", "add_log_p_z"))
181179
def _compute_diversity_reward(
182180
self,
183181
transition: QDTransition,
@@ -212,8 +210,7 @@ def _compute_diversity_reward(
212210
reward += jnp.log(self._config.num_skills)
213211
return reward
214212

215-
@partial(jax.jit, static_argnames=("self", "env", "deterministic"))
216-
def play_step_fn(
213+
def play_step_fn( # type: ignore
217214
self,
218215
env_state: EnvState,
219216
training_state: DiaynTrainingState,
@@ -279,14 +276,13 @@ def play_step_fn(
279276

280277
return next_env_state, training_state, transition
281278

282-
@partial(jax.jit, static_argnames=("self", "play_step_fn", "env_batch_size"))
283-
def eval_policy_fn(
279+
def eval_policy_fn( # type: ignore
284280
self,
285281
training_state: DiaynTrainingState,
286282
eval_env_first_state: EnvState,
287283
play_step_fn: Callable[
288-
[EnvState, Params, RNGKey],
289-
Tuple[EnvState, Params, RNGKey, QDTransition],
284+
[EnvState, Params],
285+
Tuple[EnvState, Params, QDTransition],
290286
],
291287
env_batch_size: int,
292288
) -> Tuple[Reward, Reward, Reward, StateDescriptor]:
@@ -347,7 +343,6 @@ def eval_policy_fn(
347343
transitions.state_desc,
348344
)
349345

350-
@partial(jax.jit, static_argnames=("self",))
351346
def _compute_reward(
352347
self, transition: QDTransition, training_state: DiaynTrainingState
353348
) -> Reward:
@@ -366,7 +361,6 @@ def _compute_reward(
366361
add_log_p_z=True,
367362
)
368363

369-
@partial(jax.jit, static_argnames=("self",))
370364
def _update_networks(
371365
self,
372366
training_state: DiaynTrainingState,
@@ -469,7 +463,6 @@ def _update_networks(
469463

470464
return new_training_state, metrics
471465

472-
@partial(jax.jit, static_argnames=("self",))
473466
def update(
474467
self,
475468
training_state: DiaynTrainingState,

qdax/baselines/diayn_smerl.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
"""
66

77
from dataclasses import dataclass
8-
from functools import partial
98
from typing import Optional, Tuple
109

1110
import jax
@@ -39,8 +38,7 @@ def __init__(self, config: DiaynSmerlConfig, action_size: int):
3938
super(DIAYNSMERL, self).__init__(config, action_size)
4039
self._config: DiaynSmerlConfig = config
4140

42-
@partial(jax.jit, static_argnames=("self",))
43-
def _compute_reward(
41+
def _compute_reward( # type: ignore
4442
self,
4543
transition: QDTransition,
4644
training_state: DiaynTrainingState,
@@ -81,7 +79,6 @@ def _compute_reward(
8179

8280
return rewards
8381

84-
@partial(jax.jit, static_argnames=("self",))
8582
def update(
8683
self,
8784
training_state: DiaynTrainingState,

0 commit comments

Comments
 (0)