|
236 | 236 | "\n", |
237 | 237 | "def scoring_function(genotypes, key):\n", |
238 | 238 | " population_size = jax.tree.leaves(genotypes)[0].shape[0]\n", |
239 | | - " first_states = jax.tree_map(\n", |
| 239 | + " first_states = jax.tree.map(\n", |
240 | 240 | " lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states\n", |
241 | 241 | " )\n", |
242 | | - " first_states = jax.tree_map(\n", |
| 242 | + " first_states = jax.tree.map(\n", |
243 | 243 | " lambda x: jnp.repeat(x, population_size, axis=0), first_states\n", |
244 | 244 | " )\n", |
245 | 245 | " population_returns, population_descriptors, _, _ = eval_policy(genotypes, first_states)\n", |
|
378 | 378 | " repertoire, emitter_state, metrics = update_fn(\n", |
379 | 379 | " repertoire, emitter_state, keys\n", |
380 | 380 | " )\n", |
381 | | - " metrics_cpu = jax.tree_map(\n", |
| 381 | + " metrics_cpu = jax.tree.map(\n", |
382 | 382 | " lambda x: jax.device_put(x, jax.devices(\"cpu\")[0])[0], metrics\n", |
383 | 383 | " )\n", |
384 | 384 | " timelapse = time.time() - start_time\n", |
|
401 | 401 | "outputs": [], |
402 | 402 | "source": [ |
403 | 403 | "# Create the performance evolution plots and visualize final grid\n", |
404 | | - "repertoire_cpu = jax.tree_map(\n", |
| 404 | + "repertoire_cpu = jax.tree.map(\n", |
405 | 405 | " lambda x: jax.device_put(x, jax.devices(\"cpu\")[0])[0], repertoire\n", |
406 | 406 | ")\n", |
407 | 407 | "num_loops_with_init = num_loops + 1\n", |
|
510 | 510 | "key, subkey = jax.random.split(key)\n", |
511 | 511 | "env_state = jax.jit(env.reset)(rng=subkey)\n", |
512 | 512 | "\n", |
513 | | - "training_state, env_state = jax.tree_map(\n", |
| 513 | + "training_state, env_state = jax.tree.map(\n", |
514 | 514 | " lambda x: jnp.expand_dims(x, axis=0), (training_state, env_state)\n", |
515 | 515 | ")\n", |
516 | 516 | "\n", |
|
529 | 529 | "outputs": [], |
530 | 530 | "source": [ |
531 | 531 | "rollout = [\n", |
532 | | - " jax.tree_map(lambda x: jax.device_put(x[0], jax.devices(\"cpu\")[0]), env_state)\n", |
| 532 | + " jax.tree.map(lambda x: jax.device_put(x[0], jax.devices(\"cpu\")[0]), env_state)\n", |
533 | 533 | " for env_state in rollout\n", |
534 | 534 | "]" |
535 | 535 | ] |
|
0 commit comments