Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions qdax/baselines/genetic_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import jax

from qdax.core.containers.ga_repertoire import GARepertoire
from qdax.core.containers.ga_popullation import GAPopulation
from qdax.core.emitters.emitter import Emitter, EmitterState
from qdax.custom_types import ExtraScores, Fitness, Genotype, Metrics, RNGKey

Expand All @@ -32,7 +32,7 @@ def __init__(
[Genotype, RNGKey], Tuple[Fitness, ExtraScores, RNGKey]
],
emitter: Emitter,
metrics_function: Callable[[GARepertoire], Metrics],
metrics_function: Callable[[GAPopulation], Metrics],
) -> None:
self._scoring_function = scoring_function
self._emitter = emitter
Expand All @@ -41,7 +41,7 @@ def __init__(
@partial(jax.jit, static_argnames=("self", "population_size"))
def init(
self, genotypes: Genotype, population_size: int, random_key: RNGKey
) -> Tuple[GARepertoire, Optional[EmitterState], RNGKey]:
) -> Tuple[GAPopulation, Optional[EmitterState], RNGKey]:
"""Initialize a GARepertoire with an initial population of genotypes.

Args:
Expand All @@ -59,7 +59,7 @@ def init(
)

# init the repertoire
repertoire = GARepertoire.init(
repertoire = GAPopulation.init(
genotypes=genotypes,
fitnesses=fitnesses,
population_size=population_size,
Expand All @@ -80,10 +80,10 @@ def init(
@partial(jax.jit, static_argnames=("self",))
def update(
self,
repertoire: GARepertoire,
repertoire: GAPopulation,
emitter_state: Optional[EmitterState],
random_key: RNGKey,
) -> Tuple[GARepertoire, Optional[EmitterState], Metrics, RNGKey]:
) -> Tuple[GAPopulation, Optional[EmitterState], Metrics, RNGKey]:
"""
Performs one iteration of a Genetic algorithm.
1. A batch of genotypes is sampled in the repertoire and the genotypes
Expand Down Expand Up @@ -134,9 +134,9 @@ def update(
@partial(jax.jit, static_argnames=("self",))
def scan_update(
self,
carry: Tuple[GARepertoire, Optional[EmitterState], RNGKey],
carry: Tuple[GAPopulation, Optional[EmitterState], RNGKey],
unused: Any,
) -> Tuple[Tuple[GARepertoire, Optional[EmitterState], RNGKey], Metrics]:
) -> Tuple[Tuple[GAPopulation, Optional[EmitterState], RNGKey], Metrics]:
"""Rewrites the update function in a way that makes it compatible with the
jax.lax.scan primitive.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from functools import partial
from typing import Callable, Tuple

import flax
import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

from qdax.core.containers.repertoire import Repertoire
from qdax.custom_types import Fitness, Genotype, RNGKey


class GARepertoire(Repertoire):
class GAPopulation(flax.struct.PyTreeNode):
"""Class for a simple repertoire for a simple genetic
algorithm.

Expand All @@ -23,7 +23,7 @@ class GARepertoire(Repertoire):
shape (population_size, num_features).
fitnesses: an array containing the fitness of the individuals
in the population. With shape (population_size, fitness_dim).
The implementation of GARepertoire was thought for the case
The implementation of GAPopulation was thought for the case
where fitness_dim equals 1 but the class can be herited and
rules adapted for cases where fitness_dim is greater than 1.
"""
Expand Down Expand Up @@ -55,7 +55,7 @@ def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
jnp.save(path + "scores.npy", self.fitnesses)

@classmethod
def load(cls, reconstruction_fn: Callable, path: str = "./") -> GARepertoire:
def load(cls, reconstruction_fn: Callable, path: str = "./") -> GAPopulation:
"""Loads a GA Repertoire.

Args:
Expand Down Expand Up @@ -107,7 +107,7 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey
@jax.jit
def add(
self, batch_of_genotypes: Genotype, batch_of_fitnesses: Fitness
) -> GARepertoire:
) -> GAPopulation:
"""Implements the repertoire addition rules.

Parents and offsprings are gathered and only the population_size
Expand Down Expand Up @@ -154,7 +154,7 @@ def init( # type: ignore
genotypes: Genotype,
fitnesses: Fitness,
population_size: int,
) -> GARepertoire:
) -> GAPopulation:
"""Initializes the repertoire.

Start with default values and adds a first batch of genotypes
Expand Down
4 changes: 2 additions & 2 deletions qdax/core/containers/nsga2_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import jax
import jax.numpy as jnp

from qdax.core.containers.ga_repertoire import GARepertoire
from qdax.core.containers.ga_popullation import GAPopulation
from qdax.custom_types import Fitness, Genotype
from qdax.utils.pareto_front import compute_masked_pareto_front


class NSGA2Repertoire(GARepertoire):
class NSGA2Repertoire(GAPopulation):
"""Repertoire used for the NSGA2 algorithm.

Inherits from the GARepertoire. The data stored are the genotypes
Expand Down
53 changes: 0 additions & 53 deletions qdax/core/containers/repertoire.py

This file was deleted.

6 changes: 3 additions & 3 deletions qdax/core/containers/spea2_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import jax
import jax.numpy as jnp

from qdax.core.containers.ga_repertoire import GARepertoire
from qdax.core.containers.ga_popullation import GAPopulation
from qdax.custom_types import Fitness, Genotype


class SPEA2Repertoire(GARepertoire):
class SPEA2Repertoire(GAPopulation):
"""Repertoire used for the SPEA2 algorithm.

