diff --git a/qdax/baselines/genetic_algorithm.py b/qdax/baselines/genetic_algorithm.py index b4c6a32f2..0d1d141d3 100644 --- a/qdax/baselines/genetic_algorithm.py +++ b/qdax/baselines/genetic_algorithm.py @@ -5,7 +5,7 @@ import jax -from qdax.core.containers.ga_repertoire import GARepertoire +from qdax.core.containers.ga_popullation import GAPopulation from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.custom_types import ExtraScores, Fitness, Genotype, Metrics, RNGKey @@ -32,7 +32,7 @@ def __init__( [Genotype, RNGKey], Tuple[Fitness, ExtraScores, RNGKey] ], emitter: Emitter, - metrics_function: Callable[[GARepertoire], Metrics], + metrics_function: Callable[[GAPopulation], Metrics], ) -> None: self._scoring_function = scoring_function self._emitter = emitter @@ -41,7 +41,7 @@ def __init__( @partial(jax.jit, static_argnames=("self", "population_size")) def init( self, genotypes: Genotype, population_size: int, random_key: RNGKey - ) -> Tuple[GARepertoire, Optional[EmitterState], RNGKey]: + ) -> Tuple[GAPopulation, Optional[EmitterState], RNGKey]: """Initialize a GARepertoire with an initial population of genotypes. Args: @@ -59,7 +59,7 @@ def init( ) # init the repertoire - repertoire = GARepertoire.init( + repertoire = GAPopulation.init( genotypes=genotypes, fitnesses=fitnesses, population_size=population_size, @@ -80,10 +80,10 @@ def init( @partial(jax.jit, static_argnames=("self",)) def update( self, - repertoire: GARepertoire, + repertoire: GAPopulation, emitter_state: Optional[EmitterState], random_key: RNGKey, - ) -> Tuple[GARepertoire, Optional[EmitterState], Metrics, RNGKey]: + ) -> Tuple[GAPopulation, Optional[EmitterState], Metrics, RNGKey]: """ Performs one iteration of a Genetic algorithm. 1. A batch of genotypes is sampled in the repertoire and the genotypes @@ -134,9 +134,9 @@ def update( @partial(jax.jit, static_argnames=("self",)) def scan_update( self, - carry: Tuple[GARepertoire, Optional[EmitterState], RNGKey], + carry: Tuple[GAPopulation, Optional[EmitterState], RNGKey], unused: Any, - ) -> Tuple[Tuple[GARepertoire, Optional[EmitterState], RNGKey], Metrics]: + ) -> Tuple[Tuple[GAPopulation, Optional[EmitterState], RNGKey], Metrics]: """Rewrites the update function in a way that makes it compatible with the jax.lax.scan primitive. diff --git a/qdax/core/containers/ga_repertoire.py b/qdax/core/containers/ga_popullation.py similarity index 96% rename from qdax/core/containers/ga_repertoire.py rename to qdax/core/containers/ga_popullation.py index 403331ff1..c2ea4427d 100644 --- a/qdax/core/containers/ga_repertoire.py +++ b/qdax/core/containers/ga_popullation.py @@ -5,15 +5,15 @@ from functools import partial from typing import Callable, Tuple +import flax import jax import jax.numpy as jnp from jax.flatten_util import ravel_pytree -from qdax.core.containers.repertoire import Repertoire from qdax.custom_types import Fitness, Genotype, RNGKey -class GARepertoire(Repertoire): +class GAPopulation(flax.struct.PyTreeNode): """Class for a simple repertoire for a simple genetic algorithm. @@ -23,7 +23,7 @@ class GARepertoire(Repertoire): shape (population_size, num_features). fitnesses: an array containing the fitness of the individuals in the population. With shape (population_size, fitness_dim). - The implementation of GARepertoire was thought for the case + The implementation of GAPopulation was thought for the case where fitness_dim equals 1 but the class can be herited and rules adapted for cases where fitness_dim is greater than 1. """ @@ -55,7 +55,7 @@ def flatten_genotype(genotype: Genotype) -> jnp.ndarray: jnp.save(path + "scores.npy", self.fitnesses) @classmethod - def load(cls, reconstruction_fn: Callable, path: str = "./") -> GARepertoire: + def load(cls, reconstruction_fn: Callable, path: str = "./") -> GAPopulation: """Loads a GA Repertoire. Args: @@ -107,7 +107,7 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey @jax.jit def add( self, batch_of_genotypes: Genotype, batch_of_fitnesses: Fitness - ) -> GARepertoire: + ) -> GAPopulation: """Implements the repertoire addition rules. Parents and offsprings are gathered and only the population_size @@ -154,7 +154,7 @@ def init( # type: ignore genotypes: Genotype, fitnesses: Fitness, population_size: int, - ) -> GARepertoire: + ) -> GAPopulation: """Initializes the repertoire. Start with default values and adds a first batch of genotypes diff --git a/qdax/core/containers/nsga2_repertoire.py b/qdax/core/containers/nsga2_repertoire.py index 331ef153c..5687ac9c6 100644 --- a/qdax/core/containers/nsga2_repertoire.py +++ b/qdax/core/containers/nsga2_repertoire.py @@ -5,12 +5,12 @@ import jax import jax.numpy as jnp -from qdax.core.containers.ga_repertoire import GARepertoire +from qdax.core.containers.ga_popullation import GAPopulation from qdax.custom_types import Fitness, Genotype from qdax.utils.pareto_front import compute_masked_pareto_front -class NSGA2Repertoire(GARepertoire): +class NSGA2Repertoire(GAPopulation): """Repertoire used for the NSGA2 algorithm. Inherits from the GARepertoire. The data stored are the genotypes diff --git a/qdax/core/containers/repertoire.py b/qdax/core/containers/repertoire.py deleted file mode 100644 index 77c916832..000000000 --- a/qdax/core/containers/repertoire.py +++ /dev/null @@ -1,53 +0,0 @@ -"""This file contains util functions and a class to define -a repertoire, used to store individuals in the MAP-Elites -algorithm as well as several variants.""" - -from __future__ import annotations - -from abc import ABC, abstractmethod - -import flax - -from qdax.custom_types import Genotype, RNGKey - - -class Repertoire(flax.struct.PyTreeNode, ABC): - """Abstract class for any repertoire of genotypes. - - We decided not to add the attributes Genotypes even if - it will be shared by all children classes because we want - to keep the parent classes explicit and transparent. - """ - - @classmethod - @abstractmethod - def init(cls) -> Repertoire: # noqa: N805 - """Create a repertoire.""" - pass - - @abstractmethod - def sample( - self, - random_key: RNGKey, - num_samples: int, - ) -> Genotype: - """Sample genotypes from the repertoire. - - Args: - random_key: a random key to handle stochasticity. - num_samples: the number of genotypes to sample. - - Returns: - The sample of genotypes. - """ - pass - - @abstractmethod - def add(self) -> Repertoire: - """Implements the rule to add new genotypes to a - repertoire. - - Returns: - The udpated repertoire. - """ - pass diff --git a/qdax/core/containers/spea2_repertoire.py b/qdax/core/containers/spea2_repertoire.py index 33c31547d..22ab09039 100644 --- a/qdax/core/containers/spea2_repertoire.py +++ b/qdax/core/containers/spea2_repertoire.py @@ -4,11 +4,11 @@ import jax import jax.numpy as jnp -from qdax.core.containers.ga_repertoire import GARepertoire +from qdax.core.containers.ga_popullation import GAPopulation from qdax.custom_types import Fitness, Genotype -class SPEA2Repertoire(GARepertoire): +class SPEA2Repertoire(GAPopulation): """Repertoire used for the SPEA2 algorithm. Inherits from the GARepertoire. The data stored are the genotypes @@ -101,7 +101,7 @@ def init( # type: ignore fitnesses: Fitness, population_size: int, num_neighbours: int, - ) -> GARepertoire: + ) -> GAPopulation: """Initializes the repertoire. Start with default values and adds a first batch of genotypes diff --git a/qdax/core/emitters/dcrl_emitter.py b/qdax/core/emitters/dcrl_emitter.py index b353a22f1..992c9f521 100644 --- a/qdax/core/emitters/dcrl_emitter.py +++ b/qdax/core/emitters/dcrl_emitter.py @@ -11,7 +11,7 @@ import optax from jax import numpy as jnp -from qdax.core.containers.repertoire import Repertoire +from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.core.neuroevolution.buffers.buffer import DCRLTransition, ReplayBuffer from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_dc_fn @@ -128,7 +128,7 @@ def use_all_data(self) -> bool: def init( self, key: RNGKey, - repertoire: Repertoire, + repertoire: MapElitesRepertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, @@ -280,7 +280,7 @@ def _compute_equivalent_params_with_desc( ) def emit( self, - repertoire: Repertoire, + repertoire: MapElitesRepertoire, emitter_state: DCRLEmitterState, key: RNGKey, ) -> Tuple[Genotype, ExtraScores, RNGKey]: @@ -393,7 +393,7 @@ def emit_actor(self, emitter_state: DCRLEmitterState) -> Genotype: def state_update( self, emitter_state: DCRLEmitterState, - repertoire: Repertoire, + repertoire: MapElitesRepertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, diff --git a/qdax/core/emitters/dpg_emitter.py b/qdax/core/emitters/dpg_emitter.py index ea921237d..9c84fa891 100644 --- a/qdax/core/emitters/dpg_emitter.py +++ b/qdax/core/emitters/dpg_emitter.py @@ -11,7 +11,7 @@ import optax from qdax.core.containers.archive import Archive -from qdax.core.containers.repertoire import Repertoire +from qdax.core.containers.ga_popullation import GAPopulation from qdax.core.emitters.qpg_emitter import ( QualityPGConfig, QualityPGEmitter, @@ -80,7 +80,7 @@ def __init__( def init( self, random_key: RNGKey, - repertoire: Repertoire, + repertoire: GAPopulation, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, @@ -136,7 +136,7 @@ def init( def state_update( self, emitter_state: DiversityPGEmitterState, - repertoire: Optional[Repertoire], + repertoire: Optional[GAPopulation], genotypes: Optional[Genotype], fitnesses: Optional[Fitness], descriptors: Optional[Descriptor], diff --git a/qdax/core/emitters/emitter.py b/qdax/core/emitters/emitter.py index 211393564..4cba2d6f4 100644 --- a/qdax/core/emitters/emitter.py +++ b/qdax/core/emitters/emitter.py @@ -5,7 +5,7 @@ import jax from flax.struct import PyTreeNode -from qdax.core.containers.repertoire import Repertoire +from qdax.core.containers.ga_popullation import GAPopulation from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey @@ -32,7 +32,7 @@ class Emitter(ABC): def init( self, random_key: RNGKey, - repertoire: Repertoire, + repertoire: GAPopulation, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, @@ -54,7 +54,7 @@ def init( @abstractmethod def emit( self, - repertoire: Optional[Repertoire], + repertoire: Optional[GAPopulation], emitter_state: Optional[EmitterState], random_key: RNGKey, ) -> Tuple[Genotype, ExtraScores, RNGKey]: @@ -80,7 +80,7 @@ def emit( def state_update( self, emitter_state: Optional[EmitterState], - repertoire: Optional[Repertoire] = None, + repertoire: Optional[GAPopulation] = None, genotypes: Optional[Genotype] = None, fitnesses: Optional[Fitness] = None, descriptors: Optional[Descriptor] = None, diff --git a/qdax/core/emitters/multi_emitter.py b/qdax/core/emitters/multi_emitter.py index 17cb8ace9..b3375f188 100644 --- a/qdax/core/emitters/multi_emitter.py +++ b/qdax/core/emitters/multi_emitter.py @@ -6,7 +6,7 @@ from chex import ArrayTree from jax import numpy as jnp -from qdax.core.containers.repertoire import Repertoire +from qdax.core.containers.ga_popullation import GAPopulation from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey @@ -58,7 +58,7 @@ def get_indexes_separation_batches( def init( self, random_key: RNGKey, - repertoire: Repertoire, + repertoire: GAPopulation, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, @@ -97,7 +97,7 @@ def init( @partial(jax.jit, static_argnames=("self",)) def emit( self, - repertoire: Optional[Repertoire], + repertoire: Optional[GAPopulation], emitter_state: Optional[MultiEmitterState], random_key: RNGKey, ) -> Tuple[Genotype, ExtraScores, RNGKey]: @@ -145,7 +145,7 @@ def emit( def state_update( self, emitter_state: Optional[MultiEmitterState], - repertoire: Optional[Repertoire] = None, + repertoire: Optional[GAPopulation] = None, genotypes: Optional[Genotype] = None, fitnesses: Optional[Fitness] = None, descriptors: Optional[Descriptor] = None, diff --git a/qdax/core/emitters/pbt_me_emitter.py b/qdax/core/emitters/pbt_me_emitter.py index 55bded4e9..5e184d250 100644 --- a/qdax/core/emitters/pbt_me_emitter.py +++ b/qdax/core/emitters/pbt_me_emitter.py @@ -9,7 +9,7 @@ from qdax.baselines.pbt import PBTTrainingState from qdax.baselines.sac_pbt import PBTSAC from qdax.baselines.td3_pbt import PBTTD3 -from qdax.core.containers.repertoire import Repertoire +from qdax.core.containers.ga_popullation import GAPopulation from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey @@ -93,7 +93,7 @@ def __init__( def init( self, random_key: RNGKey, - repertoire: Repertoire, + repertoire: GAPopulation, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, @@ -169,7 +169,7 @@ def init( ) def emit( self, - repertoire: Repertoire, + repertoire: GAPopulation, emitter_state: PBTEmitterState, random_key: RNGKey, ) -> Tuple[Genotype, ExtraScores, RNGKey]: @@ -221,7 +221,7 @@ def batch_size(self) -> int: def state_update( self, emitter_state: PBTEmitterState, - repertoire: Repertoire, + repertoire: GAPopulation, genotypes: Optional[Genotype], fitnesses: Fitness, descriptors: Optional[Descriptor], diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index 63373494e..67911a3e8 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -12,7 +12,7 @@ import optax from jax import numpy as jnp -from qdax.core.containers.repertoire import Repertoire +from qdax.core.containers.ga_popullation import GAPopulation from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_fn @@ -121,7 +121,7 @@ def use_all_data(self) -> bool: def init( self, random_key: RNGKey, - repertoire: Repertoire, + repertoire: GAPopulation, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, @@ -197,7 +197,7 @@ def init( ) def emit( self, - repertoire: Repertoire, + repertoire: GAPopulation, emitter_state: QualityPGEmitterState, random_key: RNGKey, ) -> Tuple[Genotype, ExtraScores, RNGKey]: @@ -286,7 +286,7 @@ def emit_actor(self, emitter_state: QualityPGEmitterState) -> Genotype: def state_update( self, emitter_state: QualityPGEmitterState, - repertoire: Optional[Repertoire], + repertoire: Optional[GAPopulation], genotypes: Optional[Genotype], fitnesses: Optional[Fitness], descriptors: Optional[Descriptor], diff --git a/qdax/core/emitters/standard_emitters.py b/qdax/core/emitters/standard_emitters.py index 1d949b2d9..34e3d310e 100644 --- a/qdax/core/emitters/standard_emitters.py +++ b/qdax/core/emitters/standard_emitters.py @@ -4,7 +4,7 @@ import jax import jax.numpy as jnp -from qdax.core.containers.repertoire import Repertoire +from qdax.core.containers.ga_popullation import GAPopulation from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.custom_types import ExtraScores, Genotype, RNGKey @@ -28,7 +28,7 @@ def __init__( ) def emit( self, - repertoire: Repertoire, + repertoire: GAPopulation, emitter_state: Optional[EmitterState], random_key: RNGKey, ) -> Tuple[Genotype, ExtraScores, RNGKey]: diff --git a/qdax/utils/metrics.py b/qdax/utils/metrics.py index 509c6d91d..8f8272766 100644 --- a/qdax/utils/metrics.py +++ b/qdax/utils/metrics.py @@ -9,7 +9,7 @@ import jax from jax import numpy as jnp -from qdax.core.containers.ga_repertoire import GARepertoire +from qdax.core.containers.ga_popullation import GAPopulation from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.containers.mome_repertoire import MOMERepertoire from qdax.custom_types import Metrics @@ -50,7 +50,7 @@ def log(self, metrics: Dict[str, float]) -> None: def default_ga_metrics( - repertoire: GARepertoire, + repertoire: GAPopulation, ) -> Metrics: """Compute the usual GA metrics that one can retrieve from a GA repertoire.