Inherits from the GARepertoire. The data stored are the genotypes
Expand Down Expand Up @@ -101,7 +101,7 @@ def init( # type: ignore
fitnesses: Fitness,
population_size: int,
num_neighbours: int,
) -> GARepertoire:
) -> GAPopulation:
"""Initializes the repertoire.

Start with default values and adds a first batch of genotypes
Expand Down
8 changes: 4 additions & 4 deletions qdax/core/emitters/dcrl_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import optax
from jax import numpy as jnp

from qdax.core.containers.repertoire import Repertoire
from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire
from qdax.core.emitters.emitter import Emitter, EmitterState
from qdax.core.neuroevolution.buffers.buffer import DCRLTransition, ReplayBuffer
from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_dc_fn
Expand Down Expand Up @@ -128,7 +128,7 @@ def use_all_data(self) -> bool:
def init(
self,
key: RNGKey,
repertoire: Repertoire,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
Expand Down Expand Up @@ -280,7 +280,7 @@ def _compute_equivalent_params_with_desc(
)
def emit(
self,
repertoire: Repertoire,
repertoire: MapElitesRepertoire,
emitter_state: DCRLEmitterState,
key: RNGKey,
) -> Tuple[Genotype, ExtraScores, RNGKey]:
Expand Down Expand Up @@ -393,7 +393,7 @@ def emit_actor(self, emitter_state: DCRLEmitterState) -> Genotype:
def state_update(
self,
emitter_state: DCRLEmitterState,
repertoire: Repertoire,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
Expand Down
6 changes: 3 additions & 3 deletions qdax/core/emitters/dpg_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import optax

from qdax.core.containers.archive import Archive
from qdax.core.containers.repertoire import Repertoire
from qdax.core.containers.ga_popullation import GAPopulation
from qdax.core.emitters.qpg_emitter import (
QualityPGConfig,
QualityPGEmitter,
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
def init(
self,
random_key: RNGKey,
repertoire: Repertoire,
repertoire: GAPopulation,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
Expand Down Expand Up @@ -136,7 +136,7 @@ def init(
def state_update(
self,
emitter_state: DiversityPGEmitterState,
repertoire: Optional[Repertoire],
repertoire: Optional[GAPopulation],
genotypes: Optional[Genotype],
fitnesses: Optional[Fitness],
descriptors: Optional[Descriptor],
Expand Down
8 changes: 4 additions & 4 deletions qdax/core/emitters/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax
from flax.struct import PyTreeNode

from qdax.core.containers.repertoire import Repertoire
from qdax.core.containers.ga_popullation import GAPopulation
from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey


Expand All @@ -32,7 +32,7 @@ class Emitter(ABC):
def init(
self,
random_key: RNGKey,
repertoire: Repertoire,
repertoire: GAPopulation,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
Expand All @@ -54,7 +54,7 @@ def init(
@abstractmethod
def emit(
self,
repertoire: Optional[Repertoire],
repertoire: Optional[GAPopulation],
emitter_state: Optional[EmitterState],
random_key: RNGKey,
) -> Tuple[Genotype, ExtraScores, RNGKey]:
Expand All @@ -80,7 +80,7 @@ def emit(
def state_update(
self,
emitter_state: Optional[EmitterState],
repertoire: Optional[Repertoire] = None,
repertoire: Optional[GAPopulation] = None,
genotypes: Optional[Genotype] = None,
fitnesses: Optional[Fitness] = None,
descriptors: Optional[Descriptor] = None,
Expand Down
8 changes: 4 additions & 4 deletions qdax/core/emitters/multi_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from chex import ArrayTree
from jax import numpy as jnp

from qdax.core.containers.repertoire import Repertoire
from qdax.core.containers.ga_popullation import GAPopulation
from qdax.core.emitters.emitter import Emitter, EmitterState
from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey

Expand Down Expand Up @@ -58,7 +58,7 @@ def get_indexes_separation_batches(
def init(
self,
random_key: RNGKey,
repertoire: Repertoire,
repertoire: GAPopulation,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
Expand Down Expand Up @@ -97,7 +97,7 @@ def init(
@partial(jax.jit, static_argnames=("self",))
def emit(
self,
repertoire: Optional[Repertoire],
repertoire: Optional[GAPopulation],
emitter_state: Optional[MultiEmitterState],
random_key: RNGKey,
) -> Tuple[Genotype, ExtraScores, RNGKey]:
Expand Down Expand Up @@ -145,7 +145,7 @@ def emit(
def state_update(
self,
emitter_state: Optional[MultiEmitterState],
repertoire: Optional[Repertoire] = None,
repertoire: Optional[GAPopulation] = None,
genotypes: Optional[Genotype] = None,
fitnesses: Optional[Fitness] = None,
descriptors: Optional[Descriptor] = None,
Expand Down
8 changes: 4 additions & 4 deletions qdax/core/emitters/pbt_me_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from qdax.baselines.pbt import PBTTrainingState
from qdax.baselines.sac_pbt import PBTSAC
from qdax.baselines.td3_pbt import PBTTD3
from qdax.core.containers.repertoire import Repertoire
from qdax.core.containers.ga_popullation import GAPopulation
from qdax.core.emitters.emitter import Emitter, EmitterState
from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition
from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey
Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(
def init(
self,
random_key: RNGKey,
repertoire: Repertoire,
repertoire: GAPopulation,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
Expand Down Expand Up @@ -169,7 +169,7 @@ def init(
)
def emit(
self,
repertoire: Repertoire,
repertoire: GAPopulation,
emitter_state: PBTEmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, ExtraScores, RNGKey]:
Expand Down Expand Up @@ -221,7 +221,7 @@ def batch_size(self) -> int:
def state_update(
self,
emitter_state: PBTEmitterState,
repertoire: Repertoire,
repertoire: GAPopulation,
genotypes: Optional[Genotype],
fitnesses: Fitness,
descriptors: Optional[Descriptor],
Expand Down
Loading
Loading