diff --git a/acme/__init__.py b/acme/__init__.py index 20c7923e05..01d101e339 100644 --- a/acme/__init__.py +++ b/acme/__init__.py @@ -17,21 +17,14 @@ # Internal import. # Expose specs and types modules. -from acme import specs -from acme import types +from acme import specs, types # Make __version__ accessible. from acme._metadata import __version__ # Expose core interfaces. -from acme.core import Actor -from acme.core import Learner -from acme.core import Saveable -from acme.core import VariableSource -from acme.core import Worker +from acme.core import Actor, Learner, Saveable, VariableSource, Worker # Expose the environment loop. from acme.environment_loop import EnvironmentLoop - from acme.specs import make_environment_spec - diff --git a/acme/_metadata.py b/acme/_metadata.py index 97ea5a9865..e83ff0a616 100644 --- a/acme/_metadata.py +++ b/acme/_metadata.py @@ -19,9 +19,9 @@ """ # We follow Semantic Versioning (https://semver.org/) -_MAJOR_VERSION = '0' -_MINOR_VERSION = '4' -_PATCH_VERSION = '1' +_MAJOR_VERSION = "0" +_MINOR_VERSION = "4" +_PATCH_VERSION = "1" # Example: '0.4.2' -__version__ = '.'.join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION]) +__version__ = ".".join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION]) diff --git a/acme/adders/__init__.py b/acme/adders/__init__.py index 5d08479a35..0891f530a4 100644 --- a/acme/adders/__init__.py +++ b/acme/adders/__init__.py @@ -17,5 +17,4 @@ # pylint: disable=unused-import from acme.adders.base import Adder -from acme.adders.wrappers import ForkingAdder -from acme.adders.wrappers import IgnoreExtrasAdder +from acme.adders.wrappers import ForkingAdder, IgnoreExtrasAdder diff --git a/acme/adders/base.py b/acme/adders/base.py index 7067e87309..7194aa209a 100644 --- a/acme/adders/base.py +++ b/acme/adders/base.py @@ -16,12 +16,13 @@ import abc -from acme import types import dm_env +from acme import types + class Adder(abc.ABC): - """The Adder interface. + """The Adder interface. An adder packs together data to send to the replay buffer, and potentially performs some reduction/transformation to this data in the process. @@ -49,9 +50,9 @@ class Adder(abc.ABC): timestep is named `next_timestep` precisely to emphasize this point. """ - @abc.abstractmethod - def add_first(self, timestep: dm_env.TimeStep): - """Defines the interface for an adder's `add_first` method. + @abc.abstractmethod + def add_first(self, timestep: dm_env.TimeStep): + """Defines the interface for an adder's `add_first` method. We expect this to be called at the beginning of each episode and it will start a trajectory to be added to replay with an initial observation. @@ -60,14 +61,14 @@ def add_first(self, timestep: dm_env.TimeStep): timestep: a dm_env TimeStep corresponding to the first step. """ - @abc.abstractmethod - def add( - self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - extras: types.NestedArray = (), - ): - """Defines the adder `add` interface. + @abc.abstractmethod + def add( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = (), + ): + """Defines the adder `add` interface. Args: action: A possibly nested structure corresponding to a_t. @@ -76,7 +77,6 @@ def add( extras: A possibly nested structure of extra data to add to replay. """ - @abc.abstractmethod - def reset(self): - """Resets the adder's buffer.""" - + @abc.abstractmethod + def reset(self): + """Resets the adder's buffer.""" diff --git a/acme/adders/reverb/__init__.py b/acme/adders/reverb/__init__.py index 189f8ce780..83ca16a349 100644 --- a/acme/adders/reverb/__init__.py +++ b/acme/adders/reverb/__init__.py @@ -16,17 +16,19 @@ # pylint: disable=unused-import -from acme.adders.reverb.base import DEFAULT_PRIORITY_TABLE -from acme.adders.reverb.base import PriorityFn -from acme.adders.reverb.base import PriorityFnInput -from acme.adders.reverb.base import ReverbAdder -from acme.adders.reverb.base import Step -from acme.adders.reverb.base import Trajectory - +from acme.adders.reverb.base import ( + DEFAULT_PRIORITY_TABLE, + PriorityFn, + PriorityFnInput, + ReverbAdder, + Step, + Trajectory, +) from acme.adders.reverb.episode import EpisodeAdder -from acme.adders.reverb.sequence import EndBehavior -from acme.adders.reverb.sequence import SequenceAdder -from acme.adders.reverb.structured import create_n_step_transition_config -from acme.adders.reverb.structured import create_step_spec -from acme.adders.reverb.structured import StructuredAdder +from acme.adders.reverb.sequence import EndBehavior, SequenceAdder +from acme.adders.reverb.structured import ( + StructuredAdder, + create_n_step_transition_config, + create_step_spec, +) from acme.adders.reverb.transition import NStepTransitionAdder diff --git a/acme/adders/reverb/base.py b/acme/adders/reverb/base.py index 1fe8d5e7b5..8caac37fc7 100644 --- a/acme/adders/reverb/base.py +++ b/acme/adders/reverb/base.py @@ -16,33 +16,42 @@ import abc import time -from typing import Callable, Iterable, Mapping, NamedTuple, Optional, Sized, Union, Tuple +from typing import ( + Callable, + Iterable, + Mapping, + NamedTuple, + Optional, + Sized, + Tuple, + Union, +) -from absl import logging -from acme import specs -from acme import types -from acme.adders import base import dm_env import numpy as np import reverb import tensorflow as tf import tree +from absl import logging -DEFAULT_PRIORITY_TABLE = 'priority_table' +from acme import specs, types +from acme.adders import base + +DEFAULT_PRIORITY_TABLE = "priority_table" _MIN_WRITER_LIFESPAN_SECONDS = 60 -StartOfEpisodeType = Union[bool, specs.Array, tf.Tensor, tf.TensorSpec, - Tuple[()]] +StartOfEpisodeType = Union[bool, specs.Array, tf.Tensor, tf.TensorSpec, Tuple[()]] # TODO(b/188510142): Delete Step. class Step(NamedTuple): - """Step class used internally for reverb adders.""" - observation: types.NestedArray - action: types.NestedArray - reward: types.NestedArray - discount: types.NestedArray - start_of_episode: StartOfEpisodeType - extras: types.NestedArray = () + """Step class used internally for reverb adders.""" + + observation: types.NestedArray + action: types.NestedArray + reward: types.NestedArray + discount: types.NestedArray + start_of_episode: StartOfEpisodeType + extras: types.NestedArray = () # TODO(b/188510142): Replace with proper Trajectory class. @@ -50,37 +59,38 @@ class Step(NamedTuple): class PriorityFnInput(NamedTuple): - """The input to a priority function consisting of stacked steps.""" - observations: types.NestedArray - actions: types.NestedArray - rewards: types.NestedArray - discounts: types.NestedArray - start_of_episode: types.NestedArray - extras: types.NestedArray + """The input to a priority function consisting of stacked steps.""" + + observations: types.NestedArray + actions: types.NestedArray + rewards: types.NestedArray + discounts: types.NestedArray + start_of_episode: types.NestedArray + extras: types.NestedArray # Define the type of a priority function and the mapping from table to function. -PriorityFn = Callable[['PriorityFnInput'], float] +PriorityFn = Callable[["PriorityFnInput"], float] PriorityFnMapping = Mapping[str, Optional[PriorityFn]] def spec_like_to_tensor_spec(paths: Iterable[str], spec: specs.Array): - return tf.TensorSpec.from_spec(spec, name='/'.join(str(p) for p in paths)) + return tf.TensorSpec.from_spec(spec, name="/".join(str(p) for p in paths)) class ReverbAdder(base.Adder): - """Base class for Reverb adders.""" - - def __init__( - self, - client: reverb.Client, - max_sequence_length: int, - max_in_flight_items: int, - delta_encoded: bool = False, - priority_fns: Optional[PriorityFnMapping] = None, - validate_items: bool = True, - ): - """Initialize a ReverbAdder instance. + """Base class for Reverb adders.""" + + def __init__( + self, + client: reverb.Client, + max_sequence_length: int, + max_in_flight_items: int, + delta_encoded: bool = False, + priority_fns: Optional[PriorityFnMapping] = None, + validate_items: bool = True, + ): + """Initialize a ReverbAdder instance. Args: client: A client to the Reverb backend. @@ -98,128 +108,144 @@ def __init__( before they are sent to the server. This requires table signature to be fetched from the server and cached locally. """ - if priority_fns: - priority_fns = dict(priority_fns) - else: - priority_fns = {DEFAULT_PRIORITY_TABLE: None} - - self._client = client - self._priority_fns = priority_fns - self._max_sequence_length = max_sequence_length - self._delta_encoded = delta_encoded - # TODO(b/206629159): Remove this. - self._max_in_flight_items = max_in_flight_items - self._add_first_called = False - - # This is exposed as the _writer property in such a way that it will create - # a new writer automatically whenever the internal __writer is None. Users - # should ONLY ever interact with self._writer. - self.__writer = None - # Every time a new writer is created, it must fetch the signature from the - # Reverb server. If this is set too low it can crash the adders in a - # distributed setup where the replay may take a while to spin up. - self._validate_items = validate_items - - def __del__(self): - if self.__writer is not None: - timeout_ms = 10_000 - # Try flush all appended data before closing to avoid loss of experience. - try: - self.__writer.flush(0, timeout_ms=timeout_ms) - except reverb.DeadlineExceededError as e: - logging.error( - 'Timeout (%d ms) exceeded when flushing the writer before ' - 'deleting it. Caught Reverb exception: %s', timeout_ms, str(e)) - self.__writer.close() - self.__writer = None - - @property - def _writer(self) -> reverb.TrajectoryWriter: - if self.__writer is None: - self.__writer = self._client.trajectory_writer( - num_keep_alive_refs=self._max_sequence_length, - validate_items=self._validate_items) - self._writer_created_timestamp = time.time() - return self.__writer - - def add_priority_table(self, table_name: str, - priority_fn: Optional[PriorityFn]): - if table_name in self._priority_fns: - raise ValueError( - f'A priority function already exists for {table_name}. ' - f'Existing tables: {", ".join(self._priority_fns.keys())}.' - ) - self._priority_fns[table_name] = priority_fn - - def reset(self, timeout_ms: Optional[int] = None): - """Resets the adder's buffer.""" - if self.__writer: - # Flush all appended data and clear the buffers. - self.__writer.end_episode(clear_buffers=True, timeout_ms=timeout_ms) - - # Create a new writer unless the current one is too young. - # This is to reduce the relative overhead of creating a new Reverb writer. - if (time.time() - self._writer_created_timestamp > - _MIN_WRITER_LIFESPAN_SECONDS): + if priority_fns: + priority_fns = dict(priority_fns) + else: + priority_fns = {DEFAULT_PRIORITY_TABLE: None} + + self._client = client + self._priority_fns = priority_fns + self._max_sequence_length = max_sequence_length + self._delta_encoded = delta_encoded + # TODO(b/206629159): Remove this. + self._max_in_flight_items = max_in_flight_items + self._add_first_called = False + + # This is exposed as the _writer property in such a way that it will create + # a new writer automatically whenever the internal __writer is None. Users + # should ONLY ever interact with self._writer. self.__writer = None - self._add_first_called = False - - def add_first(self, timestep: dm_env.TimeStep): - """Record the first observation of a trajectory.""" - if not timestep.first(): - raise ValueError('adder.add_first with an initial timestep (i.e. one for ' - 'which timestep.first() is True') - - # Record the next observation but leave the history buffer row open by - # passing `partial_step=True`. - self._writer.append(dict(observation=timestep.observation, - start_of_episode=timestep.first()), - partial_step=True) - self._add_first_called = True - - def add(self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - extras: types.NestedArray = ()): - """Record an action and the following timestep.""" - - if not self._add_first_called: - raise ValueError('adder.add_first must be called before adder.add.') - - # Add the timestep to the buffer. - has_extras = (len(extras) > 0 if isinstance(extras, Sized) # pylint: disable=g-explicit-length-test - else extras is not None) - current_step = dict( - # Observation was passed at the previous add call. - action=action, - reward=next_timestep.reward, - discount=next_timestep.discount, - # Start of episode indicator was passed at the previous add call. - **({'extras': extras} if has_extras else {}) - ) - self._writer.append(current_step) - - # Record the next observation and write. - self._writer.append( - dict( - observation=next_timestep.observation, - start_of_episode=next_timestep.first()), - partial_step=True) - self._write() - - if next_timestep.last(): - # Complete the row by appending zeros to remaining open fields. - # TODO(b/183945808): remove this when fields are no longer expected to be - # of equal length on the learner side. - dummy_step = tree.map_structure(np.zeros_like, current_step) - self._writer.append(dummy_step) - self._write_last() - self.reset() - - @classmethod - def signature(cls, environment_spec: specs.EnvironmentSpec, - extras_spec: types.NestedSpec = ()): - """This is a helper method for generating signatures for Reverb tables. + # Every time a new writer is created, it must fetch the signature from the + # Reverb server. If this is set too low it can crash the adders in a + # distributed setup where the replay may take a while to spin up. + self._validate_items = validate_items + + def __del__(self): + if self.__writer is not None: + timeout_ms = 10_000 + # Try flush all appended data before closing to avoid loss of experience. + try: + self.__writer.flush(0, timeout_ms=timeout_ms) + except reverb.DeadlineExceededError as e: + logging.error( + "Timeout (%d ms) exceeded when flushing the writer before " + "deleting it. Caught Reverb exception: %s", + timeout_ms, + str(e), + ) + self.__writer.close() + self.__writer = None + + @property + def _writer(self) -> reverb.TrajectoryWriter: + if self.__writer is None: + self.__writer = self._client.trajectory_writer( + num_keep_alive_refs=self._max_sequence_length, + validate_items=self._validate_items, + ) + self._writer_created_timestamp = time.time() + return self.__writer + + def add_priority_table(self, table_name: str, priority_fn: Optional[PriorityFn]): + if table_name in self._priority_fns: + raise ValueError( + f"A priority function already exists for {table_name}. " + f'Existing tables: {", ".join(self._priority_fns.keys())}.' + ) + self._priority_fns[table_name] = priority_fn + + def reset(self, timeout_ms: Optional[int] = None): + """Resets the adder's buffer.""" + if self.__writer: + # Flush all appended data and clear the buffers. + self.__writer.end_episode(clear_buffers=True, timeout_ms=timeout_ms) + + # Create a new writer unless the current one is too young. + # This is to reduce the relative overhead of creating a new Reverb writer. + if ( + time.time() - self._writer_created_timestamp + > _MIN_WRITER_LIFESPAN_SECONDS + ): + self.__writer = None + self._add_first_called = False + + def add_first(self, timestep: dm_env.TimeStep): + """Record the first observation of a trajectory.""" + if not timestep.first(): + raise ValueError( + "adder.add_first with an initial timestep (i.e. one for " + "which timestep.first() is True" + ) + + # Record the next observation but leave the history buffer row open by + # passing `partial_step=True`. + self._writer.append( + dict(observation=timestep.observation, start_of_episode=timestep.first()), + partial_step=True, + ) + self._add_first_called = True + + def add( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = (), + ): + """Record an action and the following timestep.""" + + if not self._add_first_called: + raise ValueError("adder.add_first must be called before adder.add.") + + # Add the timestep to the buffer. + has_extras = ( + len(extras) > 0 + if isinstance(extras, Sized) # pylint: disable=g-explicit-length-test + else extras is not None + ) + current_step = dict( + # Observation was passed at the previous add call. + action=action, + reward=next_timestep.reward, + discount=next_timestep.discount, + # Start of episode indicator was passed at the previous add call. + **({"extras": extras} if has_extras else {}), + ) + self._writer.append(current_step) + + # Record the next observation and write. + self._writer.append( + dict( + observation=next_timestep.observation, + start_of_episode=next_timestep.first(), + ), + partial_step=True, + ) + self._write() + + if next_timestep.last(): + # Complete the row by appending zeros to remaining open fields. + # TODO(b/183945808): remove this when fields are no longer expected to be + # of equal length on the learner side. + dummy_step = tree.map_structure(np.zeros_like, current_step) + self._writer.append(dummy_step) + self._write_last() + self.reset() + + @classmethod + def signature( + cls, environment_spec: specs.EnvironmentSpec, extras_spec: types.NestedSpec = () + ): + """This is a helper method for generating signatures for Reverb tables. Signatures are useful for validating data types and shapes, see Reverb's documentation for details on how they are used. @@ -236,19 +262,20 @@ def signature(cls, environment_spec: specs.EnvironmentSpec, Returns: A `Step` whose leaf nodes are `tf.TensorSpec` objects. """ - spec_step = Step( - observation=environment_spec.observations, - action=environment_spec.actions, - reward=environment_spec.rewards, - discount=environment_spec.discounts, - start_of_episode=specs.Array(shape=(), dtype=bool), - extras=extras_spec) - return tree.map_structure_with_path(spec_like_to_tensor_spec, spec_step) - - @abc.abstractmethod - def _write(self): - """Write data to replay from the buffer.""" - - @abc.abstractmethod - def _write_last(self): - """Write data to replay from the buffer.""" + spec_step = Step( + observation=environment_spec.observations, + action=environment_spec.actions, + reward=environment_spec.rewards, + discount=environment_spec.discounts, + start_of_episode=specs.Array(shape=(), dtype=bool), + extras=extras_spec, + ) + return tree.map_structure_with_path(spec_like_to_tensor_spec, spec_step) + + @abc.abstractmethod + def _write(self): + """Write data to replay from the buffer.""" + + @abc.abstractmethod + def _write_last(self): + """Write data to replay from the buffer.""" diff --git a/acme/adders/reverb/episode.py b/acme/adders/reverb/episode.py index 2b5f644857..158ab91981 100644 --- a/acme/adders/reverb/episode.py +++ b/acme/adders/reverb/episode.py @@ -17,12 +17,7 @@ This implements full episode adders, potentially with padding. """ -from typing import Callable, Optional, Iterable, Tuple - -from acme import specs -from acme import types -from acme.adders.reverb import base -from acme.adders.reverb import utils +from typing import Callable, Iterable, Optional, Tuple import dm_env import numpy as np @@ -30,90 +25,99 @@ import tensorflow as tf import tree +from acme import specs, types +from acme.adders.reverb import base, utils + _PaddingFn = Callable[[Tuple[int, ...], np.dtype], np.ndarray] class EpisodeAdder(base.ReverbAdder): - """Adder which adds entire episodes as trajectories.""" - - def __init__( - self, - client: reverb.Client, - max_sequence_length: int, - delta_encoded: bool = False, - priority_fns: Optional[base.PriorityFnMapping] = None, - max_in_flight_items: int = 1, - padding_fn: Optional[_PaddingFn] = None, - # Deprecated kwargs. - chunk_length: Optional[int] = None, - ): - del chunk_length - - super().__init__( - client=client, - max_sequence_length=max_sequence_length, - delta_encoded=delta_encoded, - priority_fns=priority_fns, - max_in_flight_items=max_in_flight_items, - ) - self._padding_fn = padding_fn - - def add( - self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - extras: types.NestedArray = (), - ): - if self._writer.episode_steps >= self._max_sequence_length - 1: - raise ValueError( - 'The number of observations within the same episode will exceed ' - 'max_sequence_length with the addition of this transition.') - - super().add(action, next_timestep, extras) - - def _write(self): - # This adder only writes at the end of the episode, see _write_last() - pass - - def _write_last(self): - if self._padding_fn is not None and self._writer.episode_steps < self._max_sequence_length: - history = self._writer.history - padding_step = dict( - observation=history['observation'], - action=history['action'], - reward=history['reward'], - discount=history['discount'], - extras=history.get('extras', ())) - # Get shapes and dtypes from the last element. - padding_step = tree.map_structure( - lambda col: self._padding_fn(col[-1].shape, col[-1].dtype), - padding_step) - padding_step['start_of_episode'] = False - while self._writer.episode_steps < self._max_sequence_length: - self._writer.append(padding_step) - - trajectory = tree.map_structure(lambda x: x[:], self._writer.history) - - # Pack the history into a base.Step structure and get numpy converted - # variant for priotiy computation. - trajectory = base.Trajectory(**trajectory) - - # Calculate the priority for this episode. - table_priorities = utils.calculate_priorities(self._priority_fns, - trajectory) - - # Create a prioritized item for each table. - for table_name, priority in table_priorities.items(): - self._writer.create_item(table_name, priority, trajectory) - self._writer.flush(self._max_in_flight_items) - - # TODO(b/185309817): make this into a standalone method. - @classmethod - def signature(cls, - environment_spec: specs.EnvironmentSpec, - extras_spec: types.NestedSpec = (), - sequence_length: Optional[int] = None): - """This is a helper method for generating signatures for Reverb tables. + """Adder which adds entire episodes as trajectories.""" + + def __init__( + self, + client: reverb.Client, + max_sequence_length: int, + delta_encoded: bool = False, + priority_fns: Optional[base.PriorityFnMapping] = None, + max_in_flight_items: int = 1, + padding_fn: Optional[_PaddingFn] = None, + # Deprecated kwargs. + chunk_length: Optional[int] = None, + ): + del chunk_length + + super().__init__( + client=client, + max_sequence_length=max_sequence_length, + delta_encoded=delta_encoded, + priority_fns=priority_fns, + max_in_flight_items=max_in_flight_items, + ) + self._padding_fn = padding_fn + + def add( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = (), + ): + if self._writer.episode_steps >= self._max_sequence_length - 1: + raise ValueError( + "The number of observations within the same episode will exceed " + "max_sequence_length with the addition of this transition." + ) + + super().add(action, next_timestep, extras) + + def _write(self): + # This adder only writes at the end of the episode, see _write_last() + pass + + def _write_last(self): + if ( + self._padding_fn is not None + and self._writer.episode_steps < self._max_sequence_length + ): + history = self._writer.history + padding_step = dict( + observation=history["observation"], + action=history["action"], + reward=history["reward"], + discount=history["discount"], + extras=history.get("extras", ()), + ) + # Get shapes and dtypes from the last element. + padding_step = tree.map_structure( + lambda col: self._padding_fn(col[-1].shape, col[-1].dtype), padding_step + ) + padding_step["start_of_episode"] = False + while self._writer.episode_steps < self._max_sequence_length: + self._writer.append(padding_step) + + trajectory = tree.map_structure(lambda x: x[:], self._writer.history) + + # Pack the history into a base.Step structure and get numpy converted + # variant for priotiy computation. + trajectory = base.Trajectory(**trajectory) + + # Calculate the priority for this episode. + table_priorities = utils.calculate_priorities(self._priority_fns, trajectory) + + # Create a prioritized item for each table. + for table_name, priority in table_priorities.items(): + self._writer.create_item(table_name, priority, trajectory) + self._writer.flush(self._max_in_flight_items) + + # TODO(b/185309817): make this into a standalone method. + @classmethod + def signature( + cls, + environment_spec: specs.EnvironmentSpec, + extras_spec: types.NestedSpec = (), + sequence_length: Optional[int] = None, + ): + """This is a helper method for generating signatures for Reverb tables. Signatures are useful for validating data types and shapes, see Reverb's documentation for details on how they are used. @@ -133,19 +137,23 @@ def signature(cls, A `Step` whose leaf nodes are `tf.TensorSpec` objects. """ - def add_time_dim(paths: Iterable[str], spec: tf.TensorSpec): - return tf.TensorSpec( - shape=(sequence_length, *spec.shape), - dtype=spec.dtype, - name='/'.join(str(p) for p in paths)) - - trajectory_env_spec, trajectory_extras_spec = tree.map_structure_with_path( - add_time_dim, (environment_spec, extras_spec)) - - trajectory_spec = base.Trajectory( - *trajectory_env_spec, - start_of_episode=tf.TensorSpec( - shape=(sequence_length,), dtype=tf.bool, name='start_of_episode'), - extras=trajectory_extras_spec) - - return trajectory_spec + def add_time_dim(paths: Iterable[str], spec: tf.TensorSpec): + return tf.TensorSpec( + shape=(sequence_length, *spec.shape), + dtype=spec.dtype, + name="/".join(str(p) for p in paths), + ) + + trajectory_env_spec, trajectory_extras_spec = tree.map_structure_with_path( + add_time_dim, (environment_spec, extras_spec) + ) + + trajectory_spec = base.Trajectory( + *trajectory_env_spec, + start_of_episode=tf.TensorSpec( + shape=(sequence_length,), dtype=tf.bool, name="start_of_episode" + ), + extras=trajectory_extras_spec + ) + + return trajectory_spec diff --git a/acme/adders/reverb/episode_test.py b/acme/adders/reverb/episode_test.py index 05d1a8e2b3..e83b06805e 100644 --- a/acme/adders/reverb/episode_test.py +++ b/acme/adders/reverb/episode_test.py @@ -14,98 +14,99 @@ """Tests for Episode adders.""" -from acme.adders.reverb import episode as adders -from acme.adders.reverb import test_utils import numpy as np +from absl.testing import absltest, parameterized -from absl.testing import absltest -from absl.testing import parameterized +from acme.adders.reverb import episode as adders +from acme.adders.reverb import test_utils class EpisodeAdderTest(test_utils.AdderTestMixin, parameterized.TestCase): - - @parameterized.parameters(2, 10, 50) - def test_adder(self, max_sequence_length): - adder = adders.EpisodeAdder(self.client, max_sequence_length) - - # Create a simple trajectory to add. - observations = range(max_sequence_length) - first, steps = test_utils.make_trajectory(observations) - - expected_episode = test_utils.make_sequence(observations) - self.run_test_adder( - adder=adder, - first=first, - steps=steps, - expected_items=[expected_episode], - signature=adder.signature(*test_utils.get_specs(steps[0]))) - - @parameterized.parameters(2, 10, 50) - def test_max_sequence_length(self, max_sequence_length): - adder = adders.EpisodeAdder(self.client, max_sequence_length) - - first, steps = test_utils.make_trajectory(range(max_sequence_length + 1)) - adder.add_first(first) - for action, step in steps[:-1]: - adder.add(action, step) - - # We should have max_sequence_length-1 timesteps that have been written, - # where the -1 is due to the dangling observation (ie we have actually - # seen max_sequence_length observations). - self.assertEqual(self.num_items(), 0) - - # Adding one more step should raise an error. - with self.assertRaises(ValueError): - action, step = steps[-1] - adder.add(action, step) - - # Since the last insert failed it should not affect the internal state. - self.assertEqual(self.num_items(), 0) - - @parameterized.parameters((2, 1), (10, 2), (50, 5)) - def test_padding(self, max_sequence_length, padding): - adder = adders.EpisodeAdder( - self.client, - max_sequence_length + padding, - padding_fn=np.zeros) - - # Create a simple trajectory to add. - observations = range(max_sequence_length) - first, steps = test_utils.make_trajectory(observations) - - expected_episode = test_utils.make_sequence(observations) - for _ in range(padding): - expected_episode.append((0, 0, 0.0, 0.0, False, ())) - - self.run_test_adder( - adder=adder, - first=first, - steps=steps, - expected_items=[expected_episode], - signature=adder.signature(*test_utils.get_specs(steps[0]))) - - @parameterized.parameters((2, 1), (10, 2), (50, 5)) - def test_nonzero_padding(self, max_sequence_length, padding): - adder = adders.EpisodeAdder( - self.client, - max_sequence_length + padding, - padding_fn=lambda s, d: np.zeros(s, d) - 1) - - # Create a simple trajectory to add. - observations = range(max_sequence_length) - first, steps = test_utils.make_trajectory(observations) - - expected_episode = test_utils.make_sequence(observations) - for _ in range(padding): - expected_episode.append((-1, -1, -1.0, -1.0, False, ())) - - self.run_test_adder( - adder=adder, - first=first, - steps=steps, - expected_items=[expected_episode], - signature=adder.signature(*test_utils.get_specs(steps[0]))) - - -if __name__ == '__main__': - absltest.main() + @parameterized.parameters(2, 10, 50) + def test_adder(self, max_sequence_length): + adder = adders.EpisodeAdder(self.client, max_sequence_length) + + # Create a simple trajectory to add. + observations = range(max_sequence_length) + first, steps = test_utils.make_trajectory(observations) + + expected_episode = test_utils.make_sequence(observations) + self.run_test_adder( + adder=adder, + first=first, + steps=steps, + expected_items=[expected_episode], + signature=adder.signature(*test_utils.get_specs(steps[0])), + ) + + @parameterized.parameters(2, 10, 50) + def test_max_sequence_length(self, max_sequence_length): + adder = adders.EpisodeAdder(self.client, max_sequence_length) + + first, steps = test_utils.make_trajectory(range(max_sequence_length + 1)) + adder.add_first(first) + for action, step in steps[:-1]: + adder.add(action, step) + + # We should have max_sequence_length-1 timesteps that have been written, + # where the -1 is due to the dangling observation (ie we have actually + # seen max_sequence_length observations). + self.assertEqual(self.num_items(), 0) + + # Adding one more step should raise an error. + with self.assertRaises(ValueError): + action, step = steps[-1] + adder.add(action, step) + + # Since the last insert failed it should not affect the internal state. + self.assertEqual(self.num_items(), 0) + + @parameterized.parameters((2, 1), (10, 2), (50, 5)) + def test_padding(self, max_sequence_length, padding): + adder = adders.EpisodeAdder( + self.client, max_sequence_length + padding, padding_fn=np.zeros + ) + + # Create a simple trajectory to add. + observations = range(max_sequence_length) + first, steps = test_utils.make_trajectory(observations) + + expected_episode = test_utils.make_sequence(observations) + for _ in range(padding): + expected_episode.append((0, 0, 0.0, 0.0, False, ())) + + self.run_test_adder( + adder=adder, + first=first, + steps=steps, + expected_items=[expected_episode], + signature=adder.signature(*test_utils.get_specs(steps[0])), + ) + + @parameterized.parameters((2, 1), (10, 2), (50, 5)) + def test_nonzero_padding(self, max_sequence_length, padding): + adder = adders.EpisodeAdder( + self.client, + max_sequence_length + padding, + padding_fn=lambda s, d: np.zeros(s, d) - 1, + ) + + # Create a simple trajectory to add. + observations = range(max_sequence_length) + first, steps = test_utils.make_trajectory(observations) + + expected_episode = test_utils.make_sequence(observations) + for _ in range(padding): + expected_episode.append((-1, -1, -1.0, -1.0, False, ())) + + self.run_test_adder( + adder=adder, + first=first, + steps=steps, + expected_items=[expected_episode], + signature=adder.signature(*test_utils.get_specs(steps[0])), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/adders/reverb/sequence.py b/acme/adders/reverb/sequence.py index 7d0669e445..c53a2e291e 100644 --- a/acme/adders/reverb/sequence.py +++ b/acme/adders/reverb/sequence.py @@ -21,19 +21,17 @@ import operator from typing import Iterable, Optional -from acme import specs -from acme import types -from acme.adders.reverb import base -from acme.adders.reverb import utils - import numpy as np import reverb import tensorflow as tf import tree +from acme import specs, types +from acme.adders.reverb import base, utils + class EndBehavior(enum.Enum): - """Class to enumerate available options for writing behavior at episode ends. + """Class to enumerate available options for writing behavior at episode ends. Example: @@ -62,32 +60,33 @@ class EndBehavior(enum.Enum): F: First step of the next episode 0: Zero-filled Step """ - WRITE = 'write_buffer' - CONTINUE = 'continue_to_next_episode' - ZERO_PAD = 'zero_pad_til_next_write' - TRUNCATE = 'write_truncated_buffer' + + WRITE = "write_buffer" + CONTINUE = "continue_to_next_episode" + ZERO_PAD = "zero_pad_til_next_write" + TRUNCATE = "write_truncated_buffer" class SequenceAdder(base.ReverbAdder): - """An adder which adds sequences of fixed length.""" - - def __init__( - self, - client: reverb.Client, - sequence_length: int, - period: int, - *, - delta_encoded: bool = False, - priority_fns: Optional[base.PriorityFnMapping] = None, - max_in_flight_items: Optional[int] = 2, - end_of_episode_behavior: Optional[EndBehavior] = None, - # Deprecated kwargs. - chunk_length: Optional[int] = None, - pad_end_of_episode: Optional[bool] = None, - break_end_of_episode: Optional[bool] = None, - validate_items: bool = True, - ): - """Makes a SequenceAdder instance. + """An adder which adds sequences of fixed length.""" + + def __init__( + self, + client: reverb.Client, + sequence_length: int, + period: int, + *, + delta_encoded: bool = False, + priority_fns: Optional[base.PriorityFnMapping] = None, + max_in_flight_items: Optional[int] = 2, + end_of_episode_behavior: Optional[EndBehavior] = None, + # Deprecated kwargs. + chunk_length: Optional[int] = None, + pad_end_of_episode: Optional[bool] = None, + break_end_of_episode: Optional[bool] = None, + validate_items: bool = True, + ): + """Makes a SequenceAdder instance. Args: client: See docstring for BaseAdder. @@ -115,151 +114,162 @@ def __init__( before they are sent to the server. This requires table signature to be fetched from the server and cached locally. """ - del chunk_length - super().__init__( - client=client, - # We need an additional space in the buffer for the partial step the - # base.ReverbAdder will add with the next observation. - max_sequence_length=sequence_length+1, - delta_encoded=delta_encoded, - priority_fns=priority_fns, - max_in_flight_items=max_in_flight_items, - validate_items=validate_items) - - if pad_end_of_episode and not break_end_of_episode: - raise ValueError( - 'Can\'t set pad_end_of_episode=True and break_end_of_episode=False at' - ' the same time, since those behaviors are incompatible.') - - self._period = period - self._sequence_length = sequence_length - - if end_of_episode_behavior and (pad_end_of_episode is not None or - break_end_of_episode is not None): - raise ValueError( - 'Using end_of_episode_behavior and either ' - 'pad_end_of_episode or break_end_of_episode is not permitted. ' - 'Please use only end_of_episode_behavior instead.') - - # Set pad_end_of_episode and break_end_of_episode to default values. - if end_of_episode_behavior is None and pad_end_of_episode is None: - pad_end_of_episode = True - if end_of_episode_behavior is None and break_end_of_episode is None: - break_end_of_episode = True - - self._end_of_episode_behavior = EndBehavior.ZERO_PAD - if pad_end_of_episode is not None or break_end_of_episode is not None: - if not break_end_of_episode: - self._end_of_episode_behavior = EndBehavior.CONTINUE - elif break_end_of_episode and pad_end_of_episode: + del chunk_length + super().__init__( + client=client, + # We need an additional space in the buffer for the partial step the + # base.ReverbAdder will add with the next observation. + max_sequence_length=sequence_length + 1, + delta_encoded=delta_encoded, + priority_fns=priority_fns, + max_in_flight_items=max_in_flight_items, + validate_items=validate_items, + ) + + if pad_end_of_episode and not break_end_of_episode: + raise ValueError( + "Can't set pad_end_of_episode=True and break_end_of_episode=False at" + " the same time, since those behaviors are incompatible." + ) + + self._period = period + self._sequence_length = sequence_length + + if end_of_episode_behavior and ( + pad_end_of_episode is not None or break_end_of_episode is not None + ): + raise ValueError( + "Using end_of_episode_behavior and either " + "pad_end_of_episode or break_end_of_episode is not permitted. " + "Please use only end_of_episode_behavior instead." + ) + + # Set pad_end_of_episode and break_end_of_episode to default values. + if end_of_episode_behavior is None and pad_end_of_episode is None: + pad_end_of_episode = True + if end_of_episode_behavior is None and break_end_of_episode is None: + break_end_of_episode = True + self._end_of_episode_behavior = EndBehavior.ZERO_PAD - elif break_end_of_episode and not pad_end_of_episode: - self._end_of_episode_behavior = EndBehavior.TRUNCATE - else: - raise ValueError( - 'Reached an unexpected configuration of the SequenceAdder ' - f'with break_end_of_episode={break_end_of_episode} ' - f'and pad_end_of_episode={pad_end_of_episode}.') - elif isinstance(end_of_episode_behavior, EndBehavior): - self._end_of_episode_behavior = end_of_episode_behavior - else: - raise ValueError('end_of_episod_behavior must be an instance of ' - f'EndBehavior, received {end_of_episode_behavior}.') - - def reset(self): - """Resets the adder's buffer.""" - # If we do not write on end of episode, we should not reset the writer. - if self._end_of_episode_behavior is EndBehavior.CONTINUE: - return - - super().reset() - - def _write(self): - self._maybe_create_item(self._sequence_length) - - def _write_last(self): - # Maybe determine the delta to the next time we would write a sequence. - if self._end_of_episode_behavior in (EndBehavior.TRUNCATE, - EndBehavior.ZERO_PAD): - delta = self._sequence_length - self._writer.episode_steps - if delta < 0: - delta = (self._period + delta) % self._period - - # Handle various end-of-episode cases. - if self._end_of_episode_behavior is EndBehavior.CONTINUE: - self._maybe_create_item(self._sequence_length, end_of_episode=True) - - elif self._end_of_episode_behavior is EndBehavior.WRITE: - # Drop episodes that are too short. - if self._writer.episode_steps < self._sequence_length: - return - self._maybe_create_item( - self._sequence_length, end_of_episode=True, force=True) - - elif self._end_of_episode_behavior is EndBehavior.TRUNCATE: - self._maybe_create_item( - self._sequence_length - delta, - end_of_episode=True, - force=True) - - elif self._end_of_episode_behavior is EndBehavior.ZERO_PAD: - zero_step = tree.map_structure(lambda x: np.zeros_like(x[-2].numpy()), - self._writer.history) - for _ in range(delta): - self._writer.append(zero_step) - - self._maybe_create_item( - self._sequence_length, end_of_episode=True, force=True) - else: - raise ValueError( - f'Unhandled end of episode behavior: {self._end_of_episode_behavior}.' - ' This should never happen, please contact Acme dev team.') - - def _maybe_create_item(self, - sequence_length: int, - *, - end_of_episode: bool = False, - force: bool = False): - - # Check conditions under which a new item is created. - first_write = self._writer.episode_steps == sequence_length - # NOTE(bshahr): the following line assumes that the only way sequence_length - # is less than self._sequence_length, is if the episode is shorter than - # self._sequence_length. - period_reached = ( - self._writer.episode_steps > self._sequence_length and - ((self._writer.episode_steps - self._sequence_length) % self._period - == 0)) - - if not first_write and not period_reached and not force: - return - - # TODO(b/183945808): will need to change to adhere to the new protocol. - if not end_of_episode: - get_traj = operator.itemgetter(slice(-sequence_length - 1, -1)) - else: - get_traj = operator.itemgetter(slice(-sequence_length, None)) - - history = self._writer.history - trajectory = base.Trajectory(**tree.map_structure(get_traj, history)) - - # Compute priorities for the buffer. - table_priorities = utils.calculate_priorities(self._priority_fns, - trajectory) - - # Create a prioritized item for each table. - for table_name, priority in table_priorities.items(): - self._writer.create_item(table_name, priority, trajectory) - self._writer.flush(self._max_in_flight_items) - - # TODO(bshahr): make this into a standalone method. Class methods should be - # used as alternative constructors or when modifying some global state, - # neither of which is done here. - @classmethod - def signature(cls, environment_spec: specs.EnvironmentSpec, - extras_spec: types.NestedSpec = (), - sequence_length: Optional[int] = None): - """This is a helper method for generating signatures for Reverb tables. + if pad_end_of_episode is not None or break_end_of_episode is not None: + if not break_end_of_episode: + self._end_of_episode_behavior = EndBehavior.CONTINUE + elif break_end_of_episode and pad_end_of_episode: + self._end_of_episode_behavior = EndBehavior.ZERO_PAD + elif break_end_of_episode and not pad_end_of_episode: + self._end_of_episode_behavior = EndBehavior.TRUNCATE + else: + raise ValueError( + "Reached an unexpected configuration of the SequenceAdder " + f"with break_end_of_episode={break_end_of_episode} " + f"and pad_end_of_episode={pad_end_of_episode}." + ) + elif isinstance(end_of_episode_behavior, EndBehavior): + self._end_of_episode_behavior = end_of_episode_behavior + else: + raise ValueError( + "end_of_episod_behavior must be an instance of " + f"EndBehavior, received {end_of_episode_behavior}." + ) + + def reset(self): + """Resets the adder's buffer.""" + # If we do not write on end of episode, we should not reset the writer. + if self._end_of_episode_behavior is EndBehavior.CONTINUE: + return + + super().reset() + + def _write(self): + self._maybe_create_item(self._sequence_length) + + def _write_last(self): + # Maybe determine the delta to the next time we would write a sequence. + if self._end_of_episode_behavior in ( + EndBehavior.TRUNCATE, + EndBehavior.ZERO_PAD, + ): + delta = self._sequence_length - self._writer.episode_steps + if delta < 0: + delta = (self._period + delta) % self._period + + # Handle various end-of-episode cases. + if self._end_of_episode_behavior is EndBehavior.CONTINUE: + self._maybe_create_item(self._sequence_length, end_of_episode=True) + + elif self._end_of_episode_behavior is EndBehavior.WRITE: + # Drop episodes that are too short. + if self._writer.episode_steps < self._sequence_length: + return + self._maybe_create_item( + self._sequence_length, end_of_episode=True, force=True + ) + + elif self._end_of_episode_behavior is EndBehavior.TRUNCATE: + self._maybe_create_item( + self._sequence_length - delta, end_of_episode=True, force=True + ) + + elif self._end_of_episode_behavior is EndBehavior.ZERO_PAD: + zero_step = tree.map_structure( + lambda x: np.zeros_like(x[-2].numpy()), self._writer.history + ) + for _ in range(delta): + self._writer.append(zero_step) + + self._maybe_create_item( + self._sequence_length, end_of_episode=True, force=True + ) + else: + raise ValueError( + f"Unhandled end of episode behavior: {self._end_of_episode_behavior}." + " This should never happen, please contact Acme dev team." + ) + + def _maybe_create_item( + self, sequence_length: int, *, end_of_episode: bool = False, force: bool = False + ): + + # Check conditions under which a new item is created. + first_write = self._writer.episode_steps == sequence_length + # NOTE(bshahr): the following line assumes that the only way sequence_length + # is less than self._sequence_length, is if the episode is shorter than + # self._sequence_length. + period_reached = self._writer.episode_steps > self._sequence_length and ( + (self._writer.episode_steps - self._sequence_length) % self._period == 0 + ) + + if not first_write and not period_reached and not force: + return + + # TODO(b/183945808): will need to change to adhere to the new protocol. + if not end_of_episode: + get_traj = operator.itemgetter(slice(-sequence_length - 1, -1)) + else: + get_traj = operator.itemgetter(slice(-sequence_length, None)) + + history = self._writer.history + trajectory = base.Trajectory(**tree.map_structure(get_traj, history)) + + # Compute priorities for the buffer. + table_priorities = utils.calculate_priorities(self._priority_fns, trajectory) + + # Create a prioritized item for each table. + for table_name, priority in table_priorities.items(): + self._writer.create_item(table_name, priority, trajectory) + self._writer.flush(self._max_in_flight_items) + + # TODO(bshahr): make this into a standalone method. Class methods should be + # used as alternative constructors or when modifying some global state, + # neither of which is done here. + @classmethod + def signature( + cls, + environment_spec: specs.EnvironmentSpec, + extras_spec: types.NestedSpec = (), + sequence_length: Optional[int] = None, + ): + """This is a helper method for generating signatures for Reverb tables. Signatures are useful for validating data types and shapes, see Reverb's documentation for details on how they are used. @@ -279,18 +289,23 @@ def signature(cls, environment_spec: specs.EnvironmentSpec, A `Trajectory` whose leaf nodes are `tf.TensorSpec` objects. """ - def add_time_dim(paths: Iterable[str], spec: tf.TensorSpec): - return tf.TensorSpec(shape=(sequence_length, *spec.shape), - dtype=spec.dtype, - name='/'.join(str(p) for p in paths)) - - trajectory_env_spec, trajectory_extras_spec = tree.map_structure_with_path( - add_time_dim, (environment_spec, extras_spec)) - - spec_step = base.Trajectory( - *trajectory_env_spec, - start_of_episode=tf.TensorSpec( - shape=(sequence_length,), dtype=tf.bool, name='start_of_episode'), - extras=trajectory_extras_spec) - - return spec_step + def add_time_dim(paths: Iterable[str], spec: tf.TensorSpec): + return tf.TensorSpec( + shape=(sequence_length, *spec.shape), + dtype=spec.dtype, + name="/".join(str(p) for p in paths), + ) + + trajectory_env_spec, trajectory_extras_spec = tree.map_structure_with_path( + add_time_dim, (environment_spec, extras_spec) + ) + + spec_step = base.Trajectory( + *trajectory_env_spec, + start_of_episode=tf.TensorSpec( + shape=(sequence_length,), dtype=tf.bool, name="start_of_episode" + ), + extras=trajectory_extras_spec, + ) + + return spec_step diff --git a/acme/adders/reverb/sequence_test.py b/acme/adders/reverb/sequence_test.py index d50f125062..707da8dc1d 100644 --- a/acme/adders/reverb/sequence_test.py +++ b/acme/adders/reverb/sequence_test.py @@ -14,55 +14,57 @@ """Tests for sequence adders.""" -from acme.adders.reverb import sequence as adders -from acme.adders.reverb import test_cases -from acme.adders.reverb import test_utils +from absl.testing import absltest, parameterized -from absl.testing import absltest -from absl.testing import parameterized +from acme.adders.reverb import sequence as adders +from acme.adders.reverb import test_cases, test_utils class SequenceAdderTest(test_utils.AdderTestMixin, parameterized.TestCase): + @parameterized.named_parameters(*test_cases.TEST_CASES_FOR_SEQUENCE_ADDER) + def test_adder( + self, + sequence_length: int, + period: int, + first, + steps, + expected_sequences, + end_behavior: adders.EndBehavior = adders.EndBehavior.ZERO_PAD, + repeat_episode_times: int = 1, + ): + adder = adders.SequenceAdder( + self.client, + sequence_length=sequence_length, + period=period, + end_of_episode_behavior=end_behavior, + ) + super().run_test_adder( + adder=adder, + first=first, + steps=steps, + expected_items=expected_sequences, + repeat_episode_times=repeat_episode_times, + end_behavior=end_behavior, + signature=adder.signature(*test_utils.get_specs(steps[0])), + ) - @parameterized.named_parameters(*test_cases.TEST_CASES_FOR_SEQUENCE_ADDER) - def test_adder(self, - sequence_length: int, - period: int, - first, - steps, - expected_sequences, - end_behavior: adders.EndBehavior = adders.EndBehavior.ZERO_PAD, - repeat_episode_times: int = 1): - adder = adders.SequenceAdder( - self.client, - sequence_length=sequence_length, - period=period, - end_of_episode_behavior=end_behavior) - super().run_test_adder( - adder=adder, - first=first, - steps=steps, - expected_items=expected_sequences, - repeat_episode_times=repeat_episode_times, - end_behavior=end_behavior, - signature=adder.signature(*test_utils.get_specs(steps[0]))) - - @parameterized.parameters( - (True, True, adders.EndBehavior.ZERO_PAD), - (False, True, adders.EndBehavior.TRUNCATE), - (False, False, adders.EndBehavior.CONTINUE), - ) - def test_end_of_episode_behavior_set_correctly(self, pad_end_of_episode, - break_end_of_episode, - expected_behavior): - adder = adders.SequenceAdder( - self.client, - sequence_length=5, - period=3, - pad_end_of_episode=pad_end_of_episode, - break_end_of_episode=break_end_of_episode) - self.assertEqual(adder._end_of_episode_behavior, expected_behavior) + @parameterized.parameters( + (True, True, adders.EndBehavior.ZERO_PAD), + (False, True, adders.EndBehavior.TRUNCATE), + (False, False, adders.EndBehavior.CONTINUE), + ) + def test_end_of_episode_behavior_set_correctly( + self, pad_end_of_episode, break_end_of_episode, expected_behavior + ): + adder = adders.SequenceAdder( + self.client, + sequence_length=5, + period=3, + pad_end_of_episode=pad_end_of_episode, + break_end_of_episode=break_end_of_episode, + ) + self.assertEqual(adder._end_of_episode_behavior, expected_behavior) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/adders/reverb/structured.py b/acme/adders/reverb/structured.py index b99f171589..7cc78a9718 100644 --- a/acme/adders/reverb/structured.py +++ b/acme/adders/reverb/structured.py @@ -16,21 +16,20 @@ import itertools import time - from typing import Callable, List, Optional, Sequence, Sized -from absl import logging -from acme import specs -from acme import types -from acme.adders import base as adders_base -from acme.adders.reverb import base as reverb_base -from acme.adders.reverb import sequence as sequence_adder import dm_env import numpy as np import reverb -from reverb import structured_writer as sw import tensorflow as tf import tree +from absl import logging +from reverb import structured_writer as sw + +from acme import specs, types +from acme.adders import base as adders_base +from acme.adders.reverb import base as reverb_base +from acme.adders.reverb import sequence as sequence_adder Step = reverb_base.Step Trajectory = reverb_base.Trajectory @@ -40,7 +39,7 @@ class StructuredAdder(adders_base.Adder): - """Generic Adder which writes to Reverb using Reverb's `StructuredWriter`. + """Generic Adder which writes to Reverb using Reverb's `StructuredWriter`. The StructuredAdder is a thin wrapper around Reverb's `StructuredWriter` and its behaviour is determined by the configs to __init__. Much of the behaviour @@ -63,9 +62,14 @@ class StructuredAdder(adders_base.Adder): expected to perform preprocessing in the dataset pipeline on the learner. """ - def __init__(self, client: reverb.Client, max_in_flight_items: int, - configs: Sequence[sw.Config], step_spec: Step): - """Initialize a StructuredAdder instance. + def __init__( + self, + client: reverb.Client, + max_in_flight_items: int, + configs: Sequence[sw.Config], + step_spec: Step, + ): + """Initialize a StructuredAdder instance. Args: client: A client to the Reverb backend. @@ -79,125 +83,139 @@ def __init__(self, client: reverb.Client, max_in_flight_items: int, and the extras spec. """ - # We validate the configs by attempting to infer the signatures of all - # targeted tables. - for table, table_configs in itertools.groupby(configs, lambda c: c.table): - try: - sw.infer_signature(list(table_configs), step_spec) - except ValueError as e: - raise ValueError( - f'Received invalid configs for table {table}: {str(e)}') from e - - self._client = client - self._configs = tuple(configs) - self._none_step: Step = tree.map_structure(lambda _: None, step_spec) - self._max_in_flight_items = max_in_flight_items - - self._writer = None - self._writer_created_at = None - - def __del__(self): - if self._writer is None: - return - - # Try flush all appended data before closing to avoid loss of experience. - try: - self._writer.flush(0, timeout_ms=10_000) - except reverb.DeadlineExceededError as e: - logging.error( - 'Timeout (10 s) exceeded when flushing the writer before ' - 'deleting it. Caught Reverb exception: %s', str(e)) - - def _make_step(self, **kwargs) -> Step: - """Complete the step with None in the missing positions.""" - return Step(**{**self._none_step._asdict(), **kwargs}) - - @property - def configs(self): - return self._configs - - def reset(self, timeout_ms: Optional[int] = None): - """Marks the active episode as completed and flushes pending items.""" - if self._writer is not None: - # Flush all pending items. - self._writer.end_episode(clear_buffers=True, timeout_ms=timeout_ms) - - # Create a new writer unless the current one is too young. - # This is to reduce the relative overhead of creating a new Reverb writer. - if time.time() - self._writer_created_at > _RESET_WRITER_EVERY_SECONDS: - self._writer = None + # We validate the configs by attempting to infer the signatures of all + # targeted tables. + for table, table_configs in itertools.groupby(configs, lambda c: c.table): + try: + sw.infer_signature(list(table_configs), step_spec) + except ValueError as e: + raise ValueError( + f"Received invalid configs for table {table}: {str(e)}" + ) from e + + self._client = client + self._configs = tuple(configs) + self._none_step: Step = tree.map_structure(lambda _: None, step_spec) + self._max_in_flight_items = max_in_flight_items - def add_first(self, timestep: dm_env.TimeStep): - """Record the first observation of an episode.""" - if not timestep.first(): - raise ValueError( - 'adder.add_first called with a timestep that was not the first of its' - 'episode (i.e. one for which timestep.first() is not True)') - - if self._writer is None: - self._writer = self._client.structured_writer(self._configs) - self._writer_created_at = time.time() - - # Record the next observation but leave the history buffer row open by - # passing `partial_step=True`. - self._writer.append( - data=self._make_step( - observation=timestep.observation, - start_of_episode=timestep.first()), - partial_step=True) - self._writer.flush(self._max_in_flight_items) - - def add(self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - extras: types.NestedArray = ()): - """Record an action and the following timestep.""" - - if not self._writer.step_is_open: - raise ValueError('adder.add_first must be called before adder.add.') - - # Add the timestep to the buffer. - has_extras = ( - len(extras) > 0 if isinstance(extras, Sized) # pylint: disable=g-explicit-length-test - else extras is not None) - - current_step = self._make_step( - action=action, - reward=next_timestep.reward, - discount=next_timestep.discount, - extras=extras if has_extras else self._none_step.extras) - self._writer.append(current_step) - - # Record the next observation and write. - self._writer.append( - data=self._make_step( - observation=next_timestep.observation, - start_of_episode=next_timestep.first()), - partial_step=True) - self._writer.flush(self._max_in_flight_items) - - if next_timestep.last(): - # Complete the row by appending zeros to remaining open fields. - # TODO(b/183945808): remove this when fields are no longer expected to be - # of equal length on the learner side. - dummy_step = tree.map_structure( - lambda x: None if x is None else np.zeros_like(x), current_step) - self._writer.append(dummy_step) - self.reset() + self._writer = None + self._writer_created_at = None + + def __del__(self): + if self._writer is None: + return + + # Try flush all appended data before closing to avoid loss of experience. + try: + self._writer.flush(0, timeout_ms=10_000) + except reverb.DeadlineExceededError as e: + logging.error( + "Timeout (10 s) exceeded when flushing the writer before " + "deleting it. Caught Reverb exception: %s", + str(e), + ) + + def _make_step(self, **kwargs) -> Step: + """Complete the step with None in the missing positions.""" + return Step(**{**self._none_step._asdict(), **kwargs}) + + @property + def configs(self): + return self._configs + + def reset(self, timeout_ms: Optional[int] = None): + """Marks the active episode as completed and flushes pending items.""" + if self._writer is not None: + # Flush all pending items. + self._writer.end_episode(clear_buffers=True, timeout_ms=timeout_ms) + + # Create a new writer unless the current one is too young. + # This is to reduce the relative overhead of creating a new Reverb writer. + if time.time() - self._writer_created_at > _RESET_WRITER_EVERY_SECONDS: + self._writer = None + + def add_first(self, timestep: dm_env.TimeStep): + """Record the first observation of an episode.""" + if not timestep.first(): + raise ValueError( + "adder.add_first called with a timestep that was not the first of its" + "episode (i.e. one for which timestep.first() is not True)" + ) + + if self._writer is None: + self._writer = self._client.structured_writer(self._configs) + self._writer_created_at = time.time() + + # Record the next observation but leave the history buffer row open by + # passing `partial_step=True`. + self._writer.append( + data=self._make_step( + observation=timestep.observation, start_of_episode=timestep.first() + ), + partial_step=True, + ) + self._writer.flush(self._max_in_flight_items) + + def add( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = (), + ): + """Record an action and the following timestep.""" + + if not self._writer.step_is_open: + raise ValueError("adder.add_first must be called before adder.add.") + + # Add the timestep to the buffer. + has_extras = ( + len(extras) > 0 + if isinstance(extras, Sized) # pylint: disable=g-explicit-length-test + else extras is not None + ) + + current_step = self._make_step( + action=action, + reward=next_timestep.reward, + discount=next_timestep.discount, + extras=extras if has_extras else self._none_step.extras, + ) + self._writer.append(current_step) + + # Record the next observation and write. + self._writer.append( + data=self._make_step( + observation=next_timestep.observation, + start_of_episode=next_timestep.first(), + ), + partial_step=True, + ) + self._writer.flush(self._max_in_flight_items) + + if next_timestep.last(): + # Complete the row by appending zeros to remaining open fields. + # TODO(b/183945808): remove this when fields are no longer expected to be + # of equal length on the learner side. + dummy_step = tree.map_structure( + lambda x: None if x is None else np.zeros_like(x), current_step + ) + self._writer.append(dummy_step) + self.reset() def create_step_spec( environment_spec: specs.EnvironmentSpec, extras_spec: types.NestedSpec = () ) -> Step: - return Step( - *environment_spec, - start_of_episode=tf.TensorSpec([], tf.bool, 'start_of_episode'), - extras=extras_spec) + return Step( + *environment_spec, + start_of_episode=tf.TensorSpec([], tf.bool, "start_of_episode"), + extras=extras_spec, + ) def _last_n(n: int, step_spec: Step) -> Trajectory: - """Constructs a sequence with the last n elements of all the Step fields.""" - return Trajectory(*tree.map_structure(lambda x: x[-n:], step_spec)) + """Constructs a sequence with the last n elements of all the Step fields.""" + return Trajectory(*tree.map_structure(lambda x: x[-n:], step_spec)) def create_sequence_config( @@ -208,7 +226,7 @@ def create_sequence_config( end_of_episode_behavior: EndBehavior = EndBehavior.TRUNCATE, sequence_pattern: Callable[[int, Step], Trajectory] = _last_n, ) -> List[sw.Config]: - """Generates configs that produces the same behaviour as `SequenceAdder`. + """Generates configs that produces the same behaviour as `SequenceAdder`. NOTE! ZERO_PAD is not supported as the same behaviour can be achieved by writing with TRUNCATE and then adding padding in the dataset pipeline on the @@ -237,113 +255,119 @@ def create_sequence_config( ValueError: If sequence_length is <= 0. NotImplementedError: If `end_of_episod_behavior` is `ZERO_PAD`. """ - if sequence_length <= 0: - raise ValueError(f'sequence_length must be > 0 but got {sequence_length}.') - - if end_of_episode_behavior == EndBehavior.ZERO_PAD: - raise NotImplementedError( - 'Zero-padding is not supported. Please use TRUNCATE instead.') - - if end_of_episode_behavior == EndBehavior.CONTINUE: - raise NotImplementedError('Merging episodes is not supported.') - - def _sequence_pattern(n: int) -> sw.Pattern: - return sw.pattern_from_transform(step_spec, - lambda step: sequence_pattern(n, step)) - - # The base config is considered for all but the last step in the episode. No - # trajectories are created for the first `sequence_step-1` steps and then a - # new trajectory is inserted every `period` steps. - base_config = sw.create_config( - pattern=_sequence_pattern(sequence_length), - table=table, - conditions=[ - sw.Condition.step_index() >= sequence_length - 1, - sw.Condition.step_index() % period == (sequence_length - 1) % period, - ]) - - end_of_episode_configs = [] - if end_of_episode_behavior == EndBehavior.WRITE: - # Simply write a trajectory in exactly the same way as the base config. The - # only difference here is that we ALWAYS create a trajectory even if it - # doesn't align with the `period`. The exceptions to the rule are episodes - # that are shorter than `sequence_length` steps which are completely - # ignored. - config = sw.create_config( + if sequence_length <= 0: + raise ValueError(f"sequence_length must be > 0 but got {sequence_length}.") + + if end_of_episode_behavior == EndBehavior.ZERO_PAD: + raise NotImplementedError( + "Zero-padding is not supported. Please use TRUNCATE instead." + ) + + if end_of_episode_behavior == EndBehavior.CONTINUE: + raise NotImplementedError("Merging episodes is not supported.") + + def _sequence_pattern(n: int) -> sw.Pattern: + return sw.pattern_from_transform( + step_spec, lambda step: sequence_pattern(n, step) + ) + + # The base config is considered for all but the last step in the episode. No + # trajectories are created for the first `sequence_step-1` steps and then a + # new trajectory is inserted every `period` steps. + base_config = sw.create_config( pattern=_sequence_pattern(sequence_length), table=table, conditions=[ - sw.Condition.is_end_episode(), sw.Condition.step_index() >= sequence_length - 1, - ]) - end_of_episode_configs.append(config) - elif end_of_episode_behavior == EndBehavior.TRUNCATE: - # The first trajectory is written at step index `sequence_length - 1` and - # then written every `period` step. This means that the - # `step_index % period` will always be equal to the below value everytime a - # trajectory is written. - target = (sequence_length - 1) % period - - # When the episode ends we still want to capture the steps that has been - # appended since the last item was created. We do this by creating a config - # for all `step_index % period`, except `target`, and condition these - # configs so that they only are triggered when `end_episode` is called. - for x in range(period): - # When the last step is aligned with the period of the inserts then no - # action is required as the item was already generated by `base_config`. - if x == target: - continue - - # If we were to pad the trajectory then we'll need to continue adding - # padding until `step_index % period` is equal to `target` again. We can - # exploit this relation by conditioning the config to only be applied for - # a single value of `step_index % period`. This constraint means that we - # can infer the number of padding steps required until the next write - # would have occurred if the episode didn't end. - # - # Now if we assume that the padding instead is added on the dataset (or - # the trajectory is simply truncated) then we can infer from the above - # that the number of real steps in this padded trajectory will be the - # difference between `sequence_length` and number of pad steps. - num_pad_steps = (target - x) % period - unpadded_length = sequence_length - num_pad_steps - - config = sw.create_config( - pattern=_sequence_pattern(unpadded_length), - table=table, - conditions=[ - sw.Condition.is_end_episode(), - sw.Condition.step_index() % period == x, - sw.Condition.step_index() >= sequence_length, - ]) - end_of_episode_configs.append(config) - - # The above configs will capture the "remainder" of any episode that is at - # least `sequence_length` steps long. However, if the entire episode is - # shorter than `sequence_length` then data might still be lost. We avoid - # this by simply creating `sequence_length-1` configs that capture the last - # `x` steps iff the entire episode is `x` steps long. - for x in range(1, sequence_length): - config = sw.create_config( - pattern=_sequence_pattern(x), - table=table, - conditions=[ - sw.Condition.is_end_episode(), - sw.Condition.step_index() == x - 1, - ]) - end_of_episode_configs.append(config) - else: - raise ValueError( - f'Unexpected `end_of_episod_behavior`: {end_of_episode_behavior}') - - return [base_config] + end_of_episode_configs + sw.Condition.step_index() % period == (sequence_length - 1) % period, + ], + ) + + end_of_episode_configs = [] + if end_of_episode_behavior == EndBehavior.WRITE: + # Simply write a trajectory in exactly the same way as the base config. The + # only difference here is that we ALWAYS create a trajectory even if it + # doesn't align with the `period`. The exceptions to the rule are episodes + # that are shorter than `sequence_length` steps which are completely + # ignored. + config = sw.create_config( + pattern=_sequence_pattern(sequence_length), + table=table, + conditions=[ + sw.Condition.is_end_episode(), + sw.Condition.step_index() >= sequence_length - 1, + ], + ) + end_of_episode_configs.append(config) + elif end_of_episode_behavior == EndBehavior.TRUNCATE: + # The first trajectory is written at step index `sequence_length - 1` and + # then written every `period` step. This means that the + # `step_index % period` will always be equal to the below value everytime a + # trajectory is written. + target = (sequence_length - 1) % period + + # When the episode ends we still want to capture the steps that has been + # appended since the last item was created. We do this by creating a config + # for all `step_index % period`, except `target`, and condition these + # configs so that they only are triggered when `end_episode` is called. + for x in range(period): + # When the last step is aligned with the period of the inserts then no + # action is required as the item was already generated by `base_config`. + if x == target: + continue + + # If we were to pad the trajectory then we'll need to continue adding + # padding until `step_index % period` is equal to `target` again. We can + # exploit this relation by conditioning the config to only be applied for + # a single value of `step_index % period`. This constraint means that we + # can infer the number of padding steps required until the next write + # would have occurred if the episode didn't end. + # + # Now if we assume that the padding instead is added on the dataset (or + # the trajectory is simply truncated) then we can infer from the above + # that the number of real steps in this padded trajectory will be the + # difference between `sequence_length` and number of pad steps. + num_pad_steps = (target - x) % period + unpadded_length = sequence_length - num_pad_steps + + config = sw.create_config( + pattern=_sequence_pattern(unpadded_length), + table=table, + conditions=[ + sw.Condition.is_end_episode(), + sw.Condition.step_index() % period == x, + sw.Condition.step_index() >= sequence_length, + ], + ) + end_of_episode_configs.append(config) + + # The above configs will capture the "remainder" of any episode that is at + # least `sequence_length` steps long. However, if the entire episode is + # shorter than `sequence_length` then data might still be lost. We avoid + # this by simply creating `sequence_length-1` configs that capture the last + # `x` steps iff the entire episode is `x` steps long. + for x in range(1, sequence_length): + config = sw.create_config( + pattern=_sequence_pattern(x), + table=table, + conditions=[ + sw.Condition.is_end_episode(), + sw.Condition.step_index() == x - 1, + ], + ) + end_of_episode_configs.append(config) + else: + raise ValueError( + f"Unexpected `end_of_episod_behavior`: {end_of_episode_behavior}" + ) + + return [base_config] + end_of_episode_configs def create_n_step_transition_config( - step_spec: Step, - n_step: int, - table: str = reverb_base.DEFAULT_PRIORITY_TABLE) -> List[sw.Config]: - """Generates configs that replicates the behaviour of NStepTransitionAdder. + step_spec: Step, n_step: int, table: str = reverb_base.DEFAULT_PRIORITY_TABLE +) -> List[sw.Config]: + """Generates configs that replicates the behaviour of NStepTransitionAdder. Please see the docstring of NStepTransitionAdder for more details. @@ -367,58 +391,57 @@ def create_n_step_transition_config( A list of configs for `StructuredAdder` to produce the described behaviour. """ - def _make_pattern(n: int): - ref_step = sw.create_reference_step(step_spec) - - get_first = lambda x: x[-(n + 1):-n] - get_all = lambda x: x[-(n + 1):-1] - get_first_and_last = lambda x: x[-(n + 1)::n] - - tmap = tree.map_structure - - # We use the exact same structure as we done when writing sequences except - # we trim the number of steps in each sub tree. This has the benefit that - # the postprocessing used to transform these items into N-step transition - # structures (cumulative rewards and discounts etc.) can be applied on - # full sequence items as well. The only difference being that the latter is - # more wasteful than the trimmed down version we write here. - return Trajectory( - observation=tmap(get_first_and_last, ref_step.observation), - action=tmap(get_first, ref_step.action), - reward=tmap(get_all, ref_step.reward), - discount=tmap(get_all, ref_step.discount), - start_of_episode=tmap(get_first, ref_step.start_of_episode), - extras=tmap(get_first, ref_step.extras)) - - # At the start of the episodes we'll add shorter transitions. - start_of_episode_configs = [] - for n in range(1, n_step): - config = sw.create_config( - pattern=_make_pattern(n), - table=table, - conditions=[ - sw.Condition.step_index() == n, - ], - ) - start_of_episode_configs.append(config) - - # During all other steps we'll add a full N-step transition. - base_config = sw.create_config(pattern=_make_pattern(n_step), table=table) - - # When the episode ends we'll add shorter transitions. - end_of_episode_configs = [] - for n in range(n_step - 1, 0, -1): - config = sw.create_config( - pattern=_make_pattern(n), - table=table, - conditions=[ - sw.Condition.is_end_episode(), - # If the entire episode is shorter than n_step then the episode - # start configs will already create an item that covers all the - # steps so we add this filter here to avoid adding it again. - sw.Condition.step_index() != n, - ], - ) - end_of_episode_configs.append(config) - - return start_of_episode_configs + [base_config] + end_of_episode_configs + def _make_pattern(n: int): + ref_step = sw.create_reference_step(step_spec) + + get_first = lambda x: x[-(n + 1) : -n] + get_all = lambda x: x[-(n + 1) : -1] + get_first_and_last = lambda x: x[-(n + 1) :: n] + + tmap = tree.map_structure + + # We use the exact same structure as we done when writing sequences except + # we trim the number of steps in each sub tree. This has the benefit that + # the postprocessing used to transform these items into N-step transition + # structures (cumulative rewards and discounts etc.) can be applied on + # full sequence items as well. The only difference being that the latter is + # more wasteful than the trimmed down version we write here. + return Trajectory( + observation=tmap(get_first_and_last, ref_step.observation), + action=tmap(get_first, ref_step.action), + reward=tmap(get_all, ref_step.reward), + discount=tmap(get_all, ref_step.discount), + start_of_episode=tmap(get_first, ref_step.start_of_episode), + extras=tmap(get_first, ref_step.extras), + ) + + # At the start of the episodes we'll add shorter transitions. + start_of_episode_configs = [] + for n in range(1, n_step): + config = sw.create_config( + pattern=_make_pattern(n), + table=table, + conditions=[sw.Condition.step_index() == n,], + ) + start_of_episode_configs.append(config) + + # During all other steps we'll add a full N-step transition. + base_config = sw.create_config(pattern=_make_pattern(n_step), table=table) + + # When the episode ends we'll add shorter transitions. + end_of_episode_configs = [] + for n in range(n_step - 1, 0, -1): + config = sw.create_config( + pattern=_make_pattern(n), + table=table, + conditions=[ + sw.Condition.is_end_episode(), + # If the entire episode is shorter than n_step then the episode + # start configs will already create an item that covers all the + # steps so we add this filter here to avoid adding it again. + sw.Condition.step_index() != n, + ], + ) + end_of_episode_configs.append(config) + + return start_of_episode_configs + [base_config] + end_of_episode_configs diff --git a/acme/adders/reverb/structured_test.py b/acme/adders/reverb/structured_test.py index 761536e138..0544446185 100644 --- a/acme/adders/reverb/structured_test.py +++ b/acme/adders/reverb/structured_test.py @@ -16,171 +16,187 @@ from typing import Sequence -from acme import types -from acme.adders.reverb import sequence as adders -from acme.adders.reverb import structured -from acme.adders.reverb import test_cases -from acme.adders.reverb import test_utils -from acme.utils import tree_utils import dm_env import numpy as np -from reverb import structured_writer as sw import tree +from absl.testing import absltest, parameterized +from reverb import structured_writer as sw -from absl.testing import absltest -from absl.testing import parameterized +from acme import types +from acme.adders.reverb import sequence as adders +from acme.adders.reverb import structured, test_cases, test_utils +from acme.utils import tree_utils class StructuredAdderTest(test_utils.AdderTestMixin, parameterized.TestCase): - - @parameterized.named_parameters(*test_cases.BASE_TEST_CASES_FOR_SEQUENCE_ADDER - ) - def test_sequence_adder(self, - sequence_length: int, - period: int, - first, - steps, - expected_sequences, - end_behavior: adders.EndBehavior, - repeat_episode_times: int = 1): - - env_spec, extras_spec = test_utils.get_specs(steps[0]) - step_spec = structured.create_step_spec(env_spec, extras_spec) - - should_pad_trajectory = end_behavior == adders.EndBehavior.ZERO_PAD - - def _maybe_zero_pad(flat_trajectory): - trajectory = tree.unflatten_as(step_spec, flat_trajectory) - - if not should_pad_trajectory: - return trajectory - - padding_length = sequence_length - flat_trajectory[0].shape[0] - if padding_length == 0: - return trajectory - - padding = tree.map_structure( - lambda x: np.zeros([padding_length, *x.shape[1:]], x.dtype), - trajectory) - - return tree.map_structure(lambda *x: np.concatenate(x), trajectory, - padding) - - # The StructuredAdder does not support adding padding steps as we assume - # that the padding will be added on the learner side. - if end_behavior == adders.EndBehavior.ZERO_PAD: - end_behavior = adders.EndBehavior.TRUNCATE - - configs = structured.create_sequence_config( - step_spec=step_spec, - sequence_length=sequence_length, - period=period, - end_of_episode_behavior=end_behavior) - adder = structured.StructuredAdder( - client=self.client, - max_in_flight_items=0, - configs=configs, - step_spec=step_spec) - - super().run_test_adder( - adder=adder, - first=first, - steps=steps, - expected_items=expected_sequences, - repeat_episode_times=repeat_episode_times, - end_behavior=end_behavior, - item_transform=_maybe_zero_pad, - signature=sw.infer_signature(configs, step_spec)) - - @parameterized.named_parameters(*test_cases.TEST_CASES_FOR_TRANSITION_ADDER) - def test_transition_adder(self, n_step: int, additional_discount: float, - first: dm_env.TimeStep, - steps: Sequence[dm_env.TimeStep], - expected_transitions: Sequence[types.Transition]): - - env_spec, extras_spec = test_utils.get_specs(steps[0]) - step_spec = structured.create_step_spec(env_spec, extras_spec) - - def _as_n_step_transition(flat_trajectory): - trajectory = tree.unflatten_as(step_spec, flat_trajectory) - - rewards, discount = _compute_cumulative_quantities( - rewards=trajectory.reward, - discounts=trajectory.discount, - additional_discount=additional_discount, - n_step=tree.flatten(trajectory.reward)[0].shape[0]) - - tmap = tree.map_structure - return types.Transition( - observation=tmap(lambda x: x[0], trajectory.observation), - action=tmap(lambda x: x[0], trajectory.action), - reward=rewards, - discount=discount, - next_observation=tmap(lambda x: x[-1], trajectory.observation), - extras=tmap(lambda x: x[0], trajectory.extras)) - - configs = structured.create_n_step_transition_config( - step_spec=step_spec, n_step=n_step) - - adder = structured.StructuredAdder( - client=self.client, - max_in_flight_items=0, - configs=configs, - step_spec=step_spec) - - super().run_test_adder( - adder=adder, - first=first, - steps=steps, - expected_items=expected_transitions, - stack_sequence_fields=False, - item_transform=_as_n_step_transition, - signature=sw.infer_signature(configs, step_spec)) - - -def _compute_cumulative_quantities(rewards: types.NestedArray, - discounts: types.NestedArray, - additional_discount: float, n_step: int): - """Stolen from TransitionAdder.""" - - # Give the same tree structure to the n-step return accumulator, - # n-step discount accumulator, and self.discount, so that they can be - # iterated in parallel using tree.map_structure. - rewards, discounts, self_discount = tree_utils.broadcast_structures( - rewards, discounts, additional_discount) - flat_rewards = tree.flatten(rewards) - flat_discounts = tree.flatten(discounts) - flat_self_discount = tree.flatten(self_discount) - - # Copy total_discount as it is otherwise read-only. - total_discount = [np.copy(a[0]) for a in flat_discounts] - - # Broadcast n_step_return to have the broadcasted shape of - # reward * discount. - n_step_return = [ - np.copy(np.broadcast_to(r[0], - np.broadcast(r[0], d).shape)) - for r, d in zip(flat_rewards, total_discount) - ] - - # NOTE: total_discount will have one less self_discount applied to it than - # the value of self._n_step. This is so that when the learner/update uses - # an additional discount we don't apply it twice. Inside the following loop - # we will apply this right before summing up the n_step_return. - for i in range(1, n_step): - for nsr, td, r, d, sd in zip(n_step_return, total_discount, flat_rewards, - flat_discounts, flat_self_discount): - # Equivalent to: `total_discount *= self._discount`. - td *= sd - # Equivalent to: `n_step_return += reward[i] * total_discount`. - nsr += r[i] * td - # Equivalent to: `total_discount *= discount[i]`. - td *= d[i] - - n_step_return = tree.unflatten_as(rewards, n_step_return) - total_discount = tree.unflatten_as(rewards, total_discount) - return n_step_return, total_discount - - -if __name__ == '__main__': - absltest.main() + @parameterized.named_parameters(*test_cases.BASE_TEST_CASES_FOR_SEQUENCE_ADDER) + def test_sequence_adder( + self, + sequence_length: int, + period: int, + first, + steps, + expected_sequences, + end_behavior: adders.EndBehavior, + repeat_episode_times: int = 1, + ): + + env_spec, extras_spec = test_utils.get_specs(steps[0]) + step_spec = structured.create_step_spec(env_spec, extras_spec) + + should_pad_trajectory = end_behavior == adders.EndBehavior.ZERO_PAD + + def _maybe_zero_pad(flat_trajectory): + trajectory = tree.unflatten_as(step_spec, flat_trajectory) + + if not should_pad_trajectory: + return trajectory + + padding_length = sequence_length - flat_trajectory[0].shape[0] + if padding_length == 0: + return trajectory + + padding = tree.map_structure( + lambda x: np.zeros([padding_length, *x.shape[1:]], x.dtype), trajectory + ) + + return tree.map_structure(lambda *x: np.concatenate(x), trajectory, padding) + + # The StructuredAdder does not support adding padding steps as we assume + # that the padding will be added on the learner side. + if end_behavior == adders.EndBehavior.ZERO_PAD: + end_behavior = adders.EndBehavior.TRUNCATE + + configs = structured.create_sequence_config( + step_spec=step_spec, + sequence_length=sequence_length, + period=period, + end_of_episode_behavior=end_behavior, + ) + adder = structured.StructuredAdder( + client=self.client, + max_in_flight_items=0, + configs=configs, + step_spec=step_spec, + ) + + super().run_test_adder( + adder=adder, + first=first, + steps=steps, + expected_items=expected_sequences, + repeat_episode_times=repeat_episode_times, + end_behavior=end_behavior, + item_transform=_maybe_zero_pad, + signature=sw.infer_signature(configs, step_spec), + ) + + @parameterized.named_parameters(*test_cases.TEST_CASES_FOR_TRANSITION_ADDER) + def test_transition_adder( + self, + n_step: int, + additional_discount: float, + first: dm_env.TimeStep, + steps: Sequence[dm_env.TimeStep], + expected_transitions: Sequence[types.Transition], + ): + + env_spec, extras_spec = test_utils.get_specs(steps[0]) + step_spec = structured.create_step_spec(env_spec, extras_spec) + + def _as_n_step_transition(flat_trajectory): + trajectory = tree.unflatten_as(step_spec, flat_trajectory) + + rewards, discount = _compute_cumulative_quantities( + rewards=trajectory.reward, + discounts=trajectory.discount, + additional_discount=additional_discount, + n_step=tree.flatten(trajectory.reward)[0].shape[0], + ) + + tmap = tree.map_structure + return types.Transition( + observation=tmap(lambda x: x[0], trajectory.observation), + action=tmap(lambda x: x[0], trajectory.action), + reward=rewards, + discount=discount, + next_observation=tmap(lambda x: x[-1], trajectory.observation), + extras=tmap(lambda x: x[0], trajectory.extras), + ) + + configs = structured.create_n_step_transition_config( + step_spec=step_spec, n_step=n_step + ) + + adder = structured.StructuredAdder( + client=self.client, + max_in_flight_items=0, + configs=configs, + step_spec=step_spec, + ) + + super().run_test_adder( + adder=adder, + first=first, + steps=steps, + expected_items=expected_transitions, + stack_sequence_fields=False, + item_transform=_as_n_step_transition, + signature=sw.infer_signature(configs, step_spec), + ) + + +def _compute_cumulative_quantities( + rewards: types.NestedArray, + discounts: types.NestedArray, + additional_discount: float, + n_step: int, +): + """Stolen from TransitionAdder.""" + + # Give the same tree structure to the n-step return accumulator, + # n-step discount accumulator, and self.discount, so that they can be + # iterated in parallel using tree.map_structure. + rewards, discounts, self_discount = tree_utils.broadcast_structures( + rewards, discounts, additional_discount + ) + flat_rewards = tree.flatten(rewards) + flat_discounts = tree.flatten(discounts) + flat_self_discount = tree.flatten(self_discount) + + # Copy total_discount as it is otherwise read-only. + total_discount = [np.copy(a[0]) for a in flat_discounts] + + # Broadcast n_step_return to have the broadcasted shape of + # reward * discount. + n_step_return = [ + np.copy(np.broadcast_to(r[0], np.broadcast(r[0], d).shape)) + for r, d in zip(flat_rewards, total_discount) + ] + + # NOTE: total_discount will have one less self_discount applied to it than + # the value of self._n_step. This is so that when the learner/update uses + # an additional discount we don't apply it twice. Inside the following loop + # we will apply this right before summing up the n_step_return. + for i in range(1, n_step): + for nsr, td, r, d, sd in zip( + n_step_return, + total_discount, + flat_rewards, + flat_discounts, + flat_self_discount, + ): + # Equivalent to: `total_discount *= self._discount`. + td *= sd + # Equivalent to: `n_step_return += reward[i] * total_discount`. + nsr += r[i] * td + # Equivalent to: `total_discount *= discount[i]`. + td *= d[i] + + n_step_return = tree.unflatten_as(rewards, n_step_return) + total_discount = tree.unflatten_as(rewards, total_discount) + return n_step_return, total_discount + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/adders/reverb/test_cases.py b/acme/adders/reverb/test_cases.py index fac0ffb3b1..d0a4d4c2f9 100644 --- a/acme/adders/reverb/test_cases.py +++ b/acme/adders/reverb/test_cases.py @@ -14,11 +14,12 @@ """Test cases used by multiple test files.""" -from acme import types -from acme.adders.reverb import sequence as sequence_adder import dm_env import numpy as np +from acme import types +from acme.adders.reverb import sequence as sequence_adder + # Define the main set of test cases; these are given as parameterized tests to # the test_adder method and describe a trajectory to add to replay and the # expected transitions that should result from this trajectory. The expected @@ -26,7 +27,7 @@ # next_observation, extras). TEST_CASES_FOR_TRANSITION_ADDER = [ dict( - testcase_name='OneStepFinalReward', + testcase_name="OneStepFinalReward", n_step=1, additional_discount=1.0, first=dm_env.restart(1), @@ -39,84 +40,58 @@ types.Transition(1, 0, 0.0, 1.0, 2), types.Transition(2, 0, 0.0, 1.0, 3), types.Transition(3, 0, 1.0, 0.0, 4), - )), + ), + ), dict( - testcase_name='OneStepDict', + testcase_name="OneStepDict", n_step=1, additional_discount=1.0, - first=dm_env.restart({'foo': 1}), + first=dm_env.restart({"foo": 1}), steps=( - (0, dm_env.transition(reward=0.0, observation={'foo': 2})), - (0, dm_env.transition(reward=0.0, observation={'foo': 3})), - (0, dm_env.termination(reward=1.0, observation={'foo': 4})), + (0, dm_env.transition(reward=0.0, observation={"foo": 2})), + (0, dm_env.transition(reward=0.0, observation={"foo": 3})), + (0, dm_env.termination(reward=1.0, observation={"foo": 4})), ), expected_transitions=( - types.Transition({'foo': 1}, 0, 0.0, 1.0, {'foo': 2}), - types.Transition({'foo': 2}, 0, 0.0, 1.0, {'foo': 3}), - types.Transition({'foo': 3}, 0, 1.0, 0.0, {'foo': 4}), - )), + types.Transition({"foo": 1}, 0, 0.0, 1.0, {"foo": 2}), + types.Transition({"foo": 2}, 0, 0.0, 1.0, {"foo": 3}), + types.Transition({"foo": 3}, 0, 1.0, 0.0, {"foo": 4}), + ), + ), dict( - testcase_name='OneStepExtras', + testcase_name="OneStepExtras", n_step=1, additional_discount=1.0, first=dm_env.restart(1), steps=( - ( - 0, - dm_env.transition(reward=0.0, observation=2), - { - 'state': 0 - }, - ), - ( - 0, - dm_env.transition(reward=0.0, observation=3), - { - 'state': 1 - }, - ), - ( - 0, - dm_env.termination(reward=1.0, observation=4), - { - 'state': 2 - }, - ), + (0, dm_env.transition(reward=0.0, observation=2), {"state": 0},), + (0, dm_env.transition(reward=0.0, observation=3), {"state": 1},), + (0, dm_env.termination(reward=1.0, observation=4), {"state": 2},), ), expected_transitions=( - types.Transition(1, 0, 0.0, 1.0, 2, {'state': 0}), - types.Transition(2, 0, 0.0, 1.0, 3, {'state': 1}), - types.Transition(3, 0, 1.0, 0.0, 4, {'state': 2}), - )), + types.Transition(1, 0, 0.0, 1.0, 2, {"state": 0}), + types.Transition(2, 0, 0.0, 1.0, 3, {"state": 1}), + types.Transition(3, 0, 1.0, 0.0, 4, {"state": 2}), + ), + ), dict( - testcase_name='OneStepExtrasZeroes', + testcase_name="OneStepExtrasZeroes", n_step=1, additional_discount=1.0, first=dm_env.restart(1), steps=( - ( - 0, - dm_env.transition(reward=0.0, observation=2), - np.zeros(1), - ), - ( - 0, - dm_env.transition(reward=0.0, observation=3), - np.zeros(1), - ), - ( - 0, - dm_env.termination(reward=1.0, observation=4), - np.zeros(1), - ), + (0, dm_env.transition(reward=0.0, observation=2), np.zeros(1),), + (0, dm_env.transition(reward=0.0, observation=3), np.zeros(1),), + (0, dm_env.termination(reward=1.0, observation=4), np.zeros(1),), ), expected_transitions=( types.Transition(1, 0, 0.0, 1.0, 2, np.zeros(1)), types.Transition(2, 0, 0.0, 1.0, 3, np.zeros(1)), types.Transition(3, 0, 1.0, 0.0, 4, np.zeros(1)), - )), + ), + ), dict( - testcase_name='TwoStep', + testcase_name="TwoStep", n_step=2, additional_discount=1.0, first=dm_env.restart(1), @@ -130,17 +105,16 @@ types.Transition(1, 0, 1.5, 0.25, 3), types.Transition(2, 0, 1.5, 0.00, 4), types.Transition(3, 0, 1.0, 0.00, 4), - )), + ), + ), dict( - testcase_name='TwoStepStructuredReward', + testcase_name="TwoStepStructuredReward", n_step=2, additional_discount=1.0, first=dm_env.restart(1), steps=( - (0, - dm_env.transition(reward=(1.0, 2.0), observation=2, discount=0.5)), - (0, - dm_env.transition(reward=(1.0, 2.0), observation=3, discount=0.5)), + (0, dm_env.transition(reward=(1.0, 2.0), observation=2, discount=0.5)), + (0, dm_env.transition(reward=(1.0, 2.0), observation=3, discount=0.5)), (0, dm_env.termination(reward=(1.0, 2.0), observation=4)), ), expected_transitions=( @@ -148,175 +122,205 @@ types.Transition(1, 0, (1.5, 3.0), (0.25, 0.25), 3), types.Transition(2, 0, (1.5, 3.0), (0.00, 0.00), 4), types.Transition(3, 0, (1.0, 2.0), (0.00, 0.00), 4), - )), + ), + ), dict( - testcase_name='TwoStepNDArrayReward', + testcase_name="TwoStepNDArrayReward", n_step=2, additional_discount=1.0, first=dm_env.restart(1), steps=( - (0, - dm_env.transition( - reward=np.array((1.0, 2.0)), observation=2, discount=0.5)), - (0, - dm_env.transition( - reward=np.array((1.0, 2.0)), observation=3, discount=0.5)), + ( + 0, + dm_env.transition( + reward=np.array((1.0, 2.0)), observation=2, discount=0.5 + ), + ), + ( + 0, + dm_env.transition( + reward=np.array((1.0, 2.0)), observation=3, discount=0.5 + ), + ), (0, dm_env.termination(reward=np.array((1.0, 2.0)), observation=4)), ), expected_transitions=( - types.Transition(1, 0, np.array((1.0, 2.0)), np.array((0.50, 0.50)), - 2), - types.Transition(1, 0, np.array((1.5, 3.0)), np.array((0.25, 0.25)), - 3), - types.Transition(2, 0, np.array((1.5, 3.0)), np.array((0.00, 0.00)), - 4), - types.Transition(3, 0, np.array((1.0, 2.0)), np.array((0.00, 0.00)), - 4), - )), + types.Transition(1, 0, np.array((1.0, 2.0)), np.array((0.50, 0.50)), 2), + types.Transition(1, 0, np.array((1.5, 3.0)), np.array((0.25, 0.25)), 3), + types.Transition(2, 0, np.array((1.5, 3.0)), np.array((0.00, 0.00)), 4), + types.Transition(3, 0, np.array((1.0, 2.0)), np.array((0.00, 0.00)), 4), + ), + ), dict( - testcase_name='TwoStepStructuredDiscount', + testcase_name="TwoStepStructuredDiscount", n_step=2, additional_discount=1.0, first=dm_env.restart(1), steps=( - (0, - dm_env.transition( - reward=1.0, observation=2, discount={ - 'a': 0.5, - 'b': 0.1 - })), - (0, - dm_env.transition( - reward=1.0, observation=3, discount={ - 'a': 0.5, - 'b': 0.1 - })), - (0, dm_env.termination(reward=1.0, - observation=4)._replace(discount={ - 'a': 0.0, - 'b': 0.0 - })), + ( + 0, + dm_env.transition( + reward=1.0, observation=2, discount={"a": 0.5, "b": 0.1} + ), + ), + ( + 0, + dm_env.transition( + reward=1.0, observation=3, discount={"a": 0.5, "b": 0.1} + ), + ), + ( + 0, + dm_env.termination(reward=1.0, observation=4)._replace( + discount={"a": 0.0, "b": 0.0} + ), + ), ), expected_transitions=( - types.Transition(1, 0, { - 'a': 1.0, - 'b': 1.0 - }, { - 'a': 0.50, - 'b': 0.10 - }, 2), - types.Transition(1, 0, { - 'a': 1.5, - 'b': 1.1 - }, { - 'a': 0.25, - 'b': 0.01 - }, 3), - types.Transition(2, 0, { - 'a': 1.5, - 'b': 1.1 - }, { - 'a': 0.00, - 'b': 0.00 - }, 4), - types.Transition(3, 0, { - 'a': 1.0, - 'b': 1.0 - }, { - 'a': 0.00, - 'b': 0.00 - }, 4), - )), + types.Transition(1, 0, {"a": 1.0, "b": 1.0}, {"a": 0.50, "b": 0.10}, 2), + types.Transition(1, 0, {"a": 1.5, "b": 1.1}, {"a": 0.25, "b": 0.01}, 3), + types.Transition(2, 0, {"a": 1.5, "b": 1.1}, {"a": 0.00, "b": 0.00}, 4), + types.Transition(3, 0, {"a": 1.0, "b": 1.0}, {"a": 0.00, "b": 0.00}, 4), + ), + ), dict( - testcase_name='TwoStepNDArrayDiscount', + testcase_name="TwoStepNDArrayDiscount", n_step=2, additional_discount=1.0, first=dm_env.restart(1), steps=( - (0, - dm_env.transition( - reward=1.0, observation=2, discount=np.array((0.5, 0.1)))), - (0, - dm_env.transition( - reward=1.0, observation=3, discount=np.array((0.5, 0.1)))), - (0, dm_env.termination( - reward=1.0, - observation=4)._replace(discount=np.array((0.0, 0.0)))), + ( + 0, + dm_env.transition( + reward=1.0, observation=2, discount=np.array((0.5, 0.1)) + ), + ), + ( + 0, + dm_env.transition( + reward=1.0, observation=3, discount=np.array((0.5, 0.1)) + ), + ), + ( + 0, + dm_env.termination(reward=1.0, observation=4)._replace( + discount=np.array((0.0, 0.0)) + ), + ), ), expected_transitions=( - types.Transition(1, 0, np.array((1.0, 1.0)), np.array((0.50, 0.10)), - 2), - types.Transition(1, 0, np.array((1.5, 1.1)), np.array((0.25, 0.01)), - 3), - types.Transition(2, 0, np.array((1.5, 1.1)), np.array((0.00, 0.00)), - 4), - types.Transition(3, 0, np.array((1.0, 1.0)), np.array((0.00, 0.00)), - 4), - )), + types.Transition(1, 0, np.array((1.0, 1.0)), np.array((0.50, 0.10)), 2), + types.Transition(1, 0, np.array((1.5, 1.1)), np.array((0.25, 0.01)), 3), + types.Transition(2, 0, np.array((1.5, 1.1)), np.array((0.00, 0.00)), 4), + types.Transition(3, 0, np.array((1.0, 1.0)), np.array((0.00, 0.00)), 4), + ), + ), dict( - testcase_name='TwoStepBroadcastedNDArrays', + testcase_name="TwoStepBroadcastedNDArrays", n_step=2, additional_discount=1.0, first=dm_env.restart(1), steps=( - (0, - dm_env.transition( - reward=np.array([[1.0, 2.0]]), - observation=2, - discount=np.array([[0.5], [0.1]]))), - (0, - dm_env.transition( - reward=np.array([[1.0, 2.0]]), - observation=3, - discount=np.array([[0.5], [0.1]]))), - (0, dm_env.termination( - reward=np.array([[1.0, 2.0]]), - observation=4)._replace(discount=np.array([[0.0], [0.0]]))), + ( + 0, + dm_env.transition( + reward=np.array([[1.0, 2.0]]), + observation=2, + discount=np.array([[0.5], [0.1]]), + ), + ), + ( + 0, + dm_env.transition( + reward=np.array([[1.0, 2.0]]), + observation=3, + discount=np.array([[0.5], [0.1]]), + ), + ), + ( + 0, + dm_env.termination( + reward=np.array([[1.0, 2.0]]), observation=4 + )._replace(discount=np.array([[0.0], [0.0]])), + ), ), expected_transitions=( - types.Transition(1, 0, np.array([[1.0, 2.0], [1.0, 2.0]]), - np.array([[0.50], [0.10]]), 2), - types.Transition(1, 0, np.array([[1.5, 3.0], [1.1, 2.2]]), - np.array([[0.25], [0.01]]), 3), - types.Transition(2, 0, np.array([[1.5, 3.0], [1.1, 2.2]]), - np.array([[0.00], [0.00]]), 4), - types.Transition(3, 0, np.array([[1.0, 2.0], [1.0, 2.0]]), - np.array([[0.00], [0.00]]), 4), - )), + types.Transition( + 1, 0, np.array([[1.0, 2.0], [1.0, 2.0]]), np.array([[0.50], [0.10]]), 2 + ), + types.Transition( + 1, 0, np.array([[1.5, 3.0], [1.1, 2.2]]), np.array([[0.25], [0.01]]), 3 + ), + types.Transition( + 2, 0, np.array([[1.5, 3.0], [1.1, 2.2]]), np.array([[0.00], [0.00]]), 4 + ), + types.Transition( + 3, 0, np.array([[1.0, 2.0], [1.0, 2.0]]), np.array([[0.00], [0.00]]), 4 + ), + ), + ), dict( - testcase_name='TwoStepStructuredBroadcastedNDArrays', + testcase_name="TwoStepStructuredBroadcastedNDArrays", n_step=2, additional_discount=1.0, first=dm_env.restart(1), steps=( - (0, - dm_env.transition( - reward={'a': np.array([[1.0, 2.0]])}, - observation=2, - discount=np.array([[0.5], [0.1]]))), - (0, - dm_env.transition( - reward={'a': np.array([[1.0, 2.0]])}, - observation=3, - discount=np.array([[0.5], [0.1]]))), - (0, - dm_env.termination( - reward={ - 'a': np.array([[1.0, 2.0]]) - }, observation=4)._replace(discount=np.array([[0.0], [0.0]]))), + ( + 0, + dm_env.transition( + reward={"a": np.array([[1.0, 2.0]])}, + observation=2, + discount=np.array([[0.5], [0.1]]), + ), + ), + ( + 0, + dm_env.transition( + reward={"a": np.array([[1.0, 2.0]])}, + observation=3, + discount=np.array([[0.5], [0.1]]), + ), + ), + ( + 0, + dm_env.termination( + reward={"a": np.array([[1.0, 2.0]])}, observation=4 + )._replace(discount=np.array([[0.0], [0.0]])), + ), ), expected_transitions=( - types.Transition(1, 0, {'a': np.array([[1.0, 2.0], [1.0, 2.0]])}, - {'a': np.array([[0.50], [0.10]])}, 2), - types.Transition(1, 0, {'a': np.array([[1.5, 3.0], [1.1, 2.2]])}, - {'a': np.array([[0.25], [0.01]])}, 3), - types.Transition(2, 0, {'a': np.array([[1.5, 3.0], [1.1, 2.2]])}, - {'a': np.array([[0.00], [0.00]])}, 4), - types.Transition(3, 0, {'a': np.array([[1.0, 2.0], [1.0, 2.0]])}, - {'a': np.array([[0.00], [0.00]])}, 4), - )), + types.Transition( + 1, + 0, + {"a": np.array([[1.0, 2.0], [1.0, 2.0]])}, + {"a": np.array([[0.50], [0.10]])}, + 2, + ), + types.Transition( + 1, + 0, + {"a": np.array([[1.5, 3.0], [1.1, 2.2]])}, + {"a": np.array([[0.25], [0.01]])}, + 3, + ), + types.Transition( + 2, + 0, + {"a": np.array([[1.5, 3.0], [1.1, 2.2]])}, + {"a": np.array([[0.00], [0.00]])}, + 4, + ), + types.Transition( + 3, + 0, + {"a": np.array([[1.0, 2.0], [1.0, 2.0]])}, + {"a": np.array([[0.00], [0.00]])}, + 4, + ), + ), + ), dict( - testcase_name='TwoStepWithExtras', + testcase_name="TwoStepWithExtras", n_step=2, additional_discount=1.0, first=dm_env.restart(1), @@ -324,33 +328,24 @@ ( 0, dm_env.transition(reward=1.0, observation=2, discount=0.5), - { - 'state': 0 - }, + {"state": 0}, ), ( 0, dm_env.transition(reward=1.0, observation=3, discount=0.5), - { - 'state': 1 - }, - ), - ( - 0, - dm_env.termination(reward=1.0, observation=4), - { - 'state': 2 - }, + {"state": 1}, ), + (0, dm_env.termination(reward=1.0, observation=4), {"state": 2},), ), expected_transitions=( - types.Transition(1, 0, 1.0, 0.50, 2, {'state': 0}), - types.Transition(1, 0, 1.5, 0.25, 3, {'state': 0}), - types.Transition(2, 0, 1.5, 0.00, 4, {'state': 1}), - types.Transition(3, 0, 1.0, 0.00, 4, {'state': 2}), - )), + types.Transition(1, 0, 1.0, 0.50, 2, {"state": 0}), + types.Transition(1, 0, 1.5, 0.25, 3, {"state": 0}), + types.Transition(2, 0, 1.5, 0.00, 4, {"state": 1}), + types.Transition(3, 0, 1.0, 0.00, 4, {"state": 2}), + ), + ), dict( - testcase_name='ThreeStepDiscounted', + testcase_name="ThreeStepDiscounted", n_step=3, additional_discount=0.4, first=dm_env.restart(1), @@ -365,9 +360,10 @@ types.Transition(1, 0, 1.24, 0.0, 4), types.Transition(2, 0, 1.20, 0.0, 4), types.Transition(3, 0, 1.00, 0.0, 4), - )), + ), + ), dict( - testcase_name='ThreeStepVaryingReward', + testcase_name="ThreeStepVaryingReward", n_step=3, additional_discount=0.5, first=dm_env.restart(1), @@ -384,16 +380,18 @@ types.Transition(2, 0, 3 + 0.5 * 5 + 0.25 * 7, 0.00, 5), types.Transition(3, 0, 5 + 0.5 * 7, 0.00, 5), types.Transition(4, 0, 7, 0.00, 5), - )), + ), + ), dict( - testcase_name='SingleTransitionEpisode', + testcase_name="SingleTransitionEpisode", n_step=4, additional_discount=1.0, first=dm_env.restart(1), steps=((0, dm_env.termination(reward=1.0, observation=2)),), - expected_transitions=(types.Transition(1, 0, 1.00, 0.0, 2),)), + expected_transitions=(types.Transition(1, 0, 1.00, 0.0, 2),), + ), dict( - testcase_name='EpisodeShorterThanN', + testcase_name="EpisodeShorterThanN", n_step=4, additional_discount=1.0, first=dm_env.restart(1), @@ -405,9 +403,10 @@ types.Transition(1, 0, 1.00, 1.0, 2), types.Transition(1, 0, 2.00, 0.0, 3), types.Transition(2, 0, 1.00, 0.0, 3), - )), + ), + ), dict( - testcase_name='EpisodeEqualToN', + testcase_name="EpisodeEqualToN", n_step=3, additional_discount=1.0, first=dm_env.restart(1), @@ -419,12 +418,13 @@ types.Transition(1, 0, 1.00, 1.0, 2), types.Transition(1, 0, 2.00, 0.0, 3), types.Transition(2, 0, 1.00, 0.0, 3), - )), + ), + ), ] BASE_TEST_CASES_FOR_SEQUENCE_ADDER = [ dict( - testcase_name='PeriodOne', + testcase_name="PeriodOne", sequence_length=3, period=1, first=dm_env.restart(1), @@ -455,7 +455,7 @@ end_behavior=sequence_adder.EndBehavior.ZERO_PAD, ), dict( - testcase_name='PeriodTwo', + testcase_name="PeriodTwo", sequence_length=3, period=2, first=dm_env.restart(1), @@ -481,7 +481,7 @@ end_behavior=sequence_adder.EndBehavior.ZERO_PAD, ), dict( - testcase_name='EarlyTerminationPeriodOne', + testcase_name="EarlyTerminationPeriodOne", sequence_length=3, period=1, first=dm_env.restart(1), @@ -495,11 +495,12 @@ (1, 0, 2.0, 1.0, True, ()), (2, 0, 3.0, 0.0, False, ()), (3, 0, 0.0, 0.0, False, ()), - ],), + ], + ), end_behavior=sequence_adder.EndBehavior.ZERO_PAD, ), dict( - testcase_name='EarlyTerminationPeriodTwo', + testcase_name="EarlyTerminationPeriodTwo", sequence_length=3, period=2, first=dm_env.restart(1), @@ -513,11 +514,12 @@ (1, 0, 2.0, 1.0, True, ()), (2, 0, 3.0, 0.0, False, ()), (3, 0, 0.0, 0.0, False, ()), - ],), + ], + ), end_behavior=sequence_adder.EndBehavior.ZERO_PAD, ), dict( - testcase_name='EarlyTerminationPaddingPeriodOne', + testcase_name="EarlyTerminationPaddingPeriodOne", sequence_length=4, period=1, first=dm_env.restart(1), @@ -532,11 +534,12 @@ (2, 0, 3.0, 0.0, False, ()), (3, 0, 0.0, 0.0, False, ()), (0, 0, 0.0, 0.0, False, ()), - ],), + ], + ), end_behavior=sequence_adder.EndBehavior.ZERO_PAD, ), dict( - testcase_name='EarlyTerminationPaddingPeriodTwo', + testcase_name="EarlyTerminationPaddingPeriodTwo", sequence_length=4, period=2, first=dm_env.restart(1), @@ -551,11 +554,12 @@ (2, 0, 3.0, 0.0, False, ()), (3, 0, 0.0, 0.0, False, ()), (0, 0, 0.0, 0.0, False, ()), - ],), + ], + ), end_behavior=sequence_adder.EndBehavior.ZERO_PAD, ), dict( - testcase_name='EarlyTerminationNoPadding', + testcase_name="EarlyTerminationNoPadding", sequence_length=4, period=1, first=dm_env.restart(1), @@ -569,11 +573,12 @@ (1, 0, 2.0, 1.0, True, ()), (2, 0, 3.0, 0.0, False, ()), (3, 0, 0.0, 0.0, False, ()), - ],), + ], + ), end_behavior=sequence_adder.EndBehavior.TRUNCATE, ), dict( - testcase_name='LongEpisodePadding', + testcase_name="LongEpisodePadding", sequence_length=3, period=3, first=dm_env.restart(1), @@ -607,7 +612,7 @@ end_behavior=sequence_adder.EndBehavior.ZERO_PAD, ), dict( - testcase_name='LongEpisodeNoPadding', + testcase_name="LongEpisodeNoPadding", sequence_length=3, period=3, first=dm_env.restart(1), @@ -632,15 +637,12 @@ (5, 0, 9.0, 1.0, False, ()), (6, 0, 11.0, 1.0, False, ()), ], - [ - (7, 0, 13.0, 0.0, False, ()), - (8, 0, 0.0, 0.0, False, ()), - ], + [(7, 0, 13.0, 0.0, False, ()), (8, 0, 0.0, 0.0, False, ()),], ), end_behavior=sequence_adder.EndBehavior.TRUNCATE, ), dict( - testcase_name='EndBehavior_WRITE', + testcase_name="EndBehavior_WRITE", sequence_length=3, period=2, first=dm_env.restart(1), @@ -675,7 +677,7 @@ TEST_CASES_FOR_SEQUENCE_ADDER = BASE_TEST_CASES_FOR_SEQUENCE_ADDER + [ dict( - testcase_name='NonBreakingSequenceOnEpisodeReset', + testcase_name="NonBreakingSequenceOnEpisodeReset", sequence_length=3, period=2, first=dm_env.restart(1), @@ -707,9 +709,10 @@ ], ), end_behavior=sequence_adder.EndBehavior.CONTINUE, - repeat_episode_times=1), + repeat_episode_times=1, + ), dict( - testcase_name='NonBreakingSequenceMultipleTerminatedEpisodes', + testcase_name="NonBreakingSequenceMultipleTerminatedEpisodes", sequence_length=3, period=2, first=dm_env.restart(1), @@ -775,9 +778,10 @@ ], ), end_behavior=sequence_adder.EndBehavior.CONTINUE, - repeat_episode_times=3), + repeat_episode_times=3, + ), dict( - testcase_name='NonBreakingSequenceMultipleTruncatedEpisodes', + testcase_name="NonBreakingSequenceMultipleTruncatedEpisodes", sequence_length=3, period=2, first=dm_env.restart(1), @@ -843,5 +847,6 @@ ], ), end_behavior=sequence_adder.EndBehavior.CONTINUE, - repeat_episode_times=3), + repeat_episode_times=3, + ), ] diff --git a/acme/adders/reverb/test_utils.py b/acme/adders/reverb/test_utils.py index 6ed9a9ac0b..2ebe55d18f 100644 --- a/acme/adders/reverb/test_utils.py +++ b/acme/adders/reverb/test_utils.py @@ -16,26 +16,25 @@ from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, Union -from acme import specs -from acme import types -from acme.adders import base as adders_base -from acme.adders import reverb as adders -from acme.utils import tree_utils import dm_env import numpy as np import reverb import tensorflow as tf import tree - from absl.testing import absltest +from acme import specs, types +from acme.adders import base as adders_base +from acme.adders import reverb as adders +from acme.utils import tree_utils + StepWithExtra = Tuple[Any, dm_env.TimeStep, Any] StepWithoutExtra = Tuple[Any, dm_env.TimeStep] -Step = TypeVar('Step', StepWithExtra, StepWithoutExtra) +Step = TypeVar("Step", StepWithExtra, StepWithoutExtra) def make_trajectory(observations): - """Make a simple trajectory from a sequence of observations. + """Make a simple trajectory from a sequence of observations. Arguments: observations: a sequence of observations. @@ -45,112 +44,126 @@ def make_trajectory(observations): object and steps contains a list of (action, step) tuples. The length of steps is given by episode_length. """ - first = dm_env.restart(observations[0]) - middle = [(0, dm_env.transition(reward=0.0, observation=observation)) - for observation in observations[1:-1]] - last = (0, dm_env.termination(reward=0.0, observation=observations[-1])) - return first, middle + [last] + first = dm_env.restart(observations[0]) + middle = [ + (0, dm_env.transition(reward=0.0, observation=observation)) + for observation in observations[1:-1] + ] + last = (0, dm_env.termination(reward=0.0, observation=observations[-1])) + return first, middle + [last] def make_sequence(observations): - """Create a sequence of timesteps of the form `first, [second, ..., last]`.""" - first, steps = make_trajectory(observations) - observation = first.observation - sequence = [] - start_of_episode = True - for action, timestep in steps: - extras = () - sequence.append((observation, action, timestep.reward, timestep.discount, - start_of_episode, extras)) - observation = timestep.observation - start_of_episode = False - sequence.append((observation, 0, 0.0, 0.0, False, ())) - return sequence + """Create a sequence of timesteps of the form `first, [second, ..., last]`.""" + first, steps = make_trajectory(observations) + observation = first.observation + sequence = [] + start_of_episode = True + for action, timestep in steps: + extras = () + sequence.append( + ( + observation, + action, + timestep.reward, + timestep.discount, + start_of_episode, + extras, + ) + ) + observation = timestep.observation + start_of_episode = False + sequence.append((observation, 0, 0.0, 0.0, False, ())) + return sequence def _numeric_to_spec(x: Union[float, int, np.ndarray]): - if isinstance(x, np.ndarray): - return specs.Array(shape=x.shape, dtype=x.dtype) - elif isinstance(x, (float, int)): - return specs.Array(shape=(), dtype=type(x)) - else: - raise ValueError(f'Unsupported numeric: {type(x)}') + if isinstance(x, np.ndarray): + return specs.Array(shape=x.shape, dtype=x.dtype) + elif isinstance(x, (float, int)): + return specs.Array(shape=(), dtype=type(x)) + else: + raise ValueError(f"Unsupported numeric: {type(x)}") def get_specs(step): - """Infer spec from an example step.""" - env_spec = tree.map_structure( - _numeric_to_spec, - specs.EnvironmentSpec( - observations=step[1].observation, - actions=step[0], - rewards=step[1].reward, - discounts=step[1].discount)) - - has_extras = len(step) == 3 - if has_extras: - extras_spec = tree.map_structure(_numeric_to_spec, step[2]) - else: - extras_spec = () - - return env_spec, extras_spec + """Infer spec from an example step.""" + env_spec = tree.map_structure( + _numeric_to_spec, + specs.EnvironmentSpec( + observations=step[1].observation, + actions=step[0], + rewards=step[1].reward, + discounts=step[1].discount, + ), + ) + + has_extras = len(step) == 3 + if has_extras: + extras_spec = tree.map_structure(_numeric_to_spec, step[2]) + else: + extras_spec = () + + return env_spec, extras_spec class AdderTestMixin(absltest.TestCase): - """A helper mixin for testing Reverb adders. + """A helper mixin for testing Reverb adders. Note that any test inheriting from this mixin must also inherit from something that provides the Python unittest assert methods. """ - server: reverb.Server - client: reverb.Client - - @classmethod - def setUpClass(cls): - super().setUpClass() - - replay_table = reverb.Table.queue(adders.DEFAULT_PRIORITY_TABLE, 1000) - cls.server = reverb.Server([replay_table]) - cls.client = reverb.Client(f'localhost:{cls.server.port}') - - def tearDown(self): - self.client.reset(adders.DEFAULT_PRIORITY_TABLE) - super().tearDown() - - @classmethod - def tearDownClass(cls): - cls.server.stop() - super().tearDownClass() - - def num_episodes(self): - info = self.client.server_info(1)[adders.DEFAULT_PRIORITY_TABLE] - return info.num_episodes - - def num_items(self): - info = self.client.server_info(1)[adders.DEFAULT_PRIORITY_TABLE] - return info.current_size - - def items(self): - sampler = self.client.sample( - table=adders.DEFAULT_PRIORITY_TABLE, - num_samples=self.num_items(), - emit_timesteps=False) - return [sample.data for sample in sampler] # pytype: disable=attribute-error - - def run_test_adder( - self, - adder: adders_base.Adder, - first: dm_env.TimeStep, - steps: Sequence[Step], - expected_items: Sequence[Any], - signature: types.NestedSpec, - pack_expected_items: bool = False, - stack_sequence_fields: bool = True, - repeat_episode_times: int = 1, - end_behavior: adders.EndBehavior = adders.EndBehavior.ZERO_PAD, - item_transform: Optional[Callable[[Sequence[np.ndarray]], Any]] = None): - """Runs a unit test case for the adder. + server: reverb.Server + client: reverb.Client + + @classmethod + def setUpClass(cls): + super().setUpClass() + + replay_table = reverb.Table.queue(adders.DEFAULT_PRIORITY_TABLE, 1000) + cls.server = reverb.Server([replay_table]) + cls.client = reverb.Client(f"localhost:{cls.server.port}") + + def tearDown(self): + self.client.reset(adders.DEFAULT_PRIORITY_TABLE) + super().tearDown() + + @classmethod + def tearDownClass(cls): + cls.server.stop() + super().tearDownClass() + + def num_episodes(self): + info = self.client.server_info(1)[adders.DEFAULT_PRIORITY_TABLE] + return info.num_episodes + + def num_items(self): + info = self.client.server_info(1)[adders.DEFAULT_PRIORITY_TABLE] + return info.current_size + + def items(self): + sampler = self.client.sample( + table=adders.DEFAULT_PRIORITY_TABLE, + num_samples=self.num_items(), + emit_timesteps=False, + ) + return [sample.data for sample in sampler] # pytype: disable=attribute-error + + def run_test_adder( + self, + adder: adders_base.Adder, + first: dm_env.TimeStep, + steps: Sequence[Step], + expected_items: Sequence[Any], + signature: types.NestedSpec, + pack_expected_items: bool = False, + stack_sequence_fields: bool = True, + repeat_episode_times: int = 1, + end_behavior: adders.EndBehavior = adders.EndBehavior.ZERO_PAD, + item_transform: Optional[Callable[[Sequence[np.ndarray]], Any]] = None, + ): + """Runs a unit test case for the adder. Args: adder: The instance of `Adder` that is being tested. @@ -173,61 +186,64 @@ def run_test_adder( dataset pipeline on the learner in a real setup. """ - del pack_expected_items + del pack_expected_items - if not steps: - raise ValueError('At least one step must be given.') + if not steps: + raise ValueError("At least one step must be given.") - has_extras = len(steps[0]) == 3 - for _ in range(repeat_episode_times): - # Add all the data up to the final step. - adder.add_first(first) - for step in steps[:-1]: - action, ts = step[0], step[1] + has_extras = len(steps[0]) == 3 + for _ in range(repeat_episode_times): + # Add all the data up to the final step. + adder.add_first(first) + for step in steps[:-1]: + action, ts = step[0], step[1] - if has_extras: - extras = step[2] - else: - extras = () + if has_extras: + extras = step[2] + else: + extras = () - adder.add(action, next_timestep=ts, extras=extras) + adder.add(action, next_timestep=ts, extras=extras) - # Add the final step. - adder.add(*steps[-1]) + # Add the final step. + adder.add(*steps[-1]) - # Force run the destructor to trigger the flushing of all pending items. - getattr(adder, '__del__', lambda: None)() + # Force run the destructor to trigger the flushing of all pending items. + getattr(adder, "__del__", lambda: None)() - # Ending the episode should close the writer. No new writer should yet have - # been created as it is constructed lazily. - if end_behavior is not adders.EndBehavior.CONTINUE: - self.assertEqual(self.num_episodes(), repeat_episode_times) + # Ending the episode should close the writer. No new writer should yet have + # been created as it is constructed lazily. + if end_behavior is not adders.EndBehavior.CONTINUE: + self.assertEqual(self.num_episodes(), repeat_episode_times) - # Make sure our expected and observed data match. - observed_items = self.items() + # Make sure our expected and observed data match. + observed_items = self.items() - # Check matching number of items. - self.assertEqual(len(expected_items), len(observed_items)) + # Check matching number of items. + self.assertEqual(len(expected_items), len(observed_items)) - # Check items are matching according to numpy's almost_equal. - for expected_item, observed_item in zip(expected_items, observed_items): - if stack_sequence_fields: - expected_item = tree_utils.stack_sequence_fields(expected_item) + # Check items are matching according to numpy's almost_equal. + for expected_item, observed_item in zip(expected_items, observed_items): + if stack_sequence_fields: + expected_item = tree_utils.stack_sequence_fields(expected_item) - # Apply the transformation which would be done by the dataset in a real - # setup. - if item_transform: - observed_item = item_transform(observed_item) + # Apply the transformation which would be done by the dataset in a real + # setup. + if item_transform: + observed_item = item_transform(observed_item) - tree.map_structure(np.testing.assert_array_almost_equal, - tree.flatten(expected_item), - tree.flatten(observed_item)) + tree.map_structure( + np.testing.assert_array_almost_equal, + tree.flatten(expected_item), + tree.flatten(observed_item), + ) - # Make sure the signature matches was is being written by Reverb. - def _check_signature(spec: tf.TensorSpec, value: np.ndarray): - self.assertTrue(spec.is_compatible_with(tf.convert_to_tensor(value))) + # Make sure the signature matches was is being written by Reverb. + def _check_signature(spec: tf.TensorSpec, value: np.ndarray): + self.assertTrue(spec.is_compatible_with(tf.convert_to_tensor(value))) - # Check that it is possible to unpack observed using the signature. - for item in observed_items: - tree.map_structure(_check_signature, tree.flatten(signature), - tree.flatten(item)) + # Check that it is possible to unpack observed using the signature. + for item in observed_items: + tree.map_structure( + _check_signature, tree.flatten(signature), tree.flatten(item) + ) diff --git a/acme/adders/reverb/transition.py b/acme/adders/reverb/transition.py index fe3e16f75b..742b8d93ec 100644 --- a/acme/adders/reverb/transition.py +++ b/acme/adders/reverb/transition.py @@ -21,19 +21,17 @@ import copy from typing import Optional, Tuple -from acme import specs -from acme import types -from acme.adders.reverb import base -from acme.adders.reverb import utils -from acme.utils import tree_utils - import numpy as np import reverb import tree +from acme import specs, types +from acme.adders.reverb import base, utils +from acme.utils import tree_utils + class NStepTransitionAdder(base.ReverbAdder): - """An N-step transition adder. + """An N-step transition adder. This will buffer a sequence of N timesteps in order to form a single N-step transition which is added to reverb for future retrieval. @@ -83,16 +81,16 @@ class NStepTransitionAdder(base.ReverbAdder): if extras are provided, we get e_t, not e_{t+n}. """ - def __init__( - self, - client: reverb.Client, - n_step: int, - discount: float, - *, - priority_fns: Optional[base.PriorityFnMapping] = None, - max_in_flight_items: int = 5, - ): - """Creates an N-step transition adder. + def __init__( + self, + client: reverb.Client, + n_step: int, + discount: float, + *, + priority_fns: Optional[base.PriorityFnMapping] = None, + max_in_flight_items: int = 5, + ): + """Creates an N-step transition adder. Args: client: A `reverb.Client` to send the data to replay through. @@ -110,191 +108,207 @@ def __init__( Raises: ValueError: If n_step is less than 1. """ - # Makes the additional discount a float32, which means that it will be - # upcast if rewards/discounts are float64 and left alone otherwise. - self.n_step = n_step - self._discount = tree.map_structure(np.float32, discount) - self._first_idx = 0 - self._last_idx = 0 - - super().__init__( - client=client, - max_sequence_length=n_step + 1, - priority_fns=priority_fns, - max_in_flight_items=max_in_flight_items) - - def add(self, *args, **kwargs): - # Increment the indices for the start and end of the window for computing - # n-step returns. - if self._writer.episode_steps >= self.n_step: - self._first_idx += 1 - self._last_idx += 1 - - super().add(*args, **kwargs) - - def reset(self): - super().reset() - self._first_idx = 0 - self._last_idx = 0 - - @property - def _n_step(self) -> int: - """Effective n-step, which may vary at starts and ends of episodes.""" - return self._last_idx - self._first_idx - - def _write(self): - # Convenient getters for use in tree operations. - get_first = lambda x: x[self._first_idx] - get_last = lambda x: x[self._last_idx] - # Note: this getter is meant to be used on a TrajectoryWriter.history to - # obtain its numpy values. - get_all_np = lambda x: x[self._first_idx:self._last_idx].numpy() - - # Get the state, action, next_state, as well as possibly extras for the - # transition that is about to be written. - history = self._writer.history - s, a = tree.map_structure(get_first, - (history['observation'], history['action'])) - s_ = tree.map_structure(get_last, history['observation']) - - # Maybe get extras to add to the transition later. - if 'extras' in history: - extras = tree.map_structure(get_first, history['extras']) - - # Note: at the beginning of an episode we will add the initial N-1 - # transitions (of size 1, 2, ...) and at the end of an episode (when - # called from write_last) we will write the final transitions of size (N, - # N-1, ...). See the Note in the docstring. - # Get numpy view of the steps to be fed into the priority functions. - reward, discount = tree.map_structure( - get_all_np, (history['reward'], history['discount'])) - - # Compute discounted return and geometric discount over n steps. - n_step_return, total_discount = self._compute_cumulative_quantities( - reward, discount) - - # Append the computed n-step return and total discount. - # Note: if this call to _write() is within a call to _write_last(), then - # this is the only data being appended and so it is not a partial append. - self._writer.append( - dict(n_step_return=n_step_return, total_discount=total_discount), - partial_step=self._writer.episode_steps <= self._last_idx) - # This should be done immediately after self._writer.append so the history - # includes the recently appended data. - history = self._writer.history - - # Form the n-step transition by using the following: - # the first observation and action in the buffer, along with the cumulative - # reward and discount computed above. - n_step_return, total_discount = tree.map_structure( - lambda x: x[-1], (history['n_step_return'], history['total_discount'])) - transition = types.Transition( - observation=s, - action=a, - reward=n_step_return, - discount=total_discount, - next_observation=s_, - extras=(extras if 'extras' in history else ())) - - # Calculate the priority for this transition. - table_priorities = utils.calculate_priorities(self._priority_fns, - transition) - - # Insert the transition into replay along with its priority. - for table, priority in table_priorities.items(): - self._writer.create_item( - table=table, priority=priority, trajectory=transition) - self._writer.flush(self._max_in_flight_items) - - def _write_last(self): - # Write the remaining shorter transitions by alternating writing and - # incrementingfirst_idx. Note that last_idx will no longer be incremented - # once we're in this method's scope. - self._first_idx += 1 - while self._first_idx < self._last_idx: - self._write() - self._first_idx += 1 - - def _compute_cumulative_quantities( - self, rewards: types.NestedArray, discounts: types.NestedArray - ) -> Tuple[types.NestedArray, types.NestedArray]: - - # Give the same tree structure to the n-step return accumulator, - # n-step discount accumulator, and self.discount, so that they can be - # iterated in parallel using tree.map_structure. - rewards, discounts, self_discount = tree_utils.broadcast_structures( - rewards, discounts, self._discount) - flat_rewards = tree.flatten(rewards) - flat_discounts = tree.flatten(discounts) - flat_self_discount = tree.flatten(self_discount) - - # Copy total_discount as it is otherwise read-only. - total_discount = [np.copy(a[0]) for a in flat_discounts] - - # Broadcast n_step_return to have the broadcasted shape of - # reward * discount. - n_step_return = [ - np.copy(np.broadcast_to(r[0], - np.broadcast(r[0], d).shape)) - for r, d in zip(flat_rewards, total_discount) - ] - - # NOTE: total_discount will have one less self_discount applied to it than - # the value of self._n_step. This is so that when the learner/update uses - # an additional discount we don't apply it twice. Inside the following loop - # we will apply this right before summing up the n_step_return. - for i in range(1, self._n_step): - for nsr, td, r, d, sd in zip(n_step_return, total_discount, flat_rewards, - flat_discounts, flat_self_discount): - # Equivalent to: `total_discount *= self._discount`. - td *= sd - # Equivalent to: `n_step_return += reward[i] * total_discount`. - nsr += r[i] * td - # Equivalent to: `total_discount *= discount[i]`. - td *= d[i] - - n_step_return = tree.unflatten_as(rewards, n_step_return) - total_discount = tree.unflatten_as(rewards, total_discount) - return n_step_return, total_discount - - # TODO(bshahr): make this into a standalone method. Class methods should be - # used as alternative constructors or when modifying some global state, - # neither of which is done here. - @classmethod - def signature(cls, - environment_spec: specs.EnvironmentSpec, - extras_spec: types.NestedSpec = ()): - - # This function currently assumes that self._discount is a scalar. - # If it ever becomes a nested structure and/or a np.ndarray, this method - # will need to know its structure / shape. This is because the signature - # discount shape is the environment's discount shape and this adder's - # discount shape broadcasted together. Also, the reward shape is this - # signature discount shape broadcasted together with the environment - # reward shape. As long as self._discount is a scalar, it will not affect - # either the signature discount shape nor the signature reward shape, so we - # can ignore it. - - rewards_spec, step_discounts_spec = tree_utils.broadcast_structures( - environment_spec.rewards, environment_spec.discounts) - rewards_spec = tree.map_structure(_broadcast_specs, rewards_spec, - step_discounts_spec) - step_discounts_spec = tree.map_structure(copy.deepcopy, step_discounts_spec) - - transition_spec = types.Transition( - environment_spec.observations, - environment_spec.actions, - rewards_spec, - step_discounts_spec, - environment_spec.observations, # next_observation - extras_spec) - - return tree.map_structure_with_path(base.spec_like_to_tensor_spec, - transition_spec) + # Makes the additional discount a float32, which means that it will be + # upcast if rewards/discounts are float64 and left alone otherwise. + self.n_step = n_step + self._discount = tree.map_structure(np.float32, discount) + self._first_idx = 0 + self._last_idx = 0 + + super().__init__( + client=client, + max_sequence_length=n_step + 1, + priority_fns=priority_fns, + max_in_flight_items=max_in_flight_items, + ) + + def add(self, *args, **kwargs): + # Increment the indices for the start and end of the window for computing + # n-step returns. + if self._writer.episode_steps >= self.n_step: + self._first_idx += 1 + self._last_idx += 1 + + super().add(*args, **kwargs) + + def reset(self): + super().reset() + self._first_idx = 0 + self._last_idx = 0 + + @property + def _n_step(self) -> int: + """Effective n-step, which may vary at starts and ends of episodes.""" + return self._last_idx - self._first_idx + + def _write(self): + # Convenient getters for use in tree operations. + get_first = lambda x: x[self._first_idx] + get_last = lambda x: x[self._last_idx] + # Note: this getter is meant to be used on a TrajectoryWriter.history to + # obtain its numpy values. + get_all_np = lambda x: x[self._first_idx : self._last_idx].numpy() + + # Get the state, action, next_state, as well as possibly extras for the + # transition that is about to be written. + history = self._writer.history + s, a = tree.map_structure( + get_first, (history["observation"], history["action"]) + ) + s_ = tree.map_structure(get_last, history["observation"]) + + # Maybe get extras to add to the transition later. + if "extras" in history: + extras = tree.map_structure(get_first, history["extras"]) + + # Note: at the beginning of an episode we will add the initial N-1 + # transitions (of size 1, 2, ...) and at the end of an episode (when + # called from write_last) we will write the final transitions of size (N, + # N-1, ...). See the Note in the docstring. + # Get numpy view of the steps to be fed into the priority functions. + reward, discount = tree.map_structure( + get_all_np, (history["reward"], history["discount"]) + ) + + # Compute discounted return and geometric discount over n steps. + n_step_return, total_discount = self._compute_cumulative_quantities( + reward, discount + ) + + # Append the computed n-step return and total discount. + # Note: if this call to _write() is within a call to _write_last(), then + # this is the only data being appended and so it is not a partial append. + self._writer.append( + dict(n_step_return=n_step_return, total_discount=total_discount), + partial_step=self._writer.episode_steps <= self._last_idx, + ) + # This should be done immediately after self._writer.append so the history + # includes the recently appended data. + history = self._writer.history + + # Form the n-step transition by using the following: + # the first observation and action in the buffer, along with the cumulative + # reward and discount computed above. + n_step_return, total_discount = tree.map_structure( + lambda x: x[-1], (history["n_step_return"], history["total_discount"]) + ) + transition = types.Transition( + observation=s, + action=a, + reward=n_step_return, + discount=total_discount, + next_observation=s_, + extras=(extras if "extras" in history else ()), + ) + + # Calculate the priority for this transition. + table_priorities = utils.calculate_priorities(self._priority_fns, transition) + + # Insert the transition into replay along with its priority. + for table, priority in table_priorities.items(): + self._writer.create_item( + table=table, priority=priority, trajectory=transition + ) + self._writer.flush(self._max_in_flight_items) + + def _write_last(self): + # Write the remaining shorter transitions by alternating writing and + # incrementingfirst_idx. Note that last_idx will no longer be incremented + # once we're in this method's scope. + self._first_idx += 1 + while self._first_idx < self._last_idx: + self._write() + self._first_idx += 1 + + def _compute_cumulative_quantities( + self, rewards: types.NestedArray, discounts: types.NestedArray + ) -> Tuple[types.NestedArray, types.NestedArray]: + + # Give the same tree structure to the n-step return accumulator, + # n-step discount accumulator, and self.discount, so that they can be + # iterated in parallel using tree.map_structure. + rewards, discounts, self_discount = tree_utils.broadcast_structures( + rewards, discounts, self._discount + ) + flat_rewards = tree.flatten(rewards) + flat_discounts = tree.flatten(discounts) + flat_self_discount = tree.flatten(self_discount) + + # Copy total_discount as it is otherwise read-only. + total_discount = [np.copy(a[0]) for a in flat_discounts] + + # Broadcast n_step_return to have the broadcasted shape of + # reward * discount. + n_step_return = [ + np.copy(np.broadcast_to(r[0], np.broadcast(r[0], d).shape)) + for r, d in zip(flat_rewards, total_discount) + ] + + # NOTE: total_discount will have one less self_discount applied to it than + # the value of self._n_step. This is so that when the learner/update uses + # an additional discount we don't apply it twice. Inside the following loop + # we will apply this right before summing up the n_step_return. + for i in range(1, self._n_step): + for nsr, td, r, d, sd in zip( + n_step_return, + total_discount, + flat_rewards, + flat_discounts, + flat_self_discount, + ): + # Equivalent to: `total_discount *= self._discount`. + td *= sd + # Equivalent to: `n_step_return += reward[i] * total_discount`. + nsr += r[i] * td + # Equivalent to: `total_discount *= discount[i]`. + td *= d[i] + + n_step_return = tree.unflatten_as(rewards, n_step_return) + total_discount = tree.unflatten_as(rewards, total_discount) + return n_step_return, total_discount + + # TODO(bshahr): make this into a standalone method. Class methods should be + # used as alternative constructors or when modifying some global state, + # neither of which is done here. + @classmethod + def signature( + cls, environment_spec: specs.EnvironmentSpec, extras_spec: types.NestedSpec = () + ): + + # This function currently assumes that self._discount is a scalar. + # If it ever becomes a nested structure and/or a np.ndarray, this method + # will need to know its structure / shape. This is because the signature + # discount shape is the environment's discount shape and this adder's + # discount shape broadcasted together. Also, the reward shape is this + # signature discount shape broadcasted together with the environment + # reward shape. As long as self._discount is a scalar, it will not affect + # either the signature discount shape nor the signature reward shape, so we + # can ignore it. + + rewards_spec, step_discounts_spec = tree_utils.broadcast_structures( + environment_spec.rewards, environment_spec.discounts + ) + rewards_spec = tree.map_structure( + _broadcast_specs, rewards_spec, step_discounts_spec + ) + step_discounts_spec = tree.map_structure(copy.deepcopy, step_discounts_spec) + + transition_spec = types.Transition( + environment_spec.observations, + environment_spec.actions, + rewards_spec, + step_discounts_spec, + environment_spec.observations, # next_observation + extras_spec, + ) + + return tree.map_structure_with_path( + base.spec_like_to_tensor_spec, transition_spec + ) def _broadcast_specs(*args: specs.Array) -> specs.Array: - """Like np.broadcast, but for specs.Array. + """Like np.broadcast, but for specs.Array. Args: *args: one or more specs.Array instances. @@ -302,6 +316,6 @@ def _broadcast_specs(*args: specs.Array) -> specs.Array: Returns: A specs.Array with the broadcasted shape and dtype of the specs in *args. """ - bc_info = np.broadcast(*tuple(a.generate_value() for a in args)) - dtype = np.result_type(*tuple(a.dtype for a in args)) - return specs.Array(shape=bc_info.shape, dtype=dtype) + bc_info = np.broadcast(*tuple(a.generate_value() for a in args)) + dtype = np.result_type(*tuple(a.dtype for a in args)) + return specs.Array(shape=bc_info.shape, dtype=dtype) diff --git a/acme/adders/reverb/transition_test.py b/acme/adders/reverb/transition_test.py index 0c668d704d..b1a25d0c77 100644 --- a/acme/adders/reverb/transition_test.py +++ b/acme/adders/reverb/transition_test.py @@ -14,30 +14,27 @@ """Tests for NStepTransition adders.""" -from acme.adders.reverb import test_cases -from acme.adders.reverb import test_utils -from acme.adders.reverb import transition as adders - -from absl.testing import absltest -from absl.testing import parameterized - +from absl.testing import absltest, parameterized -class NStepTransitionAdderTest(test_utils.AdderTestMixin, - parameterized.TestCase): - - @parameterized.named_parameters(*test_cases.TEST_CASES_FOR_TRANSITION_ADDER) - def test_adder(self, n_step, additional_discount, first, steps, - expected_transitions): - adder = adders.NStepTransitionAdder(self.client, n_step, - additional_discount) - super().run_test_adder( - adder=adder, - first=first, - steps=steps, - expected_items=expected_transitions, - stack_sequence_fields=False, - signature=adder.signature(*test_utils.get_specs(steps[0]))) +from acme.adders.reverb import test_cases, test_utils +from acme.adders.reverb import transition as adders -if __name__ == '__main__': - absltest.main() +class NStepTransitionAdderTest(test_utils.AdderTestMixin, parameterized.TestCase): + @parameterized.named_parameters(*test_cases.TEST_CASES_FOR_TRANSITION_ADDER) + def test_adder( + self, n_step, additional_discount, first, steps, expected_transitions + ): + adder = adders.NStepTransitionAdder(self.client, n_step, additional_discount) + super().run_test_adder( + adder=adder, + first=first, + steps=steps, + expected_items=expected_transitions, + stack_sequence_fields=False, + signature=adder.signature(*test_utils.get_specs(steps[0])), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/adders/reverb/utils.py b/acme/adders/reverb/utils.py index 619712c71a..325fb8323b 100644 --- a/acme/adders/reverb/utils.py +++ b/acme/adders/reverb/utils.py @@ -16,16 +16,17 @@ from typing import Dict, Union -from acme import types -from acme.adders.reverb import base import jax import jax.numpy as jnp import numpy as np import tree +from acme import types +from acme.adders.reverb import base + def zeros_like(x: Union[np.ndarray, int, float, np.number]): - """Returns a zero-filled object of the same (d)type and shape as the input. + """Returns a zero-filled object of the same (d)type and shape as the input. The difference between this and `np.zeros_like()` is that this works well with `np.number`, `int`, `float`, and `jax.numpy.DeviceArray` objects without @@ -37,39 +38,41 @@ def zeros_like(x: Union[np.ndarray, int, float, np.number]): Returns: A zero-filed object of the same (d)type and shape as the input. """ - if isinstance(x, (int, float, np.number)): - return type(x)(0) - elif isinstance(x, jax.Array): - return jnp.zeros_like(x) - elif isinstance(x, np.ndarray): - return np.zeros_like(x) - else: - raise ValueError( - f'Input ({type(x)}) must be either a numpy array, an int, or a float.') - - -def final_step_like(step: base.Step, - next_observation: types.NestedArray) -> base.Step: - """Return a list of steps with the final step zero-filled.""" - # Make zero-filled components so we can fill out the last step. - zero_action, zero_reward, zero_discount, zero_extras = tree.map_structure( - zeros_like, (step.action, step.reward, step.discount, step.extras)) - - # Return a final step that only has next_observation. - return base.Step( - observation=next_observation, - action=zero_action, - reward=zero_reward, - discount=zero_discount, - start_of_episode=False, - extras=zero_extras) + if isinstance(x, (int, float, np.number)): + return type(x)(0) + elif isinstance(x, jax.Array): + return jnp.zeros_like(x) + elif isinstance(x, np.ndarray): + return np.zeros_like(x) + else: + raise ValueError( + f"Input ({type(x)}) must be either a numpy array, an int, or a float." + ) + + +def final_step_like(step: base.Step, next_observation: types.NestedArray) -> base.Step: + """Return a list of steps with the final step zero-filled.""" + # Make zero-filled components so we can fill out the last step. + zero_action, zero_reward, zero_discount, zero_extras = tree.map_structure( + zeros_like, (step.action, step.reward, step.discount, step.extras) + ) + + # Return a final step that only has next_observation. + return base.Step( + observation=next_observation, + action=zero_action, + reward=zero_reward, + discount=zero_discount, + start_of_episode=False, + extras=zero_extras, + ) def calculate_priorities( priority_fns: base.PriorityFnMapping, trajectory_or_transition: Union[base.Trajectory, types.Transition], ) -> Dict[str, float]: - """Helper used to calculate the priority of a Trajectory or Transition. + """Helper used to calculate the priority of a Trajectory or Transition. This helper converts the leaves of the Trajectory or Transition from `reverb.TrajectoryColumn` objects into numpy arrays. The converted Trajectory @@ -86,12 +89,13 @@ def calculate_priorities( A dictionary mapping from table names to the priority (a float) for the given collection Trajectory or Transition. """ - if any([priority_fn is not None for priority_fn in priority_fns.values()]): + if any([priority_fn is not None for priority_fn in priority_fns.values()]): - trajectory_or_transition = tree.map_structure(lambda col: col.numpy(), - trajectory_or_transition) + trajectory_or_transition = tree.map_structure( + lambda col: col.numpy(), trajectory_or_transition + ) - return { - table: (priority_fn(trajectory_or_transition) if priority_fn else 1.0) - for table, priority_fn in priority_fns.items() - } + return { + table: (priority_fn(trajectory_or_transition) if priority_fn else 1.0) + for table, priority_fn in priority_fns.items() + } diff --git a/acme/adders/wrappers.py b/acme/adders/wrappers.py index 9b26944e24..2cce2b9f36 100644 --- a/acme/adders/wrappers.py +++ b/acme/adders/wrappers.py @@ -16,47 +16,52 @@ from typing import Iterable +import dm_env + from acme import types from acme.adders import base -import dm_env class ForkingAdder(base.Adder): - """An adder that forks data into several other adders.""" + """An adder that forks data into several other adders.""" - def __init__(self, adders: Iterable[base.Adder]): - self._adders = adders + def __init__(self, adders: Iterable[base.Adder]): + self._adders = adders - def reset(self): - for adder in self._adders: - adder.reset() + def reset(self): + for adder in self._adders: + adder.reset() - def add_first(self, timestep: dm_env.TimeStep): - for adder in self._adders: - adder.add_first(timestep) + def add_first(self, timestep: dm_env.TimeStep): + for adder in self._adders: + adder.add_first(timestep) - def add(self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - extras: types.NestedArray = ()): - for adder in self._adders: - adder.add(action, next_timestep, extras) + def add( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = (), + ): + for adder in self._adders: + adder.add(action, next_timestep, extras) class IgnoreExtrasAdder(base.Adder): - """An adder that ignores extras.""" + """An adder that ignores extras.""" - def __init__(self, adder: base.Adder): - self._adder = adder + def __init__(self, adder: base.Adder): + self._adder = adder - def reset(self): - self._adder.reset() + def reset(self): + self._adder.reset() - def add_first(self, timestep: dm_env.TimeStep): - self._adder.add_first(timestep) + def add_first(self, timestep: dm_env.TimeStep): + self._adder.add_first(timestep) - def add(self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - extras: types.NestedArray = ()): - self._adder.add(action, next_timestep) + def add( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = (), + ): + self._adder.add(action, next_timestep) diff --git a/acme/agents/agent.py b/acme/agents/agent.py index 678b49611d..4e43481eff 100644 --- a/acme/agents/agent.py +++ b/acme/agents/agent.py @@ -17,31 +17,31 @@ import math from typing import List, Optional, Sequence -from acme import core -from acme import types import dm_env import numpy as np import reverb +from acme import core, types -def _calculate_num_learner_steps(num_observations: int, - min_observations: int, - observations_per_step: float) -> int: - """Calculates the number of learner steps to do at step=num_observations.""" - n = num_observations - min_observations - if n < 0: - # Do not do any learner steps until you have seen min_observations. - return 0 - if observations_per_step > 1: - # One batch every 1/obs_per_step observations, otherwise zero. - return int(n % int(observations_per_step) == 0) - else: - # Always return 1/obs_per_step batches every observation. - return int(1 / observations_per_step) + +def _calculate_num_learner_steps( + num_observations: int, min_observations: int, observations_per_step: float +) -> int: + """Calculates the number of learner steps to do at step=num_observations.""" + n = num_observations - min_observations + if n < 0: + # Do not do any learner steps until you have seen min_observations. + return 0 + if observations_per_step > 1: + # One batch every 1/obs_per_step observations, otherwise zero. + return int(n % int(observations_per_step) == 0) + else: + # Always return 1/obs_per_step batches every observation. + return int(1 / observations_per_step) class Agent(core.Actor, core.VariableSource): - """Agent class which combines acting and learning. + """Agent class which combines acting and learning. This provides an implementation of the `Actor` interface which acts and learns. It takes as input instances of both `acme.Actor` and `acme.Learner` @@ -57,80 +57,90 @@ class Agent(core.Actor, core.VariableSource): in order to allow the agent to take more than 1 learner step per action. """ - def __init__(self, actor: core.Actor, learner: core.Learner, - min_observations: Optional[int] = None, - observations_per_step: Optional[float] = None, - iterator: Optional[core.PrefetchingIterator] = None, - replay_tables: Optional[List[reverb.Table]] = None): - self._actor = actor - self._learner = learner - self._min_observations = min_observations - self._observations_per_step = observations_per_step - self._num_observations = 0 - self._iterator = iterator - self._replay_tables = replay_tables - self._batch_size_upper_bounds = [1_000_000_000] * len( - replay_tables) if replay_tables else None - - def select_action(self, observation: types.NestedArray) -> types.NestedArray: - return self._actor.select_action(observation) - - def observe_first(self, timestep: dm_env.TimeStep): - self._actor.observe_first(timestep) - - def observe(self, action: types.NestedArray, next_timestep: dm_env.TimeStep): - self._num_observations += 1 - self._actor.observe(action, next_timestep) - - def _has_data_for_training(self): - if self._iterator.ready(): - return True - for (table, batch_size) in zip(self._replay_tables, - self._batch_size_upper_bounds): - if not table.can_sample(batch_size): - return False - return True - - def update(self): - if self._iterator: - # Perform learner steps as long as iterator has data. - update_actor = False - while self._has_data_for_training(): - # Run learner steps (usually means gradient steps). - total_batches = self._iterator.retrieved_elements() - self._learner.step() - current_batches = self._iterator.retrieved_elements() - total_batches - assert current_batches == 1, ( - 'Learner step must retrieve exactly one element from the iterator' - f' (retrieved {current_batches}). Otherwise agent can deadlock. ' - 'Example cause is that your chosen agent' - 's Builder has a ' - '`make_learner` factory that prefetches the data but it ' - 'shouldn' - 't.') - self._batch_size_upper_bounds = [ - math.ceil(t.info.rate_limiter_info.sample_stats.completed / - (total_batches + 1)) for t in self._replay_tables - ] - update_actor = True - if update_actor: - # Update the actor weights only when learner was updated. - self._actor.update() - return - - # If dataset is not provided, follback to the old logic. - # TODO(stanczyk): Remove when not used. - num_steps = _calculate_num_learner_steps( - num_observations=self._num_observations, - min_observations=self._min_observations, - observations_per_step=self._observations_per_step, - ) - for _ in range(num_steps): - # Run learner steps (usually means gradient steps). - self._learner.step() - if num_steps > 0: - # Update the actor weights when learner updates. - self._actor.update() - - def get_variables(self, names: Sequence[str]) -> List[List[np.ndarray]]: - return self._learner.get_variables(names) + def __init__( + self, + actor: core.Actor, + learner: core.Learner, + min_observations: Optional[int] = None, + observations_per_step: Optional[float] = None, + iterator: Optional[core.PrefetchingIterator] = None, + replay_tables: Optional[List[reverb.Table]] = None, + ): + self._actor = actor + self._learner = learner + self._min_observations = min_observations + self._observations_per_step = observations_per_step + self._num_observations = 0 + self._iterator = iterator + self._replay_tables = replay_tables + self._batch_size_upper_bounds = ( + [1_000_000_000] * len(replay_tables) if replay_tables else None + ) + + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + return self._actor.select_action(observation) + + def observe_first(self, timestep: dm_env.TimeStep): + self._actor.observe_first(timestep) + + def observe(self, action: types.NestedArray, next_timestep: dm_env.TimeStep): + self._num_observations += 1 + self._actor.observe(action, next_timestep) + + def _has_data_for_training(self): + if self._iterator.ready(): + return True + for (table, batch_size) in zip( + self._replay_tables, self._batch_size_upper_bounds + ): + if not table.can_sample(batch_size): + return False + return True + + def update(self): + if self._iterator: + # Perform learner steps as long as iterator has data. + update_actor = False + while self._has_data_for_training(): + # Run learner steps (usually means gradient steps). + total_batches = self._iterator.retrieved_elements() + self._learner.step() + current_batches = self._iterator.retrieved_elements() - total_batches + assert current_batches == 1, ( + "Learner step must retrieve exactly one element from the iterator" + f" (retrieved {current_batches}). Otherwise agent can deadlock. " + "Example cause is that your chosen agent" + "s Builder has a " + "`make_learner` factory that prefetches the data but it " + "shouldn" + "t." + ) + self._batch_size_upper_bounds = [ + math.ceil( + t.info.rate_limiter_info.sample_stats.completed + / (total_batches + 1) + ) + for t in self._replay_tables + ] + update_actor = True + if update_actor: + # Update the actor weights only when learner was updated. + self._actor.update() + return + + # If dataset is not provided, follback to the old logic. + # TODO(stanczyk): Remove when not used. + num_steps = _calculate_num_learner_steps( + num_observations=self._num_observations, + min_observations=self._min_observations, + observations_per_step=self._observations_per_step, + ) + for _ in range(num_steps): + # Run learner steps (usually means gradient steps). + self._learner.step() + if num_steps > 0: + # Update the actor weights when learner updates. + self._actor.update() + + def get_variables(self, names: Sequence[str]) -> List[List[np.ndarray]]: + return self._learner.get_variables(names) diff --git a/acme/agents/jax/actor_core.py b/acme/agents/jax/actor_core.py index 598cb001fb..42cd3268ff 100644 --- a/acme/agents/jax/actor_core.py +++ b/acme/agents/jax/actor_core.py @@ -17,154 +17,167 @@ import dataclasses from typing import Callable, Generic, Mapping, Tuple, TypeVar, Union -from acme import types -from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.jax.types import PRNGKey import chex import jax import jax.numpy as jnp +from acme import types +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.jax.types import PRNGKey NoneType = type(None) # The state of the actor. This could include recurrent network state or any # other state which needs to be propagated through the select_action calls. -State = TypeVar('State') +State = TypeVar("State") # The extras to be passed to the observe method. -Extras = TypeVar('Extras') -RecurrentState = TypeVar('RecurrentState') +Extras = TypeVar("Extras") +RecurrentState = TypeVar("RecurrentState") SelectActionFn = Callable[ [networks_lib.Params, networks_lib.Observation, State], - Tuple[networks_lib.Action, State]] + Tuple[networks_lib.Action, State], +] @dataclasses.dataclass class ActorCore(Generic[State, Extras]): - """Pure functions that define the algorithm-specific actor functionality.""" - init: Callable[[PRNGKey], State] - select_action: SelectActionFn - get_extras: Callable[[State], Extras] + """Pure functions that define the algorithm-specific actor functionality.""" + + init: Callable[[PRNGKey], State] + select_action: SelectActionFn + get_extras: Callable[[State], Extras] # A simple feed forward policy which produces no extras and takes only an RNGKey # as a state. FeedForwardPolicy = Callable[ - [networks_lib.Params, PRNGKey, networks_lib.Observation], - networks_lib.Action] + [networks_lib.Params, PRNGKey, networks_lib.Observation], networks_lib.Action +] FeedForwardPolicyWithExtra = Callable[ [networks_lib.Params, PRNGKey, networks_lib.Observation], - Tuple[networks_lib.Action, types.NestedArray]] + Tuple[networks_lib.Action, types.NestedArray], +] -RecurrentPolicy = Callable[[ - networks_lib.Params, PRNGKey, networks_lib - .Observation, RecurrentState -], Tuple[networks_lib.Action, RecurrentState]] +RecurrentPolicy = Callable[ + [networks_lib.Params, PRNGKey, networks_lib.Observation, RecurrentState], + Tuple[networks_lib.Action, RecurrentState], +] Policy = Union[FeedForwardPolicy, FeedForwardPolicyWithExtra, RecurrentPolicy] def batched_feed_forward_to_actor_core( - policy: FeedForwardPolicy) -> ActorCore[PRNGKey, Tuple[()]]: - """A convenience adaptor from FeedForwardPolicy to ActorCore.""" + policy: FeedForwardPolicy, +) -> ActorCore[PRNGKey, Tuple[()]]: + """A convenience adaptor from FeedForwardPolicy to ActorCore.""" + + def select_action( + params: networks_lib.Params, + observation: networks_lib.Observation, + state: PRNGKey, + ): + rng = state + rng1, rng2 = jax.random.split(rng) + observation = utils.add_batch_dim(observation) + action = utils.squeeze_batch_dim(policy(params, rng1, observation)) + return action, rng2 - def select_action(params: networks_lib.Params, - observation: networks_lib.Observation, - state: PRNGKey): - rng = state - rng1, rng2 = jax.random.split(rng) - observation = utils.add_batch_dim(observation) - action = utils.squeeze_batch_dim(policy(params, rng1, observation)) - return action, rng2 + def init(rng: PRNGKey) -> PRNGKey: + return rng - def init(rng: PRNGKey) -> PRNGKey: - return rng + def get_extras(unused_rng: PRNGKey) -> Tuple[()]: + return () - def get_extras(unused_rng: PRNGKey) -> Tuple[()]: - return () - return ActorCore(init=init, select_action=select_action, - get_extras=get_extras) + return ActorCore(init=init, select_action=select_action, get_extras=get_extras) @chex.dataclass(frozen=True, mappable_dataclass=False) class SimpleActorCoreStateWithExtras: - rng: PRNGKey - extras: Mapping[str, jnp.ndarray] + rng: PRNGKey + extras: Mapping[str, jnp.ndarray] def unvectorize_select_action(actor_core: ActorCore) -> ActorCore: - """Makes an actor core's select_action method expect unbatched arguments.""" + """Makes an actor core's select_action method expect unbatched arguments.""" - def unvectorized_select_action( - params: networks_lib.Params, - observations: networks_lib.Observation, - state: State, - ) -> Tuple[networks_lib.Action, State]: - observations, state = utils.add_batch_dim((observations, state)) - actions, state = actor_core.select_action(params, observations, state) - return utils.squeeze_batch_dim((actions, state)) + def unvectorized_select_action( + params: networks_lib.Params, + observations: networks_lib.Observation, + state: State, + ) -> Tuple[networks_lib.Action, State]: + observations, state = utils.add_batch_dim((observations, state)) + actions, state = actor_core.select_action(params, observations, state) + return utils.squeeze_batch_dim((actions, state)) - return ActorCore( - init=actor_core.init, - select_action=unvectorized_select_action, - get_extras=actor_core.get_extras) + return ActorCore( + init=actor_core.init, + select_action=unvectorized_select_action, + get_extras=actor_core.get_extras, + ) def batched_feed_forward_with_extras_to_actor_core( - policy: FeedForwardPolicyWithExtra + policy: FeedForwardPolicyWithExtra, ) -> ActorCore[SimpleActorCoreStateWithExtras, Mapping[str, jnp.ndarray]]: - """A convenience adaptor from FeedForwardPolicy to ActorCore.""" + """A convenience adaptor from FeedForwardPolicy to ActorCore.""" + + def select_action( + params: networks_lib.Params, + observation: networks_lib.Observation, + state: SimpleActorCoreStateWithExtras, + ): + rng = state.rng + rng1, rng2 = jax.random.split(rng) + observation = utils.add_batch_dim(observation) + action, extras = utils.squeeze_batch_dim(policy(params, rng1, observation)) + return action, SimpleActorCoreStateWithExtras(rng2, extras) - def select_action(params: networks_lib.Params, - observation: networks_lib.Observation, - state: SimpleActorCoreStateWithExtras): - rng = state.rng - rng1, rng2 = jax.random.split(rng) - observation = utils.add_batch_dim(observation) - action, extras = utils.squeeze_batch_dim(policy(params, rng1, observation)) - return action, SimpleActorCoreStateWithExtras(rng2, extras) + def init(rng: PRNGKey) -> SimpleActorCoreStateWithExtras: + return SimpleActorCoreStateWithExtras(rng, {}) - def init(rng: PRNGKey) -> SimpleActorCoreStateWithExtras: - return SimpleActorCoreStateWithExtras(rng, {}) + def get_extras(state: SimpleActorCoreStateWithExtras) -> Mapping[str, jnp.ndarray]: + return state.extras - def get_extras( - state: SimpleActorCoreStateWithExtras) -> Mapping[str, jnp.ndarray]: - return state.extras - return ActorCore(init=init, select_action=select_action, - get_extras=get_extras) + return ActorCore(init=init, select_action=select_action, get_extras=get_extras) @chex.dataclass(frozen=True, mappable_dataclass=False) class SimpleActorCoreRecurrentState(Generic[RecurrentState]): - rng: PRNGKey - recurrent_state: RecurrentState + rng: PRNGKey + recurrent_state: RecurrentState def batched_recurrent_to_actor_core( recurrent_policy: RecurrentPolicy, initial_core_state: RecurrentState -) -> ActorCore[SimpleActorCoreRecurrentState[RecurrentState], Mapping[ - str, jnp.ndarray]]: - """Returns ActorCore for a recurrent policy.""" - def select_action(params: networks_lib.Params, - observation: networks_lib.Observation, - state: SimpleActorCoreRecurrentState[RecurrentState]): - # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. - rng = state.rng - rng, policy_rng = jax.random.split(rng) - observation = utils.add_batch_dim(observation) - recurrent_state = utils.add_batch_dim(state.recurrent_state) - action, new_recurrent_state = utils.squeeze_batch_dim(recurrent_policy( - params, policy_rng, observation, recurrent_state)) - return action, SimpleActorCoreRecurrentState(rng, new_recurrent_state) - - initial_core_state = utils.squeeze_batch_dim(initial_core_state) - def init(rng: PRNGKey) -> SimpleActorCoreRecurrentState[RecurrentState]: - return SimpleActorCoreRecurrentState(rng, initial_core_state) - - def get_extras( - state: SimpleActorCoreRecurrentState[RecurrentState] - ) -> Mapping[str, jnp.ndarray]: - return {'core_state': state.recurrent_state} - - return ActorCore(init=init, select_action=select_action, - get_extras=get_extras) +) -> ActorCore[ + SimpleActorCoreRecurrentState[RecurrentState], Mapping[str, jnp.ndarray] +]: + """Returns ActorCore for a recurrent policy.""" + + def select_action( + params: networks_lib.Params, + observation: networks_lib.Observation, + state: SimpleActorCoreRecurrentState[RecurrentState], + ): + # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. + rng = state.rng + rng, policy_rng = jax.random.split(rng) + observation = utils.add_batch_dim(observation) + recurrent_state = utils.add_batch_dim(state.recurrent_state) + action, new_recurrent_state = utils.squeeze_batch_dim( + recurrent_policy(params, policy_rng, observation, recurrent_state) + ) + return action, SimpleActorCoreRecurrentState(rng, new_recurrent_state) + + initial_core_state = utils.squeeze_batch_dim(initial_core_state) + + def init(rng: PRNGKey) -> SimpleActorCoreRecurrentState[RecurrentState]: + return SimpleActorCoreRecurrentState(rng, initial_core_state) + + def get_extras( + state: SimpleActorCoreRecurrentState[RecurrentState], + ) -> Mapping[str, jnp.ndarray]: + return {"core_state": state.recurrent_state} + + return ActorCore(init=init, select_action=select_action, get_extras=get_extras) diff --git a/acme/agents/jax/actors.py b/acme/agents/jax/actors.py index 20426c2e0d..70925f216e 100644 --- a/acme/agents/jax/actors.py +++ b/acme/agents/jax/actors.py @@ -16,36 +16,34 @@ from typing import Generic, Optional -from acme import adders -from acme import core -from acme import types -from acme.agents.jax import actor_core -from acme.jax import networks as network_lib -from acme.jax import utils -from acme.jax import variable_utils import dm_env import jax +from acme import adders, core, types +from acme.agents.jax import actor_core +from acme.jax import networks as network_lib +from acme.jax import utils, variable_utils + class GenericActor(core.Actor, Generic[actor_core.State, actor_core.Extras]): - """A generic actor implemented on top of ActorCore. + """A generic actor implemented on top of ActorCore. An actor based on a policy which takes observations and outputs actions. It also adds experiences to replay and updates the actor weights from the policy on the learner. """ - def __init__( - self, - actor: actor_core.ActorCore[actor_core.State, actor_core.Extras], - random_key: network_lib.PRNGKey, - variable_client: Optional[variable_utils.VariableClient], - adder: Optional[adders.Adder] = None, - jit: bool = True, - backend: Optional[str] = 'cpu', - per_episode_update: bool = False - ): - """Initializes a feed forward actor. + def __init__( + self, + actor: actor_core.ActorCore[actor_core.State, actor_core.Extras], + random_key: network_lib.PRNGKey, + variable_client: Optional[variable_utils.VariableClient], + adder: Optional[adders.Adder] = None, + jit: bool = True, + backend: Optional[str] = "cpu", + per_episode_update: bool = False, + ): + """Initializes a feed forward actor. Args: actor: actor core. @@ -57,43 +55,41 @@ def __init__( per_episode_update: if True, updates variable client params once at the beginning of each episode """ - self._random_key = random_key - self._variable_client = variable_client - self._adder = adder - self._state = None + self._random_key = random_key + self._variable_client = variable_client + self._adder = adder + self._state = None - # Unpack ActorCore, jitting if requested. - if jit: - self._init = jax.jit(actor.init, backend=backend) - self._policy = jax.jit(actor.select_action, backend=backend) - else: - self._init = actor.init - self._policy = actor.select_action - self._get_extras = actor.get_extras - self._per_episode_update = per_episode_update + # Unpack ActorCore, jitting if requested. + if jit: + self._init = jax.jit(actor.init, backend=backend) + self._policy = jax.jit(actor.select_action, backend=backend) + else: + self._init = actor.init + self._policy = actor.select_action + self._get_extras = actor.get_extras + self._per_episode_update = per_episode_update - @property - def _params(self): - return self._variable_client.params if self._variable_client else [] + @property + def _params(self): + return self._variable_client.params if self._variable_client else [] - def select_action(self, - observation: network_lib.Observation) -> types.NestedArray: - action, self._state = self._policy(self._params, observation, self._state) - return utils.to_numpy(action) + def select_action(self, observation: network_lib.Observation) -> types.NestedArray: + action, self._state = self._policy(self._params, observation, self._state) + return utils.to_numpy(action) - def observe_first(self, timestep: dm_env.TimeStep): - self._random_key, key = jax.random.split(self._random_key) - self._state = self._init(key) - if self._adder: - self._adder.add_first(timestep) - if self._variable_client and self._per_episode_update: - self._variable_client.update_and_wait() + def observe_first(self, timestep: dm_env.TimeStep): + self._random_key, key = jax.random.split(self._random_key) + self._state = self._init(key) + if self._adder: + self._adder.add_first(timestep) + if self._variable_client and self._per_episode_update: + self._variable_client.update_and_wait() - def observe(self, action: network_lib.Action, next_timestep: dm_env.TimeStep): - if self._adder: - self._adder.add( - action, next_timestep, extras=self._get_extras(self._state)) + def observe(self, action: network_lib.Action, next_timestep: dm_env.TimeStep): + if self._adder: + self._adder.add(action, next_timestep, extras=self._get_extras(self._state)) - def update(self, wait: bool = False): - if self._variable_client and not self._per_episode_update: - self._variable_client.update(wait) + def update(self, wait: bool = False): + if self._variable_client and not self._per_episode_update: + self._variable_client.update(wait) diff --git a/acme/agents/jax/actors_test.py b/acme/agents/jax/actors_test.py index 941e7a20a2..5d9e356005 100644 --- a/acme/agents/jax/actors_test.py +++ b/acme/agents/jax/actors_test.py @@ -15,128 +15,123 @@ """Tests for actors.""" from typing import Optional, Tuple -from acme import environment_loop -from acme import specs -from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.jax import utils -from acme.jax import variable_utils -from acme.testing import fakes import dm_env import haiku as hk import jax import jax.numpy as jnp import numpy as np +from absl.testing import absltest, parameterized -from absl.testing import absltest -from absl.testing import parameterized +from acme import environment_loop, specs +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.jax import utils, variable_utils +from acme.testing import fakes def _make_fake_env() -> dm_env.Environment: - env_spec = specs.EnvironmentSpec( - observations=specs.Array(shape=(10, 5), dtype=np.float32), - actions=specs.DiscreteArray(num_values=3), - rewards=specs.Array(shape=(), dtype=np.float32), - discounts=specs.BoundedArray( - shape=(), dtype=np.float32, minimum=0., maximum=1.), - ) - return fakes.Environment(env_spec, episode_length=10) + env_spec = specs.EnvironmentSpec( + observations=specs.Array(shape=(10, 5), dtype=np.float32), + actions=specs.DiscreteArray(num_values=3), + rewards=specs.Array(shape=(), dtype=np.float32), + discounts=specs.BoundedArray( + shape=(), dtype=np.float32, minimum=0.0, maximum=1.0 + ), + ) + return fakes.Environment(env_spec, episode_length=10) class ActorTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('policy', False), - ('policy_with_extras', True)) - def test_feedforward(self, has_extras): - environment = _make_fake_env() - env_spec = specs.make_environment_spec(environment) - - def policy(inputs: jnp.ndarray): - action_values = hk.Sequential([ - hk.Flatten(), - hk.Linear(env_spec.actions.num_values), - ])( - inputs) - action = jnp.argmax(action_values, axis=-1) - if has_extras: - return action, (action_values,) - else: - return action - - policy = hk.transform(policy) - - rng = hk.PRNGSequence(1) - dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) - params = policy.init(next(rng), dummy_obs) - - variable_source = fakes.VariableSource(params) - variable_client = variable_utils.VariableClient(variable_source, 'policy') - - if has_extras: - actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core( - policy.apply) - else: - actor_core = actor_core_lib.batched_feed_forward_to_actor_core( - policy.apply) - actor = actors.GenericActor( - actor_core, - random_key=jax.random.PRNGKey(1), - variable_client=variable_client) - - loop = environment_loop.EnvironmentLoop(environment, actor) - loop.run(20) + @parameterized.named_parameters(("policy", False), ("policy_with_extras", True)) + def test_feedforward(self, has_extras): + environment = _make_fake_env() + env_spec = specs.make_environment_spec(environment) + + def policy(inputs: jnp.ndarray): + action_values = hk.Sequential( + [hk.Flatten(), hk.Linear(env_spec.actions.num_values),] + )(inputs) + action = jnp.argmax(action_values, axis=-1) + if has_extras: + return action, (action_values,) + else: + return action + + policy = hk.transform(policy) + + rng = hk.PRNGSequence(1) + dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) + params = policy.init(next(rng), dummy_obs) + + variable_source = fakes.VariableSource(params) + variable_client = variable_utils.VariableClient(variable_source, "policy") + + if has_extras: + actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core( + policy.apply + ) + else: + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy.apply) + actor = actors.GenericActor( + actor_core, + random_key=jax.random.PRNGKey(1), + variable_client=variable_client, + ) + + loop = environment_loop.EnvironmentLoop(environment, actor) + loop.run(20) def _transform_without_rng(f): - return hk.without_apply_rng(hk.transform(f)) + return hk.without_apply_rng(hk.transform(f)) class RecurrentActorTest(absltest.TestCase): - - def test_recurrent(self): - environment = _make_fake_env() - env_spec = specs.make_environment_spec(environment) - output_size = env_spec.actions.num_values - obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) - rng = hk.PRNGSequence(1) - - @_transform_without_rng - def network(inputs: jnp.ndarray, state: hk.LSTMState): - return hk.DeepRNN([hk.Reshape([-1], preserve_dims=1), - hk.LSTM(output_size)])(inputs, state) - - @_transform_without_rng - def initial_state(batch_size: Optional[int] = None): - network = hk.DeepRNN([hk.Reshape([-1], preserve_dims=1), - hk.LSTM(output_size)]) - return network.initial_state(batch_size) - - initial_state = initial_state.apply(initial_state.init(next(rng)), 1) - params = network.init(next(rng), obs, initial_state) - - def policy( - params: jnp.ndarray, - key: jnp.ndarray, - observation: jnp.ndarray, - core_state: hk.LSTMState - ) -> Tuple[jnp.ndarray, hk.LSTMState]: - del key # Unused for test-case deterministic policy. - action_values, core_state = network.apply(params, observation, core_state) - actions = jnp.argmax(action_values, axis=-1) - return actions, core_state - - variable_source = fakes.VariableSource(params) - variable_client = variable_utils.VariableClient(variable_source, 'policy') - - actor_core = actor_core_lib.batched_recurrent_to_actor_core( - policy, initial_state) - actor = actors.GenericActor(actor_core, jax.random.PRNGKey(1), - variable_client) - - loop = environment_loop.EnvironmentLoop(environment, actor) - loop.run(20) - - -if __name__ == '__main__': - absltest.main() + def test_recurrent(self): + environment = _make_fake_env() + env_spec = specs.make_environment_spec(environment) + output_size = env_spec.actions.num_values + obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) + rng = hk.PRNGSequence(1) + + @_transform_without_rng + def network(inputs: jnp.ndarray, state: hk.LSTMState): + return hk.DeepRNN( + [hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)] + )(inputs, state) + + @_transform_without_rng + def initial_state(batch_size: Optional[int] = None): + network = hk.DeepRNN( + [hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)] + ) + return network.initial_state(batch_size) + + initial_state = initial_state.apply(initial_state.init(next(rng)), 1) + params = network.init(next(rng), obs, initial_state) + + def policy( + params: jnp.ndarray, + key: jnp.ndarray, + observation: jnp.ndarray, + core_state: hk.LSTMState, + ) -> Tuple[jnp.ndarray, hk.LSTMState]: + del key # Unused for test-case deterministic policy. + action_values, core_state = network.apply(params, observation, core_state) + actions = jnp.argmax(action_values, axis=-1) + return actions, core_state + + variable_source = fakes.VariableSource(params) + variable_client = variable_utils.VariableClient(variable_source, "policy") + + actor_core = actor_core_lib.batched_recurrent_to_actor_core( + policy, initial_state + ) + actor = actors.GenericActor(actor_core, jax.random.PRNGKey(1), variable_client) + + loop = environment_loop.EnvironmentLoop(environment, actor) + loop.run(20) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/jax/ail/__init__.py b/acme/agents/jax/ail/__init__.py index df302c494e..40a933b8fb 100644 --- a/acme/agents/jax/ail/__init__.py +++ b/acme/agents/jax/ail/__init__.py @@ -14,18 +14,17 @@ """Implementations of a AIL agent.""" -from acme.agents.jax.ail import losses -from acme.agents.jax.ail import rewards +from acme.agents.jax.ail import losses, rewards from acme.agents.jax.ail.builder import AILBuilder from acme.agents.jax.ail.config import AILConfig -from acme.agents.jax.ail.dac import DACBuilder -from acme.agents.jax.ail.dac import DACConfig -from acme.agents.jax.ail.gail import GAILBuilder -from acme.agents.jax.ail.gail import GAILConfig +from acme.agents.jax.ail.dac import DACBuilder, DACConfig +from acme.agents.jax.ail.gail import GAILBuilder, GAILConfig from acme.agents.jax.ail.learning import AILLearner -from acme.agents.jax.ail.networks import AILNetworks -from acme.agents.jax.ail.networks import AIRLModule -from acme.agents.jax.ail.networks import compute_ail_reward -from acme.agents.jax.ail.networks import DiscriminatorMLP -from acme.agents.jax.ail.networks import DiscriminatorModule -from acme.agents.jax.ail.networks import make_discriminator +from acme.agents.jax.ail.networks import ( + AILNetworks, + AIRLModule, + DiscriminatorMLP, + DiscriminatorModule, + compute_ail_reward, + make_discriminator, +) diff --git a/acme/agents/jax/ail/builder.py b/acme/agents/jax/ail/builder.py index b71350c91a..9e0477deef 100644 --- a/acme/agents/jax/ail/builder.py +++ b/acme/agents/jax/ail/builder.py @@ -18,35 +18,30 @@ import itertools from typing import Callable, Generic, Iterator, List, Optional, Tuple -from acme import adders -from acme import core -from acme import specs -from acme import types +import jax +import numpy as np +import optax +import reverb +import tree +from reverb import rate_limiters + +from acme import adders, core, specs, types from acme.adders import reverb as adders_reverb from acme.agents.jax import builders from acme.agents.jax.ail import config as ail_config -from acme.agents.jax.ail import learning -from acme.agents.jax.ail import losses +from acme.agents.jax.ail import learning, losses from acme.agents.jax.ail import networks as ail_networks from acme.datasets import reverb as datasets from acme.jax import types as jax_types from acme.jax import utils from acme.jax.imitation_learning_types import DirectPolicyNetwork -from acme.utils import counting -from acme.utils import loggers -from acme.utils import reverb_utils -import jax -import numpy as np -import optax -import reverb -from reverb import rate_limiters -import tree +from acme.utils import counting, loggers, reverb_utils def _split_transitions( - transitions: types.Transition, - index: int) -> Tuple[types.Transition, types.Transition]: - """Splits the given transition on the first axis at the given index. + transitions: types.Transition, index: int +) -> Tuple[types.Transition, types.Transition]: + """Splits the given transition on the first axis at the given index. Args: transitions: Transitions to split. @@ -56,13 +51,16 @@ def _split_transitions( A pair of transitions, the first containing elements before the index (exclusive) and the second after the index (inclusive) """ - return (tree.map_structure(lambda x: x[:index], transitions), - tree.map_structure(lambda x: x[index:], transitions)) + return ( + tree.map_structure(lambda x: x[:index], transitions), + tree.map_structure(lambda x: x[index:], transitions), + ) -def _rebatch(iterator: Iterator[types.Transition], - batch_size: int) -> Iterator[types.Transition]: - """Rebatch the itererator with the given batch size. +def _rebatch( + iterator: Iterator[types.Transition], batch_size: int +) -> Iterator[types.Transition]: + """Rebatch the itererator with the given batch size. Args: iterator: Iterator to rebatch. @@ -71,23 +69,22 @@ def _rebatch(iterator: Iterator[types.Transition], Yields: Transitions with the new batch size. """ - data = next(iterator) - while True: - while len(data.reward) < batch_size: - # Ensure we can get enough demonstrations. - next_data = next(iterator) - data = tree.map_structure(lambda *args: np.concatenate(list(args)), data, - next_data) - output, data = _split_transitions(data, batch_size) - yield output + data = next(iterator) + while True: + while len(data.reward) < batch_size: + # Ensure we can get enough demonstrations. + next_data = next(iterator) + data = tree.map_structure( + lambda *args: np.concatenate(list(args)), data, next_data + ) + output, data = _split_transitions(data, batch_size) + yield output def _mix_arrays( - replay: np.ndarray, - demo: np.ndarray, - index: int, - seed: int) -> np.ndarray: - """Mixes `replay` and `demo`. + replay: np.ndarray, demo: np.ndarray, index: int, seed: int +) -> np.ndarray: + """Mixes `replay` and `demo`. Args: replay: Replay data to mix. Only index element will be selected. @@ -98,18 +95,19 @@ def _mix_arrays( Returns: An array with replay elements up to 'index' and all the demos. """ - # We're throwing away some replay data here. We have to if we want to make - # sure the output info field is correct. - output = np.concatenate((replay[:index], demo)) - return np.random.default_rng(seed=seed).permutation(output) + # We're throwing away some replay data here. We have to if we want to make + # sure the output info field is correct. + output = np.concatenate((replay[:index], demo)) + return np.random.default_rng(seed=seed).permutation(output) def _generate_samples_with_demonstrations( demonstration_iterator: Iterator[types.Transition], replay_iterator: Iterator[reverb.ReplaySample], policy_to_expert_data_ratio: int, - batch_size) -> Iterator[reverb.ReplaySample]: - """Generator which creates the sample having demonstrations in them. + batch_size, +) -> Iterator[reverb.ReplaySample]: + """Generator which creates the sample having demonstrations in them. It takes the demonstrations and replay iterators and generates batches with same size as the replay iterator, such that each batches have the ratio of @@ -129,39 +127,44 @@ def _generate_samples_with_demonstrations( the current replay sample info and the batch size will be the same as the replay_iterator data batch size. """ - count = 0 - if batch_size % (policy_to_expert_data_ratio + 1) != 0: - raise ValueError( - 'policy_to_expert_data_ratio + 1 must divide the batch size but ' - f'{batch_size} % {policy_to_expert_data_ratio+1} !=0') - demo_insertion_size = batch_size // (policy_to_expert_data_ratio + 1) - policy_insertion_size = batch_size - demo_insertion_size - - demonstration_iterator = _rebatch(demonstration_iterator, demo_insertion_size) - for sample, demos in zip(replay_iterator, demonstration_iterator): - output_transitions = tree.map_structure( - functools.partial(_mix_arrays, - index=policy_insertion_size, - seed=count), - sample.data, demos) - count += 1 - yield reverb.ReplaySample(info=sample.info, data=output_transitions) - - -class AILBuilder(builders.ActorLearnerBuilder[ail_networks.AILNetworks, - DirectPolicyNetwork, - learning.AILSample], - Generic[ail_networks.DirectRLNetworks, DirectPolicyNetwork]): - """AIL Builder.""" - - def __init__( - self, - rl_agent: builders.ActorLearnerBuilder[ail_networks.DirectRLNetworks, - DirectPolicyNetwork, - reverb.ReplaySample], - config: ail_config.AILConfig, discriminator_loss: losses.Loss, - make_demonstrations: Callable[[int], Iterator[types.Transition]]): - """Implements a builder for AIL using rl_agent as forward RL algorithm. + count = 0 + if batch_size % (policy_to_expert_data_ratio + 1) != 0: + raise ValueError( + "policy_to_expert_data_ratio + 1 must divide the batch size but " + f"{batch_size} % {policy_to_expert_data_ratio+1} !=0" + ) + demo_insertion_size = batch_size // (policy_to_expert_data_ratio + 1) + policy_insertion_size = batch_size - demo_insertion_size + + demonstration_iterator = _rebatch(demonstration_iterator, demo_insertion_size) + for sample, demos in zip(replay_iterator, demonstration_iterator): + output_transitions = tree.map_structure( + functools.partial(_mix_arrays, index=policy_insertion_size, seed=count), + sample.data, + demos, + ) + count += 1 + yield reverb.ReplaySample(info=sample.info, data=output_transitions) + + +class AILBuilder( + builders.ActorLearnerBuilder[ + ail_networks.AILNetworks, DirectPolicyNetwork, learning.AILSample + ], + Generic[ail_networks.DirectRLNetworks, DirectPolicyNetwork], +): + """AIL Builder.""" + + def __init__( + self, + rl_agent: builders.ActorLearnerBuilder[ + ail_networks.DirectRLNetworks, DirectPolicyNetwork, reverb.ReplaySample + ], + config: ail_config.AILConfig, + discriminator_loss: losses.Loss, + make_demonstrations: Callable[[int], Iterator[types.Transition]], + ): + """Implements a builder for AIL using rl_agent as forward RL algorithm. Args: rl_agent: The standard RL agent used by AIL to optimize the generator. @@ -170,162 +173,191 @@ def __init__( make_demonstrations: A function that returns an iterator with demonstrations to be imitated. """ - self._rl_agent = rl_agent - self._config = config - self._discriminator_loss = discriminator_loss - self._make_demonstrations = make_demonstrations - - def make_learner(self, - random_key: jax_types.PRNGKey, - networks: ail_networks.AILNetworks, - dataset: Iterator[learning.AILSample], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None) -> core.Learner: - counter = counter or counting.Counter() - direct_rl_counter = counting.Counter(counter, 'direct_rl') - batch_size_per_learner_step = ail_config.get_per_learner_step_batch_size( - self._config) - - direct_rl_learner_key, discriminator_key = jax.random.split(random_key) - - direct_rl_learner = functools.partial( - self._rl_agent.make_learner, - direct_rl_learner_key, - networks.direct_rl_networks, - logger_fn=logger_fn, - environment_spec=environment_spec, - replay_client=replay_client, - counter=direct_rl_counter) - - discriminator_optimizer = ( - self._config.discriminator_optimizer or optax.adam(1e-5)) - - return learning.AILLearner( - counter, - direct_rl_learner_factory=direct_rl_learner, - loss_fn=self._discriminator_loss, - iterator=dataset, - discriminator_optimizer=discriminator_optimizer, - ail_network=networks, - discriminator_key=discriminator_key, - is_sequence_based=self._config.is_sequence_based, - num_sgd_steps_per_step=batch_size_per_learner_step // - self._config.discriminator_batch_size, - policy_variable_name=self._config.policy_variable_name, - logger=logger_fn('learner', steps_key=counter.get_steps_key())) - - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: DirectPolicyNetwork, - ) -> List[reverb.Table]: - replay_tables = self._rl_agent.make_replay_tables(environment_spec, policy) - if self._config.share_iterator: - return replay_tables - replay_tables.append( - reverb.Table( - name=self._config.replay_table_name, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=self._config.max_replay_size, - rate_limiter=rate_limiters.MinSize(self._config.min_replay_size), - signature=adders_reverb.NStepTransitionAdder.signature( - environment_spec))) - return replay_tables - - # This function does not expose all the iterators used by the learner when - # share_iterator is False, making further wrapping impossible. - # TODO(eorsini): Expose all iterators. - # Currently GAIL uses 3 iterators, instead we can make it use a single - # iterator and return this one here. The way to achieve this would be: - # * Create the 3 iterators here. - # * zip them and return them here. - # * upzip them in the learner (this step will not be necessary once we move to - # stateless learners) - # This should work fine as the 3 iterators are always iterated in parallel - # (i.e. at every step we call next once on each of them). - def make_dataset_iterator( - self, replay_client: reverb.Client) -> Iterator[learning.AILSample]: - batch_size_per_learner_step = ail_config.get_per_learner_step_batch_size( - self._config) - - iterator_demonstration = self._make_demonstrations( - batch_size_per_learner_step) - - direct_iterator = self._rl_agent.make_dataset_iterator(replay_client) - - if self._config.share_iterator: - # In order to reuse the iterator return values and not lose a 2x factor on - # sample efficiency, we need to use itertools.tee(). - discriminator_iterator, direct_iterator = itertools.tee(direct_iterator) - else: - discriminator_iterator = datasets.make_reverb_dataset( - table=self._config.replay_table_name, - server_address=replay_client.server_address, - batch_size=ail_config.get_per_learner_step_batch_size(self._config), - prefetch_size=self._config.prefetch_size).as_numpy_iterator() - - if self._config.policy_to_expert_data_ratio is not None: - iterator_demonstration, iterator_demonstration2 = itertools.tee( - iterator_demonstration) - direct_iterator = _generate_samples_with_demonstrations( - iterator_demonstration2, direct_iterator, - self._config.policy_to_expert_data_ratio, - self._config.direct_rl_batch_size) - - is_sequence_based = self._config.is_sequence_based - - # Don't flatten the discriminator batch if the iterator is not shared. - process_discriminator_sample = functools.partial( - reverb_utils.replay_sample_to_sars_transition, - is_sequence=is_sequence_based and self._config.share_iterator, - flatten_batch=is_sequence_based and self._config.share_iterator, - strip_last_transition=is_sequence_based and self._config.share_iterator) - - discriminator_iterator = ( - # Remove the extras to have the same nested structure as demonstrations. - process_discriminator_sample(sample)._replace(extras=()) - for sample in discriminator_iterator) - - return utils.device_put((learning.AILSample(*sample) for sample in zip( - discriminator_iterator, direct_iterator, iterator_demonstration)), - jax.devices()[0]) - - def make_adder( - self, replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[DirectPolicyNetwork]) -> Optional[adders.Adder]: - direct_rl_adder = self._rl_agent.make_adder(replay_client, environment_spec, - policy) - if self._config.share_iterator: - return direct_rl_adder - ail_adder = adders_reverb.NStepTransitionAdder( - priority_fns={self._config.replay_table_name: None}, - client=replay_client, - n_step=1, - discount=self._config.discount) - - # Some direct rl algorithms (such as PPO), might be passing extra data - # which we won't be able to process here properly, so we need to ignore them - return adders.ForkingAdder( - [adders.IgnoreExtrasAdder(ail_adder), direct_rl_adder]) - - def make_actor( - self, - random_key: jax_types.PRNGKey, - policy: DirectPolicyNetwork, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - adder: Optional[adders.Adder] = None, - ) -> core.Actor: - return self._rl_agent.make_actor(random_key, policy, environment_spec, - variable_source, adder) - - def make_policy(self, - networks: ail_networks.AILNetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> DirectPolicyNetwork: - return self._rl_agent.make_policy(networks.direct_rl_networks, - environment_spec, evaluation) + self._rl_agent = rl_agent + self._config = config + self._discriminator_loss = discriminator_loss + self._make_demonstrations = make_demonstrations + + def make_learner( + self, + random_key: jax_types.PRNGKey, + networks: ail_networks.AILNetworks, + dataset: Iterator[learning.AILSample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + counter = counter or counting.Counter() + direct_rl_counter = counting.Counter(counter, "direct_rl") + batch_size_per_learner_step = ail_config.get_per_learner_step_batch_size( + self._config + ) + + direct_rl_learner_key, discriminator_key = jax.random.split(random_key) + + direct_rl_learner = functools.partial( + self._rl_agent.make_learner, + direct_rl_learner_key, + networks.direct_rl_networks, + logger_fn=logger_fn, + environment_spec=environment_spec, + replay_client=replay_client, + counter=direct_rl_counter, + ) + + discriminator_optimizer = self._config.discriminator_optimizer or optax.adam( + 1e-5 + ) + + return learning.AILLearner( + counter, + direct_rl_learner_factory=direct_rl_learner, + loss_fn=self._discriminator_loss, + iterator=dataset, + discriminator_optimizer=discriminator_optimizer, + ail_network=networks, + discriminator_key=discriminator_key, + is_sequence_based=self._config.is_sequence_based, + num_sgd_steps_per_step=batch_size_per_learner_step + // self._config.discriminator_batch_size, + policy_variable_name=self._config.policy_variable_name, + logger=logger_fn("learner", steps_key=counter.get_steps_key()), + ) + + def make_replay_tables( + self, environment_spec: specs.EnvironmentSpec, policy: DirectPolicyNetwork, + ) -> List[reverb.Table]: + replay_tables = self._rl_agent.make_replay_tables(environment_spec, policy) + if self._config.share_iterator: + return replay_tables + replay_tables.append( + reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=rate_limiters.MinSize(self._config.min_replay_size), + signature=adders_reverb.NStepTransitionAdder.signature( + environment_spec + ), + ) + ) + return replay_tables + + # This function does not expose all the iterators used by the learner when + # share_iterator is False, making further wrapping impossible. + # TODO(eorsini): Expose all iterators. + # Currently GAIL uses 3 iterators, instead we can make it use a single + # iterator and return this one here. The way to achieve this would be: + # * Create the 3 iterators here. + # * zip them and return them here. + # * upzip them in the learner (this step will not be necessary once we move to + # stateless learners) + # This should work fine as the 3 iterators are always iterated in parallel + # (i.e. at every step we call next once on each of them). + def make_dataset_iterator( + self, replay_client: reverb.Client + ) -> Iterator[learning.AILSample]: + batch_size_per_learner_step = ail_config.get_per_learner_step_batch_size( + self._config + ) + + iterator_demonstration = self._make_demonstrations(batch_size_per_learner_step) + + direct_iterator = self._rl_agent.make_dataset_iterator(replay_client) + + if self._config.share_iterator: + # In order to reuse the iterator return values and not lose a 2x factor on + # sample efficiency, we need to use itertools.tee(). + discriminator_iterator, direct_iterator = itertools.tee(direct_iterator) + else: + discriminator_iterator = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=replay_client.server_address, + batch_size=ail_config.get_per_learner_step_batch_size(self._config), + prefetch_size=self._config.prefetch_size, + ).as_numpy_iterator() + + if self._config.policy_to_expert_data_ratio is not None: + iterator_demonstration, iterator_demonstration2 = itertools.tee( + iterator_demonstration + ) + direct_iterator = _generate_samples_with_demonstrations( + iterator_demonstration2, + direct_iterator, + self._config.policy_to_expert_data_ratio, + self._config.direct_rl_batch_size, + ) + + is_sequence_based = self._config.is_sequence_based + + # Don't flatten the discriminator batch if the iterator is not shared. + process_discriminator_sample = functools.partial( + reverb_utils.replay_sample_to_sars_transition, + is_sequence=is_sequence_based and self._config.share_iterator, + flatten_batch=is_sequence_based and self._config.share_iterator, + strip_last_transition=is_sequence_based and self._config.share_iterator, + ) + + discriminator_iterator = ( + # Remove the extras to have the same nested structure as demonstrations. + process_discriminator_sample(sample)._replace(extras=()) + for sample in discriminator_iterator + ) + + return utils.device_put( + ( + learning.AILSample(*sample) + for sample in zip( + discriminator_iterator, direct_iterator, iterator_demonstration + ) + ), + jax.devices()[0], + ) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[DirectPolicyNetwork], + ) -> Optional[adders.Adder]: + direct_rl_adder = self._rl_agent.make_adder( + replay_client, environment_spec, policy + ) + if self._config.share_iterator: + return direct_rl_adder + ail_adder = adders_reverb.NStepTransitionAdder( + priority_fns={self._config.replay_table_name: None}, + client=replay_client, + n_step=1, + discount=self._config.discount, + ) + + # Some direct rl algorithms (such as PPO), might be passing extra data + # which we won't be able to process here properly, so we need to ignore them + return adders.ForkingAdder( + [adders.IgnoreExtrasAdder(ail_adder), direct_rl_adder] + ) + + def make_actor( + self, + random_key: jax_types.PRNGKey, + policy: DirectPolicyNetwork, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + return self._rl_agent.make_actor( + random_key, policy, environment_spec, variable_source, adder + ) + + def make_policy( + self, + networks: ail_networks.AILNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> DirectPolicyNetwork: + return self._rl_agent.make_policy( + networks.direct_rl_networks, environment_spec, evaluation + ) diff --git a/acme/agents/jax/ail/builder_test.py b/acme/agents/jax/ail/builder_test.py index 800dbc3f36..df9684080e 100644 --- a/acme/agents/jax/ail/builder_test.py +++ b/acme/agents/jax/ail/builder_test.py @@ -13,44 +13,52 @@ # limitations under the License. """Tests for the builder generator.""" -from acme import types -from acme.agents.jax.ail import builder import numpy as np import reverb - from absl.testing import absltest +from acme import types +from acme.agents.jax.ail import builder + _REWARD = np.zeros((3,)) class BuilderTest(absltest.TestCase): + def test_weighted_generator(self): + data0 = types.Transition(np.array([[1], [2], [3]]), (), _REWARD, (), ()) + it0 = iter([data0]) + + data1 = types.Transition(np.array([[4], [5], [6]]), (), _REWARD, (), ()) + data2 = types.Transition(np.array([[7], [8], [9]]), (), _REWARD, (), ()) + it1 = iter( + [ + reverb.ReplaySample( + info=reverb.SampleInfo( + *[() for _ in reverb.SampleInfo.tf_dtypes()] + ), + data=data1, + ), + reverb.ReplaySample( + info=reverb.SampleInfo( + *[() for _ in reverb.SampleInfo.tf_dtypes()] + ), + data=data2, + ), + ] + ) + + weighted_it = builder._generate_samples_with_demonstrations( + it0, it1, policy_to_expert_data_ratio=2, batch_size=3 + ) + + np.testing.assert_array_equal( + next(weighted_it).data.observation, np.array([[1], [4], [5]]) + ) + np.testing.assert_array_equal( + next(weighted_it).data.observation, np.array([[7], [8], [2]]) + ) + self.assertRaises(StopIteration, lambda: next(weighted_it)) + - def test_weighted_generator(self): - data0 = types.Transition(np.array([[1], [2], [3]]), (), _REWARD, (), ()) - it0 = iter([data0]) - - data1 = types.Transition(np.array([[4], [5], [6]]), (), _REWARD, (), ()) - data2 = types.Transition(np.array([[7], [8], [9]]), (), _REWARD, (), ()) - it1 = iter([ - reverb.ReplaySample( - info=reverb.SampleInfo( - *[() for _ in reverb.SampleInfo.tf_dtypes()]), - data=data1), - reverb.ReplaySample( - info=reverb.SampleInfo( - *[() for _ in reverb.SampleInfo.tf_dtypes()]), - data=data2) - ]) - - weighted_it = builder._generate_samples_with_demonstrations( - it0, it1, policy_to_expert_data_ratio=2, batch_size=3) - - np.testing.assert_array_equal( - next(weighted_it).data.observation, np.array([[1], [4], [5]])) - np.testing.assert_array_equal( - next(weighted_it).data.observation, np.array([[7], [8], [2]])) - self.assertRaises(StopIteration, lambda: next(weighted_it)) - - -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/jax/ail/config.py b/acme/agents/jax/ail/config.py index c5541de533..07abcfc76a 100644 --- a/acme/agents/jax/ail/config.py +++ b/acme/agents/jax/ail/config.py @@ -21,7 +21,7 @@ @dataclasses.dataclass class AILConfig: - """Configuration options for AIL. + """Configuration options for AIL. Attributes: direct_rl_batch_size: Batch size of a direct rl algorithm (measured in @@ -47,32 +47,33 @@ class AILConfig: expert transitions in the given proportions. policy_to_expert_data_ratio + 1 must divide the direct RL batch size. """ - direct_rl_batch_size: int - is_sequence_based: bool = False - share_iterator: bool = True - num_sgd_steps_per_step: int = 1 - discriminator_batch_size: int = 256 - policy_variable_name: Optional[str] = None - discriminator_optimizer: Optional[optax.GradientTransformation] = None - replay_table_name: str = 'ail_table' - prefetch_size: int = 4 - discount: float = 0.99 - min_replay_size: int = 1000 - max_replay_size: int = int(1e6) - policy_to_expert_data_ratio: Optional[int] = None - def __post_init__(self): - assert self.direct_rl_batch_size % self.discriminator_batch_size == 0 + direct_rl_batch_size: int + is_sequence_based: bool = False + share_iterator: bool = True + num_sgd_steps_per_step: int = 1 + discriminator_batch_size: int = 256 + policy_variable_name: Optional[str] = None + discriminator_optimizer: Optional[optax.GradientTransformation] = None + replay_table_name: str = "ail_table" + prefetch_size: int = 4 + discount: float = 0.99 + min_replay_size: int = 1000 + max_replay_size: int = int(1e6) + policy_to_expert_data_ratio: Optional[int] = None + + def __post_init__(self): + assert self.direct_rl_batch_size % self.discriminator_batch_size == 0 def get_per_learner_step_batch_size(config: AILConfig) -> int: - """Returns how many transitions should be sampled per direct learner step.""" - # If the iterators are tied, the discriminator learning batch size has to - # match the direct RL one. - if config.share_iterator: - assert (config.direct_rl_batch_size % config.discriminator_batch_size) == 0 - return config.direct_rl_batch_size - # Otherwise each iteration of the discriminator will sample a batch which will - # be split in num_sgd_steps_per_step batches, each of size - # discriminator_batch_size. - return config.discriminator_batch_size * config.num_sgd_steps_per_step + """Returns how many transitions should be sampled per direct learner step.""" + # If the iterators are tied, the discriminator learning batch size has to + # match the direct RL one. + if config.share_iterator: + assert (config.direct_rl_batch_size % config.discriminator_batch_size) == 0 + return config.direct_rl_batch_size + # Otherwise each iteration of the discriminator will sample a batch which will + # be split in num_sgd_steps_per_step batches, each of size + # discriminator_batch_size. + return config.discriminator_batch_size * config.num_sgd_steps_per_step diff --git a/acme/agents/jax/ail/dac.py b/acme/agents/jax/ail/dac.py index d048f6e166..b9b9b9a6ce 100644 --- a/acme/agents/jax/ail/dac.py +++ b/acme/agents/jax/ail/dac.py @@ -30,7 +30,7 @@ @dataclasses.dataclass class DACConfig: - """Configuration options specific to DAC. + """Configuration options specific to DAC. Attributes: ail_config: AIL config. @@ -39,27 +39,31 @@ class DACConfig: gradient_penalty_coefficient: Coefficient for the gradient penalty term in the discriminator loss. """ - ail_config: ail_config.AILConfig - td3_config: td3.TD3Config - entropy_coefficient: float = 1e-3 - gradient_penalty_coefficient: float = 10. + ail_config: ail_config.AILConfig + td3_config: td3.TD3Config + entropy_coefficient: float = 1e-3 + gradient_penalty_coefficient: float = 10.0 -class DACBuilder(builder.AILBuilder[td3.TD3Networks, - actor_core_lib.FeedForwardPolicy]): - """DAC Builder.""" - def __init__(self, config: DACConfig, - make_demonstrations: Callable[[int], - Iterator[types.Transition]]): +class DACBuilder(builder.AILBuilder[td3.TD3Networks, actor_core_lib.FeedForwardPolicy]): + """DAC Builder.""" - td3_builder = td3.TD3Builder(config.td3_config) - dac_loss = losses.add_gradient_penalty( - losses.gail_loss(entropy_coefficient=config.entropy_coefficient), - gradient_penalty_coefficient=config.gradient_penalty_coefficient, - gradient_penalty_target=1.) - super().__init__( - td3_builder, - config=config.ail_config, - discriminator_loss=dac_loss, - make_demonstrations=make_demonstrations) + def __init__( + self, + config: DACConfig, + make_demonstrations: Callable[[int], Iterator[types.Transition]], + ): + + td3_builder = td3.TD3Builder(config.td3_config) + dac_loss = losses.add_gradient_penalty( + losses.gail_loss(entropy_coefficient=config.entropy_coefficient), + gradient_penalty_coefficient=config.gradient_penalty_coefficient, + gradient_penalty_target=1.0, + ) + super().__init__( + td3_builder, + config=config.ail_config, + discriminator_loss=dac_loss, + make_demonstrations=make_demonstrations, + ) diff --git a/acme/agents/jax/ail/gail.py b/acme/agents/jax/ail/gail.py index c5ba29042e..df05557093 100644 --- a/acme/agents/jax/ail/gail.py +++ b/acme/agents/jax/ail/gail.py @@ -30,23 +30,27 @@ @dataclasses.dataclass class GAILConfig: - """Configuration options specific to GAIL.""" - ail_config: ail_config.AILConfig - ppo_config: ppo.PPOConfig - - -class GAILBuilder(builder.AILBuilder[ppo.PPONetworks, - actor_core_lib.FeedForwardPolicyWithExtra] - ): - """GAIL Builder.""" - - def __init__(self, config: GAILConfig, - make_demonstrations: Callable[[int], - Iterator[types.Transition]]): - - ppo_builder = ppo.PPOBuilder(config.ppo_config) - super().__init__( - ppo_builder, - config=config.ail_config, - discriminator_loss=losses.gail_loss(), - make_demonstrations=make_demonstrations) + """Configuration options specific to GAIL.""" + + ail_config: ail_config.AILConfig + ppo_config: ppo.PPOConfig + + +class GAILBuilder( + builder.AILBuilder[ppo.PPONetworks, actor_core_lib.FeedForwardPolicyWithExtra] +): + """GAIL Builder.""" + + def __init__( + self, + config: GAILConfig, + make_demonstrations: Callable[[int], Iterator[types.Transition]], + ): + + ppo_builder = ppo.PPOBuilder(config.ppo_config) + super().__init__( + ppo_builder, + config=config.ail_config, + discriminator_loss=losses.gail_loss(), + make_demonstrations=make_demonstrations, + ) diff --git a/acme/agents/jax/ail/learning.py b/acme/agents/jax/ail/learning.py index 64f625ff3b..7464d9bf4a 100644 --- a/acme/agents/jax/ail/learning.py +++ b/acme/agents/jax/ail/learning.py @@ -18,54 +18,56 @@ import time from typing import Any, Callable, Iterator, List, NamedTuple, Optional, Tuple +import jax +import optax +import reverb + import acme from acme import types from acme.agents.jax.ail import losses from acme.agents.jax.ail import networks as ail_networks from acme.jax import networks as networks_lib from acme.jax import utils -from acme.utils import counting -from acme.utils import loggers -from acme.utils import reverb_utils -import jax -import optax -import reverb +from acme.utils import counting, loggers, reverb_utils class DiscriminatorTrainingState(NamedTuple): - """Contains training state for the discriminator.""" - # State of the optimizer used to optimize the discriminator parameters. - optimizer_state: optax.OptState + """Contains training state for the discriminator.""" - # Parameters of the discriminator. - discriminator_params: networks_lib.Params + # State of the optimizer used to optimize the discriminator parameters. + optimizer_state: optax.OptState - # State of the discriminator - discriminator_state: losses.State + # Parameters of the discriminator. + discriminator_params: networks_lib.Params - # For AIRL variants, we need the policy params to compute the loss. - policy_params: Optional[networks_lib.Params] + # State of the discriminator + discriminator_state: losses.State - # Key for random number generation. - key: networks_lib.PRNGKey + # For AIRL variants, we need the policy params to compute the loss. + policy_params: Optional[networks_lib.Params] - # Training step of the discriminator. - steps: int + # Key for random number generation. + key: networks_lib.PRNGKey + + # Training step of the discriminator. + steps: int class TrainingState(NamedTuple): - """Contains training state of the AIL learner.""" - rewarder_state: DiscriminatorTrainingState - learner_state: Any + """Contains training state of the AIL learner.""" + + rewarder_state: DiscriminatorTrainingState + learner_state: Any def ail_update_step( - state: DiscriminatorTrainingState, data: Tuple[types.Transition, - types.Transition], + state: DiscriminatorTrainingState, + data: Tuple[types.Transition, types.Transition], optimizer: optax.GradientTransformation, ail_network: ail_networks.AILNetworks, - loss_fn: losses.Loss) -> Tuple[DiscriminatorTrainingState, losses.Metrics]: - """Run an update steps on the given transitions. + loss_fn: losses.Loss, +) -> Tuple[DiscriminatorTrainingState, losses.Metrics]: + """Run an update steps on the given transitions. Args: state: The learner state. @@ -77,65 +79,71 @@ def ail_update_step( Returns: A new state and metrics. """ - demo_transitions, rb_transitions = data - key, discriminator_key, loss_key = jax.random.split(state.key, 3) - - def compute_loss( - discriminator_params: networks_lib.Params) -> losses.LossOutput: - discriminator_fn = functools.partial( - ail_network.discriminator_network.apply, - discriminator_params, - state.policy_params, - is_training=True, - rng=discriminator_key) - return loss_fn(discriminator_fn, state.discriminator_state, - demo_transitions, rb_transitions, loss_key) - - loss_grad = jax.grad(compute_loss, has_aux=True) - - grads, (loss, new_discriminator_state) = loss_grad(state.discriminator_params) - - update, optimizer_state = optimizer.update( - grads, - state.optimizer_state, - params=state.discriminator_params) - discriminator_params = optax.apply_updates(state.discriminator_params, update) - - new_state = DiscriminatorTrainingState( - optimizer_state=optimizer_state, - discriminator_params=discriminator_params, - discriminator_state=new_discriminator_state, - policy_params=state.policy_params, # Not modified. - key=key, - steps=state.steps + 1, - ) - return new_state, loss + demo_transitions, rb_transitions = data + key, discriminator_key, loss_key = jax.random.split(state.key, 3) + + def compute_loss(discriminator_params: networks_lib.Params) -> losses.LossOutput: + discriminator_fn = functools.partial( + ail_network.discriminator_network.apply, + discriminator_params, + state.policy_params, + is_training=True, + rng=discriminator_key, + ) + return loss_fn( + discriminator_fn, + state.discriminator_state, + demo_transitions, + rb_transitions, + loss_key, + ) + + loss_grad = jax.grad(compute_loss, has_aux=True) + + grads, (loss, new_discriminator_state) = loss_grad(state.discriminator_params) + + update, optimizer_state = optimizer.update( + grads, state.optimizer_state, params=state.discriminator_params + ) + discriminator_params = optax.apply_updates(state.discriminator_params, update) + + new_state = DiscriminatorTrainingState( + optimizer_state=optimizer_state, + discriminator_params=discriminator_params, + discriminator_state=new_discriminator_state, + policy_params=state.policy_params, # Not modified. + key=key, + steps=state.steps + 1, + ) + return new_state, loss class AILSample(NamedTuple): - discriminator_sample: types.Transition - direct_sample: reverb.ReplaySample - demonstration_sample: types.Transition + discriminator_sample: types.Transition + direct_sample: reverb.ReplaySample + demonstration_sample: types.Transition class AILLearner(acme.Learner): - """AIL learner.""" - - def __init__( - self, - counter: counting.Counter, - direct_rl_learner_factory: Callable[[Iterator[reverb.ReplaySample]], - acme.Learner], - loss_fn: losses.Loss, - iterator: Iterator[AILSample], - discriminator_optimizer: optax.GradientTransformation, - ail_network: ail_networks.AILNetworks, - discriminator_key: networks_lib.PRNGKey, - is_sequence_based: bool, - num_sgd_steps_per_step: int = 1, - policy_variable_name: Optional[str] = None, - logger: Optional[loggers.Logger] = None): - """AIL Learner. + """AIL learner.""" + + def __init__( + self, + counter: counting.Counter, + direct_rl_learner_factory: Callable[ + [Iterator[reverb.ReplaySample]], acme.Learner + ], + loss_fn: losses.Loss, + iterator: Iterator[AILSample], + discriminator_optimizer: optax.GradientTransformation, + ail_network: ail_networks.AILNetworks, + discriminator_key: networks_lib.PRNGKey, + is_sequence_based: bool, + num_sgd_steps_per_step: int = 1, + policy_variable_name: Optional[str] = None, + logger: Optional[loggers.Logger] = None, + ): + """AIL Learner. Args: counter: Counter. @@ -154,70 +162,75 @@ def __init__( direct_rl policy parameters. logger: Logger. """ - self._is_sequence_based = is_sequence_based - - state_key, networks_key = jax.random.split(discriminator_key) - - # Generator expression that works the same as an iterator. - # https://pymbook.readthedocs.io/en/latest/igd.html#generator-expressions - iterator, direct_rl_iterator = itertools.tee(iterator) - direct_rl_iterator = ( - self._process_sample(sample.direct_sample) - for sample in direct_rl_iterator) - self._direct_rl_learner = direct_rl_learner_factory(direct_rl_iterator) - - self._iterator = iterator - - if policy_variable_name is not None: - - def get_policy_params(): - return self._direct_rl_learner.get_variables([policy_variable_name])[0] - - self._get_policy_params = get_policy_params - - else: - self._get_policy_params = lambda: None - - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger( - 'learner', - asynchronous=True, - serialize_fn=utils.fetch_devicearray, - steps_key=self._counter.get_steps_key()) - - # Use the JIT compiler. - self._update_step = functools.partial( - ail_update_step, - optimizer=discriminator_optimizer, - ail_network=ail_network, - loss_fn=loss_fn) - self._update_step = utils.process_multiple_batches(self._update_step, - num_sgd_steps_per_step) - self._update_step = jax.jit(self._update_step) - - discriminator_params, discriminator_state = ( - ail_network.discriminator_network.init(networks_key)) - self._state = DiscriminatorTrainingState( - optimizer_state=discriminator_optimizer.init(discriminator_params), - discriminator_params=discriminator_params, - discriminator_state=discriminator_state, - policy_params=self._get_policy_params(), - key=state_key, - steps=0, - ) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - self._get_reward = jax.jit( - functools.partial( - ail_networks.compute_ail_reward, networks=ail_network)) - - def _process_sample(self, sample: reverb.ReplaySample) -> reverb.ReplaySample: - """Updates the reward of the replay sample. + self._is_sequence_based = is_sequence_based + + state_key, networks_key = jax.random.split(discriminator_key) + + # Generator expression that works the same as an iterator. + # https://pymbook.readthedocs.io/en/latest/igd.html#generator-expressions + iterator, direct_rl_iterator = itertools.tee(iterator) + direct_rl_iterator = ( + self._process_sample(sample.direct_sample) for sample in direct_rl_iterator + ) + self._direct_rl_learner = direct_rl_learner_factory(direct_rl_iterator) + + self._iterator = iterator + + if policy_variable_name is not None: + + def get_policy_params(): + return self._direct_rl_learner.get_variables([policy_variable_name])[0] + + self._get_policy_params = get_policy_params + + else: + self._get_policy_params = lambda: None + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + "learner", + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key(), + ) + + # Use the JIT compiler. + self._update_step = functools.partial( + ail_update_step, + optimizer=discriminator_optimizer, + ail_network=ail_network, + loss_fn=loss_fn, + ) + self._update_step = utils.process_multiple_batches( + self._update_step, num_sgd_steps_per_step + ) + self._update_step = jax.jit(self._update_step) + + ( + discriminator_params, + discriminator_state, + ) = ail_network.discriminator_network.init(networks_key) + self._state = DiscriminatorTrainingState( + optimizer_state=discriminator_optimizer.init(discriminator_params), + discriminator_params=discriminator_params, + discriminator_state=discriminator_state, + policy_params=self._get_policy_params(), + key=state_key, + steps=0, + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + self._get_reward = jax.jit( + functools.partial(ail_networks.compute_ail_reward, networks=ail_network) + ) + + def _process_sample(self, sample: reverb.ReplaySample) -> reverb.ReplaySample: + """Updates the reward of the replay sample. Args: sample: Replay sample to update the reward to. @@ -225,69 +238,75 @@ def _process_sample(self, sample: reverb.ReplaySample) -> reverb.ReplaySample: Returns: The replay sample with an updated reward. """ - transitions = reverb_utils.replay_sample_to_sars_transition( - sample, is_sequence=self._is_sequence_based) - rewards = self._get_reward(self._state.discriminator_params, - self._state.discriminator_state, - self._state.policy_params, transitions) - - return sample._replace(data=sample.data._replace(reward=rewards)) - - def step(self): - sample = next(self._iterator) - rb_transitions = sample.discriminator_sample - demo_transitions = sample.demonstration_sample - - if demo_transitions.reward.shape != rb_transitions.reward.shape: - raise ValueError( - 'Different shapes for demo transitions and rb_transitions: ' - f'{demo_transitions.reward.shape} != {rb_transitions.reward.shape}') - - # Update the parameters of the policy before doing a gradient step. - state = self._state._replace(policy_params=self._get_policy_params()) - self._state, metrics = self._update_step(state, - (demo_transitions, rb_transitions)) - - # The order is important for AIRL. - # In AIRL, the discriminator update depends on the logpi of the direct rl - # policy. - # When updating the discriminator, we want the logpi for which the - # transitions were made with and not an updated one. - # Get data from replay (dropping extras if any). Note there is no - # extra data here because we do not insert any into Reverb. - self._direct_rl_learner.step() - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Increment counts and record the current time. - counts = self._counter.increment(steps=1, walltime=elapsed_time) - - # Attempts to write the logs. - self._logger.write({**metrics, **counts}) - - def get_variables(self, names: List[str]) -> List[Any]: - rewarder_dict = {'discriminator': self._state.discriminator_params} - - learner_names = [name for name in names if name not in rewarder_dict] - learner_dict = {} - if learner_names: - learner_dict = dict( - zip(learner_names, - self._direct_rl_learner.get_variables(learner_names))) - - variables = [ - rewarder_dict.get(name, learner_dict.get(name, None)) for name in names - ] - return variables - - def save(self) -> TrainingState: - return TrainingState( - rewarder_state=self._state, - learner_state=self._direct_rl_learner.save()) - - def restore(self, state: TrainingState): - self._state = state.rewarder_state - self._direct_rl_learner.restore(state.learner_state) + transitions = reverb_utils.replay_sample_to_sars_transition( + sample, is_sequence=self._is_sequence_based + ) + rewards = self._get_reward( + self._state.discriminator_params, + self._state.discriminator_state, + self._state.policy_params, + transitions, + ) + + return sample._replace(data=sample.data._replace(reward=rewards)) + + def step(self): + sample = next(self._iterator) + rb_transitions = sample.discriminator_sample + demo_transitions = sample.demonstration_sample + + if demo_transitions.reward.shape != rb_transitions.reward.shape: + raise ValueError( + "Different shapes for demo transitions and rb_transitions: " + f"{demo_transitions.reward.shape} != {rb_transitions.reward.shape}" + ) + + # Update the parameters of the policy before doing a gradient step. + state = self._state._replace(policy_params=self._get_policy_params()) + self._state, metrics = self._update_step( + state, (demo_transitions, rb_transitions) + ) + + # The order is important for AIRL. + # In AIRL, the discriminator update depends on the logpi of the direct rl + # policy. + # When updating the discriminator, we want the logpi for which the + # transitions were made with and not an updated one. + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + self._direct_rl_learner.step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[Any]: + rewarder_dict = {"discriminator": self._state.discriminator_params} + + learner_names = [name for name in names if name not in rewarder_dict] + learner_dict = {} + if learner_names: + learner_dict = dict( + zip(learner_names, self._direct_rl_learner.get_variables(learner_names)) + ) + + variables = [ + rewarder_dict.get(name, learner_dict.get(name, None)) for name in names + ] + return variables + + def save(self) -> TrainingState: + return TrainingState( + rewarder_state=self._state, learner_state=self._direct_rl_learner.save() + ) + + def restore(self, state: TrainingState): + self._state = state.rewarder_state + self._direct_rl_learner.restore(state.learner_state) diff --git a/acme/agents/jax/ail/learning_test.py b/acme/agents/jax/ail/learning_test.py index 9d1eb76db5..29a0800ca5 100644 --- a/acme/agents/jax/ail/learning_test.py +++ b/acme/agents/jax/ail/learning_test.py @@ -15,84 +15,89 @@ """Tests for the AIL learner.""" import functools -from acme import specs -from acme import types -from acme.agents.jax.ail import learning as ail_learning -from acme.agents.jax.ail import losses -from acme.agents.jax.ail import networks as ail_networks -from acme.jax import networks as networks_lib -from acme.jax import utils import haiku as hk import jax import numpy as np import optax - from absl.testing import absltest +from acme import specs, types +from acme.agents.jax.ail import learning as ail_learning +from acme.agents.jax.ail import losses +from acme.agents.jax.ail import networks as ail_networks +from acme.jax import networks as networks_lib +from acme.jax import utils -def _make_discriminator(spec): - def discriminator(*args, **kwargs) -> networks_lib.Logits: - return ail_networks.DiscriminatorModule( - environment_spec=spec, - use_action=False, - use_next_obs=False, - network_core=ail_networks.DiscriminatorMLP([]))(*args, **kwargs) - discriminator_transformed = hk.without_apply_rng( - hk.transform_with_state(discriminator)) - return ail_networks.make_discriminator( - environment_spec=spec, - discriminator_transformed=discriminator_transformed) +def _make_discriminator(spec): + def discriminator(*args, **kwargs) -> networks_lib.Logits: + return ail_networks.DiscriminatorModule( + environment_spec=spec, + use_action=False, + use_next_obs=False, + network_core=ail_networks.DiscriminatorMLP([]), + )(*args, **kwargs) + + discriminator_transformed = hk.without_apply_rng( + hk.transform_with_state(discriminator) + ) + return ail_networks.make_discriminator( + environment_spec=spec, discriminator_transformed=discriminator_transformed + ) class AilLearnerTest(absltest.TestCase): + def test_step(self): + simple_spec = specs.Array(shape=(), dtype=float) - def test_step(self): - simple_spec = specs.Array(shape=(), dtype=float) + spec = specs.EnvironmentSpec(simple_spec, simple_spec, simple_spec, simple_spec) - spec = specs.EnvironmentSpec(simple_spec, simple_spec, simple_spec, - simple_spec) + discriminator = _make_discriminator(spec) + ail_network = ail_networks.AILNetworks( + discriminator, imitation_reward_fn=lambda x: x, direct_rl_networks=None + ) - discriminator = _make_discriminator(spec) - ail_network = ail_networks.AILNetworks( - discriminator, imitation_reward_fn=lambda x: x, direct_rl_networks=None) + loss = losses.gail_loss() - loss = losses.gail_loss() + optimizer = optax.adam(0.01) - optimizer = optax.adam(.01) + step = jax.jit( + functools.partial( + ail_learning.ail_update_step, + optimizer=optimizer, + ail_network=ail_network, + loss_fn=loss, + ) + ) - step = jax.jit(functools.partial( - ail_learning.ail_update_step, - optimizer=optimizer, - ail_network=ail_network, - loss_fn=loss)) + zero_transition = types.Transition( + np.array([0.0]), np.array([0.0]), 0.0, 0.0, np.array([0.0]) + ) + zero_transition = utils.add_batch_dim(zero_transition) - zero_transition = types.Transition( - np.array([0.]), np.array([0.]), 0., 0., np.array([0.])) - zero_transition = utils.add_batch_dim(zero_transition) + one_transition = types.Transition( + np.array([1.0]), np.array([0.0]), 0.0, 0.0, np.array([0.0]) + ) + one_transition = utils.add_batch_dim(one_transition) - one_transition = types.Transition( - np.array([1.]), np.array([0.]), 0., 0., np.array([0.])) - one_transition = utils.add_batch_dim(one_transition) + key = jax.random.PRNGKey(0) + discriminator_params, discriminator_state = discriminator.init(key) - key = jax.random.PRNGKey(0) - discriminator_params, discriminator_state = discriminator.init(key) - - state = ail_learning.DiscriminatorTrainingState( - optimizer_state=optimizer.init(discriminator_params), - discriminator_params=discriminator_params, - discriminator_state=discriminator_state, - policy_params=None, - key=key, - steps=0, - ) + state = ail_learning.DiscriminatorTrainingState( + optimizer_state=optimizer.init(discriminator_params), + discriminator_params=discriminator_params, + discriminator_state=discriminator_state, + policy_params=None, + key=key, + steps=0, + ) - expected_loss = [1.062, 1.057, 1.052] + expected_loss = [1.062, 1.057, 1.052] - for i in range(3): - state, loss = step(state, (one_transition, zero_transition)) - self.assertAlmostEqual(loss['total_loss'], expected_loss[i], places=3) + for i in range(3): + state, loss = step(state, (one_transition, zero_transition)) + self.assertAlmostEqual(loss["total_loss"], expected_loss[i], places=3) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/jax/ail/losses.py b/acme/agents/jax/ail/losses.py index 83f88723ff..87969d3801 100644 --- a/acme/agents/jax/ail/losses.py +++ b/acme/agents/jax/ail/losses.py @@ -17,13 +17,14 @@ import functools from typing import Callable, Dict, Optional, Tuple -from acme import types -from acme.jax import networks as networks_lib import jax import jax.numpy as jnp import tensorflow_probability as tfp import tree +from acme import types +from acme.jax import networks as networks_lib + tfp = tfp.experimental.substrates.jax tfd = tfp.distributions @@ -35,202 +36,240 @@ DiscriminatorFn = Callable[[State, types.Transition], DiscriminatorOutput] Metrics = Dict[str, float] LossOutput = Tuple[float, Tuple[Metrics, State]] -Loss = Callable[[ - DiscriminatorFn, State, types.Transition, types.Transition, networks_lib - .PRNGKey -], LossOutput] +Loss = Callable[ + [DiscriminatorFn, State, types.Transition, types.Transition, networks_lib.PRNGKey], + LossOutput, +] -def _binary_cross_entropy_loss(logit: jnp.ndarray, - label: jnp.ndarray) -> jnp.ndarray: - return label * jax.nn.softplus(-logit) + (1 - label) * jax.nn.softplus(logit) +def _binary_cross_entropy_loss(logit: jnp.ndarray, label: jnp.ndarray) -> jnp.ndarray: + return label * jax.nn.softplus(-logit) + (1 - label) * jax.nn.softplus(logit) @jax.vmap -def _weighted_average(x: jnp.ndarray, y: jnp.ndarray, - lambdas: jnp.ndarray) -> jnp.ndarray: - return lambdas * x + (1. - lambdas) * y +def _weighted_average( + x: jnp.ndarray, y: jnp.ndarray, lambdas: jnp.ndarray +) -> jnp.ndarray: + return lambdas * x + (1.0 - lambdas) * y def _label_data( rb_transitions: types.Transition, - demonstration_transitions: types.Transition, mixup_alpha: Optional[float], - key: networks_lib.PRNGKey) -> Tuple[types.Transition, jnp.ndarray]: - """Create a tuple data, labels by concatenating the rb and dem transitions.""" - data = tree.map_structure(lambda x, y: jnp.concatenate([x, y]), - rb_transitions, demonstration_transitions) - labels = jnp.concatenate([ - jnp.zeros(rb_transitions.reward.shape), - jnp.ones(demonstration_transitions.reward.shape) - ]) - - if mixup_alpha is not None: - lambda_key, mixup_key = jax.random.split(key) - - lambdas = tfd.Beta(mixup_alpha, mixup_alpha).sample( - len(labels), seed=lambda_key) - - shuffled_data = tree.map_structure( - lambda x: jax.random.permutation(key=mixup_key, x=x), data) - shuffled_labels = jax.random.permutation(key=mixup_key, x=labels) - - data = tree.map_structure(lambda x, y: _weighted_average(x, y, lambdas), - data, shuffled_data) - labels = _weighted_average(labels, shuffled_labels, lambdas) - - return data, labels + demonstration_transitions: types.Transition, + mixup_alpha: Optional[float], + key: networks_lib.PRNGKey, +) -> Tuple[types.Transition, jnp.ndarray]: + """Create a tuple data, labels by concatenating the rb and dem transitions.""" + data = tree.map_structure( + lambda x, y: jnp.concatenate([x, y]), rb_transitions, demonstration_transitions + ) + labels = jnp.concatenate( + [ + jnp.zeros(rb_transitions.reward.shape), + jnp.ones(demonstration_transitions.reward.shape), + ] + ) + + if mixup_alpha is not None: + lambda_key, mixup_key = jax.random.split(key) + + lambdas = tfd.Beta(mixup_alpha, mixup_alpha).sample( + len(labels), seed=lambda_key + ) + + shuffled_data = tree.map_structure( + lambda x: jax.random.permutation(key=mixup_key, x=x), data + ) + shuffled_labels = jax.random.permutation(key=mixup_key, x=labels) + + data = tree.map_structure( + lambda x, y: _weighted_average(x, y, lambdas), data, shuffled_data + ) + labels = _weighted_average(labels, shuffled_labels, lambdas) + + return data, labels def _logit_bernoulli_entropy(logits: networks_lib.Logits) -> jnp.ndarray: - return (1. - jax.nn.sigmoid(logits)) * logits - jax.nn.log_sigmoid(logits) - - -def gail_loss(entropy_coefficient: float = 0., - mixup_alpha: Optional[float] = None) -> Loss: - """Computes the standard GAIL loss.""" - - def loss_fn( - discriminator_fn: DiscriminatorFn, - discriminator_state: State, - demo_transitions: types.Transition, rb_transitions: types.Transition, - rng_key: networks_lib.PRNGKey) -> LossOutput: - - data, labels = _label_data( - rb_transitions=rb_transitions, - demonstration_transitions=demo_transitions, - mixup_alpha=mixup_alpha, - key=rng_key) - logits, discriminator_state = discriminator_fn(discriminator_state, data) - - classification_loss = jnp.mean(_binary_cross_entropy_loss(logits, labels)) - - entropy = jnp.mean(_logit_bernoulli_entropy(logits)) - entropy_loss = -entropy_coefficient * entropy - - total_loss = classification_loss + entropy_loss - - metrics = { - 'total_loss': total_loss, - 'entropy_loss': entropy_loss, - 'classification_loss': classification_loss - } - return total_loss, (metrics, discriminator_state) - - return loss_fn - - -def pugail_loss(positive_class_prior: float, - entropy_coefficient: float, - pugail_beta: Optional[float] = None) -> Loss: - """Computes the PUGAIL loss (https://arxiv.org/pdf/1911.00459.pdf).""" - - def loss_fn( - discriminator_fn: DiscriminatorFn, - discriminator_state: State, - demo_transitions: types.Transition, rb_transitions: types.Transition, - rng_key: networks_lib.PRNGKey) -> LossOutput: - del rng_key - - demo_logits, discriminator_state = discriminator_fn(discriminator_state, - demo_transitions) - rb_logits, discriminator_state = discriminator_fn(discriminator_state, - rb_transitions) - - # Quick Maths: - # output = logit(D) = ln(D) - ln(1-D) - # -softplus(-output) = ln(D) - # softplus(output) = -ln(1-D) - - # prior * -ln(D(expert)) - positive_loss = positive_class_prior * -jax.nn.log_sigmoid(demo_logits) - # -ln(1 - D(policy)) - prior * -ln(1 - D(expert)) - negative_loss = jax.nn.softplus( - rb_logits) - positive_class_prior * jax.nn.softplus(demo_logits) - if pugail_beta is not None: - negative_loss = jnp.clip(negative_loss, a_min=-1. * pugail_beta) - - classification_loss = jnp.mean(positive_loss + negative_loss) - - entropy = jnp.mean( - _logit_bernoulli_entropy(jnp.concatenate([demo_logits, rb_logits]))) - entropy_loss = -entropy_coefficient * entropy - - total_loss = classification_loss + entropy_loss - - metrics = { - 'total_loss': total_loss, - 'positive_loss': jnp.mean(positive_loss), - 'negative_loss': jnp.mean(negative_loss), - 'demo_logits': jnp.mean(demo_logits), - 'rb_logits': jnp.mean(rb_logits), - 'entropy_loss': entropy_loss, - 'classification_loss': classification_loss - } - return total_loss, (metrics, discriminator_state) - - return loss_fn - - -def _make_gradient_penalty_data(rb_transitions: types.Transition, - demonstration_transitions: types.Transition, - key: networks_lib.PRNGKey) -> types.Transition: - lambdas = tfd.Uniform().sample(len(rb_transitions.reward), seed=key) - return tree.map_structure(lambda x, y: _weighted_average(x, y, lambdas), - rb_transitions, demonstration_transitions) + return (1.0 - jax.nn.sigmoid(logits)) * logits - jax.nn.log_sigmoid(logits) + + +def gail_loss( + entropy_coefficient: float = 0.0, mixup_alpha: Optional[float] = None +) -> Loss: + """Computes the standard GAIL loss.""" + + def loss_fn( + discriminator_fn: DiscriminatorFn, + discriminator_state: State, + demo_transitions: types.Transition, + rb_transitions: types.Transition, + rng_key: networks_lib.PRNGKey, + ) -> LossOutput: + + data, labels = _label_data( + rb_transitions=rb_transitions, + demonstration_transitions=demo_transitions, + mixup_alpha=mixup_alpha, + key=rng_key, + ) + logits, discriminator_state = discriminator_fn(discriminator_state, data) + + classification_loss = jnp.mean(_binary_cross_entropy_loss(logits, labels)) + + entropy = jnp.mean(_logit_bernoulli_entropy(logits)) + entropy_loss = -entropy_coefficient * entropy + + total_loss = classification_loss + entropy_loss + + metrics = { + "total_loss": total_loss, + "entropy_loss": entropy_loss, + "classification_loss": classification_loss, + } + return total_loss, (metrics, discriminator_state) + + return loss_fn + + +def pugail_loss( + positive_class_prior: float, + entropy_coefficient: float, + pugail_beta: Optional[float] = None, +) -> Loss: + """Computes the PUGAIL loss (https://arxiv.org/pdf/1911.00459.pdf).""" + + def loss_fn( + discriminator_fn: DiscriminatorFn, + discriminator_state: State, + demo_transitions: types.Transition, + rb_transitions: types.Transition, + rng_key: networks_lib.PRNGKey, + ) -> LossOutput: + del rng_key + + demo_logits, discriminator_state = discriminator_fn( + discriminator_state, demo_transitions + ) + rb_logits, discriminator_state = discriminator_fn( + discriminator_state, rb_transitions + ) + + # Quick Maths: + # output = logit(D) = ln(D) - ln(1-D) + # -softplus(-output) = ln(D) + # softplus(output) = -ln(1-D) + + # prior * -ln(D(expert)) + positive_loss = positive_class_prior * -jax.nn.log_sigmoid(demo_logits) + # -ln(1 - D(policy)) - prior * -ln(1 - D(expert)) + negative_loss = jax.nn.softplus( + rb_logits + ) - positive_class_prior * jax.nn.softplus(demo_logits) + if pugail_beta is not None: + negative_loss = jnp.clip(negative_loss, a_min=-1.0 * pugail_beta) + + classification_loss = jnp.mean(positive_loss + negative_loss) + + entropy = jnp.mean( + _logit_bernoulli_entropy(jnp.concatenate([demo_logits, rb_logits])) + ) + entropy_loss = -entropy_coefficient * entropy + + total_loss = classification_loss + entropy_loss + + metrics = { + "total_loss": total_loss, + "positive_loss": jnp.mean(positive_loss), + "negative_loss": jnp.mean(negative_loss), + "demo_logits": jnp.mean(demo_logits), + "rb_logits": jnp.mean(rb_logits), + "entropy_loss": entropy_loss, + "classification_loss": classification_loss, + } + return total_loss, (metrics, discriminator_state) + + return loss_fn + + +def _make_gradient_penalty_data( + rb_transitions: types.Transition, + demonstration_transitions: types.Transition, + key: networks_lib.PRNGKey, +) -> types.Transition: + lambdas = tfd.Uniform().sample(len(rb_transitions.reward), seed=key) + return tree.map_structure( + lambda x, y: _weighted_average(x, y, lambdas), + rb_transitions, + demonstration_transitions, + ) @functools.partial(jax.vmap, in_axes=(0, None, None)) -def _compute_gradient_penalty(gradient_penalty_data: types.Transition, - discriminator_fn: Callable[[types.Transition], - float], - gradient_penalty_target: float) -> float: - """Computes a penalty based on the gradient norm on the data.""" - # The input should not be batched. - assert not gradient_penalty_data.reward.shape - discriminator_gradient_fn = jax.grad(discriminator_fn) - gradients = discriminator_gradient_fn(gradient_penalty_data) - gradients = tree.map_structure(lambda x: x.flatten(), gradients) - gradients = jnp.concatenate([gradients.observation, gradients.action, - gradients.next_observation]) - gradient_norms = jnp.linalg.norm(gradients + 1e-8) - k = gradient_penalty_target * jnp.ones_like(gradient_norms) - return jnp.mean(jnp.square(gradient_norms - k)) - - -def add_gradient_penalty(base_loss: Loss, - gradient_penalty_coefficient: float, - gradient_penalty_target: float) -> Loss: - """Adds a gradient penalty to the base_loss.""" - - if not gradient_penalty_coefficient: - return base_loss - - def loss_fn(discriminator_fn: DiscriminatorFn, - discriminator_state: State, - demo_transitions: types.Transition, - rb_transitions: types.Transition, - rng_key: networks_lib.PRNGKey) -> LossOutput: - super_key, gradient_penalty_key = jax.random.split(rng_key) - - partial_loss, (losses, discriminator_state) = base_loss( - discriminator_fn, discriminator_state, demo_transitions, rb_transitions, - super_key) - - gradient_penalty_data = _make_gradient_penalty_data( - rb_transitions=rb_transitions, - demonstration_transitions=demo_transitions, - key=gradient_penalty_key) - def apply_discriminator_fn(transitions: types.Transition) -> float: - logits, _ = discriminator_fn(discriminator_state, transitions) - return logits # pytype: disable=bad-return-type # jax-ndarray - gradient_penalty = gradient_penalty_coefficient * jnp.mean( - _compute_gradient_penalty(gradient_penalty_data, apply_discriminator_fn, - gradient_penalty_target)) - - losses['gradient_penalty'] = gradient_penalty - total_loss = partial_loss + gradient_penalty - losses['total_loss'] = total_loss - - return total_loss, (losses, discriminator_state) - - return loss_fn +def _compute_gradient_penalty( + gradient_penalty_data: types.Transition, + discriminator_fn: Callable[[types.Transition], float], + gradient_penalty_target: float, +) -> float: + """Computes a penalty based on the gradient norm on the data.""" + # The input should not be batched. + assert not gradient_penalty_data.reward.shape + discriminator_gradient_fn = jax.grad(discriminator_fn) + gradients = discriminator_gradient_fn(gradient_penalty_data) + gradients = tree.map_structure(lambda x: x.flatten(), gradients) + gradients = jnp.concatenate( + [gradients.observation, gradients.action, gradients.next_observation] + ) + gradient_norms = jnp.linalg.norm(gradients + 1e-8) + k = gradient_penalty_target * jnp.ones_like(gradient_norms) + return jnp.mean(jnp.square(gradient_norms - k)) + + +def add_gradient_penalty( + base_loss: Loss, gradient_penalty_coefficient: float, gradient_penalty_target: float +) -> Loss: + """Adds a gradient penalty to the base_loss.""" + + if not gradient_penalty_coefficient: + return base_loss + + def loss_fn( + discriminator_fn: DiscriminatorFn, + discriminator_state: State, + demo_transitions: types.Transition, + rb_transitions: types.Transition, + rng_key: networks_lib.PRNGKey, + ) -> LossOutput: + super_key, gradient_penalty_key = jax.random.split(rng_key) + + partial_loss, (losses, discriminator_state) = base_loss( + discriminator_fn, + discriminator_state, + demo_transitions, + rb_transitions, + super_key, + ) + + gradient_penalty_data = _make_gradient_penalty_data( + rb_transitions=rb_transitions, + demonstration_transitions=demo_transitions, + key=gradient_penalty_key, + ) + + def apply_discriminator_fn(transitions: types.Transition) -> float: + logits, _ = discriminator_fn(discriminator_state, transitions) + return logits # pytype: disable=bad-return-type # jax-ndarray + + gradient_penalty = gradient_penalty_coefficient * jnp.mean( + _compute_gradient_penalty( + gradient_penalty_data, apply_discriminator_fn, gradient_penalty_target + ) + ) + + losses["gradient_penalty"] = gradient_penalty + total_loss = partial_loss + gradient_penalty + losses["total_loss"] = total_loss + + return total_loss, (losses, discriminator_state) + + return loss_fn diff --git a/acme/agents/jax/ail/losses_test.py b/acme/agents/jax/ail/losses_test.py index e38943a873..07a04afb3e 100644 --- a/acme/agents/jax/ail/losses_test.py +++ b/acme/agents/jax/ail/losses_test.py @@ -14,66 +14,71 @@ """Tests for the AIL discriminator losses.""" -from acme import types -from acme.agents.jax.ail import losses -from acme.jax import networks as networks_lib import jax import jax.numpy as jnp import tree - from absl.testing import absltest - -class AilLossTest(absltest.TestCase): - - def test_gradient_penalty(self): - - def dummy_discriminator( - transition: types.Transition) -> networks_lib.Logits: - return transition.observation + jnp.square(transition.action) - - zero_transition = types.Transition(0., 0., 0., 0., 0.) - zero_transition = tree.map_structure(lambda x: jnp.expand_dims(x, axis=0), - zero_transition) - self.assertEqual( - losses._compute_gradient_penalty(zero_transition, dummy_discriminator, - 0.), 1**2 + 0**2) - - one_transition = types.Transition(1., 1., 0., 0., 0.) - one_transition = tree.map_structure(lambda x: jnp.expand_dims(x, axis=0), - one_transition) - self.assertEqual( - losses._compute_gradient_penalty(one_transition, dummy_discriminator, - 0.), 1**2 + 2**2) - - def test_pugail(self): - - def dummy_discriminator( - state: losses.State, - transition: types.Transition) -> losses.DiscriminatorOutput: - return transition.observation, state - - zero_transition = types.Transition(.1, 0., 0., 0., 0.) - zero_transition = tree.map_structure(lambda x: jnp.expand_dims(x, axis=0), - zero_transition) - - one_transition = types.Transition(1., 0., 0., 0., 0.) - one_transition = tree.map_structure(lambda x: jnp.expand_dims(x, axis=0), - one_transition) - - prior = .7 - loss_fn = losses.pugail_loss( - positive_class_prior=prior, entropy_coefficient=0.) - loss, _ = loss_fn(dummy_discriminator, {}, one_transition, - zero_transition, ()) - - d_one = jax.nn.sigmoid(dummy_discriminator({}, one_transition)[0]) - d_zero = jax.nn.sigmoid(dummy_discriminator({}, zero_transition)[0]) - expected_loss = -prior * jnp.log( - d_one) + -jnp.log(1. - d_zero) - prior * -jnp.log(1 - d_one) - - self.assertAlmostEqual(loss, expected_loss, places=6) +from acme import types +from acme.agents.jax.ail import losses +from acme.jax import networks as networks_lib -if __name__ == '__main__': - absltest.main() +class AilLossTest(absltest.TestCase): + def test_gradient_penalty(self): + def dummy_discriminator(transition: types.Transition) -> networks_lib.Logits: + return transition.observation + jnp.square(transition.action) + + zero_transition = types.Transition(0.0, 0.0, 0.0, 0.0, 0.0) + zero_transition = tree.map_structure( + lambda x: jnp.expand_dims(x, axis=0), zero_transition + ) + self.assertEqual( + losses._compute_gradient_penalty(zero_transition, dummy_discriminator, 0.0), + 1 ** 2 + 0 ** 2, + ) + + one_transition = types.Transition(1.0, 1.0, 0.0, 0.0, 0.0) + one_transition = tree.map_structure( + lambda x: jnp.expand_dims(x, axis=0), one_transition + ) + self.assertEqual( + losses._compute_gradient_penalty(one_transition, dummy_discriminator, 0.0), + 1 ** 2 + 2 ** 2, + ) + + def test_pugail(self): + def dummy_discriminator( + state: losses.State, transition: types.Transition + ) -> losses.DiscriminatorOutput: + return transition.observation, state + + zero_transition = types.Transition(0.1, 0.0, 0.0, 0.0, 0.0) + zero_transition = tree.map_structure( + lambda x: jnp.expand_dims(x, axis=0), zero_transition + ) + + one_transition = types.Transition(1.0, 0.0, 0.0, 0.0, 0.0) + one_transition = tree.map_structure( + lambda x: jnp.expand_dims(x, axis=0), one_transition + ) + + prior = 0.7 + loss_fn = losses.pugail_loss( + positive_class_prior=prior, entropy_coefficient=0.0 + ) + loss, _ = loss_fn(dummy_discriminator, {}, one_transition, zero_transition, ()) + + d_one = jax.nn.sigmoid(dummy_discriminator({}, one_transition)[0]) + d_zero = jax.nn.sigmoid(dummy_discriminator({}, zero_transition)[0]) + expected_loss = ( + -prior * jnp.log(d_one) + + -jnp.log(1.0 - d_zero) + - prior * -jnp.log(1 - d_one) + ) + + self.assertAlmostEqual(loss, expected_loss, places=6) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/jax/ail/networks.py b/acme/agents/jax/ail/networks.py index 38eab6c721..7561097176 100644 --- a/acme/agents/jax/ail/networks.py +++ b/acme/agents/jax/ail/networks.py @@ -20,15 +20,15 @@ import functools from typing import Any, Callable, Generic, Iterable, Optional -from acme import specs -from acme import types -from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.jax.imitation_learning_types import DirectRLNetworks import haiku as hk import jax -from jax import numpy as jnp import numpy as np +from jax import numpy as jnp + +from acme import specs, types +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.jax.imitation_learning_types import DirectRLNetworks # Function from discriminator logit to imitation reward. ImitationRewardFn = Callable[[networks_lib.Logits], jnp.ndarray] @@ -37,7 +37,7 @@ @dataclasses.dataclass class AILNetworks(Generic[DirectRLNetworks]): - """AIL networks data class. + """AIL networks data class. Attributes: discriminator_network: Networks which takes as input: @@ -48,17 +48,20 @@ class AILNetworks(Generic[DirectRLNetworks]): reward. direct_rl_networks: Networks of the direct RL algorithm. """ - discriminator_network: networks_lib.FeedForwardNetwork - imitation_reward_fn: ImitationRewardFn - direct_rl_networks: DirectRLNetworks + + discriminator_network: networks_lib.FeedForwardNetwork + imitation_reward_fn: ImitationRewardFn + direct_rl_networks: DirectRLNetworks -def compute_ail_reward(discriminator_params: networks_lib.Params, - discriminator_state: State, - policy_params: Optional[networks_lib.Params], - transitions: types.Transition, - networks: AILNetworks) -> jnp.ndarray: - """Computes the AIL reward for a given transition. +def compute_ail_reward( + discriminator_params: networks_lib.Params, + discriminator_state: State, + policy_params: Optional[networks_lib.Params], + transitions: types.Transition, + networks: AILNetworks, +) -> jnp.ndarray: + """Computes the AIL reward for a given transition. Args: discriminator_params: Parameters of the discriminator network. @@ -70,18 +73,19 @@ def compute_ail_reward(discriminator_params: networks_lib.Params, Returns: The rewards as an ndarray. """ - logits, _ = networks.discriminator_network.apply( - discriminator_params, - policy_params, - discriminator_state, - transitions, - is_training=False, - rng=None) - return networks.imitation_reward_fn(logits) + logits, _ = networks.discriminator_network.apply( + discriminator_params, + policy_params, + discriminator_state, + transitions, + is_training=False, + rng=None, + ) + return networks.imitation_reward_fn(logits) class SpectralNormalizedLinear(hk.Module): - """SpectralNormalizedLinear module. + """SpectralNormalizedLinear module. This is a Linear layer with a upper-bounded Lipschitz. It is used in iResNet. @@ -90,16 +94,16 @@ class SpectralNormalizedLinear(hk.Module): https://arxiv.org/pdf/1811.00995.pdf """ - def __init__( - self, - output_size: int, - lipschitz_coeff: float, - with_bias: bool = True, - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, - name: Optional[str] = None, - ): - """Constructs the SpectralNormalizedLinear module. + def __init__( + self, + output_size: int, + lipschitz_coeff: float, + with_bias: bool = True, + w_init: Optional[hk.initializers.Initializer] = None, + b_init: Optional[hk.initializers.Initializer] = None, + name: Optional[str] = None, + ): + """Constructs the SpectralNormalizedLinear module. Args: output_size: Output dimensionality. @@ -111,85 +115,85 @@ def __init__( b_init: Optional initializer for bias. By default, zero. name: Name of the module. """ - super().__init__(name=name) - self.input_size = None - self.output_size = output_size - self.with_bias = with_bias - self.w_init = w_init - self.b_init = b_init or jnp.zeros - self.lipschitz_coeff = lipschitz_coeff - self.num_iterations = 100 - self.eps = 1e-6 - - def get_normalized_weights(self, - weights: jnp.ndarray, - renormalize: bool = False) -> jnp.ndarray: - - def _l2_normalize(x, axis=None, eps=1e-12): - return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps) - - output_size = self.output_size - dtype = weights.dtype - assert output_size == weights.shape[-1] - sigma = hk.get_state('sigma', (), init=jnp.ones) - if renormalize: - # Power iterations to compute spectral norm V*W*U^T. - u = hk.get_state( - 'u', (1, output_size), dtype, init=hk.initializers.RandomNormal()) - for _ in range(self.num_iterations): - v = _l2_normalize(jnp.matmul(u, weights.transpose()), eps=self.eps) - u = _l2_normalize(jnp.matmul(v, weights), eps=self.eps) - u = jax.lax.stop_gradient(u) - v = jax.lax.stop_gradient(v) - sigma = jnp.matmul(jnp.matmul(v, weights), jnp.transpose(u))[0, 0] - hk.set_state('u', u) - hk.set_state('v', v) - hk.set_state('sigma', sigma) - factor = jnp.maximum(1, sigma / self.lipschitz_coeff) - return weights / factor - - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: - """Computes a linear transform of the input.""" - if not inputs.shape: - raise ValueError('Input must not be scalar.') - - input_size = self.input_size = inputs.shape[-1] - output_size = self.output_size - dtype = inputs.dtype - - w_init = self.w_init - if w_init is None: - stddev = 1. / np.sqrt(self.input_size) - w_init = hk.initializers.TruncatedNormal(stddev=stddev) - w = hk.get_parameter('w', [input_size, output_size], dtype, init=w_init) - w = self.get_normalized_weights(w, renormalize=True) - - out = jnp.dot(inputs, w) - - if self.with_bias: - b = hk.get_parameter('b', [self.output_size], dtype, init=self.b_init) - b = jnp.broadcast_to(b, out.shape) - out = out + b - - return out + super().__init__(name=name) + self.input_size = None + self.output_size = output_size + self.with_bias = with_bias + self.w_init = w_init + self.b_init = b_init or jnp.zeros + self.lipschitz_coeff = lipschitz_coeff + self.num_iterations = 100 + self.eps = 1e-6 + + def get_normalized_weights( + self, weights: jnp.ndarray, renormalize: bool = False + ) -> jnp.ndarray: + def _l2_normalize(x, axis=None, eps=1e-12): + return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps) + + output_size = self.output_size + dtype = weights.dtype + assert output_size == weights.shape[-1] + sigma = hk.get_state("sigma", (), init=jnp.ones) + if renormalize: + # Power iterations to compute spectral norm V*W*U^T. + u = hk.get_state( + "u", (1, output_size), dtype, init=hk.initializers.RandomNormal() + ) + for _ in range(self.num_iterations): + v = _l2_normalize(jnp.matmul(u, weights.transpose()), eps=self.eps) + u = _l2_normalize(jnp.matmul(v, weights), eps=self.eps) + u = jax.lax.stop_gradient(u) + v = jax.lax.stop_gradient(v) + sigma = jnp.matmul(jnp.matmul(v, weights), jnp.transpose(u))[0, 0] + hk.set_state("u", u) + hk.set_state("v", v) + hk.set_state("sigma", sigma) + factor = jnp.maximum(1, sigma / self.lipschitz_coeff) + return weights / factor + + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + """Computes a linear transform of the input.""" + if not inputs.shape: + raise ValueError("Input must not be scalar.") + + input_size = self.input_size = inputs.shape[-1] + output_size = self.output_size + dtype = inputs.dtype + + w_init = self.w_init + if w_init is None: + stddev = 1.0 / np.sqrt(self.input_size) + w_init = hk.initializers.TruncatedNormal(stddev=stddev) + w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init) + w = self.get_normalized_weights(w, renormalize=True) + + out = jnp.dot(inputs, w) + + if self.with_bias: + b = hk.get_parameter("b", [self.output_size], dtype, init=self.b_init) + b = jnp.broadcast_to(b, out.shape) + out = out + b + + return out class DiscriminatorMLP(hk.Module): - """A multi-layer perceptron module.""" - - def __init__( - self, - hidden_layer_sizes: Iterable[int], - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, - with_bias: bool = True, - activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu, - input_dropout_rate: float = 0., - hidden_dropout_rate: float = 0., - spectral_normalization_lipschitz_coeff: Optional[float] = None, - name: Optional[str] = None - ): - """Constructs an MLP. + """A multi-layer perceptron module.""" + + def __init__( + self, + hidden_layer_sizes: Iterable[int], + w_init: Optional[hk.initializers.Initializer] = None, + b_init: Optional[hk.initializers.Initializer] = None, + with_bias: bool = True, + activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu, + input_dropout_rate: float = 0.0, + hidden_dropout_rate: float = 0.0, + spectral_normalization_lipschitz_coeff: Optional[float] = None, + name: Optional[str] = None, + ): + """Constructs an MLP. Args: hidden_layer_sizes: Hiddent layer sizes. @@ -208,142 +212,157 @@ def __init__( Raises: ValueError: If ``with_bias`` is ``False`` and ``b_init`` is not ``None``. """ - if not with_bias and b_init is not None: - raise ValueError('When with_bias=False b_init must not be set.') - - super().__init__(name=name) - self._activation = activation - self._input_dropout_rate = input_dropout_rate - self._hidden_dropout_rate = hidden_dropout_rate - layer_sizes = list(hidden_layer_sizes) + [1] - - if spectral_normalization_lipschitz_coeff is not None: - layer_lipschitz_coeff = np.power(spectral_normalization_lipschitz_coeff, - 1. / len(layer_sizes)) - layer_module = functools.partial( - SpectralNormalizedLinear, - lipschitz_coeff=layer_lipschitz_coeff, - w_init=w_init, - b_init=b_init, - with_bias=with_bias) - else: - layer_module = functools.partial( - hk.Linear, - w_init=w_init, - b_init=b_init, - with_bias=with_bias) - - layers = [] - for index, output_size in enumerate(layer_sizes): - layers.append( - layer_module(output_size=output_size, name=f'linear_{index}')) - self._layers = tuple(layers) - - def __call__( - self, - inputs: jnp.ndarray, - is_training: bool, - rng: Optional[networks_lib.PRNGKey], - ) -> networks_lib.Logits: - rng = hk.PRNGSequence(rng) if rng is not None else None - - out = inputs - for i, layer in enumerate(self._layers): - if is_training: - dropout_rate = ( - self._input_dropout_rate if i == 0 else self._hidden_dropout_rate) - out = hk.dropout(next(rng), dropout_rate, out) - out = layer(out) - if i < len(self._layers) - 1: - out = self._activation(out) - - return out + if not with_bias and b_init is not None: + raise ValueError("When with_bias=False b_init must not be set.") + + super().__init__(name=name) + self._activation = activation + self._input_dropout_rate = input_dropout_rate + self._hidden_dropout_rate = hidden_dropout_rate + layer_sizes = list(hidden_layer_sizes) + [1] + + if spectral_normalization_lipschitz_coeff is not None: + layer_lipschitz_coeff = np.power( + spectral_normalization_lipschitz_coeff, 1.0 / len(layer_sizes) + ) + layer_module = functools.partial( + SpectralNormalizedLinear, + lipschitz_coeff=layer_lipschitz_coeff, + w_init=w_init, + b_init=b_init, + with_bias=with_bias, + ) + else: + layer_module = functools.partial( + hk.Linear, w_init=w_init, b_init=b_init, with_bias=with_bias + ) + + layers = [] + for index, output_size in enumerate(layer_sizes): + layers.append(layer_module(output_size=output_size, name=f"linear_{index}")) + self._layers = tuple(layers) + + def __call__( + self, + inputs: jnp.ndarray, + is_training: bool, + rng: Optional[networks_lib.PRNGKey], + ) -> networks_lib.Logits: + rng = hk.PRNGSequence(rng) if rng is not None else None + + out = inputs + for i, layer in enumerate(self._layers): + if is_training: + dropout_rate = ( + self._input_dropout_rate if i == 0 else self._hidden_dropout_rate + ) + out = hk.dropout(next(rng), dropout_rate, out) + out = layer(out) + if i < len(self._layers) - 1: + out = self._activation(out) + + return out class DiscriminatorModule(hk.Module): - """Discriminator module that concatenates its inputs.""" - - def __init__(self, - environment_spec: specs.EnvironmentSpec, - use_action: bool, - use_next_obs: bool, - network_core: Callable[..., Any], - observation_embedding: Callable[[networks_lib.Observation], - jnp.ndarray] = lambda x: x, - name='discriminator'): - super().__init__(name=name) - self._use_action = use_action - self._environment_spec = environment_spec - self._use_next_obs = use_next_obs - self._network_core = network_core - self._observation_embedding = observation_embedding - - def __call__(self, observations: networks_lib.Observation, - actions: networks_lib.Action, - next_observations: networks_lib.Observation, is_training: bool, - rng: networks_lib.PRNGKey) -> networks_lib.Logits: - observations = self._observation_embedding(observations) - if self._use_next_obs: - next_observations = self._observation_embedding(next_observations) - data = jnp.concatenate([observations, next_observations], axis=-1) - else: - data = observations - if self._use_action: - action_spec = self._environment_spec.actions - if isinstance(action_spec, specs.DiscreteArray): - actions = jax.nn.one_hot(actions, - action_spec.num_values) - data = jnp.concatenate([data, actions], axis=-1) - output = self._network_core(data, is_training, rng) - output = jnp.squeeze(output, axis=-1) - return output + """Discriminator module that concatenates its inputs.""" + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + use_action: bool, + use_next_obs: bool, + network_core: Callable[..., Any], + observation_embedding: Callable[ + [networks_lib.Observation], jnp.ndarray + ] = lambda x: x, + name="discriminator", + ): + super().__init__(name=name) + self._use_action = use_action + self._environment_spec = environment_spec + self._use_next_obs = use_next_obs + self._network_core = network_core + self._observation_embedding = observation_embedding + + def __call__( + self, + observations: networks_lib.Observation, + actions: networks_lib.Action, + next_observations: networks_lib.Observation, + is_training: bool, + rng: networks_lib.PRNGKey, + ) -> networks_lib.Logits: + observations = self._observation_embedding(observations) + if self._use_next_obs: + next_observations = self._observation_embedding(next_observations) + data = jnp.concatenate([observations, next_observations], axis=-1) + else: + data = observations + if self._use_action: + action_spec = self._environment_spec.actions + if isinstance(action_spec, specs.DiscreteArray): + actions = jax.nn.one_hot(actions, action_spec.num_values) + data = jnp.concatenate([data, actions], axis=-1) + output = self._network_core(data, is_training, rng) + output = jnp.squeeze(output, axis=-1) + return output class AIRLModule(hk.Module): - """AIRL Module.""" - - def __init__(self, - environment_spec: specs.EnvironmentSpec, - use_action: bool, - use_next_obs: bool, - discount: float, - g_core: Callable[..., Any], - h_core: Callable[..., Any], - observation_embedding: Callable[[networks_lib.Observation], - jnp.ndarray] = lambda x: x, - name='airl'): - super().__init__(name=name) - self._environment_spec = environment_spec - self._use_action = use_action - self._use_next_obs = use_next_obs - self._discount = discount - self._g_core = g_core - self._h_core = h_core - self._observation_embedding = observation_embedding - - def __call__(self, observations: networks_lib.Observation, - actions: networks_lib.Action, - next_observations: networks_lib.Observation, - is_training: bool, - rng: networks_lib.PRNGKey) -> networks_lib.Logits: - g_output = DiscriminatorModule( - environment_spec=self._environment_spec, - use_action=self._use_action, - use_next_obs=self._use_next_obs, - network_core=self._g_core, - observation_embedding=self._observation_embedding, - name='airl_g')(observations, actions, next_observations, is_training, - rng) - h_module = DiscriminatorModule( - environment_spec=self._environment_spec, - use_action=False, - use_next_obs=False, - network_core=self._h_core, - observation_embedding=self._observation_embedding, - name='airl_h') - return (g_output + self._discount * h_module(next_observations, (), - (), is_training, rng) - - h_module(observations, (), (), is_training, rng)) + """AIRL Module.""" + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + use_action: bool, + use_next_obs: bool, + discount: float, + g_core: Callable[..., Any], + h_core: Callable[..., Any], + observation_embedding: Callable[ + [networks_lib.Observation], jnp.ndarray + ] = lambda x: x, + name="airl", + ): + super().__init__(name=name) + self._environment_spec = environment_spec + self._use_action = use_action + self._use_next_obs = use_next_obs + self._discount = discount + self._g_core = g_core + self._h_core = h_core + self._observation_embedding = observation_embedding + + def __call__( + self, + observations: networks_lib.Observation, + actions: networks_lib.Action, + next_observations: networks_lib.Observation, + is_training: bool, + rng: networks_lib.PRNGKey, + ) -> networks_lib.Logits: + g_output = DiscriminatorModule( + environment_spec=self._environment_spec, + use_action=self._use_action, + use_next_obs=self._use_next_obs, + network_core=self._g_core, + observation_embedding=self._observation_embedding, + name="airl_g", + )(observations, actions, next_observations, is_training, rng) + h_module = DiscriminatorModule( + environment_spec=self._environment_spec, + use_action=False, + use_next_obs=False, + network_core=self._h_core, + observation_embedding=self._observation_embedding, + name="airl_h", + ) + return ( + g_output + + self._discount * h_module(next_observations, (), (), is_training, rng) + - h_module(observations, (), (), is_training, rng) + ) # TODO(eorsini): Manipulate FeedForwardNetworks instead of transforms to @@ -351,11 +370,14 @@ def __call__(self, observations: networks_lib.Observation, def make_discriminator( environment_spec: specs.EnvironmentSpec, discriminator_transformed: hk.TransformedWithState, - logpi_fn: Optional[Callable[ - [networks_lib.Params, networks_lib.Observation, networks_lib.Action], - jnp.ndarray]] = None + logpi_fn: Optional[ + Callable[ + [networks_lib.Params, networks_lib.Observation, networks_lib.Action], + jnp.ndarray, + ] + ] = None, ) -> networks_lib.FeedForwardNetwork: - """Creates the discriminator network. + """Creates the discriminator network. Args: environment_spec: Environment spec @@ -367,33 +389,45 @@ def make_discriminator( The network. """ - def apply_fn(params: hk.Params, - policy_params: networks_lib.Params, - state: hk.State, - transitions: types.Transition, - is_training: bool, - rng: networks_lib.PRNGKey) -> networks_lib.Logits: - output, state = discriminator_transformed.apply( - params, state, transitions.observation, transitions.action, - transitions.next_observation, is_training, rng) - if logpi_fn is not None: - logpi = logpi_fn(policy_params, transitions.observation, - transitions.action) - - # Quick Maths: - # D = exp(output)/(exp(output) + pi(a|s)) - # logit(D) = log(D/(1-D)) = log(exp(output)/pi(a|s)) - # logit(D) = output - logpi - return output - logpi, state # pytype: disable=bad-return-type # jax-ndarray - return output, state # pytype: disable=bad-return-type # jax-ndarray - - dummy_obs = utils.zeros_like(environment_spec.observations) - dummy_obs = utils.add_batch_dim(dummy_obs) - dummy_actions = utils.zeros_like(environment_spec.actions) - dummy_actions = utils.add_batch_dim(dummy_actions) - - return networks_lib.FeedForwardNetwork( - # pylint: disable=g-long-lambda - init=lambda rng: discriminator_transformed.init( - rng, dummy_obs, dummy_actions, dummy_obs, False, rng), - apply=apply_fn) + def apply_fn( + params: hk.Params, + policy_params: networks_lib.Params, + state: hk.State, + transitions: types.Transition, + is_training: bool, + rng: networks_lib.PRNGKey, + ) -> networks_lib.Logits: + output, state = discriminator_transformed.apply( + params, + state, + transitions.observation, + transitions.action, + transitions.next_observation, + is_training, + rng, + ) + if logpi_fn is not None: + logpi = logpi_fn(policy_params, transitions.observation, transitions.action) + + # Quick Maths: + # D = exp(output)/(exp(output) + pi(a|s)) + # logit(D) = log(D/(1-D)) = log(exp(output)/pi(a|s)) + # logit(D) = output - logpi + return ( + output - logpi, + state, + ) # pytype: disable=bad-return-type # jax-ndarray + return output, state # pytype: disable=bad-return-type # jax-ndarray + + dummy_obs = utils.zeros_like(environment_spec.observations) + dummy_obs = utils.add_batch_dim(dummy_obs) + dummy_actions = utils.zeros_like(environment_spec.actions) + dummy_actions = utils.add_batch_dim(dummy_actions) + + return networks_lib.FeedForwardNetwork( + # pylint: disable=g-long-lambda + init=lambda rng: discriminator_transformed.init( + rng, dummy_obs, dummy_actions, dummy_obs, False, rng + ), + apply=apply_fn, + ) diff --git a/acme/agents/jax/ail/rewards.py b/acme/agents/jax/ail/rewards.py index ea737ad86d..8308fb289a 100644 --- a/acme/agents/jax/ail/rewards.py +++ b/acme/agents/jax/ail/rewards.py @@ -15,16 +15,17 @@ """AIL logits to AIL reward.""" from typing import Optional -from acme.agents.jax.ail import networks as ail_networks -from acme.jax import networks as networks_lib import jax import jax.numpy as jnp +from acme.agents.jax.ail import networks as ail_networks +from acme.jax import networks as networks_lib + def fairl_reward( - max_reward_magnitude: Optional[float] = None + max_reward_magnitude: Optional[float] = None, ) -> ail_networks.ImitationRewardFn: - """The FAIRL reward function (https://arxiv.org/pdf/1911.02256.pdf). + """The FAIRL reward function (https://arxiv.org/pdf/1911.02256.pdf). Args: max_reward_magnitude: Clipping value for the reward. @@ -33,22 +34,22 @@ def fairl_reward( The function from logit to imitation reward. """ - def imitation_reward(logits: networks_lib.Logits) -> float: - rewards = jnp.exp(jnp.clip(logits, a_max=20.)) * -logits - if max_reward_magnitude is not None: - # pylint: disable=invalid-unary-operand-type - rewards = jnp.clip( - rewards, a_min=-max_reward_magnitude, a_max=max_reward_magnitude) - return rewards # pytype: disable=bad-return-type # jax-types + def imitation_reward(logits: networks_lib.Logits) -> float: + rewards = jnp.exp(jnp.clip(logits, a_max=20.0)) * -logits + if max_reward_magnitude is not None: + # pylint: disable=invalid-unary-operand-type + rewards = jnp.clip( + rewards, a_min=-max_reward_magnitude, a_max=max_reward_magnitude + ) + return rewards # pytype: disable=bad-return-type # jax-types - return imitation_reward # pytype: disable=bad-return-type # jax-ndarray + return imitation_reward # pytype: disable=bad-return-type # jax-ndarray def gail_reward( - reward_balance: float = .5, - max_reward_magnitude: Optional[float] = None + reward_balance: float = 0.5, max_reward_magnitude: Optional[float] = None ) -> ail_networks.ImitationRewardFn: - """GAIL reward function (https://arxiv.org/pdf/1606.03476.pdf). + """GAIL reward function (https://arxiv.org/pdf/1606.03476.pdf). Args: reward_balance: 1 means log(D) reward, 0 means -log(1-D) and other values @@ -59,18 +60,19 @@ def gail_reward( The function from logit to imitation reward. """ - def imitation_reward(logits: networks_lib.Logits) -> float: - # Quick Maths: - # logits = ln(D) - ln(1-D) - # -softplus(-logits) = ln(D) - # softplus(logits) = -ln(1-D) - rewards = ( - reward_balance * -jax.nn.softplus(-logits) + - (1 - reward_balance) * jax.nn.softplus(logits)) - if max_reward_magnitude is not None: - # pylint: disable=invalid-unary-operand-type - rewards = jnp.clip( - rewards, a_min=-max_reward_magnitude, a_max=max_reward_magnitude) - return rewards + def imitation_reward(logits: networks_lib.Logits) -> float: + # Quick Maths: + # logits = ln(D) - ln(1-D) + # -softplus(-logits) = ln(D) + # softplus(logits) = -ln(1-D) + rewards = reward_balance * -jax.nn.softplus(-logits) + ( + 1 - reward_balance + ) * jax.nn.softplus(logits) + if max_reward_magnitude is not None: + # pylint: disable=invalid-unary-operand-type + rewards = jnp.clip( + rewards, a_min=-max_reward_magnitude, a_max=max_reward_magnitude + ) + return rewards - return imitation_reward # pytype: disable=bad-return-type # jax-ndarray + return imitation_reward # pytype: disable=bad-return-type # jax-ndarray diff --git a/acme/agents/jax/ars/__init__.py b/acme/agents/jax/ars/__init__.py index 38a58510b8..2af73b70cc 100644 --- a/acme/agents/jax/ars/__init__.py +++ b/acme/agents/jax/ars/__init__.py @@ -16,5 +16,4 @@ from acme.agents.jax.ars.builder import ARSBuilder from acme.agents.jax.ars.config import ARSConfig -from acme.agents.jax.ars.networks import make_networks -from acme.agents.jax.ars.networks import make_policy_network +from acme.agents.jax.ars.networks import make_networks, make_policy_network diff --git a/acme/agents/jax/ars/builder.py b/acme/agents/jax/ars/builder.py index 01fc8c3fcb..474d1be3a6 100644 --- a/acme/agents/jax/ars/builder.py +++ b/acme/agents/jax/ars/builder.py @@ -15,146 +15,167 @@ """ARS Builder.""" from typing import Dict, Iterator, List, Optional, Tuple +import jax +import jax.numpy as jnp +import numpy as np +import reverb + import acme -from acme import adders -from acme import core -from acme import specs +from acme import adders, core, specs from acme.adders import reverb as adders_reverb from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.agents.jax import builders +from acme.agents.jax import actors, builders from acme.agents.jax.ars import config as ars_config from acme.agents.jax.ars import learning from acme.jax import networks as networks_lib -from acme.jax import running_statistics -from acme.jax import utils -from acme.jax import variable_utils -from acme.utils import counting -from acme.utils import loggers -import jax -import jax.numpy as jnp -import numpy as np -import reverb - - -def get_policy(policy_network: networks_lib.FeedForwardNetwork, - normalization_apply_fn) -> actor_core_lib.FeedForwardPolicy: - """Returns a function that computes actions.""" - - def apply( - params: networks_lib.Params, key: networks_lib.PRNGKey, - obs: networks_lib.Observation - ) -> Tuple[networks_lib.Action, Dict[str, jnp.ndarray]]: - del key - params_key, policy_params, normalization_params = params - normalized_obs = normalization_apply_fn(obs, normalization_params) - action = policy_network.apply(policy_params, normalized_obs) - return action, { - 'params_key': - jax.tree_map(lambda x: jnp.expand_dims(x, axis=0), params_key) - } - - return apply +from acme.jax import running_statistics, utils, variable_utils +from acme.utils import counting, loggers + + +def get_policy( + policy_network: networks_lib.FeedForwardNetwork, normalization_apply_fn +) -> actor_core_lib.FeedForwardPolicy: + """Returns a function that computes actions.""" + + def apply( + params: networks_lib.Params, + key: networks_lib.PRNGKey, + obs: networks_lib.Observation, + ) -> Tuple[networks_lib.Action, Dict[str, jnp.ndarray]]: + del key + params_key, policy_params, normalization_params = params + normalized_obs = normalization_apply_fn(obs, normalization_params) + action = policy_network.apply(policy_params, normalized_obs) + return ( + action, + { + "params_key": jax.tree_map( + lambda x: jnp.expand_dims(x, axis=0), params_key + ) + }, + ) + + return apply class ARSBuilder( - builders.ActorLearnerBuilder[networks_lib.FeedForwardNetwork, - Tuple[str, networks_lib.FeedForwardNetwork], - reverb.ReplaySample]): - """ARS Builder.""" - - def __init__( - self, - config: ars_config.ARSConfig, - spec: specs.EnvironmentSpec, - ): - self._config = config - self._spec = spec - - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: networks_lib.FeedForwardNetwork, - dataset: Iterator[reverb.ReplaySample], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - del environment_spec, replay_client - return learning.ARSLearner(self._spec, networks, random_key, self._config, - dataset, counter, logger_fn('learner')) - - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: Tuple[str, networks_lib.FeedForwardNetwork], - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - adder: Optional[adders.Adder] = None, - ) -> acme.Actor: - del environment_spec - assert variable_source is not None - - kname, policy = policy - - normalization_apply_fn = ( - running_statistics.normalize if self._config.normalize_observations else - (lambda a, b: a)) - policy_to_run = get_policy(policy, normalization_apply_fn) - - actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core( - policy_to_run) - variable_client = variable_utils.VariableClient(variable_source, kname, - device='cpu') - return actors.GenericActor( - actor_core, - random_key, - variable_client, - adder, - backend='cpu', - per_episode_update=True) - - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: Tuple[str, networks_lib.FeedForwardNetwork], - ) -> List[reverb.Table]: - """Create tables to insert data into.""" - del policy - extra_spec = { - 'params_key': (np.zeros(shape=(), dtype=np.int32), - np.zeros(shape=(), dtype=np.int32), - np.zeros(shape=(), dtype=np.bool_)), - } - signature = adders_reverb.EpisodeAdder.signature( - environment_spec, sequence_length=None, extras_spec=extra_spec) - return [ - reverb.Table.queue( - name=self._config.replay_table_name, - max_size=10000, # a big number - signature=signature) + builders.ActorLearnerBuilder[ + networks_lib.FeedForwardNetwork, + Tuple[str, networks_lib.FeedForwardNetwork], + reverb.ReplaySample, ] - - def make_dataset_iterator( - self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: - """Create a dataset iterator to use for learning/updating the agent.""" - dataset = reverb.TrajectoryDataset.from_table_signature( - server_address=replay_client.server_address, - table=self._config.replay_table_name, - max_in_flight_samples_per_worker=1) - return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0]) - - def make_adder( - self, replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[Tuple[str, networks_lib.FeedForwardNetwork]] - ) -> Optional[adders.Adder]: - """Create an adder which records data generated by the actor/environment.""" - del environment_spec, policy - - return adders_reverb.EpisodeAdder( - priority_fns={self._config.replay_table_name: None}, - client=replay_client, - max_sequence_length=2000, - ) +): + """ARS Builder.""" + + def __init__( + self, config: ars_config.ARSConfig, spec: specs.EnvironmentSpec, + ): + self._config = config + self._spec = spec + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: networks_lib.FeedForwardNetwork, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec, replay_client + return learning.ARSLearner( + self._spec, + networks, + random_key, + self._config, + dataset, + counter, + logger_fn("learner"), + ) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: Tuple[str, networks_lib.FeedForwardNetwork], + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> acme.Actor: + del environment_spec + assert variable_source is not None + + kname, policy = policy + + normalization_apply_fn = ( + running_statistics.normalize + if self._config.normalize_observations + else (lambda a, b: a) + ) + policy_to_run = get_policy(policy, normalization_apply_fn) + + actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core( + policy_to_run + ) + variable_client = variable_utils.VariableClient( + variable_source, kname, device="cpu" + ) + return actors.GenericActor( + actor_core, + random_key, + variable_client, + adder, + backend="cpu", + per_episode_update=True, + ) + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: Tuple[str, networks_lib.FeedForwardNetwork], + ) -> List[reverb.Table]: + """Create tables to insert data into.""" + del policy + extra_spec = { + "params_key": ( + np.zeros(shape=(), dtype=np.int32), + np.zeros(shape=(), dtype=np.int32), + np.zeros(shape=(), dtype=np.bool_), + ), + } + signature = adders_reverb.EpisodeAdder.signature( + environment_spec, sequence_length=None, extras_spec=extra_spec + ) + return [ + reverb.Table.queue( + name=self._config.replay_table_name, + max_size=10000, # a big number + signature=signature, + ) + ] + + def make_dataset_iterator( + self, replay_client: reverb.Client + ) -> Iterator[reverb.ReplaySample]: + """Create a dataset iterator to use for learning/updating the agent.""" + dataset = reverb.TrajectoryDataset.from_table_signature( + server_address=replay_client.server_address, + table=self._config.replay_table_name, + max_in_flight_samples_per_worker=1, + ) + return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0]) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[Tuple[str, networks_lib.FeedForwardNetwork]], + ) -> Optional[adders.Adder]: + """Create an adder which records data generated by the actor/environment.""" + del environment_spec, policy + + return adders_reverb.EpisodeAdder( + priority_fns={self._config.replay_table_name: None}, + client=replay_client, + max_sequence_length=2000, + ) diff --git a/acme/agents/jax/ars/config.py b/acme/agents/jax/ars/config.py index 7658714ba6..1bd5b53a25 100644 --- a/acme/agents/jax/ars/config.py +++ b/acme/agents/jax/ars/config.py @@ -20,12 +20,13 @@ @dataclasses.dataclass class ARSConfig: - """Configuration options for ARS.""" - num_steps: int = 1000000 - normalize_observations: bool = True - step_size: float = 0.015 - num_directions: int = 60 - exploration_noise_std: float = 0.025 - top_directions: int = 20 - reward_shift: float = 1.0 - replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE + """Configuration options for ARS.""" + + num_steps: int = 1000000 + normalize_observations: bool = True + step_size: float = 0.015 + num_directions: int = 60 + exploration_noise_std: float = 0.025 + top_directions: int = 20 + reward_shift: float = 1.0 + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE diff --git a/acme/agents/jax/ars/learning.py b/acme/agents/jax/ars/learning.py index 7ad7f3cf40..97c52435d8 100644 --- a/acme/agents/jax/ars/learning.py +++ b/acme/agents/jax/ars/learning.py @@ -19,264 +19,303 @@ import time from typing import Any, Deque, Dict, Iterator, List, NamedTuple, Optional +import jax +import numpy as np +import reverb + import acme from acme import specs from acme.adders import reverb as acme_reverb from acme.agents.jax.ars import config as ars_config from acme.agents.jax.ars import networks as ars_networks from acme.jax import networks as networks_lib -from acme.jax import running_statistics -from acme.jax import utils -from acme.utils import counting -from acme.utils import loggers -import jax -import numpy as np -import reverb +from acme.jax import running_statistics, utils +from acme.utils import counting, loggers class PerturbationKey(NamedTuple): - training_iteration: int - perturbation_id: int - is_opposite: bool + training_iteration: int + perturbation_id: int + is_opposite: bool class EvaluationResult(NamedTuple): - total_reward: float - observation: networks_lib.Observation + total_reward: float + observation: networks_lib.Observation class EvaluationRequest(NamedTuple): - key: PerturbationKey - policy_params: networks_lib.Params - normalization_params: networks_lib.Params + key: PerturbationKey + policy_params: networks_lib.Params + normalization_params: networks_lib.Params class TrainingState(NamedTuple): - """Contains training state for the learner.""" - key: networks_lib.PRNGKey - normalizer_params: networks_lib.Params - policy_params: networks_lib.Params - training_iteration: int + """Contains training state for the learner.""" + + key: networks_lib.PRNGKey + normalizer_params: networks_lib.Params + policy_params: networks_lib.Params + training_iteration: int class EvaluationState(NamedTuple): - """Contains training state for the learner.""" - key: networks_lib.PRNGKey - evaluation_queue: Deque[EvaluationRequest] - received_results: Dict[PerturbationKey, EvaluationResult] - noises: List[networks_lib.Params] + """Contains training state for the learner.""" + + key: networks_lib.PRNGKey + evaluation_queue: Deque[EvaluationRequest] + received_results: Dict[PerturbationKey, EvaluationResult] + noises: List[networks_lib.Params] class ARSLearner(acme.Learner): - """ARS learner.""" - - _state: TrainingState - - def __init__( - self, - spec: specs.EnvironmentSpec, - networks: networks_lib.FeedForwardNetwork, - rng: networks_lib.PRNGKey, - config: ars_config.ARSConfig, - iterator: Iterator[reverb.ReplaySample], - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None): - - self._config = config - self._lock = threading.Lock() - - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger( - 'learner', - asynchronous=True, - serialize_fn=utils.fetch_devicearray, - steps_key=self._counter.get_steps_key()) - - # Iterator on demonstration transitions. - self._iterator = iterator - - if self._config.normalize_observations: - normalizer_params = running_statistics.init_state(spec.observations) - self._normalizer_update_fn = running_statistics.update - else: - normalizer_params = () - self._normalizer_update_fn = lambda a, b: a - - rng1, rng2, tmp = jax.random.split(rng, 3) - # Create initial state. - self._training_state = TrainingState( - key=rng1, - policy_params=networks.init(tmp), - normalizer_params=normalizer_params, - training_iteration=0) - self._evaluation_state = EvaluationState( - key=rng2, - evaluation_queue=collections.deque(), - received_results={}, - noises=[]) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - def _generate_perturbations(self): - with self._lock: - rng, noise_key = jax.random.split(self._evaluation_state.key) - self._evaluation_state = EvaluationState( - key=rng, - evaluation_queue=collections.deque(), - received_results={}, - noises=[]) - - all_noise = jax.random.normal( - noise_key, - shape=(self._config.num_directions,) + - self._training_state.policy_params.shape, - dtype=self._training_state.policy_params.dtype) - for i in range(self._config.num_directions): - noise = all_noise[i] - self._evaluation_state.noises.append(noise) - for direction in (-1, 1): - self._evaluation_state.evaluation_queue.append( - EvaluationRequest( - PerturbationKey(self._training_state.training_iteration, i, - direction == -1), - self._training_state.policy_params + - direction * noise * self._config.exploration_noise_std, - self._training_state.normalizer_params)) - - def _read_results(self): - while len(self._evaluation_state.received_results - ) != self._config.num_directions * 2: - data = next(self._iterator).data - data = acme_reverb.Step(*data) - - # validation - params_key = data.extras['params_key'] - training_step, perturbation_id, is_opposite = params_key - # If the incoming data does not correspond to the current iteration, - # we simply ignore it. - if not np.all( - training_step[:-1] == self._training_state.training_iteration): - continue - - # The whole episode should be run with the same policy, so let's check - # for that. - assert np.all(perturbation_id[:-1] == perturbation_id[0]) - assert np.all(is_opposite[:-1] == is_opposite[0]) - - perturbation_id = perturbation_id[0].item() - is_opposite = is_opposite[0].item() - - total_reward = np.sum(data.reward - self._config.reward_shift) - k = PerturbationKey(self._training_state.training_iteration, - perturbation_id, is_opposite) - if k in self._evaluation_state.received_results: - continue - self._evaluation_state.received_results[k] = EvaluationResult( - total_reward, data.observation) - - def _update_model(self) -> int: - # Update normalization params. - real_actor_steps = 0 - normalizer_params = self._training_state.normalizer_params - for _, value in self._evaluation_state.received_results.items(): - real_actor_steps += value.observation.shape[0] - 1 - normalizer_params = self._normalizer_update_fn(normalizer_params, - value.observation) - - # Keep only top directions. - top_directions = [] - for i in range(self._config.num_directions): - reward_forward = self._evaluation_state.received_results[PerturbationKey( - self._training_state.training_iteration, i, False)].total_reward - reward_reverse = self._evaluation_state.received_results[PerturbationKey( - self._training_state.training_iteration, i, True)].total_reward - top_directions.append((max(reward_forward, reward_reverse), i)) - top_directions.sort() - top_directions = top_directions[-self._config.top_directions:] - - # Compute reward_std. - reward = [] - for _, i in top_directions: - reward.append(self._evaluation_state.received_results[PerturbationKey( - self._training_state.training_iteration, i, False)].total_reward) - reward.append(self._evaluation_state.received_results[PerturbationKey( - self._training_state.training_iteration, i, True)].total_reward) - reward_std = np.std(reward) - - # Compute new policy params. - policy_params = self._training_state.policy_params - curr_sum = np.zeros_like(policy_params) - for _, i in top_directions: - reward_forward = self._evaluation_state.received_results[PerturbationKey( - self._training_state.training_iteration, i, False)].total_reward - reward_reverse = self._evaluation_state.received_results[PerturbationKey( - self._training_state.training_iteration, i, True)].total_reward - curr_sum += self._evaluation_state.noises[i] * ( - reward_forward - reward_reverse) - - policy_params = policy_params + self._config.step_size / ( - self._config.top_directions * reward_std) * curr_sum - - self._training_state = TrainingState( - key=self._training_state.key, - normalizer_params=normalizer_params, - policy_params=policy_params, - training_iteration=self._training_state.training_iteration) - return real_actor_steps - - def step(self): - self._training_state = self._training_state._replace( - training_iteration=self._training_state.training_iteration + 1) - self._generate_perturbations() - self._read_results() - real_actor_steps = self._update_model() - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Increment counts and record the current time - counts = self._counter.increment( - steps=1, - real_actor_steps=real_actor_steps, - learner_episodes=2 * self._config.num_directions, - walltime=elapsed_time) - - # Attempts to write the logs. - self._logger.write(counts) - - def get_variables(self, names: List[str]) -> List[Any]: - assert (names == [ars_networks.BEHAVIOR_PARAMS_NAME] or - names == [ars_networks.EVAL_PARAMS_NAME]) - if names == [ars_networks.EVAL_PARAMS_NAME]: - return [PerturbationKey(-1, -1, False), - self._training_state.policy_params, - self._training_state.normalizer_params] - should_sleep = False - while True: - if should_sleep: - time.sleep(0.1) - should_sleep = False - with self._lock: - if not self._evaluation_state.evaluation_queue: - should_sleep = True - continue - data = self._evaluation_state.evaluation_queue.pop() - # If this perturbation was already evaluated, we simply skip it. - if data.key in self._evaluation_state.received_results: - continue - # In case if an actor fails we still need to reevaluate the same - # perturbation, so we just add it to the end of the queue. - self._evaluation_state.evaluation_queue.append(data) - return [data] - - def save(self) -> TrainingState: - return self._training_state - - def restore(self, state: TrainingState): - self._training_state = state + """ARS learner.""" + + _state: TrainingState + + def __init__( + self, + spec: specs.EnvironmentSpec, + networks: networks_lib.FeedForwardNetwork, + rng: networks_lib.PRNGKey, + config: ars_config.ARSConfig, + iterator: Iterator[reverb.ReplaySample], + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + ): + + self._config = config + self._lock = threading.Lock() + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + "learner", + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key(), + ) + + # Iterator on demonstration transitions. + self._iterator = iterator + + if self._config.normalize_observations: + normalizer_params = running_statistics.init_state(spec.observations) + self._normalizer_update_fn = running_statistics.update + else: + normalizer_params = () + self._normalizer_update_fn = lambda a, b: a + + rng1, rng2, tmp = jax.random.split(rng, 3) + # Create initial state. + self._training_state = TrainingState( + key=rng1, + policy_params=networks.init(tmp), + normalizer_params=normalizer_params, + training_iteration=0, + ) + self._evaluation_state = EvaluationState( + key=rng2, + evaluation_queue=collections.deque(), + received_results={}, + noises=[], + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + def _generate_perturbations(self): + with self._lock: + rng, noise_key = jax.random.split(self._evaluation_state.key) + self._evaluation_state = EvaluationState( + key=rng, + evaluation_queue=collections.deque(), + received_results={}, + noises=[], + ) + + all_noise = jax.random.normal( + noise_key, + shape=(self._config.num_directions,) + + self._training_state.policy_params.shape, + dtype=self._training_state.policy_params.dtype, + ) + for i in range(self._config.num_directions): + noise = all_noise[i] + self._evaluation_state.noises.append(noise) + for direction in (-1, 1): + self._evaluation_state.evaluation_queue.append( + EvaluationRequest( + PerturbationKey( + self._training_state.training_iteration, + i, + direction == -1, + ), + self._training_state.policy_params + + direction * noise * self._config.exploration_noise_std, + self._training_state.normalizer_params, + ) + ) + + def _read_results(self): + while ( + len(self._evaluation_state.received_results) + != self._config.num_directions * 2 + ): + data = next(self._iterator).data + data = acme_reverb.Step(*data) + + # validation + params_key = data.extras["params_key"] + training_step, perturbation_id, is_opposite = params_key + # If the incoming data does not correspond to the current iteration, + # we simply ignore it. + if not np.all( + training_step[:-1] == self._training_state.training_iteration + ): + continue + + # The whole episode should be run with the same policy, so let's check + # for that. + assert np.all(perturbation_id[:-1] == perturbation_id[0]) + assert np.all(is_opposite[:-1] == is_opposite[0]) + + perturbation_id = perturbation_id[0].item() + is_opposite = is_opposite[0].item() + + total_reward = np.sum(data.reward - self._config.reward_shift) + k = PerturbationKey( + self._training_state.training_iteration, perturbation_id, is_opposite + ) + if k in self._evaluation_state.received_results: + continue + self._evaluation_state.received_results[k] = EvaluationResult( + total_reward, data.observation + ) + + def _update_model(self) -> int: + # Update normalization params. + real_actor_steps = 0 + normalizer_params = self._training_state.normalizer_params + for _, value in self._evaluation_state.received_results.items(): + real_actor_steps += value.observation.shape[0] - 1 + normalizer_params = self._normalizer_update_fn( + normalizer_params, value.observation + ) + + # Keep only top directions. + top_directions = [] + for i in range(self._config.num_directions): + reward_forward = self._evaluation_state.received_results[ + PerturbationKey(self._training_state.training_iteration, i, False) + ].total_reward + reward_reverse = self._evaluation_state.received_results[ + PerturbationKey(self._training_state.training_iteration, i, True) + ].total_reward + top_directions.append((max(reward_forward, reward_reverse), i)) + top_directions.sort() + top_directions = top_directions[-self._config.top_directions :] + + # Compute reward_std. + reward = [] + for _, i in top_directions: + reward.append( + self._evaluation_state.received_results[ + PerturbationKey(self._training_state.training_iteration, i, False) + ].total_reward + ) + reward.append( + self._evaluation_state.received_results[ + PerturbationKey(self._training_state.training_iteration, i, True) + ].total_reward + ) + reward_std = np.std(reward) + + # Compute new policy params. + policy_params = self._training_state.policy_params + curr_sum = np.zeros_like(policy_params) + for _, i in top_directions: + reward_forward = self._evaluation_state.received_results[ + PerturbationKey(self._training_state.training_iteration, i, False) + ].total_reward + reward_reverse = self._evaluation_state.received_results[ + PerturbationKey(self._training_state.training_iteration, i, True) + ].total_reward + curr_sum += self._evaluation_state.noises[i] * ( + reward_forward - reward_reverse + ) + + policy_params = ( + policy_params + + self._config.step_size + / (self._config.top_directions * reward_std) + * curr_sum + ) + + self._training_state = TrainingState( + key=self._training_state.key, + normalizer_params=normalizer_params, + policy_params=policy_params, + training_iteration=self._training_state.training_iteration, + ) + return real_actor_steps + + def step(self): + self._training_state = self._training_state._replace( + training_iteration=self._training_state.training_iteration + 1 + ) + self._generate_perturbations() + self._read_results() + real_actor_steps = self._update_model() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment( + steps=1, + real_actor_steps=real_actor_steps, + learner_episodes=2 * self._config.num_directions, + walltime=elapsed_time, + ) + + # Attempts to write the logs. + self._logger.write(counts) + + def get_variables(self, names: List[str]) -> List[Any]: + assert names == [ars_networks.BEHAVIOR_PARAMS_NAME] or names == [ + ars_networks.EVAL_PARAMS_NAME + ] + if names == [ars_networks.EVAL_PARAMS_NAME]: + return [ + PerturbationKey(-1, -1, False), + self._training_state.policy_params, + self._training_state.normalizer_params, + ] + should_sleep = False + while True: + if should_sleep: + time.sleep(0.1) + should_sleep = False + with self._lock: + if not self._evaluation_state.evaluation_queue: + should_sleep = True + continue + data = self._evaluation_state.evaluation_queue.pop() + # If this perturbation was already evaluated, we simply skip it. + if data.key in self._evaluation_state.received_results: + continue + # In case if an actor fails we still need to reevaluate the same + # perturbation, so we just add it to the end of the queue. + self._evaluation_state.evaluation_queue.append(data) + return [data] + + def save(self) -> TrainingState: + return self._training_state + + def restore(self, state: TrainingState): + self._training_state = state diff --git a/acme/agents/jax/ars/networks.py b/acme/agents/jax/ars/networks.py index 55c0a5379a..48d139a37a 100644 --- a/acme/agents/jax/ars/networks.py +++ b/acme/agents/jax/ars/networks.py @@ -16,18 +16,17 @@ from typing import Tuple -from acme import specs -from acme.jax import networks as networks_lib import jax.numpy as jnp +from acme import specs +from acme.jax import networks as networks_lib -BEHAVIOR_PARAMS_NAME = 'policy' -EVAL_PARAMS_NAME = 'eval' +BEHAVIOR_PARAMS_NAME = "policy" +EVAL_PARAMS_NAME = "eval" -def make_networks( - spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: - """Creates networks used by the agent. +def make_networks(spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: + """Creates networks used by the agent. The model used by the ARS paper is a simple clipped linear model. @@ -38,15 +37,16 @@ def make_networks( A FeedForwardNetwork network. """ - obs_size = spec.observations.shape[0] - act_size = spec.actions.shape[0] - return networks_lib.FeedForwardNetwork( - init=lambda _: jnp.zeros((obs_size, act_size)), - apply=lambda matrix, obs: jnp.clip(jnp.matmul(obs, matrix), -1, 1)) + obs_size = spec.observations.shape[0] + act_size = spec.actions.shape[0] + return networks_lib.FeedForwardNetwork( + init=lambda _: jnp.zeros((obs_size, act_size)), + apply=lambda matrix, obs: jnp.clip(jnp.matmul(obs, matrix), -1, 1), + ) def make_policy_network( - network: networks_lib.FeedForwardNetwork, - eval_mode: bool = True) -> Tuple[str, networks_lib.FeedForwardNetwork]: - params_name = EVAL_PARAMS_NAME if eval_mode else BEHAVIOR_PARAMS_NAME - return (params_name, network) + network: networks_lib.FeedForwardNetwork, eval_mode: bool = True +) -> Tuple[str, networks_lib.FeedForwardNetwork]: + params_name = EVAL_PARAMS_NAME if eval_mode else BEHAVIOR_PARAMS_NAME + return (params_name, network) diff --git a/acme/agents/jax/bc/__init__.py b/acme/agents/jax/bc/__init__.py index dee94a5a9f..36fac831df 100644 --- a/acme/agents/jax/bc/__init__.py +++ b/acme/agents/jax/bc/__init__.py @@ -18,12 +18,10 @@ from acme.agents.jax.bc.builder import BCBuilder from acme.agents.jax.bc.config import BCConfig from acme.agents.jax.bc.learning import BCLearner -from acme.agents.jax.bc.losses import BCLoss -from acme.agents.jax.bc.losses import logp -from acme.agents.jax.bc.losses import mse -from acme.agents.jax.bc.losses import peerbc -from acme.agents.jax.bc.losses import rcal -from acme.agents.jax.bc.networks import BCNetworks -from acme.agents.jax.bc.networks import BCPolicyNetwork -from acme.agents.jax.bc.networks import convert_policy_value_to_bc_network -from acme.agents.jax.bc.networks import convert_to_bc_network +from acme.agents.jax.bc.losses import BCLoss, logp, mse, peerbc, rcal +from acme.agents.jax.bc.networks import ( + BCNetworks, + BCPolicyNetwork, + convert_policy_value_to_bc_network, + convert_to_bc_network, +) diff --git a/acme/agents/jax/bc/agent_test.py b/acme/agents/jax/bc/agent_test.py index e266b49a02..53fadf09f3 100644 --- a/acme/agents/jax/bc/agent_test.py +++ b/acme/agents/jax/bc/agent_test.py @@ -14,178 +14,186 @@ """Tests for the BC agent.""" -from acme import specs -from acme import types -from acme.agents.jax import bc -from acme.jax import networks as networks_lib -from acme.jax import types as jax_types -from acme.jax import utils -from acme.testing import fakes import chex import haiku as hk import jax import jax.numpy as jnp -from jax.scipy import special import numpy as np import optax import rlax +from absl.testing import absltest, parameterized +from jax.scipy import special -from absl.testing import absltest -from absl.testing import parameterized - +from acme import specs, types +from acme.agents.jax import bc +from acme.jax import networks as networks_lib +from acme.jax import types as jax_types +from acme.jax import utils +from acme.testing import fakes -def make_networks(spec: specs.EnvironmentSpec, - discrete_actions: bool = False) -> bc.BCNetworks: - """Creates networks used by the agent.""" - if discrete_actions: - final_layer_size = spec.actions.num_values - else: - final_layer_size = np.prod(spec.actions.shape, dtype=int) +def make_networks( + spec: specs.EnvironmentSpec, discrete_actions: bool = False +) -> bc.BCNetworks: + """Creates networks used by the agent.""" - def _actor_fn(obs, is_training=False, key=None): - # is_training and key allows to defined train/test dependant modules - # like dropout. - del is_training - del key if discrete_actions: - network = hk.nets.MLP([64, 64, final_layer_size]) + final_layer_size = spec.actions.num_values else: - network = hk.Sequential([ - networks_lib.LayerNormMLP([64, 64], activate_final=True), - networks_lib.NormalTanhDistribution(final_layer_size), - ]) - return network(obs) - - policy = hk.without_apply_rng(hk.transform(_actor_fn)) + final_layer_size = np.prod(spec.actions.shape, dtype=int) + + def _actor_fn(obs, is_training=False, key=None): + # is_training and key allows to defined train/test dependant modules + # like dropout. + del is_training + del key + if discrete_actions: + network = hk.nets.MLP([64, 64, final_layer_size]) + else: + network = hk.Sequential( + [ + networks_lib.LayerNormMLP([64, 64], activate_final=True), + networks_lib.NormalTanhDistribution(final_layer_size), + ] + ) + return network(obs) + + policy = hk.without_apply_rng(hk.transform(_actor_fn)) + + # Create dummy observations and actions to create network parameters. + dummy_obs = utils.zeros_like(spec.observations) + dummy_obs = utils.add_batch_dim(dummy_obs) + policy_network = networks_lib.FeedForwardNetwork( + lambda key: policy.init(key, dummy_obs), policy.apply + ) + bc_policy_network = bc.convert_to_bc_network(policy_network) - # Create dummy observations and actions to create network parameters. - dummy_obs = utils.zeros_like(spec.observations) - dummy_obs = utils.add_batch_dim(dummy_obs) - policy_network = networks_lib.FeedForwardNetwork( - lambda key: policy.init(key, dummy_obs), policy.apply) - bc_policy_network = bc.convert_to_bc_network(policy_network) - - if discrete_actions: + if discrete_actions: - def sample_fn(logits: networks_lib.NetworkOutput, - key: jax_types.PRNGKey) -> networks_lib.Action: - return rlax.epsilon_greedy(epsilon=0.0).sample(key, logits) + def sample_fn( + logits: networks_lib.NetworkOutput, key: jax_types.PRNGKey + ) -> networks_lib.Action: + return rlax.epsilon_greedy(epsilon=0.0).sample(key, logits) - def log_prob(logits: networks_lib.NetworkOutput, - actions: networks_lib.Action) -> networks_lib.LogProb: - max_logits = jnp.max(logits, axis=-1, keepdims=True) - logits = logits - max_logits - logits_actions = jnp.sum( - jax.nn.one_hot(actions, spec.actions.num_values) * logits, axis=-1) + def log_prob( + logits: networks_lib.NetworkOutput, actions: networks_lib.Action + ) -> networks_lib.LogProb: + max_logits = jnp.max(logits, axis=-1, keepdims=True) + logits = logits - max_logits + logits_actions = jnp.sum( + jax.nn.one_hot(actions, spec.actions.num_values) * logits, axis=-1 + ) - log_prob = logits_actions - special.logsumexp(logits, axis=-1) - return log_prob + log_prob = logits_actions - special.logsumexp(logits, axis=-1) + return log_prob - else: + else: - def sample_fn(distribution: networks_lib.NetworkOutput, - key: jax_types.PRNGKey) -> networks_lib.Action: - return distribution.sample(seed=key) + def sample_fn( + distribution: networks_lib.NetworkOutput, key: jax_types.PRNGKey + ) -> networks_lib.Action: + return distribution.sample(seed=key) - def log_prob(distribuition: networks_lib.NetworkOutput, - actions: networks_lib.Action) -> networks_lib.LogProb: - return distribuition.log_prob(actions) + def log_prob( + distribuition: networks_lib.NetworkOutput, actions: networks_lib.Action + ) -> networks_lib.LogProb: + return distribuition.log_prob(actions) - return bc.BCNetworks(bc_policy_network, sample_fn, log_prob) + return bc.BCNetworks(bc_policy_network, sample_fn, log_prob) class BCTest(parameterized.TestCase): - - @parameterized.parameters( - ('logp',), - ('mse',), - ('peerbc',) - ) - def test_continuous_actions(self, loss_name): - with chex.fake_pmap_and_jit(): - num_sgd_steps_per_step = 1 - num_steps = 5 - - # Create a fake environment to test with. - environment = fakes.ContinuousEnvironment( - episode_length=10, bounded=True, action_dim=6) - - spec = specs.make_environment_spec(environment) - dataset_demonstration = fakes.transition_dataset(environment) - dataset_demonstration = dataset_demonstration.map( - lambda sample: types.Transition(*sample.data)) - dataset_demonstration = dataset_demonstration.batch(8).as_numpy_iterator() - - # Construct the agent. - networks = make_networks(spec) - - if loss_name == 'logp': - loss_fn = bc.logp() - elif loss_name == 'mse': - loss_fn = bc.mse() - elif loss_name == 'peerbc': - loss_fn = bc.peerbc(bc.logp(), zeta=0.1) - else: - raise ValueError - - learner = bc.BCLearner( - networks=networks, - random_key=jax.random.PRNGKey(0), - loss_fn=loss_fn, - optimizer=optax.adam(0.01), - prefetching_iterator=utils.sharded_prefetch(dataset_demonstration), - num_sgd_steps_per_step=num_sgd_steps_per_step) - - # Train the agent - for _ in range(num_steps): - learner.step() - - @parameterized.parameters( - ('logp',), - ('rcal',)) - def test_discrete_actions(self, loss_name): - with chex.fake_pmap_and_jit(): - - num_sgd_steps_per_step = 1 - num_steps = 5 - - # Create a fake environment to test with. - environment = fakes.DiscreteEnvironment( - num_actions=10, num_observations=100, obs_shape=(10,), - obs_dtype=np.float32) - - spec = specs.make_environment_spec(environment) - dataset_demonstration = fakes.transition_dataset(environment) - dataset_demonstration = dataset_demonstration.map( - lambda sample: types.Transition(*sample.data)) - dataset_demonstration = dataset_demonstration.batch(8).as_numpy_iterator() - - # Construct the agent. - networks = make_networks(spec, discrete_actions=True) - - if loss_name == 'logp': - loss_fn = bc.logp() - - elif loss_name == 'rcal': - base_loss_fn = bc.logp() - loss_fn = bc.rcal(base_loss_fn, discount=0.99, alpha=0.1) - - else: - raise ValueError - - learner = bc.BCLearner( - networks=networks, - random_key=jax.random.PRNGKey(0), - loss_fn=loss_fn, - optimizer=optax.adam(0.01), - prefetching_iterator=utils.sharded_prefetch(dataset_demonstration), - num_sgd_steps_per_step=num_sgd_steps_per_step) - - # Train the agent - for _ in range(num_steps): - learner.step() - - -if __name__ == '__main__': - absltest.main() + @parameterized.parameters(("logp",), ("mse",), ("peerbc",)) + def test_continuous_actions(self, loss_name): + with chex.fake_pmap_and_jit(): + num_sgd_steps_per_step = 1 + num_steps = 5 + + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment( + episode_length=10, bounded=True, action_dim=6 + ) + + spec = specs.make_environment_spec(environment) + dataset_demonstration = fakes.transition_dataset(environment) + dataset_demonstration = dataset_demonstration.map( + lambda sample: types.Transition(*sample.data) + ) + dataset_demonstration = dataset_demonstration.batch(8).as_numpy_iterator() + + # Construct the agent. + networks = make_networks(spec) + + if loss_name == "logp": + loss_fn = bc.logp() + elif loss_name == "mse": + loss_fn = bc.mse() + elif loss_name == "peerbc": + loss_fn = bc.peerbc(bc.logp(), zeta=0.1) + else: + raise ValueError + + learner = bc.BCLearner( + networks=networks, + random_key=jax.random.PRNGKey(0), + loss_fn=loss_fn, + optimizer=optax.adam(0.01), + prefetching_iterator=utils.sharded_prefetch(dataset_demonstration), + num_sgd_steps_per_step=num_sgd_steps_per_step, + ) + + # Train the agent + for _ in range(num_steps): + learner.step() + + @parameterized.parameters(("logp",), ("rcal",)) + def test_discrete_actions(self, loss_name): + with chex.fake_pmap_and_jit(): + + num_sgd_steps_per_step = 1 + num_steps = 5 + + # Create a fake environment to test with. + environment = fakes.DiscreteEnvironment( + num_actions=10, + num_observations=100, + obs_shape=(10,), + obs_dtype=np.float32, + ) + + spec = specs.make_environment_spec(environment) + dataset_demonstration = fakes.transition_dataset(environment) + dataset_demonstration = dataset_demonstration.map( + lambda sample: types.Transition(*sample.data) + ) + dataset_demonstration = dataset_demonstration.batch(8).as_numpy_iterator() + + # Construct the agent. + networks = make_networks(spec, discrete_actions=True) + + if loss_name == "logp": + loss_fn = bc.logp() + + elif loss_name == "rcal": + base_loss_fn = bc.logp() + loss_fn = bc.rcal(base_loss_fn, discount=0.99, alpha=0.1) + + else: + raise ValueError + + learner = bc.BCLearner( + networks=networks, + random_key=jax.random.PRNGKey(0), + loss_fn=loss_fn, + optimizer=optax.adam(0.01), + prefetching_iterator=utils.sharded_prefetch(dataset_demonstration), + num_sgd_steps_per_step=num_sgd_steps_per_step, + ) + + # Train the agent + for _ in range(num_steps): + learner.step() + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/jax/bc/builder.py b/acme/agents/jax/bc/builder.py index 92476195f7..a4cfe6b4d3 100644 --- a/acme/agents/jax/bc/builder.py +++ b/acme/agents/jax/bc/builder.py @@ -15,37 +15,34 @@ """BC Builder.""" from typing import Iterator, Optional -from acme import core -from acme import specs -from acme import types +import jax +import optax + +from acme import core, specs, types from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.agents.jax import builders +from acme.agents.jax import actors, builders from acme.agents.jax.bc import config as bc_config -from acme.agents.jax.bc import learning -from acme.agents.jax.bc import losses +from acme.agents.jax.bc import learning, losses from acme.agents.jax.bc import networks as bc_networks from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.jax import variable_utils -from acme.utils import counting -from acme.utils import loggers -import jax -import optax +from acme.jax import utils, variable_utils +from acme.utils import counting, loggers -class BCBuilder(builders.OfflineBuilder[bc_networks.BCNetworks, - actor_core_lib.FeedForwardPolicy, - types.Transition]): - """BC Builder.""" +class BCBuilder( + builders.OfflineBuilder[ + bc_networks.BCNetworks, actor_core_lib.FeedForwardPolicy, types.Transition + ] +): + """BC Builder.""" - def __init__( - self, - config: bc_config.BCConfig, - loss_fn: losses.BCLoss, - loss_has_aux: bool = False, - ): - """Creates a BC learner, an evaluation policy and an eval actor. + def __init__( + self, + config: bc_config.BCConfig, + loss_fn: losses.BCLoss, + loss_has_aux: bool = False, + ): + """Creates a BC learner, an evaluation policy and an eval actor. Args: config: a config with BC hps. @@ -53,61 +50,69 @@ def __init__( loss_has_aux: Whether the loss function returns auxiliary metrics as a second argument. """ - self._config = config - self._loss_fn = loss_fn - self._loss_has_aux = loss_has_aux + self._config = config + self._loss_fn = loss_fn + self._loss_has_aux = loss_has_aux - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: bc_networks.BCNetworks, - dataset: Iterator[types.Transition], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - *, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - del environment_spec + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: bc_networks.BCNetworks, + dataset: Iterator[types.Transition], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + *, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec - return learning.BCLearner( - networks=networks, - random_key=random_key, - loss_fn=self._loss_fn, - optimizer=optax.adam(learning_rate=self._config.learning_rate), - prefetching_iterator=utils.sharded_prefetch(dataset), - num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, - loss_has_aux=self._loss_has_aux, - logger=logger_fn('learner'), - counter=counter) + return learning.BCLearner( + networks=networks, + random_key=random_key, + loss_fn=self._loss_fn, + optimizer=optax.adam(learning_rate=self._config.learning_rate), + prefetching_iterator=utils.sharded_prefetch(dataset), + num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, + loss_has_aux=self._loss_has_aux, + logger=logger_fn("learner"), + counter=counter, + ) - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: actor_core_lib.FeedForwardPolicy, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - ) -> core.Actor: - del environment_spec - assert variable_source is not None - actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) - variable_client = variable_utils.VariableClient( - variable_source, 'policy', device='cpu') - return actors.GenericActor( - actor_core, random_key, variable_client, backend='cpu') + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: actor_core_lib.FeedForwardPolicy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + ) -> core.Actor: + del environment_spec + assert variable_source is not None + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) + variable_client = variable_utils.VariableClient( + variable_source, "policy", device="cpu" + ) + return actors.GenericActor( + actor_core, random_key, variable_client, backend="cpu" + ) - def make_policy(self, - networks: bc_networks.BCNetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> actor_core_lib.FeedForwardPolicy: - """Construct the policy.""" - del environment_spec, evaluation + def make_policy( + self, + networks: bc_networks.BCNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> actor_core_lib.FeedForwardPolicy: + """Construct the policy.""" + del environment_spec, evaluation - def evaluation_policy( - params: networks_lib.Params, key: networks_lib.PRNGKey, - observation: networks_lib.Observation) -> networks_lib.Action: - apply_key, sample_key = jax.random.split(key) - network_output = networks.policy_network.apply( - params, observation, is_training=False, key=apply_key) - return networks.sample_fn(network_output, sample_key) + def evaluation_policy( + params: networks_lib.Params, + key: networks_lib.PRNGKey, + observation: networks_lib.Observation, + ) -> networks_lib.Action: + apply_key, sample_key = jax.random.split(key) + network_output = networks.policy_network.apply( + params, observation, is_training=False, key=apply_key + ) + return networks.sample_fn(network_output, sample_key) - return evaluation_policy + return evaluation_policy diff --git a/acme/agents/jax/bc/config.py b/acme/agents/jax/bc/config.py index 15fa1ff816..7d229589ab 100644 --- a/acme/agents/jax/bc/config.py +++ b/acme/agents/jax/bc/config.py @@ -18,11 +18,12 @@ @dataclasses.dataclass class BCConfig: - """Configuration options for BC. + """Configuration options for BC. Attributes: learning_rate: Learning rate. num_sgd_steps_per_step: How many gradient updates to perform per step. """ - learning_rate: float = 1e-4 - num_sgd_steps_per_step: int = 1 + + learning_rate: float = 1e-4 + num_sgd_steps_per_step: int = 1 diff --git a/acme/agents/jax/bc/learning.py b/acme/agents/jax/bc/learning.py index 11406bdaed..b252d38025 100644 --- a/acme/agents/jax/bc/learning.py +++ b/acme/agents/jax/bc/learning.py @@ -15,7 +15,11 @@ """BC learner implementation.""" import time -from typing import Dict, List, NamedTuple, Optional, Tuple, Union, Iterator +from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple, Union + +import jax +import jax.numpy as jnp +import optax import acme from acme import types @@ -23,21 +27,18 @@ from acme.agents.jax.bc import networks as bc_networks from acme.jax import networks as networks_lib from acme.jax import utils -from acme.utils import counting -from acme.utils import loggers -import jax -import jax.numpy as jnp -import optax +from acme.utils import counting, loggers -_PMAP_AXIS_NAME = 'data' +_PMAP_AXIS_NAME = "data" class TrainingState(NamedTuple): - """Contains training state for the learner.""" - optimizer_state: optax.OptState - policy_params: networks_lib.Params - key: networks_lib.PRNGKey - steps: int + """Contains training state for the learner.""" + + optimizer_state: optax.OptState + policy_params: networks_lib.Params + key: networks_lib.PRNGKey + steps: int def _create_loss_metrics( @@ -45,51 +46,60 @@ def _create_loss_metrics( loss_result: Union[jnp.ndarray, Tuple[jnp.ndarray, loggers.LoggingData]], gradients: jnp.ndarray, ): - """Creates loss metrics for logging.""" - # Validate input. - if loss_has_aux and not (len(loss_result) == 2 and isinstance( - loss_result[0], jnp.ndarray) and isinstance(loss_result[1], dict)): - raise ValueError('Could not parse loss value and metrics from loss_fn\'s ' - 'output. Since loss_has_aux is enabled, loss_fn must ' - 'return loss_value and auxiliary metrics.') - - if not loss_has_aux and not isinstance(loss_result, jnp.ndarray): - raise ValueError(f'Loss returns type {loss_result}. However, it should ' - 'return a jnp.ndarray, given that loss_has_aux = False.') - - # Maybe unpack loss result. - if loss_has_aux: - loss, metrics = loss_result - else: - loss = loss_result - metrics = {} - - # Complete metrics dict and return it. - metrics['loss'] = loss - metrics['gradient_norm'] = optax.global_norm(gradients) - return metrics + """Creates loss metrics for logging.""" + # Validate input. + if loss_has_aux and not ( + len(loss_result) == 2 + and isinstance(loss_result[0], jnp.ndarray) + and isinstance(loss_result[1], dict) + ): + raise ValueError( + "Could not parse loss value and metrics from loss_fn's " + "output. Since loss_has_aux is enabled, loss_fn must " + "return loss_value and auxiliary metrics." + ) + + if not loss_has_aux and not isinstance(loss_result, jnp.ndarray): + raise ValueError( + f"Loss returns type {loss_result}. However, it should " + "return a jnp.ndarray, given that loss_has_aux = False." + ) + + # Maybe unpack loss result. + if loss_has_aux: + loss, metrics = loss_result + else: + loss = loss_result + metrics = {} + + # Complete metrics dict and return it. + metrics["loss"] = loss + metrics["gradient_norm"] = optax.global_norm(gradients) + return metrics class BCLearner(acme.Learner): - """BC learner. + """BC learner. This is the learning component of a BC agent. It takes a Transitions iterator as input and implements update functionality to learn from this iterator. """ - _state: TrainingState - - def __init__(self, - networks: bc_networks.BCNetworks, - random_key: networks_lib.PRNGKey, - loss_fn: losses.BCLoss, - optimizer: optax.GradientTransformation, - prefetching_iterator: Iterator[types.Transition], - num_sgd_steps_per_step: int, - loss_has_aux: bool = False, - logger: Optional[loggers.Logger] = None, - counter: Optional[counting.Counter] = None): - """Behavior Cloning Learner. + _state: TrainingState + + def __init__( + self, + networks: bc_networks.BCNetworks, + random_key: networks_lib.PRNGKey, + loss_fn: losses.BCLoss, + optimizer: optax.GradientTransformation, + prefetching_iterator: Iterator[types.Transition], + num_sgd_steps_per_step: int, + loss_has_aux: bool = False, + logger: Optional[loggers.Logger] = None, + counter: Optional[counting.Counter] = None, + ): + """Behavior Cloning Learner. Args: networks: BC networks @@ -105,96 +115,97 @@ def __init__(self, logger: Logger. counter: Counter. """ - def sgd_step( - state: TrainingState, - transitions: types.Transition, - ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: - - loss_and_grad = jax.value_and_grad( - loss_fn, argnums=1, has_aux=loss_has_aux) - - # Compute losses and their gradients. - key, key_input = jax.random.split(state.key) - loss_result, gradients = loss_and_grad(networks, state.policy_params, - key_input, transitions) - - # Combine the gradient across all devices (by taking their mean). - gradients = jax.lax.pmean(gradients, axis_name=_PMAP_AXIS_NAME) - - # Compute and combine metrics across all devices. - metrics = _create_loss_metrics(loss_has_aux, loss_result, gradients) - metrics = jax.lax.pmean(metrics, axis_name=_PMAP_AXIS_NAME) - - policy_update, optimizer_state = optimizer.update(gradients, - state.optimizer_state, - state.policy_params) - policy_params = optax.apply_updates(state.policy_params, policy_update) - - new_state = TrainingState( - optimizer_state=optimizer_state, - policy_params=policy_params, - key=key, - steps=state.steps + 1, - ) - - return new_state, metrics - - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter(prefix='learner') - self._logger = logger or loggers.make_default_logger( - 'learner', - asynchronous=True, - serialize_fn=utils.fetch_devicearray, - steps_key=self._counter.get_steps_key()) - - # Split the input batch to `num_sgd_steps_per_step` minibatches in order - # to achieve better performance on accelerators. - sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step) - self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME) - - random_key, init_key = jax.random.split(random_key) - policy_params = networks.policy_network.init(init_key) - optimizer_state = optimizer.init(policy_params) - - # Create initial state. - state = TrainingState( - optimizer_state=optimizer_state, - policy_params=policy_params, - key=random_key, - steps=0, - ) - self._state = utils.replicate_in_all_devices(state) - - self._timestamp = None - - self._prefetching_iterator = prefetching_iterator - - def step(self): - # Get a batch of Transitions. - transitions = next(self._prefetching_iterator) - self._state, metrics = self._sgd_step(self._state, transitions) - metrics = utils.get_from_first_device(metrics) - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Increment counts and record the current time - counts = self._counter.increment(steps=1, walltime=elapsed_time) - - # Attempts to write the logs. - self._logger.write({**metrics, **counts}) - - def get_variables(self, names: List[str]) -> List[networks_lib.Params]: - variables = { - 'policy': utils.get_from_first_device(self._state.policy_params), - } - return [variables[name] for name in names] - - def save(self) -> TrainingState: - # Serialize only the first replica of parameters and optimizer state. - return jax.tree_map(utils.get_from_first_device, self._state) - - def restore(self, state: TrainingState): - self._state = utils.replicate_in_all_devices(state) + + def sgd_step( + state: TrainingState, transitions: types.Transition, + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + + loss_and_grad = jax.value_and_grad(loss_fn, argnums=1, has_aux=loss_has_aux) + + # Compute losses and their gradients. + key, key_input = jax.random.split(state.key) + loss_result, gradients = loss_and_grad( + networks, state.policy_params, key_input, transitions + ) + + # Combine the gradient across all devices (by taking their mean). + gradients = jax.lax.pmean(gradients, axis_name=_PMAP_AXIS_NAME) + + # Compute and combine metrics across all devices. + metrics = _create_loss_metrics(loss_has_aux, loss_result, gradients) + metrics = jax.lax.pmean(metrics, axis_name=_PMAP_AXIS_NAME) + + policy_update, optimizer_state = optimizer.update( + gradients, state.optimizer_state, state.policy_params + ) + policy_params = optax.apply_updates(state.policy_params, policy_update) + + new_state = TrainingState( + optimizer_state=optimizer_state, + policy_params=policy_params, + key=key, + steps=state.steps + 1, + ) + + return new_state, metrics + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter(prefix="learner") + self._logger = logger or loggers.make_default_logger( + "learner", + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key(), + ) + + # Split the input batch to `num_sgd_steps_per_step` minibatches in order + # to achieve better performance on accelerators. + sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step) + self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME) + + random_key, init_key = jax.random.split(random_key) + policy_params = networks.policy_network.init(init_key) + optimizer_state = optimizer.init(policy_params) + + # Create initial state. + state = TrainingState( + optimizer_state=optimizer_state, + policy_params=policy_params, + key=random_key, + steps=0, + ) + self._state = utils.replicate_in_all_devices(state) + + self._timestamp = None + + self._prefetching_iterator = prefetching_iterator + + def step(self): + # Get a batch of Transitions. + transitions = next(self._prefetching_iterator) + self._state, metrics = self._sgd_step(self._state, transitions) + metrics = utils.get_from_first_device(metrics) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[networks_lib.Params]: + variables = { + "policy": utils.get_from_first_device(self._state.policy_params), + } + return [variables[name] for name in names] + + def save(self) -> TrainingState: + # Serialize only the first replica of parameters and optimizer state. + return jax.tree_map(utils.get_from_first_device, self._state) + + def restore(self, state: TrainingState): + self._state = utils.replicate_in_all_devices(state) diff --git a/acme/agents/jax/bc/losses.py b/acme/agents/jax/bc/losses.py index de39fbdad2..deed59cbc1 100644 --- a/acme/agents/jax/bc/losses.py +++ b/acme/agents/jax/bc/losses.py @@ -15,18 +15,20 @@ """Offline losses used in variants of BC.""" from typing import Callable, Optional, Tuple, Union +import jax +import jax.numpy as jnp + from acme import types from acme.agents.jax.bc import networks as bc_networks from acme.jax import networks as networks_lib from acme.jax import types as jax_types from acme.utils import loggers -import jax -import jax.numpy as jnp - loss_args = [ - bc_networks.BCNetworks, networks_lib.Params, networks_lib.PRNGKey, - types.Transition + bc_networks.BCNetworks, + networks_lib.Params, + networks_lib.PRNGKey, + types.Transition, ] BCLossWithoutAux = Callable[loss_args, jnp.ndarray] BCLossWithAux = Callable[loss_args, Tuple[jnp.ndarray, loggers.LoggingData]] @@ -34,36 +36,44 @@ def mse() -> BCLossWithoutAux: - """Mean Squared Error loss.""" - - def loss(networks: bc_networks.BCNetworks, params: networks_lib.Params, - key: jax_types.PRNGKey, - transitions: types.Transition) -> jnp.ndarray: - key, key_dropout = jax.random.split(key) - dist_params = networks.policy_network.apply( - params, transitions.observation, is_training=True, key=key_dropout) - action = networks.sample_fn(dist_params, key) - return jnp.mean(jnp.square(action - transitions.action)) + """Mean Squared Error loss.""" + + def loss( + networks: bc_networks.BCNetworks, + params: networks_lib.Params, + key: jax_types.PRNGKey, + transitions: types.Transition, + ) -> jnp.ndarray: + key, key_dropout = jax.random.split(key) + dist_params = networks.policy_network.apply( + params, transitions.observation, is_training=True, key=key_dropout + ) + action = networks.sample_fn(dist_params, key) + return jnp.mean(jnp.square(action - transitions.action)) - return loss + return loss def logp() -> BCLossWithoutAux: - """Log probability loss.""" - - def loss(networks: bc_networks.BCNetworks, params: networks_lib.Params, - key: jax_types.PRNGKey, - transitions: types.Transition) -> jnp.ndarray: - logits = networks.policy_network.apply( - params, transitions.observation, is_training=True, key=key) - logp_action = networks.log_prob(logits, transitions.action) - return -jnp.mean(logp_action) + """Log probability loss.""" + + def loss( + networks: bc_networks.BCNetworks, + params: networks_lib.Params, + key: jax_types.PRNGKey, + transitions: types.Transition, + ) -> jnp.ndarray: + logits = networks.policy_network.apply( + params, transitions.observation, is_training=True, key=key + ) + logp_action = networks.log_prob(logits, transitions.action) + return -jnp.mean(logp_action) - return loss + return loss def peerbc(base_loss_fn: BCLossWithoutAux, zeta: float) -> BCLossWithoutAux: - """Peer-BC loss from https://arxiv.org/pdf/2010.01748.pdf. + """Peer-BC loss from https://arxiv.org/pdf/2010.01748.pdf. Args: base_loss_fn: the base loss to add RCAL on top of. @@ -72,29 +82,35 @@ def peerbc(base_loss_fn: BCLossWithoutAux, zeta: float) -> BCLossWithoutAux: The loss. """ - def loss(networks: bc_networks.BCNetworks, params: networks_lib.Params, - key: jax_types.PRNGKey, - transitions: types.Transition) -> jnp.ndarray: - key_perm, key_bc_loss, key_permuted_loss = jax.random.split(key, 3) - - permutation_keys = jax.random.split(key_perm, transitions.action.shape[0]) - permuted_actions = jax.vmap( - jax.random.permutation, in_axes=(0, 0))(permutation_keys, - transitions.action) - permuted_transition = transitions._replace(action=permuted_actions) - bc_loss = base_loss_fn(networks, params, key_bc_loss, transitions) - permuted_loss = base_loss_fn(networks, params, key_permuted_loss, - permuted_transition) - return bc_loss - zeta * permuted_loss + def loss( + networks: bc_networks.BCNetworks, + params: networks_lib.Params, + key: jax_types.PRNGKey, + transitions: types.Transition, + ) -> jnp.ndarray: + key_perm, key_bc_loss, key_permuted_loss = jax.random.split(key, 3) + + permutation_keys = jax.random.split(key_perm, transitions.action.shape[0]) + permuted_actions = jax.vmap(jax.random.permutation, in_axes=(0, 0))( + permutation_keys, transitions.action + ) + permuted_transition = transitions._replace(action=permuted_actions) + bc_loss = base_loss_fn(networks, params, key_bc_loss, transitions) + permuted_loss = base_loss_fn( + networks, params, key_permuted_loss, permuted_transition + ) + return bc_loss - zeta * permuted_loss - return loss + return loss -def rcal(base_loss_fn: BCLossWithoutAux, - discount: float, - alpha: float, - num_bins: Optional[int] = None) -> BCLossWithoutAux: - """https://www.cristal.univ-lille.fr/~pietquin/pdf/AAMAS_2014_BPMGOP.pdf. +def rcal( + base_loss_fn: BCLossWithoutAux, + discount: float, + alpha: float, + num_bins: Optional[int] = None, +) -> BCLossWithoutAux: + """https://www.cristal.univ-lille.fr/~pietquin/pdf/AAMAS_2014_BPMGOP.pdf. Args: base_loss_fn: the base loss to add RCAL on top of. @@ -106,38 +122,42 @@ def rcal(base_loss_fn: BCLossWithoutAux, The loss function. """ - def loss(networks: bc_networks.BCNetworks, params: networks_lib.Params, - key: jax_types.PRNGKey, - transitions: types.Transition) -> jnp.ndarray: - - def logits_fn(key: jax_types.PRNGKey, - observations: networks_lib.Observation, - actions: Optional[networks_lib.Action] = None): - logits = networks.policy_network.apply( - params, observations, key=key, is_training=True) - if num_bins: - logits = jnp.reshape(logits, list(logits.shape[:-1]) + [-1, num_bins]) - if actions is None: - actions = jnp.argmax(logits, axis=-1) - logits_actions = jnp.sum( - jax.nn.one_hot(actions, logits.shape[-1]) * logits, axis=-1) - return logits_actions - - key, key1, key2 = jax.random.split(key, 3) - - logits_a_tm1 = logits_fn(key1, transitions.observation, transitions.action) - logits_a_t = logits_fn(key2, transitions.next_observation) - - # RCAL, by making a parallel between the logits of BC and Q-values, - # defines a regularization loss that encourages the implicit reward - # (inferred by inversing the Bellman Equation) to be sparse. - # NOTE: In case of discretized envs jnp.mean goes over batch and num_bins - # dimensions. - regularization_loss = jnp.mean( - jnp.abs(logits_a_tm1 - discount * logits_a_t) - ) - - loss = base_loss_fn(networks, params, key, transitions) - return loss + alpha * regularization_loss - - return loss + def loss( + networks: bc_networks.BCNetworks, + params: networks_lib.Params, + key: jax_types.PRNGKey, + transitions: types.Transition, + ) -> jnp.ndarray: + def logits_fn( + key: jax_types.PRNGKey, + observations: networks_lib.Observation, + actions: Optional[networks_lib.Action] = None, + ): + logits = networks.policy_network.apply( + params, observations, key=key, is_training=True + ) + if num_bins: + logits = jnp.reshape(logits, list(logits.shape[:-1]) + [-1, num_bins]) + if actions is None: + actions = jnp.argmax(logits, axis=-1) + logits_actions = jnp.sum( + jax.nn.one_hot(actions, logits.shape[-1]) * logits, axis=-1 + ) + return logits_actions + + key, key1, key2 = jax.random.split(key, 3) + + logits_a_tm1 = logits_fn(key1, transitions.observation, transitions.action) + logits_a_t = logits_fn(key2, transitions.next_observation) + + # RCAL, by making a parallel between the logits of BC and Q-values, + # defines a regularization loss that encourages the implicit reward + # (inferred by inversing the Bellman Equation) to be sparse. + # NOTE: In case of discretized envs jnp.mean goes over batch and num_bins + # dimensions. + regularization_loss = jnp.mean(jnp.abs(logits_a_tm1 - discount * logits_a_t)) + + loss = base_loss_fn(networks, params, key, transitions) + return loss + alpha * regularization_loss + + return loss diff --git a/acme/agents/jax/bc/networks.py b/acme/agents/jax/bc/networks.py index 1d08af13c3..85e274b711 100644 --- a/acme/agents/jax/bc/networks.py +++ b/acme/agents/jax/bc/networks.py @@ -17,26 +17,28 @@ import dataclasses from typing import Callable, Optional +from typing_extensions import Protocol + from acme.jax import networks as networks_lib from acme.jax import types -from typing_extensions import Protocol class ApplyFn(Protocol): - - def __call__(self, - params: networks_lib.Params, - observation: networks_lib.Observation, - *args, - is_training: bool, - key: Optional[types.PRNGKey] = None, - **kwargs) -> networks_lib.NetworkOutput: - ... + def __call__( + self, + params: networks_lib.Params, + observation: networks_lib.Observation, + *args, + is_training: bool, + key: Optional[types.PRNGKey] = None, + **kwargs + ) -> networks_lib.NetworkOutput: + ... @dataclasses.dataclass class BCPolicyNetwork: - """Holds a pair of pure functions defining a policy network for BC. + """Holds a pair of pure functions defining a policy network for BC. This is a feed-forward network taking params, obs, is_training, key as input. @@ -44,20 +46,22 @@ class BCPolicyNetwork: init: A pure function. Initializes and returns the networks parameters. apply: A pure function. Computes and returns the outputs of a forward pass. """ - init: Callable[[types.PRNGKey], networks_lib.Params] - apply: ApplyFn + + init: Callable[[types.PRNGKey], networks_lib.Params] + apply: ApplyFn -def identity_sample(output: networks_lib.NetworkOutput, - key: types.PRNGKey) -> networks_lib.Action: - """Placeholder sampling function for non-distributional networks.""" - del key - return output +def identity_sample( + output: networks_lib.NetworkOutput, key: types.PRNGKey +) -> networks_lib.Action: + """Placeholder sampling function for non-distributional networks.""" + del key + return output @dataclasses.dataclass class BCNetworks: - """The network and pure functions for the BC agent. + """The network and pure functions for the BC agent. Attributes: policy_network: The policy network. @@ -66,14 +70,16 @@ class BCNetworks: log_prob: A pure function. Computes log-probability for an action. Must be set for distributional networks. Otherwise None. """ - policy_network: BCPolicyNetwork - sample_fn: networks_lib.SampleFn = identity_sample - log_prob: Optional[networks_lib.LogProbFn] = None + + policy_network: BCPolicyNetwork + sample_fn: networks_lib.SampleFn = identity_sample + log_prob: Optional[networks_lib.LogProbFn] = None def convert_to_bc_network( - policy_network: networks_lib.FeedForwardNetwork) -> BCPolicyNetwork: - """Converts a policy network from SAC/TD3/D4PG/.. into a BC policy network. + policy_network: networks_lib.FeedForwardNetwork, +) -> BCPolicyNetwork: + """Converts a policy network from SAC/TD3/D4PG/.. into a BC policy network. Args: policy_network: FeedForwardNetwork taking the observation as input and @@ -83,21 +89,24 @@ def convert_to_bc_network( The BC policy network taking observation, is_training, key as input. """ - def apply(params: networks_lib.Params, - observation: networks_lib.Observation, - *args, - is_training: bool = False, - key: Optional[types.PRNGKey] = None, - **kwargs) -> networks_lib.NetworkOutput: - del is_training, key - return policy_network.apply(params, observation, *args, **kwargs) + def apply( + params: networks_lib.Params, + observation: networks_lib.Observation, + *args, + is_training: bool = False, + key: Optional[types.PRNGKey] = None, + **kwargs + ) -> networks_lib.NetworkOutput: + del is_training, key + return policy_network.apply(params, observation, *args, **kwargs) - return BCPolicyNetwork(policy_network.init, apply) + return BCPolicyNetwork(policy_network.init, apply) def convert_policy_value_to_bc_network( - policy_value_network: networks_lib.FeedForwardNetwork) -> BCPolicyNetwork: - """Converts a policy-value network (e.g. from PPO) into a BC policy network. + policy_value_network: networks_lib.FeedForwardNetwork, +) -> BCPolicyNetwork: + """Converts a policy-value network (e.g. from PPO) into a BC policy network. Args: policy_value_network: FeedForwardNetwork taking the observation as input. @@ -106,15 +115,16 @@ def convert_policy_value_to_bc_network( The BC policy network taking observation, is_training, key as input. """ - def apply(params: networks_lib.Params, - observation: networks_lib.Observation, - *args, - is_training: bool = False, - key: Optional[types.PRNGKey] = None, - **kwargs) -> networks_lib.NetworkOutput: - del is_training, key - actions, _ = policy_value_network.apply(params, observation, *args, - **kwargs) - return actions - - return BCPolicyNetwork(policy_value_network.init, apply) + def apply( + params: networks_lib.Params, + observation: networks_lib.Observation, + *args, + is_training: bool = False, + key: Optional[types.PRNGKey] = None, + **kwargs + ) -> networks_lib.NetworkOutput: + del is_training, key + actions, _ = policy_value_network.apply(params, observation, *args, **kwargs) + return actions + + return BCPolicyNetwork(policy_value_network.init, apply) diff --git a/acme/agents/jax/bc/pretraining.py b/acme/agents/jax/bc/pretraining.py index b761578b0e..6ff2346c69 100644 --- a/acme/agents/jax/bc/pretraining.py +++ b/acme/agents/jax/bc/pretraining.py @@ -15,22 +15,23 @@ """Tools to train a policy network with BC.""" from typing import Callable, Iterator +import jax +import optax + from acme import types -from acme.agents.jax.bc import learning -from acme.agents.jax.bc import losses +from acme.agents.jax.bc import learning, losses from acme.agents.jax.bc import networks as bc_networks from acme.jax import networks as networks_lib from acme.jax import utils -import jax -import optax -def train_with_bc(make_demonstrations: Callable[[int], - Iterator[types.Transition]], - networks: bc_networks.BCNetworks, - loss: losses.BCLoss, - num_steps: int = 100000) -> networks_lib.Params: - """Trains the given network with BC and returns the params. +def train_with_bc( + make_demonstrations: Callable[[int], Iterator[types.Transition]], + networks: bc_networks.BCNetworks, + loss: losses.BCLoss, + num_steps: int = 100000, +) -> networks_lib.Params: + """Trains the given network with BC and returns the params. Args: make_demonstrations: A function (batch_size) -> iterator with demonstrations @@ -42,22 +43,22 @@ def train_with_bc(make_demonstrations: Callable[[int], Returns: The trained network params. """ - demonstration_iterator = make_demonstrations(256) - prefetching_iterator = utils.sharded_prefetch( - demonstration_iterator, - buffer_size=2, - num_threads=jax.local_device_count()) - - learner = learning.BCLearner( - networks=networks, - random_key=jax.random.PRNGKey(0), - loss_fn=loss, - prefetching_iterator=prefetching_iterator, - optimizer=optax.adam(1e-4), - num_sgd_steps_per_step=1) - - # Train the agent - for _ in range(num_steps): - learner.step() - - return learner.get_variables(['policy'])[0] + demonstration_iterator = make_demonstrations(256) + prefetching_iterator = utils.sharded_prefetch( + demonstration_iterator, buffer_size=2, num_threads=jax.local_device_count() + ) + + learner = learning.BCLearner( + networks=networks, + random_key=jax.random.PRNGKey(0), + loss_fn=loss, + prefetching_iterator=prefetching_iterator, + optimizer=optax.adam(1e-4), + num_sgd_steps_per_step=1, + ) + + # Train the agent + for _ in range(num_steps): + learner.step() + + return learner.get_variables(["policy"])[0] diff --git a/acme/agents/jax/bc/pretraining_test.py b/acme/agents/jax/bc/pretraining_test.py index b298cc8c36..ef6cd8c76c 100644 --- a/acme/agents/jax/bc/pretraining_test.py +++ b/acme/agents/jax/bc/pretraining_test.py @@ -14,81 +14,84 @@ """Tests for bc_initialization.""" -from acme import specs -from acme.agents.jax import bc -from acme.agents.jax import sac -from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.testing import fakes import haiku as hk import jax import numpy as np - from absl.testing import absltest +from acme import specs +from acme.agents.jax import bc, sac +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.testing import fakes + def make_networks(spec: specs.EnvironmentSpec) -> bc.BCNetworks: - """Creates networks used by the agent.""" + """Creates networks used by the agent.""" - final_layer_size = np.prod(spec.actions.shape, dtype=int) + final_layer_size = np.prod(spec.actions.shape, dtype=int) - def _actor_fn(obs, is_training=False, key=None): - # is_training and key allows to defined train/test dependant modules - # like dropout. - del is_training - del key - network = networks_lib.LayerNormMLP([64, 64, final_layer_size], - activate_final=False) - return jax.nn.tanh(network(obs)) + def _actor_fn(obs, is_training=False, key=None): + # is_training and key allows to defined train/test dependant modules + # like dropout. + del is_training + del key + network = networks_lib.LayerNormMLP( + [64, 64, final_layer_size], activate_final=False + ) + return jax.nn.tanh(network(obs)) - policy = hk.without_apply_rng(hk.transform(_actor_fn)) + policy = hk.without_apply_rng(hk.transform(_actor_fn)) - # Create dummy observations and actions to create network parameters. - dummy_obs = utils.zeros_like(spec.observations) - dummy_obs = utils.add_batch_dim(dummy_obs) - policy_network = bc.BCPolicyNetwork(lambda key: policy.init(key, dummy_obs), - policy.apply) + # Create dummy observations and actions to create network parameters. + dummy_obs = utils.zeros_like(spec.observations) + dummy_obs = utils.add_batch_dim(dummy_obs) + policy_network = bc.BCPolicyNetwork( + lambda key: policy.init(key, dummy_obs), policy.apply + ) - return bc.BCNetworks(policy_network) + return bc.BCNetworks(policy_network) class BcPretrainingTest(absltest.TestCase): + def test_bc_initialization(self): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment( + episode_length=10, bounded=True, action_dim=6 + ) + spec = specs.make_environment_spec(environment) - def test_bc_initialization(self): - # Create a fake environment to test with. - environment = fakes.ContinuousEnvironment( - episode_length=10, bounded=True, action_dim=6) - spec = specs.make_environment_spec(environment) + # Construct the agent. + nets = make_networks(spec) - # Construct the agent. - nets = make_networks(spec) + loss = bc.mse() - loss = bc.mse() + bc.pretraining.train_with_bc( + fakes.transition_iterator(environment), nets, loss, num_steps=100 + ) - bc.pretraining.train_with_bc( - fakes.transition_iterator(environment), nets, loss, num_steps=100) + def test_sac_to_bc_networks(self): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment( + episode_length=10, bounded=True, action_dim=6 + ) + spec = specs.make_environment_spec(environment) - def test_sac_to_bc_networks(self): - # Create a fake environment to test with. - environment = fakes.ContinuousEnvironment( - episode_length=10, bounded=True, action_dim=6) - spec = specs.make_environment_spec(environment) + sac_nets = sac.make_networks(spec, hidden_layer_sizes=(4, 4)) + bc_nets = bc.convert_to_bc_network(sac_nets.policy_network) - sac_nets = sac.make_networks(spec, hidden_layer_sizes=(4, 4)) - bc_nets = bc.convert_to_bc_network(sac_nets.policy_network) - - rng = jax.random.PRNGKey(0) - dummy_obs = utils.zeros_like(spec.observations) - dummy_obs = utils.add_batch_dim(dummy_obs) + rng = jax.random.PRNGKey(0) + dummy_obs = utils.zeros_like(spec.observations) + dummy_obs = utils.add_batch_dim(dummy_obs) - sac_params = sac_nets.policy_network.init(rng) - sac_output = sac_nets.policy_network.apply(sac_params, dummy_obs) + sac_params = sac_nets.policy_network.init(rng) + sac_output = sac_nets.policy_network.apply(sac_params, dummy_obs) - bc_params = bc_nets.init(rng) - bc_output = bc_nets.apply(bc_params, dummy_obs, is_training=False, key=None) + bc_params = bc_nets.init(rng) + bc_output = bc_nets.apply(bc_params, dummy_obs, is_training=False, key=None) - np.testing.assert_array_equal(sac_output.mode(), bc_output.mode()) + np.testing.assert_array_equal(sac_output.mode(), bc_output.mode()) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/jax/builders.py b/acme/agents/jax/builders.py index 4f19519e44..5b79fda848 100644 --- a/acme/agents/jax/builders.py +++ b/acme/agents/jax/builders.py @@ -18,14 +18,12 @@ import dataclasses from typing import Generic, Iterator, List, Optional -from acme import adders -from acme import core -from acme import specs +import reverb + +from acme import adders, core, specs from acme.jax import networks as networks_lib from acme.jax import types as jax_types -from acme.utils import counting -from acme.utils import loggers -import reverb +from acme.utils import counting, loggers Networks = jax_types.Networks Policy = jax_types.Policy @@ -33,25 +31,25 @@ class OfflineBuilder(abc.ABC, Generic[Networks, Policy, Sample]): - """Interface for defining the components of an offline RL agent. + """Interface for defining the components of an offline RL agent. Implementations of this interface contain a complete specification of a concrete offline RL agent. An instance of this class can be used to build an offline RL agent that operates either locally or in a distributed setup. """ - @abc.abstractmethod - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: Networks, - dataset: Iterator[Sample], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - *, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - """Creates an instance of the learner. + @abc.abstractmethod + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: Networks, + dataset: Iterator[Sample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + *, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + """Creates an instance of the learner. Args: random_key: A key for random number generation. @@ -64,15 +62,15 @@ def make_learner( evaluator steps, etc.) distributed throughout the agent. """ - @abc.abstractmethod - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: Policy, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - ) -> core.Actor: - """Create an actor instance to be used for evaluation. + @abc.abstractmethod + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: Policy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + ) -> core.Actor: + """Create an actor instance to be used for evaluation. Args: random_key: A key for random number generation. @@ -82,11 +80,14 @@ def make_actor( variable_source: A source providing the necessary actor parameters. """ - @abc.abstractmethod - def make_policy(self, networks: Networks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool) -> Policy: - """Creates the agent policy to be used for evaluation. + @abc.abstractmethod + def make_policy( + self, + networks: Networks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool, + ) -> Policy: + """Creates the agent policy to be used for evaluation. Args: networks: struct describing the networks needed to generate the policy. @@ -103,9 +104,10 @@ def make_policy(self, networks: Networks, """ -class ActorLearnerBuilder(OfflineBuilder[Networks, Policy, Sample], - Generic[Networks, Policy, Sample]): - """Defines an interface for defining the components of an RL agent. +class ActorLearnerBuilder( + OfflineBuilder[Networks, Policy, Sample], Generic[Networks, Policy, Sample] +): + """Defines an interface for defining the components of an RL agent. Implementations of this interface contain a complete specification of a concrete RL agent. An instance of this class can be used to build an @@ -113,13 +115,11 @@ class ActorLearnerBuilder(OfflineBuilder[Networks, Policy, Sample], distributed setup. """ - @abc.abstractmethod - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: Policy, - ) -> List[reverb.Table]: - """Create tables to insert data into. + @abc.abstractmethod + def make_replay_tables( + self, environment_spec: specs.EnvironmentSpec, policy: Policy, + ) -> List[reverb.Table]: + """Create tables to insert data into. Args: environment_spec: A container for all relevant environment specs. @@ -129,39 +129,36 @@ def make_replay_tables( The replay tables used to store the experience the agent uses to train. """ - @abc.abstractmethod - def make_dataset_iterator( - self, - replay_client: reverb.Client, - ) -> Iterator[Sample]: - """Create a dataset iterator to use for learning/updating the agent.""" - - @abc.abstractmethod - def make_adder( - self, - replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[Policy], - ) -> Optional[adders.Adder]: - """Create an adder which records data generated by the actor/environment. + @abc.abstractmethod + def make_dataset_iterator(self, replay_client: reverb.Client,) -> Iterator[Sample]: + """Create a dataset iterator to use for learning/updating the agent.""" + + @abc.abstractmethod + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[Policy], + ) -> Optional[adders.Adder]: + """Create an adder which records data generated by the actor/environment. Args: replay_client: Reverb Client which points to the replay server. environment_spec: specs of the environment. policy: Agent's policy which can be used to extract the extras_spec. """ - # TODO(sabela): make the parameters non-optional. - - @abc.abstractmethod - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: Policy, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - adder: Optional[adders.Adder] = None, - ) -> core.Actor: - """Create an actor instance. + # TODO(sabela): make the parameters non-optional. + + @abc.abstractmethod + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: Policy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + """Create an actor instance. Args: random_key: A key for random number generation. @@ -172,18 +169,18 @@ def make_actor( adder: How data is recorded (e.g. added to replay). """ - @abc.abstractmethod - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: Networks, - dataset: Iterator[Sample], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - """Creates an instance of the learner. + @abc.abstractmethod + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: Networks, + dataset: Iterator[Sample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + """Creates an instance of the learner. Args: random_key: A key for random number generation. @@ -199,11 +196,13 @@ def make_learner( actor steps, etc.) distributed throughout the agent. """ - def make_policy(self, - networks: Networks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> Policy: - """Creates the agent policy. + def make_policy( + self, + networks: Networks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> Policy: + """Creates the agent policy. Creates the agent policy given the collection of network components and environment spec. An optional boolean can be given to indicate if the @@ -220,68 +219,74 @@ def make_policy(self, Returns: Behavior policy or evaluation policy for the agent. """ - # TODO(sabela): make abstract once all agents implement it. - del networks, environment_spec, evaluation - raise NotImplementedError + # TODO(sabela): make abstract once all agents implement it. + del networks, environment_spec, evaluation + raise NotImplementedError @dataclasses.dataclass(frozen=True) -class ActorLearnerBuilderWrapper(ActorLearnerBuilder[Networks, Policy, Sample], - Generic[Networks, Policy, Sample]): - """An empty wrapper for ActorLearnerBuilder.""" - - wrapped: ActorLearnerBuilder[Networks, Policy, Sample] - - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: Policy, - ) -> List[reverb.Table]: - return self.wrapped.make_replay_tables(environment_spec, policy) - - def make_dataset_iterator( - self, - replay_client: reverb.Client, - ) -> Iterator[Sample]: - return self.wrapped.make_dataset_iterator(replay_client) - - def make_adder( - self, - replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[Policy], - ) -> Optional[adders.Adder]: - return self.wrapped.make_adder(replay_client, environment_spec, policy) - - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: Policy, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - adder: Optional[adders.Adder] = None, - ) -> core.Actor: - return self.wrapped.make_actor(random_key, policy, environment_spec, - variable_source, adder) - - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: Networks, - dataset: Iterator[Sample], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - return self.wrapped.make_learner(random_key, networks, dataset, logger_fn, - environment_spec, replay_client, counter) - - def make_policy(self, - networks: Networks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> Policy: - return self.wrapped.make_policy(networks, environment_spec, evaluation) +class ActorLearnerBuilderWrapper( + ActorLearnerBuilder[Networks, Policy, Sample], Generic[Networks, Policy, Sample] +): + """An empty wrapper for ActorLearnerBuilder.""" + + wrapped: ActorLearnerBuilder[Networks, Policy, Sample] + + def make_replay_tables( + self, environment_spec: specs.EnvironmentSpec, policy: Policy, + ) -> List[reverb.Table]: + return self.wrapped.make_replay_tables(environment_spec, policy) + + def make_dataset_iterator(self, replay_client: reverb.Client,) -> Iterator[Sample]: + return self.wrapped.make_dataset_iterator(replay_client) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[Policy], + ) -> Optional[adders.Adder]: + return self.wrapped.make_adder(replay_client, environment_spec, policy) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: Policy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + return self.wrapped.make_actor( + random_key, policy, environment_spec, variable_source, adder + ) + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: Networks, + dataset: Iterator[Sample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + return self.wrapped.make_learner( + random_key, + networks, + dataset, + logger_fn, + environment_spec, + replay_client, + counter, + ) + + def make_policy( + self, + networks: Networks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> Policy: + return self.wrapped.make_policy(networks, environment_spec, evaluation) # TODO(sinopalnikov): deprecated, migrate all users and remove. diff --git a/acme/agents/jax/bve/builder.py b/acme/agents/jax/bve/builder.py index 181c407138..654567d63d 100644 --- a/acme/agents/jax/bve/builder.py +++ b/acme/agents/jax/bve/builder.py @@ -15,71 +15,73 @@ """BVE Builder.""" from typing import Iterator, Optional -from acme import core -from acme import specs +import haiku as hk +import optax + +from acme import core, specs from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.agents.jax import builders +from acme.agents.jax import actors, builders from acme.agents.jax.bve import losses from acme.agents.jax.bve import networks as bve_networks from acme.agents.jax.dqn import learning_lib from acme.jax import networks as networks_lib from acme.jax import types as jax_types -from acme.jax import utils -from acme.jax import variable_utils -from acme.utils import counting -from acme.utils import loggers -import haiku as hk -import optax +from acme.jax import utils, variable_utils +from acme.utils import counting, loggers -class BVEBuilder(builders.OfflineBuilder[bve_networks.BVENetworks, - actor_core_lib.ActorCore, - utils.PrefetchingSplit]): - """BVE Builder.""" +class BVEBuilder( + builders.OfflineBuilder[ + bve_networks.BVENetworks, actor_core_lib.ActorCore, utils.PrefetchingSplit + ] +): + """BVE Builder.""" - def __init__(self, config): - """Build a BVE agent. + def __init__(self, config): + """Build a BVE agent. Args: config: The config of the BVE agent. """ - self._config = config - - def make_learner(self, - random_key: jax_types.PRNGKey, - networks: bve_networks.BVENetworks, - dataset: Iterator[utils.PrefetchingSplit], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - counter: Optional[counting.Counter] = None) -> core.Learner: - del environment_spec - - loss_fn = losses.BVELoss( - discount=self._config.discount, - max_abs_reward=self._config.max_abs_reward, - huber_loss_parameter=self._config.huber_loss_parameter, - ) - - return learning_lib.SGDLearner( - network=networks.policy_network, - random_key=random_key, - optimizer=optax.adam( - self._config.learning_rate, eps=self._config.adam_eps), - target_update_period=self._config.target_update_period, - data_iterator=dataset, - loss_fn=loss_fn, - counter=counter, - num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, - logger=logger_fn('learner')) - - def make_actor( - self, - random_key: jax_types.PRNGKey, - policy: actor_core_lib.ActorCore, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None) -> core.Actor: - """Create the actor for the BVE to perform online evals. + self._config = config + + def make_learner( + self, + random_key: jax_types.PRNGKey, + networks: bve_networks.BVENetworks, + dataset: Iterator[utils.PrefetchingSplit], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec + + loss_fn = losses.BVELoss( + discount=self._config.discount, + max_abs_reward=self._config.max_abs_reward, + huber_loss_parameter=self._config.huber_loss_parameter, + ) + + return learning_lib.SGDLearner( + network=networks.policy_network, + random_key=random_key, + optimizer=optax.adam(self._config.learning_rate, eps=self._config.adam_eps), + target_update_period=self._config.target_update_period, + data_iterator=dataset, + loss_fn=loss_fn, + counter=counter, + num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, + logger=logger_fn("learner"), + ) + + def make_actor( + self, + random_key: jax_types.PRNGKey, + policy: actor_core_lib.ActorCore, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + ) -> core.Actor: + """Create the actor for the BVE to perform online evals. Args: random_key: prng key. @@ -90,24 +92,29 @@ def make_actor( Returns: Return the actor for the evaluations. """ - del environment_spec - variable_client = variable_utils.VariableClient( - variable_source, 'policy', device='cpu') - return actors.GenericActor(policy, random_key, variable_client) - - def make_policy( - self, - networks: bve_networks.BVENetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: Optional[bool] = False) -> actor_core_lib.ActorCore: - """Creates a policy.""" - del environment_spec, evaluation - - def behavior_policy( - params: hk.Params, key: jax_types.PRNGKey, - observation: networks_lib.Observation) -> networks_lib.Action: - network_output = networks.policy_network.apply( - params, observation, is_training=False) - return networks.sample_fn(network_output, key) - - return actor_core_lib.batched_feed_forward_to_actor_core(behavior_policy) + del environment_spec + variable_client = variable_utils.VariableClient( + variable_source, "policy", device="cpu" + ) + return actors.GenericActor(policy, random_key, variable_client) + + def make_policy( + self, + networks: bve_networks.BVENetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: Optional[bool] = False, + ) -> actor_core_lib.ActorCore: + """Creates a policy.""" + del environment_spec, evaluation + + def behavior_policy( + params: hk.Params, + key: jax_types.PRNGKey, + observation: networks_lib.Observation, + ) -> networks_lib.Action: + network_output = networks.policy_network.apply( + params, observation, is_training=False + ) + return networks.sample_fn(network_output, key) + + return actor_core_lib.batched_feed_forward_to_actor_core(behavior_policy) diff --git a/acme/agents/jax/bve/config.py b/acme/agents/jax/bve/config.py index cdcfa0c5a9..559e62662d 100644 --- a/acme/agents/jax/bve/config.py +++ b/acme/agents/jax/bve/config.py @@ -22,7 +22,7 @@ @dataclasses.dataclass class BVEConfig: - """Configuration options for BVE agent. + """Configuration options for BVE agent. Attributes: epsilon: for use by epsilon-greedy policies. If multiple, the epsilons are @@ -41,18 +41,19 @@ class BVEConfig: num_sgd_steps_per_step: How many gradient updates to perform per learner step. """ - epsilon: Union[float, Sequence[float]] = 0.05 - # TODO(b/191706065): update all clients and remove this field. - seed: int = 1 - - # Learning rule - learning_rate: Union[float, Callable[[int], float]] = 3e-4 - adam_eps: float = 1e-8 # Eps for Adam optimizer. - discount: float = 0.99 # Discount rate applied to value per timestep. - target_update_period: int = 2500 # Update target network every period. - max_gradient_norm: float = np.inf # For gradient clipping. - max_abs_reward: float = 1. # Maximum absolute value to clip the rewards. - huber_loss_parameter: float = 1. # Huber loss delta parameter. - batch_size: int = 256 # Minibatch size. - prefetch_size = 500 # The amount of data to prefetch into the memory. - num_sgd_steps_per_step: int = 1 + + epsilon: Union[float, Sequence[float]] = 0.05 + # TODO(b/191706065): update all clients and remove this field. + seed: int = 1 + + # Learning rule + learning_rate: Union[float, Callable[[int], float]] = 3e-4 + adam_eps: float = 1e-8 # Eps for Adam optimizer. + discount: float = 0.99 # Discount rate applied to value per timestep. + target_update_period: int = 2500 # Update target network every period. + max_gradient_norm: float = np.inf # For gradient clipping. + max_abs_reward: float = 1.0 # Maximum absolute value to clip the rewards. + huber_loss_parameter: float = 1.0 # Huber loss delta parameter. + batch_size: int = 256 # Minibatch size. + prefetch_size = 500 # The amount of data to prefetch into the memory. + num_sgd_steps_per_step: int = 1 diff --git a/acme/agents/jax/bve/losses.py b/acme/agents/jax/bve/losses.py index cac6e278c8..003e29644f 100644 --- a/acme/agents/jax/bve/losses.py +++ b/acme/agents/jax/bve/losses.py @@ -16,18 +16,19 @@ import dataclasses from typing import Tuple -from acme import types -from acme.agents.jax import dqn -from acme.jax import networks as networks_lib import jax import jax.numpy as jnp import reverb import rlax +from acme import types +from acme.agents.jax import dqn +from acme.jax import networks as networks_lib + @dataclasses.dataclass class BVELoss(dqn.LossFn): - """This loss implements TD-loss to estimate behavior value. + """This loss implements TD-loss to estimate behavior value. This loss function uses the next action to learn with the SARSA tuples. It is intended to be used with dqn.SGDLearner. The method was proposed @@ -35,43 +36,51 @@ class BVELoss(dqn.LossFn): the extrapolation error in offline RL setting: https://arxiv.org/abs/2103.09575 """ - discount: float = 0.99 - max_abs_reward: float = 1. - huber_loss_parameter: float = 1. - def __call__( - self, - network: networks_lib.TypedFeedForwardNetwork, - params: networks_lib.Params, - target_params: networks_lib.Params, - batch: reverb.ReplaySample, - key: networks_lib.PRNGKey, - ) -> Tuple[jax.Array, dqn.LossExtra]: - """Calculate a loss on a single batch of data.""" - transitions: types.Transition = batch.data + discount: float = 0.99 + max_abs_reward: float = 1.0 + huber_loss_parameter: float = 1.0 + + def __call__( + self, + network: networks_lib.TypedFeedForwardNetwork, + params: networks_lib.Params, + target_params: networks_lib.Params, + batch: reverb.ReplaySample, + key: networks_lib.PRNGKey, + ) -> Tuple[jax.Array, dqn.LossExtra]: + """Calculate a loss on a single batch of data.""" + transitions: types.Transition = batch.data - # Forward pass. - key1, key2 = jax.random.split(key) - q_tm1 = network.apply( - params, transitions.observation, is_training=True, key=key1) - q_t_value = network.apply( - target_params, transitions.next_observation, is_training=True, key=key2) + # Forward pass. + key1, key2 = jax.random.split(key) + q_tm1 = network.apply( + params, transitions.observation, is_training=True, key=key1 + ) + q_t_value = network.apply( + target_params, transitions.next_observation, is_training=True, key=key2 + ) - # Cast and clip rewards. - d_t = (transitions.discount * self.discount).astype(jnp.float32) - r_t = jnp.clip(transitions.reward, -self.max_abs_reward, - self.max_abs_reward).astype(jnp.float32) + # Cast and clip rewards. + d_t = (transitions.discount * self.discount).astype(jnp.float32) + r_t = jnp.clip( + transitions.reward, -self.max_abs_reward, self.max_abs_reward + ).astype(jnp.float32) - # Compute double Q-learning n-step TD-error. - batch_error = jax.vmap(rlax.sarsa) - next_action = transitions.extras['next_action'] - td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t_value, - next_action) - batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter) + # Compute double Q-learning n-step TD-error. + batch_error = jax.vmap(rlax.sarsa) + next_action = transitions.extras["next_action"] + td_error = batch_error( + q_tm1, transitions.action, r_t, d_t, q_t_value, next_action + ) + batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter) - # Average: - loss = jnp.mean(batch_loss) # [] - metrics = {'td_error': td_error, 'batch_loss': batch_loss} - return loss, dqn.LossExtra( - metrics=metrics, - reverb_priorities=jnp.abs(td_error).astype(jnp.float64)) + # Average: + loss = jnp.mean(batch_loss) # [] + metrics = {"td_error": td_error, "batch_loss": batch_loss} + return ( + loss, + dqn.LossExtra( + metrics=metrics, reverb_priorities=jnp.abs(td_error).astype(jnp.float64) + ), + ) diff --git a/acme/agents/jax/bve/networks.py b/acme/agents/jax/bve/networks.py index c162efb6bb..864089bd0c 100644 --- a/acme/agents/jax/bve/networks.py +++ b/acme/agents/jax/bve/networks.py @@ -22,13 +22,14 @@ @dataclasses.dataclass class BVENetworks: - """The network and pure functions for the BVE agent. + """The network and pure functions for the BVE agent. Attributes: policy_network: The policy network. sample_fn: A pure function. Samples an action based on the network output. log_prob: A pure function. Computes log-probability for an action. """ - policy_network: networks_lib.TypedFeedForwardNetwork - sample_fn: networks_lib.SampleFn - log_prob: Optional[networks_lib.LogProbFn] = None + + policy_network: networks_lib.TypedFeedForwardNetwork + sample_fn: networks_lib.SampleFn + log_prob: Optional[networks_lib.LogProbFn] = None diff --git a/acme/agents/jax/cql/__init__.py b/acme/agents/jax/cql/__init__.py index 238ed87849..fa813dd48c 100644 --- a/acme/agents/jax/cql/__init__.py +++ b/acme/agents/jax/cql/__init__.py @@ -17,5 +17,4 @@ from acme.agents.jax.cql.builder import CQLBuilder from acme.agents.jax.cql.config import CQLConfig from acme.agents.jax.cql.learning import CQLLearner -from acme.agents.jax.cql.networks import CQLNetworks -from acme.agents.jax.cql.networks import make_networks +from acme.agents.jax.cql.networks import CQLNetworks, make_networks diff --git a/acme/agents/jax/cql/agent_test.py b/acme/agents/jax/cql/agent_test.py index 2babe68917..b5a8f8482c 100644 --- a/acme/agents/jax/cql/agent_test.py +++ b/acme/agents/jax/cql/agent_test.py @@ -14,49 +14,49 @@ """Tests for the CQL agent.""" -from acme import specs -from acme.agents.jax import cql -from acme.testing import fakes import jax import optax - from absl.testing import absltest +from acme import specs +from acme.agents.jax import cql +from acme.testing import fakes -class CQLTest(absltest.TestCase): - def test_train(self): - seed = 0 - num_iterations = 6 - batch_size = 64 - - # Create a fake environment to test with. - environment = fakes.ContinuousEnvironment( - episode_length=10, bounded=True, action_dim=6) - spec = specs.make_environment_spec(environment) - - # Construct the agent. - networks = cql.make_networks( - spec, hidden_layer_sizes=(8, 8)) - dataset = fakes.transition_iterator(environment) - key = jax.random.PRNGKey(seed) - learner = cql.CQLLearner( - batch_size, - networks, - key, - demonstrations=dataset(batch_size), - policy_optimizer=optax.adam(3e-5), - critic_optimizer=optax.adam(3e-4), - fixed_cql_coefficient=5., - cql_lagrange_threshold=None, - target_entropy=0.1, - num_bc_iters=2, - num_sgd_steps_per_step=1) - - # Train the agent - for _ in range(num_iterations): - learner.step() - - -if __name__ == '__main__': - absltest.main() +class CQLTest(absltest.TestCase): + def test_train(self): + seed = 0 + num_iterations = 6 + batch_size = 64 + + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment( + episode_length=10, bounded=True, action_dim=6 + ) + spec = specs.make_environment_spec(environment) + + # Construct the agent. + networks = cql.make_networks(spec, hidden_layer_sizes=(8, 8)) + dataset = fakes.transition_iterator(environment) + key = jax.random.PRNGKey(seed) + learner = cql.CQLLearner( + batch_size, + networks, + key, + demonstrations=dataset(batch_size), + policy_optimizer=optax.adam(3e-5), + critic_optimizer=optax.adam(3e-4), + fixed_cql_coefficient=5.0, + cql_lagrange_threshold=None, + target_entropy=0.1, + num_bc_iters=2, + num_sgd_steps_per_step=1, + ) + + # Train the agent + for _ in range(num_iterations): + learner.step() + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/jax/cql/builder.py b/acme/agents/jax/cql/builder.py index 8fe8173c57..937c96dbf2 100644 --- a/acme/agents/jax/cql/builder.py +++ b/acme/agents/jax/cql/builder.py @@ -15,95 +15,101 @@ """CQL Builder.""" from typing import Iterator, Optional -from acme import core -from acme import specs -from acme import types +import optax + +from acme import core, specs, types from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.agents.jax import builders +from acme.agents.jax import actors, builders from acme.agents.jax.cql import config as cql_config from acme.agents.jax.cql import learning from acme.agents.jax.cql import networks as cql_networks from acme.jax import networks as networks_lib from acme.jax import variable_utils -from acme.utils import counting -from acme.utils import loggers -import optax +from acme.utils import counting, loggers -class CQLBuilder(builders.OfflineBuilder[cql_networks.CQLNetworks, - actor_core_lib.FeedForwardPolicy, - types.Transition]): - """CQL Builder.""" +class CQLBuilder( + builders.OfflineBuilder[ + cql_networks.CQLNetworks, actor_core_lib.FeedForwardPolicy, types.Transition + ] +): + """CQL Builder.""" - def __init__( - self, - config: cql_config.CQLConfig, - ): - """Creates a CQL learner, an evaluation policy and an eval actor. + def __init__( + self, config: cql_config.CQLConfig, + ): + """Creates a CQL learner, an evaluation policy and an eval actor. Args: config: a config with CQL hps. """ - self._config = config + self._config = config - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: cql_networks.CQLNetworks, - dataset: Iterator[types.Transition], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - *, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - del environment_spec + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: cql_networks.CQLNetworks, + dataset: Iterator[types.Transition], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + *, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec - return learning.CQLLearner( - batch_size=self._config.batch_size, - networks=networks, - random_key=random_key, - demonstrations=dataset, - policy_optimizer=optax.adam(self._config.policy_learning_rate), - critic_optimizer=optax.adam(self._config.critic_learning_rate), - tau=self._config.tau, - fixed_cql_coefficient=self._config.fixed_cql_coefficient, - cql_lagrange_threshold=self._config.cql_lagrange_threshold, - cql_num_samples=self._config.cql_num_samples, - num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, - reward_scale=self._config.reward_scale, - discount=self._config.discount, - fixed_entropy_coefficient=self._config.fixed_entropy_coefficient, - target_entropy=self._config.target_entropy, - num_bc_iters=self._config.num_bc_iters, - logger=logger_fn('learner'), - counter=counter) + return learning.CQLLearner( + batch_size=self._config.batch_size, + networks=networks, + random_key=random_key, + demonstrations=dataset, + policy_optimizer=optax.adam(self._config.policy_learning_rate), + critic_optimizer=optax.adam(self._config.critic_learning_rate), + tau=self._config.tau, + fixed_cql_coefficient=self._config.fixed_cql_coefficient, + cql_lagrange_threshold=self._config.cql_lagrange_threshold, + cql_num_samples=self._config.cql_num_samples, + num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, + reward_scale=self._config.reward_scale, + discount=self._config.discount, + fixed_entropy_coefficient=self._config.fixed_entropy_coefficient, + target_entropy=self._config.target_entropy, + num_bc_iters=self._config.num_bc_iters, + logger=logger_fn("learner"), + counter=counter, + ) - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: actor_core_lib.FeedForwardPolicy, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - ) -> core.Actor: - del environment_spec - assert variable_source is not None - actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) - variable_client = variable_utils.VariableClient( - variable_source, 'policy', device='cpu') - return actors.GenericActor( - actor_core, random_key, variable_client, backend='cpu') + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: actor_core_lib.FeedForwardPolicy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + ) -> core.Actor: + del environment_spec + assert variable_source is not None + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) + variable_client = variable_utils.VariableClient( + variable_source, "policy", device="cpu" + ) + return actors.GenericActor( + actor_core, random_key, variable_client, backend="cpu" + ) - def make_policy(self, networks: cql_networks.CQLNetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool) -> actor_core_lib.FeedForwardPolicy: - """Construct the policy.""" - del environment_spec, evaluation + def make_policy( + self, + networks: cql_networks.CQLNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool, + ) -> actor_core_lib.FeedForwardPolicy: + """Construct the policy.""" + del environment_spec, evaluation - def evaluation_policy( - params: networks_lib.Params, key: networks_lib.PRNGKey, - observation: networks_lib.Observation) -> networks_lib.Action: - dist_params = networks.policy_network.apply(params, observation) - return networks.sample_eval(dist_params, key) + def evaluation_policy( + params: networks_lib.Params, + key: networks_lib.PRNGKey, + observation: networks_lib.Observation, + ) -> networks_lib.Action: + dist_params = networks.policy_network.apply(params, observation) + return networks.sample_eval(dist_params, key) - return evaluation_policy + return evaluation_policy diff --git a/acme/agents/jax/cql/config.py b/acme/agents/jax/cql/config.py index 44b2b26b50..12b891e68c 100644 --- a/acme/agents/jax/cql/config.py +++ b/acme/agents/jax/cql/config.py @@ -19,7 +19,7 @@ @dataclasses.dataclass class CQLConfig: - """Configuration options for CQL. + """Configuration options for CQL. Attributes: batch_size: batch size. @@ -42,17 +42,18 @@ class CQLConfig: target_entropy: target entropy when using adapdative entropy bonus. num_bc_iters: number of BC steps for actor initialization. """ - batch_size: int = 256 - policy_learning_rate: float = 3e-5 - critic_learning_rate: float = 3e-4 - fixed_cql_coefficient: float = 5. - tau: float = 0.005 - fixed_cql_coefficient: Optional[float] = 5. - cql_lagrange_threshold: Optional[float] = None - cql_num_samples: int = 10 - num_sgd_steps_per_step: int = 1 - reward_scale: float = 1.0 - discount: float = 0.99 - fixed_entropy_coefficient: Optional[float] = 0. - target_entropy: Optional[float] = 0 - num_bc_iters: int = 50_000 + + batch_size: int = 256 + policy_learning_rate: float = 3e-5 + critic_learning_rate: float = 3e-4 + fixed_cql_coefficient: float = 5.0 + tau: float = 0.005 + fixed_cql_coefficient: Optional[float] = 5.0 + cql_lagrange_threshold: Optional[float] = None + cql_num_samples: int = 10 + num_sgd_steps_per_step: int = 1 + reward_scale: float = 1.0 + discount: float = 0.99 + fixed_entropy_coefficient: Optional[float] = 0.0 + target_entropy: Optional[float] = 0 + num_bc_iters: int = 50_000 diff --git a/acme/agents/jax/cql/learning.py b/acme/agents/jax/cql/learning.py index 82b1406861..0a241ebccf 100644 --- a/acme/agents/jax/cql/learning.py +++ b/acme/agents/jax/cql/learning.py @@ -16,77 +16,79 @@ import time from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple -import acme -from acme import types -from acme.agents.jax.cql.networks import apply_and_sample_n -from acme.agents.jax.cql.networks import CQLNetworks -from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.utils import counting -from acme.utils import loggers + import jax import jax.numpy as jnp import optax +import acme +from acme import types +from acme.agents.jax.cql.networks import CQLNetworks, apply_and_sample_n +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import counting, loggers -_CQL_COEFFICIENT_MAX_VALUE = 1E6 +_CQL_COEFFICIENT_MAX_VALUE = 1e6 _CQL_GRAD_CLIPPING_VALUE = 40 class TrainingState(NamedTuple): - """Contains training state for the learner.""" - policy_optimizer_state: optax.OptState - critic_optimizer_state: optax.OptState - policy_params: networks_lib.Params - critic_params: networks_lib.Params - target_critic_params: networks_lib.Params - key: networks_lib.PRNGKey - - # Optimizer and value of the alpha parameter from SAC (entropy temperature). - # These fields are only used with an adaptive coefficient (when - # fixed_entropy_coefficeint is None in the CQLLearner) - alpha_optimizer_state: Optional[optax.OptState] = None - log_sac_alpha: Optional[networks_lib.Params] = None - - # Optimizer and value of the alpha parameter from CQL (regularization - # coefficient). - # These fields are only used with an adaptive coefficient (when - # fixed_cql_coefficiennt is None in the CQLLearner) - cql_optimizer_state: Optional[optax.OptState] = None - log_cql_alpha: Optional[networks_lib.Params] = None - - steps: int = 0 + """Contains training state for the learner.""" + + policy_optimizer_state: optax.OptState + critic_optimizer_state: optax.OptState + policy_params: networks_lib.Params + critic_params: networks_lib.Params + target_critic_params: networks_lib.Params + key: networks_lib.PRNGKey + + # Optimizer and value of the alpha parameter from SAC (entropy temperature). + # These fields are only used with an adaptive coefficient (when + # fixed_entropy_coefficeint is None in the CQLLearner) + alpha_optimizer_state: Optional[optax.OptState] = None + log_sac_alpha: Optional[networks_lib.Params] = None + + # Optimizer and value of the alpha parameter from CQL (regularization + # coefficient). + # These fields are only used with an adaptive coefficient (when + # fixed_cql_coefficiennt is None in the CQLLearner) + cql_optimizer_state: Optional[optax.OptState] = None + log_cql_alpha: Optional[networks_lib.Params] = None + + steps: int = 0 class CQLLearner(acme.Learner): - """CQL learner. + """CQL learner. Learning component of the Conservative Q-Learning algorithm from [Kumar et al., 2020] https://arxiv.org/abs/2006.04779. """ - _state: TrainingState - - def __init__(self, - batch_size: int, - networks: CQLNetworks, - random_key: networks_lib.PRNGKey, - demonstrations: Iterator[types.Transition], - policy_optimizer: optax.GradientTransformation, - critic_optimizer: optax.GradientTransformation, - tau: float = 0.005, - fixed_cql_coefficient: Optional[float] = None, - cql_lagrange_threshold: Optional[float] = None, - cql_num_samples: int = 10, - num_sgd_steps_per_step: int = 1, - reward_scale: float = 1.0, - discount: float = 0.99, - fixed_entropy_coefficient: Optional[float] = None, - target_entropy: Optional[float] = 0, - num_bc_iters: int = 50_000, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None): - """Initializes the CQL learner. + _state: TrainingState + + def __init__( + self, + batch_size: int, + networks: CQLNetworks, + random_key: networks_lib.PRNGKey, + demonstrations: Iterator[types.Transition], + policy_optimizer: optax.GradientTransformation, + critic_optimizer: optax.GradientTransformation, + tau: float = 0.005, + fixed_cql_coefficient: Optional[float] = None, + cql_lagrange_threshold: Optional[float] = None, + cql_num_samples: int = 10, + num_sgd_steps_per_step: int = 1, + reward_scale: float = 1.0, + discount: float = 0.99, + fixed_entropy_coefficient: Optional[float] = None, + target_entropy: Optional[float] = 0, + num_bc_iters: int = 50_000, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + ): + """Initializes the CQL learner. Args: batch_size: batch size. @@ -114,368 +116,440 @@ def __init__(self, counter: counter object used to keep track of steps. logger: logger object to be used by learner. """ - self._num_bc_iters = num_bc_iters - adaptive_entropy_coefficient = fixed_entropy_coefficient is None - action_spec = networks.environment_specs.actions - if adaptive_entropy_coefficient: - # sac_alpha is the temperature parameter that determines the relative - # importance of the entropy term versus the reward. - log_sac_alpha = jnp.asarray(0., dtype=jnp.float32) - alpha_optimizer = optax.adam(learning_rate=3e-4) - alpha_optimizer_state = alpha_optimizer.init(log_sac_alpha) - else: - if target_entropy: - raise ValueError('target_entropy should not be set when ' - 'fixed_entropy_coefficient is provided') - - adaptive_cql_coefficient = fixed_cql_coefficient is None - if adaptive_cql_coefficient: - log_cql_alpha = jnp.asarray(0., dtype=jnp.float32) - cql_optimizer = optax.adam(learning_rate=3e-4) - cql_optimizer_state = cql_optimizer.init(log_cql_alpha) - else: - if cql_lagrange_threshold: - raise ValueError('cql_lagrange_threshold should not be set when ' - 'fixed_cql_coefficient is provided') - - def alpha_loss(log_sac_alpha: jnp.ndarray, - policy_params: networks_lib.Params, - transitions: types.Transition, - key: jnp.ndarray) -> jnp.ndarray: - """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.""" - dist_params = networks.policy_network.apply(policy_params, - transitions.observation) - action = networks.sample(dist_params, key) - log_prob = networks.log_prob(dist_params, action) - sac_alpha = jnp.exp(log_sac_alpha) - sac_alpha_loss = sac_alpha * jax.lax.stop_gradient(-log_prob - - target_entropy) - return jnp.mean(sac_alpha_loss) - - def sac_critic_loss(q_old_action: jnp.ndarray, - policy_params: networks_lib.Params, - target_critic_params: networks_lib.Params, - transitions: types.Transition, - key: networks_lib.PRNGKey) -> jnp.ndarray: - """Computes the SAC part of the loss.""" - next_dist_params = networks.policy_network.apply( - policy_params, transitions.next_observation) - next_action = networks.sample(next_dist_params, key) - next_q = networks.critic_network.apply(target_critic_params, - transitions.next_observation, - next_action) - next_v = jnp.min(next_q, axis=-1) - target_q = jax.lax.stop_gradient(transitions.reward * reward_scale + - transitions.discount * discount * next_v) - return jnp.mean(jnp.square(q_old_action - jnp.expand_dims(target_q, -1))) - - def batched_critic(actions: jnp.ndarray, critic_params: networks_lib.Params, - observation: jnp.ndarray) -> jnp.ndarray: - """Applies the critic network to a batch of sampled actions.""" - actions = jax.lax.stop_gradient(actions) - tiled_actions = jnp.reshape(actions, (batch_size * cql_num_samples, -1)) - tiled_states = jnp.tile(observation, [cql_num_samples, 1]) - tiled_q = networks.critic_network.apply(critic_params, tiled_states, - tiled_actions) - return jnp.reshape(tiled_q, (cql_num_samples, batch_size, -1)) - - def cql_critic_loss(q_old_action: jnp.ndarray, - critic_params: networks_lib.Params, - policy_params: networks_lib.Params, - transitions: types.Transition, - key: networks_lib.PRNGKey) -> jnp.ndarray: - """Computes the CQL part of the loss.""" - # The CQL part of the loss is - # logsumexp(Q(s,·)) - Q(s,a), - # where s is the currrent state, and a the action in the dataset (so - # Q(s,a) is simply q_old_action. - # We need to estimate logsumexp(Q). This is done with importance sampling - # (IS). This function implements the unlabeled equation page 29, Appx. F, - # in https://arxiv.org/abs/2006.04779. - # Here, IS is done with the uniform distribution and the policy in the - # current state s. In their implementation, the authors also add the - # policy in the transiting state s': - # https://github.com/aviralkumar2907/CQL/blob/master/d4rl/rlkit/torch/sac/cql.py, - # (l. 233-236). - - key_policy, key_policy_next, key_uniform = jax.random.split(key, 3) - - def sampled_q(obs, key): - actions, log_probs = apply_and_sample_n( - key, networks, policy_params, obs, cql_num_samples) - return batched_critic(actions, critic_params, - transitions.observation) - jax.lax.stop_gradient( - jnp.expand_dims(log_probs, -1)) - - # Sample wrt policy in s - sampled_q_from_policy = sampled_q(transitions.observation, key_policy) - - # Sample wrt policy in s' - sampled_q_from_policy_next = sampled_q(transitions.next_observation, - key_policy_next) - - # Sample wrt uniform - actions_uniform = jax.random.uniform( - key_uniform, (cql_num_samples, batch_size) + action_spec.shape, - minval=action_spec.minimum, maxval=action_spec.maximum) - log_prob_uniform = -jnp.sum( - jnp.log(action_spec.maximum - action_spec.minimum)) - sampled_q_from_uniform = ( - batched_critic(actions_uniform, critic_params, - transitions.observation) - log_prob_uniform) - - # Combine the samplings - combined = jnp.concatenate( - (sampled_q_from_uniform, sampled_q_from_policy, - sampled_q_from_policy_next), - axis=0) - lse_q = jax.nn.logsumexp(combined, axis=0, b=1. / (3 * cql_num_samples)) - - return jnp.mean(lse_q - q_old_action) - - def critic_loss(critic_params: networks_lib.Params, - policy_params: networks_lib.Params, - target_critic_params: networks_lib.Params, - cql_alpha: jnp.ndarray, transitions: types.Transition, - key: networks_lib.PRNGKey) -> jnp.ndarray: - """Computes the full critic loss.""" - key_cql, key_sac = jax.random.split(key, 2) - q_old_action = networks.critic_network.apply(critic_params, - transitions.observation, - transitions.action) - cql_loss = cql_critic_loss(q_old_action, critic_params, policy_params, - transitions, key_cql) - sac_loss = sac_critic_loss(q_old_action, policy_params, - target_critic_params, transitions, key_sac) - return cql_alpha * cql_loss + sac_loss - - def cql_lagrange_loss(log_cql_alpha: jnp.ndarray, - critic_params: networks_lib.Params, - policy_params: networks_lib.Params, - transitions: types.Transition, - key: jnp.ndarray) -> jnp.ndarray: - """Computes the loss that optimizes the cql coefficient.""" - cql_alpha = jnp.exp(log_cql_alpha) - q_old_action = networks.critic_network.apply(critic_params, - transitions.observation, - transitions.action) - return -cql_alpha * ( - cql_critic_loss(q_old_action, critic_params, policy_params, - transitions, key) - cql_lagrange_threshold) - - def actor_loss(policy_params: networks_lib.Params, - critic_params: networks_lib.Params, sac_alpha: jnp.ndarray, - transitions: types.Transition, key: jnp.ndarray, - in_initial_bc_iters: bool) -> jnp.ndarray: - """Computes the loss for the policy.""" - dist_params = networks.policy_network.apply(policy_params, - transitions.observation) - if in_initial_bc_iters: - log_prob = networks.log_prob(dist_params, transitions.action) - actor_loss = -jnp.mean(log_prob) - else: - action = networks.sample(dist_params, key) - log_prob = networks.log_prob(dist_params, action) - q_action = networks.critic_network.apply(critic_params, - transitions.observation, - action) - min_q = jnp.min(q_action, axis=-1) - actor_loss = jnp.mean(sac_alpha * log_prob - min_q) - return actor_loss - - alpha_grad = jax.value_and_grad(alpha_loss) - cql_lagrange_grad = jax.value_and_grad(cql_lagrange_loss) - critic_grad = jax.value_and_grad(critic_loss) - actor_grad = jax.value_and_grad(actor_loss) - - def update_step( - state: TrainingState, - rb_transitions: types.Transition, - in_initial_bc_iters: bool, - ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: - - key, key_alpha, key_critic, key_actor = jax.random.split(state.key, 4) - - if adaptive_entropy_coefficient: - alpha_loss, alpha_grads = alpha_grad(state.log_sac_alpha, - state.policy_params, - rb_transitions, key_alpha) - sac_alpha = jnp.exp(state.log_sac_alpha) - else: - sac_alpha = fixed_entropy_coefficient - - if adaptive_cql_coefficient: - cql_lagrange_loss, cql_lagrange_grads = cql_lagrange_grad( - state.log_cql_alpha, state.critic_params, state.policy_params, - rb_transitions, key_critic) - cql_lagrange_grads = jnp.clip(cql_lagrange_grads, - -_CQL_GRAD_CLIPPING_VALUE, - _CQL_GRAD_CLIPPING_VALUE) - cql_alpha = jnp.exp(state.log_cql_alpha) - cql_alpha = jnp.clip( - cql_alpha, a_min=0., a_max=_CQL_COEFFICIENT_MAX_VALUE) - else: - cql_alpha = fixed_cql_coefficient - - critic_loss, critic_grads = critic_grad(state.critic_params, - state.policy_params, - state.target_critic_params, - cql_alpha, rb_transitions, - key_critic) - actor_loss, actor_grads = actor_grad(state.policy_params, - state.critic_params, sac_alpha, - rb_transitions, key_actor, - in_initial_bc_iters) - - # Apply policy gradients - actor_update, policy_optimizer_state = policy_optimizer.update( - actor_grads, state.policy_optimizer_state) - policy_params = optax.apply_updates(state.policy_params, actor_update) - - # Apply critic gradients - critic_update, critic_optimizer_state = critic_optimizer.update( - critic_grads, state.critic_optimizer_state) - critic_params = optax.apply_updates(state.critic_params, critic_update) - - new_target_critic_params = jax.tree_map( - lambda x, y: x * (1 - tau) + y * tau, state.target_critic_params, - critic_params) - - metrics = { - 'critic_loss': critic_loss, - 'actor_loss': actor_loss, - } - - new_state = TrainingState( - policy_optimizer_state=policy_optimizer_state, - critic_optimizer_state=critic_optimizer_state, - policy_params=policy_params, - critic_params=critic_params, - target_critic_params=new_target_critic_params, - key=key, - alpha_optimizer_state=state.alpha_optimizer_state, - log_sac_alpha=state.log_sac_alpha, - steps=state.steps + 1, - ) - if adaptive_entropy_coefficient and (not in_initial_bc_iters): - # Apply sac_alpha gradients - alpha_update, alpha_optimizer_state = alpha_optimizer.update( - alpha_grads, state.alpha_optimizer_state) - log_sac_alpha = optax.apply_updates(state.log_sac_alpha, alpha_update) - metrics.update({ - 'alpha_loss': alpha_loss, - 'sac_alpha': jnp.exp(log_sac_alpha), - }) - new_state = new_state._replace( - alpha_optimizer_state=alpha_optimizer_state, - log_sac_alpha=log_sac_alpha) - else: - metrics['alpha_loss'] = 0. - metrics['sac_alpha'] = fixed_cql_coefficient - - if adaptive_cql_coefficient: - # Apply cql coeff gradients - cql_update, cql_optimizer_state = cql_optimizer.update( - cql_lagrange_grads, state.cql_optimizer_state) - log_cql_alpha = optax.apply_updates(state.log_cql_alpha, cql_update) - metrics.update({ - 'cql_lagrange_loss': cql_lagrange_loss, - 'cql_alpha': jnp.exp(log_cql_alpha), - }) - new_state = new_state._replace( - cql_optimizer_state=cql_optimizer_state, - log_cql_alpha=log_cql_alpha) - - return new_state, metrics - - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger( - 'learner', - asynchronous=True, - serialize_fn=utils.fetch_devicearray, - steps_key=self._counter.get_steps_key()) - - # Iterator on demonstration transitions. - self._demonstrations = demonstrations - - # Use the JIT compiler. - update_step_in_initial_bc_iters = utils.process_multiple_batches( - lambda x, y: update_step(x, y, True), num_sgd_steps_per_step) - update_step_rest = utils.process_multiple_batches( - lambda x, y: update_step(x, y, False), num_sgd_steps_per_step) - - self._update_step_in_initial_bc_iters = jax.jit( - update_step_in_initial_bc_iters) - self._update_step_rest = jax.jit(update_step_rest) - - # Create initial state. - key_policy, key_q, training_state_key = jax.random.split(random_key, 3) - del random_key - policy_params = networks.policy_network.init(key_policy) - policy_optimizer_state = policy_optimizer.init(policy_params) - critic_params = networks.critic_network.init(key_q) - critic_optimizer_state = critic_optimizer.init(critic_params) - - self._state = TrainingState( - policy_optimizer_state=policy_optimizer_state, - critic_optimizer_state=critic_optimizer_state, - policy_params=policy_params, - critic_params=critic_params, - target_critic_params=critic_params, - key=training_state_key, - steps=0) - - if adaptive_entropy_coefficient: - self._state = self._state._replace( - alpha_optimizer_state=alpha_optimizer_state, - log_sac_alpha=log_sac_alpha) - if adaptive_cql_coefficient: - self._state = self._state._replace( - cql_optimizer_state=cql_optimizer_state, log_cql_alpha=log_cql_alpha) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - def step(self): - # Get data from replay (dropping extras if any). Note there is no - # extra data here because we do not insert any into Reverb. - transitions = next(self._demonstrations) - - counts = self._counter.get_counts() - if 'learner_steps' not in counts: - cur_step = 0 - else: - cur_step = counts['learner_steps'] - in_initial_bc_iters = cur_step < self._num_bc_iters - - if in_initial_bc_iters: - self._state, metrics = self._update_step_in_initial_bc_iters( - self._state, transitions) - else: - self._state, metrics = self._update_step_rest(self._state, transitions) - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Increment counts and record the current time - counts = self._counter.increment(steps=1, walltime=elapsed_time) - - # Attempts to write the logs. - self._logger.write({**metrics, **counts}) - - def get_variables(self, names: List[str]) -> List[Any]: - variables = { - 'policy': self._state.policy_params, - } - return [variables[name] for name in names] - - def save(self) -> TrainingState: - return self._state - - def restore(self, state: TrainingState): - self._state = state + self._num_bc_iters = num_bc_iters + adaptive_entropy_coefficient = fixed_entropy_coefficient is None + action_spec = networks.environment_specs.actions + if adaptive_entropy_coefficient: + # sac_alpha is the temperature parameter that determines the relative + # importance of the entropy term versus the reward. + log_sac_alpha = jnp.asarray(0.0, dtype=jnp.float32) + alpha_optimizer = optax.adam(learning_rate=3e-4) + alpha_optimizer_state = alpha_optimizer.init(log_sac_alpha) + else: + if target_entropy: + raise ValueError( + "target_entropy should not be set when " + "fixed_entropy_coefficient is provided" + ) + + adaptive_cql_coefficient = fixed_cql_coefficient is None + if adaptive_cql_coefficient: + log_cql_alpha = jnp.asarray(0.0, dtype=jnp.float32) + cql_optimizer = optax.adam(learning_rate=3e-4) + cql_optimizer_state = cql_optimizer.init(log_cql_alpha) + else: + if cql_lagrange_threshold: + raise ValueError( + "cql_lagrange_threshold should not be set when " + "fixed_cql_coefficient is provided" + ) + + def alpha_loss( + log_sac_alpha: jnp.ndarray, + policy_params: networks_lib.Params, + transitions: types.Transition, + key: jnp.ndarray, + ) -> jnp.ndarray: + """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.""" + dist_params = networks.policy_network.apply( + policy_params, transitions.observation + ) + action = networks.sample(dist_params, key) + log_prob = networks.log_prob(dist_params, action) + sac_alpha = jnp.exp(log_sac_alpha) + sac_alpha_loss = sac_alpha * jax.lax.stop_gradient( + -log_prob - target_entropy + ) + return jnp.mean(sac_alpha_loss) + + def sac_critic_loss( + q_old_action: jnp.ndarray, + policy_params: networks_lib.Params, + target_critic_params: networks_lib.Params, + transitions: types.Transition, + key: networks_lib.PRNGKey, + ) -> jnp.ndarray: + """Computes the SAC part of the loss.""" + next_dist_params = networks.policy_network.apply( + policy_params, transitions.next_observation + ) + next_action = networks.sample(next_dist_params, key) + next_q = networks.critic_network.apply( + target_critic_params, transitions.next_observation, next_action + ) + next_v = jnp.min(next_q, axis=-1) + target_q = jax.lax.stop_gradient( + transitions.reward * reward_scale + + transitions.discount * discount * next_v + ) + return jnp.mean(jnp.square(q_old_action - jnp.expand_dims(target_q, -1))) + + def batched_critic( + actions: jnp.ndarray, + critic_params: networks_lib.Params, + observation: jnp.ndarray, + ) -> jnp.ndarray: + """Applies the critic network to a batch of sampled actions.""" + actions = jax.lax.stop_gradient(actions) + tiled_actions = jnp.reshape(actions, (batch_size * cql_num_samples, -1)) + tiled_states = jnp.tile(observation, [cql_num_samples, 1]) + tiled_q = networks.critic_network.apply( + critic_params, tiled_states, tiled_actions + ) + return jnp.reshape(tiled_q, (cql_num_samples, batch_size, -1)) + + def cql_critic_loss( + q_old_action: jnp.ndarray, + critic_params: networks_lib.Params, + policy_params: networks_lib.Params, + transitions: types.Transition, + key: networks_lib.PRNGKey, + ) -> jnp.ndarray: + """Computes the CQL part of the loss.""" + # The CQL part of the loss is + # logsumexp(Q(s,·)) - Q(s,a), + # where s is the currrent state, and a the action in the dataset (so + # Q(s,a) is simply q_old_action. + # We need to estimate logsumexp(Q). This is done with importance sampling + # (IS). This function implements the unlabeled equation page 29, Appx. F, + # in https://arxiv.org/abs/2006.04779. + # Here, IS is done with the uniform distribution and the policy in the + # current state s. In their implementation, the authors also add the + # policy in the transiting state s': + # https://github.com/aviralkumar2907/CQL/blob/master/d4rl/rlkit/torch/sac/cql.py, + # (l. 233-236). + + key_policy, key_policy_next, key_uniform = jax.random.split(key, 3) + + def sampled_q(obs, key): + actions, log_probs = apply_and_sample_n( + key, networks, policy_params, obs, cql_num_samples + ) + return batched_critic( + actions, critic_params, transitions.observation + ) - jax.lax.stop_gradient(jnp.expand_dims(log_probs, -1)) + + # Sample wrt policy in s + sampled_q_from_policy = sampled_q(transitions.observation, key_policy) + + # Sample wrt policy in s' + sampled_q_from_policy_next = sampled_q( + transitions.next_observation, key_policy_next + ) + + # Sample wrt uniform + actions_uniform = jax.random.uniform( + key_uniform, + (cql_num_samples, batch_size) + action_spec.shape, + minval=action_spec.minimum, + maxval=action_spec.maximum, + ) + log_prob_uniform = -jnp.sum( + jnp.log(action_spec.maximum - action_spec.minimum) + ) + sampled_q_from_uniform = ( + batched_critic(actions_uniform, critic_params, transitions.observation) + - log_prob_uniform + ) + + # Combine the samplings + combined = jnp.concatenate( + ( + sampled_q_from_uniform, + sampled_q_from_policy, + sampled_q_from_policy_next, + ), + axis=0, + ) + lse_q = jax.nn.logsumexp(combined, axis=0, b=1.0 / (3 * cql_num_samples)) + + return jnp.mean(lse_q - q_old_action) + + def critic_loss( + critic_params: networks_lib.Params, + policy_params: networks_lib.Params, + target_critic_params: networks_lib.Params, + cql_alpha: jnp.ndarray, + transitions: types.Transition, + key: networks_lib.PRNGKey, + ) -> jnp.ndarray: + """Computes the full critic loss.""" + key_cql, key_sac = jax.random.split(key, 2) + q_old_action = networks.critic_network.apply( + critic_params, transitions.observation, transitions.action + ) + cql_loss = cql_critic_loss( + q_old_action, critic_params, policy_params, transitions, key_cql + ) + sac_loss = sac_critic_loss( + q_old_action, policy_params, target_critic_params, transitions, key_sac + ) + return cql_alpha * cql_loss + sac_loss + + def cql_lagrange_loss( + log_cql_alpha: jnp.ndarray, + critic_params: networks_lib.Params, + policy_params: networks_lib.Params, + transitions: types.Transition, + key: jnp.ndarray, + ) -> jnp.ndarray: + """Computes the loss that optimizes the cql coefficient.""" + cql_alpha = jnp.exp(log_cql_alpha) + q_old_action = networks.critic_network.apply( + critic_params, transitions.observation, transitions.action + ) + return -cql_alpha * ( + cql_critic_loss( + q_old_action, critic_params, policy_params, transitions, key + ) + - cql_lagrange_threshold + ) + + def actor_loss( + policy_params: networks_lib.Params, + critic_params: networks_lib.Params, + sac_alpha: jnp.ndarray, + transitions: types.Transition, + key: jnp.ndarray, + in_initial_bc_iters: bool, + ) -> jnp.ndarray: + """Computes the loss for the policy.""" + dist_params = networks.policy_network.apply( + policy_params, transitions.observation + ) + if in_initial_bc_iters: + log_prob = networks.log_prob(dist_params, transitions.action) + actor_loss = -jnp.mean(log_prob) + else: + action = networks.sample(dist_params, key) + log_prob = networks.log_prob(dist_params, action) + q_action = networks.critic_network.apply( + critic_params, transitions.observation, action + ) + min_q = jnp.min(q_action, axis=-1) + actor_loss = jnp.mean(sac_alpha * log_prob - min_q) + return actor_loss + + alpha_grad = jax.value_and_grad(alpha_loss) + cql_lagrange_grad = jax.value_and_grad(cql_lagrange_loss) + critic_grad = jax.value_and_grad(critic_loss) + actor_grad = jax.value_and_grad(actor_loss) + + def update_step( + state: TrainingState, + rb_transitions: types.Transition, + in_initial_bc_iters: bool, + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + + key, key_alpha, key_critic, key_actor = jax.random.split(state.key, 4) + + if adaptive_entropy_coefficient: + alpha_loss, alpha_grads = alpha_grad( + state.log_sac_alpha, state.policy_params, rb_transitions, key_alpha + ) + sac_alpha = jnp.exp(state.log_sac_alpha) + else: + sac_alpha = fixed_entropy_coefficient + + if adaptive_cql_coefficient: + cql_lagrange_loss, cql_lagrange_grads = cql_lagrange_grad( + state.log_cql_alpha, + state.critic_params, + state.policy_params, + rb_transitions, + key_critic, + ) + cql_lagrange_grads = jnp.clip( + cql_lagrange_grads, + -_CQL_GRAD_CLIPPING_VALUE, + _CQL_GRAD_CLIPPING_VALUE, + ) + cql_alpha = jnp.exp(state.log_cql_alpha) + cql_alpha = jnp.clip( + cql_alpha, a_min=0.0, a_max=_CQL_COEFFICIENT_MAX_VALUE + ) + else: + cql_alpha = fixed_cql_coefficient + + critic_loss, critic_grads = critic_grad( + state.critic_params, + state.policy_params, + state.target_critic_params, + cql_alpha, + rb_transitions, + key_critic, + ) + actor_loss, actor_grads = actor_grad( + state.policy_params, + state.critic_params, + sac_alpha, + rb_transitions, + key_actor, + in_initial_bc_iters, + ) + + # Apply policy gradients + actor_update, policy_optimizer_state = policy_optimizer.update( + actor_grads, state.policy_optimizer_state + ) + policy_params = optax.apply_updates(state.policy_params, actor_update) + + # Apply critic gradients + critic_update, critic_optimizer_state = critic_optimizer.update( + critic_grads, state.critic_optimizer_state + ) + critic_params = optax.apply_updates(state.critic_params, critic_update) + + new_target_critic_params = jax.tree_map( + lambda x, y: x * (1 - tau) + y * tau, + state.target_critic_params, + critic_params, + ) + + metrics = { + "critic_loss": critic_loss, + "actor_loss": actor_loss, + } + + new_state = TrainingState( + policy_optimizer_state=policy_optimizer_state, + critic_optimizer_state=critic_optimizer_state, + policy_params=policy_params, + critic_params=critic_params, + target_critic_params=new_target_critic_params, + key=key, + alpha_optimizer_state=state.alpha_optimizer_state, + log_sac_alpha=state.log_sac_alpha, + steps=state.steps + 1, + ) + if adaptive_entropy_coefficient and (not in_initial_bc_iters): + # Apply sac_alpha gradients + alpha_update, alpha_optimizer_state = alpha_optimizer.update( + alpha_grads, state.alpha_optimizer_state + ) + log_sac_alpha = optax.apply_updates(state.log_sac_alpha, alpha_update) + metrics.update( + {"alpha_loss": alpha_loss, "sac_alpha": jnp.exp(log_sac_alpha),} + ) + new_state = new_state._replace( + alpha_optimizer_state=alpha_optimizer_state, + log_sac_alpha=log_sac_alpha, + ) + else: + metrics["alpha_loss"] = 0.0 + metrics["sac_alpha"] = fixed_cql_coefficient + + if adaptive_cql_coefficient: + # Apply cql coeff gradients + cql_update, cql_optimizer_state = cql_optimizer.update( + cql_lagrange_grads, state.cql_optimizer_state + ) + log_cql_alpha = optax.apply_updates(state.log_cql_alpha, cql_update) + metrics.update( + { + "cql_lagrange_loss": cql_lagrange_loss, + "cql_alpha": jnp.exp(log_cql_alpha), + } + ) + new_state = new_state._replace( + cql_optimizer_state=cql_optimizer_state, log_cql_alpha=log_cql_alpha + ) + + return new_state, metrics + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + "learner", + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key(), + ) + + # Iterator on demonstration transitions. + self._demonstrations = demonstrations + + # Use the JIT compiler. + update_step_in_initial_bc_iters = utils.process_multiple_batches( + lambda x, y: update_step(x, y, True), num_sgd_steps_per_step + ) + update_step_rest = utils.process_multiple_batches( + lambda x, y: update_step(x, y, False), num_sgd_steps_per_step + ) + + self._update_step_in_initial_bc_iters = jax.jit(update_step_in_initial_bc_iters) + self._update_step_rest = jax.jit(update_step_rest) + + # Create initial state. + key_policy, key_q, training_state_key = jax.random.split(random_key, 3) + del random_key + policy_params = networks.policy_network.init(key_policy) + policy_optimizer_state = policy_optimizer.init(policy_params) + critic_params = networks.critic_network.init(key_q) + critic_optimizer_state = critic_optimizer.init(critic_params) + + self._state = TrainingState( + policy_optimizer_state=policy_optimizer_state, + critic_optimizer_state=critic_optimizer_state, + policy_params=policy_params, + critic_params=critic_params, + target_critic_params=critic_params, + key=training_state_key, + steps=0, + ) + + if adaptive_entropy_coefficient: + self._state = self._state._replace( + alpha_optimizer_state=alpha_optimizer_state, log_sac_alpha=log_sac_alpha + ) + if adaptive_cql_coefficient: + self._state = self._state._replace( + cql_optimizer_state=cql_optimizer_state, log_cql_alpha=log_cql_alpha + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + def step(self): + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + transitions = next(self._demonstrations) + + counts = self._counter.get_counts() + if "learner_steps" not in counts: + cur_step = 0 + else: + cur_step = counts["learner_steps"] + in_initial_bc_iters = cur_step < self._num_bc_iters + + if in_initial_bc_iters: + self._state, metrics = self._update_step_in_initial_bc_iters( + self._state, transitions + ) + else: + self._state, metrics = self._update_step_rest(self._state, transitions) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[Any]: + variables = { + "policy": self._state.policy_params, + } + return [variables[name] for name in names] + + def save(self) -> TrainingState: + return self._state + + def restore(self, state: TrainingState): + self._state = state diff --git a/acme/agents/jax/cql/networks.py b/acme/agents/jax/cql/networks.py index 0593489b50..d80a347e2b 100644 --- a/acme/agents/jax/cql/networks.py +++ b/acme/agents/jax/cql/networks.py @@ -16,45 +16,52 @@ import dataclasses from typing import Optional, Tuple +import jax +import jax.numpy as jnp + from acme import specs from acme.agents.jax import sac from acme.jax import networks as networks_lib -import jax -import jax.numpy as jnp @dataclasses.dataclass class CQLNetworks: - """Network and pure functions for the CQL agent.""" - policy_network: networks_lib.FeedForwardNetwork - critic_network: networks_lib.FeedForwardNetwork - log_prob: networks_lib.LogProbFn - sample: Optional[networks_lib.SampleFn] - sample_eval: Optional[networks_lib.SampleFn] - environment_specs: specs.EnvironmentSpec - - -def apply_and_sample_n(key: networks_lib.PRNGKey, - networks: CQLNetworks, - params: networks_lib.Params, obs: jnp.ndarray, - num_samples: int) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Applies the policy and samples num_samples actions.""" - dist_params = networks.policy_network.apply(params, obs) - sampled_actions = jnp.array([ - networks.sample(dist_params, key_n) - for key_n in jax.random.split(key, num_samples) - ]) - sampled_log_probs = networks.log_prob(dist_params, sampled_actions) - return sampled_actions, sampled_log_probs - - -def make_networks( - spec: specs.EnvironmentSpec, **kwargs) -> CQLNetworks: - sac_networks = sac.make_networks(spec, **kwargs) - return CQLNetworks( - policy_network=sac_networks.policy_network, - critic_network=sac_networks.q_network, - log_prob=sac_networks.log_prob, - sample=sac_networks.sample, - sample_eval=sac_networks.sample_eval, - environment_specs=spec) + """Network and pure functions for the CQL agent.""" + + policy_network: networks_lib.FeedForwardNetwork + critic_network: networks_lib.FeedForwardNetwork + log_prob: networks_lib.LogProbFn + sample: Optional[networks_lib.SampleFn] + sample_eval: Optional[networks_lib.SampleFn] + environment_specs: specs.EnvironmentSpec + + +def apply_and_sample_n( + key: networks_lib.PRNGKey, + networks: CQLNetworks, + params: networks_lib.Params, + obs: jnp.ndarray, + num_samples: int, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Applies the policy and samples num_samples actions.""" + dist_params = networks.policy_network.apply(params, obs) + sampled_actions = jnp.array( + [ + networks.sample(dist_params, key_n) + for key_n in jax.random.split(key, num_samples) + ] + ) + sampled_log_probs = networks.log_prob(dist_params, sampled_actions) + return sampled_actions, sampled_log_probs + + +def make_networks(spec: specs.EnvironmentSpec, **kwargs) -> CQLNetworks: + sac_networks = sac.make_networks(spec, **kwargs) + return CQLNetworks( + policy_network=sac_networks.policy_network, + critic_network=sac_networks.q_network, + log_prob=sac_networks.log_prob, + sample=sac_networks.sample, + sample_eval=sac_networks.sample_eval, + environment_specs=spec, + ) diff --git a/acme/agents/jax/crr/__init__.py b/acme/agents/jax/crr/__init__.py index d594cf2313..a748ca3f30 100644 --- a/acme/agents/jax/crr/__init__.py +++ b/acme/agents/jax/crr/__init__.py @@ -17,9 +17,10 @@ from acme.agents.jax.crr.builder import CRRBuilder from acme.agents.jax.crr.config import CRRConfig from acme.agents.jax.crr.learning import CRRLearner -from acme.agents.jax.crr.losses import policy_loss_coeff_advantage_exp -from acme.agents.jax.crr.losses import policy_loss_coeff_advantage_indicator -from acme.agents.jax.crr.losses import policy_loss_coeff_constant -from acme.agents.jax.crr.losses import PolicyLossCoeff -from acme.agents.jax.crr.networks import CRRNetworks -from acme.agents.jax.crr.networks import make_networks +from acme.agents.jax.crr.losses import ( + PolicyLossCoeff, + policy_loss_coeff_advantage_exp, + policy_loss_coeff_advantage_indicator, + policy_loss_coeff_constant, +) +from acme.agents.jax.crr.networks import CRRNetworks, make_networks diff --git a/acme/agents/jax/crr/agent_test.py b/acme/agents/jax/crr/agent_test.py index 2f92520625..efa69351ef 100644 --- a/acme/agents/jax/crr/agent_test.py +++ b/acme/agents/jax/crr/agent_test.py @@ -14,53 +14,55 @@ """Tests for the CRR agent.""" -from acme import specs -from acme.agents.jax import crr -from acme.testing import fakes import jax import optax +from absl.testing import absltest, parameterized -from absl.testing import absltest -from absl.testing import parameterized +from acme import specs +from acme.agents.jax import crr +from acme.testing import fakes class CRRTest(parameterized.TestCase): + @parameterized.named_parameters( + ("exp", crr.policy_loss_coeff_advantage_exp), + ("indicator", crr.policy_loss_coeff_advantage_indicator), + ("all", crr.policy_loss_coeff_constant), + ) + def test_train(self, policy_loss_coeff_fn): + seed = 0 + num_iterations = 5 + batch_size = 64 + grad_updates_per_batch = 1 - @parameterized.named_parameters( - ('exp', crr.policy_loss_coeff_advantage_exp), - ('indicator', crr.policy_loss_coeff_advantage_indicator), - ('all', crr.policy_loss_coeff_constant)) - def test_train(self, policy_loss_coeff_fn): - seed = 0 - num_iterations = 5 - batch_size = 64 - grad_updates_per_batch = 1 - - # Create a fake environment to test with. - environment = fakes.ContinuousEnvironment( - episode_length=10, bounded=True, action_dim=6) - spec = specs.make_environment_spec(environment) + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment( + episode_length=10, bounded=True, action_dim=6 + ) + spec = specs.make_environment_spec(environment) - # Construct the learner. - networks = crr.make_networks( - spec, policy_layer_sizes=(8, 8), critic_layer_sizes=(8, 8)) - key = jax.random.PRNGKey(seed) - dataset = fakes.transition_iterator(environment) - learner = crr.CRRLearner( - networks, - key, - discount=0.95, - target_update_period=2, - policy_loss_coeff_fn=policy_loss_coeff_fn, - iterator=dataset(batch_size * grad_updates_per_batch), - policy_optimizer=optax.adam(1e-4), - critic_optimizer=optax.adam(1e-4), - grad_updates_per_batch=grad_updates_per_batch) + # Construct the learner. + networks = crr.make_networks( + spec, policy_layer_sizes=(8, 8), critic_layer_sizes=(8, 8) + ) + key = jax.random.PRNGKey(seed) + dataset = fakes.transition_iterator(environment) + learner = crr.CRRLearner( + networks, + key, + discount=0.95, + target_update_period=2, + policy_loss_coeff_fn=policy_loss_coeff_fn, + iterator=dataset(batch_size * grad_updates_per_batch), + policy_optimizer=optax.adam(1e-4), + critic_optimizer=optax.adam(1e-4), + grad_updates_per_batch=grad_updates_per_batch, + ) - # Train the learner. - for _ in range(num_iterations): - learner.step() + # Train the learner. + for _ in range(num_iterations): + learner.step() -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/jax/crr/builder.py b/acme/agents/jax/crr/builder.py index 7cdd9b534c..164b3b76ea 100644 --- a/acme/agents/jax/crr/builder.py +++ b/acme/agents/jax/crr/builder.py @@ -15,92 +15,98 @@ """CRR Builder.""" from typing import Iterator, Optional -from acme import core -from acme import specs -from acme import types +import optax + +from acme import core, specs, types from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.agents.jax import builders +from acme.agents.jax import actors, builders from acme.agents.jax.crr import config as crr_config -from acme.agents.jax.crr import learning -from acme.agents.jax.crr import losses +from acme.agents.jax.crr import learning, losses from acme.agents.jax.crr import networks as crr_networks from acme.jax import networks as networks_lib from acme.jax import variable_utils -from acme.utils import counting -from acme.utils import loggers -import optax +from acme.utils import counting, loggers -class CRRBuilder(builders.OfflineBuilder[crr_networks.CRRNetworks, - actor_core_lib.FeedForwardPolicy, - types.Transition]): - """CRR Builder.""" +class CRRBuilder( + builders.OfflineBuilder[ + crr_networks.CRRNetworks, actor_core_lib.FeedForwardPolicy, types.Transition + ] +): + """CRR Builder.""" - def __init__( - self, - config: crr_config.CRRConfig, - policy_loss_coeff_fn: losses.PolicyLossCoeff, - ): - """Creates a CRR learner, an evaluation policy and an eval actor. + def __init__( + self, + config: crr_config.CRRConfig, + policy_loss_coeff_fn: losses.PolicyLossCoeff, + ): + """Creates a CRR learner, an evaluation policy and an eval actor. Args: config: a config with CRR hps. policy_loss_coeff_fn: set the loss function for the policy. """ - self._config = config - self._policy_loss_coeff_fn = policy_loss_coeff_fn + self._config = config + self._policy_loss_coeff_fn = policy_loss_coeff_fn - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: crr_networks.CRRNetworks, - dataset: Iterator[types.Transition], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - *, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - del environment_spec + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: crr_networks.CRRNetworks, + dataset: Iterator[types.Transition], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + *, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec - return learning.CRRLearner( - networks=networks, - random_key=random_key, - discount=self._config.discount, - target_update_period=self._config.target_update_period, - iterator=dataset, - policy_loss_coeff_fn=self._policy_loss_coeff_fn, - policy_optimizer=optax.adam(self._config.learning_rate), - critic_optimizer=optax.adam(self._config.learning_rate), - use_sarsa_target=self._config.use_sarsa_target, - logger=logger_fn('learner'), - counter=counter) + return learning.CRRLearner( + networks=networks, + random_key=random_key, + discount=self._config.discount, + target_update_period=self._config.target_update_period, + iterator=dataset, + policy_loss_coeff_fn=self._policy_loss_coeff_fn, + policy_optimizer=optax.adam(self._config.learning_rate), + critic_optimizer=optax.adam(self._config.learning_rate), + use_sarsa_target=self._config.use_sarsa_target, + logger=logger_fn("learner"), + counter=counter, + ) - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: actor_core_lib.FeedForwardPolicy, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - ) -> core.Actor: - del environment_spec - assert variable_source is not None - actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) - variable_client = variable_utils.VariableClient( - variable_source, 'policy', device='cpu') - return actors.GenericActor( - actor_core, random_key, variable_client, backend='cpu') + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: actor_core_lib.FeedForwardPolicy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + ) -> core.Actor: + del environment_spec + assert variable_source is not None + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) + variable_client = variable_utils.VariableClient( + variable_source, "policy", device="cpu" + ) + return actors.GenericActor( + actor_core, random_key, variable_client, backend="cpu" + ) - def make_policy(self, networks: crr_networks.CRRNetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool) -> actor_core_lib.FeedForwardPolicy: - """Construct the policy.""" - del environment_spec, evaluation + def make_policy( + self, + networks: crr_networks.CRRNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool, + ) -> actor_core_lib.FeedForwardPolicy: + """Construct the policy.""" + del environment_spec, evaluation - def evaluation_policy( - params: networks_lib.Params, key: networks_lib.PRNGKey, - observation: networks_lib.Observation) -> networks_lib.Action: - dist_params = networks.policy_network.apply(params, observation) - return networks.sample_eval(dist_params, key) + def evaluation_policy( + params: networks_lib.Params, + key: networks_lib.PRNGKey, + observation: networks_lib.Observation, + ) -> networks_lib.Action: + dist_params = networks.policy_network.apply(params, observation) + return networks.sample_eval(dist_params, key) - return evaluation_policy + return evaluation_policy diff --git a/acme/agents/jax/crr/config.py b/acme/agents/jax/crr/config.py index 9bfda702df..75cd4f606b 100644 --- a/acme/agents/jax/crr/config.py +++ b/acme/agents/jax/crr/config.py @@ -18,7 +18,7 @@ @dataclasses.dataclass class CRRConfig: - """Configuration options for CRR. + """Configuration options for CRR. Attributes: learning_rate: Learning rate. @@ -28,7 +28,8 @@ class CRRConfig: than sampled actions. Useful for 1-step offline RL (https://arxiv.org/pdf/2106.08909.pdf). """ - learning_rate: float = 3e-4 - discount: float = 0.99 - target_update_period: int = 100 - use_sarsa_target: bool = False + + learning_rate: float = 3e-4 + discount: float = 0.99 + target_update_period: int = 100 + use_sarsa_target: bool = False diff --git a/acme/agents/jax/crr/learning.py b/acme/agents/jax/crr/learning.py index 5502cfca66..3f7f658bc7 100644 --- a/acme/agents/jax/crr/learning.py +++ b/acme/agents/jax/crr/learning.py @@ -17,54 +17,57 @@ import time from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple +import jax +import jax.numpy as jnp +import optax + import acme from acme import types from acme.agents.jax.crr.losses import PolicyLossCoeff from acme.agents.jax.crr.networks import CRRNetworks from acme.jax import networks as networks_lib from acme.jax import utils -from acme.utils import counting -from acme.utils import loggers -import jax -import jax.numpy as jnp -import optax +from acme.utils import counting, loggers class TrainingState(NamedTuple): - """Contains training state for the learner.""" - policy_params: networks_lib.Params - target_policy_params: networks_lib.Params - critic_params: networks_lib.Params - target_critic_params: networks_lib.Params - policy_opt_state: optax.OptState - critic_opt_state: optax.OptState - steps: int - key: networks_lib.PRNGKey + """Contains training state for the learner.""" + + policy_params: networks_lib.Params + target_policy_params: networks_lib.Params + critic_params: networks_lib.Params + target_critic_params: networks_lib.Params + policy_opt_state: optax.OptState + critic_opt_state: optax.OptState + steps: int + key: networks_lib.PRNGKey class CRRLearner(acme.Learner): - """Critic Regularized Regression (CRR) learner. + """Critic Regularized Regression (CRR) learner. This is the learning component of a CRR agent as described in https://arxiv.org/abs/2006.15134. """ - _state: TrainingState - - def __init__(self, - networks: CRRNetworks, - random_key: networks_lib.PRNGKey, - discount: float, - target_update_period: int, - policy_loss_coeff_fn: PolicyLossCoeff, - iterator: Iterator[types.Transition], - policy_optimizer: optax.GradientTransformation, - critic_optimizer: optax.GradientTransformation, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - grad_updates_per_batch: int = 1, - use_sarsa_target: bool = False): - """Initializes the CRR learner. + _state: TrainingState + + def __init__( + self, + networks: CRRNetworks, + random_key: networks_lib.PRNGKey, + discount: float, + target_update_period: int, + policy_loss_coeff_fn: PolicyLossCoeff, + iterator: Iterator[types.Transition], + policy_optimizer: optax.GradientTransformation, + critic_optimizer: optax.GradientTransformation, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + grad_updates_per_batch: int = 1, + use_sarsa_target: bool = False, + ): + """Initializes the CRR learner. Args: networks: CRR networks. @@ -84,180 +87,197 @@ def __init__(self, When set to `True`, `target_policy_params` are unused. """ - critic_network = networks.critic_network - policy_network = networks.policy_network - - def policy_loss( - policy_params: networks_lib.Params, - critic_params: networks_lib.Params, - transition: types.Transition, - key: networks_lib.PRNGKey, - ) -> jnp.ndarray: - # Compute the loss coefficients. - coeff = policy_loss_coeff_fn(networks, policy_params, critic_params, - transition, key) - coeff = jax.lax.stop_gradient(coeff) - # Return the weighted loss. - dist_params = policy_network.apply(policy_params, transition.observation) - logp_action = networks.log_prob(dist_params, transition.action) - # Make sure there is no broadcasting. - logp_action *= coeff.flatten() - assert len(logp_action.shape) == 1 - return -jnp.mean(logp_action) - - def critic_loss( - critic_params: networks_lib.Params, - target_policy_params: networks_lib.Params, - target_critic_params: networks_lib.Params, - transition: types.Transition, - key: networks_lib.PRNGKey, - ): - # Sample the next action. - if use_sarsa_target: - # TODO(b/222674779): use N-steps Trajectories to get the next actions. - assert 'next_action' in transition.extras, ( - 'next actions should be given as extras for one step RL.') - next_action = transition.extras['next_action'] - else: - next_dist_params = policy_network.apply(target_policy_params, - transition.next_observation) - next_action = networks.sample(next_dist_params, key) - # Calculate the value of the next state and action. - next_q = critic_network.apply(target_critic_params, - transition.next_observation, next_action) - target_q = transition.reward + transition.discount * discount * next_q - target_q = jax.lax.stop_gradient(target_q) - - q = critic_network.apply(critic_params, transition.observation, - transition.action) - q_error = q - target_q - # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error. - # TODO(sertan): Replace with a distributional critic. CRR paper states - # that this may perform better. - return 0.5 * jnp.mean(jnp.square(q_error)) - - policy_loss_and_grad = jax.value_and_grad(policy_loss) - critic_loss_and_grad = jax.value_and_grad(critic_loss) - - def sgd_step( - state: TrainingState, - transitions: types.Transition, - ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: - - key, key_policy, key_critic = jax.random.split(state.key, 3) - - # Compute losses and their gradients. - policy_loss_value, policy_gradients = policy_loss_and_grad( - state.policy_params, state.critic_params, transitions, key_policy) - critic_loss_value, critic_gradients = critic_loss_and_grad( - state.critic_params, state.target_policy_params, - state.target_critic_params, transitions, key_critic) - - # Get optimizer updates and state. - policy_updates, policy_opt_state = policy_optimizer.update( - policy_gradients, state.policy_opt_state) - critic_updates, critic_opt_state = critic_optimizer.update( - critic_gradients, state.critic_opt_state) - - # Apply optimizer updates to parameters. - policy_params = optax.apply_updates(state.policy_params, policy_updates) - critic_params = optax.apply_updates(state.critic_params, critic_updates) - - steps = state.steps + 1 - - # Periodically update target networks. - target_policy_params, target_critic_params = optax.periodic_update( # pytype: disable=wrong-arg-types # numpy-scalars - (policy_params, critic_params), - (state.target_policy_params, state.target_critic_params), steps, - target_update_period) - - new_state = TrainingState( - policy_params=policy_params, - target_policy_params=target_policy_params, - critic_params=critic_params, - target_critic_params=target_critic_params, - policy_opt_state=policy_opt_state, - critic_opt_state=critic_opt_state, - steps=steps, - key=key, - ) - - metrics = { - 'policy_loss': policy_loss_value, - 'critic_loss': critic_loss_value, - } - - return new_state, metrics - - sgd_step = utils.process_multiple_batches(sgd_step, grad_updates_per_batch) - self._sgd_step = jax.jit(sgd_step) - - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger( - 'learner', - asynchronous=True, - serialize_fn=utils.fetch_devicearray, - steps_key=self._counter.get_steps_key()) - - # Create prefetching dataset iterator. - self._iterator = iterator - - # Create the network parameters and copy into the target network parameters. - key, key_policy, key_critic = jax.random.split(random_key, 3) - initial_policy_params = policy_network.init(key_policy) - initial_critic_params = critic_network.init(key_critic) - initial_target_policy_params = initial_policy_params - initial_target_critic_params = initial_critic_params - - # Initialize optimizers. - initial_policy_opt_state = policy_optimizer.init(initial_policy_params) - initial_critic_opt_state = critic_optimizer.init(initial_critic_params) - - # Create initial state. - self._state = TrainingState( - policy_params=initial_policy_params, - target_policy_params=initial_target_policy_params, - critic_params=initial_critic_params, - target_critic_params=initial_target_critic_params, - policy_opt_state=initial_policy_opt_state, - critic_opt_state=initial_critic_opt_state, - steps=0, - key=key, - ) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - def step(self): - transitions = next(self._iterator) - - self._state, metrics = self._sgd_step(self._state, transitions) - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Increment counts and record the current time - counts = self._counter.increment(steps=1, walltime=elapsed_time) - - # Attempts to write the logs. - self._logger.write({**metrics, **counts}) - - def get_variables(self, names: List[str]) -> List[networks_lib.Params]: - # We only expose the variables for the learned policy and critic. The target - # policy and critic are internal details. - variables = { - 'policy': self._state.target_policy_params, - 'critic': self._state.target_critic_params, - } - return [variables[name] for name in names] - - def save(self) -> TrainingState: - return self._state - - def restore(self, state: TrainingState): - self._state = state + critic_network = networks.critic_network + policy_network = networks.policy_network + + def policy_loss( + policy_params: networks_lib.Params, + critic_params: networks_lib.Params, + transition: types.Transition, + key: networks_lib.PRNGKey, + ) -> jnp.ndarray: + # Compute the loss coefficients. + coeff = policy_loss_coeff_fn( + networks, policy_params, critic_params, transition, key + ) + coeff = jax.lax.stop_gradient(coeff) + # Return the weighted loss. + dist_params = policy_network.apply(policy_params, transition.observation) + logp_action = networks.log_prob(dist_params, transition.action) + # Make sure there is no broadcasting. + logp_action *= coeff.flatten() + assert len(logp_action.shape) == 1 + return -jnp.mean(logp_action) + + def critic_loss( + critic_params: networks_lib.Params, + target_policy_params: networks_lib.Params, + target_critic_params: networks_lib.Params, + transition: types.Transition, + key: networks_lib.PRNGKey, + ): + # Sample the next action. + if use_sarsa_target: + # TODO(b/222674779): use N-steps Trajectories to get the next actions. + assert ( + "next_action" in transition.extras + ), "next actions should be given as extras for one step RL." + next_action = transition.extras["next_action"] + else: + next_dist_params = policy_network.apply( + target_policy_params, transition.next_observation + ) + next_action = networks.sample(next_dist_params, key) + # Calculate the value of the next state and action. + next_q = critic_network.apply( + target_critic_params, transition.next_observation, next_action + ) + target_q = transition.reward + transition.discount * discount * next_q + target_q = jax.lax.stop_gradient(target_q) + + q = critic_network.apply( + critic_params, transition.observation, transition.action + ) + q_error = q - target_q + # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error. + # TODO(sertan): Replace with a distributional critic. CRR paper states + # that this may perform better. + return 0.5 * jnp.mean(jnp.square(q_error)) + + policy_loss_and_grad = jax.value_and_grad(policy_loss) + critic_loss_and_grad = jax.value_and_grad(critic_loss) + + def sgd_step( + state: TrainingState, transitions: types.Transition, + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + + key, key_policy, key_critic = jax.random.split(state.key, 3) + + # Compute losses and their gradients. + policy_loss_value, policy_gradients = policy_loss_and_grad( + state.policy_params, state.critic_params, transitions, key_policy + ) + critic_loss_value, critic_gradients = critic_loss_and_grad( + state.critic_params, + state.target_policy_params, + state.target_critic_params, + transitions, + key_critic, + ) + + # Get optimizer updates and state. + policy_updates, policy_opt_state = policy_optimizer.update( + policy_gradients, state.policy_opt_state + ) + critic_updates, critic_opt_state = critic_optimizer.update( + critic_gradients, state.critic_opt_state + ) + + # Apply optimizer updates to parameters. + policy_params = optax.apply_updates(state.policy_params, policy_updates) + critic_params = optax.apply_updates(state.critic_params, critic_updates) + + steps = state.steps + 1 + + # Periodically update target networks. + ( + target_policy_params, + target_critic_params, + ) = optax.periodic_update( # pytype: disable=wrong-arg-types # numpy-scalars + (policy_params, critic_params), + (state.target_policy_params, state.target_critic_params), + steps, + target_update_period, + ) + + new_state = TrainingState( + policy_params=policy_params, + target_policy_params=target_policy_params, + critic_params=critic_params, + target_critic_params=target_critic_params, + policy_opt_state=policy_opt_state, + critic_opt_state=critic_opt_state, + steps=steps, + key=key, + ) + + metrics = { + "policy_loss": policy_loss_value, + "critic_loss": critic_loss_value, + } + + return new_state, metrics + + sgd_step = utils.process_multiple_batches(sgd_step, grad_updates_per_batch) + self._sgd_step = jax.jit(sgd_step) + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + "learner", + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key(), + ) + + # Create prefetching dataset iterator. + self._iterator = iterator + + # Create the network parameters and copy into the target network parameters. + key, key_policy, key_critic = jax.random.split(random_key, 3) + initial_policy_params = policy_network.init(key_policy) + initial_critic_params = critic_network.init(key_critic) + initial_target_policy_params = initial_policy_params + initial_target_critic_params = initial_critic_params + + # Initialize optimizers. + initial_policy_opt_state = policy_optimizer.init(initial_policy_params) + initial_critic_opt_state = critic_optimizer.init(initial_critic_params) + + # Create initial state. + self._state = TrainingState( + policy_params=initial_policy_params, + target_policy_params=initial_target_policy_params, + critic_params=initial_critic_params, + target_critic_params=initial_target_critic_params, + policy_opt_state=initial_policy_opt_state, + critic_opt_state=initial_critic_opt_state, + steps=0, + key=key, + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + def step(self): + transitions = next(self._iterator) + + self._state, metrics = self._sgd_step(self._state, transitions) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[networks_lib.Params]: + # We only expose the variables for the learned policy and critic. The target + # policy and critic are internal details. + variables = { + "policy": self._state.target_policy_params, + "critic": self._state.target_critic_params, + } + return [variables[name] for name in names] + + def save(self) -> TrainingState: + return self._state + + def restore(self, state: TrainingState): + self._state = state diff --git a/acme/agents/jax/crr/losses.py b/acme/agents/jax/crr/losses.py index 336b10509f..b0cdc03ff1 100644 --- a/acme/agents/jax/crr/losses.py +++ b/acme/agents/jax/crr/losses.py @@ -16,44 +16,51 @@ from typing import Callable +import jax.numpy as jnp + from acme import types from acme.agents.jax.crr.networks import CRRNetworks from acme.jax import networks as networks_lib -import jax.numpy as jnp -PolicyLossCoeff = Callable[[ - CRRNetworks, - networks_lib.Params, - networks_lib.Params, - types.Transition, - networks_lib.PRNGKey, -], jnp.ndarray] +PolicyLossCoeff = Callable[ + [ + CRRNetworks, + networks_lib.Params, + networks_lib.Params, + types.Transition, + networks_lib.PRNGKey, + ], + jnp.ndarray, +] -def _compute_advantage(networks: CRRNetworks, - policy_params: networks_lib.Params, - critic_params: networks_lib.Params, - transition: types.Transition, - key: networks_lib.PRNGKey, - num_action_samples: int = 4) -> jnp.ndarray: - """Returns the advantage for the transition.""" - # Sample count actions. - replicated_observation = jnp.broadcast_to(transition.observation, - (num_action_samples,) + - transition.observation.shape) - dist_params = networks.policy_network.apply(policy_params, - replicated_observation) - actions = networks.sample(dist_params, key) - # Compute the state-action values for the sampled actions. - q_actions = networks.critic_network.apply(critic_params, - replicated_observation, actions) - # Take the mean as the state-value estimate. It is also possible to take the - # maximum, aka CRR(max); see table 1 in CRR paper. - q_estimate = jnp.mean(q_actions, axis=0) - # Compute the advantage. - q = networks.critic_network.apply(critic_params, transition.observation, - transition.action) - return q - q_estimate +def _compute_advantage( + networks: CRRNetworks, + policy_params: networks_lib.Params, + critic_params: networks_lib.Params, + transition: types.Transition, + key: networks_lib.PRNGKey, + num_action_samples: int = 4, +) -> jnp.ndarray: + """Returns the advantage for the transition.""" + # Sample count actions. + replicated_observation = jnp.broadcast_to( + transition.observation, (num_action_samples,) + transition.observation.shape + ) + dist_params = networks.policy_network.apply(policy_params, replicated_observation) + actions = networks.sample(dist_params, key) + # Compute the state-action values for the sampled actions. + q_actions = networks.critic_network.apply( + critic_params, replicated_observation, actions + ) + # Take the mean as the state-value estimate. It is also possible to take the + # maximum, aka CRR(max); see table 1 in CRR paper. + q_estimate = jnp.mean(q_actions, axis=0) + # Compute the advantage. + q = networks.critic_network.apply( + critic_params, transition.observation, transition.action + ) + return q - q_estimate def policy_loss_coeff_advantage_exp( @@ -64,11 +71,13 @@ def policy_loss_coeff_advantage_exp( key: networks_lib.PRNGKey, num_action_samples: int = 4, beta: float = 1.0, - ratio_upper_bound: float = 20.0) -> jnp.ndarray: - """Exponential advantage weigting; see equation (4) in CRR paper.""" - advantage = _compute_advantage(networks, policy_params, critic_params, - transition, key, num_action_samples) - return jnp.minimum(jnp.exp(advantage / beta), ratio_upper_bound) + ratio_upper_bound: float = 20.0, +) -> jnp.ndarray: + """Exponential advantage weigting; see equation (4) in CRR paper.""" + advantage = _compute_advantage( + networks, policy_params, critic_params, transition, key, num_action_samples + ) + return jnp.minimum(jnp.exp(advantage / beta), ratio_upper_bound) def policy_loss_coeff_advantage_indicator( @@ -77,23 +86,27 @@ def policy_loss_coeff_advantage_indicator( critic_params: networks_lib.Params, transition: types.Transition, key: networks_lib.PRNGKey, - num_action_samples: int = 4) -> jnp.ndarray: - """Indicator advantage weighting; see equation (3) in CRR paper.""" - advantage = _compute_advantage(networks, policy_params, critic_params, - transition, key, num_action_samples) - return jnp.heaviside(advantage, 0.) + num_action_samples: int = 4, +) -> jnp.ndarray: + """Indicator advantage weighting; see equation (3) in CRR paper.""" + advantage = _compute_advantage( + networks, policy_params, critic_params, transition, key, num_action_samples + ) + return jnp.heaviside(advantage, 0.0) -def policy_loss_coeff_constant(networks: CRRNetworks, - policy_params: networks_lib.Params, - critic_params: networks_lib.Params, - transition: types.Transition, - key: networks_lib.PRNGKey, - value: float = 1.0) -> jnp.ndarray: - """Constant weights.""" - del networks - del policy_params - del critic_params - del transition - del key - return value # pytype: disable=bad-return-type # jax-ndarray +def policy_loss_coeff_constant( + networks: CRRNetworks, + policy_params: networks_lib.Params, + critic_params: networks_lib.Params, + transition: types.Transition, + key: networks_lib.PRNGKey, + value: float = 1.0, +) -> jnp.ndarray: + """Constant weights.""" + del networks + del policy_params + del critic_params + del transition + del key + return value # pytype: disable=bad-return-type # jax-ndarray diff --git a/acme/agents/jax/crr/networks.py b/acme/agents/jax/crr/networks.py index 6967791285..ead2ca2239 100644 --- a/acme/agents/jax/crr/networks.py +++ b/acme/agents/jax/crr/networks.py @@ -17,23 +17,25 @@ import dataclasses from typing import Callable, Tuple -from acme import specs -from acme.jax import networks as networks_lib -from acme.jax import utils import haiku as hk import jax import jax.numpy as jnp import numpy as np +from acme import specs +from acme.jax import networks as networks_lib +from acme.jax import utils + @dataclasses.dataclass class CRRNetworks: - """Network and pure functions for the CRR agent..""" - policy_network: networks_lib.FeedForwardNetwork - critic_network: networks_lib.FeedForwardNetwork - log_prob: networks_lib.LogProbFn - sample: networks_lib.SampleFn - sample_eval: networks_lib.SampleFn + """Network and pure functions for the CRR agent..""" + + policy_network: networks_lib.FeedForwardNetwork + critic_network: networks_lib.FeedForwardNetwork + log_prob: networks_lib.LogProbFn + sample: networks_lib.SampleFn + sample_eval: networks_lib.SampleFn def make_networks( @@ -42,45 +44,54 @@ def make_networks( critic_layer_sizes: Tuple[int, ...] = (256, 256), activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu, ) -> CRRNetworks: - """Creates networks used by the agent.""" - num_actions = np.prod(spec.actions.shape, dtype=int) + """Creates networks used by the agent.""" + num_actions = np.prod(spec.actions.shape, dtype=int) - # Create dummy observations and actions to create network parameters. - dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions)) - dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations)) + # Create dummy observations and actions to create network parameters. + dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions)) + dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations)) - def _policy_fn(obs: jnp.ndarray) -> jnp.ndarray: - network = hk.Sequential([ - hk.nets.MLP( - list(policy_layer_sizes), - w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), - activation=activation, - activate_final=True), - networks_lib.NormalTanhDistribution(num_actions), - ]) - return network(obs) + def _policy_fn(obs: jnp.ndarray) -> jnp.ndarray: + network = hk.Sequential( + [ + hk.nets.MLP( + list(policy_layer_sizes), + w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), + activation=activation, + activate_final=True, + ), + networks_lib.NormalTanhDistribution(num_actions), + ] + ) + return network(obs) - policy = hk.without_apply_rng(hk.transform(_policy_fn)) - policy_network = networks_lib.FeedForwardNetwork( - lambda key: policy.init(key, dummy_obs), policy.apply) + policy = hk.without_apply_rng(hk.transform(_policy_fn)) + policy_network = networks_lib.FeedForwardNetwork( + lambda key: policy.init(key, dummy_obs), policy.apply + ) - def _critic_fn(obs, action): - network = hk.Sequential([ - hk.nets.MLP( - list(critic_layer_sizes) + [1], - w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), - activation=activation), - ]) - data = jnp.concatenate([obs, action], axis=-1) - return network(data) + def _critic_fn(obs, action): + network = hk.Sequential( + [ + hk.nets.MLP( + list(critic_layer_sizes) + [1], + w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), + activation=activation, + ), + ] + ) + data = jnp.concatenate([obs, action], axis=-1) + return network(data) - critic = hk.without_apply_rng(hk.transform(_critic_fn)) - critic_network = networks_lib.FeedForwardNetwork( - lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply) + critic = hk.without_apply_rng(hk.transform(_critic_fn)) + critic_network = networks_lib.FeedForwardNetwork( + lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply + ) - return CRRNetworks( - policy_network=policy_network, - critic_network=critic_network, - log_prob=lambda params, actions: params.log_prob(actions), - sample=lambda params, key: params.sample(seed=key), - sample_eval=lambda params, key: params.mode()) + return CRRNetworks( + policy_network=policy_network, + critic_network=critic_network, + log_prob=lambda params, actions: params.log_prob(actions), + sample=lambda params, key: params.sample(seed=key), + sample_eval=lambda params, key: params.mode(), + ) diff --git a/acme/agents/jax/d4pg/__init__.py b/acme/agents/jax/d4pg/__init__.py index a9ea271e11..8d2416a736 100644 --- a/acme/agents/jax/d4pg/__init__.py +++ b/acme/agents/jax/d4pg/__init__.py @@ -17,8 +17,9 @@ from acme.agents.jax.d4pg.builder import D4PGBuilder from acme.agents.jax.d4pg.config import D4PGConfig from acme.agents.jax.d4pg.learning import D4PGLearner -from acme.agents.jax.d4pg.networks import D4PGNetworks -from acme.agents.jax.d4pg.networks import get_default_behavior_policy -from acme.agents.jax.d4pg.networks import get_default_eval_policy -from acme.agents.jax.d4pg.networks import make_networks - +from acme.agents.jax.d4pg.networks import ( + D4PGNetworks, + get_default_behavior_policy, + get_default_eval_policy, + make_networks, +) diff --git a/acme/agents/jax/d4pg/builder.py b/acme/agents/jax/d4pg/builder.py index feed96283d..12adb7b314 100644 --- a/acme/agents/jax/d4pg/builder.py +++ b/acme/agents/jax/d4pg/builder.py @@ -15,43 +15,41 @@ """D4PG Builder.""" from typing import Iterator, List, Optional +import jax +import optax +import reverb +import tensorflow as tf +import tree +from reverb import rate_limiters +from reverb import structured_writer as sw + import acme -from acme import adders -from acme import core -from acme import specs -from acme import types +from acme import adders, core, specs, types from acme.adders import reverb as adders_reverb from acme.adders.reverb import base as reverb_base from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.agents.jax import builders +from acme.agents.jax import actors, builders from acme.agents.jax.d4pg import config as d4pg_config from acme.agents.jax.d4pg import learning from acme.agents.jax.d4pg import networks as d4pg_networks from acme.datasets import reverb as datasets from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.jax import variable_utils -from acme.utils import counting -from acme.utils import loggers -import jax -import optax -import reverb -from reverb import rate_limiters -from reverb import structured_writer as sw -import tensorflow as tf -import tree +from acme.jax import utils, variable_utils +from acme.utils import counting, loggers -def _make_adder_config(step_spec: reverb_base.Step, n_step: int, - table: str) -> List[sw.Config]: - return adders_reverb.create_n_step_transition_config( - step_spec=step_spec, n_step=n_step, table=table) +def _make_adder_config( + step_spec: reverb_base.Step, n_step: int, table: str +) -> List[sw.Config]: + return adders_reverb.create_n_step_transition_config( + step_spec=step_spec, n_step=n_step, table=table + ) -def _as_n_step_transition(flat_trajectory: reverb.ReplaySample, - agent_discount: float) -> reverb.ReplaySample: - """Compute discounted return and total discount for N-step transitions. +def _as_n_step_transition( + flat_trajectory: reverb.ReplaySample, agent_discount: float +) -> reverb.ReplaySample: + """Compute discounted return and total discount for N-step transitions. For N greater than 1, transitions are of the form: @@ -92,188 +90,204 @@ def _as_n_step_transition(flat_trajectory: reverb.ReplaySample, Returns: A reverb.ReplaySample with computed discounted return and total discount. """ - trajectory = flat_trajectory.data - - def compute_discount_and_reward( - state: types.NestedTensor, - discount_and_reward: types.NestedTensor) -> types.NestedTensor: - compounded_discount, discounted_reward = state - return (agent_discount * discount_and_reward[0] * compounded_discount, - discounted_reward + discount_and_reward[1] * compounded_discount) - - initializer = (tf.constant(1, dtype=tf.float32), - tf.constant(0, dtype=tf.float32)) - elems = tf.stack((trajectory.discount, trajectory.reward), axis=-1) - total_discount, n_step_return = tf.scan( - compute_discount_and_reward, elems, initializer, reverse=True) - return reverb.ReplaySample( - info=flat_trajectory.info, - data=types.Transition( - observation=tree.map_structure(lambda x: x[0], - trajectory.observation), - action=tree.map_structure(lambda x: x[0], trajectory.action), - reward=n_step_return[0], - discount=total_discount[0], - next_observation=tree.map_structure(lambda x: x[-1], - trajectory.observation), - extras=tree.map_structure(lambda x: x[0], trajectory.extras))) - - -class D4PGBuilder(builders.ActorLearnerBuilder[d4pg_networks.D4PGNetworks, - actor_core_lib.ActorCore, - reverb.ReplaySample]): - """D4PG Builder.""" - - def __init__( - self, - config: d4pg_config.D4PGConfig, - ): - """Creates a D4PG learner, a behavior policy and an eval actor. - - Args: - config: a config with D4PG hps - """ - self._config = config - - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: d4pg_networks.D4PGNetworks, - dataset: Iterator[reverb.ReplaySample], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - del environment_spec, replay_client - - policy_optimizer = optax.adam(self._config.learning_rate) - critic_optimizer = optax.adam(self._config.learning_rate) - - if self._config.clipping: - policy_optimizer = optax.chain( - optax.clip_by_global_norm(40.), policy_optimizer) - critic_optimizer = optax.chain( - optax.clip_by_global_norm(40.), critic_optimizer) - - # The learner updates the parameters (and initializes them). - return learning.D4PGLearner( - policy_network=networks.policy_network, - critic_network=networks.critic_network, - random_key=random_key, - policy_optimizer=policy_optimizer, - critic_optimizer=critic_optimizer, - clipping=self._config.clipping, - discount=self._config.discount, - target_update_period=self._config.target_update_period, - iterator=dataset, - counter=counter, - logger=logger_fn('learner'), - num_sgd_steps_per_step=self._config.num_sgd_steps_per_step) - - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: actor_core_lib.ActorCore, - ) -> List[reverb.Table]: - """Create tables to insert data into.""" - dummy_actor_state = policy.init(jax.random.PRNGKey(0)) - extras_spec = policy.get_extras(dummy_actor_state) - step_spec = adders_reverb.create_step_spec( - environment_spec=environment_spec, extras_spec=extras_spec) - - # Create the rate limiter. - if self._config.samples_per_insert: - samples_per_insert_tolerance = ( - self._config.samples_per_insert_tolerance_rate * - self._config.samples_per_insert) - error_buffer = self._config.min_replay_size * samples_per_insert_tolerance - limiter = rate_limiters.SampleToInsertRatio( - min_size_to_sample=self._config.min_replay_size, - samples_per_insert=self._config.samples_per_insert, - error_buffer=error_buffer) - else: - limiter = rate_limiters.MinSize(self._config.min_replay_size) - return [ - reverb.Table( - name=self._config.replay_table_name, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=self._config.max_replay_size, - rate_limiter=limiter, - signature=sw.infer_signature( - configs=_make_adder_config(step_spec, self._config.n_step, - self._config.replay_table_name), - step_spec=step_spec)) - ] + trajectory = flat_trajectory.data + + def compute_discount_and_reward( + state: types.NestedTensor, discount_and_reward: types.NestedTensor + ) -> types.NestedTensor: + compounded_discount, discounted_reward = state + return ( + agent_discount * discount_and_reward[0] * compounded_discount, + discounted_reward + discount_and_reward[1] * compounded_discount, + ) + + initializer = (tf.constant(1, dtype=tf.float32), tf.constant(0, dtype=tf.float32)) + elems = tf.stack((trajectory.discount, trajectory.reward), axis=-1) + total_discount, n_step_return = tf.scan( + compute_discount_and_reward, elems, initializer, reverse=True + ) + return reverb.ReplaySample( + info=flat_trajectory.info, + data=types.Transition( + observation=tree.map_structure(lambda x: x[0], trajectory.observation), + action=tree.map_structure(lambda x: x[0], trajectory.action), + reward=n_step_return[0], + discount=total_discount[0], + next_observation=tree.map_structure( + lambda x: x[-1], trajectory.observation + ), + extras=tree.map_structure(lambda x: x[0], trajectory.extras), + ), + ) - def make_dataset_iterator( - self, - replay_client: reverb.Client, - ) -> Iterator[reverb.ReplaySample]: - """Create a dataset iterator to use for learning/updating the agent.""" - def postprocess( - flat_trajectory: reverb.ReplaySample) -> reverb.ReplaySample: - return _as_n_step_transition(flat_trajectory, self._config.discount) +class D4PGBuilder( + builders.ActorLearnerBuilder[ + d4pg_networks.D4PGNetworks, actor_core_lib.ActorCore, reverb.ReplaySample + ] +): + """D4PG Builder.""" - batch_size_per_device = self._config.batch_size // jax.device_count() + def __init__( + self, config: d4pg_config.D4PGConfig, + ): + """Creates a D4PG learner, a behavior policy and an eval actor. - dataset = datasets.make_reverb_dataset( - table=self._config.replay_table_name, - server_address=replay_client.server_address, - batch_size=batch_size_per_device * self._config.num_sgd_steps_per_step, - prefetch_size=self._config.prefetch_size, - postprocess=postprocess, - ) - return utils.multi_device_put(dataset.as_numpy_iterator(), - jax.local_devices()) - - def make_adder( - self, - replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[actor_core_lib.ActorCore], - ) -> Optional[adders.Adder]: - """Create an adder which records data generated by the actor/environment.""" - if environment_spec is None or policy is None: - raise ValueError('`environment_spec` and `policy` cannot be None.') - dummy_actor_state = policy.init(jax.random.PRNGKey(0)) - extras_spec = policy.get_extras(dummy_actor_state) - step_spec = adders_reverb.create_step_spec( - environment_spec=environment_spec, extras_spec=extras_spec) - return adders_reverb.StructuredAdder( - client=replay_client, - max_in_flight_items=5, - configs=_make_adder_config(step_spec, self._config.n_step, - self._config.replay_table_name), - step_spec=step_spec) - - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: actor_core_lib.ActorCore, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - adder: Optional[adders.Adder] = None, - ) -> acme.Actor: - del environment_spec - assert variable_source is not None - # Inference happens on CPU, so it's better to move variables there too. - variable_client = variable_utils.VariableClient( - variable_source, 'policy', device='cpu') - return actors.GenericActor( - policy, random_key, variable_client, adder, backend='cpu') - - def make_policy(self, - networks: d4pg_networks.D4PGNetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> actor_core_lib.ActorCore: - """Create the policy.""" - del environment_spec - if evaluation: - policy = d4pg_networks.get_default_eval_policy(networks) - else: - policy = d4pg_networks.get_default_behavior_policy(networks, self._config) - - return actor_core_lib.batched_feed_forward_to_actor_core(policy) + Args: + config: a config with D4PG hps + """ + self._config = config + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: d4pg_networks.D4PGNetworks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec, replay_client + + policy_optimizer = optax.adam(self._config.learning_rate) + critic_optimizer = optax.adam(self._config.learning_rate) + + if self._config.clipping: + policy_optimizer = optax.chain( + optax.clip_by_global_norm(40.0), policy_optimizer + ) + critic_optimizer = optax.chain( + optax.clip_by_global_norm(40.0), critic_optimizer + ) + + # The learner updates the parameters (and initializes them). + return learning.D4PGLearner( + policy_network=networks.policy_network, + critic_network=networks.critic_network, + random_key=random_key, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + clipping=self._config.clipping, + discount=self._config.discount, + target_update_period=self._config.target_update_period, + iterator=dataset, + counter=counter, + logger=logger_fn("learner"), + num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, + ) + + def make_replay_tables( + self, environment_spec: specs.EnvironmentSpec, policy: actor_core_lib.ActorCore, + ) -> List[reverb.Table]: + """Create tables to insert data into.""" + dummy_actor_state = policy.init(jax.random.PRNGKey(0)) + extras_spec = policy.get_extras(dummy_actor_state) + step_spec = adders_reverb.create_step_spec( + environment_spec=environment_spec, extras_spec=extras_spec + ) + + # Create the rate limiter. + if self._config.samples_per_insert: + samples_per_insert_tolerance = ( + self._config.samples_per_insert_tolerance_rate + * self._config.samples_per_insert + ) + error_buffer = self._config.min_replay_size * samples_per_insert_tolerance + limiter = rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer, + ) + else: + limiter = rate_limiters.MinSize(self._config.min_replay_size) + return [ + reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=sw.infer_signature( + configs=_make_adder_config( + step_spec, self._config.n_step, self._config.replay_table_name + ), + step_spec=step_spec, + ), + ) + ] + + def make_dataset_iterator( + self, replay_client: reverb.Client, + ) -> Iterator[reverb.ReplaySample]: + """Create a dataset iterator to use for learning/updating the agent.""" + + def postprocess(flat_trajectory: reverb.ReplaySample) -> reverb.ReplaySample: + return _as_n_step_transition(flat_trajectory, self._config.discount) + + batch_size_per_device = self._config.batch_size // jax.device_count() + + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=replay_client.server_address, + batch_size=batch_size_per_device * self._config.num_sgd_steps_per_step, + prefetch_size=self._config.prefetch_size, + postprocess=postprocess, + ) + return utils.multi_device_put(dataset.as_numpy_iterator(), jax.local_devices()) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[actor_core_lib.ActorCore], + ) -> Optional[adders.Adder]: + """Create an adder which records data generated by the actor/environment.""" + if environment_spec is None or policy is None: + raise ValueError("`environment_spec` and `policy` cannot be None.") + dummy_actor_state = policy.init(jax.random.PRNGKey(0)) + extras_spec = policy.get_extras(dummy_actor_state) + step_spec = adders_reverb.create_step_spec( + environment_spec=environment_spec, extras_spec=extras_spec + ) + return adders_reverb.StructuredAdder( + client=replay_client, + max_in_flight_items=5, + configs=_make_adder_config( + step_spec, self._config.n_step, self._config.replay_table_name + ), + step_spec=step_spec, + ) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: actor_core_lib.ActorCore, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> acme.Actor: + del environment_spec + assert variable_source is not None + # Inference happens on CPU, so it's better to move variables there too. + variable_client = variable_utils.VariableClient( + variable_source, "policy", device="cpu" + ) + return actors.GenericActor( + policy, random_key, variable_client, adder, backend="cpu" + ) + + def make_policy( + self, + networks: d4pg_networks.D4PGNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> actor_core_lib.ActorCore: + """Create the policy.""" + del environment_spec + if evaluation: + policy = d4pg_networks.get_default_eval_policy(networks) + else: + policy = d4pg_networks.get_default_behavior_policy(networks, self._config) + + return actor_core_lib.batched_feed_forward_to_actor_core(policy) diff --git a/acme/agents/jax/d4pg/config.py b/acme/agents/jax/d4pg/config.py index 338a1abc45..9a6d6d8df3 100644 --- a/acme/agents/jax/d4pg/config.py +++ b/acme/agents/jax/d4pg/config.py @@ -15,31 +15,33 @@ """Config classes for D4PG.""" import dataclasses from typing import Optional + from acme.adders import reverb as adders_reverb @dataclasses.dataclass class D4PGConfig: - """Configuration options for D4PG.""" - sigma: float = 0.3 - target_update_period: int = 100 - samples_per_insert: Optional[float] = 32.0 - - # Loss options - n_step: int = 5 - discount: float = 0.99 - batch_size: int = 256 - learning_rate: float = 1e-4 - clipping: bool = True - - # Replay options - min_replay_size: int = 1000 - max_replay_size: int = 1000000 - replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE - prefetch_size: int = 4 - # Rate to be used for the SampleToInsertRatio rate limitter tolerance. - # See a formula in make_replay_tables for more details. - samples_per_insert_tolerance_rate: float = 0.1 - - # How many gradient updates to perform per step. - num_sgd_steps_per_step: int = 1 + """Configuration options for D4PG.""" + + sigma: float = 0.3 + target_update_period: int = 100 + samples_per_insert: Optional[float] = 32.0 + + # Loss options + n_step: int = 5 + discount: float = 0.99 + batch_size: int = 256 + learning_rate: float = 1e-4 + clipping: bool = True + + # Replay options + min_replay_size: int = 1000 + max_replay_size: int = 1000000 + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE + prefetch_size: int = 4 + # Rate to be used for the SampleToInsertRatio rate limitter tolerance. + # See a formula in make_replay_tables for more details. + samples_per_insert_tolerance_rate: float = 0.1 + + # How many gradient updates to perform per step. + num_sgd_steps_per_step: int = 1 diff --git a/acme/agents/jax/d4pg/learning.py b/acme/agents/jax/d4pg/learning.py index b1b897225d..81e175005e 100644 --- a/acme/agents/jax/d4pg/learning.py +++ b/acme/agents/jax/d4pg/learning.py @@ -17,245 +17,280 @@ import time from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple -import acme -from acme import types -from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.utils import counting -from acme.utils import loggers import jax import jax.numpy as jnp import optax import reverb import rlax -_PMAP_AXIS_NAME = 'data' +import acme +from acme import types +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import counting, loggers + +_PMAP_AXIS_NAME = "data" class TrainingState(NamedTuple): - """Contains training state for the learner.""" - policy_params: networks_lib.Params - target_policy_params: networks_lib.Params - critic_params: networks_lib.Params - target_critic_params: networks_lib.Params - policy_opt_state: optax.OptState - critic_opt_state: optax.OptState - steps: int + """Contains training state for the learner.""" + + policy_params: networks_lib.Params + target_policy_params: networks_lib.Params + critic_params: networks_lib.Params + target_critic_params: networks_lib.Params + policy_opt_state: optax.OptState + critic_opt_state: optax.OptState + steps: int class D4PGLearner(acme.Learner): - """D4PG learner. + """D4PG learner. This is the learning component of a D4PG agent. IE it takes a dataset as input and implements update functionality to learn from this dataset. """ - _state: TrainingState - - def __init__(self, - policy_network: networks_lib.FeedForwardNetwork, - critic_network: networks_lib.FeedForwardNetwork, - random_key: networks_lib.PRNGKey, - discount: float, - target_update_period: int, - iterator: Iterator[reverb.ReplaySample], - policy_optimizer: Optional[optax.GradientTransformation] = None, - critic_optimizer: Optional[optax.GradientTransformation] = None, - clipping: bool = True, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - jit: bool = True, - num_sgd_steps_per_step: int = 1): - - def critic_mean( - critic_params: networks_lib.Params, - observation: types.NestedArray, - action: types.NestedArray, - ) -> jnp.ndarray: - # We add batch dimension to make sure batch concat in critic_network - # works correctly. - observation = utils.add_batch_dim(observation) - action = utils.add_batch_dim(action) - # Computes the mean action-value estimate. - logits, atoms = critic_network.apply(critic_params, observation, action) - logits = utils.squeeze_batch_dim(logits) - probabilities = jax.nn.softmax(logits) - return jnp.sum(probabilities * atoms, axis=-1) - - def policy_loss( - policy_params: networks_lib.Params, - critic_params: networks_lib.Params, - o_t: types.NestedArray, - ) -> jnp.ndarray: - # Computes the discrete policy gradient loss. - dpg_a_t = policy_network.apply(policy_params, o_t) - grad_critic = jax.vmap( - jax.grad(critic_mean, argnums=2), in_axes=(None, 0, 0)) - dq_da = grad_critic(critic_params, o_t, dpg_a_t) - dqda_clipping = 1. if clipping else None - batch_dpg_learning = jax.vmap(rlax.dpg_loss, in_axes=(0, 0, None)) - loss = batch_dpg_learning(dpg_a_t, dq_da, dqda_clipping) - return jnp.mean(loss) - - def critic_loss( - critic_params: networks_lib.Params, - state: TrainingState, - transition: types.Transition, + _state: TrainingState + + def __init__( + self, + policy_network: networks_lib.FeedForwardNetwork, + critic_network: networks_lib.FeedForwardNetwork, + random_key: networks_lib.PRNGKey, + discount: float, + target_update_period: int, + iterator: Iterator[reverb.ReplaySample], + policy_optimizer: Optional[optax.GradientTransformation] = None, + critic_optimizer: Optional[optax.GradientTransformation] = None, + clipping: bool = True, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + jit: bool = True, + num_sgd_steps_per_step: int = 1, ): - # Computes the distributional critic loss. - q_tm1, atoms_tm1 = critic_network.apply(critic_params, - transition.observation, - transition.action) - a = policy_network.apply(state.target_policy_params, - transition.next_observation) - q_t, atoms_t = critic_network.apply(state.target_critic_params, - transition.next_observation, a) - batch_td_learning = jax.vmap( - rlax.categorical_td_learning, in_axes=(None, 0, 0, 0, None, 0)) - loss = batch_td_learning(atoms_tm1, q_tm1, transition.reward, - discount * transition.discount, atoms_t, q_t) - return jnp.mean(loss) - - def sgd_step( - state: TrainingState, - transitions: types.Transition, - ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: - - # TODO(jaslanides): Use a shared forward pass for efficiency. - policy_loss_and_grad = jax.value_and_grad(policy_loss) - critic_loss_and_grad = jax.value_and_grad(critic_loss) - - # Compute losses and their gradients. - policy_loss_value, policy_gradients = policy_loss_and_grad( - state.policy_params, state.critic_params, - transitions.next_observation) - critic_loss_value, critic_gradients = critic_loss_and_grad( - state.critic_params, state, transitions) - - # Average over all devices. - policy_loss_value, policy_gradients = jax.lax.pmean( - (policy_loss_value, policy_gradients), _PMAP_AXIS_NAME) - critic_loss_value, critic_gradients = jax.lax.pmean( - (critic_loss_value, critic_gradients), _PMAP_AXIS_NAME) - - # Get optimizer updates and state. - policy_updates, policy_opt_state = policy_optimizer.update( # pytype: disable=attribute-error - policy_gradients, state.policy_opt_state) - critic_updates, critic_opt_state = critic_optimizer.update( # pytype: disable=attribute-error - critic_gradients, state.critic_opt_state) - - # Apply optimizer updates to parameters. - policy_params = optax.apply_updates(state.policy_params, policy_updates) - critic_params = optax.apply_updates(state.critic_params, critic_updates) - - steps = state.steps + 1 - - # Periodically update target networks. - target_policy_params, target_critic_params = optax.periodic_update( # pytype: disable=wrong-arg-types # numpy-scalars - (policy_params, critic_params), - (state.target_policy_params, state.target_critic_params), steps, - self._target_update_period) - - new_state = TrainingState( - policy_params=policy_params, - critic_params=critic_params, - target_policy_params=target_policy_params, - target_critic_params=target_critic_params, - policy_opt_state=policy_opt_state, - critic_opt_state=critic_opt_state, - steps=steps, - ) - - metrics = { - 'policy_loss': policy_loss_value, - 'critic_loss': critic_loss_value, - } - - return new_state, metrics - - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger( - 'learner', - asynchronous=True, - serialize_fn=utils.fetch_devicearray, - steps_key=self._counter.get_steps_key()) - - # Necessary to track when to update target networks. - self._target_update_period = target_update_period - - # Create prefetching dataset iterator. - self._iterator = iterator - - # Maybe use the JIT compiler. - sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step) - self._sgd_step = ( - jax.pmap(sgd_step, _PMAP_AXIS_NAME, devices=jax.devices()) - if jit else sgd_step) - - # Create the network parameters and copy into the target network parameters. - key_policy, key_critic = jax.random.split(random_key) - initial_policy_params = policy_network.init(key_policy) - initial_critic_params = critic_network.init(key_critic) - initial_target_policy_params = initial_policy_params - initial_target_critic_params = initial_critic_params - - # Create optimizers if they aren't given. - critic_optimizer = critic_optimizer or optax.adam(1e-4) - policy_optimizer = policy_optimizer or optax.adam(1e-4) - - # Initialize optimizers. - initial_policy_opt_state = policy_optimizer.init(initial_policy_params) # pytype: disable=attribute-error - initial_critic_opt_state = critic_optimizer.init(initial_critic_params) # pytype: disable=attribute-error - - # Create the initial state and replicate it in all devices. - self._state = utils.replicate_in_all_devices( - TrainingState( - policy_params=initial_policy_params, - target_policy_params=initial_target_policy_params, - critic_params=initial_critic_params, - target_critic_params=initial_target_critic_params, - policy_opt_state=initial_policy_opt_state, - critic_opt_state=initial_critic_opt_state, - steps=0, - )) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - def step(self): - # Sample from replay and pack the data in a Transition. - sample = next(self._iterator) - transitions = types.Transition(*sample.data) - - self._state, metrics = self._sgd_step(self._state, transitions) - - # Take the metrics from the first device, since they've been pmeaned over - # all devices and are therefore identical. - metrics = utils.get_from_first_device(metrics) - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Increment counts and record the current time - counts = self._counter.increment(steps=1, walltime=elapsed_time) - - # Attempts to write the logs. - self._logger.write({**metrics, **counts}) - - def get_variables(self, names: List[str]) -> List[networks_lib.Params]: - variables = { - 'policy': self._state.target_policy_params, - 'critic': self._state.target_critic_params, - } - return utils.get_from_first_device([variables[name] for name in names]) - - def save(self) -> TrainingState: - return utils.get_from_first_device(self._state) - - def restore(self, state: TrainingState): - self._state = utils.replicate_in_all_devices(state) + def critic_mean( + critic_params: networks_lib.Params, + observation: types.NestedArray, + action: types.NestedArray, + ) -> jnp.ndarray: + # We add batch dimension to make sure batch concat in critic_network + # works correctly. + observation = utils.add_batch_dim(observation) + action = utils.add_batch_dim(action) + # Computes the mean action-value estimate. + logits, atoms = critic_network.apply(critic_params, observation, action) + logits = utils.squeeze_batch_dim(logits) + probabilities = jax.nn.softmax(logits) + return jnp.sum(probabilities * atoms, axis=-1) + + def policy_loss( + policy_params: networks_lib.Params, + critic_params: networks_lib.Params, + o_t: types.NestedArray, + ) -> jnp.ndarray: + # Computes the discrete policy gradient loss. + dpg_a_t = policy_network.apply(policy_params, o_t) + grad_critic = jax.vmap( + jax.grad(critic_mean, argnums=2), in_axes=(None, 0, 0) + ) + dq_da = grad_critic(critic_params, o_t, dpg_a_t) + dqda_clipping = 1.0 if clipping else None + batch_dpg_learning = jax.vmap(rlax.dpg_loss, in_axes=(0, 0, None)) + loss = batch_dpg_learning(dpg_a_t, dq_da, dqda_clipping) + return jnp.mean(loss) + + def critic_loss( + critic_params: networks_lib.Params, + state: TrainingState, + transition: types.Transition, + ): + # Computes the distributional critic loss. + q_tm1, atoms_tm1 = critic_network.apply( + critic_params, transition.observation, transition.action + ) + a = policy_network.apply( + state.target_policy_params, transition.next_observation + ) + q_t, atoms_t = critic_network.apply( + state.target_critic_params, transition.next_observation, a + ) + batch_td_learning = jax.vmap( + rlax.categorical_td_learning, in_axes=(None, 0, 0, 0, None, 0) + ) + loss = batch_td_learning( + atoms_tm1, + q_tm1, + transition.reward, + discount * transition.discount, + atoms_t, + q_t, + ) + return jnp.mean(loss) + + def sgd_step( + state: TrainingState, transitions: types.Transition, + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + + # TODO(jaslanides): Use a shared forward pass for efficiency. + policy_loss_and_grad = jax.value_and_grad(policy_loss) + critic_loss_and_grad = jax.value_and_grad(critic_loss) + + # Compute losses and their gradients. + policy_loss_value, policy_gradients = policy_loss_and_grad( + state.policy_params, state.critic_params, transitions.next_observation + ) + critic_loss_value, critic_gradients = critic_loss_and_grad( + state.critic_params, state, transitions + ) + + # Average over all devices. + policy_loss_value, policy_gradients = jax.lax.pmean( + (policy_loss_value, policy_gradients), _PMAP_AXIS_NAME + ) + critic_loss_value, critic_gradients = jax.lax.pmean( + (critic_loss_value, critic_gradients), _PMAP_AXIS_NAME + ) + + # Get optimizer updates and state. + ( + policy_updates, + policy_opt_state, + ) = policy_optimizer.update( # pytype: disable=attribute-error + policy_gradients, state.policy_opt_state + ) + ( + critic_updates, + critic_opt_state, + ) = critic_optimizer.update( # pytype: disable=attribute-error + critic_gradients, state.critic_opt_state + ) + + # Apply optimizer updates to parameters. + policy_params = optax.apply_updates(state.policy_params, policy_updates) + critic_params = optax.apply_updates(state.critic_params, critic_updates) + + steps = state.steps + 1 + + # Periodically update target networks. + ( + target_policy_params, + target_critic_params, + ) = optax.periodic_update( # pytype: disable=wrong-arg-types # numpy-scalars + (policy_params, critic_params), + (state.target_policy_params, state.target_critic_params), + steps, + self._target_update_period, + ) + + new_state = TrainingState( + policy_params=policy_params, + critic_params=critic_params, + target_policy_params=target_policy_params, + target_critic_params=target_critic_params, + policy_opt_state=policy_opt_state, + critic_opt_state=critic_opt_state, + steps=steps, + ) + + metrics = { + "policy_loss": policy_loss_value, + "critic_loss": critic_loss_value, + } + + return new_state, metrics + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + "learner", + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key(), + ) + + # Necessary to track when to update target networks. + self._target_update_period = target_update_period + + # Create prefetching dataset iterator. + self._iterator = iterator + + # Maybe use the JIT compiler. + sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step) + self._sgd_step = ( + jax.pmap(sgd_step, _PMAP_AXIS_NAME, devices=jax.devices()) + if jit + else sgd_step + ) + + # Create the network parameters and copy into the target network parameters. + key_policy, key_critic = jax.random.split(random_key) + initial_policy_params = policy_network.init(key_policy) + initial_critic_params = critic_network.init(key_critic) + initial_target_policy_params = initial_policy_params + initial_target_critic_params = initial_critic_params + + # Create optimizers if they aren't given. + critic_optimizer = critic_optimizer or optax.adam(1e-4) + policy_optimizer = policy_optimizer or optax.adam(1e-4) + + # Initialize optimizers. + initial_policy_opt_state = policy_optimizer.init( + initial_policy_params + ) # pytype: disable=attribute-error + initial_critic_opt_state = critic_optimizer.init( + initial_critic_params + ) # pytype: disable=attribute-error + + # Create the initial state and replicate it in all devices. + self._state = utils.replicate_in_all_devices( + TrainingState( + policy_params=initial_policy_params, + target_policy_params=initial_target_policy_params, + critic_params=initial_critic_params, + target_critic_params=initial_target_critic_params, + policy_opt_state=initial_policy_opt_state, + critic_opt_state=initial_critic_opt_state, + steps=0, + ) + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + def step(self): + # Sample from replay and pack the data in a Transition. + sample = next(self._iterator) + transitions = types.Transition(*sample.data) + + self._state, metrics = self._sgd_step(self._state, transitions) + + # Take the metrics from the first device, since they've been pmeaned over + # all devices and are therefore identical. + metrics = utils.get_from_first_device(metrics) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[networks_lib.Params]: + variables = { + "policy": self._state.target_policy_params, + "critic": self._state.target_critic_params, + } + return utils.get_from_first_device([variables[name] for name in names]) + + def save(self) -> TrainingState: + return utils.get_from_first_device(self._state) + + def restore(self, state: TrainingState): + self._state = utils.replicate_in_all_devices(state) diff --git a/acme/agents/jax/d4pg/networks.py b/acme/agents/jax/d4pg/networks.py index d598035ec6..b00d947481 100644 --- a/acme/agents/jax/d4pg/networks.py +++ b/acme/agents/jax/d4pg/networks.py @@ -17,93 +17,109 @@ import dataclasses from typing import Sequence -from acme import specs -from acme import types -from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax.d4pg import config as d4pg_config -from acme.jax import networks as networks_lib -from acme.jax import utils import haiku as hk import jax.numpy as jnp import numpy as np import rlax +from acme import specs, types +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax.d4pg import config as d4pg_config +from acme.jax import networks as networks_lib +from acme.jax import utils + @dataclasses.dataclass class D4PGNetworks: - """Network and pure functions for the D4PG agent..""" - policy_network: networks_lib.FeedForwardNetwork - critic_network: networks_lib.FeedForwardNetwork + """Network and pure functions for the D4PG agent..""" + + policy_network: networks_lib.FeedForwardNetwork + critic_network: networks_lib.FeedForwardNetwork def get_default_behavior_policy( - networks: D4PGNetworks, - config: d4pg_config.D4PGConfig) -> actor_core_lib.FeedForwardPolicy: - """Selects action according to the training policy.""" - def behavior_policy(params: networks_lib.Params, key: networks_lib.PRNGKey, - observation: types.NestedArray): - action = networks.policy_network.apply(params, observation) - if config.sigma != 0: - action = rlax.add_gaussian_noise(key, action, config.sigma) - return action - - return behavior_policy - - -def get_default_eval_policy( - networks: D4PGNetworks) -> actor_core_lib.FeedForwardPolicy: - """Selects action according to the training policy.""" - def behavior_policy(params: networks_lib.Params, key: networks_lib.PRNGKey, - observation: types.NestedArray): - del key - action = networks.policy_network.apply(params, observation) - return action - return behavior_policy + networks: D4PGNetworks, config: d4pg_config.D4PGConfig +) -> actor_core_lib.FeedForwardPolicy: + """Selects action according to the training policy.""" + + def behavior_policy( + params: networks_lib.Params, + key: networks_lib.PRNGKey, + observation: types.NestedArray, + ): + action = networks.policy_network.apply(params, observation) + if config.sigma != 0: + action = rlax.add_gaussian_noise(key, action, config.sigma) + return action + + return behavior_policy + + +def get_default_eval_policy(networks: D4PGNetworks) -> actor_core_lib.FeedForwardPolicy: + """Selects action according to the training policy.""" + + def behavior_policy( + params: networks_lib.Params, + key: networks_lib.PRNGKey, + observation: types.NestedArray, + ): + del key + action = networks.policy_network.apply(params, observation) + return action + + return behavior_policy def make_networks( spec: specs.EnvironmentSpec, policy_layer_sizes: Sequence[int] = (300, 200), critic_layer_sizes: Sequence[int] = (400, 300), - vmin: float = -150., - vmax: float = 150., + vmin: float = -150.0, + vmax: float = 150.0, num_atoms: int = 51, ) -> D4PGNetworks: - """Creates networks used by the agent.""" - - action_spec = spec.actions - - num_dimensions = np.prod(action_spec.shape, dtype=int) - critic_atoms = jnp.linspace(vmin, vmax, num_atoms) - - def _actor_fn(obs): - network = hk.Sequential([ - utils.batch_concat, - networks_lib.LayerNormMLP(policy_layer_sizes, activate_final=True), - networks_lib.NearZeroInitializedLinear(num_dimensions), - networks_lib.TanhToSpec(action_spec), - ]) - return network(obs) - - def _critic_fn(obs, action): - network = hk.Sequential([ - utils.batch_concat, - networks_lib.LayerNormMLP(layer_sizes=[*critic_layer_sizes, num_atoms]), - ]) - value = network([obs, action]) - return value, critic_atoms - - policy = hk.without_apply_rng(hk.transform(_actor_fn)) - critic = hk.without_apply_rng(hk.transform(_critic_fn)) - - # Create dummy observations and actions to create network parameters. - dummy_action = utils.zeros_like(spec.actions) - dummy_obs = utils.zeros_like(spec.observations) - dummy_action = utils.add_batch_dim(dummy_action) - dummy_obs = utils.add_batch_dim(dummy_obs) - - return D4PGNetworks( - policy_network=networks_lib.FeedForwardNetwork( - lambda rng: policy.init(rng, dummy_obs), policy.apply), - critic_network=networks_lib.FeedForwardNetwork( - lambda rng: critic.init(rng, dummy_obs, dummy_action), critic.apply)) + """Creates networks used by the agent.""" + + action_spec = spec.actions + + num_dimensions = np.prod(action_spec.shape, dtype=int) + critic_atoms = jnp.linspace(vmin, vmax, num_atoms) + + def _actor_fn(obs): + network = hk.Sequential( + [ + utils.batch_concat, + networks_lib.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks_lib.NearZeroInitializedLinear(num_dimensions), + networks_lib.TanhToSpec(action_spec), + ] + ) + return network(obs) + + def _critic_fn(obs, action): + network = hk.Sequential( + [ + utils.batch_concat, + networks_lib.LayerNormMLP(layer_sizes=[*critic_layer_sizes, num_atoms]), + ] + ) + value = network([obs, action]) + return value, critic_atoms + + policy = hk.without_apply_rng(hk.transform(_actor_fn)) + critic = hk.without_apply_rng(hk.transform(_critic_fn)) + + # Create dummy observations and actions to create network parameters. + dummy_action = utils.zeros_like(spec.actions) + dummy_obs = utils.zeros_like(spec.observations) + dummy_action = utils.add_batch_dim(dummy_action) + dummy_obs = utils.add_batch_dim(dummy_obs) + + return D4PGNetworks( + policy_network=networks_lib.FeedForwardNetwork( + lambda rng: policy.init(rng, dummy_obs), policy.apply + ), + critic_network=networks_lib.FeedForwardNetwork( + lambda rng: critic.init(rng, dummy_obs, dummy_action), critic.apply + ), + ) diff --git a/acme/agents/jax/dqn/__init__.py b/acme/agents/jax/dqn/__init__.py index 4f84f493a9..39a0ac6f7d 100644 --- a/acme/agents/jax/dqn/__init__.py +++ b/acme/agents/jax/dqn/__init__.py @@ -14,21 +14,21 @@ """Implementation of a deep Q-networks (DQN) agent.""" -from acme.agents.jax.dqn.actor import behavior_policy -from acme.agents.jax.dqn.actor import default_behavior_policy -from acme.agents.jax.dqn.actor import DQNPolicy -from acme.agents.jax.dqn.actor import Epsilon -from acme.agents.jax.dqn.actor import EpsilonPolicy -from acme.agents.jax.dqn.builder import DistributionalDQNBuilder -from acme.agents.jax.dqn.builder import DQNBuilder +from acme.agents.jax.dqn.actor import ( + DQNPolicy, + Epsilon, + EpsilonPolicy, + behavior_policy, + default_behavior_policy, +) +from acme.agents.jax.dqn.builder import DistributionalDQNBuilder, DQNBuilder from acme.agents.jax.dqn.config import DQNConfig from acme.agents.jax.dqn.learning import DQNLearner -from acme.agents.jax.dqn.learning_lib import LossExtra -from acme.agents.jax.dqn.learning_lib import LossFn -from acme.agents.jax.dqn.learning_lib import ReverbUpdate -from acme.agents.jax.dqn.learning_lib import SGDLearner -from acme.agents.jax.dqn.losses import PrioritizedCategoricalDoubleQLearning -from acme.agents.jax.dqn.losses import PrioritizedDoubleQLearning -from acme.agents.jax.dqn.losses import QLearning -from acme.agents.jax.dqn.losses import QrDqn +from acme.agents.jax.dqn.learning_lib import LossExtra, LossFn, ReverbUpdate, SGDLearner +from acme.agents.jax.dqn.losses import ( + PrioritizedCategoricalDoubleQLearning, + PrioritizedDoubleQLearning, + QLearning, + QrDqn, +) from acme.agents.jax.dqn.networks import DQNNetworks diff --git a/acme/agents/jax/dqn/actor.py b/acme/agents/jax/dqn/actor.py index 13995077e2..261dd8fe9a 100644 --- a/acme/agents/jax/dqn/actor.py +++ b/acme/agents/jax/dqn/actor.py @@ -16,34 +16,35 @@ from typing import Callable, Sequence -from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax.dqn import networks as dqn_networks -from acme.jax import networks as networks_lib -from acme.jax import utils import chex import jax import jax.numpy as jnp +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax.dqn import networks as dqn_networks +from acme.jax import networks as networks_lib +from acme.jax import utils Epsilon = float -EpsilonPolicy = Callable[[ - networks_lib.Params, networks_lib.PRNGKey, networks_lib - .Observation, Epsilon -], networks_lib.Action] +EpsilonPolicy = Callable[ + [networks_lib.Params, networks_lib.PRNGKey, networks_lib.Observation, Epsilon], + networks_lib.Action, +] @chex.dataclass(frozen=True, mappable_dataclass=False) class EpsilonActorState: - rng: networks_lib.PRNGKey - epsilon: jnp.ndarray + rng: networks_lib.PRNGKey + epsilon: jnp.ndarray DQNPolicy = actor_core_lib.ActorCore[EpsilonActorState, None] -def alternating_epsilons_actor_core(policy_network: EpsilonPolicy, - epsilons: Sequence[float]) -> DQNPolicy: - """Returns actor components for alternating epsilon exploration. +def alternating_epsilons_actor_core( + policy_network: EpsilonPolicy, epsilons: Sequence[float] +) -> DQNPolicy: + """Returns actor components for alternating epsilon exploration. Args: policy_network: A feedforward action selecting function. @@ -52,45 +53,56 @@ def alternating_epsilons_actor_core(policy_network: EpsilonPolicy, Returns: A feedforward policy. """ - epsilons = jnp.array(epsilons) - - def apply_and_sample(params: networks_lib.Params, - observation: networks_lib.Observation, - state: EpsilonActorState): - random_key, key = jax.random.split(state.rng) - actions = policy_network(params, key, observation, state.epsilon) # pytype: disable=wrong-arg-types # jax-ndarray - return (actions.astype(jnp.int32), - EpsilonActorState(rng=random_key, epsilon=state.epsilon)) - - def policy_init(random_key: networks_lib.PRNGKey): - random_key, key = jax.random.split(random_key) - epsilon = jax.random.choice(key, epsilons) - return EpsilonActorState(rng=random_key, epsilon=epsilon) - - return actor_core_lib.ActorCore( - init=policy_init, select_action=apply_and_sample, - get_extras=lambda _: None) + epsilons = jnp.array(epsilons) + + def apply_and_sample( + params: networks_lib.Params, + observation: networks_lib.Observation, + state: EpsilonActorState, + ): + random_key, key = jax.random.split(state.rng) + actions = policy_network( + params, key, observation, state.epsilon + ) # pytype: disable=wrong-arg-types # jax-ndarray + return ( + actions.astype(jnp.int32), + EpsilonActorState(rng=random_key, epsilon=state.epsilon), + ) + + def policy_init(random_key: networks_lib.PRNGKey): + random_key, key = jax.random.split(random_key) + epsilon = jax.random.choice(key, epsilons) + return EpsilonActorState(rng=random_key, epsilon=epsilon) + + return actor_core_lib.ActorCore( + init=policy_init, select_action=apply_and_sample, get_extras=lambda _: None + ) def behavior_policy(networks: dqn_networks.DQNNetworks) -> EpsilonPolicy: - """A policy with parameterized epsilon-greedy exploration.""" - - def apply_and_sample(params: networks_lib.Params, key: networks_lib.PRNGKey, - observation: networks_lib.Observation, epsilon: Epsilon - ) -> networks_lib.Action: - # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. - observation = utils.add_batch_dim(observation) - action_values = networks.policy_network.apply( - params, observation, is_training=False) - action_values = utils.squeeze_batch_dim(action_values) - return networks.sample_fn(action_values, key, epsilon) - - return apply_and_sample - - -def default_behavior_policy(networks: dqn_networks.DQNNetworks, - epsilon: Epsilon) -> EpsilonPolicy: - """A policy with a fixed-epsilon epsilon-greedy exploration. + """A policy with parameterized epsilon-greedy exploration.""" + + def apply_and_sample( + params: networks_lib.Params, + key: networks_lib.PRNGKey, + observation: networks_lib.Observation, + epsilon: Epsilon, + ) -> networks_lib.Action: + # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. + observation = utils.add_batch_dim(observation) + action_values = networks.policy_network.apply( + params, observation, is_training=False + ) + action_values = utils.squeeze_batch_dim(action_values) + return networks.sample_fn(action_values, key, epsilon) + + return apply_and_sample + + +def default_behavior_policy( + networks: dqn_networks.DQNNetworks, epsilon: Epsilon +) -> EpsilonPolicy: + """A policy with a fixed-epsilon epsilon-greedy exploration. DEPRECATED: use behavior_policy instead. Args: @@ -99,16 +111,20 @@ def default_behavior_policy(networks: dqn_networks.DQNNetworks, Returns: epsilon-greedy behavior policy with fixed epsilon """ - # TODO(lukstafi): remove this function and migrate its users. - - def apply_and_sample(params: networks_lib.Params, key: networks_lib.PRNGKey, - observation: networks_lib.Observation, _: Epsilon - ) -> networks_lib.Action: - # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. - observation = utils.add_batch_dim(observation) - action_values = networks.policy_network.apply( - params, observation, is_training=False) - action_values = utils.squeeze_batch_dim(action_values) - return networks.sample_fn(action_values, key, epsilon) - - return apply_and_sample + # TODO(lukstafi): remove this function and migrate its users. + + def apply_and_sample( + params: networks_lib.Params, + key: networks_lib.PRNGKey, + observation: networks_lib.Observation, + _: Epsilon, + ) -> networks_lib.Action: + # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. + observation = utils.add_batch_dim(observation) + action_values = networks.policy_network.apply( + params, observation, is_training=False + ) + action_values = utils.squeeze_batch_dim(action_values) + return networks.sample_fn(action_values, key, epsilon) + + return apply_and_sample diff --git a/acme/agents/jax/dqn/builder.py b/acme/agents/jax/dqn/builder.py index 2e11f4a648..9cd5c5f0b2 100644 --- a/acme/agents/jax/dqn/builder.py +++ b/acme/agents/jax/dqn/builder.py @@ -15,198 +15,209 @@ """DQN Builder.""" from typing import Iterator, List, Optional, Sequence -from acme import adders -from acme import core -from acme import specs +import jax +import optax +import reverb +from reverb import rate_limiters + +from acme import adders, core, specs from acme.adders import reverb as adders_reverb -from acme.agents.jax import actors -from acme.agents.jax import builders +from acme.agents.jax import actors, builders from acme.agents.jax.dqn import actor as dqn_actor from acme.agents.jax.dqn import config as dqn_config from acme.agents.jax.dqn import learning_lib from acme.agents.jax.dqn import networks as dqn_networks from acme.datasets import reverb as datasets from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.jax import variable_utils -from acme.utils import counting -from acme.utils import loggers -import jax -import optax -import reverb -from reverb import rate_limiters +from acme.jax import utils, variable_utils +from acme.utils import counting, loggers -class DQNBuilder(builders.ActorLearnerBuilder[dqn_networks.DQNNetworks, - dqn_actor.DQNPolicy, - utils.PrefetchingSplit]): - """DQN Builder.""" +class DQNBuilder( + builders.ActorLearnerBuilder[ + dqn_networks.DQNNetworks, dqn_actor.DQNPolicy, utils.PrefetchingSplit + ] +): + """DQN Builder.""" - def __init__(self, - config: dqn_config.DQNConfig, - loss_fn: learning_lib.LossFn, - actor_backend: Optional[str] = 'cpu'): - """Creates DQN learner and the behavior policies. + def __init__( + self, + config: dqn_config.DQNConfig, + loss_fn: learning_lib.LossFn, + actor_backend: Optional[str] = "cpu", + ): + """Creates DQN learner and the behavior policies. Args: config: DQN config. loss_fn: A loss function. actor_backend: Which backend to use when jitting the policy. """ - self._config = config - self._loss_fn = loss_fn - self._actor_backend = actor_backend - - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: dqn_networks.DQNNetworks, - dataset: Iterator[utils.PrefetchingSplit], - logger_fn: loggers.LoggerFactory, - environment_spec: Optional[specs.EnvironmentSpec], - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - del environment_spec - - return learning_lib.SGDLearner( - network=networks.policy_network, - random_key=random_key, - optimizer=optax.adam( - self._config.learning_rate, eps=self._config.adam_eps), - target_update_period=self._config.target_update_period, - data_iterator=dataset, - loss_fn=self._loss_fn, - replay_client=replay_client, - replay_table_name=self._config.replay_table_name, - counter=counter, - num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, - logger=logger_fn('learner')) - - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: dqn_actor.DQNPolicy, - environment_spec: Optional[specs.EnvironmentSpec], - variable_source: Optional[core.VariableSource] = None, - adder: Optional[adders.Adder] = None, - ) -> core.Actor: - del environment_spec - assert variable_source is not None - # Inference happens on CPU, so it's better to move variables there too. - variable_client = variable_utils.VariableClient( - variable_source, '', device='cpu') - return actors.GenericActor( - actor=policy, - random_key=random_key, - variable_client=variable_client, - adder=adder, - backend=self._actor_backend) - - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: dqn_actor.DQNPolicy, - ) -> List[reverb.Table]: - """Creates reverb tables for the algorithm.""" - del policy - samples_per_insert_tolerance = ( - self._config.samples_per_insert_tolerance_rate * - self._config.samples_per_insert) - error_buffer = self._config.min_replay_size * samples_per_insert_tolerance - limiter = rate_limiters.SampleToInsertRatio( - min_size_to_sample=self._config.min_replay_size, - samples_per_insert=self._config.samples_per_insert, - error_buffer=error_buffer) - return [ - reverb.Table( - name=self._config.replay_table_name, - sampler=reverb.selectors.Prioritized( - self._config.priority_exponent), - remover=reverb.selectors.Fifo(), - max_size=self._config.max_replay_size, - rate_limiter=limiter, - signature=adders_reverb.NStepTransitionAdder.signature( - environment_spec)) - ] + self._config = config + self._loss_fn = loss_fn + self._actor_backend = actor_backend + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: dqn_networks.DQNNetworks, + dataset: Iterator[utils.PrefetchingSplit], + logger_fn: loggers.LoggerFactory, + environment_spec: Optional[specs.EnvironmentSpec], + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec + + return learning_lib.SGDLearner( + network=networks.policy_network, + random_key=random_key, + optimizer=optax.adam(self._config.learning_rate, eps=self._config.adam_eps), + target_update_period=self._config.target_update_period, + data_iterator=dataset, + loss_fn=self._loss_fn, + replay_client=replay_client, + replay_table_name=self._config.replay_table_name, + counter=counter, + num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, + logger=logger_fn("learner"), + ) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: dqn_actor.DQNPolicy, + environment_spec: Optional[specs.EnvironmentSpec], + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + del environment_spec + assert variable_source is not None + # Inference happens on CPU, so it's better to move variables there too. + variable_client = variable_utils.VariableClient( + variable_source, "", device="cpu" + ) + return actors.GenericActor( + actor=policy, + random_key=random_key, + variable_client=variable_client, + adder=adder, + backend=self._actor_backend, + ) + + def make_replay_tables( + self, environment_spec: specs.EnvironmentSpec, policy: dqn_actor.DQNPolicy, + ) -> List[reverb.Table]: + """Creates reverb tables for the algorithm.""" + del policy + samples_per_insert_tolerance = ( + self._config.samples_per_insert_tolerance_rate + * self._config.samples_per_insert + ) + error_buffer = self._config.min_replay_size * samples_per_insert_tolerance + limiter = rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer, + ) + return [ + reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Prioritized(self._config.priority_exponent), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=adders_reverb.NStepTransitionAdder.signature( + environment_spec + ), + ) + ] + + @property + def batch_size_per_device(self) -> int: + """Splits the batch size across local devices.""" + + # Account for the number of SGD steps per step. + batch_size = self._config.batch_size * self._config.num_sgd_steps_per_step + + num_devices = jax.local_device_count() + # TODO(bshahr): Using jax.device_count will not be valid when colocating + # learning and inference. + + if batch_size % num_devices != 0: + raise ValueError( + "The DQN learner received a batch size that is not divisible by the " + f"number of available learner devices. Got: batch_size={batch_size}, " + f"num_devices={num_devices}." + ) + + return batch_size // num_devices + + def make_dataset_iterator( + self, replay_client: reverb.Client, + ) -> Iterator[utils.PrefetchingSplit]: + """Creates a dataset iterator to use for learning.""" + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=replay_client.server_address, + batch_size=self.batch_size_per_device, + prefetch_size=self._config.prefetch_size, + ) + + return utils.multi_device_put( + dataset.as_numpy_iterator(), + jax.local_devices(), + split_fn=utils.keep_key_on_host, + ) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[dqn_actor.DQNPolicy], + ) -> Optional[adders.Adder]: + """Creates an adder which handles observations.""" + del environment_spec, policy + return adders_reverb.NStepTransitionAdder( + priority_fns={self._config.replay_table_name: None}, + client=replay_client, + n_step=self._config.n_step, + discount=self._config.discount, + ) + + def _policy_epsilons(self, evaluation: bool) -> Sequence[float]: + if evaluation and self._config.eval_epsilon: + epsilon = self._config.eval_epsilon + else: + epsilon = self._config.epsilon + epsilons = epsilon if isinstance(epsilon, Sequence) else (epsilon,) + return epsilons + + def make_policy( + self, + networks: dqn_networks.DQNNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> dqn_actor.DQNPolicy: + """Creates the policy.""" + del environment_spec - @property - def batch_size_per_device(self) -> int: - """Splits the batch size across local devices.""" - - # Account for the number of SGD steps per step. - batch_size = self._config.batch_size * self._config.num_sgd_steps_per_step - - num_devices = jax.local_device_count() - # TODO(bshahr): Using jax.device_count will not be valid when colocating - # learning and inference. - - if batch_size % num_devices != 0: - raise ValueError( - 'The DQN learner received a batch size that is not divisible by the ' - f'number of available learner devices. Got: batch_size={batch_size}, ' - f'num_devices={num_devices}.') - - return batch_size // num_devices - - def make_dataset_iterator( - self, - replay_client: reverb.Client, - ) -> Iterator[utils.PrefetchingSplit]: - """Creates a dataset iterator to use for learning.""" - dataset = datasets.make_reverb_dataset( - table=self._config.replay_table_name, - server_address=replay_client.server_address, - batch_size=self.batch_size_per_device, - prefetch_size=self._config.prefetch_size) - - return utils.multi_device_put( - dataset.as_numpy_iterator(), - jax.local_devices(), - split_fn=utils.keep_key_on_host) - - def make_adder( - self, - replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[dqn_actor.DQNPolicy], - ) -> Optional[adders.Adder]: - """Creates an adder which handles observations.""" - del environment_spec, policy - return adders_reverb.NStepTransitionAdder( - priority_fns={self._config.replay_table_name: None}, - client=replay_client, - n_step=self._config.n_step, - discount=self._config.discount) - - def _policy_epsilons(self, evaluation: bool) -> Sequence[float]: - if evaluation and self._config.eval_epsilon: - epsilon = self._config.eval_epsilon - else: - epsilon = self._config.epsilon - epsilons = epsilon if isinstance(epsilon, Sequence) else (epsilon,) - return epsilons - - def make_policy(self, - networks: dqn_networks.DQNNetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> dqn_actor.DQNPolicy: - """Creates the policy.""" - del environment_spec - - return dqn_actor.alternating_epsilons_actor_core( - dqn_actor.behavior_policy(networks), - epsilons=self._policy_epsilons(evaluation)) + return dqn_actor.alternating_epsilons_actor_core( + dqn_actor.behavior_policy(networks), + epsilons=self._policy_epsilons(evaluation), + ) class DistributionalDQNBuilder(DQNBuilder): - """Distributional DQN Builder.""" + """Distributional DQN Builder.""" - def make_policy(self, - networks: dqn_networks.DQNNetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> dqn_actor.DQNPolicy: - """Creates the policy. + def make_policy( + self, + networks: dqn_networks.DQNNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> dqn_actor.DQNPolicy: + """Creates the policy. Expects network head which returns a tuple with the first entry representing q-values. @@ -225,18 +236,25 @@ def make_policy(self, Returns: Behavior policy or evaluation policy for the agent. """ - del environment_spec + del environment_spec - def get_action_values(params: networks_lib.Params, - observation: networks_lib.Observation, *args, - **kwargs) -> networks_lib.NetworkOutput: - return networks.policy_network.apply(params, observation, *args, - **kwargs)[0] + def get_action_values( + params: networks_lib.Params, + observation: networks_lib.Observation, + *args, + **kwargs, + ) -> networks_lib.NetworkOutput: + return networks.policy_network.apply(params, observation, *args, **kwargs)[ + 0 + ] - typed_network = networks_lib.TypedFeedForwardNetwork( - init=networks.policy_network.init, apply=get_action_values) - behavior_policy = dqn_actor.behavior_policy( - dqn_networks.DQNNetworks(policy_network=typed_network)) + typed_network = networks_lib.TypedFeedForwardNetwork( + init=networks.policy_network.init, apply=get_action_values + ) + behavior_policy = dqn_actor.behavior_policy( + dqn_networks.DQNNetworks(policy_network=typed_network) + ) - return dqn_actor.alternating_epsilons_actor_core( - behavior_policy, epsilons=self._policy_epsilons(evaluation)) + return dqn_actor.alternating_epsilons_actor_core( + behavior_policy, epsilons=self._policy_epsilons(evaluation) + ) diff --git a/acme/agents/jax/dqn/config.py b/acme/agents/jax/dqn/config.py index 3200f02367..c2f10d6926 100644 --- a/acme/agents/jax/dqn/config.py +++ b/acme/agents/jax/dqn/config.py @@ -17,14 +17,15 @@ import dataclasses from typing import Callable, Optional, Sequence, Union -from acme.adders import reverb as adders_reverb import jax.numpy as jnp import numpy as np +from acme.adders import reverb as adders_reverb + @dataclasses.dataclass class DQNConfig: - """Configuration options for DQN agent. + """Configuration options for DQN agent. Attributes: epsilon: for use by epsilon-greedy policies. If multiple, the epsilons are @@ -52,36 +53,36 @@ class DQNConfig: num_sgd_steps_per_step: How many gradient updates to perform per learner step. """ - epsilon: Union[float, Sequence[float]] = 0.05 - eval_epsilon: Optional[float] = None - # TODO(b/191706065): update all clients and remove this field. - seed: int = 1 - # Learning rule - learning_rate: Union[float, Callable[[int], float]] = 1e-3 - adam_eps: float = 1e-8 # Eps for Adam optimizer. - discount: float = 0.99 # Discount rate applied to value per timestep. - n_step: int = 5 # N-step TD learning. - target_update_period: int = 100 # Update target network every period. - max_gradient_norm: float = np.inf # For gradient clipping. + epsilon: Union[float, Sequence[float]] = 0.05 + eval_epsilon: Optional[float] = None + # TODO(b/191706065): update all clients and remove this field. + seed: int = 1 + + # Learning rule + learning_rate: Union[float, Callable[[int], float]] = 1e-3 + adam_eps: float = 1e-8 # Eps for Adam optimizer. + discount: float = 0.99 # Discount rate applied to value per timestep. + n_step: int = 5 # N-step TD learning. + target_update_period: int = 100 # Update target network every period. + max_gradient_norm: float = np.inf # For gradient clipping. - # Replay options - batch_size: int = 256 - min_replay_size: int = 1_000 - max_replay_size: int = 1_000_000 - replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE - importance_sampling_exponent: float = 0.2 - priority_exponent: float = 0.6 - prefetch_size: int = 4 - samples_per_insert: float = 0.5 - samples_per_insert_tolerance_rate: float = 0.1 + # Replay options + batch_size: int = 256 + min_replay_size: int = 1_000 + max_replay_size: int = 1_000_000 + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE + importance_sampling_exponent: float = 0.2 + priority_exponent: float = 0.6 + prefetch_size: int = 4 + samples_per_insert: float = 0.5 + samples_per_insert_tolerance_rate: float = 0.1 - num_sgd_steps_per_step: int = 1 + num_sgd_steps_per_step: int = 1 -def logspace_epsilons(num_epsilons: int, epsilon: float = 0.017 - ) -> Sequence[float]: - """`num_epsilons` of logspace-distributed values, with median `epsilon`.""" - if num_epsilons <= 1: - return (epsilon,) - return jnp.logspace(1, 8, num_epsilons, base=epsilon ** (2./9.)) +def logspace_epsilons(num_epsilons: int, epsilon: float = 0.017) -> Sequence[float]: + """`num_epsilons` of logspace-distributed values, with median `epsilon`.""" + if num_epsilons <= 1: + return (epsilon,) + return jnp.logspace(1, 8, num_epsilons, base=epsilon ** (2.0 / 9.0)) diff --git a/acme/agents/jax/dqn/learning.py b/acme/agents/jax/dqn/learning.py index a6375c24e3..be889bf342 100644 --- a/acme/agents/jax/dqn/learning.py +++ b/acme/agents/jax/dqn/learning.py @@ -16,56 +16,57 @@ from typing import Iterator, Optional +import optax +import reverb + from acme.adders import reverb as adders -from acme.agents.jax.dqn import learning_lib -from acme.agents.jax.dqn import losses +from acme.agents.jax.dqn import learning_lib, losses from acme.jax import networks as networks_lib from acme.jax import utils -from acme.utils import counting -from acme.utils import loggers -import optax -import reverb +from acme.utils import counting, loggers class DQNLearner(learning_lib.SGDLearner): - """DQN learner. + """DQN learner. We are in the process of migrating towards a more general SGDLearner to allow for easy configuration of the loss. This is maintained now for compatibility. """ - def __init__(self, - network: networks_lib.TypedFeedForwardNetwork, - discount: float, - importance_sampling_exponent: float, - target_update_period: int, - iterator: Iterator[utils.PrefetchingSplit], - optimizer: optax.GradientTransformation, - random_key: networks_lib.PRNGKey, - max_abs_reward: float = 1., - huber_loss_parameter: float = 1., - replay_client: Optional[reverb.Client] = None, - replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - num_sgd_steps_per_step: int = 1): - """Initializes the learner.""" - loss_fn = losses.PrioritizedDoubleQLearning( - discount=discount, - importance_sampling_exponent=importance_sampling_exponent, - max_abs_reward=max_abs_reward, - huber_loss_parameter=huber_loss_parameter, - ) - super().__init__( - network=network, - loss_fn=loss_fn, - optimizer=optimizer, - data_iterator=iterator, - target_update_period=target_update_period, - random_key=random_key, - replay_client=replay_client, - replay_table_name=replay_table_name, - counter=counter, - logger=logger, - num_sgd_steps_per_step=num_sgd_steps_per_step, - ) + def __init__( + self, + network: networks_lib.TypedFeedForwardNetwork, + discount: float, + importance_sampling_exponent: float, + target_update_period: int, + iterator: Iterator[utils.PrefetchingSplit], + optimizer: optax.GradientTransformation, + random_key: networks_lib.PRNGKey, + max_abs_reward: float = 1.0, + huber_loss_parameter: float = 1.0, + replay_client: Optional[reverb.Client] = None, + replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + num_sgd_steps_per_step: int = 1, + ): + """Initializes the learner.""" + loss_fn = losses.PrioritizedDoubleQLearning( + discount=discount, + importance_sampling_exponent=importance_sampling_exponent, + max_abs_reward=max_abs_reward, + huber_loss_parameter=huber_loss_parameter, + ) + super().__init__( + network=network, + loss_fn=loss_fn, + optimizer=optimizer, + data_iterator=iterator, + target_update_period=target_update_period, + random_key=random_key, + replay_client=replay_client, + replay_table_name=replay_table_name, + counter=counter, + logger=logger, + num_sgd_steps_per_step=num_sgd_steps_per_step, + ) diff --git a/acme/agents/jax/dqn/learning_lib.py b/acme/agents/jax/dqn/learning_lib.py index af860d9cc0..dcda9c4df3 100644 --- a/acme/agents/jax/dqn/learning_lib.py +++ b/acme/agents/jax/dqn/learning_lib.py @@ -18,13 +18,6 @@ import time from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple -import acme -from acme.adders import reverb as adders -from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.utils import async_utils -from acme.utils import counting -from acme.utils import loggers import jax import jax.numpy as jnp import optax @@ -32,190 +25,214 @@ import tree import typing_extensions +import acme +from acme.adders import reverb as adders +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import async_utils, counting, loggers # The pmap axis name. Data means data parallelization. -PMAP_AXIS_NAME = 'data' +PMAP_AXIS_NAME = "data" class ReverbUpdate(NamedTuple): - """Tuple for updating reverb priority information.""" - keys: jnp.ndarray - priorities: jnp.ndarray + """Tuple for updating reverb priority information.""" + + keys: jnp.ndarray + priorities: jnp.ndarray class LossExtra(NamedTuple): - """Extra information that is returned along with loss value.""" - metrics: Dict[str, jax.Array] - # New optional updated priorities for the samples. - reverb_priorities: Optional[jax.Array] = None + """Extra information that is returned along with loss value.""" + + metrics: Dict[str, jax.Array] + # New optional updated priorities for the samples. + reverb_priorities: Optional[jax.Array] = None class LossFn(typing_extensions.Protocol): - """A LossFn calculates a loss on a single batch of data.""" + """A LossFn calculates a loss on a single batch of data.""" - def __call__( - self, - network: networks_lib.TypedFeedForwardNetwork, - params: networks_lib.Params, - target_params: networks_lib.Params, - batch: reverb.ReplaySample, - key: networks_lib.PRNGKey, - ) -> Tuple[jax.Array, LossExtra]: - """Calculates a loss on a single batch of data.""" + def __call__( + self, + network: networks_lib.TypedFeedForwardNetwork, + params: networks_lib.Params, + target_params: networks_lib.Params, + batch: reverb.ReplaySample, + key: networks_lib.PRNGKey, + ) -> Tuple[jax.Array, LossExtra]: + """Calculates a loss on a single batch of data.""" class TrainingState(NamedTuple): - """Holds the agent's training state.""" - params: networks_lib.Params - target_params: networks_lib.Params - opt_state: optax.OptState - steps: int - rng_key: networks_lib.PRNGKey + """Holds the agent's training state.""" + + params: networks_lib.Params + target_params: networks_lib.Params + opt_state: optax.OptState + steps: int + rng_key: networks_lib.PRNGKey class SGDLearner(acme.Learner): - """An Acme learner based around SGD on batches. + """An Acme learner based around SGD on batches. This learner currently supports optional prioritized replay and assumes a TrainingState as described above. """ - def __init__(self, - network: networks_lib.TypedFeedForwardNetwork, - loss_fn: LossFn, - optimizer: optax.GradientTransformation, - data_iterator: Iterator[utils.PrefetchingSplit], - target_update_period: int, - random_key: networks_lib.PRNGKey, - replay_client: Optional[reverb.Client] = None, - replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - num_sgd_steps_per_step: int = 1): - """Initialize the SGD learner.""" - self.network = network - - # Internalize the loss_fn with network. - self._loss = jax.jit(functools.partial(loss_fn, self.network)) - - # SGD performs the loss, optimizer update and periodic target net update. - def sgd_step(state: TrainingState, - batch: reverb.ReplaySample) -> Tuple[TrainingState, LossExtra]: - next_rng_key, rng_key = jax.random.split(state.rng_key) - # Implements one SGD step of the loss and updates training state - (loss, extra), grads = jax.value_and_grad( - self._loss, has_aux=True)(state.params, state.target_params, batch, - rng_key) - - loss = jax.lax.pmean(loss, axis_name=PMAP_AXIS_NAME) - # Average gradients over pmap replicas before optimizer update. - grads = jax.lax.pmean(grads, axis_name=PMAP_AXIS_NAME) - # Apply the optimizer updates - updates, new_opt_state = optimizer.update(grads, state.opt_state) - new_params = optax.apply_updates(state.params, updates) - - extra.metrics.update({'total_loss': loss}) - - # Periodically update target networks. - steps = state.steps + 1 - target_params = optax.periodic_update(new_params, state.target_params, # pytype: disable=wrong-arg-types # numpy-scalars - steps, target_update_period) - - new_training_state = TrainingState( - new_params, target_params, new_opt_state, steps, next_rng_key) - return new_training_state, extra - - def postprocess_aux(extra: LossExtra) -> LossExtra: - reverb_priorities = jax.tree_util.tree_map( - lambda a: jnp.reshape(a, (-1, *a.shape[2:])), extra.reverb_priorities) - return extra._replace( - metrics=jax.tree_util.tree_map(jnp.mean, extra.metrics), - reverb_priorities=reverb_priorities) - - self._num_sgd_steps_per_step = num_sgd_steps_per_step - sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step, - postprocess_aux) - self._sgd_step = jax.pmap( - sgd_step, axis_name=PMAP_AXIS_NAME, devices=jax.devices()) - - # Internalise agent components - self._data_iterator = data_iterator - self._target_update_period = target_update_period - self._counter = counter or counting.Counter() - self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - # Initialize the network parameters - key_params, key_target, key_state = jax.random.split(random_key, 3) - initial_params = self.network.init(key_params) - initial_target_params = self.network.init(key_target) - state = TrainingState( - params=initial_params, - target_params=initial_target_params, - opt_state=optimizer.init(initial_params), - steps=0, - rng_key=key_state, - ) - self._state = utils.replicate_in_all_devices(state, jax.local_devices()) - - # Update replay priorities - def update_priorities(reverb_update: ReverbUpdate) -> None: - if replay_client is None: - return - keys, priorities = tree.map_structure( - # Fetch array and combine device and batch dimensions. - lambda x: utils.fetch_devicearray(x).reshape((-1,) + x.shape[2:]), - (reverb_update.keys, reverb_update.priorities)) - replay_client.mutate_priorities( - table=replay_table_name, - updates=dict(zip(keys, priorities))) - self._replay_client = replay_client - self._async_priority_updater = async_utils.AsyncExecutor(update_priorities) - - self._current_step = 0 - - def step(self): - """Takes one SGD step on the learner.""" - with jax.profiler.StepTraceAnnotation('step', step_num=self._current_step): - prefetching_split = next(self._data_iterator) - # In this case the host property of the prefetching split contains only - # replay keys and the device property is the prefetched full original - # sample. Key is on host since it's uint64 type. - reverb_keys = prefetching_split.host - batch: reverb.ReplaySample = prefetching_split.device - - self._state, extra = self._sgd_step(self._state, batch) - # Compute elapsed time. - timestamp = time.time() - elapsed = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - if self._replay_client and extra.reverb_priorities is not None: - reverb_update = ReverbUpdate(reverb_keys, extra.reverb_priorities) - self._async_priority_updater.put(reverb_update) - - steps_per_sec = (self._num_sgd_steps_per_step / elapsed) if elapsed else 0 - self._current_step, metrics = utils.get_from_first_device( - (self._state.steps, extra.metrics)) - metrics['steps_per_second'] = steps_per_sec - - # Update our counts and record it. - result = self._counter.increment( - steps=self._num_sgd_steps_per_step, walltime=elapsed) - result.update(metrics) - self._logger.write(result) - - def get_variables(self, names: List[str]) -> List[networks_lib.Params]: - # Return first replica of parameters. - return utils.get_from_first_device([self._state.params]) - - def save(self) -> TrainingState: - # Serialize only the first replica of parameters and optimizer state. - return utils.get_from_first_device(self._state) - - def restore(self, state: TrainingState): - self._state = utils.replicate_in_all_devices(state, jax.local_devices()) + def __init__( + self, + network: networks_lib.TypedFeedForwardNetwork, + loss_fn: LossFn, + optimizer: optax.GradientTransformation, + data_iterator: Iterator[utils.PrefetchingSplit], + target_update_period: int, + random_key: networks_lib.PRNGKey, + replay_client: Optional[reverb.Client] = None, + replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + num_sgd_steps_per_step: int = 1, + ): + """Initialize the SGD learner.""" + self.network = network + + # Internalize the loss_fn with network. + self._loss = jax.jit(functools.partial(loss_fn, self.network)) + + # SGD performs the loss, optimizer update and periodic target net update. + def sgd_step( + state: TrainingState, batch: reverb.ReplaySample + ) -> Tuple[TrainingState, LossExtra]: + next_rng_key, rng_key = jax.random.split(state.rng_key) + # Implements one SGD step of the loss and updates training state + (loss, extra), grads = jax.value_and_grad(self._loss, has_aux=True)( + state.params, state.target_params, batch, rng_key + ) + + loss = jax.lax.pmean(loss, axis_name=PMAP_AXIS_NAME) + # Average gradients over pmap replicas before optimizer update. + grads = jax.lax.pmean(grads, axis_name=PMAP_AXIS_NAME) + # Apply the optimizer updates + updates, new_opt_state = optimizer.update(grads, state.opt_state) + new_params = optax.apply_updates(state.params, updates) + + extra.metrics.update({"total_loss": loss}) + + # Periodically update target networks. + steps = state.steps + 1 + target_params = optax.periodic_update( + new_params, + state.target_params, # pytype: disable=wrong-arg-types # numpy-scalars + steps, + target_update_period, + ) + + new_training_state = TrainingState( + new_params, target_params, new_opt_state, steps, next_rng_key + ) + return new_training_state, extra + + def postprocess_aux(extra: LossExtra) -> LossExtra: + reverb_priorities = jax.tree_util.tree_map( + lambda a: jnp.reshape(a, (-1, *a.shape[2:])), extra.reverb_priorities + ) + return extra._replace( + metrics=jax.tree_util.tree_map(jnp.mean, extra.metrics), + reverb_priorities=reverb_priorities, + ) + + self._num_sgd_steps_per_step = num_sgd_steps_per_step + sgd_step = utils.process_multiple_batches( + sgd_step, num_sgd_steps_per_step, postprocess_aux + ) + self._sgd_step = jax.pmap( + sgd_step, axis_name=PMAP_AXIS_NAME, devices=jax.devices() + ) + + # Internalise agent components + self._data_iterator = data_iterator + self._target_update_period = target_update_period + self._counter = counter or counting.Counter() + self._logger = logger or loggers.TerminalLogger("learner", time_delta=1.0) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + # Initialize the network parameters + key_params, key_target, key_state = jax.random.split(random_key, 3) + initial_params = self.network.init(key_params) + initial_target_params = self.network.init(key_target) + state = TrainingState( + params=initial_params, + target_params=initial_target_params, + opt_state=optimizer.init(initial_params), + steps=0, + rng_key=key_state, + ) + self._state = utils.replicate_in_all_devices(state, jax.local_devices()) + + # Update replay priorities + def update_priorities(reverb_update: ReverbUpdate) -> None: + if replay_client is None: + return + keys, priorities = tree.map_structure( + # Fetch array and combine device and batch dimensions. + lambda x: utils.fetch_devicearray(x).reshape((-1,) + x.shape[2:]), + (reverb_update.keys, reverb_update.priorities), + ) + replay_client.mutate_priorities( + table=replay_table_name, updates=dict(zip(keys, priorities)) + ) + + self._replay_client = replay_client + self._async_priority_updater = async_utils.AsyncExecutor(update_priorities) + + self._current_step = 0 + + def step(self): + """Takes one SGD step on the learner.""" + with jax.profiler.StepTraceAnnotation("step", step_num=self._current_step): + prefetching_split = next(self._data_iterator) + # In this case the host property of the prefetching split contains only + # replay keys and the device property is the prefetched full original + # sample. Key is on host since it's uint64 type. + reverb_keys = prefetching_split.host + batch: reverb.ReplaySample = prefetching_split.device + + self._state, extra = self._sgd_step(self._state, batch) + # Compute elapsed time. + timestamp = time.time() + elapsed = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + if self._replay_client and extra.reverb_priorities is not None: + reverb_update = ReverbUpdate(reverb_keys, extra.reverb_priorities) + self._async_priority_updater.put(reverb_update) + + steps_per_sec = (self._num_sgd_steps_per_step / elapsed) if elapsed else 0 + self._current_step, metrics = utils.get_from_first_device( + (self._state.steps, extra.metrics) + ) + metrics["steps_per_second"] = steps_per_sec + + # Update our counts and record it. + result = self._counter.increment( + steps=self._num_sgd_steps_per_step, walltime=elapsed + ) + result.update(metrics) + self._logger.write(result) + + def get_variables(self, names: List[str]) -> List[networks_lib.Params]: + # Return first replica of parameters. + return utils.get_from_first_device([self._state.params]) + + def save(self) -> TrainingState: + # Serialize only the first replica of parameters and optimizer state. + return utils.get_from_first_device(self._state) + + def restore(self, state: TrainingState): + self._state = utils.replicate_in_all_devices(state, jax.local_devices()) diff --git a/acme/agents/jax/dqn/losses.py b/acme/agents/jax/dqn/losses.py index 26c769fac0..86340f752e 100644 --- a/acme/agents/jax/dqn/losses.py +++ b/acme/agents/jax/dqn/losses.py @@ -16,333 +16,370 @@ import dataclasses from typing import Tuple -from acme import types -from acme.agents.jax.dqn import learning_lib -from acme.jax import networks as networks_lib import chex import jax import jax.numpy as jnp import reverb import rlax +from acme import types +from acme.agents.jax.dqn import learning_lib +from acme.jax import networks as networks_lib + @dataclasses.dataclass class PrioritizedDoubleQLearning(learning_lib.LossFn): - """Clipped double q learning with prioritization on TD error.""" - discount: float = 0.99 - importance_sampling_exponent: float = 0.2 - max_abs_reward: float = 1. - huber_loss_parameter: float = 1. - - def __call__( - self, - network: networks_lib.TypedFeedForwardNetwork, - params: networks_lib.Params, - target_params: networks_lib.Params, - batch: reverb.ReplaySample, - key: networks_lib.PRNGKey, - ) -> Tuple[jax.Array, learning_lib.LossExtra]: - """Calculate a loss on a single batch of data.""" - transitions: types.Transition = batch.data - probs = batch.info.probability - - # Forward pass. - key1, key2, key3 = jax.random.split(key, 3) - q_tm1 = network.apply( - params, transitions.observation, is_training=True, key=key1) - q_t_value = network.apply( - target_params, transitions.next_observation, is_training=True, key=key2) - q_t_selector = network.apply( - params, transitions.next_observation, is_training=True, key=key3) - - # Cast and clip rewards. - d_t = (transitions.discount * self.discount).astype(jnp.float32) - r_t = jnp.clip(transitions.reward, -self.max_abs_reward, - self.max_abs_reward).astype(jnp.float32) - - # Compute double Q-learning n-step TD-error. - batch_error = jax.vmap(rlax.double_q_learning) - td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t_value, - q_t_selector) - batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter) - - # Importance weighting. - importance_weights = (1. / probs).astype(jnp.float32) - importance_weights **= self.importance_sampling_exponent - importance_weights /= jnp.max(importance_weights) - - # Reweight. - loss = jnp.mean(importance_weights * batch_loss) # [] - extra = learning_lib.LossExtra( - metrics={}, reverb_priorities=jnp.abs(td_error).astype(jnp.float64)) - return loss, extra + """Clipped double q learning with prioritization on TD error.""" + + discount: float = 0.99 + importance_sampling_exponent: float = 0.2 + max_abs_reward: float = 1.0 + huber_loss_parameter: float = 1.0 + + def __call__( + self, + network: networks_lib.TypedFeedForwardNetwork, + params: networks_lib.Params, + target_params: networks_lib.Params, + batch: reverb.ReplaySample, + key: networks_lib.PRNGKey, + ) -> Tuple[jax.Array, learning_lib.LossExtra]: + """Calculate a loss on a single batch of data.""" + transitions: types.Transition = batch.data + probs = batch.info.probability + + # Forward pass. + key1, key2, key3 = jax.random.split(key, 3) + q_tm1 = network.apply( + params, transitions.observation, is_training=True, key=key1 + ) + q_t_value = network.apply( + target_params, transitions.next_observation, is_training=True, key=key2 + ) + q_t_selector = network.apply( + params, transitions.next_observation, is_training=True, key=key3 + ) + + # Cast and clip rewards. + d_t = (transitions.discount * self.discount).astype(jnp.float32) + r_t = jnp.clip( + transitions.reward, -self.max_abs_reward, self.max_abs_reward + ).astype(jnp.float32) + + # Compute double Q-learning n-step TD-error. + batch_error = jax.vmap(rlax.double_q_learning) + td_error = batch_error( + q_tm1, transitions.action, r_t, d_t, q_t_value, q_t_selector + ) + batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter) + + # Importance weighting. + importance_weights = (1.0 / probs).astype(jnp.float32) + importance_weights **= self.importance_sampling_exponent + importance_weights /= jnp.max(importance_weights) + + # Reweight. + loss = jnp.mean(importance_weights * batch_loss) # [] + extra = learning_lib.LossExtra( + metrics={}, reverb_priorities=jnp.abs(td_error).astype(jnp.float64) + ) + return loss, extra @dataclasses.dataclass class QrDqn(learning_lib.LossFn): - """Quantile Regression DQN. + """Quantile Regression DQN. https://arxiv.org/abs/1710.10044 """ - num_atoms: int = 51 - huber_param: float = 1.0 - - def __call__( - self, - network: networks_lib.TypedFeedForwardNetwork, - params: networks_lib.Params, - target_params: networks_lib.Params, - batch: reverb.ReplaySample, - key: networks_lib.PRNGKey, - ) -> Tuple[jax.Array, learning_lib.LossExtra]: - """Calculate a loss on a single batch of data.""" - transitions: types.Transition = batch.data - key1, key2 = jax.random.split(key) - _, dist_q_tm1 = network.apply( - params, transitions.observation, is_training=True, key=key1) - _, dist_q_target_t = network.apply( - target_params, transitions.next_observation, is_training=True, key=key2) - batch_size = len(transitions.observation) - chex.assert_shape( - dist_q_tm1, ( - batch_size, - None, - self.num_atoms, - ), - custom_message=f'Expected (batch_size, num_actions, num_atoms), got: {dist_q_tm1.shape}', - include_default_message=True) - chex.assert_shape( - dist_q_target_t, ( - batch_size, - None, - self.num_atoms, - ), - custom_message=f'Expected (batch_size, num_actions, num_atoms), got: {dist_q_target_t.shape}', - include_default_message=True) - # Swap distribution and action dimension, since - # rlax.quantile_q_learning expects it that way. - dist_q_tm1 = jnp.swapaxes(dist_q_tm1, 1, 2) - dist_q_target_t = jnp.swapaxes(dist_q_target_t, 1, 2) - quantiles = ( - (jnp.arange(self.num_atoms, dtype=jnp.float32) + 0.5) / self.num_atoms) - batch_quantile_q_learning = jax.vmap( - rlax.quantile_q_learning, in_axes=(0, None, 0, 0, 0, 0, 0, None)) - losses = batch_quantile_q_learning( - dist_q_tm1, - quantiles, - transitions.action, - transitions.reward, - transitions.discount, - dist_q_target_t, # No double Q-learning here. - dist_q_target_t, - self.huber_param, - ) - loss = jnp.mean(losses) - chex.assert_shape(losses, (batch_size,)) - extra = learning_lib.LossExtra(metrics={'mean_loss': loss}) - return loss, extra + + num_atoms: int = 51 + huber_param: float = 1.0 + + def __call__( + self, + network: networks_lib.TypedFeedForwardNetwork, + params: networks_lib.Params, + target_params: networks_lib.Params, + batch: reverb.ReplaySample, + key: networks_lib.PRNGKey, + ) -> Tuple[jax.Array, learning_lib.LossExtra]: + """Calculate a loss on a single batch of data.""" + transitions: types.Transition = batch.data + key1, key2 = jax.random.split(key) + _, dist_q_tm1 = network.apply( + params, transitions.observation, is_training=True, key=key1 + ) + _, dist_q_target_t = network.apply( + target_params, transitions.next_observation, is_training=True, key=key2 + ) + batch_size = len(transitions.observation) + chex.assert_shape( + dist_q_tm1, + (batch_size, None, self.num_atoms,), + custom_message=f"Expected (batch_size, num_actions, num_atoms), got: {dist_q_tm1.shape}", + include_default_message=True, + ) + chex.assert_shape( + dist_q_target_t, + (batch_size, None, self.num_atoms,), + custom_message=f"Expected (batch_size, num_actions, num_atoms), got: {dist_q_target_t.shape}", + include_default_message=True, + ) + # Swap distribution and action dimension, since + # rlax.quantile_q_learning expects it that way. + dist_q_tm1 = jnp.swapaxes(dist_q_tm1, 1, 2) + dist_q_target_t = jnp.swapaxes(dist_q_target_t, 1, 2) + quantiles = ( + jnp.arange(self.num_atoms, dtype=jnp.float32) + 0.5 + ) / self.num_atoms + batch_quantile_q_learning = jax.vmap( + rlax.quantile_q_learning, in_axes=(0, None, 0, 0, 0, 0, 0, None) + ) + losses = batch_quantile_q_learning( + dist_q_tm1, + quantiles, + transitions.action, + transitions.reward, + transitions.discount, + dist_q_target_t, # No double Q-learning here. + dist_q_target_t, + self.huber_param, + ) + loss = jnp.mean(losses) + chex.assert_shape(losses, (batch_size,)) + extra = learning_lib.LossExtra(metrics={"mean_loss": loss}) + return loss, extra @dataclasses.dataclass class PrioritizedCategoricalDoubleQLearning(learning_lib.LossFn): - """Categorical double q learning with prioritization on TD error.""" - discount: float = 0.99 - importance_sampling_exponent: float = 0.2 - max_abs_reward: float = 1. - - def __call__( - self, - network: networks_lib.TypedFeedForwardNetwork, - params: networks_lib.Params, - target_params: networks_lib.Params, - batch: reverb.ReplaySample, - key: networks_lib.PRNGKey, - ) -> Tuple[jax.Array, learning_lib.LossExtra]: - """Calculate a loss on a single batch of data.""" - transitions: types.Transition = batch.data - probs = batch.info.probability - - # Forward pass. - key1, key2, key3 = jax.random.split(key, 3) - _, logits_tm1, atoms_tm1 = network.apply( - params, transitions.observation, is_training=True, key=key1) - _, logits_t, atoms_t = network.apply( - target_params, transitions.next_observation, is_training=True, key=key2) - q_t_selector, _, _ = network.apply( - params, transitions.next_observation, is_training=True, key=key3) - - # Cast and clip rewards. - d_t = (transitions.discount * self.discount).astype(jnp.float32) - r_t = jnp.clip(transitions.reward, -self.max_abs_reward, - self.max_abs_reward).astype(jnp.float32) - - # Compute categorical double Q-learning loss. - batch_loss_fn = jax.vmap( - rlax.categorical_double_q_learning, - in_axes=(None, 0, 0, 0, 0, None, 0, 0)) - batch_loss = batch_loss_fn(atoms_tm1, logits_tm1, transitions.action, r_t, - d_t, atoms_t, logits_t, q_t_selector) - - # Importance weighting. - importance_weights = (1. / probs).astype(jnp.float32) - importance_weights **= self.importance_sampling_exponent - importance_weights /= jnp.max(importance_weights) - - # Reweight. - loss = jnp.mean(importance_weights * batch_loss) # [] - extra = learning_lib.LossExtra( - metrics={}, reverb_priorities=jnp.abs(batch_loss).astype(jnp.float64)) - return loss, extra + """Categorical double q learning with prioritization on TD error.""" + + discount: float = 0.99 + importance_sampling_exponent: float = 0.2 + max_abs_reward: float = 1.0 + + def __call__( + self, + network: networks_lib.TypedFeedForwardNetwork, + params: networks_lib.Params, + target_params: networks_lib.Params, + batch: reverb.ReplaySample, + key: networks_lib.PRNGKey, + ) -> Tuple[jax.Array, learning_lib.LossExtra]: + """Calculate a loss on a single batch of data.""" + transitions: types.Transition = batch.data + probs = batch.info.probability + + # Forward pass. + key1, key2, key3 = jax.random.split(key, 3) + _, logits_tm1, atoms_tm1 = network.apply( + params, transitions.observation, is_training=True, key=key1 + ) + _, logits_t, atoms_t = network.apply( + target_params, transitions.next_observation, is_training=True, key=key2 + ) + q_t_selector, _, _ = network.apply( + params, transitions.next_observation, is_training=True, key=key3 + ) + + # Cast and clip rewards. + d_t = (transitions.discount * self.discount).astype(jnp.float32) + r_t = jnp.clip( + transitions.reward, -self.max_abs_reward, self.max_abs_reward + ).astype(jnp.float32) + + # Compute categorical double Q-learning loss. + batch_loss_fn = jax.vmap( + rlax.categorical_double_q_learning, in_axes=(None, 0, 0, 0, 0, None, 0, 0) + ) + batch_loss = batch_loss_fn( + atoms_tm1, + logits_tm1, + transitions.action, + r_t, + d_t, + atoms_t, + logits_t, + q_t_selector, + ) + + # Importance weighting. + importance_weights = (1.0 / probs).astype(jnp.float32) + importance_weights **= self.importance_sampling_exponent + importance_weights /= jnp.max(importance_weights) + + # Reweight. + loss = jnp.mean(importance_weights * batch_loss) # [] + extra = learning_lib.LossExtra( + metrics={}, reverb_priorities=jnp.abs(batch_loss).astype(jnp.float64) + ) + return loss, extra @dataclasses.dataclass class QLearning(learning_lib.LossFn): - """Deep q learning. + """Deep q learning. This matches the original DQN loss: https://arxiv.org/abs/1312.5602. It differs by two aspects that improve it on the optimization side - it uses Adam instead of RMSProp as an optimizer - it uses a square loss instead of the Huber one. """ - discount: float = 0.99 - max_abs_reward: float = 1. - - def __call__( - self, - network: networks_lib.TypedFeedForwardNetwork, - params: networks_lib.Params, - target_params: networks_lib.Params, - batch: reverb.ReplaySample, - key: networks_lib.PRNGKey, - ) -> Tuple[jax.Array, learning_lib.LossExtra]: - """Calculate a loss on a single batch of data.""" - transitions: types.Transition = batch.data - - # Forward pass. - key1, key2 = jax.random.split(key) - q_tm1 = network.apply( - params, transitions.observation, is_training=True, key=key1) - q_t = network.apply( - target_params, transitions.next_observation, is_training=True, key=key2) - - # Cast and clip rewards. - d_t = (transitions.discount * self.discount).astype(jnp.float32) - r_t = jnp.clip(transitions.reward, -self.max_abs_reward, - self.max_abs_reward).astype(jnp.float32) - - # Compute Q-learning TD-error. - batch_error = jax.vmap(rlax.q_learning) - td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t) - batch_loss = jnp.square(td_error) - - loss = jnp.mean(batch_loss) - extra = learning_lib.LossExtra(metrics={}) - return loss, extra + + discount: float = 0.99 + max_abs_reward: float = 1.0 + + def __call__( + self, + network: networks_lib.TypedFeedForwardNetwork, + params: networks_lib.Params, + target_params: networks_lib.Params, + batch: reverb.ReplaySample, + key: networks_lib.PRNGKey, + ) -> Tuple[jax.Array, learning_lib.LossExtra]: + """Calculate a loss on a single batch of data.""" + transitions: types.Transition = batch.data + + # Forward pass. + key1, key2 = jax.random.split(key) + q_tm1 = network.apply( + params, transitions.observation, is_training=True, key=key1 + ) + q_t = network.apply( + target_params, transitions.next_observation, is_training=True, key=key2 + ) + + # Cast and clip rewards. + d_t = (transitions.discount * self.discount).astype(jnp.float32) + r_t = jnp.clip( + transitions.reward, -self.max_abs_reward, self.max_abs_reward + ).astype(jnp.float32) + + # Compute Q-learning TD-error. + batch_error = jax.vmap(rlax.q_learning) + td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t) + batch_loss = jnp.square(td_error) + + loss = jnp.mean(batch_loss) + extra = learning_lib.LossExtra(metrics={}) + return loss, extra @dataclasses.dataclass class RegularizedQLearning(learning_lib.LossFn): - """Regularized Q-learning. + """Regularized Q-learning. Implements DQNReg loss function: https://arxiv.org/abs/2101.03958. This is almost identical to QLearning except: 1) Adds a regularization term; 2) Uses vanilla TD error without huber loss. 3) No reward clipping. """ - discount: float = 0.99 - regularizer_coeff = 0.1 - - def __call__( - self, - network: networks_lib.TypedFeedForwardNetwork, - params: networks_lib.Params, - target_params: networks_lib.Params, - batch: reverb.ReplaySample, - key: networks_lib.PRNGKey, - ) -> Tuple[jax.Array, learning_lib.LossExtra]: - """Calculate a loss on a single batch of data.""" - transitions: types.Transition = batch.data - - # Forward pass. - key1, key2 = jax.random.split(key) - q_tm1 = network.apply( - params, transitions.observation, is_training=True, key=key1) - q_t = network.apply( - target_params, transitions.next_observation, is_training=True, key=key2) - - d_t = (transitions.discount * self.discount).astype(jnp.float32) - - # Compute Q-learning TD-error. - batch_error = jax.vmap(rlax.q_learning) - td_error = batch_error( - q_tm1, transitions.action, transitions.reward, d_t, q_t) - td_error = 0.5 * jnp.square(td_error) - - def select(qtm1, action): - return qtm1[action] - q_regularizer = jax.vmap(select)(q_tm1, transitions.action) - - loss = self.regularizer_coeff * jnp.mean(q_regularizer) + jnp.mean(td_error) - extra = learning_lib.LossExtra(metrics={}) - return loss, extra + + discount: float = 0.99 + regularizer_coeff = 0.1 + + def __call__( + self, + network: networks_lib.TypedFeedForwardNetwork, + params: networks_lib.Params, + target_params: networks_lib.Params, + batch: reverb.ReplaySample, + key: networks_lib.PRNGKey, + ) -> Tuple[jax.Array, learning_lib.LossExtra]: + """Calculate a loss on a single batch of data.""" + transitions: types.Transition = batch.data + + # Forward pass. + key1, key2 = jax.random.split(key) + q_tm1 = network.apply( + params, transitions.observation, is_training=True, key=key1 + ) + q_t = network.apply( + target_params, transitions.next_observation, is_training=True, key=key2 + ) + + d_t = (transitions.discount * self.discount).astype(jnp.float32) + + # Compute Q-learning TD-error. + batch_error = jax.vmap(rlax.q_learning) + td_error = batch_error(q_tm1, transitions.action, transitions.reward, d_t, q_t) + td_error = 0.5 * jnp.square(td_error) + + def select(qtm1, action): + return qtm1[action] + + q_regularizer = jax.vmap(select)(q_tm1, transitions.action) + + loss = self.regularizer_coeff * jnp.mean(q_regularizer) + jnp.mean(td_error) + extra = learning_lib.LossExtra(metrics={}) + return loss, extra @dataclasses.dataclass class MunchausenQLearning(learning_lib.LossFn): - """Munchausen q learning. + """Munchausen q learning. Implements M-DQN: https://arxiv.org/abs/2007.14430. """ - entropy_temperature: float = 0.03 # tau parameter - munchausen_coefficient: float = 0.9 # alpha parameter - clip_value_min: float = -1e3 - discount: float = 0.99 - max_abs_reward: float = 1. - huber_loss_parameter: float = 1. - - def __call__( - self, - network: networks_lib.TypedFeedForwardNetwork, - params: networks_lib.Params, - target_params: networks_lib.Params, - batch: reverb.ReplaySample, - key: networks_lib.PRNGKey, - ) -> Tuple[jax.Array, learning_lib.LossExtra]: - """Calculate a loss on a single batch of data.""" - transitions: types.Transition = batch.data - - # Forward pass. - key1, key2, key3 = jax.random.split(key, 3) - q_online_s = network.apply( - params, transitions.observation, is_training=True, key=key1) - action_one_hot = jax.nn.one_hot(transitions.action, q_online_s.shape[-1]) - q_online_sa = jnp.sum(action_one_hot * q_online_s, axis=-1) - q_target_s = network.apply( - target_params, transitions.observation, is_training=True, key=key2) - q_target_next = network.apply( - target_params, transitions.next_observation, is_training=True, key=key3) - - # Cast and clip rewards. - d_t = (transitions.discount * self.discount).astype(jnp.float32) - r_t = jnp.clip(transitions.reward, -self.max_abs_reward, - self.max_abs_reward).astype(jnp.float32) - - # Munchausen term : tau * log_pi(a|s) - munchausen_term = self.entropy_temperature * jax.nn.log_softmax( - q_target_s / self.entropy_temperature, axis=-1) - munchausen_term_a = jnp.sum(action_one_hot * munchausen_term, axis=-1) - munchausen_term_a = jnp.clip(munchausen_term_a, - a_min=self.clip_value_min, - a_max=0.) - - # Soft Bellman operator applied to q - next_v = self.entropy_temperature * jax.nn.logsumexp( - q_target_next / self.entropy_temperature, axis=-1) - target_q = jax.lax.stop_gradient(r_t + self.munchausen_coefficient * - munchausen_term_a + d_t * next_v) - - batch_loss = rlax.huber_loss(target_q - q_online_sa, - self.huber_loss_parameter) - loss = jnp.mean(batch_loss) - - extra = learning_lib.LossExtra(metrics={}) - return loss, extra + + entropy_temperature: float = 0.03 # tau parameter + munchausen_coefficient: float = 0.9 # alpha parameter + clip_value_min: float = -1e3 + discount: float = 0.99 + max_abs_reward: float = 1.0 + huber_loss_parameter: float = 1.0 + + def __call__( + self, + network: networks_lib.TypedFeedForwardNetwork, + params: networks_lib.Params, + target_params: networks_lib.Params, + batch: reverb.ReplaySample, + key: networks_lib.PRNGKey, + ) -> Tuple[jax.Array, learning_lib.LossExtra]: + """Calculate a loss on a single batch of data.""" + transitions: types.Transition = batch.data + + # Forward pass. + key1, key2, key3 = jax.random.split(key, 3) + q_online_s = network.apply( + params, transitions.observation, is_training=True, key=key1 + ) + action_one_hot = jax.nn.one_hot(transitions.action, q_online_s.shape[-1]) + q_online_sa = jnp.sum(action_one_hot * q_online_s, axis=-1) + q_target_s = network.apply( + target_params, transitions.observation, is_training=True, key=key2 + ) + q_target_next = network.apply( + target_params, transitions.next_observation, is_training=True, key=key3 + ) + + # Cast and clip rewards. + d_t = (transitions.discount * self.discount).astype(jnp.float32) + r_t = jnp.clip( + transitions.reward, -self.max_abs_reward, self.max_abs_reward + ).astype(jnp.float32) + + # Munchausen term : tau * log_pi(a|s) + munchausen_term = self.entropy_temperature * jax.nn.log_softmax( + q_target_s / self.entropy_temperature, axis=-1 + ) + munchausen_term_a = jnp.sum(action_one_hot * munchausen_term, axis=-1) + munchausen_term_a = jnp.clip( + munchausen_term_a, a_min=self.clip_value_min, a_max=0.0 + ) + + # Soft Bellman operator applied to q + next_v = self.entropy_temperature * jax.nn.logsumexp( + q_target_next / self.entropy_temperature, axis=-1 + ) + target_q = jax.lax.stop_gradient( + r_t + self.munchausen_coefficient * munchausen_term_a + d_t * next_v + ) + + batch_loss = rlax.huber_loss(target_q - q_online_sa, self.huber_loss_parameter) + loss = jnp.mean(batch_loss) + + extra = learning_lib.LossExtra(metrics={}) + return loss, extra diff --git a/acme/agents/jax/dqn/networks.py b/acme/agents/jax/dqn/networks.py index 7005fdf873..3a411e415c 100644 --- a/acme/agents/jax/dqn/networks.py +++ b/acme/agents/jax/dqn/networks.py @@ -17,36 +17,40 @@ import dataclasses from typing import Callable, Optional +import rlax + from acme.jax import networks as networks_lib from acme.jax import types -import rlax Epsilon = float -EpsilonPolicy = Callable[[ - networks_lib.Params, networks_lib.PRNGKey, networks_lib.Observation, Epsilon -], networks_lib.Action] -EpsilonSampleFn = Callable[[networks_lib.NetworkOutput, types.PRNGKey, Epsilon], - networks_lib.Action] +EpsilonPolicy = Callable[ + [networks_lib.Params, networks_lib.PRNGKey, networks_lib.Observation, Epsilon], + networks_lib.Action, +] +EpsilonSampleFn = Callable[ + [networks_lib.NetworkOutput, types.PRNGKey, Epsilon], networks_lib.Action +] EpsilonLogProbFn = Callable[ - [networks_lib.NetworkOutput, networks_lib.Action, Epsilon], - networks_lib.LogProb] + [networks_lib.NetworkOutput, networks_lib.Action, Epsilon], networks_lib.LogProb +] -def default_sample_fn(action_values: networks_lib.NetworkOutput, - key: types.PRNGKey, - epsilon: Epsilon) -> networks_lib.Action: - return rlax.epsilon_greedy(epsilon).sample(key, action_values) +def default_sample_fn( + action_values: networks_lib.NetworkOutput, key: types.PRNGKey, epsilon: Epsilon +) -> networks_lib.Action: + return rlax.epsilon_greedy(epsilon).sample(key, action_values) @dataclasses.dataclass class DQNNetworks: - """The network and pure functions for the DQN agent. + """The network and pure functions for the DQN agent. Attributes: policy_network: The policy network. sample_fn: A pure function. Samples an action based on the network output. log_prob: A pure function. Computes log-probability for an action. """ - policy_network: networks_lib.TypedFeedForwardNetwork - sample_fn: EpsilonSampleFn = default_sample_fn - log_prob: Optional[EpsilonLogProbFn] = None + + policy_network: networks_lib.TypedFeedForwardNetwork + sample_fn: EpsilonSampleFn = default_sample_fn + log_prob: Optional[EpsilonLogProbFn] = None diff --git a/acme/agents/jax/dqn/rainbow.py b/acme/agents/jax/dqn/rainbow.py index c5773ea200..a910027e1b 100644 --- a/acme/agents/jax/dqn/rainbow.py +++ b/acme/agents/jax/dqn/rainbow.py @@ -17,6 +17,7 @@ import dataclasses from typing import Callable +import rlax from acme import specs from acme.agents.jax.dqn import actor as dqn_actor @@ -25,21 +26,21 @@ from acme.agents.jax.dqn import losses from acme.jax import networks as networks_lib from acme.jax import utils -import rlax -NetworkFactory = Callable[[specs.EnvironmentSpec], - networks_lib.FeedForwardNetwork] +NetworkFactory = Callable[[specs.EnvironmentSpec], networks_lib.FeedForwardNetwork] @dataclasses.dataclass class RainbowConfig(dqn_config.DQNConfig): - """(Additional) configuration options for RainbowDQN.""" - max_abs_reward: float = 1.0 # For clipping reward + """(Additional) configuration options for RainbowDQN.""" + + max_abs_reward: float = 1.0 # For clipping reward def apply_policy_and_sample( - network: networks_lib.FeedForwardNetwork,) -> dqn_actor.EpsilonPolicy: - """Returns a function that computes actions. + network: networks_lib.FeedForwardNetwork, +) -> dqn_actor.EpsilonPolicy: + """Returns a function that computes actions. Note that this differs from default_behavior_policy with that it expects c51-style network head which returns a tuple with the first entry @@ -52,19 +53,20 @@ def apply_policy_and_sample( A feedforward policy. """ - def apply_and_sample(params, key, obs, epsilon): - # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. - obs = utils.add_batch_dim(obs) - action_values = network.apply(params, obs)[0] - action_values = utils.squeeze_batch_dim(action_values) - return rlax.epsilon_greedy(epsilon).sample(key, action_values) + def apply_and_sample(params, key, obs, epsilon): + # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. + obs = utils.add_batch_dim(obs) + action_values = network.apply(params, obs)[0] + action_values = utils.squeeze_batch_dim(action_values) + return rlax.epsilon_greedy(epsilon).sample(key, action_values) - return apply_and_sample + return apply_and_sample -def eval_policy(network: networks_lib.FeedForwardNetwork, - eval_epsilon: float) -> dqn_actor.EpsilonPolicy: - """Returns a function that computes actions. +def eval_policy( + network: networks_lib.FeedForwardNetwork, eval_epsilon: float +) -> dqn_actor.EpsilonPolicy: + """Returns a function that computes actions. Note that this differs from default_behavior_policy with that it expects c51-style network head which returns a tuple with the first entry @@ -77,19 +79,19 @@ def eval_policy(network: networks_lib.FeedForwardNetwork, Returns: A feedforward policy. """ - policy = apply_policy_and_sample(network) + policy = apply_policy_and_sample(network) - def apply_and_sample(params, key, obs, _): - return policy(params, key, obs, eval_epsilon) + def apply_and_sample(params, key, obs, _): + return policy(params, key, obs, eval_epsilon) - return apply_and_sample + return apply_and_sample def make_builder(config: RainbowConfig): - """Returns a DQNBuilder with a pre-built loss function.""" - loss_fn = losses.PrioritizedCategoricalDoubleQLearning( - discount=config.discount, - importance_sampling_exponent=config.importance_sampling_exponent, - max_abs_reward=config.max_abs_reward, - ) - return builder.DQNBuilder(config, loss_fn=loss_fn) + """Returns a DQNBuilder with a pre-built loss function.""" + loss_fn = losses.PrioritizedCategoricalDoubleQLearning( + discount=config.discount, + importance_sampling_exponent=config.importance_sampling_exponent, + max_abs_reward=config.max_abs_reward, + ) + return builder.DQNBuilder(config, loss_fn=loss_fn) diff --git a/acme/agents/jax/impala/__init__.py b/acme/agents/jax/impala/__init__.py index 1f0c3a1b46..270171152c 100644 --- a/acme/agents/jax/impala/__init__.py +++ b/acme/agents/jax/impala/__init__.py @@ -17,5 +17,4 @@ from acme.agents.jax.impala.builder import IMPALABuilder from acme.agents.jax.impala.config import IMPALAConfig from acme.agents.jax.impala.learning import IMPALALearner -from acme.agents.jax.impala.networks import IMPALANetworks -from acme.agents.jax.impala.networks import make_atari_networks +from acme.agents.jax.impala.networks import IMPALANetworks, make_atari_networks diff --git a/acme/agents/jax/impala/acting.py b/acme/agents/jax/impala/acting.py index e308a20ac8..8ea33207ce 100644 --- a/acme/agents/jax/impala/acting.py +++ b/acme/agents/jax/impala/acting.py @@ -16,29 +16,30 @@ from typing import Generic, Mapping, Tuple +import chex +import jax +import jax.numpy as jnp + from acme import specs from acme.agents.jax import actor_core as actor_core_lib from acme.agents.jax.impala import networks as impala_networks from acme.jax import networks as networks_lib from acme.jax import types as jax_types -import chex -import jax -import jax.numpy as jnp - ImpalaExtras = Mapping[str, jnp.ndarray] @chex.dataclass(frozen=True, mappable_dataclass=False) class ImpalaActorState(Generic[actor_core_lib.RecurrentState]): - rng: jax_types.PRNGKey - logits: networks_lib.Logits - recurrent_state: actor_core_lib.RecurrentState - prev_recurrent_state: actor_core_lib.RecurrentState + rng: jax_types.PRNGKey + logits: networks_lib.Logits + recurrent_state: actor_core_lib.RecurrentState + prev_recurrent_state: actor_core_lib.RecurrentState ImpalaPolicy = actor_core_lib.ActorCore[ - ImpalaActorState[actor_core_lib.RecurrentState], ImpalaExtras] + ImpalaActorState[actor_core_lib.RecurrentState], ImpalaExtras +] def get_actor_core( @@ -46,50 +47,51 @@ def get_actor_core( environment_spec: specs.EnvironmentSpec, evaluation: bool = False, ) -> ImpalaPolicy: - """Creates an Impala ActorCore.""" - - dummy_logits = jnp.zeros(environment_spec.actions.num_values) - - def init( - rng: jax_types.PRNGKey - ) -> ImpalaActorState[actor_core_lib.RecurrentState]: - rng, init_state_rng = jax.random.split(rng) - initial_state = networks.init_recurrent_state(init_state_rng, None) - return ImpalaActorState( - rng=rng, - logits=dummy_logits, - recurrent_state=initial_state, - prev_recurrent_state=initial_state) - - def select_action( - params: networks_lib.Params, - observation: networks_lib.Observation, - state: ImpalaActorState[actor_core_lib.RecurrentState], - ) -> Tuple[networks_lib.Action, - ImpalaActorState[actor_core_lib.RecurrentState]]: - - rng, apply_rng, policy_rng = jax.random.split(state.rng, 3) - (logits, _), new_recurrent_state = networks.apply( - params, - apply_rng, - observation, - state.recurrent_state, + """Creates an Impala ActorCore.""" + + dummy_logits = jnp.zeros(environment_spec.actions.num_values) + + def init(rng: jax_types.PRNGKey) -> ImpalaActorState[actor_core_lib.RecurrentState]: + rng, init_state_rng = jax.random.split(rng) + initial_state = networks.init_recurrent_state(init_state_rng, None) + return ImpalaActorState( + rng=rng, + logits=dummy_logits, + recurrent_state=initial_state, + prev_recurrent_state=initial_state, + ) + + def select_action( + params: networks_lib.Params, + observation: networks_lib.Observation, + state: ImpalaActorState[actor_core_lib.RecurrentState], + ) -> Tuple[networks_lib.Action, ImpalaActorState[actor_core_lib.RecurrentState]]: + + rng, apply_rng, policy_rng = jax.random.split(state.rng, 3) + (logits, _), new_recurrent_state = networks.apply( + params, apply_rng, observation, state.recurrent_state, + ) + + if evaluation: + action = jnp.argmax(logits, axis=-1) + else: + action = jax.random.categorical(policy_rng, logits) + + return ( + action, + ImpalaActorState( + rng=rng, + logits=logits, + recurrent_state=new_recurrent_state, + prev_recurrent_state=state.recurrent_state, + ), + ) + + def get_extras( + state: ImpalaActorState[actor_core_lib.RecurrentState], + ) -> ImpalaExtras: + return {"logits": state.logits, "core_state": state.prev_recurrent_state} + + return actor_core_lib.ActorCore( + init=init, select_action=select_action, get_extras=get_extras ) - - if evaluation: - action = jnp.argmax(logits, axis=-1) - else: - action = jax.random.categorical(policy_rng, logits) - - return action, ImpalaActorState( - rng=rng, - logits=logits, - recurrent_state=new_recurrent_state, - prev_recurrent_state=state.recurrent_state) - - def get_extras( - state: ImpalaActorState[actor_core_lib.RecurrentState]) -> ImpalaExtras: - return {'logits': state.logits, 'core_state': state.prev_recurrent_state} - - return actor_core_lib.ActorCore( - init=init, select_action=select_action, get_extras=get_extras) diff --git a/acme/agents/jax/impala/builder.py b/acme/agents/jax/impala/builder.py index 96f429e958..4d25248d86 100644 --- a/acme/agents/jax/impala/builder.py +++ b/acme/agents/jax/impala/builder.py @@ -16,10 +16,12 @@ from typing import Any, Callable, Generic, Iterator, List, Optional +import jax +import optax +import reverb + import acme -from acme import adders -from acme import core -from acme import specs +from acme import adders, core, specs from acme.adders import reverb as reverb_adders from acme.agents.jax import actor_core as actor_core_lib from acme.agents.jax import actors as actors_lib @@ -30,163 +32,170 @@ from acme.agents.jax.impala import networks as impala_networks from acme.datasets import reverb as datasets from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.jax import variable_utils -from acme.utils import counting -from acme.utils import loggers -import jax -import optax -import reverb - - -class IMPALABuilder(Generic[actor_core_lib.RecurrentState], - builders.ActorLearnerBuilder[impala_networks.IMPALANetworks, - acting.ImpalaPolicy, - reverb.ReplaySample]): - """IMPALA Builder.""" - - def __init__( - self, - config: impala_config.IMPALAConfig, - table_extension: Optional[Callable[[], Any]] = None, - ): - """Creates an IMPALA learner.""" - self._config = config - self._sequence_length = self._config.sequence_length - self._table_extension = table_extension - - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: acting.ImpalaPolicy, - ) -> List[reverb.Table]: - """The queue; use XData or INFO log.""" - dummy_actor_state = policy.init(jax.random.PRNGKey(0)) - signature = reverb_adders.SequenceAdder.signature( - environment_spec, - policy.get_extras(dummy_actor_state), - sequence_length=self._config.sequence_length) - - # Maybe create rate limiter. - # Setting the samples_per_insert ratio less than the default of 1.0, allows - # the agent to drop data for the benefit of using data from most up-to-date - # policies to compute its learner updates. - samples_per_insert = self._config.samples_per_insert - if samples_per_insert: - if samples_per_insert > 1.0 or samples_per_insert <= 0.0: - raise ValueError( - 'Impala requires a samples_per_insert ratio in the range (0, 1],' - f' but received {samples_per_insert}.') - limiter = reverb.rate_limiters.SampleToInsertRatio( - samples_per_insert=samples_per_insert, - min_size_to_sample=1, - error_buffer=self._config.batch_size) - else: - limiter = reverb.rate_limiters.MinSize(1) - - table_extensions = [] - if self._table_extension is not None: - table_extensions = [self._table_extension()] - queue = reverb.Table( - name=self._config.replay_table_name, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=self._config.max_queue_size, - max_times_sampled=1, - rate_limiter=limiter, - extensions=table_extensions, - signature=signature) - return [queue] - - def make_dataset_iterator( - self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: - """Creates a dataset.""" - batch_size_per_learner = self._config.batch_size // jax.process_count() - batch_size_per_device, ragged = divmod(self._config.batch_size, - jax.device_count()) - if ragged: - raise ValueError( - 'Learner batch size must be divisible by total number of devices!') - - dataset = datasets.make_reverb_dataset( - table=self._config.replay_table_name, - server_address=replay_client.server_address, - batch_size=batch_size_per_device, - num_parallel_calls=None, - max_in_flight_samples_per_worker=2 * batch_size_per_learner) - - return utils.multi_device_put(dataset.as_numpy_iterator(), - jax.local_devices()) - - def make_adder( - self, - replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[acting.ImpalaPolicy], - ) -> Optional[adders.Adder]: - """Creates an adder which handles observations.""" - del environment_spec, policy - # Note that the last transition in the sequence is used for bootstrapping - # only and is ignored otherwise. So we need to make sure that sequences - # overlap on one transition, thus "-1" in the period length computation. - return reverb_adders.SequenceAdder( - client=replay_client, - priority_fns={self._config.replay_table_name: None}, - period=self._config.sequence_period or (self._sequence_length - 1), - sequence_length=self._sequence_length, - ) - - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: impala_networks.IMPALANetworks, - dataset: Iterator[reverb.ReplaySample], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - del environment_spec, replay_client - - optimizer = optax.chain( - optax.clip_by_global_norm(self._config.max_gradient_norm), - optax.adam( - self._config.learning_rate, - b1=self._config.adam_momentum_decay, - b2=self._config.adam_variance_decay, - eps=self._config.adam_eps, - eps_root=self._config.adam_eps_root)) - - return learning.IMPALALearner( - networks=networks, - iterator=dataset, - optimizer=optimizer, - random_key=random_key, - discount=self._config.discount, - entropy_cost=self._config.entropy_cost, - baseline_cost=self._config.baseline_cost, - max_abs_reward=self._config.max_abs_reward, - counter=counter, - logger=logger_fn('learner'), - ) - - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: acting.ImpalaPolicy, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - adder: Optional[adders.Adder] = None, - ) -> acme.Actor: - del environment_spec - variable_client = variable_utils.VariableClient( - client=variable_source, - key='network', - update_period=self._config.variable_update_period) - return actors_lib.GenericActor(policy, random_key, variable_client, adder) - - def make_policy(self, - networks: impala_networks.IMPALANetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> acting.ImpalaPolicy: - return acting.get_actor_core(networks, environment_spec, evaluation) +from acme.jax import utils, variable_utils +from acme.utils import counting, loggers + + +class IMPALABuilder( + Generic[actor_core_lib.RecurrentState], + builders.ActorLearnerBuilder[ + impala_networks.IMPALANetworks, acting.ImpalaPolicy, reverb.ReplaySample + ], +): + """IMPALA Builder.""" + + def __init__( + self, + config: impala_config.IMPALAConfig, + table_extension: Optional[Callable[[], Any]] = None, + ): + """Creates an IMPALA learner.""" + self._config = config + self._sequence_length = self._config.sequence_length + self._table_extension = table_extension + + def make_replay_tables( + self, environment_spec: specs.EnvironmentSpec, policy: acting.ImpalaPolicy, + ) -> List[reverb.Table]: + """The queue; use XData or INFO log.""" + dummy_actor_state = policy.init(jax.random.PRNGKey(0)) + signature = reverb_adders.SequenceAdder.signature( + environment_spec, + policy.get_extras(dummy_actor_state), + sequence_length=self._config.sequence_length, + ) + + # Maybe create rate limiter. + # Setting the samples_per_insert ratio less than the default of 1.0, allows + # the agent to drop data for the benefit of using data from most up-to-date + # policies to compute its learner updates. + samples_per_insert = self._config.samples_per_insert + if samples_per_insert: + if samples_per_insert > 1.0 or samples_per_insert <= 0.0: + raise ValueError( + "Impala requires a samples_per_insert ratio in the range (0, 1]," + f" but received {samples_per_insert}." + ) + limiter = reverb.rate_limiters.SampleToInsertRatio( + samples_per_insert=samples_per_insert, + min_size_to_sample=1, + error_buffer=self._config.batch_size, + ) + else: + limiter = reverb.rate_limiters.MinSize(1) + + table_extensions = [] + if self._table_extension is not None: + table_extensions = [self._table_extension()] + queue = reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_queue_size, + max_times_sampled=1, + rate_limiter=limiter, + extensions=table_extensions, + signature=signature, + ) + return [queue] + + def make_dataset_iterator( + self, replay_client: reverb.Client + ) -> Iterator[reverb.ReplaySample]: + """Creates a dataset.""" + batch_size_per_learner = self._config.batch_size // jax.process_count() + batch_size_per_device, ragged = divmod( + self._config.batch_size, jax.device_count() + ) + if ragged: + raise ValueError( + "Learner batch size must be divisible by total number of devices!" + ) + + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=replay_client.server_address, + batch_size=batch_size_per_device, + num_parallel_calls=None, + max_in_flight_samples_per_worker=2 * batch_size_per_learner, + ) + + return utils.multi_device_put(dataset.as_numpy_iterator(), jax.local_devices()) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[acting.ImpalaPolicy], + ) -> Optional[adders.Adder]: + """Creates an adder which handles observations.""" + del environment_spec, policy + # Note that the last transition in the sequence is used for bootstrapping + # only and is ignored otherwise. So we need to make sure that sequences + # overlap on one transition, thus "-1" in the period length computation. + return reverb_adders.SequenceAdder( + client=replay_client, + priority_fns={self._config.replay_table_name: None}, + period=self._config.sequence_period or (self._sequence_length - 1), + sequence_length=self._sequence_length, + ) + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: impala_networks.IMPALANetworks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec, replay_client + + optimizer = optax.chain( + optax.clip_by_global_norm(self._config.max_gradient_norm), + optax.adam( + self._config.learning_rate, + b1=self._config.adam_momentum_decay, + b2=self._config.adam_variance_decay, + eps=self._config.adam_eps, + eps_root=self._config.adam_eps_root, + ), + ) + + return learning.IMPALALearner( + networks=networks, + iterator=dataset, + optimizer=optimizer, + random_key=random_key, + discount=self._config.discount, + entropy_cost=self._config.entropy_cost, + baseline_cost=self._config.baseline_cost, + max_abs_reward=self._config.max_abs_reward, + counter=counter, + logger=logger_fn("learner"), + ) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: acting.ImpalaPolicy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> acme.Actor: + del environment_spec + variable_client = variable_utils.VariableClient( + client=variable_source, + key="network", + update_period=self._config.variable_update_period, + ) + return actors_lib.GenericActor(policy, random_key, variable_client, adder) + + def make_policy( + self, + networks: impala_networks.IMPALANetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> acting.ImpalaPolicy: + return acting.get_actor_core(networks, environment_spec, evaluation) diff --git a/acme/agents/jax/impala/config.py b/acme/agents/jax/impala/config.py index 161dd9c5a8..11e09bd3b2 100644 --- a/acme/agents/jax/impala/config.py +++ b/acme/agents/jax/impala/config.py @@ -16,48 +16,52 @@ import dataclasses from typing import Optional, Union -from acme import types -from acme.adders import reverb as adders_reverb import numpy as np import optax +from acme import types +from acme.adders import reverb as adders_reverb + @dataclasses.dataclass class IMPALAConfig: - """Configuration options for IMPALA.""" - seed: int = 0 - discount: float = 0.99 - sequence_length: int = 20 - sequence_period: Optional[int] = None - variable_update_period: int = 1000 - - # Optimizer configuration. - batch_size: int = 32 - learning_rate: Union[float, optax.Schedule] = 2e-4 - adam_momentum_decay: float = 0.0 - adam_variance_decay: float = 0.99 - adam_eps: float = 1e-8 - adam_eps_root: float = 0.0 - max_gradient_norm: float = 40.0 - - # Loss configuration. - baseline_cost: float = 0.5 - entropy_cost: float = 0.01 - max_abs_reward: float = np.inf - - # Replay options - replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE - num_prefetch_threads: Optional[int] = None - samples_per_insert: Optional[float] = 1.0 - max_queue_size: Union[int, types.Batches] = types.Batches(10) - - def __post_init__(self): - if isinstance(self.max_queue_size, types.Batches): - self.max_queue_size *= self.batch_size - assert self.max_queue_size > self.batch_size + 1, (""" + """Configuration options for IMPALA.""" + + seed: int = 0 + discount: float = 0.99 + sequence_length: int = 20 + sequence_period: Optional[int] = None + variable_update_period: int = 1000 + + # Optimizer configuration. + batch_size: int = 32 + learning_rate: Union[float, optax.Schedule] = 2e-4 + adam_momentum_decay: float = 0.0 + adam_variance_decay: float = 0.99 + adam_eps: float = 1e-8 + adam_eps_root: float = 0.0 + max_gradient_norm: float = 40.0 + + # Loss configuration. + baseline_cost: float = 0.5 + entropy_cost: float = 0.01 + max_abs_reward: float = np.inf + + # Replay options + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE + num_prefetch_threads: Optional[int] = None + samples_per_insert: Optional[float] = 1.0 + max_queue_size: Union[int, types.Batches] = types.Batches(10) + + def __post_init__(self): + if isinstance(self.max_queue_size, types.Batches): + self.max_queue_size *= self.batch_size + assert ( + self.max_queue_size > self.batch_size + 1 + ), """ max_queue_size must be strictly larger than the batch size: - during the last step in an episode we might write 2 sequences to Reverb at once (that's how SequenceAdder works) - Reverb does insertion/sampling in multiple threads, so data is added asynchronously at unpredictable times. Therefore we need - additional buffer size in order to avoid deadlocks.""") + additional buffer size in order to avoid deadlocks.""" diff --git a/acme/agents/jax/impala/learning.py b/acme/agents/jax/impala/learning.py index 5228048575..35146fc55f 100644 --- a/acme/agents/jax/impala/learning.py +++ b/acme/agents/jax/impala/learning.py @@ -17,144 +17,154 @@ import time from typing import Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple +import jax +import jax.numpy as jnp +import numpy as np +import optax +import reverb from absl import logging + import acme from acme.agents.jax.impala import networks as impala_networks from acme.jax import losses from acme.jax import networks as networks_lib from acme.jax import utils -from acme.utils import counting -from acme.utils import loggers -import jax -import jax.numpy as jnp -import numpy as np -import optax -import reverb +from acme.utils import counting, loggers -_PMAP_AXIS_NAME = 'data' +_PMAP_AXIS_NAME = "data" class TrainingState(NamedTuple): - """Training state consists of network parameters and optimiser state.""" - params: networks_lib.Params - opt_state: optax.OptState + """Training state consists of network parameters and optimiser state.""" + + params: networks_lib.Params + opt_state: optax.OptState class IMPALALearner(acme.Learner): - """Learner for an importanced-weighted advantage actor-critic.""" - - def __init__( - self, - networks: impala_networks.IMPALANetworks, - iterator: Iterator[reverb.ReplaySample], - optimizer: optax.GradientTransformation, - random_key: networks_lib.PRNGKey, - discount: float = 0.99, - entropy_cost: float = 0.0, - baseline_cost: float = 1.0, - max_abs_reward: float = np.inf, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - devices: Optional[Sequence[jax.Device]] = None, - prefetch_size: int = 2, - ): - local_devices = jax.local_devices() - process_id = jax.process_index() - logging.info('Learner process id: %s. Devices passed: %s', process_id, - devices) - logging.info('Learner process id: %s. Local devices from JAX API: %s', - process_id, local_devices) - self._devices = devices or local_devices - self._local_devices = [d for d in self._devices if d in local_devices] - - self._iterator = iterator - - def unroll_without_rng( - params: networks_lib.Params, observations: networks_lib.Observation, - initial_state: networks_lib.RecurrentState - ) -> Tuple[networks_lib.NetworkOutput, networks_lib.RecurrentState]: - unused_rng = jax.random.PRNGKey(0) - return networks.unroll(params, unused_rng, observations, initial_state) - - loss_fn = losses.impala_loss( - # TODO(b/244319884): Consider supporting the use of RNG in impala_loss. - unroll_fn=unroll_without_rng, - discount=discount, - max_abs_reward=max_abs_reward, - baseline_cost=baseline_cost, - entropy_cost=entropy_cost) - - @jax.jit - def sgd_step( - state: TrainingState, sample: reverb.ReplaySample - ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: - """Computes an SGD step, returning new state and metrics for logging.""" - - # Compute gradients. - grad_fn = jax.value_and_grad(loss_fn, has_aux=True) - (loss_value, metrics), gradients = grad_fn(state.params, sample) - - # Average gradients over pmap replicas before optimizer update. - gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) - - # Apply updates. - updates, new_opt_state = optimizer.update(gradients, state.opt_state) - new_params = optax.apply_updates(state.params, updates) - - metrics.update({ - 'loss': loss_value, - 'param_norm': optax.global_norm(new_params), - 'param_updates_norm': optax.global_norm(updates), - }) - - new_state = TrainingState(params=new_params, opt_state=new_opt_state) - - return new_state, metrics - - def make_initial_state(key: jnp.ndarray) -> TrainingState: - """Initialises the training state (parameters and optimiser state).""" - initial_params = networks.init(key) - return TrainingState( - params=initial_params, opt_state=optimizer.init(initial_params)) - - # Initialise training state (parameters and optimiser state). - state = make_initial_state(random_key) - self._state = utils.replicate_in_all_devices(state, self._local_devices) - - self._sgd_step = jax.pmap( - sgd_step, axis_name=_PMAP_AXIS_NAME, devices=self._devices) - - # Set up logging/counting. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger( - 'learner', steps_key=self._counter.get_steps_key()) - - def step(self): - """Does a step of SGD and logs the results.""" - samples = next(self._iterator) - - # Do a batch of SGD. - start = time.time() - self._state, results = self._sgd_step(self._state, samples) - - # Take results from first replica. - # NOTE: This measure will be a noisy estimate for the purposes of the logs - # as it does not pmean over all devices. - results = utils.get_from_first_device(results) - - # Update our counts and record them. - counts = self._counter.increment(steps=1, time_elapsed=time.time() - start) - - # Maybe write logs. - self._logger.write({**results, **counts}) - - def get_variables(self, names: Sequence[str]) -> List[networks_lib.Params]: - # Return first replica of parameters. - return utils.get_from_first_device([self._state.params], as_numpy=False) - - def save(self) -> TrainingState: - # Serialize only the first replica of parameters and optimizer state. - return utils.get_from_first_device(self._state) - - def restore(self, state: TrainingState): - self._state = utils.replicate_in_all_devices(state, self._local_devices) + """Learner for an importanced-weighted advantage actor-critic.""" + + def __init__( + self, + networks: impala_networks.IMPALANetworks, + iterator: Iterator[reverb.ReplaySample], + optimizer: optax.GradientTransformation, + random_key: networks_lib.PRNGKey, + discount: float = 0.99, + entropy_cost: float = 0.0, + baseline_cost: float = 1.0, + max_abs_reward: float = np.inf, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + devices: Optional[Sequence[jax.Device]] = None, + prefetch_size: int = 2, + ): + local_devices = jax.local_devices() + process_id = jax.process_index() + logging.info("Learner process id: %s. Devices passed: %s", process_id, devices) + logging.info( + "Learner process id: %s. Local devices from JAX API: %s", + process_id, + local_devices, + ) + self._devices = devices or local_devices + self._local_devices = [d for d in self._devices if d in local_devices] + + self._iterator = iterator + + def unroll_without_rng( + params: networks_lib.Params, + observations: networks_lib.Observation, + initial_state: networks_lib.RecurrentState, + ) -> Tuple[networks_lib.NetworkOutput, networks_lib.RecurrentState]: + unused_rng = jax.random.PRNGKey(0) + return networks.unroll(params, unused_rng, observations, initial_state) + + loss_fn = losses.impala_loss( + # TODO(b/244319884): Consider supporting the use of RNG in impala_loss. + unroll_fn=unroll_without_rng, + discount=discount, + max_abs_reward=max_abs_reward, + baseline_cost=baseline_cost, + entropy_cost=entropy_cost, + ) + + @jax.jit + def sgd_step( + state: TrainingState, sample: reverb.ReplaySample + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + """Computes an SGD step, returning new state and metrics for logging.""" + + # Compute gradients. + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss_value, metrics), gradients = grad_fn(state.params, sample) + + # Average gradients over pmap replicas before optimizer update. + gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) + + # Apply updates. + updates, new_opt_state = optimizer.update(gradients, state.opt_state) + new_params = optax.apply_updates(state.params, updates) + + metrics.update( + { + "loss": loss_value, + "param_norm": optax.global_norm(new_params), + "param_updates_norm": optax.global_norm(updates), + } + ) + + new_state = TrainingState(params=new_params, opt_state=new_opt_state) + + return new_state, metrics + + def make_initial_state(key: jnp.ndarray) -> TrainingState: + """Initialises the training state (parameters and optimiser state).""" + initial_params = networks.init(key) + return TrainingState( + params=initial_params, opt_state=optimizer.init(initial_params) + ) + + # Initialise training state (parameters and optimiser state). + state = make_initial_state(random_key) + self._state = utils.replicate_in_all_devices(state, self._local_devices) + + self._sgd_step = jax.pmap( + sgd_step, axis_name=_PMAP_AXIS_NAME, devices=self._devices + ) + + # Set up logging/counting. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + "learner", steps_key=self._counter.get_steps_key() + ) + + def step(self): + """Does a step of SGD and logs the results.""" + samples = next(self._iterator) + + # Do a batch of SGD. + start = time.time() + self._state, results = self._sgd_step(self._state, samples) + + # Take results from first replica. + # NOTE: This measure will be a noisy estimate for the purposes of the logs + # as it does not pmean over all devices. + results = utils.get_from_first_device(results) + + # Update our counts and record them. + counts = self._counter.increment(steps=1, time_elapsed=time.time() - start) + + # Maybe write logs. + self._logger.write({**results, **counts}) + + def get_variables(self, names: Sequence[str]) -> List[networks_lib.Params]: + # Return first replica of parameters. + return utils.get_from_first_device([self._state.params], as_numpy=False) + + def save(self) -> TrainingState: + # Serialize only the first replica of parameters and optimizer state. + return utils.get_from_first_device(self._state) + + def restore(self, state: TrainingState): + self._state = utils.replicate_in_all_devices(state, self._local_devices) diff --git a/acme/agents/jax/impala/networks.py b/acme/agents/jax/impala/networks.py index 4bfbd5139e..583711d9f5 100644 --- a/acme/agents/jax/impala/networks.py +++ b/acme/agents/jax/impala/networks.py @@ -17,14 +17,13 @@ from acme import specs from acme.jax import networks as networks_lib - IMPALANetworks = networks_lib.UnrollableNetwork def make_atari_networks(env_spec: specs.EnvironmentSpec) -> IMPALANetworks: - """Builds default IMPALA networks for Atari games.""" + """Builds default IMPALA networks for Atari games.""" - def make_core_module() -> networks_lib.DeepIMPALAAtariNetwork: - return networks_lib.DeepIMPALAAtariNetwork(env_spec.actions.num_values) + def make_core_module() -> networks_lib.DeepIMPALAAtariNetwork: + return networks_lib.DeepIMPALAAtariNetwork(env_spec.actions.num_values) - return networks_lib.make_unrollable_network(env_spec, make_core_module) + return networks_lib.make_unrollable_network(env_spec, make_core_module) diff --git a/acme/agents/jax/impala/types.py b/acme/agents/jax/impala/types.py index 6763a71d4e..2917b07725 100644 --- a/acme/agents/jax/impala/types.py +++ b/acme/agents/jax/impala/types.py @@ -15,17 +15,16 @@ """Some types/assumptions used in the IMPALA agent.""" from typing import Callable, Tuple +import jax.numpy as jnp + from acme.agents.jax.actor_core import RecurrentState from acme.jax import networks from acme.jax import types as jax_types -import jax.numpy as jnp # Only simple observations & discrete action spaces for now. Observation = jnp.ndarray Action = int Outputs = Tuple[Tuple[networks.Logits, networks.Value], RecurrentState] -PolicyValueInitFn = Callable[[networks.PRNGKey, RecurrentState], - networks.Params] -PolicyValueFn = Callable[[networks.Params, Observation, RecurrentState], - Outputs] +PolicyValueInitFn = Callable[[networks.PRNGKey, RecurrentState], networks.Params] +PolicyValueFn = Callable[[networks.Params, Observation, RecurrentState], Outputs] RecurrentStateFn = Callable[[jax_types.PRNGKey], RecurrentState] diff --git a/acme/agents/jax/lfd/__init__.py b/acme/agents/jax/lfd/__init__.py index 873ce23a1d..d373aab8e7 100644 --- a/acme/agents/jax/lfd/__init__.py +++ b/acme/agents/jax/lfd/__init__.py @@ -14,10 +14,7 @@ """Lfd agents.""" -from acme.agents.jax.lfd.builder import LfdBuilder -from acme.agents.jax.lfd.builder import LfdStep +from acme.agents.jax.lfd.builder import LfdBuilder, LfdStep from acme.agents.jax.lfd.config import LfdConfig -from acme.agents.jax.lfd.sacfd import SACfDBuilder -from acme.agents.jax.lfd.sacfd import SACfDConfig -from acme.agents.jax.lfd.td3fd import TD3fDBuilder -from acme.agents.jax.lfd.td3fd import TD3fDConfig +from acme.agents.jax.lfd.sacfd import SACfDBuilder, SACfDConfig +from acme.agents.jax.lfd.td3fd import TD3fDBuilder, TD3fDConfig diff --git a/acme/agents/jax/lfd/builder.py b/acme/agents/jax/lfd/builder.py index 2f544f4180..fd5c704127 100644 --- a/acme/agents/jax/lfd/builder.py +++ b/acme/agents/jax/lfd/builder.py @@ -16,31 +16,34 @@ from typing import Any, Callable, Generic, Iterator, Tuple +import dm_env + from acme.agents.jax import builders from acme.agents.jax.lfd import config as lfd_config from acme.agents.jax.lfd import lfd_adder -import dm_env - LfdStep = Tuple[Any, dm_env.TimeStep] -class LfdBuilder(builders.ActorLearnerBuilder[builders.Networks, - builders.Policy, - builders.Sample,], - Generic[builders.Networks, builders.Policy, builders.Sample]): - """Builder that enables Learning From demonstrations. +class LfdBuilder( + builders.ActorLearnerBuilder[builders.Networks, builders.Policy, builders.Sample,], + Generic[builders.Networks, builders.Policy, builders.Sample], +): + """Builder that enables Learning From demonstrations. This builder is not self contained and requires an underlying builder implementing an off-policy algorithm. """ - def __init__(self, builder: builders.ActorLearnerBuilder[builders.Networks, - builders.Policy, - builders.Sample], - demonstrations_factory: Callable[[], Iterator[LfdStep]], - config: lfd_config.LfdConfig): - """LfdBuilder constructor. + def __init__( + self, + builder: builders.ActorLearnerBuilder[ + builders.Networks, builders.Policy, builders.Sample + ], + demonstrations_factory: Callable[[], Iterator[LfdStep]], + config: lfd_config.LfdConfig, + ): + """LfdBuilder constructor. Args: builder: The underlying builder implementing the off-policy algorithm. @@ -53,28 +56,30 @@ def __init__(self, builder: builders.ActorLearnerBuilder[builders.Networks, as the number of actors being used. config: LfD configuration. """ - self._builder = builder - self._demonstrations_factory = demonstrations_factory - self._config = config + self._builder = builder + self._demonstrations_factory = demonstrations_factory + self._config = config - def make_replay_tables(self, *args, **kwargs): - return self._builder.make_replay_tables(*args, **kwargs) + def make_replay_tables(self, *args, **kwargs): + return self._builder.make_replay_tables(*args, **kwargs) - def make_dataset_iterator(self, *args, **kwargs): - return self._builder.make_dataset_iterator(*args, **kwargs) + def make_dataset_iterator(self, *args, **kwargs): + return self._builder.make_dataset_iterator(*args, **kwargs) - def make_adder(self, *args, **kwargs): - demonstrations = self._demonstrations_factory() - return lfd_adder.LfdAdder(self._builder.make_adder(*args, **kwargs), - demonstrations, - self._config.initial_insert_count, - self._config.demonstration_ratio) + def make_adder(self, *args, **kwargs): + demonstrations = self._demonstrations_factory() + return lfd_adder.LfdAdder( + self._builder.make_adder(*args, **kwargs), + demonstrations, + self._config.initial_insert_count, + self._config.demonstration_ratio, + ) - def make_actor(self, *args, **kwargs): - return self._builder.make_actor(*args, **kwargs) + def make_actor(self, *args, **kwargs): + return self._builder.make_actor(*args, **kwargs) - def make_learner(self, *args, **kwargs): - return self._builder.make_learner(*args, **kwargs) + def make_learner(self, *args, **kwargs): + return self._builder.make_learner(*args, **kwargs) - def make_policy(self, *args, **kwargs): - return self._builder.make_policy(*args, **kwargs) + def make_policy(self, *args, **kwargs): + return self._builder.make_policy(*args, **kwargs) diff --git a/acme/agents/jax/lfd/config.py b/acme/agents/jax/lfd/config.py index 2d6caf302f..d3fc7be3ac 100644 --- a/acme/agents/jax/lfd/config.py +++ b/acme/agents/jax/lfd/config.py @@ -19,7 +19,7 @@ @dataclasses.dataclass class LfdConfig: - """Configuration options for LfD. + """Configuration options for LfD. Attributes: initial_insert_count: Number of steps of demonstrations to add to the replay @@ -33,5 +33,6 @@ class LfdConfig: Note also that this ratio is only a target ratio since the granularity is the episode. """ - initial_insert_count: int = 0 - demonstration_ratio: float = 0.01 + + initial_insert_count: int = 0 + demonstration_ratio: float = 0.01 diff --git a/acme/agents/jax/lfd/lfd_adder.py b/acme/agents/jax/lfd/lfd_adder.py index 9d53c96e26..4151ca7b8a 100644 --- a/acme/agents/jax/lfd/lfd_adder.py +++ b/acme/agents/jax/lfd/lfd_adder.py @@ -26,24 +26,26 @@ from typing import Any, Iterator, Tuple -from acme import adders -from acme import types import dm_env +from acme import adders, types + class LfdAdder(adders.Adder): - """Adder which adds from time to time some demonstrations. + """Adder which adds from time to time some demonstrations. Lfd stands for Learning From Demonstrations and is the same technique as the one used in R2D3. """ - def __init__(self, - adder: adders.Adder, - demonstrations: Iterator[Tuple[Any, dm_env.TimeStep]], - initial_insert_count: int, - demonstration_ratio: float): - """LfdAdder constructor. + def __init__( + self, + adder: adders.Adder, + demonstrations: Iterator[Tuple[Any, dm_env.TimeStep]], + initial_insert_count: int, + demonstration_ratio: float, + ): + """LfdAdder constructor. Args: adder: The underlying adder used to add mixed episodes. @@ -63,55 +65,58 @@ def __init__(self, Note also that this ratio is only a target ratio since the granularity is the episode. """ - self._adder = adder - self._demonstrations = demonstrations - self._demonstration_ratio = demonstration_ratio - if demonstration_ratio < 0 or demonstration_ratio >= 1.: - raise ValueError('Invalid demonstration ratio.') - - # Number of demonstration steps that should have been added to the replay - # buffer to meet the target demonstration ratio minus what has been really - # added. - # As a consequence: - # - when this delta is zero, the effective ratio exactly matches the desired - # ratio - # - when it is positive, more demonstrations need to be added to - # reestablish the balance - # The initial value is set so that after exactly initial_insert_count - # inserts of demonstration steps, _delta_demonstration_step_count will be - # zero. - self._delta_demonstration_step_count = ( - (1. - self._demonstration_ratio) * initial_insert_count) - - def reset(self): - self._adder.reset() - - def _add_demonstration_episode(self): - _, timestep = next(self._demonstrations) - if not timestep.first(): - raise ValueError('Expecting the start of an episode.') - self._adder.add_first(timestep) - self._delta_demonstration_step_count -= (1. - self._demonstration_ratio) - while not timestep.last(): - action, timestep = next(self._demonstrations) - self._adder.add(action, timestep) - self._delta_demonstration_step_count -= (1. - self._demonstration_ratio) - - # Reset is being called periodically to reset the connection to reverb. - # TODO(damienv, bshahr): Make the reset an internal detail of the reverb - # adder and remove it from the adder API. - self._adder.reset() - - def add_first(self, timestep: dm_env.TimeStep): - while self._delta_demonstration_step_count > 0.: - self._add_demonstration_episode() - - self._adder.add_first(timestep) - self._delta_demonstration_step_count += self._demonstration_ratio - - def add(self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - extras: types.NestedArray = ()): - self._adder.add(action, next_timestep) - self._delta_demonstration_step_count += self._demonstration_ratio + self._adder = adder + self._demonstrations = demonstrations + self._demonstration_ratio = demonstration_ratio + if demonstration_ratio < 0 or demonstration_ratio >= 1.0: + raise ValueError("Invalid demonstration ratio.") + + # Number of demonstration steps that should have been added to the replay + # buffer to meet the target demonstration ratio minus what has been really + # added. + # As a consequence: + # - when this delta is zero, the effective ratio exactly matches the desired + # ratio + # - when it is positive, more demonstrations need to be added to + # reestablish the balance + # The initial value is set so that after exactly initial_insert_count + # inserts of demonstration steps, _delta_demonstration_step_count will be + # zero. + self._delta_demonstration_step_count = ( + 1.0 - self._demonstration_ratio + ) * initial_insert_count + + def reset(self): + self._adder.reset() + + def _add_demonstration_episode(self): + _, timestep = next(self._demonstrations) + if not timestep.first(): + raise ValueError("Expecting the start of an episode.") + self._adder.add_first(timestep) + self._delta_demonstration_step_count -= 1.0 - self._demonstration_ratio + while not timestep.last(): + action, timestep = next(self._demonstrations) + self._adder.add(action, timestep) + self._delta_demonstration_step_count -= 1.0 - self._demonstration_ratio + + # Reset is being called periodically to reset the connection to reverb. + # TODO(damienv, bshahr): Make the reset an internal detail of the reverb + # adder and remove it from the adder API. + self._adder.reset() + + def add_first(self, timestep: dm_env.TimeStep): + while self._delta_demonstration_step_count > 0.0: + self._add_demonstration_episode() + + self._adder.add_first(timestep) + self._delta_demonstration_step_count += self._demonstration_ratio + + def add( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = (), + ): + self._adder.add(action, next_timestep) + self._delta_demonstration_step_count += self._demonstration_ratio diff --git a/acme/agents/jax/lfd/lfd_adder_test.py b/acme/agents/jax/lfd/lfd_adder_test.py index 1e8f926ead..8700a25db5 100644 --- a/acme/agents/jax/lfd/lfd_adder_test.py +++ b/acme/agents/jax/lfd/lfd_adder_test.py @@ -16,128 +16,145 @@ import collections -from acme import adders -from acme import types -from acme.agents.jax.lfd import lfd_adder import dm_env import numpy as np - from absl.testing import absltest +from acme import adders, types +from acme.agents.jax.lfd import lfd_adder + class TestStatisticsAdder(adders.Adder): + def __init__(self): + self.counts = collections.defaultdict(int) - def __init__(self): - self.counts = collections.defaultdict(int) - - def reset(self): - pass + def reset(self): + pass - def add_first(self, timestep: dm_env.TimeStep): - self.counts[int(timestep.observation[0])] += 1 + def add_first(self, timestep: dm_env.TimeStep): + self.counts[int(timestep.observation[0])] += 1 - def add(self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - extras: types.NestedArray = ()): - del action - del extras - self.counts[int(next_timestep.observation[0])] += 1 + def add( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = (), + ): + del action + del extras + self.counts[int(next_timestep.observation[0])] += 1 class LfdAdderTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self._demonstration_episode_type = 1 - self._demonstration_episode_length = 10 - self._collected_episode_type = 2 - self._collected_episode_length = 5 - - def generate_episode(self, episode_type, episode_index, length): - episode = [] - action_dim = 8 - obs_dim = 16 - for k in range(length): - if k == 0: - action = None - else: - action = np.concatenate([ - np.asarray([episode_type, episode_index], dtype=np.float32), - np.random.uniform(0., 1., (action_dim - 2,))]) - observation = np.concatenate([ - np.asarray([episode_type, episode_index], dtype=np.float32), - np.random.uniform(0., 1., (obs_dim - 2,))]) - if k == 0: - timestep = dm_env.restart(observation) - elif k == length - 1: - timestep = dm_env.termination(0., observation) - else: - timestep = dm_env.transition(0., observation, 1.) - episode.append((action, timestep)) - return episode - - def generate_demonstration(self): - episode_index = 0 - while True: - episode = self.generate_episode(self._demonstration_episode_type, - episode_index, - self._demonstration_episode_length) - for x in episode: - yield x - episode_index += 1 - - def test_adder(self): - stats_adder = TestStatisticsAdder() - demonstration_ratio = 0.2 - initial_insert_count = 50 - adder = lfd_adder.LfdAdder( - stats_adder, - self.generate_demonstration(), - initial_insert_count=initial_insert_count, - demonstration_ratio=demonstration_ratio) - - num_episodes = 100 - for episode_index in range(num_episodes): - episode = self.generate_episode(self._collected_episode_type, - episode_index, - self._collected_episode_length) - for k, (action, timestep) in enumerate(episode): - if k == 0: - adder.add_first(timestep) - if episode_index == 0: - self.assertGreaterEqual( - stats_adder.counts[self._demonstration_episode_type], - initial_insert_count - self._demonstration_episode_length) - self.assertLessEqual( - stats_adder.counts[self._demonstration_episode_type], - initial_insert_count + self._demonstration_episode_length) - else: - adder.add(action, timestep) - - # Only 2 types of episodes. - self.assertLen(stats_adder.counts, 2) - - total_count = (stats_adder.counts[self._demonstration_episode_type] + - stats_adder.counts[self._collected_episode_type]) - # The demonstration ratio does not account for the initial demonstration - # insertion. Computes a ratio that takes it into account. - target_ratio = ( - demonstration_ratio * (float)(total_count - initial_insert_count) - + initial_insert_count) / (float)(total_count) - # Effective ratio of demonstrations. - effective_ratio = ( - float(stats_adder.counts[self._demonstration_episode_type]) / - float(total_count)) - # Only full episodes can be fed to the adder so the effective ratio - # might be slightly different from the requested demonstration ratio. - min_ratio = (target_ratio - - self._demonstration_episode_length / float(total_count)) - max_ratio = (target_ratio + - self._demonstration_episode_length / float(total_count)) - self.assertGreaterEqual(effective_ratio, min_ratio) - self.assertLessEqual(effective_ratio, max_ratio) - - -if __name__ == '__main__': - absltest.main() + def setUp(self): + super().setUp() + self._demonstration_episode_type = 1 + self._demonstration_episode_length = 10 + self._collected_episode_type = 2 + self._collected_episode_length = 5 + + def generate_episode(self, episode_type, episode_index, length): + episode = [] + action_dim = 8 + obs_dim = 16 + for k in range(length): + if k == 0: + action = None + else: + action = np.concatenate( + [ + np.asarray([episode_type, episode_index], dtype=np.float32), + np.random.uniform(0.0, 1.0, (action_dim - 2,)), + ] + ) + observation = np.concatenate( + [ + np.asarray([episode_type, episode_index], dtype=np.float32), + np.random.uniform(0.0, 1.0, (obs_dim - 2,)), + ] + ) + if k == 0: + timestep = dm_env.restart(observation) + elif k == length - 1: + timestep = dm_env.termination(0.0, observation) + else: + timestep = dm_env.transition(0.0, observation, 1.0) + episode.append((action, timestep)) + return episode + + def generate_demonstration(self): + episode_index = 0 + while True: + episode = self.generate_episode( + self._demonstration_episode_type, + episode_index, + self._demonstration_episode_length, + ) + for x in episode: + yield x + episode_index += 1 + + def test_adder(self): + stats_adder = TestStatisticsAdder() + demonstration_ratio = 0.2 + initial_insert_count = 50 + adder = lfd_adder.LfdAdder( + stats_adder, + self.generate_demonstration(), + initial_insert_count=initial_insert_count, + demonstration_ratio=demonstration_ratio, + ) + + num_episodes = 100 + for episode_index in range(num_episodes): + episode = self.generate_episode( + self._collected_episode_type, + episode_index, + self._collected_episode_length, + ) + for k, (action, timestep) in enumerate(episode): + if k == 0: + adder.add_first(timestep) + if episode_index == 0: + self.assertGreaterEqual( + stats_adder.counts[self._demonstration_episode_type], + initial_insert_count - self._demonstration_episode_length, + ) + self.assertLessEqual( + stats_adder.counts[self._demonstration_episode_type], + initial_insert_count + self._demonstration_episode_length, + ) + else: + adder.add(action, timestep) + + # Only 2 types of episodes. + self.assertLen(stats_adder.counts, 2) + + total_count = ( + stats_adder.counts[self._demonstration_episode_type] + + stats_adder.counts[self._collected_episode_type] + ) + # The demonstration ratio does not account for the initial demonstration + # insertion. Computes a ratio that takes it into account. + target_ratio = ( + demonstration_ratio * (float)(total_count - initial_insert_count) + + initial_insert_count + ) / (float)(total_count) + # Effective ratio of demonstrations. + effective_ratio = float( + stats_adder.counts[self._demonstration_episode_type] + ) / float(total_count) + # Only full episodes can be fed to the adder so the effective ratio + # might be slightly different from the requested demonstration ratio. + min_ratio = target_ratio - self._demonstration_episode_length / float( + total_count + ) + max_ratio = target_ratio + self._demonstration_episode_length / float( + total_count + ) + self.assertGreaterEqual(effective_ratio, min_ratio) + self.assertLessEqual(effective_ratio, max_ratio) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/jax/lfd/sacfd.py b/acme/agents/jax/lfd/sacfd.py index bd4c2d4928..56acee2749 100644 --- a/acme/agents/jax/lfd/sacfd.py +++ b/acme/agents/jax/lfd/sacfd.py @@ -17,31 +17,37 @@ import dataclasses from typing import Callable, Iterator +import reverb + from acme.agents.jax import actor_core as actor_core_lib from acme.agents.jax import sac -from acme.agents.jax.lfd import builder -from acme.agents.jax.lfd import config -import reverb +from acme.agents.jax.lfd import builder, config @dataclasses.dataclass class SACfDConfig: - """Configuration options specific to SAC with demonstrations. + """Configuration options specific to SAC with demonstrations. Attributes: lfd_config: LfD config. sac_config: SAC config. """ - lfd_config: config.LfdConfig - sac_config: sac.SACConfig + + lfd_config: config.LfdConfig + sac_config: sac.SACConfig -class SACfDBuilder(builder.LfdBuilder[sac.SACNetworks, - actor_core_lib.FeedForwardPolicy, - reverb.ReplaySample]): - """Builder for SAC agent learning from demonstrations.""" +class SACfDBuilder( + builder.LfdBuilder[ + sac.SACNetworks, actor_core_lib.FeedForwardPolicy, reverb.ReplaySample + ] +): + """Builder for SAC agent learning from demonstrations.""" - def __init__(self, sac_fd_config: SACfDConfig, - lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]]): - sac_builder = sac.SACBuilder(sac_fd_config.sac_config) - super().__init__(sac_builder, lfd_iterator_fn, sac_fd_config.lfd_config) + def __init__( + self, + sac_fd_config: SACfDConfig, + lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]], + ): + sac_builder = sac.SACBuilder(sac_fd_config.sac_config) + super().__init__(sac_builder, lfd_iterator_fn, sac_fd_config.lfd_config) diff --git a/acme/agents/jax/lfd/td3fd.py b/acme/agents/jax/lfd/td3fd.py index 531bfe35dc..1bc259ebc3 100644 --- a/acme/agents/jax/lfd/td3fd.py +++ b/acme/agents/jax/lfd/td3fd.py @@ -17,31 +17,37 @@ import dataclasses from typing import Callable, Iterator +import reverb + from acme.agents.jax import actor_core as actor_core_lib from acme.agents.jax import td3 -from acme.agents.jax.lfd import builder -from acme.agents.jax.lfd import config -import reverb +from acme.agents.jax.lfd import builder, config @dataclasses.dataclass class TD3fDConfig: - """Configuration options specific to TD3 with demonstrations. + """Configuration options specific to TD3 with demonstrations. Attributes: lfd_config: LfD config. td3_config: TD3 config. """ - lfd_config: config.LfdConfig - td3_config: td3.TD3Config + + lfd_config: config.LfdConfig + td3_config: td3.TD3Config -class TD3fDBuilder(builder.LfdBuilder[td3.TD3Networks, - actor_core_lib.FeedForwardPolicy, - reverb.ReplaySample]): - """Builder for TD3 agent learning from demonstrations.""" +class TD3fDBuilder( + builder.LfdBuilder[ + td3.TD3Networks, actor_core_lib.FeedForwardPolicy, reverb.ReplaySample + ] +): + """Builder for TD3 agent learning from demonstrations.""" - def __init__(self, td3_fd_config: TD3fDConfig, - lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]]): - td3_builder = td3.TD3Builder(td3_fd_config.td3_config) - super().__init__(td3_builder, lfd_iterator_fn, td3_fd_config.lfd_config) + def __init__( + self, + td3_fd_config: TD3fDConfig, + lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]], + ): + td3_builder = td3.TD3Builder(td3_fd_config.td3_config) + super().__init__(td3_builder, lfd_iterator_fn, td3_fd_config.lfd_config) diff --git a/acme/agents/jax/mbop/__init__.py b/acme/agents/jax/mbop/__init__.py index 769f96f51d..b1ff9cb6b1 100644 --- a/acme/agents/jax/mbop/__init__.py +++ b/acme/agents/jax/mbop/__init__.py @@ -14,38 +14,52 @@ """Implementation of the Model-Based Offline Planning (MBOP) agent.""" -from acme.agents.jax.mbop.acting import ActorCore -from acme.agents.jax.mbop.acting import make_actor -from acme.agents.jax.mbop.acting import make_actor_core -from acme.agents.jax.mbop.acting import make_ensemble_actor_core +from acme.agents.jax.mbop.acting import ( + ActorCore, + make_actor, + make_actor_core, + make_ensemble_actor_core, +) from acme.agents.jax.mbop.builder import MBOPBuilder from acme.agents.jax.mbop.config import MBOPConfig -from acme.agents.jax.mbop.dataset import EPISODE_RETURN -from acme.agents.jax.mbop.dataset import episodes_to_timestep_batched_transitions -from acme.agents.jax.mbop.dataset import get_normalization_stats -from acme.agents.jax.mbop.dataset import N_STEP_RETURN -from acme.agents.jax.mbop.learning import LoggerFn -from acme.agents.jax.mbop.learning import make_ensemble_regressor_learner -from acme.agents.jax.mbop.learning import MakeNStepReturnLearner -from acme.agents.jax.mbop.learning import MakePolicyPriorLearner -from acme.agents.jax.mbop.learning import MakeWorldModelLearner -from acme.agents.jax.mbop.learning import MBOPLearner -from acme.agents.jax.mbop.learning import TrainingState -from acme.agents.jax.mbop.losses import MBOPLosses -from acme.agents.jax.mbop.losses import policy_prior_loss -from acme.agents.jax.mbop.losses import TransitionLoss -from acme.agents.jax.mbop.losses import world_model_loss -from acme.agents.jax.mbop.models import make_ensemble_n_step_return -from acme.agents.jax.mbop.models import make_ensemble_policy_prior -from acme.agents.jax.mbop.models import make_ensemble_world_model -from acme.agents.jax.mbop.models import MakeNStepReturn -from acme.agents.jax.mbop.models import MakePolicyPrior -from acme.agents.jax.mbop.models import MakeWorldModel -from acme.agents.jax.mbop.mppi import mppi_planner -from acme.agents.jax.mbop.mppi import MPPIConfig -from acme.agents.jax.mbop.mppi import return_top_k_average -from acme.agents.jax.mbop.mppi import return_weighted_average -from acme.agents.jax.mbop.networks import make_networks -from acme.agents.jax.mbop.networks import make_policy_prior_network -from acme.agents.jax.mbop.networks import make_world_model_network -from acme.agents.jax.mbop.networks import MBOPNetworks +from acme.agents.jax.mbop.dataset import ( + EPISODE_RETURN, + N_STEP_RETURN, + episodes_to_timestep_batched_transitions, + get_normalization_stats, +) +from acme.agents.jax.mbop.learning import ( + LoggerFn, + MakeNStepReturnLearner, + MakePolicyPriorLearner, + MakeWorldModelLearner, + MBOPLearner, + TrainingState, + make_ensemble_regressor_learner, +) +from acme.agents.jax.mbop.losses import ( + MBOPLosses, + TransitionLoss, + policy_prior_loss, + world_model_loss, +) +from acme.agents.jax.mbop.models import ( + MakeNStepReturn, + MakePolicyPrior, + MakeWorldModel, + make_ensemble_n_step_return, + make_ensemble_policy_prior, + make_ensemble_world_model, +) +from acme.agents.jax.mbop.mppi import ( + MPPIConfig, + mppi_planner, + return_top_k_average, + return_weighted_average, +) +from acme.agents.jax.mbop.networks import ( + MBOPNetworks, + make_networks, + make_policy_prior_network, + make_world_model_network, +) diff --git a/acme/agents/jax/mbop/acting.py b/acme/agents/jax/mbop/acting.py index 8d06b66fa2..3401e6de80 100644 --- a/acme/agents/jax/mbop/acting.py +++ b/acme/agents/jax/mbop/acting.py @@ -16,26 +16,23 @@ from typing import List, Mapping, Optional, Tuple -from acme import adders -from acme import core -from acme import specs +import jax +from jax import numpy as jnp + +from acme import adders, core, specs from acme.agents.jax import actor_core as actor_core_lib from acme.agents.jax import actors -from acme.agents.jax.mbop import models -from acme.agents.jax.mbop import mppi +from acme.agents.jax.mbop import models, mppi from acme.agents.jax.mbop import networks as mbop_networks from acme.jax import networks as networks_lib -from acme.jax import running_statistics -from acme.jax import variable_utils -import jax -from jax import numpy as jnp +from acme.jax import running_statistics, variable_utils # Recurrent state is the trajectory. Trajectory = jnp.ndarray ActorCore = actor_core_lib.ActorCore[ - actor_core_lib.SimpleActorCoreRecurrentState[Trajectory], - Mapping[str, jnp.ndarray]] + actor_core_lib.SimpleActorCoreRecurrentState[Trajectory], Mapping[str, jnp.ndarray] +] def make_actor_core( @@ -46,7 +43,7 @@ def make_actor_core( environment_spec: specs.EnvironmentSpec, mean_std: Optional[running_statistics.NestedMeanStd] = None, ) -> ActorCore: - """Creates an actor core wrapping the MBOP-configured MPPI planner. + """Creates an actor core wrapping the MBOP-configured MPPI planner. Args: mppi_config: Planner hyperparameters. @@ -60,79 +57,92 @@ def make_actor_core( A recurrent actor core. """ - if mean_std is not None: - mean_std_observation = running_statistics.NestedMeanStd( - mean=mean_std.mean.observation, std=mean_std.std.observation) - mean_std_action = running_statistics.NestedMeanStd( - mean=mean_std.mean.action, std=mean_std.std.action) - mean_std_reward = running_statistics.NestedMeanStd( - mean=mean_std.mean.reward, std=mean_std.std.reward) - mean_std_n_step_return = running_statistics.NestedMeanStd( - mean=mean_std.mean.extras['n_step_return'], - std=mean_std.std.extras['n_step_return']) - - def denormalized_world_model( - params: networks_lib.Params, observation_t: networks_lib.Observation, - action_t: networks_lib.Action - ) -> Tuple[networks_lib.Observation, networks_lib.Value]: - """Denormalizes the reward for proper weighting in the planner.""" - observation_tp1, normalized_reward_t = world_model( - params, observation_t, action_t) - reward_t = running_statistics.denormalize(normalized_reward_t, - mean_std_reward) - return observation_tp1, reward_t - - planner_world_model = denormalized_world_model - - def denormalized_n_step_return( - params: networks_lib.Params, observation_t: networks_lib.Observation, - action_t: networks_lib.Action) -> networks_lib.Value: - """Denormalize the n-step return for proper weighting in the planner.""" - normalized_n_step_return_t = n_step_return(params, observation_t, - action_t) - return running_statistics.denormalize(normalized_n_step_return_t, - mean_std_n_step_return) - - planner_n_step_return = denormalized_n_step_return - else: - planner_world_model = world_model - planner_n_step_return = n_step_return - - def recurrent_policy( - params_list: List[networks_lib.Params], - random_key: networks_lib.PRNGKey, - observation: networks_lib.Observation, - previous_trajectory: Trajectory, - ) -> Tuple[networks_lib.Action, Trajectory]: - # Note that splitting the random key is handled by GenericActor. - if mean_std is not None: - observation = running_statistics.normalize( - observation, mean_std=mean_std_observation) - trajectory = mppi.mppi_planner( - config=mppi_config, - world_model=planner_world_model, - policy_prior=policy_prior, - n_step_return=planner_n_step_return, - world_model_params=params_list[0], - policy_prior_params=params_list[1], - n_step_return_params=params_list[2], - random_key=random_key, - observation=observation, - previous_trajectory=previous_trajectory) - action = trajectory[0, ...] if mean_std is not None: - action = running_statistics.denormalize(action, mean_std=mean_std_action) - return (action, trajectory) - - batched_policy = jax.vmap(recurrent_policy, in_axes=(None, None, 0, 0)) - batched_policy = jax.jit(batched_policy) - - initial_trajectory = mppi.get_initial_trajectory( - config=mppi_config, env_spec=environment_spec) - initial_trajectory = jnp.expand_dims(initial_trajectory, axis=0) - - return actor_core_lib.batched_recurrent_to_actor_core(batched_policy, - initial_trajectory) + mean_std_observation = running_statistics.NestedMeanStd( + mean=mean_std.mean.observation, std=mean_std.std.observation + ) + mean_std_action = running_statistics.NestedMeanStd( + mean=mean_std.mean.action, std=mean_std.std.action + ) + mean_std_reward = running_statistics.NestedMeanStd( + mean=mean_std.mean.reward, std=mean_std.std.reward + ) + mean_std_n_step_return = running_statistics.NestedMeanStd( + mean=mean_std.mean.extras["n_step_return"], + std=mean_std.std.extras["n_step_return"], + ) + + def denormalized_world_model( + params: networks_lib.Params, + observation_t: networks_lib.Observation, + action_t: networks_lib.Action, + ) -> Tuple[networks_lib.Observation, networks_lib.Value]: + """Denormalizes the reward for proper weighting in the planner.""" + observation_tp1, normalized_reward_t = world_model( + params, observation_t, action_t + ) + reward_t = running_statistics.denormalize( + normalized_reward_t, mean_std_reward + ) + return observation_tp1, reward_t + + planner_world_model = denormalized_world_model + + def denormalized_n_step_return( + params: networks_lib.Params, + observation_t: networks_lib.Observation, + action_t: networks_lib.Action, + ) -> networks_lib.Value: + """Denormalize the n-step return for proper weighting in the planner.""" + normalized_n_step_return_t = n_step_return(params, observation_t, action_t) + return running_statistics.denormalize( + normalized_n_step_return_t, mean_std_n_step_return + ) + + planner_n_step_return = denormalized_n_step_return + else: + planner_world_model = world_model + planner_n_step_return = n_step_return + + def recurrent_policy( + params_list: List[networks_lib.Params], + random_key: networks_lib.PRNGKey, + observation: networks_lib.Observation, + previous_trajectory: Trajectory, + ) -> Tuple[networks_lib.Action, Trajectory]: + # Note that splitting the random key is handled by GenericActor. + if mean_std is not None: + observation = running_statistics.normalize( + observation, mean_std=mean_std_observation + ) + trajectory = mppi.mppi_planner( + config=mppi_config, + world_model=planner_world_model, + policy_prior=policy_prior, + n_step_return=planner_n_step_return, + world_model_params=params_list[0], + policy_prior_params=params_list[1], + n_step_return_params=params_list[2], + random_key=random_key, + observation=observation, + previous_trajectory=previous_trajectory, + ) + action = trajectory[0, ...] + if mean_std is not None: + action = running_statistics.denormalize(action, mean_std=mean_std_action) + return (action, trajectory) + + batched_policy = jax.vmap(recurrent_policy, in_axes=(None, None, 0, 0)) + batched_policy = jax.jit(batched_policy) + + initial_trajectory = mppi.get_initial_trajectory( + config=mppi_config, env_spec=environment_spec + ) + initial_trajectory = jnp.expand_dims(initial_trajectory, axis=0) + + return actor_core_lib.batched_recurrent_to_actor_core( + batched_policy, initial_trajectory + ) def make_ensemble_actor_core( @@ -142,7 +152,7 @@ def make_ensemble_actor_core( mean_std: Optional[running_statistics.NestedMeanStd] = None, use_round_robin: bool = True, ) -> ActorCore: - """Creates an actor core that uses ensemble models. + """Creates an actor core that uses ensemble models. Args: networks: MBOP networks. @@ -155,23 +165,29 @@ def make_ensemble_actor_core( Returns: A recurrent actor core. """ - world_model = models.make_ensemble_world_model(networks.world_model_network) - policy_prior = models.make_ensemble_policy_prior( - networks.policy_prior_network, - environment_spec, - use_round_robin=use_round_robin) - n_step_return = models.make_ensemble_n_step_return( - networks.n_step_return_network) - - return make_actor_core(mppi_config, world_model, policy_prior, n_step_return, - environment_spec, mean_std) - - -def make_actor(actor_core: ActorCore, - random_key: networks_lib.PRNGKey, - variable_source: core.VariableSource, - adder: Optional[adders.Adder] = None) -> core.Actor: - """Creates an MBOP actor from an actor core. + world_model = models.make_ensemble_world_model(networks.world_model_network) + policy_prior = models.make_ensemble_policy_prior( + networks.policy_prior_network, environment_spec, use_round_robin=use_round_robin + ) + n_step_return = models.make_ensemble_n_step_return(networks.n_step_return_network) + + return make_actor_core( + mppi_config, + world_model, + policy_prior, + n_step_return, + environment_spec, + mean_std, + ) + + +def make_actor( + actor_core: ActorCore, + random_key: networks_lib.PRNGKey, + variable_source: core.VariableSource, + adder: Optional[adders.Adder] = None, +) -> core.Actor: + """Creates an MBOP actor from an actor core. Args: actor_core: An MBOP actor core. @@ -185,9 +201,11 @@ def make_actor(actor_core: ActorCore, Returns: A recurrent actor. """ - variable_client = variable_utils.VariableClient( - client=variable_source, - key=['world_model-policy', 'policy_prior-policy', 'n_step_return-policy']) - - return actors.GenericActor( - actor_core, random_key, variable_client, adder, backend=None) + variable_client = variable_utils.VariableClient( + client=variable_source, + key=["world_model-policy", "policy_prior-policy", "n_step_return-policy"], + ) + + return actors.GenericActor( + actor_core, random_key, variable_client, adder, backend=None + ) diff --git a/acme/agents/jax/mbop/agent_test.py b/acme/agents/jax/mbop/agent_test.py index db0fcadc3e..e723c50a59 100644 --- a/acme/agents/jax/mbop/agent_test.py +++ b/acme/agents/jax/mbop/agent_test.py @@ -16,77 +16,89 @@ import functools -from acme import specs -from acme import types -from acme.agents.jax.mbop import learning -from acme.agents.jax.mbop import losses as mbop_losses -from acme.agents.jax.mbop import networks as mbop_networks -from acme.testing import fakes -from acme.utils import loggers import chex import jax import optax import rlds - from absl.testing import absltest +from acme import specs, types +from acme.agents.jax.mbop import learning +from acme.agents.jax.mbop import losses as mbop_losses +from acme.agents.jax.mbop import networks as mbop_networks +from acme.testing import fakes +from acme.utils import loggers -class MBOPTest(absltest.TestCase): - def test_learner(self): - with chex.fake_pmap_and_jit(): - num_sgd_steps_per_step = 1 - num_steps = 5 - num_networks = 7 - - # Create a fake environment to test with. - environment = fakes.ContinuousEnvironment( - episode_length=10, bounded=True, observation_dim=3, action_dim=2) - - spec = specs.make_environment_spec(environment) - dataset = fakes.transition_dataset(environment) - - # Add dummy n-step return to the transitions. - def _add_dummy_n_step_return(sample): - return types.Transition(*sample.data)._replace( - extras={'n_step_return': 1.0}) - - dataset = dataset.map(_add_dummy_n_step_return) - # Convert into time-batched format with previous, current and next - # transitions. - dataset = rlds.transformations.batch(dataset, 3) - dataset = dataset.batch(8).as_numpy_iterator() - - # Use the default networks and losses. - networks = mbop_networks.make_networks(spec) - losses = mbop_losses.MBOPLosses() - - def logger_fn(label: str, steps_key: str): - return loggers.make_default_logger(label, steps_key=steps_key) - - def make_learner_fn(name, logger_fn, counter, rng_key, dataset, network, - loss): - return learning.make_ensemble_regressor_learner(name, num_networks, - logger_fn, counter, - rng_key, dataset, - network, loss, - optax.adam(0.01), - num_sgd_steps_per_step) - - learner = learning.MBOPLearner( - networks, losses, dataset, jax.random.PRNGKey(0), logger_fn, - functools.partial(make_learner_fn, 'world_model'), - functools.partial(make_learner_fn, 'policy_prior'), - functools.partial(make_learner_fn, 'n_step_return')) - - # Train the agent - for _ in range(num_steps): - learner.step() - - # Save and restore. - learner_state = learner.save() - learner.restore(learner_state) - - -if __name__ == '__main__': - absltest.main() +class MBOPTest(absltest.TestCase): + def test_learner(self): + with chex.fake_pmap_and_jit(): + num_sgd_steps_per_step = 1 + num_steps = 5 + num_networks = 7 + + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment( + episode_length=10, bounded=True, observation_dim=3, action_dim=2 + ) + + spec = specs.make_environment_spec(environment) + dataset = fakes.transition_dataset(environment) + + # Add dummy n-step return to the transitions. + def _add_dummy_n_step_return(sample): + return types.Transition(*sample.data)._replace( + extras={"n_step_return": 1.0} + ) + + dataset = dataset.map(_add_dummy_n_step_return) + # Convert into time-batched format with previous, current and next + # transitions. + dataset = rlds.transformations.batch(dataset, 3) + dataset = dataset.batch(8).as_numpy_iterator() + + # Use the default networks and losses. + networks = mbop_networks.make_networks(spec) + losses = mbop_losses.MBOPLosses() + + def logger_fn(label: str, steps_key: str): + return loggers.make_default_logger(label, steps_key=steps_key) + + def make_learner_fn( + name, logger_fn, counter, rng_key, dataset, network, loss + ): + return learning.make_ensemble_regressor_learner( + name, + num_networks, + logger_fn, + counter, + rng_key, + dataset, + network, + loss, + optax.adam(0.01), + num_sgd_steps_per_step, + ) + + learner = learning.MBOPLearner( + networks, + losses, + dataset, + jax.random.PRNGKey(0), + logger_fn, + functools.partial(make_learner_fn, "world_model"), + functools.partial(make_learner_fn, "policy_prior"), + functools.partial(make_learner_fn, "n_step_return"), + ) + + # Train the agent + for _ in range(num_steps): + learner.step() + + # Save and restore. + learner_state = learner.save() + learner.restore(learner_state) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/jax/mbop/builder.py b/acme/agents/jax/mbop/builder.py index 21622d7793..952c9121fa 100644 --- a/acme/agents/jax/mbop/builder.py +++ b/acme/agents/jax/mbop/builder.py @@ -16,9 +16,9 @@ import functools from typing import Iterator, Optional -from acme import core -from acme import specs -from acme import types +import optax + +from acme import core, specs, types from acme.agents.jax import builders from acme.agents.jax.mbop import acting from acme.agents.jax.mbop import config as mbop_config @@ -27,111 +27,116 @@ from acme.agents.jax.mbop import networks as mbop_networks from acme.jax import networks as networks_lib from acme.jax import running_statistics -from acme.utils import counting -from acme.utils import loggers -import optax +from acme.utils import counting, loggers -class MBOPBuilder(builders.OfflineBuilder[mbop_networks.MBOPNetworks, - acting.ActorCore, types.Transition]): - """MBOP Builder. +class MBOPBuilder( + builders.OfflineBuilder[ + mbop_networks.MBOPNetworks, acting.ActorCore, types.Transition + ] +): + """MBOP Builder. This builder uses ensemble regressor learners for the world model, policy prior and the n-step return models with fixed learning rates. The ensembles and the learning rate are configured in the config. """ - def __init__( - self, - config: mbop_config.MBOPConfig, - losses: mbop_losses.MBOPLosses, - mean_std: Optional[running_statistics.NestedMeanStd] = None, - ): - """Initializes an MBOP builder. + def __init__( + self, + config: mbop_config.MBOPConfig, + losses: mbop_losses.MBOPLosses, + mean_std: Optional[running_statistics.NestedMeanStd] = None, + ): + """Initializes an MBOP builder. Args: config: a config with MBOP hyperparameters. losses: MBOP losses. mean_std: NestedMeanStd used to normalize the samples. """ - self._config = config - self._losses = losses - self._mean_std = mean_std + self._config = config + self._losses = losses + self._mean_std = mean_std - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: mbop_networks.MBOPNetworks, - dataset: Iterator[types.Transition], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - """See base class.""" - - def make_ensemble_regressor_learner( - name: str, + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: mbop_networks.MBOPNetworks, + dataset: Iterator[types.Transition], logger_fn: loggers.LoggerFactory, - counter: counting.Counter, - rng_key: networks_lib.PRNGKey, - iterator: Iterator[types.Transition], - network: networks_lib.FeedForwardNetwork, - loss: mbop_losses.TransitionLoss, + environment_spec: specs.EnvironmentSpec, + counter: Optional[counting.Counter] = None, ) -> core.Learner: - """Creates an ensemble regressor learner.""" - return learning.make_ensemble_regressor_learner( - name, - self._config.num_networks, - logger_fn, - counter, - rng_key, - iterator, - network, - loss, - optax.adam(self._config.learning_rate), - self._config.num_sgd_steps_per_step, - ) + """See base class.""" + + def make_ensemble_regressor_learner( + name: str, + logger_fn: loggers.LoggerFactory, + counter: counting.Counter, + rng_key: networks_lib.PRNGKey, + iterator: Iterator[types.Transition], + network: networks_lib.FeedForwardNetwork, + loss: mbop_losses.TransitionLoss, + ) -> core.Learner: + """Creates an ensemble regressor learner.""" + return learning.make_ensemble_regressor_learner( + name, + self._config.num_networks, + logger_fn, + counter, + rng_key, + iterator, + network, + loss, + optax.adam(self._config.learning_rate), + self._config.num_sgd_steps_per_step, + ) - make_world_model_learner = functools.partial( - make_ensemble_regressor_learner, 'world_model') - make_policy_prior_learner = functools.partial( - make_ensemble_regressor_learner, 'policy_prior') - make_n_step_return_learner = functools.partial( - make_ensemble_regressor_learner, 'n_step_return') - counter = counter or counting.Counter(time_delta=0.) - return learning.MBOPLearner( - networks, - self._losses, - dataset, - random_key, - logger_fn, - make_world_model_learner, - make_policy_prior_learner, - make_n_step_return_learner, - counter, - ) + make_world_model_learner = functools.partial( + make_ensemble_regressor_learner, "world_model" + ) + make_policy_prior_learner = functools.partial( + make_ensemble_regressor_learner, "policy_prior" + ) + make_n_step_return_learner = functools.partial( + make_ensemble_regressor_learner, "n_step_return" + ) + counter = counter or counting.Counter(time_delta=0.0) + return learning.MBOPLearner( + networks, + self._losses, + dataset, + random_key, + logger_fn, + make_world_model_learner, + make_policy_prior_learner, + make_n_step_return_learner, + counter, + ) - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: acting.ActorCore, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - ) -> core.Actor: - """See base class.""" - del environment_spec - return acting.make_actor(policy, random_key, variable_source) + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: acting.ActorCore, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + ) -> core.Actor: + """See base class.""" + del environment_spec + return acting.make_actor(policy, random_key, variable_source) - def make_policy( - self, - networks: mbop_networks.MBOPNetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool, - ) -> acting.ActorCore: - """See base class.""" - return acting.make_ensemble_actor_core( - networks, - self._config.mppi_config, - environment_spec, - self._mean_std, - use_round_robin=not evaluation) + def make_policy( + self, + networks: mbop_networks.MBOPNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool, + ) -> acting.ActorCore: + """See base class.""" + return acting.make_ensemble_actor_core( + networks, + self._config.mppi_config, + environment_spec, + self._mean_std, + use_round_robin=not evaluation, + ) diff --git a/acme/agents/jax/mbop/config.py b/acme/agents/jax/mbop/config.py index 8028430114..ebbc23538d 100644 --- a/acme/agents/jax/mbop/config.py +++ b/acme/agents/jax/mbop/config.py @@ -21,7 +21,7 @@ @dataclasses.dataclass(frozen=True) class MBOPConfig: - """Configuration options for the MBOP agent. + """Configuration options for the MBOP agent. Attributes: mppi_config: Planner hyperparameters. @@ -30,7 +30,8 @@ class MBOPConfig: num_sgd_steps_per_step: How many gradient updates to perform per learner step. """ - mppi_config: mppi.MPPIConfig = mppi.MPPIConfig() - learning_rate: float = 3e-4 - num_networks: int = 5 - num_sgd_steps_per_step: int = 1 + + mppi_config: mppi.MPPIConfig = mppi.MPPIConfig() + learning_rate: float = 3e-4 + num_networks: int = 5 + num_sgd_steps_per_step: int = 1 diff --git a/acme/agents/jax/mbop/dataset.py b/acme/agents/jax/mbop/dataset.py index 22bfdf2065..9f86d806f3 100644 --- a/acme/agents/jax/mbop/dataset.py +++ b/acme/agents/jax/mbop/dataset.py @@ -18,19 +18,20 @@ import itertools from typing import Iterator, Optional -from acme import types -from acme.jax import running_statistics import jax import jax.numpy as jnp import rlds import tensorflow as tf import tree +from acme import types +from acme.jax import running_statistics + # Keys in extras dictionary of the transitions. # Total return over n-steps. -N_STEP_RETURN: str = 'n_step_return' +N_STEP_RETURN: str = "n_step_return" # Total return of the episode that the transition belongs to. -EPISODE_RETURN: str = 'episode_return' +EPISODE_RETURN: str = "episode_return" # Indices of the time-batched transitions. PREVIOUS: int = 0 @@ -39,28 +40,29 @@ def _append_n_step_return(output, n_step_return): - """Append n-step return to an output step.""" - output[N_STEP_RETURN] = n_step_return - return output + """Append n-step return to an output step.""" + output[N_STEP_RETURN] = n_step_return + return output def _append_episode_return(output, episode_return): - """Append episode return to an output step.""" - output[EPISODE_RETURN] = episode_return - return output + """Append episode return to an output step.""" + output[EPISODE_RETURN] = episode_return + return output def _expand_scalars(output): - """If rewards are scalar, expand them.""" - return tree.map_structure(tf.experimental.numpy.atleast_1d, output) + """If rewards are scalar, expand them.""" + return tree.map_structure(tf.experimental.numpy.atleast_1d, output) def episode_to_timestep_batch( episode: rlds.BatchedStep, return_horizon: int = 0, drop_return_horizon: bool = False, - calculate_episode_return: bool = False) -> tf.data.Dataset: - """Converts an episode into multi-timestep batches. + calculate_episode_return: bool = False, +) -> tf.data.Dataset: + """Converts an episode into multi-timestep batches. Args: episode: Batched steps as provided directly by RLDS. @@ -99,60 +101,61 @@ def episode_to_timestep_batch( [2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]], dtype=float32)> ``` """ - steps = episode[rlds.STEPS] - - if drop_return_horizon: - episode_length = steps.cardinality() - steps = steps.take(episode_length - return_horizon) - - # Calculate n-step return: - rewards = steps.map(lambda step: step[rlds.REWARD]) - batched_rewards = rlds.transformations.batch( - rewards, size=return_horizon, shift=1, stride=1, drop_remainder=True) - returns = batched_rewards.map(tf.math.reduce_sum) - output = tf.data.Dataset.zip((steps, returns)).map(_append_n_step_return) - - # Calculate total episode return for potential filtering, use total # of steps - # to calculate return. - if calculate_episode_return: - dtype = jnp.float64 if jax.config.jax_enable_x64 else jnp.float32 - # Need to redefine this here to avoid a tf.data crash. + steps = episode[rlds.STEPS] + + if drop_return_horizon: + episode_length = steps.cardinality() + steps = steps.take(episode_length - return_horizon) + + # Calculate n-step return: rewards = steps.map(lambda step: step[rlds.REWARD]) - episode_return = rewards.reduce(dtype(0), lambda x, y: x + y) - output = output.map( - functools.partial( - _append_episode_return, episode_return=episode_return)) + batched_rewards = rlds.transformations.batch( + rewards, size=return_horizon, shift=1, stride=1, drop_remainder=True + ) + returns = batched_rewards.map(tf.math.reduce_sum) + output = tf.data.Dataset.zip((steps, returns)).map(_append_n_step_return) - output = output.map(_expand_scalars) + # Calculate total episode return for potential filtering, use total # of steps + # to calculate return. + if calculate_episode_return: + dtype = jnp.float64 if jax.config.jax_enable_x64 else jnp.float32 + # Need to redefine this here to avoid a tf.data crash. + rewards = steps.map(lambda step: step[rlds.REWARD]) + episode_return = rewards.reduce(dtype(0), lambda x, y: x + y) + output = output.map( + functools.partial(_append_episode_return, episode_return=episode_return) + ) - output = rlds.transformations.batch( - output, size=3, shift=1, drop_remainder=True) - return output + output = output.map(_expand_scalars) + + output = rlds.transformations.batch(output, size=3, shift=1, drop_remainder=True) + return output def _step_to_transition(rlds_step: rlds.BatchedStep) -> types.Transition: - """Converts batched RLDS steps to batched transitions.""" - return types.Transition( - observation=rlds_step[rlds.OBSERVATION], - action=rlds_step[rlds.ACTION], - reward=rlds_step[rlds.REWARD], - discount=rlds_step[rlds.DISCOUNT], - # We provide next_observation if an algorithm needs it, however note that - # it will only contain s_t and s_t+1, so will be one element short of all - # other attributes (which contain s_t-1, s_t, s_t+1). - next_observation=tree.map_structure(lambda x: x[1:], - rlds_step[rlds.OBSERVATION]), - extras={ - N_STEP_RETURN: rlds_step[N_STEP_RETURN], - }) + """Converts batched RLDS steps to batched transitions.""" + return types.Transition( + observation=rlds_step[rlds.OBSERVATION], + action=rlds_step[rlds.ACTION], + reward=rlds_step[rlds.REWARD], + discount=rlds_step[rlds.DISCOUNT], + # We provide next_observation if an algorithm needs it, however note that + # it will only contain s_t and s_t+1, so will be one element short of all + # other attributes (which contain s_t-1, s_t, s_t+1). + next_observation=tree.map_structure( + lambda x: x[1:], rlds_step[rlds.OBSERVATION] + ), + extras={N_STEP_RETURN: rlds_step[N_STEP_RETURN],}, + ) def episodes_to_timestep_batched_transitions( episode_dataset: tf.data.Dataset, return_horizon: int = 10, drop_return_horizon: bool = False, - min_return_filter: Optional[float] = None) -> tf.data.Dataset: - """Process an existing dataset converting it to episode to 3-transitions. + min_return_filter: Optional[float] = None, +) -> tf.data.Dataset: + """Process an existing dataset converting it to episode to 3-transitions. A 3-transition is an Transition with each attribute having an extra dimension of size 3, representing 3 consecutive timesteps. Each 3-step object will be @@ -168,33 +171,35 @@ def episodes_to_timestep_batched_transitions( Returns: A tf.data.Dataset of 3-transitions. """ - dataset = episode_dataset.interleave( - functools.partial( - episode_to_timestep_batch, - return_horizon=return_horizon, - drop_return_horizon=drop_return_horizon, - calculate_episode_return=min_return_filter is not None), - num_parallel_calls=tf.data.experimental.AUTOTUNE, - deterministic=False) + dataset = episode_dataset.interleave( + functools.partial( + episode_to_timestep_batch, + return_horizon=return_horizon, + drop_return_horizon=drop_return_horizon, + calculate_episode_return=min_return_filter is not None, + ), + num_parallel_calls=tf.data.experimental.AUTOTUNE, + deterministic=False, + ) - if min_return_filter is not None: + if min_return_filter is not None: - def filter_on_return(step): - return step[EPISODE_RETURN][0][0] > min_return_filter + def filter_on_return(step): + return step[EPISODE_RETURN][0][0] > min_return_filter - dataset = dataset.filter(filter_on_return) + dataset = dataset.filter(filter_on_return) - dataset = dataset.map( - _step_to_transition, num_parallel_calls=tf.data.experimental.AUTOTUNE) + dataset = dataset.map( + _step_to_transition, num_parallel_calls=tf.data.experimental.AUTOTUNE + ) - return dataset + return dataset def get_normalization_stats( - iterator: Iterator[types.Transition], - num_normalization_batches: int = 50 + iterator: Iterator[types.Transition], num_normalization_batches: int = 50 ) -> running_statistics.RunningStatisticsState: - """Precomputes normalization statistics over a fixed number of batches. + """Precomputes normalization statistics over a fixed number of batches. The iterator should contain batches of 3-transitions, i.e. with two leading dimensions, the first one denoting the batch dimension and the second one the @@ -208,13 +213,13 @@ def get_normalization_stats( Returns: RunningStatisticsState containing the normalization statistics. """ - # Set up normalization: - example = next(iterator) - unbatched_single_example = jax.tree_map(lambda x: x[0, PREVIOUS, :], example) - mean_std = running_statistics.init_state(unbatched_single_example) + # Set up normalization: + example = next(iterator) + unbatched_single_example = jax.tree_map(lambda x: x[0, PREVIOUS, :], example) + mean_std = running_statistics.init_state(unbatched_single_example) - for batch in itertools.islice(iterator, num_normalization_batches - 1): - example = jax.tree_map(lambda x: x[:, PREVIOUS, :], batch) - mean_std = running_statistics.update(mean_std, example) + for batch in itertools.islice(iterator, num_normalization_batches - 1): + example = jax.tree_map(lambda x: x[:, PREVIOUS, :], batch) + mean_std = running_statistics.update(mean_std, example) - return mean_std + return mean_std diff --git a/acme/agents/jax/mbop/dataset_test.py b/acme/agents/jax/mbop/dataset_test.py index e0f1a93813..d604aa7aa5 100644 --- a/acme/agents/jax/mbop/dataset_test.py +++ b/acme/agents/jax/mbop/dataset_test.py @@ -14,181 +14,153 @@ """Tests for dataset.""" -from acme.agents.jax.mbop import dataset as dataset_lib import rlds -from rlds.transformations import transformations_testlib import tensorflow as tf - from absl.testing import absltest +from rlds.transformations import transformations_testlib +from acme.agents.jax.mbop import dataset as dataset_lib -def sample_episode() -> rlds.Episode: - """Returns a sample episode.""" - steps = { - rlds.OBSERVATION: [ - [1, 1], - [2, 2], - [3, 3], - [4, 4], - [5, 5], - ], - rlds.ACTION: [[1], [2], [3], [4], [5]], - rlds.REWARD: [1.0, 2.0, 3.0, 4.0, 5.0], - rlds.DISCOUNT: [1, 1, 1, 1, 1], - rlds.IS_FIRST: [True, False, False, False, False], - rlds.IS_LAST: [False, False, False, False, True], - rlds.IS_TERMINAL: [False, False, False, False, True], - } - return {rlds.STEPS: tf.data.Dataset.from_tensor_slices(steps)} - - -class DatasetTest(transformations_testlib.TransformationsTest): - - def test_episode_to_timestep_batch(self): - batched = dataset_lib.episode_to_timestep_batch( - sample_episode(), return_horizon=2) - - # Scalars should be expanded and the n-step return should be present. Each - # element of a step should be a triplet containing the previous, current and - # next values of the corresponding fields. Since the return horizon is 2 and - # the number of steps in the episode is 5, there can be only 2 triplets for - # time steps 1 and 2. - expected_steps = { - rlds.OBSERVATION: [ - [[1, 1], [2, 2], [3, 3]], - [[2, 2], [3, 3], [4, 4]], - ], - rlds.ACTION: [ - [[1], [2], [3]], - [[2], [3], [4]], - ], - rlds.REWARD: [ - [[1.0], [2.0], [3.0]], - [[2.0], [3.0], [4.0]], - ], - rlds.DISCOUNT: [ - [[1], [1], [1]], - [[1], [1], [1]], - ], - rlds.IS_FIRST: [ - [[True], [False], [False]], - [[False], [False], [False]], - ], - rlds.IS_LAST: [ - [[False], [False], [False]], - [[False], [False], [False]], - ], - rlds.IS_TERMINAL: [ - [[False], [False], [False]], - [[False], [False], [False]], - ], - dataset_lib.N_STEP_RETURN: [ - [[3.0], [5.0], [7.0]], - [[5.0], [7.0], [9.0]], - ], - } - - self.expect_equal_datasets( - batched, tf.data.Dataset.from_tensor_slices(expected_steps)) - - def test_episode_to_timestep_batch_episode_return(self): - batched = dataset_lib.episode_to_timestep_batch( - sample_episode(), return_horizon=3, calculate_episode_return=True) - - expected_steps = { - rlds.OBSERVATION: [[[1, 1], [2, 2], [3, 3]]], - rlds.ACTION: [[[1], [2], [3]]], - rlds.REWARD: [[[1.0], [2.0], [3.0]]], - rlds.DISCOUNT: [[[1], [1], [1]]], - rlds.IS_FIRST: [[[True], [False], [False]]], - rlds.IS_LAST: [[[False], [False], [False]]], - rlds.IS_TERMINAL: [[[False], [False], [False]]], - dataset_lib.N_STEP_RETURN: [[[6.0], [9.0], [12.0]]], - # This should match to the sum of the rewards in the input. - dataset_lib.EPISODE_RETURN: [[[15.0], [15.0], [15.0]]], - } - - self.expect_equal_datasets( - batched, tf.data.Dataset.from_tensor_slices(expected_steps)) - - def test_episode_to_timestep_batch_no_return_horizon(self): - batched = dataset_lib.episode_to_timestep_batch( - sample_episode(), return_horizon=1) - - expected_steps = { - rlds.OBSERVATION: [ - [[1, 1], [2, 2], [3, 3]], - [[2, 2], [3, 3], [4, 4]], - [[3, 3], [4, 4], [5, 5]], - ], - rlds.ACTION: [ - [[1], [2], [3]], - [[2], [3], [4]], - [[3], [4], [5]], - ], - rlds.REWARD: [ - [[1.0], [2.0], [3.0]], - [[2.0], [3.0], [4.0]], - [[3.0], [4.0], [5.0]], - ], - rlds.DISCOUNT: [ - [[1], [1], [1]], - [[1], [1], [1]], - [[1], [1], [1]], - ], - rlds.IS_FIRST: [ - [[True], [False], [False]], - [[False], [False], [False]], - [[False], [False], [False]], - ], - rlds.IS_LAST: [ - [[False], [False], [False]], - [[False], [False], [False]], - [[False], [False], [True]], - ], - rlds.IS_TERMINAL: [ - [[False], [False], [False]], - [[False], [False], [False]], - [[False], [False], [True]], - ], - # n-step return should be equal to the rewards. - dataset_lib.N_STEP_RETURN: [ - [[1.0], [2.0], [3.0]], - [[2.0], [3.0], [4.0]], - [[3.0], [4.0], [5.0]], - ], - } - - self.expect_equal_datasets( - batched, tf.data.Dataset.from_tensor_slices(expected_steps)) - def test_episode_to_timestep_batch_drop_return_horizon(self): +def sample_episode() -> rlds.Episode: + """Returns a sample episode.""" steps = { - rlds.OBSERVATION: [[1], [2], [3], [4], [5], [6]], - rlds.REWARD: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], - } - episode = {rlds.STEPS: tf.data.Dataset.from_tensor_slices(steps)} - - batched = dataset_lib.episode_to_timestep_batch( - episode, - return_horizon=2, - calculate_episode_return=True, - drop_return_horizon=True) - - # The two steps of the episode should be dropped. There will be 4 steps left - # and since the return horizon is 2, only a single 3-batched step should be - # emitted. The episode return should be the sum of the rewards of the first - # 4 steps. - expected_steps = { - rlds.OBSERVATION: [[[1], [2], [3]]], - rlds.REWARD: [[[1.0], [2.0], [3.0]]], - dataset_lib.N_STEP_RETURN: [[[3.0], [5.0], [7.0]]], - dataset_lib.EPISODE_RETURN: [[[10.0], [10.0], [10.0]]], + rlds.OBSERVATION: [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5],], + rlds.ACTION: [[1], [2], [3], [4], [5]], + rlds.REWARD: [1.0, 2.0, 3.0, 4.0, 5.0], + rlds.DISCOUNT: [1, 1, 1, 1, 1], + rlds.IS_FIRST: [True, False, False, False, False], + rlds.IS_LAST: [False, False, False, False, True], + rlds.IS_TERMINAL: [False, False, False, False, True], } + return {rlds.STEPS: tf.data.Dataset.from_tensor_slices(steps)} - self.expect_equal_datasets( - batched, tf.data.Dataset.from_tensor_slices(expected_steps)) - -if __name__ == '__main__': - absltest.main() +class DatasetTest(transformations_testlib.TransformationsTest): + def test_episode_to_timestep_batch(self): + batched = dataset_lib.episode_to_timestep_batch( + sample_episode(), return_horizon=2 + ) + + # Scalars should be expanded and the n-step return should be present. Each + # element of a step should be a triplet containing the previous, current and + # next values of the corresponding fields. Since the return horizon is 2 and + # the number of steps in the episode is 5, there can be only 2 triplets for + # time steps 1 and 2. + expected_steps = { + rlds.OBSERVATION: [[[1, 1], [2, 2], [3, 3]], [[2, 2], [3, 3], [4, 4]],], + rlds.ACTION: [[[1], [2], [3]], [[2], [3], [4]],], + rlds.REWARD: [[[1.0], [2.0], [3.0]], [[2.0], [3.0], [4.0]],], + rlds.DISCOUNT: [[[1], [1], [1]], [[1], [1], [1]],], + rlds.IS_FIRST: [[[True], [False], [False]], [[False], [False], [False]],], + rlds.IS_LAST: [[[False], [False], [False]], [[False], [False], [False]],], + rlds.IS_TERMINAL: [ + [[False], [False], [False]], + [[False], [False], [False]], + ], + dataset_lib.N_STEP_RETURN: [[[3.0], [5.0], [7.0]], [[5.0], [7.0], [9.0]],], + } + + self.expect_equal_datasets( + batched, tf.data.Dataset.from_tensor_slices(expected_steps) + ) + + def test_episode_to_timestep_batch_episode_return(self): + batched = dataset_lib.episode_to_timestep_batch( + sample_episode(), return_horizon=3, calculate_episode_return=True + ) + + expected_steps = { + rlds.OBSERVATION: [[[1, 1], [2, 2], [3, 3]]], + rlds.ACTION: [[[1], [2], [3]]], + rlds.REWARD: [[[1.0], [2.0], [3.0]]], + rlds.DISCOUNT: [[[1], [1], [1]]], + rlds.IS_FIRST: [[[True], [False], [False]]], + rlds.IS_LAST: [[[False], [False], [False]]], + rlds.IS_TERMINAL: [[[False], [False], [False]]], + dataset_lib.N_STEP_RETURN: [[[6.0], [9.0], [12.0]]], + # This should match to the sum of the rewards in the input. + dataset_lib.EPISODE_RETURN: [[[15.0], [15.0], [15.0]]], + } + + self.expect_equal_datasets( + batched, tf.data.Dataset.from_tensor_slices(expected_steps) + ) + + def test_episode_to_timestep_batch_no_return_horizon(self): + batched = dataset_lib.episode_to_timestep_batch( + sample_episode(), return_horizon=1 + ) + + expected_steps = { + rlds.OBSERVATION: [ + [[1, 1], [2, 2], [3, 3]], + [[2, 2], [3, 3], [4, 4]], + [[3, 3], [4, 4], [5, 5]], + ], + rlds.ACTION: [[[1], [2], [3]], [[2], [3], [4]], [[3], [4], [5]],], + rlds.REWARD: [ + [[1.0], [2.0], [3.0]], + [[2.0], [3.0], [4.0]], + [[3.0], [4.0], [5.0]], + ], + rlds.DISCOUNT: [[[1], [1], [1]], [[1], [1], [1]], [[1], [1], [1]],], + rlds.IS_FIRST: [ + [[True], [False], [False]], + [[False], [False], [False]], + [[False], [False], [False]], + ], + rlds.IS_LAST: [ + [[False], [False], [False]], + [[False], [False], [False]], + [[False], [False], [True]], + ], + rlds.IS_TERMINAL: [ + [[False], [False], [False]], + [[False], [False], [False]], + [[False], [False], [True]], + ], + # n-step return should be equal to the rewards. + dataset_lib.N_STEP_RETURN: [ + [[1.0], [2.0], [3.0]], + [[2.0], [3.0], [4.0]], + [[3.0], [4.0], [5.0]], + ], + } + + self.expect_equal_datasets( + batched, tf.data.Dataset.from_tensor_slices(expected_steps) + ) + + def test_episode_to_timestep_batch_drop_return_horizon(self): + steps = { + rlds.OBSERVATION: [[1], [2], [3], [4], [5], [6]], + rlds.REWARD: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + } + episode = {rlds.STEPS: tf.data.Dataset.from_tensor_slices(steps)} + + batched = dataset_lib.episode_to_timestep_batch( + episode, + return_horizon=2, + calculate_episode_return=True, + drop_return_horizon=True, + ) + + # The two steps of the episode should be dropped. There will be 4 steps left + # and since the return horizon is 2, only a single 3-batched step should be + # emitted. The episode return should be the sum of the rewards of the first + # 4 steps. + expected_steps = { + rlds.OBSERVATION: [[[1], [2], [3]]], + rlds.REWARD: [[[1.0], [2.0], [3.0]]], + dataset_lib.N_STEP_RETURN: [[[3.0], [5.0], [7.0]]], + dataset_lib.EPISODE_RETURN: [[[10.0], [10.0], [10.0]]], + } + + self.expect_equal_datasets( + batched, tf.data.Dataset.from_tensor_slices(expected_steps) + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/jax/mbop/ensemble.py b/acme/agents/jax/mbop/ensemble.py index c7ccc412c8..131d695e9d 100644 --- a/acme/agents/jax/mbop/ensemble.py +++ b/acme/agents/jax/mbop/ensemble.py @@ -14,15 +14,16 @@ """Module to provide ensembling support on top of a base network.""" import functools -from typing import (Any, Callable) +from typing import Any, Callable -from acme.jax import networks import jax import jax.numpy as jnp +from acme.jax import networks + def _split_batch_dimension(new_batch: int, data: jnp.ndarray) -> jnp.ndarray: - """Splits the batch dimension and introduces new one with size `new_batch`. + """Splits the batch dimension and introduces new one with size `new_batch`. The result has two batch dimensions, first one of size `new_batch`, second one of size `data.shape[0]/new_batch`. It expects that `data.shape[0]` is @@ -36,22 +37,25 @@ def _split_batch_dimension(new_batch: int, data: jnp.ndarray) -> jnp.ndarray: jnp.ndarray with extra batch dimension at start and updated second dimension. """ - # The first dimension will be used for allocating to a specific ensemble - # member, and the second dimension is the parallelized batch dimension, and - # the remaining dimensions are passed as-is to the wrapped network. - # We use Fortan (F) order so that each input batch i is allocated to - # ensemble member k = i % new_batch. - return jnp.reshape(data, (new_batch, -1) + data.shape[1:], order='F') + # The first dimension will be used for allocating to a specific ensemble + # member, and the second dimension is the parallelized batch dimension, and + # the remaining dimensions are passed as-is to the wrapped network. + # We use Fortan (F) order so that each input batch i is allocated to + # ensemble member k = i % new_batch. + return jnp.reshape(data, (new_batch, -1) + data.shape[1:], order="F") def _repeat_n(new_batch: int, data: jnp.ndarray) -> jnp.ndarray: - """Create new batch dimension of size `new_batch` by repeating `data`.""" - return jnp.broadcast_to(data, (new_batch,) + data.shape) + """Create new batch dimension of size `new_batch` by repeating `data`.""" + return jnp.broadcast_to(data, (new_batch,) + data.shape) -def ensemble_init(base_init: Callable[[networks.PRNGKey], networks.Params], - num_networks: int, rnd: jnp.ndarray): - """Initializes the ensemble parameters. +def ensemble_init( + base_init: Callable[[networks.PRNGKey], networks.Params], + num_networks: int, + rnd: jnp.ndarray, +): + """Initializes the ensemble parameters. Args: base_init: An init function that takes only a PRNGKey, if a network's init @@ -63,13 +67,17 @@ def ensemble_init(base_init: Callable[[networks.PRNGKey], networks.Params], Returns: `params` for the set of ensemble networks. """ - rnds = jax.random.split(rnd, num_networks) - return jax.vmap(base_init)(rnds) + rnds = jax.random.split(rnd, num_networks) + return jax.vmap(base_init)(rnds) -def apply_round_robin(base_apply: Callable[[networks.Params, Any], Any], - params: networks.Params, *args, **kwargs) -> Any: - """Passes the input in a round-robin manner. +def apply_round_robin( + base_apply: Callable[[networks.Params, Any], Any], + params: networks.Params, + *args, + **kwargs +) -> Any: + """Passes the input in a round-robin manner. The round-robin application means that each element of the input batch will be passed through a single ensemble member in a deterministic round-robin @@ -96,24 +104,28 @@ def apply_round_robin(base_apply: Callable[[networks.Params, Any], Any], pytree of the round-robin application. Output shape will be [initial_batch_size, ]. """ - # `num_networks` is the size of the batch dimension in `params`. - num_networks = jax.tree_util.tree_leaves(params)[0].shape[0] - - # Reshape args and kwargs for the round-robin: - args = jax.tree_map( - functools.partial(_split_batch_dimension, num_networks), args) - kwargs = jax.tree_map( - functools.partial(_split_batch_dimension, num_networks), kwargs) - # `out.shape` is `(num_networks, initial_batch_size/num_networks, ...) - out = jax.vmap(base_apply)(params, *args, **kwargs) - # Reshape to [initial_batch_size, ]. Using the 'F' order - # forces the original values to the last dimension. - return jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:], order='F'), out) - - -def apply_all(base_apply: Callable[[networks.Params, Any], Any], - params: networks.Params, *args, **kwargs) -> Any: - """Pass the input to all ensemble members. + # `num_networks` is the size of the batch dimension in `params`. + num_networks = jax.tree_util.tree_leaves(params)[0].shape[0] + + # Reshape args and kwargs for the round-robin: + args = jax.tree_map(functools.partial(_split_batch_dimension, num_networks), args) + kwargs = jax.tree_map( + functools.partial(_split_batch_dimension, num_networks), kwargs + ) + # `out.shape` is `(num_networks, initial_batch_size/num_networks, ...) + out = jax.vmap(base_apply)(params, *args, **kwargs) + # Reshape to [initial_batch_size, ]. Using the 'F' order + # forces the original values to the last dimension. + return jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:], order="F"), out) + + +def apply_all( + base_apply: Callable[[networks.Params, Any], Any], + params: networks.Params, + *args, + **kwargs +) -> Any: + """Pass the input to all ensemble members. Inputs can either have a batch dimension which will get implicitly vmapped over, or can be a single vector which will get sent to all ensemble members. @@ -130,18 +142,22 @@ def apply_all(base_apply: Callable[[networks.Params, Any], Any], pytree of the resulting output of passing input to all ensemble members. Output shape will be [num_members, batch_size, ]. """ - # `num_networks` is the size of the batch dimension in `params`. - num_networks = jax.tree_util.tree_leaves(params)[0].shape[0] + # `num_networks` is the size of the batch dimension in `params`. + num_networks = jax.tree_util.tree_leaves(params)[0].shape[0] - args = jax.tree_map(functools.partial(_repeat_n, num_networks), args) - kwargs = jax.tree_map(functools.partial(_repeat_n, num_networks), kwargs) - # `out` is of shape `(num_networks, batch_size, )`. - return jax.vmap(base_apply)(params, *args, **kwargs) + args = jax.tree_map(functools.partial(_repeat_n, num_networks), args) + kwargs = jax.tree_map(functools.partial(_repeat_n, num_networks), kwargs) + # `out` is of shape `(num_networks, batch_size, )`. + return jax.vmap(base_apply)(params, *args, **kwargs) -def apply_mean(base_apply: Callable[[networks.Params, Any], Any], - params: networks.Params, *args, **kwargs) -> Any: - """Calculates the mean over all ensemble members for each batch element. +def apply_mean( + base_apply: Callable[[networks.Params, Any], Any], + params: networks.Params, + *args, + **kwargs +) -> Any: + """Calculates the mean over all ensemble members for each batch element. Args: base_apply: Base network `apply` function that will be used for averaging. @@ -154,13 +170,16 @@ def apply_mean(base_apply: Callable[[networks.Params, Any], Any], pytree of the average over all ensembles for each element. Output shape will be [batch_size, ] """ - out = apply_all(base_apply, params, *args, **kwargs) - return jax.tree_map(functools.partial(jnp.mean, axis=0), out) - - -def make_ensemble(base_network: networks.FeedForwardNetwork, - ensemble_apply: Callable[..., Any], - num_networks: int) -> networks.FeedForwardNetwork: - return networks.FeedForwardNetwork( - init=functools.partial(ensemble_init, base_network.init, num_networks), - apply=functools.partial(ensemble_apply, base_network.apply)) + out = apply_all(base_apply, params, *args, **kwargs) + return jax.tree_map(functools.partial(jnp.mean, axis=0), out) + + +def make_ensemble( + base_network: networks.FeedForwardNetwork, + ensemble_apply: Callable[..., Any], + num_networks: int, +) -> networks.FeedForwardNetwork: + return networks.FeedForwardNetwork( + init=functools.partial(ensemble_init, base_network.init, num_networks), + apply=functools.partial(ensemble_apply, base_network.apply), + ) diff --git a/acme/agents/jax/mbop/ensemble_test.py b/acme/agents/jax/mbop/ensemble_test.py index 9890a78121..4919cc5679 100644 --- a/acme/agents/jax/mbop/ensemble_test.py +++ b/acme/agents/jax/mbop/ensemble_test.py @@ -17,313 +17,329 @@ import functools from typing import Any -from acme.agents.jax.mbop import ensemble -from acme.jax import networks -from flax import linen as nn import jax import jax.numpy as jnp import numpy as np - from absl.testing import absltest +from flax import linen as nn +from acme.agents.jax.mbop import ensemble +from acme.jax import networks -class RandomFFN(nn.Module): - @nn.compact - def __call__(self, x): - return nn.Dense(15)(x) +class RandomFFN(nn.Module): + @nn.compact + def __call__(self, x): + return nn.Dense(15)(x) def params_adding_ffn(x: jnp.ndarray) -> networks.FeedForwardNetwork: - """Apply adds the parameters to the inputs.""" - return networks.FeedForwardNetwork( - init=lambda key, x=x: jax.random.uniform(key, x.shape), - apply=lambda params, x: params + x) + """Apply adds the parameters to the inputs.""" + return networks.FeedForwardNetwork( + init=lambda key, x=x: jax.random.uniform(key, x.shape), + apply=lambda params, x: params + x, + ) def funny_args_ffn(x: jnp.ndarray) -> networks.FeedForwardNetwork: - """Apply takes additional parameters, returns `params + x + foo - bar`.""" - return networks.FeedForwardNetwork( - init=lambda key, x=x: jax.random.uniform(key, x.shape), - apply=lambda params, x, foo, bar: params + x + foo - bar) + """Apply takes additional parameters, returns `params + x + foo - bar`.""" + return networks.FeedForwardNetwork( + init=lambda key, x=x: jax.random.uniform(key, x.shape), + apply=lambda params, x, foo, bar: params + x + foo - bar, + ) def struct_params_adding_ffn(sx: Any) -> networks.FeedForwardNetwork: - """Like params_adding_ffn, but with pytree inputs, preserves structure.""" + """Like params_adding_ffn, but with pytree inputs, preserves structure.""" - def init_fn(key, sx=sx): - return jax.tree_map(lambda x: jax.random.uniform(key, x.shape), sx) + def init_fn(key, sx=sx): + return jax.tree_map(lambda x: jax.random.uniform(key, x.shape), sx) - def apply_fn(params, x): - return jax.tree_map(lambda p, v: p + v, params, x) + def apply_fn(params, x): + return jax.tree_map(lambda p, v: p + v, params, x) - return networks.FeedForwardNetwork(init=init_fn, apply=apply_fn) + return networks.FeedForwardNetwork(init=init_fn, apply=apply_fn) class EnsembleTest(absltest.TestCase): - - def test_ensemble_init(self): - x = jnp.ones(10) # Base input - - wrapped_ffn = params_adding_ffn(x) - - rr_ensemble = ensemble.make_ensemble( - wrapped_ffn, ensemble.apply_round_robin, num_networks=3) - key = jax.random.PRNGKey(0) - params = rr_ensemble.init(key) - - self.assertTupleEqual(params.shape, (3,) + x.shape) - - # The ensemble dimension is the lead dimension. - self.assertFalse((params[0, ...] == params[1, ...]).all()) - - def test_apply_all(self): - x = jnp.ones(10) # Base input - bx = jnp.ones((7, 10)) # Batched input - - wrapped_ffn = params_adding_ffn(x) - - rr_ensemble = ensemble.make_ensemble( - wrapped_ffn, ensemble.apply_all, num_networks=3) - key = jax.random.PRNGKey(0) - params = rr_ensemble.init(key) - self.assertTupleEqual(params.shape, (3,) + x.shape) - - y = rr_ensemble.apply(params, x) - self.assertTupleEqual(y.shape, (3,) + x.shape) - np.testing.assert_allclose(params, y - jnp.broadcast_to(x, (3,) + x.shape)) - - by = rr_ensemble.apply(params, bx) - # Note: the batch dimension is no longer the leading dimension. - self.assertTupleEqual(by.shape, (3,) + bx.shape) - - def test_apply_round_robin(self): - x = jnp.ones(10) # Base input - bx = jnp.ones((7, 10)) # Batched input - - wrapped_ffn = params_adding_ffn(x) - - rr_ensemble = ensemble.make_ensemble( - wrapped_ffn, ensemble.apply_round_robin, num_networks=3) - key = jax.random.PRNGKey(0) - params = rr_ensemble.init(key) - self.assertTupleEqual(params.shape, (3,) + x.shape) - - y = rr_ensemble.apply(params, jnp.broadcast_to(x, (3,) + x.shape)) - self.assertTupleEqual(y.shape, (3,) + x.shape) - np.testing.assert_allclose(params, y - x) - - # Note: the ensemble dimension must lead, the batch dimension is no longer - # the leading dimension. - by = rr_ensemble.apply( - params, jnp.broadcast_to(jnp.expand_dims(bx, axis=0), (3,) + bx.shape)) - self.assertTupleEqual(by.shape, (3,) + bx.shape) - - # If num_networks=3, then `round_robin(params, input)[4]` should be equal - # to `apply(params[1], input[4])`, etc. - yy = rr_ensemble.apply(params, jnp.broadcast_to(x, (6,) + x.shape)) - self.assertTupleEqual(yy.shape, (6,) + x.shape) - np.testing.assert_allclose( - jnp.concatenate([params, params], axis=0), - yy - jnp.expand_dims(x, axis=0)) - - def test_apply_mean(self): - x = jnp.ones(10) # Base input - bx = jnp.ones((7, 10)) # Batched input - - wrapped_ffn = params_adding_ffn(x) - - rr_ensemble = ensemble.make_ensemble( - wrapped_ffn, ensemble.apply_mean, num_networks=3) - key = jax.random.PRNGKey(0) - params = rr_ensemble.init(key) - self.assertTupleEqual(params.shape, (3,) + x.shape) - self.assertFalse((params[0, ...] == params[1, ...]).all()) - - y = rr_ensemble.apply(params, x) - self.assertTupleEqual(y.shape, x.shape) - np.testing.assert_allclose( - jnp.mean(params, axis=0), y - x, atol=1E-5, rtol=1E-5) - - by = rr_ensemble.apply(params, bx) - self.assertTupleEqual(by.shape, bx.shape) - - def test_apply_all_multiargs(self): - x = jnp.ones(10) # Base input - - wrapped_ffn = funny_args_ffn(x) - - rr_ensemble = ensemble.make_ensemble( - wrapped_ffn, ensemble.apply_all, num_networks=3) - key = jax.random.PRNGKey(0) - params = rr_ensemble.init(key) - self.assertTupleEqual(params.shape, (3,) + x.shape) - - y = rr_ensemble.apply(params, x, 2 * x, x) - self.assertTupleEqual(y.shape, (3,) + x.shape) - np.testing.assert_allclose( - params, - y - jnp.broadcast_to(2 * x, (3,) + x.shape), - atol=1E-5, - rtol=1E-5) - - y = rr_ensemble.apply(params, x, bar=x, foo=2 * x) - self.assertTupleEqual(y.shape, (3,) + x.shape) - np.testing.assert_allclose( - params, - y - jnp.broadcast_to(2 * x, (3,) + x.shape), - atol=1E-5, - rtol=1E-5) - - def test_apply_all_structured(self): - x = jnp.ones(10) - sx = [(3 * x, 2 * x), 5 * x] # Base input - - wrapped_ffn = struct_params_adding_ffn(sx) - - rr_ensemble = ensemble.make_ensemble( - wrapped_ffn, ensemble.apply_all, num_networks=3) - key = jax.random.PRNGKey(0) - params = rr_ensemble.init(key) - - y = rr_ensemble.apply(params, sx) - ex = jnp.broadcast_to(x, (3,) + x.shape) - np.testing.assert_allclose(y[0][0], params[0][0] + 3 * ex) - - def test_apply_round_robin_multiargs(self): - x = jnp.ones(10) # Base input - - wrapped_ffn = funny_args_ffn(x) - - rr_ensemble = ensemble.make_ensemble( - wrapped_ffn, ensemble.apply_round_robin, num_networks=3) - key = jax.random.PRNGKey(0) - params = rr_ensemble.init(key) - self.assertTupleEqual(params.shape, (3,) + x.shape) - - ex = jnp.broadcast_to(x, (3,) + x.shape) - y = rr_ensemble.apply(params, ex, 2 * ex, ex) - self.assertTupleEqual(y.shape, (3,) + x.shape) - np.testing.assert_allclose( - params, - y - jnp.broadcast_to(2 * x, (3,) + x.shape), - atol=1E-5, - rtol=1E-5) - - y = rr_ensemble.apply(params, ex, bar=ex, foo=2 * ex) - self.assertTupleEqual(y.shape, (3,) + x.shape) - np.testing.assert_allclose( - params, - y - jnp.broadcast_to(2 * x, (3,) + x.shape), - atol=1E-5, - rtol=1E-5) - - def test_apply_round_robin_structured(self): - x = jnp.ones(10) - sx = [(3 * x, 2 * x), 5 * x] # Base input - - wrapped_ffn = struct_params_adding_ffn(sx) - - rr_ensemble = ensemble.make_ensemble( - wrapped_ffn, ensemble.apply_round_robin, num_networks=3) - key = jax.random.PRNGKey(0) - params = rr_ensemble.init(key) - - ex = jnp.broadcast_to(x, (3,) + x.shape) - esx = [(3 * ex, 2 * ex), 5 * ex] - y = rr_ensemble.apply(params, esx) - np.testing.assert_allclose(y[0][0], params[0][0] + 3 * ex) - - def test_apply_mean_multiargs(self): - x = jnp.ones(10) # Base input - - wrapped_ffn = funny_args_ffn(x) - - rr_ensemble = ensemble.make_ensemble( - wrapped_ffn, ensemble.apply_mean, num_networks=3) - key = jax.random.PRNGKey(0) - params = rr_ensemble.init(key) - self.assertTupleEqual(params.shape, (3,) + x.shape) - - y = rr_ensemble.apply(params, x, 2 * x, x) - self.assertTupleEqual(y.shape, x.shape) - np.testing.assert_allclose( - jnp.mean(params, axis=0), y - 2 * x, atol=1E-5, rtol=1E-5) - - y = rr_ensemble.apply(params, x, bar=x, foo=2 * x) - self.assertTupleEqual(y.shape, x.shape) - np.testing.assert_allclose( - jnp.mean(params, axis=0), y - 2 * x, atol=1E-5, rtol=1E-5) - - def test_apply_mean_structured(self): - x = jnp.ones(10) - sx = [(3 * x, 2 * x), 5 * x] # Base input - - wrapped_ffn = struct_params_adding_ffn(sx) - - rr_ensemble = ensemble.make_ensemble( - wrapped_ffn, ensemble.apply_mean, num_networks=3) - key = jax.random.PRNGKey(0) - params = rr_ensemble.init(key) - - y = rr_ensemble.apply(params, sx) - np.testing.assert_allclose( - y[0][0], jnp.mean(params[0][0], axis=0) + 3 * x, atol=1E-5, rtol=1E-5) - - def test_round_robin_random(self): - x = jnp.ones(10) # Base input - bx = jnp.ones((9, 10)) # Batched input - ffn = RandomFFN() - wrapped_ffn = networks.FeedForwardNetwork( - init=functools.partial(ffn.init, x=x), apply=ffn.apply) - rr_ensemble = ensemble.make_ensemble( - wrapped_ffn, ensemble.apply_round_robin, num_networks=3) - - key = jax.random.PRNGKey(0) - params = rr_ensemble.init(key) - out = rr_ensemble.apply(params, bx) - # The output should be the same every 3 rows: - blocks = jnp.split(out, 3, axis=0) - np.testing.assert_array_equal(blocks[0], blocks[1]) - np.testing.assert_array_equal(blocks[0], blocks[2]) - self.assertTrue((out[0] != out[1]).any()) - - for i in range(9): - np.testing.assert_allclose( - out[i], - ffn.apply(jax.tree_map(lambda p, i=i: p[i % 3], params), bx[i]), - atol=1E-5, - rtol=1E-5) - - def test_mean_random(self): - x = jnp.ones(10) - bx = jnp.ones((9, 10)) - ffn = RandomFFN() - wrapped_ffn = networks.FeedForwardNetwork( - init=functools.partial(ffn.init, x=x), apply=ffn.apply) - mean_ensemble = ensemble.make_ensemble( - wrapped_ffn, ensemble.apply_mean, num_networks=3) - key = jax.random.PRNGKey(0) - params = mean_ensemble.init(key) - single_output = mean_ensemble.apply(params, x) - self.assertEqual(single_output.shape, (15,)) - batch_output = mean_ensemble.apply(params, bx) - # Make sure all rows are equal: - np.testing.assert_allclose( - jnp.broadcast_to(batch_output[0], batch_output.shape), - batch_output, - atol=1E-5, - rtol=1E-5) - - # Check results explicitly: - all_members = jnp.concatenate([ - jnp.expand_dims( - ffn.apply(jax.tree_map(lambda p, i=i: p[i], params), bx), axis=0) - for i in range(3) - ]) - batch_means = jnp.mean(all_members, axis=0) - np.testing.assert_allclose(batch_output, batch_means, atol=1E-5, rtol=1E-5) - - -if __name__ == '__main__': - absltest.main() + def test_ensemble_init(self): + x = jnp.ones(10) # Base input + + wrapped_ffn = params_adding_ffn(x) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_round_robin, num_networks=3 + ) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + + self.assertTupleEqual(params.shape, (3,) + x.shape) + + # The ensemble dimension is the lead dimension. + self.assertFalse((params[0, ...] == params[1, ...]).all()) + + def test_apply_all(self): + x = jnp.ones(10) # Base input + bx = jnp.ones((7, 10)) # Batched input + + wrapped_ffn = params_adding_ffn(x) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_all, num_networks=3 + ) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + self.assertTupleEqual(params.shape, (3,) + x.shape) + + y = rr_ensemble.apply(params, x) + self.assertTupleEqual(y.shape, (3,) + x.shape) + np.testing.assert_allclose(params, y - jnp.broadcast_to(x, (3,) + x.shape)) + + by = rr_ensemble.apply(params, bx) + # Note: the batch dimension is no longer the leading dimension. + self.assertTupleEqual(by.shape, (3,) + bx.shape) + + def test_apply_round_robin(self): + x = jnp.ones(10) # Base input + bx = jnp.ones((7, 10)) # Batched input + + wrapped_ffn = params_adding_ffn(x) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_round_robin, num_networks=3 + ) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + self.assertTupleEqual(params.shape, (3,) + x.shape) + + y = rr_ensemble.apply(params, jnp.broadcast_to(x, (3,) + x.shape)) + self.assertTupleEqual(y.shape, (3,) + x.shape) + np.testing.assert_allclose(params, y - x) + + # Note: the ensemble dimension must lead, the batch dimension is no longer + # the leading dimension. + by = rr_ensemble.apply( + params, jnp.broadcast_to(jnp.expand_dims(bx, axis=0), (3,) + bx.shape) + ) + self.assertTupleEqual(by.shape, (3,) + bx.shape) + + # If num_networks=3, then `round_robin(params, input)[4]` should be equal + # to `apply(params[1], input[4])`, etc. + yy = rr_ensemble.apply(params, jnp.broadcast_to(x, (6,) + x.shape)) + self.assertTupleEqual(yy.shape, (6,) + x.shape) + np.testing.assert_allclose( + jnp.concatenate([params, params], axis=0), yy - jnp.expand_dims(x, axis=0) + ) + + def test_apply_mean(self): + x = jnp.ones(10) # Base input + bx = jnp.ones((7, 10)) # Batched input + + wrapped_ffn = params_adding_ffn(x) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_mean, num_networks=3 + ) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + self.assertTupleEqual(params.shape, (3,) + x.shape) + self.assertFalse((params[0, ...] == params[1, ...]).all()) + + y = rr_ensemble.apply(params, x) + self.assertTupleEqual(y.shape, x.shape) + np.testing.assert_allclose( + jnp.mean(params, axis=0), y - x, atol=1e-5, rtol=1e-5 + ) + + by = rr_ensemble.apply(params, bx) + self.assertTupleEqual(by.shape, bx.shape) + + def test_apply_all_multiargs(self): + x = jnp.ones(10) # Base input + + wrapped_ffn = funny_args_ffn(x) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_all, num_networks=3 + ) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + self.assertTupleEqual(params.shape, (3,) + x.shape) + + y = rr_ensemble.apply(params, x, 2 * x, x) + self.assertTupleEqual(y.shape, (3,) + x.shape) + np.testing.assert_allclose( + params, y - jnp.broadcast_to(2 * x, (3,) + x.shape), atol=1e-5, rtol=1e-5 + ) + + y = rr_ensemble.apply(params, x, bar=x, foo=2 * x) + self.assertTupleEqual(y.shape, (3,) + x.shape) + np.testing.assert_allclose( + params, y - jnp.broadcast_to(2 * x, (3,) + x.shape), atol=1e-5, rtol=1e-5 + ) + + def test_apply_all_structured(self): + x = jnp.ones(10) + sx = [(3 * x, 2 * x), 5 * x] # Base input + + wrapped_ffn = struct_params_adding_ffn(sx) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_all, num_networks=3 + ) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + + y = rr_ensemble.apply(params, sx) + ex = jnp.broadcast_to(x, (3,) + x.shape) + np.testing.assert_allclose(y[0][0], params[0][0] + 3 * ex) + + def test_apply_round_robin_multiargs(self): + x = jnp.ones(10) # Base input + + wrapped_ffn = funny_args_ffn(x) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_round_robin, num_networks=3 + ) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + self.assertTupleEqual(params.shape, (3,) + x.shape) + + ex = jnp.broadcast_to(x, (3,) + x.shape) + y = rr_ensemble.apply(params, ex, 2 * ex, ex) + self.assertTupleEqual(y.shape, (3,) + x.shape) + np.testing.assert_allclose( + params, y - jnp.broadcast_to(2 * x, (3,) + x.shape), atol=1e-5, rtol=1e-5 + ) + + y = rr_ensemble.apply(params, ex, bar=ex, foo=2 * ex) + self.assertTupleEqual(y.shape, (3,) + x.shape) + np.testing.assert_allclose( + params, y - jnp.broadcast_to(2 * x, (3,) + x.shape), atol=1e-5, rtol=1e-5 + ) + + def test_apply_round_robin_structured(self): + x = jnp.ones(10) + sx = [(3 * x, 2 * x), 5 * x] # Base input + + wrapped_ffn = struct_params_adding_ffn(sx) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_round_robin, num_networks=3 + ) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + + ex = jnp.broadcast_to(x, (3,) + x.shape) + esx = [(3 * ex, 2 * ex), 5 * ex] + y = rr_ensemble.apply(params, esx) + np.testing.assert_allclose(y[0][0], params[0][0] + 3 * ex) + + def test_apply_mean_multiargs(self): + x = jnp.ones(10) # Base input + + wrapped_ffn = funny_args_ffn(x) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_mean, num_networks=3 + ) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + self.assertTupleEqual(params.shape, (3,) + x.shape) + + y = rr_ensemble.apply(params, x, 2 * x, x) + self.assertTupleEqual(y.shape, x.shape) + np.testing.assert_allclose( + jnp.mean(params, axis=0), y - 2 * x, atol=1e-5, rtol=1e-5 + ) + + y = rr_ensemble.apply(params, x, bar=x, foo=2 * x) + self.assertTupleEqual(y.shape, x.shape) + np.testing.assert_allclose( + jnp.mean(params, axis=0), y - 2 * x, atol=1e-5, rtol=1e-5 + ) + + def test_apply_mean_structured(self): + x = jnp.ones(10) + sx = [(3 * x, 2 * x), 5 * x] # Base input + + wrapped_ffn = struct_params_adding_ffn(sx) + + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_mean, num_networks=3 + ) + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + + y = rr_ensemble.apply(params, sx) + np.testing.assert_allclose( + y[0][0], jnp.mean(params[0][0], axis=0) + 3 * x, atol=1e-5, rtol=1e-5 + ) + + def test_round_robin_random(self): + x = jnp.ones(10) # Base input + bx = jnp.ones((9, 10)) # Batched input + ffn = RandomFFN() + wrapped_ffn = networks.FeedForwardNetwork( + init=functools.partial(ffn.init, x=x), apply=ffn.apply + ) + rr_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_round_robin, num_networks=3 + ) + + key = jax.random.PRNGKey(0) + params = rr_ensemble.init(key) + out = rr_ensemble.apply(params, bx) + # The output should be the same every 3 rows: + blocks = jnp.split(out, 3, axis=0) + np.testing.assert_array_equal(blocks[0], blocks[1]) + np.testing.assert_array_equal(blocks[0], blocks[2]) + self.assertTrue((out[0] != out[1]).any()) + + for i in range(9): + np.testing.assert_allclose( + out[i], + ffn.apply(jax.tree_map(lambda p, i=i: p[i % 3], params), bx[i]), + atol=1e-5, + rtol=1e-5, + ) + + def test_mean_random(self): + x = jnp.ones(10) + bx = jnp.ones((9, 10)) + ffn = RandomFFN() + wrapped_ffn = networks.FeedForwardNetwork( + init=functools.partial(ffn.init, x=x), apply=ffn.apply + ) + mean_ensemble = ensemble.make_ensemble( + wrapped_ffn, ensemble.apply_mean, num_networks=3 + ) + key = jax.random.PRNGKey(0) + params = mean_ensemble.init(key) + single_output = mean_ensemble.apply(params, x) + self.assertEqual(single_output.shape, (15,)) + batch_output = mean_ensemble.apply(params, bx) + # Make sure all rows are equal: + np.testing.assert_allclose( + jnp.broadcast_to(batch_output[0], batch_output.shape), + batch_output, + atol=1e-5, + rtol=1e-5, + ) + + # Check results explicitly: + all_members = jnp.concatenate( + [ + jnp.expand_dims( + ffn.apply(jax.tree_map(lambda p, i=i: p[i], params), bx), axis=0 + ) + for i in range(3) + ] + ) + batch_means = jnp.mean(all_members, axis=0) + np.testing.assert_allclose(batch_output, batch_means, atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/jax/mbop/learning.py b/acme/agents/jax/mbop/learning.py index 82d7101745..dddda847c2 100644 --- a/acme/agents/jax/mbop/learning.py +++ b/acme/agents/jax/mbop/learning.py @@ -20,8 +20,11 @@ import time from typing import Any, Callable, Iterator, List, Optional -from acme import core -from acme import types +import jax +import jax.numpy as jnp +import optax + +from acme import core, types from acme.agents.jax import bc from acme.agents.jax.mbop import ensemble from acme.agents.jax.mbop import losses as mbop_losses @@ -29,52 +32,58 @@ from acme.jax import networks as networks_lib from acme.jax import types as jax_types from acme.jax import utils -from acme.utils import counting -from acme.utils import loggers -import jax -import jax.numpy as jnp -import optax +from acme.utils import counting, loggers @dataclasses.dataclass class TrainingState: - """States of the world model, policy prior and n-step return learners.""" - world_model: Any - policy_prior: Any - n_step_return: Any + """States of the world model, policy prior and n-step return learners.""" + + world_model: Any + policy_prior: Any + n_step_return: Any LoggerFn = Callable[[str, str], loggers.Logger] # Creates a world model learner. -MakeWorldModelLearner = Callable[[ - LoggerFn, - counting.Counter, - jax_types.PRNGKey, - Iterator[types.Transition], - mbop_networks.WorldModelNetwork, - mbop_losses.TransitionLoss, -], core.Learner] +MakeWorldModelLearner = Callable[ + [ + LoggerFn, + counting.Counter, + jax_types.PRNGKey, + Iterator[types.Transition], + mbop_networks.WorldModelNetwork, + mbop_losses.TransitionLoss, + ], + core.Learner, +] # Creates a policy prior learner. -MakePolicyPriorLearner = Callable[[ - LoggerFn, - counting.Counter, - jax_types.PRNGKey, - Iterator[types.Transition], - mbop_networks.PolicyPriorNetwork, - mbop_losses.TransitionLoss, -], core.Learner] +MakePolicyPriorLearner = Callable[ + [ + LoggerFn, + counting.Counter, + jax_types.PRNGKey, + Iterator[types.Transition], + mbop_networks.PolicyPriorNetwork, + mbop_losses.TransitionLoss, + ], + core.Learner, +] # Creates an n-step return model learner. -MakeNStepReturnLearner = Callable[[ - LoggerFn, - counting.Counter, - jax_types.PRNGKey, - Iterator[types.Transition], - mbop_networks.NStepReturnNetwork, - mbop_losses.TransitionLoss, -], core.Learner] +MakeNStepReturnLearner = Callable[ + [ + LoggerFn, + counting.Counter, + jax_types.PRNGKey, + Iterator[types.Transition], + mbop_networks.NStepReturnNetwork, + mbop_losses.TransitionLoss, + ], + core.Learner, +] def make_ensemble_regressor_learner( @@ -89,7 +98,7 @@ def make_ensemble_regressor_learner( optimizer: optax.GradientTransformation, num_sgd_steps_per_step: int, ): - """Creates an ensemble regressor learner from the base network. + """Creates an ensemble regressor learner from the base network. Args: name: Name of the learner used for logging and counters. @@ -107,51 +116,58 @@ def make_ensemble_regressor_learner( Returns: An ensemble regressor learner. """ - mbop_ensemble = ensemble.make_ensemble(base_network, ensemble.apply_all, - num_networks) - local_counter = counting.Counter(parent=counter, prefix=name) - local_logger = logger_fn(name, - local_counter.get_steps_key()) if logger_fn else None - - def loss_fn(networks: bc.BCNetworks, params: networks_lib.Params, - key: jax_types.PRNGKey, - transitions: types.Transition) -> jnp.ndarray: - del key - return loss( - functools.partial(networks.policy_network.apply, params), transitions) - - bc_policy_network = bc.convert_to_bc_network(mbop_ensemble) - bc_networks = bc.BCNetworks(bc_policy_network) - - # This is effectively a regressor learner. - return bc.BCLearner( - bc_networks, - rng_key, - loss_fn, - optimizer, - iterator, - num_sgd_steps_per_step, - logger=local_logger, - counter=local_counter) + mbop_ensemble = ensemble.make_ensemble( + base_network, ensemble.apply_all, num_networks + ) + local_counter = counting.Counter(parent=counter, prefix=name) + local_logger = logger_fn(name, local_counter.get_steps_key()) if logger_fn else None + + def loss_fn( + networks: bc.BCNetworks, + params: networks_lib.Params, + key: jax_types.PRNGKey, + transitions: types.Transition, + ) -> jnp.ndarray: + del key + return loss( + functools.partial(networks.policy_network.apply, params), transitions + ) + + bc_policy_network = bc.convert_to_bc_network(mbop_ensemble) + bc_networks = bc.BCNetworks(bc_policy_network) + + # This is effectively a regressor learner. + return bc.BCLearner( + bc_networks, + rng_key, + loss_fn, + optimizer, + iterator, + num_sgd_steps_per_step, + logger=local_logger, + counter=local_counter, + ) class MBOPLearner(core.Learner): - """Model-Based Offline Planning (MBOP) learner. + """Model-Based Offline Planning (MBOP) learner. See https://arxiv.org/abs/2008.05556 for more information. """ - def __init__(self, - networks: mbop_networks.MBOPNetworks, - losses: mbop_losses.MBOPLosses, - iterator: Iterator[types.Transition], - rng_key: jax_types.PRNGKey, - logger_fn: LoggerFn, - make_world_model_learner: MakeWorldModelLearner, - make_policy_prior_learner: MakePolicyPriorLearner, - make_n_step_return_learner: MakeNStepReturnLearner, - counter: Optional[counting.Counter] = None): - """Creates an MBOP learner. + def __init__( + self, + networks: mbop_networks.MBOPNetworks, + losses: mbop_losses.MBOPLosses, + iterator: Iterator[types.Transition], + rng_key: jax_types.PRNGKey, + logger_fn: LoggerFn, + make_world_model_learner: MakeWorldModelLearner, + make_policy_prior_learner: MakePolicyPriorLearner, + make_n_step_return_learner: MakeNStepReturnLearner, + counter: Optional[counting.Counter] = None, + ): + """Creates an MBOP learner. Args: networks: One network per model. @@ -165,71 +181,89 @@ def __init__(self, make_n_step_return_learner: Function to create the n-step return learner. counter: Parent counter object. """ - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger_fn('mbop', 'steps') - - # Prepare iterators for the learners, to not split the data (preserve sample - # efficiency). - sharded_prefetching_dataset = utils.sharded_prefetch(iterator) - world_model_iterator, policy_prior_iterator, n_step_return_iterator = ( - itertools.tee(sharded_prefetching_dataset, 3)) - - world_model_key, policy_prior_key, n_step_return_key = jax.random.split( - rng_key, 3) - - self._world_model = make_world_model_learner(logger_fn, self._counter, - world_model_key, - world_model_iterator, - networks.world_model_network, - losses.world_model_loss) - self._policy_prior = make_policy_prior_learner( - logger_fn, self._counter, policy_prior_key, policy_prior_iterator, - networks.policy_prior_network, losses.policy_prior_loss) - self._n_step_return = make_n_step_return_learner( - logger_fn, self._counter, n_step_return_key, n_step_return_iterator, - networks.n_step_return_network, losses.n_step_return_loss) - # Start recording timestamps after the first learning step to not report - # "warmup" time. - self._timestamp = None - self._learners = { - 'world_model': self._world_model, - 'policy_prior': self._policy_prior, - 'n_step_return': self._n_step_return - } - - def step(self): - # Step the world model, policy learner and n-step return learners. - self._world_model.step() - self._policy_prior.step() - self._n_step_return.step() - - # Compute the elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - # Increment counts and record the current time. - counts = self._counter.increment(steps=1, walltime=elapsed_time) - # Attempt to write the logs. - self._logger.write({**counts}) - - def get_variables(self, names: List[str]) -> List[types.NestedArray]: - variables = [] - for name in names: - # Variables will be prefixed by the learner names. If separator is not - # found, learner_name=name, which is OK. - learner_name, _, variable_name = name.partition('-') - learner = self._learners[learner_name] - variables.extend(learner.get_variables([variable_name])) - return variables - - def save(self) -> TrainingState: - return TrainingState( - world_model=self._world_model.save(), - policy_prior=self._policy_prior.save(), - n_step_return=self._n_step_return.save()) - - def restore(self, state: TrainingState): - self._world_model.restore(state.world_model) - self._policy_prior.restore(state.policy_prior) - self._n_step_return.restore(state.n_step_return) + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger_fn("mbop", "steps") + + # Prepare iterators for the learners, to not split the data (preserve sample + # efficiency). + sharded_prefetching_dataset = utils.sharded_prefetch(iterator) + ( + world_model_iterator, + policy_prior_iterator, + n_step_return_iterator, + ) = itertools.tee(sharded_prefetching_dataset, 3) + + world_model_key, policy_prior_key, n_step_return_key = jax.random.split( + rng_key, 3 + ) + + self._world_model = make_world_model_learner( + logger_fn, + self._counter, + world_model_key, + world_model_iterator, + networks.world_model_network, + losses.world_model_loss, + ) + self._policy_prior = make_policy_prior_learner( + logger_fn, + self._counter, + policy_prior_key, + policy_prior_iterator, + networks.policy_prior_network, + losses.policy_prior_loss, + ) + self._n_step_return = make_n_step_return_learner( + logger_fn, + self._counter, + n_step_return_key, + n_step_return_iterator, + networks.n_step_return_network, + losses.n_step_return_loss, + ) + # Start recording timestamps after the first learning step to not report + # "warmup" time. + self._timestamp = None + self._learners = { + "world_model": self._world_model, + "policy_prior": self._policy_prior, + "n_step_return": self._n_step_return, + } + + def step(self): + # Step the world model, policy learner and n-step return learners. + self._world_model.step() + self._policy_prior.step() + self._n_step_return.step() + + # Compute the elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + # Increment counts and record the current time. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + # Attempt to write the logs. + self._logger.write({**counts}) + + def get_variables(self, names: List[str]) -> List[types.NestedArray]: + variables = [] + for name in names: + # Variables will be prefixed by the learner names. If separator is not + # found, learner_name=name, which is OK. + learner_name, _, variable_name = name.partition("-") + learner = self._learners[learner_name] + variables.extend(learner.get_variables([variable_name])) + return variables + + def save(self) -> TrainingState: + return TrainingState( + world_model=self._world_model.save(), + policy_prior=self._policy_prior.save(), + n_step_return=self._n_step_return.save(), + ) + + def restore(self, state: TrainingState): + self._world_model.restore(state.world_model) + self._policy_prior.restore(state.policy_prior) + self._n_step_return.restore(state.n_step_return) diff --git a/acme/agents/jax/mbop/losses.py b/acme/agents/jax/mbop/losses.py index 4ec911f431..b022938d69 100644 --- a/acme/agents/jax/mbop/losses.py +++ b/acme/agents/jax/mbop/losses.py @@ -17,31 +17,36 @@ import dataclasses from typing import Any, Callable, Optional, Tuple, Union +import jax +import jax.numpy as jnp + from acme import types from acme.agents.jax.mbop import dataset from acme.jax import networks -import jax -import jax.numpy as jnp # The apply function takes an observation (and an action) as arguments, and is # usually a network with bound parameters. TransitionApplyFn = Callable[[networks.Observation, networks.Action], Any] ObservationOnlyTransitionApplyFn = Callable[[networks.Observation], Any] -TransitionLoss = Callable[[ - Union[TransitionApplyFn, ObservationOnlyTransitionApplyFn], types.Transition -], jnp.ndarray] +TransitionLoss = Callable[ + [Union[TransitionApplyFn, ObservationOnlyTransitionApplyFn], types.Transition], + jnp.ndarray, +] def mse(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: - """MSE distance.""" - return jnp.mean(jnp.square(a - b)) + """MSE distance.""" + return jnp.mean(jnp.square(a - b)) -def world_model_loss(apply_fn: Callable[[networks.Observation, networks.Action], - Tuple[networks.Observation, - jnp.ndarray]], - steps: types.Transition) -> jnp.ndarray: - """Returns the loss for the world model. +def world_model_loss( + apply_fn: Callable[ + [networks.Observation, networks.Action], + Tuple[networks.Observation, jnp.ndarray], + ], + steps: types.Transition, +) -> jnp.ndarray: + """Returns the loss for the world model. Args: apply_fn: applies a transition model (o_t, a_t) -> (o_t+1, r), expects the @@ -53,28 +58,31 @@ def world_model_loss(apply_fn: Callable[[networks.Observation, networks.Action], Returns: A scalar loss value as jnp.ndarray. """ - observation_t = jax.tree_map(lambda obs: obs[:, dataset.CURRENT, ...], - steps.observation) - action_t = steps.action[:, dataset.CURRENT, ...] - observation_tp1 = jax.tree_map(lambda obs: obs[:, dataset.NEXT, ...], - steps.observation) - reward_t = steps.reward[:, dataset.CURRENT, ...] - (predicted_observation_tp1, - predicted_reward_t) = apply_fn(observation_t, action_t) - # predicted_* variables may have an extra outer dimension due to ensembling, - # the mse loss still works due to broadcasting however. - if len(observation_tp1.shape) != len(reward_t.shape): - # The rewards from the transitions may not have the last singular dimension. - reward_t = jnp.expand_dims(reward_t, axis=-1) - return mse( - jnp.concatenate([predicted_observation_tp1, predicted_reward_t], axis=-1), - jnp.concatenate([observation_tp1, reward_t], axis=-1)) + observation_t = jax.tree_map( + lambda obs: obs[:, dataset.CURRENT, ...], steps.observation + ) + action_t = steps.action[:, dataset.CURRENT, ...] + observation_tp1 = jax.tree_map( + lambda obs: obs[:, dataset.NEXT, ...], steps.observation + ) + reward_t = steps.reward[:, dataset.CURRENT, ...] + (predicted_observation_tp1, predicted_reward_t) = apply_fn(observation_t, action_t) + # predicted_* variables may have an extra outer dimension due to ensembling, + # the mse loss still works due to broadcasting however. + if len(observation_tp1.shape) != len(reward_t.shape): + # The rewards from the transitions may not have the last singular dimension. + reward_t = jnp.expand_dims(reward_t, axis=-1) + return mse( + jnp.concatenate([predicted_observation_tp1, predicted_reward_t], axis=-1), + jnp.concatenate([observation_tp1, reward_t], axis=-1), + ) def policy_prior_loss( - apply_fn: Callable[[networks.Observation, networks.Action], - networks.Action], steps: types.Transition): - """Returns the loss for the policy prior. + apply_fn: Callable[[networks.Observation, networks.Action], networks.Action], + steps: types.Transition, +): + """Returns the loss for the policy prior. Args: apply_fn: applies a policy prior (o_t, a_t) -> a_t+1, expects the leading @@ -86,18 +94,21 @@ def policy_prior_loss( Returns: A scalar loss value as jnp.ndarray. """ - observation_t = jax.tree_map(lambda obs: obs[:, dataset.CURRENT, ...], - steps.observation) - action_tm1 = steps.action[:, dataset.PREVIOUS, ...] - action_t = steps.action[:, dataset.CURRENT, ...] + observation_t = jax.tree_map( + lambda obs: obs[:, dataset.CURRENT, ...], steps.observation + ) + action_tm1 = steps.action[:, dataset.PREVIOUS, ...] + action_t = steps.action[:, dataset.CURRENT, ...] - predicted_action_t = apply_fn(observation_t, action_tm1) - return mse(predicted_action_t, action_t) + predicted_action_t = apply_fn(observation_t, action_tm1) + return mse(predicted_action_t, action_t) -def return_loss(apply_fn: Callable[[networks.Observation, networks.Action], - jnp.ndarray], steps: types.Transition): - """Returns the loss for the n-step return model. +def return_loss( + apply_fn: Callable[[networks.Observation, networks.Action], jnp.ndarray], + steps: types.Transition, +): + """Returns the loss for the n-step return model. Args: apply_fn: applies an n-step return model (o_t, a_t) -> r, expects the @@ -109,18 +120,20 @@ def return_loss(apply_fn: Callable[[networks.Observation, networks.Action], Returns: A scalar loss value as jnp.ndarray. """ - observation_t = jax.tree_map(lambda obs: obs[:, dataset.CURRENT, ...], - steps.observation) - action_t = steps.action[:, dataset.CURRENT, ...] - n_step_return_t = steps.extras[dataset.N_STEP_RETURN][:, dataset.CURRENT, ...] + observation_t = jax.tree_map( + lambda obs: obs[:, dataset.CURRENT, ...], steps.observation + ) + action_t = steps.action[:, dataset.CURRENT, ...] + n_step_return_t = steps.extras[dataset.N_STEP_RETURN][:, dataset.CURRENT, ...] - predicted_n_step_return_t = apply_fn(observation_t, action_t) - return mse(predicted_n_step_return_t, n_step_return_t) + predicted_n_step_return_t = apply_fn(observation_t, action_t) + return mse(predicted_n_step_return_t, n_step_return_t) @dataclasses.dataclass class MBOPLosses: - """Losses for the world model, policy prior and the n-step return.""" - world_model_loss: Optional[TransitionLoss] = world_model_loss - policy_prior_loss: Optional[TransitionLoss] = policy_prior_loss - n_step_return_loss: Optional[TransitionLoss] = return_loss + """Losses for the world model, policy prior and the n-step return.""" + + world_model_loss: Optional[TransitionLoss] = world_model_loss + policy_prior_loss: Optional[TransitionLoss] = policy_prior_loss + n_step_return_loss: Optional[TransitionLoss] = return_loss diff --git a/acme/agents/jax/mbop/models.py b/acme/agents/jax/mbop/models.py index 874895747c..85ab0ba88d 100644 --- a/acme/agents/jax/mbop/models.py +++ b/acme/agents/jax/mbop/models.py @@ -17,33 +17,37 @@ import functools from typing import Callable, Generic, Optional, Tuple +import chex +import jax + from acme import specs from acme.agents.jax import actor_core from acme.agents.jax.mbop import ensemble from acme.agents.jax.mbop import networks as mbop_networks -from acme.jax import networks -from acme.jax import utils -import chex -import jax +from acme.jax import networks, utils # World, policy prior and n-step return models. These are backed by the # corresponding networks. -WorldModel = Callable[[networks.Params, networks.Observation, networks.Action], - Tuple[networks.Observation, networks.Value]] +WorldModel = Callable[ + [networks.Params, networks.Observation, networks.Action], + Tuple[networks.Observation, networks.Value], +] MakeWorldModel = Callable[[mbop_networks.WorldModelNetwork], WorldModel] PolicyPrior = actor_core.ActorCore MakePolicyPrior = Callable[ - [mbop_networks.PolicyPriorNetwork, specs.EnvironmentSpec], PolicyPrior] + [mbop_networks.PolicyPriorNetwork, specs.EnvironmentSpec], PolicyPrior +] -NStepReturn = Callable[[networks.Params, networks.Observation, networks.Action], - networks.Value] +NStepReturn = Callable[ + [networks.Params, networks.Observation, networks.Action], networks.Value +] MakeNStepReturn = Callable[[mbop_networks.NStepReturnNetwork], NStepReturn] @chex.dataclass(frozen=True, mappable_dataclass=False) class PolicyPriorState(Generic[actor_core.RecurrentState]): - """State of a policy prior. + """State of a policy prior. Attributes: rng: Random key. @@ -51,9 +55,10 @@ class PolicyPriorState(Generic[actor_core.RecurrentState]): recurrent_state: Recurrent state. It will be none for non-recurrent, e.g. feed forward, policies. """ - rng: networks.PRNGKey - action_tm1: networks.Action - recurrent_state: Optional[actor_core.RecurrentState] = None + + rng: networks.PRNGKey + action_tm1: networks.Action + recurrent_state: Optional[actor_core.RecurrentState] = None FeedForwardPolicyState = PolicyPriorState[actor_core.NoneType] @@ -62,7 +67,7 @@ class PolicyPriorState(Generic[actor_core.RecurrentState]): def feed_forward_policy_prior_to_actor_core( policy: actor_core.RecurrentPolicy, initial_action_tm1: networks.Action ) -> actor_core.ActorCore[PolicyPriorState, actor_core.NoneType]: - """A convenience adaptor from a feed forward policy prior to ActorCore. + """A convenience adaptor from a feed forward policy prior to ActorCore. Args: policy: A feed forward policy prior. In the planner and other components, @@ -77,34 +82,39 @@ def feed_forward_policy_prior_to_actor_core( an ActorCore representing the feed forward policy prior. """ - def select_action(params: networks.Params, observation: networks.Observation, - state: FeedForwardPolicyState): - rng, policy_rng = jax.random.split(state.rng) - action = policy(params, policy_rng, observation, state.action_tm1) - return action, PolicyPriorState(rng, action) + def select_action( + params: networks.Params, + observation: networks.Observation, + state: FeedForwardPolicyState, + ): + rng, policy_rng = jax.random.split(state.rng) + action = policy(params, policy_rng, observation, state.action_tm1) + return action, PolicyPriorState(rng, action) - def init(rng: networks.PRNGKey) -> FeedForwardPolicyState: - return PolicyPriorState(rng, initial_action_tm1) + def init(rng: networks.PRNGKey) -> FeedForwardPolicyState: + return PolicyPriorState(rng, initial_action_tm1) - def get_extras(unused_state: FeedForwardPolicyState) -> actor_core.NoneType: - return None + def get_extras(unused_state: FeedForwardPolicyState) -> actor_core.NoneType: + return None - return actor_core.ActorCore( - init=init, select_action=select_action, get_extras=get_extras) + return actor_core.ActorCore( + init=init, select_action=select_action, get_extras=get_extras + ) def make_ensemble_world_model( - world_model_network: mbop_networks.WorldModelNetwork) -> WorldModel: - """Creates an ensemble world model from its network.""" - return functools.partial(ensemble.apply_round_robin, - world_model_network.apply) + world_model_network: mbop_networks.WorldModelNetwork, +) -> WorldModel: + """Creates an ensemble world model from its network.""" + return functools.partial(ensemble.apply_round_robin, world_model_network.apply) def make_ensemble_policy_prior( policy_prior_network: mbop_networks.PolicyPriorNetwork, spec: specs.EnvironmentSpec, - use_round_robin: bool = True) -> PolicyPrior: - """Creates an ensemble policy prior from its network. + use_round_robin: bool = True, +) -> PolicyPrior: + """Creates an ensemble policy prior from its network. Args: policy_prior_network: The policy prior network. @@ -116,26 +126,32 @@ def make_ensemble_policy_prior( A policy prior. """ - def _policy_prior(params: networks.Params, key: networks.PRNGKey, - observation_t: networks.Observation, - action_tm1: networks.Action) -> networks.Action: - # Regressor policies are deterministic. - del key - apply_fn = ( - ensemble.apply_round_robin if use_round_robin else ensemble.apply_mean) - return apply_fn( - policy_prior_network.apply, - params, - observation_t=observation_t, - action_tm1=action_tm1) - - dummy_action = utils.zeros_like(spec.actions) - dummy_action = utils.add_batch_dim(dummy_action) - - return feed_forward_policy_prior_to_actor_core(_policy_prior, dummy_action) + def _policy_prior( + params: networks.Params, + key: networks.PRNGKey, + observation_t: networks.Observation, + action_tm1: networks.Action, + ) -> networks.Action: + # Regressor policies are deterministic. + del key + apply_fn = ( + ensemble.apply_round_robin if use_round_robin else ensemble.apply_mean + ) + return apply_fn( + policy_prior_network.apply, + params, + observation_t=observation_t, + action_tm1=action_tm1, + ) + + dummy_action = utils.zeros_like(spec.actions) + dummy_action = utils.add_batch_dim(dummy_action) + + return feed_forward_policy_prior_to_actor_core(_policy_prior, dummy_action) def make_ensemble_n_step_return( - n_step_return_network: mbop_networks.NStepReturnNetwork) -> NStepReturn: - """Creates an ensemble n-step return model from its network.""" - return functools.partial(ensemble.apply_mean, n_step_return_network.apply) + n_step_return_network: mbop_networks.NStepReturnNetwork, +) -> NStepReturn: + """Creates an ensemble n-step return model from its network.""" + return functools.partial(ensemble.apply_mean, n_step_return_network.apply) diff --git a/acme/agents/jax/mbop/mppi.py b/acme/agents/jax/mbop/mppi.py index 1731a9940a..f988c60c35 100644 --- a/acme/agents/jax/mbop/mppi.py +++ b/acme/agents/jax/mbop/mppi.py @@ -29,12 +29,13 @@ import functools from typing import Callable, Optional +import jax +import jax.numpy as jnp +from jax import random + from acme import specs from acme.agents.jax.mbop import models from acme.jax import networks -import jax -from jax import random -import jax.numpy as jnp # Function that takes (n_trajectories, horizon, action_dim) tensor of action # trajectories and (n_trajectories) vector of corresponding cumulative rewards, @@ -43,10 +44,10 @@ ActionAggregationFn = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] -def return_weighted_average(action_trajectories: jnp.ndarray, - cum_reward: jnp.ndarray, - kappa: float) -> jnp.ndarray: - r"""Calculates return-weighted average over all trajectories. +def return_weighted_average( + action_trajectories: jnp.ndarray, cum_reward: jnp.ndarray, kappa: float +) -> jnp.ndarray: + r"""Calculates return-weighted average over all trajectories. This will calculate the return-weighted average over a set of trajectories as defined on l.17 of Alg. 2 in the MBOP paper: @@ -66,18 +67,18 @@ def return_weighted_average(action_trajectories: jnp.ndarray, Single action trajectory corresponding to the return-weighted average of the trajectories. """ - # Substract maximum reward to avoid NaNs: - cum_reward = cum_reward - cum_reward.max() - # Remove the batch dimension of cum_reward allows for an implicit broadcast in - # jnp.average: - exp_cum_reward = jnp.exp(kappa * jnp.squeeze(cum_reward)) - return jnp.average(action_trajectories, weights=exp_cum_reward, axis=0) + # Substract maximum reward to avoid NaNs: + cum_reward = cum_reward - cum_reward.max() + # Remove the batch dimension of cum_reward allows for an implicit broadcast in + # jnp.average: + exp_cum_reward = jnp.exp(kappa * jnp.squeeze(cum_reward)) + return jnp.average(action_trajectories, weights=exp_cum_reward, axis=0) -def return_top_k_average(action_trajectories: jnp.ndarray, - cum_reward: jnp.ndarray, - k: int = 10) -> jnp.ndarray: - r"""Calculates the top-k average over all trajectories. +def return_top_k_average( + action_trajectories: jnp.ndarray, cum_reward: jnp.ndarray, k: int = 10 +) -> jnp.ndarray: + r"""Calculates the top-k average over all trajectories. This will calculate the top-k average over a set of trajectories as defined in the POIR Paper: @@ -95,14 +96,15 @@ def return_top_k_average(action_trajectories: jnp.ndarray, Single action trajectory corresponding to the average of the k best trajectories. """ - top_k_trajectories = action_trajectories[jnp.argsort( - jnp.squeeze(cum_reward))[-int(k):]] - return jnp.mean(top_k_trajectories, axis=0) + top_k_trajectories = action_trajectories[ + jnp.argsort(jnp.squeeze(cum_reward))[-int(k) :] + ] + return jnp.mean(top_k_trajectories, axis=0) @dataclasses.dataclass class MPPIConfig: - """Config dataclass for MPPI-style planning, used in mppi.py. + """Config dataclass for MPPI-style planning, used in mppi.py. These variables correspond to different parameters of `MBOP-Trajopt` as defined in MBOP [https://arxiv.org/abs/2008.05556] (Alg. 2). @@ -118,23 +120,25 @@ class MPPIConfig: action_aggregation_fn: Function that aggregates action trajectories and returns a single action trajectory. """ - sigma: float = 0.8 - beta: float = 0.2 - horizon: int = 15 - n_trajectories: int = 1000 - previous_trajectory_clip: Optional[float] = None - action_aggregation_fn: ActionAggregationFn = ( - functools.partial(return_weighted_average, kappa=0.5)) + + sigma: float = 0.8 + beta: float = 0.2 + horizon: int = 15 + n_trajectories: int = 1000 + previous_trajectory_clip: Optional[float] = None + action_aggregation_fn: ActionAggregationFn = ( + functools.partial(return_weighted_average, kappa=0.5) + ) def get_initial_trajectory(config: MPPIConfig, env_spec: specs.EnvironmentSpec): - """Returns the initial empty trajectory `T_0`.""" - return jnp.zeros((max(1, config.horizon),) + env_spec.actions.shape) + """Returns the initial empty trajectory `T_0`.""" + return jnp.zeros((max(1, config.horizon),) + env_spec.actions.shape) def _repeat_n(new_batch: int, data: jnp.ndarray) -> jnp.ndarray: - """Create new batch dimension of size `new_batch` by repeating `data`.""" - return jnp.broadcast_to(data, (new_batch,) + data.shape) + """Create new batch dimension of size `new_batch` by repeating `data`.""" + return jnp.broadcast_to(data, (new_batch,) + data.shape) def mppi_planner( @@ -149,7 +153,7 @@ def mppi_planner( observation: networks.Observation, previous_trajectory: jnp.ndarray, ) -> jnp.ndarray: - """MPPI-extended trajectory optimizer. + """MPPI-extended trajectory optimizer. This implements the trajectory optimizer described in MBOP [https://arxiv.org/abs/2008.05556] (Alg. 2) which is an extended version of @@ -179,73 +183,78 @@ def mppi_planner( Returns: jnp.ndarray: Average action trajectory of shape [horizon, action_dims]. """ - action_trajectory_tm1 = previous_trajectory - policy_prior_state = policy_prior.init(random_key) - - # Broadcast so that we have n_trajectories copies of each: - observation_t = jax.tree_map( - functools.partial(_repeat_n, config.n_trajectories), observation) - action_tm1 = jnp.broadcast_to(action_trajectory_tm1[0], - (config.n_trajectories,) + - action_trajectory_tm1[0].shape) - - if config.previous_trajectory_clip is not None: - action_tm1 = jnp.clip( - action_tm1, - a_min=-config.previous_trajectory_clip, - a_max=config.previous_trajectory_clip) - - # First check if planning is unnecessary: - if config.horizon == 0: - if hasattr(policy_prior_state, 'action_tm1'): - policy_prior_state = policy_prior_state.replace(action_tm1=action_tm1) - action_set, _ = policy_prior.select_action(policy_prior_params, - observation_t, - policy_prior_state) - # Need to re-create an action trajectory from a single action. - return jnp.broadcast_to( - jnp.mean(action_set, axis=0), (1, action_set.shape[-1])) - - # Accumulators for returns and trajectories: - cum_reward = jnp.zeros((config.n_trajectories, 1)) - - # Generate noise once: - random_key, noise_key = random.split(random_key) - action_noise = config.sigma * random.normal(noise_key, ( - (config.horizon,) + action_tm1.shape)) - - # Initialize empty set of action trajectories for concatenation in loop: - action_trajectories = jnp.zeros((config.n_trajectories, 0) + - action_trajectory_tm1[0].shape) - - for t in range(config.horizon): - # Query policy prior for proposed action: - if hasattr(policy_prior_state, 'action_tm1'): - policy_prior_state = policy_prior_state.replace(action_tm1=action_tm1) - action_t, policy_prior_state = policy_prior.select_action( - policy_prior_params, observation_t, policy_prior_state) - # Add action noise: - action_t = action_t + action_noise[t] - # Mix action with previous trajectory's corresponding action: - action_t = (1 - - config.beta) * action_t + config.beta * action_trajectory_tm1[t] - - # Query world model to get next observation and reward: - observation_tp1, reward_t = world_model(world_model_params, observation_t, - action_t) - cum_reward += reward_t - - # Insert actions into trajectory matrix: - action_trajectories = jnp.concatenate( - [action_trajectories, - jnp.expand_dims(action_t, axis=1)], axis=1) - # Bump variable timesteps for next loop: - observation_t = observation_tp1 - action_tm1 = action_t - - # De-normalize and append the final n_step return prediction: - n_step_return_t = n_step_return(n_step_return_params, observation_t, action_t) - cum_reward += n_step_return_t - - # Average the set of `n_trajectories` trajectories into a single trajectory. - return config.action_aggregation_fn(action_trajectories, cum_reward) + action_trajectory_tm1 = previous_trajectory + policy_prior_state = policy_prior.init(random_key) + + # Broadcast so that we have n_trajectories copies of each: + observation_t = jax.tree_map( + functools.partial(_repeat_n, config.n_trajectories), observation + ) + action_tm1 = jnp.broadcast_to( + action_trajectory_tm1[0], + (config.n_trajectories,) + action_trajectory_tm1[0].shape, + ) + + if config.previous_trajectory_clip is not None: + action_tm1 = jnp.clip( + action_tm1, + a_min=-config.previous_trajectory_clip, + a_max=config.previous_trajectory_clip, + ) + + # First check if planning is unnecessary: + if config.horizon == 0: + if hasattr(policy_prior_state, "action_tm1"): + policy_prior_state = policy_prior_state.replace(action_tm1=action_tm1) + action_set, _ = policy_prior.select_action( + policy_prior_params, observation_t, policy_prior_state + ) + # Need to re-create an action trajectory from a single action. + return jnp.broadcast_to(jnp.mean(action_set, axis=0), (1, action_set.shape[-1])) + + # Accumulators for returns and trajectories: + cum_reward = jnp.zeros((config.n_trajectories, 1)) + + # Generate noise once: + random_key, noise_key = random.split(random_key) + action_noise = config.sigma * random.normal( + noise_key, ((config.horizon,) + action_tm1.shape) + ) + + # Initialize empty set of action trajectories for concatenation in loop: + action_trajectories = jnp.zeros( + (config.n_trajectories, 0) + action_trajectory_tm1[0].shape + ) + + for t in range(config.horizon): + # Query policy prior for proposed action: + if hasattr(policy_prior_state, "action_tm1"): + policy_prior_state = policy_prior_state.replace(action_tm1=action_tm1) + action_t, policy_prior_state = policy_prior.select_action( + policy_prior_params, observation_t, policy_prior_state + ) + # Add action noise: + action_t = action_t + action_noise[t] + # Mix action with previous trajectory's corresponding action: + action_t = (1 - config.beta) * action_t + config.beta * action_trajectory_tm1[t] + + # Query world model to get next observation and reward: + observation_tp1, reward_t = world_model( + world_model_params, observation_t, action_t + ) + cum_reward += reward_t + + # Insert actions into trajectory matrix: + action_trajectories = jnp.concatenate( + [action_trajectories, jnp.expand_dims(action_t, axis=1)], axis=1 + ) + # Bump variable timesteps for next loop: + observation_t = observation_tp1 + action_tm1 = action_t + + # De-normalize and append the final n_step return prediction: + n_step_return_t = n_step_return(n_step_return_params, observation_t, action_t) + cum_reward += n_step_return_t + + # Average the set of `n_trajectories` trajectories into a single trajectory. + return config.action_aggregation_fn(action_trajectories, cum_reward) diff --git a/acme/agents/jax/mbop/mppi_test.py b/acme/agents/jax/mbop/mppi_test.py index e0d80a8c49..5cd366cc04 100644 --- a/acme/agents/jax/mbop/mppi_test.py +++ b/acme/agents/jax/mbop/mppi_test.py @@ -16,140 +16,141 @@ import functools from typing import Any -from acme import specs -from acme.agents.jax.mbop import ensemble -from acme.agents.jax.mbop import models -from acme.agents.jax.mbop import mppi -from acme.jax import networks as networks_lib import jax import jax.numpy as jnp import numpy as np +from absl.testing import absltest, parameterized -from absl.testing import absltest -from absl.testing import parameterized +from acme import specs +from acme.agents.jax.mbop import ensemble, models, mppi +from acme.jax import networks as networks_lib def get_fake_world_model() -> networks_lib.FeedForwardNetwork: + def apply(params: Any, observation_t: jnp.ndarray, action_t: jnp.ndarray): + del params + return observation_t, jnp.ones((action_t.shape[0], 1,)) - def apply(params: Any, observation_t: jnp.ndarray, action_t: jnp.ndarray): - del params - return observation_t, jnp.ones(( - action_t.shape[0], - 1, - )) - - return networks_lib.FeedForwardNetwork(init=lambda: None, apply=apply) + return networks_lib.FeedForwardNetwork(init=lambda: None, apply=apply) def get_fake_policy_prior() -> networks_lib.FeedForwardNetwork: - return networks_lib.FeedForwardNetwork( - init=lambda: None, - apply=lambda params, observation_t, action_tm1: action_tm1) + return networks_lib.FeedForwardNetwork( + init=lambda: None, apply=lambda params, observation_t, action_tm1: action_tm1 + ) def get_fake_n_step_return() -> networks_lib.FeedForwardNetwork: + def apply(params, observation_t, action_t): + del params, action_t + return jnp.ones((observation_t.shape[0], 1)) - def apply(params, observation_t, action_t): - del params, action_t - return jnp.ones((observation_t.shape[0], 1)) - - return networks_lib.FeedForwardNetwork(init=lambda: None, apply=apply) + return networks_lib.FeedForwardNetwork(init=lambda: None, apply=apply) class WeightedAverageTests(parameterized.TestCase): - - @parameterized.parameters((np.array([1, 1, 1]), 1), (np.array([0, 1, 0]), 10), - (np.array([-1, 1, -1]), 4), - (np.array([-10, 30, 0]), -0.5)) - def test_weighted_averages(self, cum_reward, kappa): - """Compares method with a local version of the exp-weighted averaging.""" - action_trajectories = jnp.reshape( - jnp.arange(3 * 10 * 4), (3, 10, 4), order='F') - averaged_trajectory = mppi.return_weighted_average( - action_trajectories=action_trajectories, - cum_reward=cum_reward, - kappa=kappa) - exp_weights = jnp.exp(kappa * cum_reward) - # Verify single-value averaging lines up with the global averaging call: - for i in range(10): - for j in range(4): - np.testing.assert_allclose( - averaged_trajectory[i, j], - jnp.sum(exp_weights * action_trajectories[:, i, j]) / - jnp.sum(exp_weights), - atol=1E-5, - rtol=1E-5) + @parameterized.parameters( + (np.array([1, 1, 1]), 1), + (np.array([0, 1, 0]), 10), + (np.array([-1, 1, -1]), 4), + (np.array([-10, 30, 0]), -0.5), + ) + def test_weighted_averages(self, cum_reward, kappa): + """Compares method with a local version of the exp-weighted averaging.""" + action_trajectories = jnp.reshape(jnp.arange(3 * 10 * 4), (3, 10, 4), order="F") + averaged_trajectory = mppi.return_weighted_average( + action_trajectories=action_trajectories, cum_reward=cum_reward, kappa=kappa + ) + exp_weights = jnp.exp(kappa * cum_reward) + # Verify single-value averaging lines up with the global averaging call: + for i in range(10): + for j in range(4): + np.testing.assert_allclose( + averaged_trajectory[i, j], + jnp.sum(exp_weights * action_trajectories[:, i, j]) + / jnp.sum(exp_weights), + atol=1e-5, + rtol=1e-5, + ) class MPPITest(parameterized.TestCase): - """This tests the MPPI planner to make sure it is correctly rolling out. + """This tests the MPPI planner to make sure it is correctly rolling out. It does not check the actual performance of the planner, as this would be a bit more complicated to set up. """ - # TODO(dulacarnold): Look at how we can check this is actually finding an - # optimal path through the model. - - def setUp(self): - super().setUp() - self.state_dims = 8 - self.action_dims = 4 - self.params = { - 'world': jnp.ones((3,)), - 'policy': jnp.ones((3,)), - 'value': jnp.ones((3,)) - } - self.env_spec = specs.EnvironmentSpec( - observations=specs.Array(shape=(self.state_dims,), dtype=float), - actions=specs.Array(shape=(self.action_dims,), dtype=float), - rewards=specs.Array(shape=(1,), dtype=float, name='reward'), - discounts=specs.BoundedArray( - shape=(), dtype=float, minimum=0., maximum=1., name='discount')) - - @parameterized.named_parameters(('NO-PLAN', 0), ('NORMAL', 10)) - def test_planner_init(self, horizon: int): - world_model = get_fake_world_model() - rr_world_model = functools.partial(ensemble.apply_round_robin, - world_model.apply) - policy_prior = get_fake_policy_prior() - - def _rr_policy_prior(params, key, observation_t, action_tm1): - del key - return ensemble.apply_round_robin( - policy_prior.apply, - params, - observation_t=observation_t, - action_tm1=action_tm1) - - rr_policy_prior = models.feed_forward_policy_prior_to_actor_core( - _rr_policy_prior, jnp.zeros((1, self.action_dims))) - - n_step_return = get_fake_n_step_return() - n_step_return = functools.partial(ensemble.apply_mean, n_step_return.apply) - - config = mppi.MPPIConfig( - sigma=1, - beta=0.2, - horizon=horizon, - n_trajectories=9, - action_aggregation_fn=functools.partial( - mppi.return_weighted_average, kappa=1)) - previous_trajectory = mppi.get_initial_trajectory(config, self.env_spec) - key = jax.random.PRNGKey(0) - for _ in range(5): - previous_trajectory = mppi.mppi_planner( - config, - world_model=rr_world_model, - policy_prior=rr_policy_prior, - n_step_return=n_step_return, - world_model_params=self.params, - policy_prior_params=self.params, - n_step_return_params=self.params, - random_key=key, - observation=jnp.ones(self.state_dims), - previous_trajectory=previous_trajectory) - - -if __name__ == '__main__': - absltest.main() + # TODO(dulacarnold): Look at how we can check this is actually finding an + # optimal path through the model. + + def setUp(self): + super().setUp() + self.state_dims = 8 + self.action_dims = 4 + self.params = { + "world": jnp.ones((3,)), + "policy": jnp.ones((3,)), + "value": jnp.ones((3,)), + } + self.env_spec = specs.EnvironmentSpec( + observations=specs.Array(shape=(self.state_dims,), dtype=float), + actions=specs.Array(shape=(self.action_dims,), dtype=float), + rewards=specs.Array(shape=(1,), dtype=float, name="reward"), + discounts=specs.BoundedArray( + shape=(), dtype=float, minimum=0.0, maximum=1.0, name="discount" + ), + ) + + @parameterized.named_parameters(("NO-PLAN", 0), ("NORMAL", 10)) + def test_planner_init(self, horizon: int): + world_model = get_fake_world_model() + rr_world_model = functools.partial( + ensemble.apply_round_robin, world_model.apply + ) + policy_prior = get_fake_policy_prior() + + def _rr_policy_prior(params, key, observation_t, action_tm1): + del key + return ensemble.apply_round_robin( + policy_prior.apply, + params, + observation_t=observation_t, + action_tm1=action_tm1, + ) + + rr_policy_prior = models.feed_forward_policy_prior_to_actor_core( + _rr_policy_prior, jnp.zeros((1, self.action_dims)) + ) + + n_step_return = get_fake_n_step_return() + n_step_return = functools.partial(ensemble.apply_mean, n_step_return.apply) + + config = mppi.MPPIConfig( + sigma=1, + beta=0.2, + horizon=horizon, + n_trajectories=9, + action_aggregation_fn=functools.partial( + mppi.return_weighted_average, kappa=1 + ), + ) + previous_trajectory = mppi.get_initial_trajectory(config, self.env_spec) + key = jax.random.PRNGKey(0) + for _ in range(5): + previous_trajectory = mppi.mppi_planner( + config, + world_model=rr_world_model, + policy_prior=rr_policy_prior, + n_step_return=n_step_return, + world_model_params=self.params, + policy_prior_params=self.params, + n_step_return_params=self.params, + random_key=key, + observation=jnp.ones(self.state_dims), + previous_trajectory=previous_trajectory, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/jax/mbop/networks.py b/acme/agents/jax/mbop/networks.py index 76967d62b8..e144cfddb6 100644 --- a/acme/agents/jax/mbop/networks.py +++ b/acme/agents/jax/mbop/networks.py @@ -17,13 +17,13 @@ import dataclasses from typing import Any, Tuple -from acme import specs -from acme.jax import networks -from acme.jax import utils import haiku as hk import jax.numpy as jnp import numpy as np +from acme import specs +from acme.jax import networks, utils + # The term network is used in a general sense, e.g. for the CRR policy prior, it # will be a dataclass that encapsulates the networks used by the CRR (learner). WorldModelNetwork = Any @@ -33,16 +33,17 @@ @dataclasses.dataclass class MBOPNetworks: - """Container class to hold MBOP networks.""" - world_model_network: WorldModelNetwork - policy_prior_network: PolicyPriorNetwork - n_step_return_network: NStepReturnNetwork + """Container class to hold MBOP networks.""" + + world_model_network: WorldModelNetwork + policy_prior_network: PolicyPriorNetwork + n_step_return_network: NStepReturnNetwork def make_network_from_module( - module: hk.Transformed, - spec: specs.EnvironmentSpec) -> networks.FeedForwardNetwork: - """Creates a network with dummy init arguments using the specified module. + module: hk.Transformed, spec: specs.EnvironmentSpec +) -> networks.FeedForwardNetwork: + """Creates a network with dummy init arguments using the specified module. Args: module: Module that expects one batch axis and one features axis for its @@ -53,86 +54,91 @@ def make_network_from_module( FeedForwardNetwork whose `init` method only takes a random key, and `apply` takes an observation and action and produces an output. """ - dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations)) - dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions)) - return networks.FeedForwardNetwork( - lambda key: module.init(key, dummy_obs, dummy_action), module.apply) + dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations)) + dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions)) + return networks.FeedForwardNetwork( + lambda key: module.init(key, dummy_obs, dummy_action), module.apply + ) def make_world_model_network( spec: specs.EnvironmentSpec, hidden_layer_sizes: Tuple[int, ...] = (64, 64) ) -> networks.FeedForwardNetwork: - """Creates a world model network used by the agent.""" + """Creates a world model network used by the agent.""" - observation_size = np.prod(spec.observations.shape, dtype=int) + observation_size = np.prod(spec.observations.shape, dtype=int) - def _world_model_fn(observation_t, action_t, is_training=False, key=None): - # is_training and key allows to defined train/test dependant modules - # like dropout. - del is_training - del key - network = hk.nets.MLP(hidden_layer_sizes + (observation_size + 1,)) - # World model returns both an observation and a reward. - observation_tp1, reward_t = jnp.split( - network(jnp.concatenate([observation_t, action_t], axis=-1)), - [observation_size], - axis=-1) - return observation_tp1, reward_t + def _world_model_fn(observation_t, action_t, is_training=False, key=None): + # is_training and key allows to defined train/test dependant modules + # like dropout. + del is_training + del key + network = hk.nets.MLP(hidden_layer_sizes + (observation_size + 1,)) + # World model returns both an observation and a reward. + observation_tp1, reward_t = jnp.split( + network(jnp.concatenate([observation_t, action_t], axis=-1)), + [observation_size], + axis=-1, + ) + return observation_tp1, reward_t - world_model = hk.without_apply_rng(hk.transform(_world_model_fn)) - return make_network_from_module(world_model, spec) + world_model = hk.without_apply_rng(hk.transform(_world_model_fn)) + return make_network_from_module(world_model, spec) def make_policy_prior_network( spec: specs.EnvironmentSpec, hidden_layer_sizes: Tuple[int, ...] = (64, 64) ) -> networks.FeedForwardNetwork: - """Creates a policy prior network used by the agent.""" + """Creates a policy prior network used by the agent.""" - action_size = np.prod(spec.actions.shape, dtype=int) + action_size = np.prod(spec.actions.shape, dtype=int) - def _policy_prior_fn(observation_t, action_tm1, is_training=False, key=None): - # is_training and key allows to defined train/test dependant modules - # like dropout. - del is_training - del key - network = hk.nets.MLP(hidden_layer_sizes + (action_size,)) - # Policy prior returns an action. - return network(jnp.concatenate([observation_t, action_tm1], axis=-1)) + def _policy_prior_fn(observation_t, action_tm1, is_training=False, key=None): + # is_training and key allows to defined train/test dependant modules + # like dropout. + del is_training + del key + network = hk.nets.MLP(hidden_layer_sizes + (action_size,)) + # Policy prior returns an action. + return network(jnp.concatenate([observation_t, action_tm1], axis=-1)) - policy_prior = hk.without_apply_rng(hk.transform(_policy_prior_fn)) - return make_network_from_module(policy_prior, spec) + policy_prior = hk.without_apply_rng(hk.transform(_policy_prior_fn)) + return make_network_from_module(policy_prior, spec) def make_n_step_return_network( spec: specs.EnvironmentSpec, hidden_layer_sizes: Tuple[int, ...] = (64, 64) ) -> networks.FeedForwardNetwork: - """Creates an N-step return network used by the agent.""" + """Creates an N-step return network used by the agent.""" - def _n_step_return_fn(observation_t, action_t, is_training=False, key=None): - # is_training and key allows to defined train/test dependant modules - # like dropout. - del is_training - del key - network = hk.nets.MLP(hidden_layer_sizes + (1,)) - return network(jnp.concatenate([observation_t, action_t], axis=-1)) + def _n_step_return_fn(observation_t, action_t, is_training=False, key=None): + # is_training and key allows to defined train/test dependant modules + # like dropout. + del is_training + del key + network = hk.nets.MLP(hidden_layer_sizes + (1,)) + return network(jnp.concatenate([observation_t, action_t], axis=-1)) - n_step_return = hk.without_apply_rng(hk.transform(_n_step_return_fn)) - return make_network_from_module(n_step_return, spec) + n_step_return = hk.without_apply_rng(hk.transform(_n_step_return_fn)) + return make_network_from_module(n_step_return, spec) def make_networks( - spec: specs.EnvironmentSpec, - hidden_layer_sizes: Tuple[int, ...] = (64, 64), + spec: specs.EnvironmentSpec, hidden_layer_sizes: Tuple[int, ...] = (64, 64), ) -> MBOPNetworks: - """Creates networks used by the agent.""" - world_model_network = make_world_model_network( - spec, hidden_layer_sizes=hidden_layer_sizes) - policy_prior_network = make_policy_prior_network( - spec, hidden_layer_sizes=hidden_layer_sizes) - n_step_return_network = make_n_step_return_network( - spec, hidden_layer_sizes=hidden_layer_sizes) - - return MBOPNetworks( - world_model_network=world_model_network, - policy_prior_network=policy_prior_network, - n_step_return_network=n_step_return_network) + """Creates networks used by the agent.""" + world_model_network = make_world_model_network( + spec, hidden_layer_sizes=hidden_layer_sizes + ) + policy_prior_network = make_policy_prior_network( + spec, hidden_layer_sizes=hidden_layer_sizes + ) + n_step_return_network = make_n_step_return_network( + spec, hidden_layer_sizes=hidden_layer_sizes + ) + + return MBOPNetworks( + world_model_network=world_model_network, + policy_prior_network=policy_prior_network, + n_step_return_network=n_step_return_network, + ) diff --git a/acme/agents/jax/mpo/__init__.py b/acme/agents/jax/mpo/__init__.py index f957ebceb6..ab428c92a0 100644 --- a/acme/agents/jax/mpo/__init__.py +++ b/acme/agents/jax/mpo/__init__.py @@ -14,14 +14,14 @@ """MPO agent module.""" -from acme.agents.jax.mpo.acting import ActorState -from acme.agents.jax.mpo.acting import make_actor_core +from acme.agents.jax.mpo.acting import ActorState, make_actor_core from acme.agents.jax.mpo.builder import MPOBuilder from acme.agents.jax.mpo.config import MPOConfig from acme.agents.jax.mpo.learning import MPOLearner -from acme.agents.jax.mpo.networks import make_control_networks -from acme.agents.jax.mpo.networks import MPONetworks -from acme.agents.jax.mpo.types import CategoricalPolicyLossConfig -from acme.agents.jax.mpo.types import CriticType -from acme.agents.jax.mpo.types import GaussianPolicyLossConfig -from acme.agents.jax.mpo.types import PolicyLossConfig +from acme.agents.jax.mpo.networks import MPONetworks, make_control_networks +from acme.agents.jax.mpo.types import ( + CategoricalPolicyLossConfig, + CriticType, + GaussianPolicyLossConfig, + PolicyLossConfig, +) diff --git a/acme/agents/jax/mpo/acting.py b/acme/agents/jax/mpo/acting.py index d21ae72ee5..014c212f3b 100644 --- a/acme/agents/jax/mpo/acting.py +++ b/acme/agents/jax/mpo/acting.py @@ -16,74 +16,84 @@ from typing import Mapping, NamedTuple, Tuple, Union -from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax.mpo import networks -from acme.agents.jax.mpo import types -from acme.jax import types as jax_types import haiku as hk import jax import jax.numpy as jnp import numpy as np +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax.mpo import networks, types +from acme.jax import types as jax_types + class ActorState(NamedTuple): - key: jax_types.PRNGKey - core_state: hk.LSTMState - prev_core_state: hk.LSTMState - log_prob: Union[jnp.ndarray, Tuple[()]] = () - - -def make_actor_core(mpo_networks: networks.MPONetworks, - stochastic: bool = True, - store_core_state: bool = False, - store_log_prob: bool = True) -> actor_core_lib.ActorCore: - """Returns a MPO ActorCore from the MPONetworks.""" - - def init(key: jax_types.PRNGKey) -> ActorState: - next_key, key = jax.random.split(key, 2) - batch_size = None - params_initial_state = mpo_networks.torso.initial_state_fn_init( - key, batch_size) - core_state = mpo_networks.torso.initial_state_fn(params_initial_state, - batch_size) - return ActorState( - key=next_key, - core_state=core_state, - prev_core_state=core_state, - log_prob=np.zeros(shape=(), dtype=np.float32) if store_log_prob else ()) - - def select_action(params: networks.MPONetworkParams, - observations: types.Observation, - state: ActorState) -> Tuple[types.Action, ActorState]: - - next_key, key = jax.random.split(state.key, 2) - - # Embed observations and apply stateful core (e.g. recurrent, transformer). - embeddings, core_state = mpo_networks.torso.apply(params.torso, - observations, - state.core_state) - - # Get the action distribution for these observations. - policy = mpo_networks.policy_head_apply(params, embeddings) - actions = policy.sample(seed=key) if stochastic else policy.mode() - - return actions, ActorState( - key=next_key, - core_state=core_state, - prev_core_state=state.core_state, - # Compute log-probabilities for use in off-policy correction schemes. - log_prob=policy.log_prob(actions) if store_log_prob else ()) - - def get_extras(state: ActorState) -> Mapping[str, jnp.ndarray]: - extras = {} - - if store_core_state: - extras['core_state'] = state.prev_core_state - - if store_log_prob: - extras['log_prob'] = state.log_prob - - return extras # pytype: disable=bad-return-type # jax-ndarray - - return actor_core_lib.ActorCore( - init=init, select_action=select_action, get_extras=get_extras) + key: jax_types.PRNGKey + core_state: hk.LSTMState + prev_core_state: hk.LSTMState + log_prob: Union[jnp.ndarray, Tuple[()]] = () + + +def make_actor_core( + mpo_networks: networks.MPONetworks, + stochastic: bool = True, + store_core_state: bool = False, + store_log_prob: bool = True, +) -> actor_core_lib.ActorCore: + """Returns a MPO ActorCore from the MPONetworks.""" + + def init(key: jax_types.PRNGKey) -> ActorState: + next_key, key = jax.random.split(key, 2) + batch_size = None + params_initial_state = mpo_networks.torso.initial_state_fn_init(key, batch_size) + core_state = mpo_networks.torso.initial_state_fn( + params_initial_state, batch_size + ) + return ActorState( + key=next_key, + core_state=core_state, + prev_core_state=core_state, + log_prob=np.zeros(shape=(), dtype=np.float32) if store_log_prob else (), + ) + + def select_action( + params: networks.MPONetworkParams, + observations: types.Observation, + state: ActorState, + ) -> Tuple[types.Action, ActorState]: + + next_key, key = jax.random.split(state.key, 2) + + # Embed observations and apply stateful core (e.g. recurrent, transformer). + embeddings, core_state = mpo_networks.torso.apply( + params.torso, observations, state.core_state + ) + + # Get the action distribution for these observations. + policy = mpo_networks.policy_head_apply(params, embeddings) + actions = policy.sample(seed=key) if stochastic else policy.mode() + + return ( + actions, + ActorState( + key=next_key, + core_state=core_state, + prev_core_state=state.core_state, + # Compute log-probabilities for use in off-policy correction schemes. + log_prob=policy.log_prob(actions) if store_log_prob else (), + ), + ) + + def get_extras(state: ActorState) -> Mapping[str, jnp.ndarray]: + extras = {} + + if store_core_state: + extras["core_state"] = state.prev_core_state + + if store_log_prob: + extras["log_prob"] = state.log_prob + + return extras # pytype: disable=bad-return-type # jax-ndarray + + return actor_core_lib.ActorCore( + init=init, select_action=select_action, get_extras=get_extras + ) diff --git a/acme/agents/jax/mpo/builder.py b/acme/agents/jax/mpo/builder.py index 45abbca5b5..54be8985c8 100644 --- a/acme/agents/jax/mpo/builder.py +++ b/acme/agents/jax/mpo/builder.py @@ -17,14 +17,21 @@ import functools from typing import Iterator, List, Optional +import chex +import jax +import optax +import reverb + +# Acme loves Reverb. +import tensorflow as tf +import tree from absl import logging -from acme import core -from acme import specs + +from acme import core, specs from acme.adders import base from acme.adders import reverb as adders from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.agents.jax import builders +from acme.agents.jax import actors, builders from acme.agents.jax.mpo import acting from acme.agents.jax.mpo import config as mpo_config from acme.agents.jax.mpo import learning @@ -34,309 +41,336 @@ from acme.datasets import reverb as datasets from acme.jax import observation_stacking as obs_stacking from acme.jax import types as jax_types -from acme.jax import utils -from acme.jax import variable_utils -from acme.utils import counting -from acme.utils import loggers -import chex -import jax -import optax -import reverb -# Acme loves Reverb. -import tensorflow as tf -import tree +from acme.jax import utils, variable_utils +from acme.utils import counting, loggers -_POLICY_KEY = 'policy' -_QUEUE_TABLE_NAME = 'queue_table' +_POLICY_KEY = "policy" +_QUEUE_TABLE_NAME = "queue_table" class MPOBuilder(builders.ActorLearnerBuilder): - """Builder class for MPO agent components.""" - - def __init__(self, - config: mpo_config.MPOConfig, - *, - sgd_steps_per_learner_step: int = 8, - max_learner_steps: Optional[int] = None): - self.config = config - self.sgd_steps_per_learner_step = sgd_steps_per_learner_step - self._max_learner_steps = max_learner_steps - - def make_policy( - self, - networks: mpo_networks.MPONetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False, - ) -> actor_core_lib.ActorCore: - actor_core = acting.make_actor_core( - networks, - stochastic=not evaluation, - store_core_state=self.config.use_stale_state, - store_log_prob=self.config.use_retrace) - - # Maybe wrap the actor core to perform actor-side observation stacking. - if self.config.num_stacked_observations > 1: - actor_core = obs_stacking.wrap_actor_core( - actor_core, - observation_spec=environment_spec.observations, - num_stacked_observations=self.config.num_stacked_observations) - - return actor_core - - def make_actor( - self, - random_key: jax_types.PRNGKey, - policy: actor_core_lib.ActorCore, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - adder: Optional[base.Adder] = None, - ) -> core.Actor: - - del environment_spec # This actor doesn't need the spec beyond the policy. - variable_client = variable_utils.VariableClient( - client=variable_source, - key=_POLICY_KEY, - update_period=self.config.variable_update_period) - - return actors.GenericActor( - actor=policy, - random_key=random_key, - variable_client=variable_client, - adder=adder, - backend='cpu') - - def make_learner(self, - random_key: jax_types.PRNGKey, - networks: mpo_networks.MPONetworks, - dataset: Iterator[reverb.ReplaySample], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None) -> core.Learner: - # Set defaults. - del replay_client # Unused as we do not update priorities. - learning_rate = self.config.learning_rate - - # Make sure we can split the batches evenly across all accelerator devices. - num_learner_devices = jax.device_count() - if self.config.batch_size % num_learner_devices > 0: - raise ValueError( - 'Batch size must divide evenly by the number of learner devices.' - f' Passed a batch size of {self.config.batch_size} and the number of' - f' available learner devices is {num_learner_devices}. Specifically,' - f' devices: {jax.devices()}.') - - agent_environment_spec = environment_spec - if self.config.num_stacked_observations > 1: - # Adjust the observation spec for the agent-side frame-stacking. - # Note: this is only for the ActorCore's benefit, the adders want the true - # environment spec. - agent_environment_spec = obs_stacking.get_adjusted_environment_spec( - agent_environment_spec, self.config.num_stacked_observations) - - if self.config.use_cosine_lr_decay: - learning_rate = optax.warmup_cosine_decay_schedule( - init_value=0., - peak_value=self.config.learning_rate, - warmup_steps=self.config.cosine_lr_decay_warmup_steps, - decay_steps=self._max_learner_steps) - - optimizer = optax.adamw( - learning_rate, - b1=self.config.adam_b1, - b2=self.config.adam_b2, - weight_decay=self.config.weight_decay) - # TODO(abef): move LR scheduling and optimizer creation into launcher. - - loss_scales_config = mpo_types.LossScalesConfig( - policy=self.config.policy_loss_scale, - critic=self.config.critic_loss_scale, - rollout=mpo_types.RolloutLossScalesConfig( - policy=self.config.rollout_policy_loss_scale, - bc_policy=self.config.rollout_bc_policy_loss_scale, - critic=self.config.rollout_critic_loss_scale, - reward=self.config.rollout_reward_loss_scale, - )) - - logger = logger_fn( - 'learner', - steps_key=counter.get_steps_key() if counter else 'learner_steps') - - with chex.fake_pmap_and_jit(not self.config.jit_learner, - not self.config.jit_learner): - learner = learning.MPOLearner( - iterator=dataset, - networks=networks, - environment_spec=agent_environment_spec, - critic_type=self.config.critic_type, - discrete_policy=self.config.discrete_policy, - random_key=random_key, - discount=self.config.discount, - num_samples=self.config.num_samples, - policy_eval_stochastic=self.config.policy_eval_stochastic, - policy_eval_num_val_samples=self.config.policy_eval_num_val_samples, - policy_loss_config=self.config.policy_loss_config, - loss_scales=loss_scales_config, - target_update_period=self.config.target_update_period, - target_update_rate=self.config.target_update_rate, - experience_type=self.config.experience_type, - use_online_policy_to_bootstrap=( - self.config.use_online_policy_to_bootstrap), - use_stale_state=self.config.use_stale_state, - use_retrace=self.config.use_retrace, - retrace_lambda=self.config.retrace_lambda, - model_rollout_length=self.config.model_rollout_length, - sgd_steps_per_learner_step=self.sgd_steps_per_learner_step, - optimizer=optimizer, - learning_rate=learning_rate, - dual_optimizer=optax.adam(self.config.dual_learning_rate), - grad_norm_clip=self.config.grad_norm_clip, - reward_clip=self.config.reward_clip, - value_tx_pair=self.config.value_tx_pair, - counter=counter, - logger=logger, - devices=jax.devices(), - ) - return learner - - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: actor_core_lib.ActorCore, # Used to get accurate extras_spec. - ) -> List[reverb.Table]: - dummy_actor_state = policy.init(jax.random.PRNGKey(0)) - extras_spec = policy.get_extras(dummy_actor_state) - - if isinstance(self.config.experience_type, mpo_types.FromTransitions): - signature = adders.NStepTransitionAdder.signature(environment_spec, - extras_spec) - elif isinstance(self.config.experience_type, mpo_types.FromSequences): - sequence_length = ( - self.config.experience_type.sequence_length + - self.config.num_stacked_observations - 1) - signature = adders.SequenceAdder.signature( - environment_spec, extras_spec, sequence_length=sequence_length) - # TODO(bshahr): This way of obtaining the signature is error-prone. Find a - # programmatic way via make_adder. - - # Create the rate limiter. - if self.config.samples_per_insert: - # Create enough of an error buffer to give a 10% tolerance in rate. - samples_per_insert_tolerance = 0.1 * self.config.samples_per_insert - error_buffer = self.config.min_replay_size * samples_per_insert_tolerance - limiter = reverb.rate_limiters.SampleToInsertRatio( - min_size_to_sample=self.config.min_replay_size, - samples_per_insert=self.config.samples_per_insert, - error_buffer=max(error_buffer, 2 * self.config.samples_per_insert)) - else: - limiter = reverb.rate_limiters.MinSize(self.config.min_replay_size) - - # Reverb loves Acme. - replay_extensions = [] - queue_extensions = [] - - - # Create replay tables. - tables = [] - if self.config.replay_fraction > 0: - replay_table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=self.config.max_replay_size, - rate_limiter=limiter, - extensions=replay_extensions, - signature=signature) - tables.append(replay_table) - logging.info( - 'Creating off-policy replay buffer with replay fraction %g ' - 'of batch %d', self.config.replay_fraction, self.config.batch_size) - - if self.config.replay_fraction < 1: - # Create a FIFO queue. This will provide the rate limitation if used. - queue = reverb.Table.queue( - name=_QUEUE_TABLE_NAME, - max_size=self.config.online_queue_capacity, - extensions=queue_extensions, - signature=signature) - tables.append(queue) - logging.info( - 'Creating online replay queue with queue fraction %g ' - 'of batch %d', 1.0 - self.config.replay_fraction, - self.config.batch_size) - - return tables - - def make_adder( - self, - replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[actor_core_lib.ActorCore], - ) -> Optional[base.Adder]: - del environment_spec, policy - # Specify the tables to insert into but don't use prioritization. - priority_fns = {} - if self.config.replay_fraction > 0: - priority_fns[adders.DEFAULT_PRIORITY_TABLE] = None - if self.config.replay_fraction < 1: - priority_fns[_QUEUE_TABLE_NAME] = None - - if isinstance(self.config.experience_type, mpo_types.FromTransitions): - return adders.NStepTransitionAdder( - client=replay_client, - n_step=self.config.experience_type.n_step, - discount=self.config.discount, - priority_fns=priority_fns) - elif isinstance(self.config.experience_type, mpo_types.FromSequences): - sequence_length = ( - self.config.experience_type.sequence_length + - self.config.num_stacked_observations - 1) - return adders.SequenceAdder( - client=replay_client, - sequence_length=sequence_length, - period=self.config.experience_type.sequence_period, - end_of_episode_behavior=adders.EndBehavior.WRITE, - max_in_flight_items=1, - priority_fns=priority_fns) - - def make_dataset_iterator( - self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: - - if self.config.num_stacked_observations > 1: - maybe_stack_observations = functools.partial( - obs_stacking.stack_reverb_observation, - stack_size=self.config.num_stacked_observations) - else: - maybe_stack_observations = None - - dataset = datasets.make_reverb_dataset( - server_address=replay_client.server_address, - batch_size=self.config.batch_size // jax.device_count(), - table={ - adders.DEFAULT_PRIORITY_TABLE: self.config.replay_fraction, - _QUEUE_TABLE_NAME: 1. - self.config.replay_fraction, - }, - num_parallel_calls=max(16, 4 * jax.local_device_count()), - max_in_flight_samples_per_worker=(2 * self.sgd_steps_per_learner_step * - self.config.batch_size // - jax.device_count()), - postprocess=maybe_stack_observations) - - if self.config.observation_transform: - # Augment dataset with random translations, simulated by pad-and-crop. - transform = img_aug.make_transform( - observation_transform=self.config.observation_transform, - transform_next_observation=isinstance(self.config.experience_type, - mpo_types.FromTransitions)) - dataset = dataset.map( - transform, num_parallel_calls=16, deterministic=False) - - # Batch and then flatten to feed multiple SGD steps per learner step. - if self.sgd_steps_per_learner_step > 1: - dataset = dataset.batch( - self.sgd_steps_per_learner_step, drop_remainder=True) - batch_flatten = lambda t: tf.reshape(t, [-1] + t.shape[2:].as_list()) - dataset = dataset.map(lambda x: tree.map_structure(batch_flatten, x)) - - return utils.multi_device_put(dataset.as_numpy_iterator(), - jax.local_devices()) + """Builder class for MPO agent components.""" + + def __init__( + self, + config: mpo_config.MPOConfig, + *, + sgd_steps_per_learner_step: int = 8, + max_learner_steps: Optional[int] = None, + ): + self.config = config + self.sgd_steps_per_learner_step = sgd_steps_per_learner_step + self._max_learner_steps = max_learner_steps + + def make_policy( + self, + networks: mpo_networks.MPONetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> actor_core_lib.ActorCore: + actor_core = acting.make_actor_core( + networks, + stochastic=not evaluation, + store_core_state=self.config.use_stale_state, + store_log_prob=self.config.use_retrace, + ) + + # Maybe wrap the actor core to perform actor-side observation stacking. + if self.config.num_stacked_observations > 1: + actor_core = obs_stacking.wrap_actor_core( + actor_core, + observation_spec=environment_spec.observations, + num_stacked_observations=self.config.num_stacked_observations, + ) + + return actor_core + + def make_actor( + self, + random_key: jax_types.PRNGKey, + policy: actor_core_lib.ActorCore, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[base.Adder] = None, + ) -> core.Actor: + + del environment_spec # This actor doesn't need the spec beyond the policy. + variable_client = variable_utils.VariableClient( + client=variable_source, + key=_POLICY_KEY, + update_period=self.config.variable_update_period, + ) + + return actors.GenericActor( + actor=policy, + random_key=random_key, + variable_client=variable_client, + adder=adder, + backend="cpu", + ) + + def make_learner( + self, + random_key: jax_types.PRNGKey, + networks: mpo_networks.MPONetworks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + # Set defaults. + del replay_client # Unused as we do not update priorities. + learning_rate = self.config.learning_rate + + # Make sure we can split the batches evenly across all accelerator devices. + num_learner_devices = jax.device_count() + if self.config.batch_size % num_learner_devices > 0: + raise ValueError( + "Batch size must divide evenly by the number of learner devices." + f" Passed a batch size of {self.config.batch_size} and the number of" + f" available learner devices is {num_learner_devices}. Specifically," + f" devices: {jax.devices()}." + ) + + agent_environment_spec = environment_spec + if self.config.num_stacked_observations > 1: + # Adjust the observation spec for the agent-side frame-stacking. + # Note: this is only for the ActorCore's benefit, the adders want the true + # environment spec. + agent_environment_spec = obs_stacking.get_adjusted_environment_spec( + agent_environment_spec, self.config.num_stacked_observations + ) + + if self.config.use_cosine_lr_decay: + learning_rate = optax.warmup_cosine_decay_schedule( + init_value=0.0, + peak_value=self.config.learning_rate, + warmup_steps=self.config.cosine_lr_decay_warmup_steps, + decay_steps=self._max_learner_steps, + ) + + optimizer = optax.adamw( + learning_rate, + b1=self.config.adam_b1, + b2=self.config.adam_b2, + weight_decay=self.config.weight_decay, + ) + # TODO(abef): move LR scheduling and optimizer creation into launcher. + + loss_scales_config = mpo_types.LossScalesConfig( + policy=self.config.policy_loss_scale, + critic=self.config.critic_loss_scale, + rollout=mpo_types.RolloutLossScalesConfig( + policy=self.config.rollout_policy_loss_scale, + bc_policy=self.config.rollout_bc_policy_loss_scale, + critic=self.config.rollout_critic_loss_scale, + reward=self.config.rollout_reward_loss_scale, + ), + ) + + logger = logger_fn( + "learner", steps_key=counter.get_steps_key() if counter else "learner_steps" + ) + + with chex.fake_pmap_and_jit( + not self.config.jit_learner, not self.config.jit_learner + ): + learner = learning.MPOLearner( + iterator=dataset, + networks=networks, + environment_spec=agent_environment_spec, + critic_type=self.config.critic_type, + discrete_policy=self.config.discrete_policy, + random_key=random_key, + discount=self.config.discount, + num_samples=self.config.num_samples, + policy_eval_stochastic=self.config.policy_eval_stochastic, + policy_eval_num_val_samples=self.config.policy_eval_num_val_samples, + policy_loss_config=self.config.policy_loss_config, + loss_scales=loss_scales_config, + target_update_period=self.config.target_update_period, + target_update_rate=self.config.target_update_rate, + experience_type=self.config.experience_type, + use_online_policy_to_bootstrap=( + self.config.use_online_policy_to_bootstrap + ), + use_stale_state=self.config.use_stale_state, + use_retrace=self.config.use_retrace, + retrace_lambda=self.config.retrace_lambda, + model_rollout_length=self.config.model_rollout_length, + sgd_steps_per_learner_step=self.sgd_steps_per_learner_step, + optimizer=optimizer, + learning_rate=learning_rate, + dual_optimizer=optax.adam(self.config.dual_learning_rate), + grad_norm_clip=self.config.grad_norm_clip, + reward_clip=self.config.reward_clip, + value_tx_pair=self.config.value_tx_pair, + counter=counter, + logger=logger, + devices=jax.devices(), + ) + return learner + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: actor_core_lib.ActorCore, # Used to get accurate extras_spec. + ) -> List[reverb.Table]: + dummy_actor_state = policy.init(jax.random.PRNGKey(0)) + extras_spec = policy.get_extras(dummy_actor_state) + + if isinstance(self.config.experience_type, mpo_types.FromTransitions): + signature = adders.NStepTransitionAdder.signature( + environment_spec, extras_spec + ) + elif isinstance(self.config.experience_type, mpo_types.FromSequences): + sequence_length = ( + self.config.experience_type.sequence_length + + self.config.num_stacked_observations + - 1 + ) + signature = adders.SequenceAdder.signature( + environment_spec, extras_spec, sequence_length=sequence_length + ) + # TODO(bshahr): This way of obtaining the signature is error-prone. Find a + # programmatic way via make_adder. + + # Create the rate limiter. + if self.config.samples_per_insert: + # Create enough of an error buffer to give a 10% tolerance in rate. + samples_per_insert_tolerance = 0.1 * self.config.samples_per_insert + error_buffer = self.config.min_replay_size * samples_per_insert_tolerance + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self.config.min_replay_size, + samples_per_insert=self.config.samples_per_insert, + error_buffer=max(error_buffer, 2 * self.config.samples_per_insert), + ) + else: + limiter = reverb.rate_limiters.MinSize(self.config.min_replay_size) + + # Reverb loves Acme. + replay_extensions = [] + queue_extensions = [] + + # Create replay tables. + tables = [] + if self.config.replay_fraction > 0: + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self.config.max_replay_size, + rate_limiter=limiter, + extensions=replay_extensions, + signature=signature, + ) + tables.append(replay_table) + logging.info( + "Creating off-policy replay buffer with replay fraction %g " + "of batch %d", + self.config.replay_fraction, + self.config.batch_size, + ) + + if self.config.replay_fraction < 1: + # Create a FIFO queue. This will provide the rate limitation if used. + queue = reverb.Table.queue( + name=_QUEUE_TABLE_NAME, + max_size=self.config.online_queue_capacity, + extensions=queue_extensions, + signature=signature, + ) + tables.append(queue) + logging.info( + "Creating online replay queue with queue fraction %g " "of batch %d", + 1.0 - self.config.replay_fraction, + self.config.batch_size, + ) + + return tables + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[actor_core_lib.ActorCore], + ) -> Optional[base.Adder]: + del environment_spec, policy + # Specify the tables to insert into but don't use prioritization. + priority_fns = {} + if self.config.replay_fraction > 0: + priority_fns[adders.DEFAULT_PRIORITY_TABLE] = None + if self.config.replay_fraction < 1: + priority_fns[_QUEUE_TABLE_NAME] = None + + if isinstance(self.config.experience_type, mpo_types.FromTransitions): + return adders.NStepTransitionAdder( + client=replay_client, + n_step=self.config.experience_type.n_step, + discount=self.config.discount, + priority_fns=priority_fns, + ) + elif isinstance(self.config.experience_type, mpo_types.FromSequences): + sequence_length = ( + self.config.experience_type.sequence_length + + self.config.num_stacked_observations + - 1 + ) + return adders.SequenceAdder( + client=replay_client, + sequence_length=sequence_length, + period=self.config.experience_type.sequence_period, + end_of_episode_behavior=adders.EndBehavior.WRITE, + max_in_flight_items=1, + priority_fns=priority_fns, + ) + + def make_dataset_iterator( + self, replay_client: reverb.Client + ) -> Iterator[reverb.ReplaySample]: + + if self.config.num_stacked_observations > 1: + maybe_stack_observations = functools.partial( + obs_stacking.stack_reverb_observation, + stack_size=self.config.num_stacked_observations, + ) + else: + maybe_stack_observations = None + + dataset = datasets.make_reverb_dataset( + server_address=replay_client.server_address, + batch_size=self.config.batch_size // jax.device_count(), + table={ + adders.DEFAULT_PRIORITY_TABLE: self.config.replay_fraction, + _QUEUE_TABLE_NAME: 1.0 - self.config.replay_fraction, + }, + num_parallel_calls=max(16, 4 * jax.local_device_count()), + max_in_flight_samples_per_worker=( + 2 + * self.sgd_steps_per_learner_step + * self.config.batch_size + // jax.device_count() + ), + postprocess=maybe_stack_observations, + ) + + if self.config.observation_transform: + # Augment dataset with random translations, simulated by pad-and-crop. + transform = img_aug.make_transform( + observation_transform=self.config.observation_transform, + transform_next_observation=isinstance( + self.config.experience_type, mpo_types.FromTransitions + ), + ) + dataset = dataset.map(transform, num_parallel_calls=16, deterministic=False) + + # Batch and then flatten to feed multiple SGD steps per learner step. + if self.sgd_steps_per_learner_step > 1: + dataset = dataset.batch( + self.sgd_steps_per_learner_step, drop_remainder=True + ) + batch_flatten = lambda t: tf.reshape(t, [-1] + t.shape[2:].as_list()) + dataset = dataset.map(lambda x: tree.map_structure(batch_flatten, x)) + + return utils.multi_device_put(dataset.as_numpy_iterator(), jax.local_devices()) diff --git a/acme/agents/jax/mpo/categorical_mpo.py b/acme/agents/jax/mpo/categorical_mpo.py index 0faea932f7..74b5d2e09e 100644 --- a/acme/agents/jax/mpo/categorical_mpo.py +++ b/acme/agents/jax/mpo/categorical_mpo.py @@ -36,46 +36,50 @@ class CategoricalMPOParams(NamedTuple): - """NamedTuple to store trainable loss parameters.""" - log_temperature: jnp.ndarray - log_alpha: jnp.ndarray + """NamedTuple to store trainable loss parameters.""" + + log_temperature: jnp.ndarray + log_alpha: jnp.ndarray class CategoricalMPOStats(NamedTuple): - """NamedTuple to store loss statistics.""" - dual_alpha: float - dual_temperature: float + """NamedTuple to store loss statistics.""" + + dual_alpha: float + dual_temperature: float - loss_e_step: float - loss_m_step: float - loss_dual: float + loss_e_step: float + loss_m_step: float + loss_dual: float - loss_policy: float - loss_alpha: float - loss_temperature: float + loss_policy: float + loss_alpha: float + loss_temperature: float - kl_q_rel: float - kl_mean_rel: float + kl_q_rel: float + kl_mean_rel: float - q_min: float - q_max: float + q_min: float + q_max: float - entropy_online: float - entropy_target: float + entropy_online: float + entropy_target: float class CategoricalMPO: - """MPO loss for a categorical policy (Abdolmaleki et al., 2018). + """MPO loss for a categorical policy (Abdolmaleki et al., 2018). (Abdolmaleki et al., 2018): https://arxiv.org/pdf/1812.02256.pdf """ - def __init__(self, - epsilon: float, - epsilon_policy: float, - init_log_temperature: float, - init_log_alpha: float): - """Initializes the MPO loss for discrete (categorical) policies. + def __init__( + self, + epsilon: float, + epsilon_policy: float, + init_log_temperature: float, + init_log_alpha: float, + ): + """Initializes the MPO loss for discrete (categorical) policies. Args: epsilon: KL constraint on the non-parametric auxiliary policy, the one @@ -88,30 +92,31 @@ def __init__(self, (rather than an exp) will be used to transform this. """ - # MPO constraint thresholds. - self._epsilon = epsilon - self._epsilon_policy = epsilon_policy - - # Initial values for the constraints' dual variables. - self._init_log_temperature = init_log_temperature - self._init_log_alpha = init_log_alpha - - def init_params(self, action_dim: int, dtype: DType = jnp.float32): - """Creates an initial set of parameters.""" - del action_dim # Unused. - return CategoricalMPOParams( - log_temperature=jnp.full([1], self._init_log_temperature, dtype=dtype), - log_alpha=jnp.full([1], self._init_log_alpha, dtype=dtype)) - - def __call__( - self, - params: CategoricalMPOParams, - online_action_distribution: distrax.Categorical, - target_action_distribution: distrax.Categorical, - actions: jnp.ndarray, # Unused. - q_values: jnp.ndarray, # Shape [D, B]. - ) -> Tuple[jnp.ndarray, CategoricalMPOStats]: - """Computes the MPO loss for a categorical policy. + # MPO constraint thresholds. + self._epsilon = epsilon + self._epsilon_policy = epsilon_policy + + # Initial values for the constraints' dual variables. + self._init_log_temperature = init_log_temperature + self._init_log_alpha = init_log_alpha + + def init_params(self, action_dim: int, dtype: DType = jnp.float32): + """Creates an initial set of parameters.""" + del action_dim # Unused. + return CategoricalMPOParams( + log_temperature=jnp.full([1], self._init_log_temperature, dtype=dtype), + log_alpha=jnp.full([1], self._init_log_alpha, dtype=dtype), + ) + + def __call__( + self, + params: CategoricalMPOParams, + online_action_distribution: distrax.Categorical, + target_action_distribution: distrax.Categorical, + actions: jnp.ndarray, # Unused. + q_values: jnp.ndarray, # Shape [D, B]. + ) -> Tuple[jnp.ndarray, CategoricalMPOStats]: + """Computes the MPO loss for a categorical policy. Args: params: parameters tracking the temperature and the dual variables. @@ -128,65 +133,73 @@ def __call__( Stats, for diagnostics and tracking performance. """ - q_values = jnp.transpose(q_values) # [D, B] --> [B, D]. - - # Transform dual variables from log-space. - # Note: using softplus instead of exponential for numerical stability. - temperature = get_temperature_from_params(params) - alpha = jax.nn.softplus(params.log_alpha) + _MPO_FLOAT_EPSILON - - # Compute the E-step logits and the temperature loss, used to adapt the - # tempering of Q-values. - logits_e_step, loss_temperature = compute_weights_and_temperature_loss( # pytype: disable=wrong-arg-types # jax-ndarray - q_values=q_values, logits=target_action_distribution.logits, - epsilon=self._epsilon, temperature=temperature) - action_distribution_e_step = distrax.Categorical(logits=logits_e_step) - - # Only needed for diagnostics: Compute estimated actualized KL between the - # non-parametric and current target policies. - kl_nonparametric = action_distribution_e_step.kl_divergence( - target_action_distribution) - - # Compute the policy loss. - loss_policy = action_distribution_e_step.cross_entropy( - online_action_distribution) - loss_policy = jnp.mean(loss_policy) - - # Compute the regularization. - kl = target_action_distribution.kl_divergence(online_action_distribution) - mean_kl = jnp.mean(kl, axis=0) - loss_kl = jax.lax.stop_gradient(alpha) * mean_kl - - # Compute the dual loss. - loss_alpha = alpha * (self._epsilon_policy - jax.lax.stop_gradient(mean_kl)) - - # Combine losses. - loss_dual = loss_alpha + loss_temperature - loss = loss_policy + loss_kl + loss_dual - - # Create statistics. - stats = CategoricalMPOStats( - # Dual Variables. - dual_alpha=jnp.mean(alpha), - dual_temperature=jnp.mean(temperature), - # Losses. - loss_e_step=loss_policy, - loss_m_step=loss_kl, - loss_dual=loss_dual, - loss_policy=jnp.mean(loss), - loss_alpha=jnp.mean(loss_alpha), - loss_temperature=jnp.mean(loss_temperature), - # KL measurements. - kl_q_rel=jnp.mean(kl_nonparametric) / self._epsilon, - kl_mean_rel=mean_kl / self._epsilon_policy, - # Q measurements. - q_min=jnp.mean(jnp.min(q_values, axis=0)), - q_max=jnp.mean(jnp.max(q_values, axis=0)), - entropy_online=jnp.mean(online_action_distribution.entropy()), - entropy_target=jnp.mean(target_action_distribution.entropy()) - ) - - return loss, stats + q_values = jnp.transpose(q_values) # [D, B] --> [B, D]. + + # Transform dual variables from log-space. + # Note: using softplus instead of exponential for numerical stability. + temperature = get_temperature_from_params(params) + alpha = jax.nn.softplus(params.log_alpha) + _MPO_FLOAT_EPSILON + + # Compute the E-step logits and the temperature loss, used to adapt the + # tempering of Q-values. + ( + logits_e_step, + loss_temperature, + ) = compute_weights_and_temperature_loss( # pytype: disable=wrong-arg-types # jax-ndarray + q_values=q_values, + logits=target_action_distribution.logits, + epsilon=self._epsilon, + temperature=temperature, + ) + action_distribution_e_step = distrax.Categorical(logits=logits_e_step) + + # Only needed for diagnostics: Compute estimated actualized KL between the + # non-parametric and current target policies. + kl_nonparametric = action_distribution_e_step.kl_divergence( + target_action_distribution + ) + + # Compute the policy loss. + loss_policy = action_distribution_e_step.cross_entropy( + online_action_distribution + ) + loss_policy = jnp.mean(loss_policy) + + # Compute the regularization. + kl = target_action_distribution.kl_divergence(online_action_distribution) + mean_kl = jnp.mean(kl, axis=0) + loss_kl = jax.lax.stop_gradient(alpha) * mean_kl + + # Compute the dual loss. + loss_alpha = alpha * (self._epsilon_policy - jax.lax.stop_gradient(mean_kl)) + + # Combine losses. + loss_dual = loss_alpha + loss_temperature + loss = loss_policy + loss_kl + loss_dual + + # Create statistics. + stats = CategoricalMPOStats( + # Dual Variables. + dual_alpha=jnp.mean(alpha), + dual_temperature=jnp.mean(temperature), + # Losses. + loss_e_step=loss_policy, + loss_m_step=loss_kl, + loss_dual=loss_dual, + loss_policy=jnp.mean(loss), + loss_alpha=jnp.mean(loss_alpha), + loss_temperature=jnp.mean(loss_temperature), + # KL measurements. + kl_q_rel=jnp.mean(kl_nonparametric) / self._epsilon, + kl_mean_rel=mean_kl / self._epsilon_policy, + # Q measurements. + q_min=jnp.mean(jnp.min(q_values, axis=0)), + q_max=jnp.mean(jnp.max(q_values, axis=0)), + entropy_online=jnp.mean(online_action_distribution.entropy()), + entropy_target=jnp.mean(target_action_distribution.entropy()), + ) + + return loss, stats def compute_weights_and_temperature_loss( @@ -195,7 +208,7 @@ def compute_weights_and_temperature_loss( epsilon: float, temperature: jnp.ndarray, ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Computes normalized importance weights for the policy optimization. + """Computes normalized importance weights for the policy optimization. Args: q_values: Q-values associated with the actions sampled from the target @@ -214,28 +227,28 @@ def compute_weights_and_temperature_loss( Temperature loss, used to adapt the temperature. """ - # Temper the given Q-values using the current temperature. - tempered_q_values = jax.lax.stop_gradient(q_values) / temperature + # Temper the given Q-values using the current temperature. + tempered_q_values = jax.lax.stop_gradient(q_values) / temperature - # Compute the E-step normalized logits. - unnormalized_logits = tempered_q_values + jax.nn.log_softmax(logits, axis=-1) - logits_e_step = jax.nn.log_softmax(unnormalized_logits, axis=-1) + # Compute the E-step normalized logits. + unnormalized_logits = tempered_q_values + jax.nn.log_softmax(logits, axis=-1) + logits_e_step = jax.nn.log_softmax(unnormalized_logits, axis=-1) - # Compute the temperature loss (dual of the E-step optimization problem). - # Note that the log normalizer will be the same for all actions, so we choose - # only the first one. - log_normalizer = unnormalized_logits[:, 0] - logits_e_step[:, 0] - loss_temperature = temperature * (epsilon + jnp.mean(log_normalizer)) + # Compute the temperature loss (dual of the E-step optimization problem). + # Note that the log normalizer will be the same for all actions, so we choose + # only the first one. + log_normalizer = unnormalized_logits[:, 0] - logits_e_step[:, 0] + loss_temperature = temperature * (epsilon + jnp.mean(log_normalizer)) - return logits_e_step, loss_temperature + return logits_e_step, loss_temperature -def clip_categorical_mpo_params( - params: CategoricalMPOParams) -> CategoricalMPOParams: - return params._replace( - log_temperature=jnp.maximum(_MIN_LOG_TEMPERATURE, params.log_temperature), - log_alpha=jnp.maximum(_MIN_LOG_ALPHA, params.log_alpha)) +def clip_categorical_mpo_params(params: CategoricalMPOParams) -> CategoricalMPOParams: + return params._replace( + log_temperature=jnp.maximum(_MIN_LOG_TEMPERATURE, params.log_temperature), + log_alpha=jnp.maximum(_MIN_LOG_ALPHA, params.log_alpha), + ) def get_temperature_from_params(params: CategoricalMPOParams) -> float: - return jax.nn.softplus(params.log_temperature) + _MPO_FLOAT_EPSILON + return jax.nn.softplus(params.log_temperature) + _MPO_FLOAT_EPSILON diff --git a/acme/agents/jax/mpo/config.py b/acme/agents/jax/mpo/config.py index 3e1d70ab31..187b46e0f7 100644 --- a/acme/agents/jax/mpo/config.py +++ b/acme/agents/jax/mpo/config.py @@ -17,135 +17,152 @@ import dataclasses from typing import Callable, Optional, Union -from acme import types -from acme.agents.jax.mpo import types as mpo_types import numpy as np import rlax +from acme import types +from acme.agents.jax.mpo import types as mpo_types + @dataclasses.dataclass class MPOConfig: - """MPO agent configuration.""" - - batch_size: int = 256 # Total batch size across all learner devices. - discount: float = 0.99 - discrete_policy: bool = False - - # Specification of the type of experience the learner will consume. - experience_type: mpo_types.ExperienceType = mpo_types.FromTransitions( - n_step=5) - num_stacked_observations: int = 1 - # Optional data-augmentation transformation for observations. - observation_transform: Optional[Callable[[types.NestedTensor], - types.NestedTensor]] = None - - # Specification of replay, e.g., min/max size, pure or mixed. - # NOTE: When replay_fraction = 1.0, this reverts to pure replay and the online - # queue is not created. - replay_fraction: float = 1.0 # Fraction of replay data (vs online) per batch. - samples_per_insert: Optional[float] = 32.0 - min_replay_size: int = 1_000 - max_replay_size: int = 1_000_000 - online_queue_capacity: int = 0 # If not set, will use 4 * online_batch_size. - - # Critic training configuration. - critic_type: mpo_types.CriticType = mpo_types.CriticType.MIXTURE_OF_GAUSSIANS - value_tx_pair: rlax.TxPair = rlax.IDENTITY_PAIR - use_retrace: bool = False - retrace_lambda: float = 0.95 - reward_clip: float = np.float32('inf') # pytype: disable=annotation-type-mismatch # numpy-scalars - use_online_policy_to_bootstrap: bool = False - use_stale_state: bool = False - - # Policy training configuration. - num_samples: int = 20 # Number of MPO action samples. - policy_loss_config: Optional[mpo_types.PolicyLossConfig] = None - policy_eval_stochastic: bool = True - policy_eval_num_val_samples: int = 128 - - # Optimizer configuration. - learning_rate: Union[float, Callable[[int], float]] = 1e-4 - dual_learning_rate: Union[float, Callable[[int], float]] = 1e-2 - grad_norm_clip: float = 40. - adam_b1: float = 0.9 - adam_b2: float = 0.999 - weight_decay: float = 0.0 - use_cosine_lr_decay: bool = False - cosine_lr_decay_warmup_steps: int = 3000 - - # Set the target update period or rate depending on whether you want a - # periodic or incremental (exponential weighted average) target update. - # Exactly one must be specified (not None). - target_update_period: Optional[int] = 100 - target_update_rate: Optional[float] = None - variable_update_period: int = 1000 - - # Configuring the mixture of policy and critic losses. - policy_loss_scale: float = 1.0 - critic_loss_scale: float = 1.0 - - # Optional roll-out loss configuration (off by default). - model_rollout_length: int = 0 - rollout_policy_loss_scale: float = 1.0 - rollout_bc_policy_loss_scale: float = 1.0 - rollout_critic_loss_scale: float = 1.0 - rollout_reward_loss_scale: float = 1.0 - - jit_learner: bool = True - - def __post_init__(self): - if ((self.target_update_period and self.target_update_rate) or - (self.target_update_period is None and - self.target_update_rate is None)): - raise ValueError( - 'Exactly one of target_update_{period|rate} must be set.' - f' Received target_update_period={self.target_update_period} and' - f' target_update_rate={self.target_update_rate}.') - - online_batch_size = int(self.batch_size * (1. - self.replay_fraction)) - if not self.online_queue_capacity: - # Note: larger capacities mean the online data is more "stale". This seems - # a reasonable default for now. - self.online_queue_capacity = int(4 * online_batch_size) - self.online_queue_capacity = max(self.online_queue_capacity, - online_batch_size + 1) - - if self.samples_per_insert is not None and self.replay_fraction < 1: - raise ValueError( - 'Cannot set samples_per_insert when using a mixed replay (i.e when ' - '0 < replay_fraction < 1). Received:\n' - f'\tsamples_per_insert={self.samples_per_insert} and\n' - f'\treplay_fraction={self.replay_fraction}.') - - if (0 < self.replay_fraction < 1 and - self.min_replay_size > self.online_queue_capacity): - raise ValueError('When mixing replay with an online queue, min replay ' - 'size must not be larger than the queue capacity.') - - if (isinstance(self.experience_type, mpo_types.FromTransitions) and - self.num_stacked_observations > 1): - raise ValueError( - 'Agent-side frame-stacking is currently only supported when learning ' - 'from sequences. Consider environment-side frame-stacking instead.') - - if self.critic_type == mpo_types.CriticType.CATEGORICAL: - if self.model_rollout_length > 0: - raise ValueError( - 'Model rollouts are not supported for the Categorical critic') - if not isinstance(self.experience_type, mpo_types.FromTransitions): - raise ValueError( - 'Categorical critic only supports experience_type=FromTransitions') - if self.use_retrace: - raise ValueError('retrace is not supported for the Categorical critic') - - if self.model_rollout_length > 0 and not self.discrete_policy: - if (self.rollout_policy_loss_scale or self.rollout_bc_policy_loss_scale): - raise ValueError('Policy rollout losses are only supported in the ' - 'discrete policy case.') + """MPO agent configuration.""" + + batch_size: int = 256 # Total batch size across all learner devices. + discount: float = 0.99 + discrete_policy: bool = False + + # Specification of the type of experience the learner will consume. + experience_type: mpo_types.ExperienceType = mpo_types.FromTransitions(n_step=5) + num_stacked_observations: int = 1 + # Optional data-augmentation transformation for observations. + observation_transform: Optional[ + Callable[[types.NestedTensor], types.NestedTensor] + ] = None + + # Specification of replay, e.g., min/max size, pure or mixed. + # NOTE: When replay_fraction = 1.0, this reverts to pure replay and the online + # queue is not created. + replay_fraction: float = 1.0 # Fraction of replay data (vs online) per batch. + samples_per_insert: Optional[float] = 32.0 + min_replay_size: int = 1_000 + max_replay_size: int = 1_000_000 + online_queue_capacity: int = 0 # If not set, will use 4 * online_batch_size. + + # Critic training configuration. + critic_type: mpo_types.CriticType = mpo_types.CriticType.MIXTURE_OF_GAUSSIANS + value_tx_pair: rlax.TxPair = rlax.IDENTITY_PAIR + use_retrace: bool = False + retrace_lambda: float = 0.95 + reward_clip: float = np.float32( + "inf" + ) # pytype: disable=annotation-type-mismatch # numpy-scalars + use_online_policy_to_bootstrap: bool = False + use_stale_state: bool = False + + # Policy training configuration. + num_samples: int = 20 # Number of MPO action samples. + policy_loss_config: Optional[mpo_types.PolicyLossConfig] = None + policy_eval_stochastic: bool = True + policy_eval_num_val_samples: int = 128 + + # Optimizer configuration. + learning_rate: Union[float, Callable[[int], float]] = 1e-4 + dual_learning_rate: Union[float, Callable[[int], float]] = 1e-2 + grad_norm_clip: float = 40.0 + adam_b1: float = 0.9 + adam_b2: float = 0.999 + weight_decay: float = 0.0 + use_cosine_lr_decay: bool = False + cosine_lr_decay_warmup_steps: int = 3000 + + # Set the target update period or rate depending on whether you want a + # periodic or incremental (exponential weighted average) target update. + # Exactly one must be specified (not None). + target_update_period: Optional[int] = 100 + target_update_rate: Optional[float] = None + variable_update_period: int = 1000 + + # Configuring the mixture of policy and critic losses. + policy_loss_scale: float = 1.0 + critic_loss_scale: float = 1.0 + + # Optional roll-out loss configuration (off by default). + model_rollout_length: int = 0 + rollout_policy_loss_scale: float = 1.0 + rollout_bc_policy_loss_scale: float = 1.0 + rollout_critic_loss_scale: float = 1.0 + rollout_reward_loss_scale: float = 1.0 + + jit_learner: bool = True + + def __post_init__(self): + if (self.target_update_period and self.target_update_rate) or ( + self.target_update_period is None and self.target_update_rate is None + ): + raise ValueError( + "Exactly one of target_update_{period|rate} must be set." + f" Received target_update_period={self.target_update_period} and" + f" target_update_rate={self.target_update_rate}." + ) + + online_batch_size = int(self.batch_size * (1.0 - self.replay_fraction)) + if not self.online_queue_capacity: + # Note: larger capacities mean the online data is more "stale". This seems + # a reasonable default for now. + self.online_queue_capacity = int(4 * online_batch_size) + self.online_queue_capacity = max( + self.online_queue_capacity, online_batch_size + 1 + ) + + if self.samples_per_insert is not None and self.replay_fraction < 1: + raise ValueError( + "Cannot set samples_per_insert when using a mixed replay (i.e when " + "0 < replay_fraction < 1). Received:\n" + f"\tsamples_per_insert={self.samples_per_insert} and\n" + f"\treplay_fraction={self.replay_fraction}." + ) + + if ( + 0 < self.replay_fraction < 1 + and self.min_replay_size > self.online_queue_capacity + ): + raise ValueError( + "When mixing replay with an online queue, min replay " + "size must not be larger than the queue capacity." + ) + + if ( + isinstance(self.experience_type, mpo_types.FromTransitions) + and self.num_stacked_observations > 1 + ): + raise ValueError( + "Agent-side frame-stacking is currently only supported when learning " + "from sequences. Consider environment-side frame-stacking instead." + ) + + if self.critic_type == mpo_types.CriticType.CATEGORICAL: + if self.model_rollout_length > 0: + raise ValueError( + "Model rollouts are not supported for the Categorical critic" + ) + if not isinstance(self.experience_type, mpo_types.FromTransitions): + raise ValueError( + "Categorical critic only supports experience_type=FromTransitions" + ) + if self.use_retrace: + raise ValueError("retrace is not supported for the Categorical critic") + + if self.model_rollout_length > 0 and not self.discrete_policy: + if self.rollout_policy_loss_scale or self.rollout_bc_policy_loss_scale: + raise ValueError( + "Policy rollout losses are only supported in the " + "discrete policy case." + ) def _compute_spi_from_replay_fraction(replay_fraction: float) -> float: - """Computes an estimated samples_per_insert from a replay_fraction. + """Computes an estimated samples_per_insert from a replay_fraction. Assumes actors simultaneously add to both the queue and replay in a mixed replay setup. Since the online queue sets samples_per_insert = 1, then the @@ -166,11 +183,11 @@ def _compute_spi_from_replay_fraction(replay_fraction: float) -> float: An estimate of the samples_per_insert value to produce comparable runs in the pure replay setting. """ - return 1 / (1 - replay_fraction) + return 1 / (1 - replay_fraction) -def _compute_num_inserts_per_actor_step(samples_per_insert: float, - batch_size: int, - sequence_period: int = 1) -> float: - """Estimate the number inserts per actor steps.""" - return sequence_period * batch_size / samples_per_insert +def _compute_num_inserts_per_actor_step( + samples_per_insert: float, batch_size: int, sequence_period: int = 1 +) -> float: + """Estimate the number inserts per actor steps.""" + return sequence_period * batch_size / samples_per_insert diff --git a/acme/agents/jax/mpo/learning.py b/acme/agents/jax/mpo/learning.py index 681b0e63f0..79cd967e34 100644 --- a/acme/agents/jax/mpo/learning.py +++ b/acme/agents/jax/mpo/learning.py @@ -17,12 +17,31 @@ import dataclasses import functools import time -from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) +import chex +import jax +import jax.numpy as jnp +import numpy as np +import optax +import reverb +import rlax +import tree from absl import logging + import acme -from acme import specs -from acme import types +import acme.jax.losses.mpo as continuous_losses +from acme import specs, types from acme.adders import reverb as adders from acme.agents.jax.mpo import categorical_mpo as discrete_losses from acme.agents.jax.mpo import networks as mpo_networks @@ -32,716 +51,771 @@ from acme.jax import networks as network_lib from acme.jax import types as jax_types from acme.jax import utils -import acme.jax.losses.mpo as continuous_losses -from acme.utils import counting -from acme.utils import loggers -import chex -import jax -import jax.numpy as jnp -import numpy as np -import optax -import reverb -import rlax -import tree +from acme.utils import counting, loggers -_PMAP_AXIS_NAME = 'data' +_PMAP_AXIS_NAME = "data" CriticType = mpo_types.CriticType class TrainingState(NamedTuple): - """Contains training state for the learner.""" - params: mpo_networks.MPONetworkParams - target_params: mpo_networks.MPONetworkParams - dual_params: mpo_types.DualParams - opt_state: optax.OptState - dual_opt_state: optax.OptState - steps: int - random_key: jax_types.PRNGKey - - -def softmax_cross_entropy( - logits: chex.Array, target_probs: chex.Array) -> chex.Array: - """Compute cross entropy loss between logits and target probabilities.""" - chex.assert_equal_shape([target_probs, logits]) - return -jnp.sum(target_probs * jax.nn.log_softmax(logits), axis=-1) - - -def top1_accuracy_tiebreak(logits: chex.Array, - targets: chex.Array, - *, - rng: jax_types.PRNGKey, - eps: float = 1e-6) -> chex.Array: - """Compute the top-1 accuracy with an argmax of targets (random tie-break).""" - noise = jax.random.uniform(rng, shape=targets.shape, - minval=-eps, maxval=eps) - acc = jnp.argmax(logits, axis=-1) == jnp.argmax(targets + noise, axis=-1) - return jnp.mean(acc) + """Contains training state for the learner.""" + + params: mpo_networks.MPONetworkParams + target_params: mpo_networks.MPONetworkParams + dual_params: mpo_types.DualParams + opt_state: optax.OptState + dual_opt_state: optax.OptState + steps: int + random_key: jax_types.PRNGKey + + +def softmax_cross_entropy(logits: chex.Array, target_probs: chex.Array) -> chex.Array: + """Compute cross entropy loss between logits and target probabilities.""" + chex.assert_equal_shape([target_probs, logits]) + return -jnp.sum(target_probs * jax.nn.log_softmax(logits), axis=-1) + + +def top1_accuracy_tiebreak( + logits: chex.Array, + targets: chex.Array, + *, + rng: jax_types.PRNGKey, + eps: float = 1e-6, +) -> chex.Array: + """Compute the top-1 accuracy with an argmax of targets (random tie-break).""" + noise = jax.random.uniform(rng, shape=targets.shape, minval=-eps, maxval=eps) + acc = jnp.argmax(logits, axis=-1) == jnp.argmax(targets + noise, axis=-1) + return jnp.mean(acc) class MPOLearner(acme.Learner): - """MPO learner (discrete or continuous, distributional or not).""" - - _state: TrainingState - - def __init__( # pytype: disable=annotation-type-mismatch # numpy-scalars - self, - critic_type: CriticType, - discrete_policy: bool, - environment_spec: specs.EnvironmentSpec, - networks: mpo_networks.MPONetworks, - random_key: jax_types.PRNGKey, - discount: float, - num_samples: int, - iterator: Iterator[reverb.ReplaySample], - experience_type: mpo_types.ExperienceType, - loss_scales: mpo_types.LossScalesConfig, - target_update_period: Optional[int] = 100, - target_update_rate: Optional[float] = None, - sgd_steps_per_learner_step: int = 20, - policy_eval_stochastic: bool = True, - policy_eval_num_val_samples: int = 128, - policy_loss_config: Optional[mpo_types.PolicyLossConfig] = None, - use_online_policy_to_bootstrap: bool = False, - use_stale_state: bool = False, - use_retrace: bool = False, - retrace_lambda: float = 0.95, - model_rollout_length: int = 0, - optimizer: Optional[optax.GradientTransformation] = None, - learning_rate: Optional[Union[float, optax.Schedule]] = None, - dual_optimizer: Optional[optax.GradientTransformation] = None, - grad_norm_clip: float = 40.0, - reward_clip: float = np.float32('inf'), - value_tx_pair: rlax.TxPair = rlax.IDENTITY_PAIR, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - devices: Optional[Sequence[jax.Device]] = None, - ): - self._critic_type = critic_type - self._discrete_policy = discrete_policy - - process_id = jax.process_index() - local_devices = jax.local_devices() - self._devices = devices or local_devices - logging.info('Learner process id: %s. Devices passed: %s', process_id, - devices) - logging.info('Learner process id: %s. Local devices from JAX API: %s', - process_id, local_devices) - self._local_devices = [d for d in self._devices if d in local_devices] - - # Store networks. - self._networks = networks - - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger - - # Other learner parameters. - self._discount = discount - self._num_samples = num_samples - self._sgd_steps_per_learner_step = sgd_steps_per_learner_step - - self._policy_eval_stochastic = policy_eval_stochastic - self._policy_eval_num_val_samples = policy_eval_num_val_samples - - self._reward_clip_range = sorted([-reward_clip, reward_clip]) - self._tx_pair = value_tx_pair - self._loss_scales = loss_scales - self._use_online_policy_to_bootstrap = use_online_policy_to_bootstrap - self._model_rollout_length = model_rollout_length - - self._use_retrace = use_retrace - self._retrace_lambda = retrace_lambda - if use_retrace and critic_type == CriticType.MIXTURE_OF_GAUSSIANS: - logging.warning( - 'Warning! Retrace has not been tested with the MoG critic.') - self._use_stale_state = use_stale_state - - self._experience_type = experience_type - if isinstance(self._experience_type, mpo_types.FromTransitions): - # Each n=5-step transition will be converted to a length 2 sequence before - # being passed to the loss, so we do n=1 step bootstrapping on the - # resulting sequence to get n=5-step bootstrapping as intended. - self._n_step_for_sequence_bootstrap = 1 - self._td_lambda = 1.0 - elif isinstance(self._experience_type, mpo_types.FromSequences): - self._n_step_for_sequence_bootstrap = self._experience_type.n_step - self._td_lambda = self._experience_type.td_lambda - - # Necessary to track when to update target networks. - self._target_update_period = target_update_period - self._target_update_rate = target_update_rate - # Assert one and only one of target update period or rate is defined. - if ((target_update_period and target_update_rate) or - (target_update_period is None and target_update_rate is None)): - raise ValueError( - 'Exactly one of target_update_{period|rate} must be set.' - f' Received target_update_period={target_update_period} and' - f' target_update_rate={target_update_rate}.') - - # Create policy loss. - if self._discrete_policy: - policy_loss_config = ( - policy_loss_config or mpo_types.CategoricalPolicyLossConfig()) - self._policy_loss_module = discrete_losses.CategoricalMPO( - **dataclasses.asdict(policy_loss_config)) - else: - policy_loss_config = ( - policy_loss_config or mpo_types.GaussianPolicyLossConfig()) - self._policy_loss_module = continuous_losses.MPO( - **dataclasses.asdict(policy_loss_config)) - - self._policy_loss_module.__call__ = jax.named_call( - self._policy_loss_module.__call__, name='policy_loss') - - # Create the dynamics model rollout loss. - if model_rollout_length > 0: - if not discrete_policy and (self._loss_scales.rollout.policy or - self._loss_scales.rollout.bc_policy): - raise ValueError('Policy rollout losses are only supported in the ' - 'discrete policy case.') - self._model_rollout_loss_fn = rollout_loss.RolloutLoss( - dynamics_model=networks.dynamics_model, - model_rollout_length=model_rollout_length, - loss_scales=loss_scales, - distributional_loss_fn=self._distributional_loss) - - # Create optimizers if they aren't given. - self._optimizer = optimizer or _get_default_optimizer(1e-4, grad_norm_clip) - self._dual_optimizer = dual_optimizer or _get_default_optimizer( - 1e-2, grad_norm_clip) - self._lr_schedule = learning_rate if callable(learning_rate) else None - - self._action_spec = environment_spec.actions - - # Initialize random key for the rest of training. - random_key, key = jax.random.split(random_key) - - # Initialize network parameters, ignoring the dummy initial state. - network_params, _ = mpo_networks.init_params( - self._networks, - environment_spec, - key, - add_batch_dim=True, - dynamics_rollout_length=self._model_rollout_length) - - # Get action dims (unused in the discrete case). - dummy_action = utils.zeros_like(environment_spec.actions) - dummy_action_concat = utils.batch_concat(dummy_action, num_batch_dims=0) - - if isinstance(self._policy_loss_module, discrete_losses.CategoricalMPO): - self._dual_clip_fn = discrete_losses.clip_categorical_mpo_params - elif isinstance(self._policy_loss_module, continuous_losses.MPO): - is_constraining = self._policy_loss_module.per_dim_constraining - self._dual_clip_fn = lambda dp: continuous_losses.clip_mpo_params( # pylint: disable=g-long-lambda # pytype: disable=wrong-arg-types # numpy-scalars - dp, - per_dim_constraining=is_constraining) - - # Create dual parameters. In the discrete case, the action dim is unused. - dual_params = self._policy_loss_module.init_params( - action_dim=dummy_action_concat.shape[-1], dtype=jnp.float32) - - # Initialize optimizers. - opt_state = self._optimizer.init(network_params) - dual_opt_state = self._dual_optimizer.init(dual_params) - - # Initialise training state (parameters and optimiser state). - state = TrainingState( - params=network_params, - target_params=network_params, - dual_params=dual_params, - opt_state=opt_state, - dual_opt_state=dual_opt_state, - steps=0, - random_key=random_key, - ) - self._state = utils.replicate_in_all_devices(state, self._local_devices) - - # Log how many parameters the network has. - sizes = tree.map_structure(jnp.size, network_params)._asdict() - num_params_by_component_str = ' | '.join( - [f'{key}: {sum(tree.flatten(size))}' for key, size in sizes.items()]) - logging.info('Number of params by network component: %s', - num_params_by_component_str) - logging.info('Total number of params: %d', - sum(tree.flatten(sizes.values()))) - - # Combine multiple SGD steps and pmap across devices. - sgd_steps = utils.process_multiple_batches(self._sgd_step, - self._sgd_steps_per_learner_step) - self._sgd_steps = jax.pmap( - sgd_steps, axis_name=_PMAP_AXIS_NAME, devices=self._devices) - - self._iterator = iterator - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - self._current_step = 0 - - def _distributional_loss(self, prediction: mpo_types.DistributionLike, - target: chex.Array): - """Compute the critic loss given the prediction and target.""" - # TODO(abef): break this function into separate functions for each critic. - chex.assert_rank(target, 3) # [N, Z, T] except for Categorical is [1, T, L] - if self._critic_type == CriticType.MIXTURE_OF_GAUSSIANS: - # Sample-based cross-entropy loss. - loss = -prediction.log_prob(target[..., jnp.newaxis]) - loss = jnp.mean(loss, axis=[0, 1]) # [T] - elif self._critic_type == CriticType.NONDISTRIBUTIONAL: - # TD error. - prediction = prediction.squeeze(axis=-1) # [T] - loss = 0.5 * jnp.square(target - prediction) - chex.assert_equal_shape([target, loss]) # Check broadcasting. - elif self._critic_type == mpo_types.CriticType.CATEGORICAL_2HOT: - # Cross-entropy loss (two-hot categorical). - target = jnp.mean(target, axis=(0, 1)) # [N, Z, T] -> [T] - # TODO(abef): Compute target differently? (e.g., do mean cross ent.). - target_probs = rlax.transform_to_2hot( # [T, L] - target, - min_value=prediction.values.min(), - max_value=prediction.values.max(), - num_bins=prediction.logits.shape[-1]) - logits = jnp.squeeze(prediction.logits, axis=1) # [T, L] - chex.assert_equal_shape([target_probs, logits]) - loss = jax.vmap(rlax.categorical_cross_entropy)(target_probs, logits) - elif self._critic_type == mpo_types.CriticType.CATEGORICAL: - loss = jax.vmap(rlax.categorical_cross_entropy)(jnp.squeeze( - target, axis=0), jnp.squeeze(prediction.logits, axis=1)) - return jnp.mean(loss) # [T] -> [] - - def _compute_predictions(self, params: mpo_networks.MPONetworkParams, - sequence: adders.Step) -> mpo_types.ModelOutputs: - """Compute model predictions at observed and rolled out states.""" - - # Initialize the core states, possibly to the recorded stale state. - if self._use_stale_state: - initial_state = utils.maybe_recover_lstm_type( - sequence.extras['core_state']) - initial_state = tree.map_structure(lambda x: x[0], initial_state) - else: - initial_state = self._networks.torso.initial_state_fn( - params.torso_initial_state, None) - - # Unroll the online core network. Note that this may pass the embeddings - # unchanged if, say, the core is an hk.IdentityCore. - state_embedding, _ = self._networks.torso_unroll( # [T, ...] - params, sequence.observation, initial_state) - - # Compute the root policy and critic outputs; [T, ...] and [T-1, ...]. - policy = self._networks.policy_head_apply(params, state_embedding) - q_value = self._networks.critic_head_apply( - params, state_embedding[:-1], sequence.action[:-1]) - - return mpo_types.ModelOutputs( - policy=policy, # [T, ...] - q_value=q_value, # [T-1, ...] - reward=None, - embedding=state_embedding) # [T, ...] - - def _compute_targets( - self, - target_params: mpo_networks.MPONetworkParams, - dual_params: mpo_types.DualParams, - sequence: adders.Step, - online_policy: types.NestedArray, # TODO(abef): remove this. - key: jax_types.PRNGKey) -> mpo_types.LossTargets: - """Compute the targets needed to train the agent.""" - - # Initialize the core states, possibly to the recorded stale state. - if self._use_stale_state: - initial_state = utils.maybe_recover_lstm_type( - sequence.extras['core_state']) - initial_state = tree.map_structure(lambda x: x[0], initial_state) - else: - initial_state = self._networks.torso.initial_state_fn( - target_params.torso_initial_state, None) - - # Unroll the target core network. Note that this may pass the embeddings - # unchanged if, say, the core is an hk.IdentityCore. - target_state_embedding, _ = self._networks.torso_unroll( - target_params, sequence.observation, initial_state) # [T, ...] - - # Compute the action distribution from target policy network. - target_policy = self._networks.policy_head_apply( - target_params, target_state_embedding) # [T, ...] - - # Maybe reward clip. - clipped_reward = jnp.clip(sequence.reward, *self._reward_clip_range) # [T] - # TODO(abef): when to clip rewards, if at all, if learning dynamics model? - - @jax.named_call - @jax.vmap - def critic_mean_fn(action_: jnp.ndarray) -> jnp.ndarray: - """Compute mean of target critic distribution.""" - critic_output = self._networks.critic_head_apply( - target_params, target_state_embedding, action_) - if self._critic_type != CriticType.NONDISTRIBUTIONAL: - critic_output = critic_output.mean() - return critic_output - - @jax.named_call - @jax.vmap - def critic_sample_fn(action_: jnp.ndarray, - seed_: jnp.ndarray) -> jnp.ndarray: - """Sample from the target critic distribution.""" - z_distribution = self._networks.critic_head_apply( - target_params, target_state_embedding, action_) - z_samples = z_distribution.sample( - self._policy_eval_num_val_samples, seed=seed_) - return z_samples # [Z, T, 1] - - if self._discrete_policy: - # Use all actions to improve policy (no sampling); N = num_actions. - a_improvement = jnp.arange(self._action_spec.num_values) # [N] - seq_len = target_state_embedding.shape[0] # T - a_improvement = jnp.tile(a_improvement[..., None], [1, seq_len]) # [N, T] - else: - # Sample actions to improve policy; [N=num_samples, T]. - a_improvement = target_policy.sample(self._num_samples, seed=key) - - # TODO(abef): use model to get q_improvement = r + gamma*V? - - # Compute the mean Q-values used in policy improvement; [N, T]. - q_improvement = critic_mean_fn(a_improvement).squeeze(axis=-1) - - # Policy to use for policy evaluation and bootstrapping. - if self._use_online_policy_to_bootstrap: - policy_to_evaluate = online_policy - chex.assert_equal(online_policy.batch_shape, target_policy.batch_shape) - else: - policy_to_evaluate = target_policy - - # Action(s) to use for policy evaluation; shape [N, T]. - if self._policy_eval_stochastic: - a_evaluation = policy_to_evaluate.sample(self._num_samples, seed=key) - else: - a_evaluation = policy_to_evaluate.mode() - a_evaluation = jnp.expand_dims(a_evaluation, axis=0) # [N=1, T] - - # TODO(abef): policy_eval_stochastic=False makes our targets more "greedy" - - # Add a stopgrad in case we use the online policy for evaluation. - a_evaluation = jax.lax.stop_gradient(a_evaluation) - - if self._critic_type == CriticType.MIXTURE_OF_GAUSSIANS: - # Produce Z return samples for every N action sample; [N, Z, T, 1]. - seeds = jax.random.split(key, num=a_evaluation.shape[0]) - z_samples = critic_sample_fn(a_evaluation, seeds) - else: - normalized_weights = 1. / a_evaluation.shape[0] - z_samples = critic_mean_fn(a_evaluation) # [N, T, 1] - - # When policy_eval_stochastic == True, this corresponds to expected SARSA. - # Otherwise, normalized_weights = 1.0 and N = 1 so the sum is a no-op. - z_samples = jnp.sum(normalized_weights * z_samples, axis=0, keepdims=True) - z_samples = jnp.expand_dims(z_samples, axis=1) # [N, Z=1, T, 1] - - # Slice to t = 1...T and transform into raw reward space; [N, Z, T]. - z_samples_itx = self._tx_pair.apply_inv(z_samples.squeeze(axis=-1)) - - # Compute the value estimate by averaging the sampled returns in the raw - # reward space; shape [N=1, Z=1, T]. - value_target_itx = jnp.mean(z_samples_itx, axis=(0, 1), keepdims=True) - - if self._use_retrace: - # Warning! Retrace has not been tested with the MoG critic. - log_rhos = ( - target_policy.log_prob(sequence.action) - sequence.extras['log_prob']) - - # Compute Q-values; expand and squeeze because critic_mean_fn is vmapped. - q_t = critic_mean_fn(jnp.expand_dims(sequence.action, axis=0)).squeeze(0) - q_t = q_t.squeeze(-1) # Also squeeze trailing scalar dimension; [T]. - - # Compute retrace targets. - # These targets use the rewards and discounts as in normal TD-learning but - # they use a mix of bootstrapped values V(s') and Q(s', a'), weighing the - # latter based on how likely a' is under the current policy (s' and a' are - # samples from replay). - # See [Munos et al., 2016](https://arxiv.org/abs/1606.02647) for more. - q_value_target_itx = rlax.general_off_policy_returns_from_q_and_v( - q_t=self._tx_pair.apply_inv(q_t[1:-1]), - v_t=jnp.squeeze(value_target_itx, axis=(0, 1))[1:], - r_t=clipped_reward[:-1], - discount_t=self._discount * sequence.discount[:-1], - c_t=self._retrace_lambda * jnp.minimum(1.0, jnp.exp(log_rhos[1:-1]))) - - # Expand dims to the expected [N=1, Z=1, T-1]. - q_value_target_itx = jnp.expand_dims(q_value_target_itx, axis=(0, 1)) - else: - # Compute bootstrap target from sequences. vmap return computation across - # N action and Z return samples; shape [N, Z, T-1]. - n_step_return_fn = functools.partial( - rlax.n_step_bootstrapped_returns, - r_t=clipped_reward[:-1], - discount_t=self._discount * sequence.discount[:-1], - n=self._n_step_for_sequence_bootstrap, - lambda_t=self._td_lambda) - n_step_return_vfn = jax.vmap(jax.vmap(n_step_return_fn)) - q_value_target_itx = n_step_return_vfn(v_t=z_samples_itx[..., 1:]) - - # Transform back to the canonical space and stop gradients. - q_value_target = jax.lax.stop_gradient( - self._tx_pair.apply(q_value_target_itx)) - reward_target = jax.lax.stop_gradient(self._tx_pair.apply(clipped_reward)) - value_target = jax.lax.stop_gradient(self._tx_pair.apply(value_target_itx)) - - if self._critic_type == mpo_types.CriticType.CATEGORICAL: - - @jax.vmap - def get_logits_and_values( - action: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: - critic_output = self._networks.critic_head_apply( - target_params, target_state_embedding[1:], action) - return critic_output.logits, critic_output.values - - z_t_logits, z_t_values = get_logits_and_values(a_evaluation[:, 1:]) - z_t_logits = jnp.squeeze(z_t_logits, axis=2) # [N, T-1, L] - z_t_values = z_t_values[0] # Values are identical at each N; [L]. - - gamma = self._discount * sequence.discount[:-1, None] # [T-1, 1] - r_t = clipped_reward[:-1, None] # [T-1, 1] - atoms_itx = self._tx_pair.apply_inv(z_t_values)[None, ...] # [1, L] - z_target_atoms = self._tx_pair.apply(r_t + gamma * atoms_itx) # [T-1, L] - # Note: this is n=1-step TD unless using experience=FromTransitions(n>1). - z_target_probs = jax.nn.softmax(z_t_logits) # [N, T-1, L] - z_target_atoms = jax.lax.broadcast( - z_target_atoms, z_target_probs.shape[:1]) # [N, T-1, L] - project_fn = functools.partial( - rlax.categorical_l2_project, z_q=z_t_values) - z_target = jax.vmap(jax.vmap(project_fn))(z_target_atoms, z_target_probs) - z_target = jnp.mean(z_target, axis=0) - q_value_target = jax.lax.stop_gradient(z_target[None, ...]) # [1, T-1, L] - # TODO(abef): make q_v_target shape align with expected [N, Z, T-1] shape? - - targets = mpo_types.LossTargets( - policy=target_policy, # [T, ...] - a_improvement=a_improvement, # [N, T] - q_improvement=q_improvement, # [N, T] - q_value=q_value_target, # [N, Z, T-1] ([1, T-1, L] for CATEGORICAL) - value=value_target[..., :-1], # [N=1, Z=1, T-1] - reward=reward_target, # [T] - embedding=target_state_embedding) # [T, ...] - - return targets - - def _loss_fn( - self, - params: mpo_networks.MPONetworkParams, - dual_params: mpo_types.DualParams, - # TODO(bshahr): clean up types: Step is not a great type for sequences. - sequence: adders.Step, - target_params: mpo_networks.MPONetworkParams, - key: jax_types.PRNGKey) -> Tuple[jnp.ndarray, mpo_types.LogDict]: - # Compute the model predictions at the root and for the rollouts. - predictions = self._compute_predictions(params=params, sequence=sequence) - - # Compute the targets to use for the losses. - targets = self._compute_targets( - target_params=target_params, - dual_params=dual_params, - sequence=sequence, - online_policy=predictions.policy, - key=key) - - # TODO(abef): mask policy loss at terminal states or use uniform targets - # is_terminal = sequence.discount == 0. - - # Compute MPO policy loss on each state in the sequence. - policy_loss, policy_stats = self._policy_loss_module( - params=dual_params, - online_action_distribution=predictions.policy, # [T, ...]. - target_action_distribution=targets.policy, # [T, ...]. - actions=targets.a_improvement, # Unused in discrete case; [N, T]. - q_values=targets.q_improvement) # [N, T] - - # Compute the critic loss on the states in the sequence. - critic_loss = self._distributional_loss( - prediction=predictions.q_value, # [T-1, 1, ...] - target=targets.q_value) # [N, Z, T-1] - - loss = (self._loss_scales.policy * policy_loss + - self._loss_scales.critic * critic_loss) - loss_logging_dict = { - 'loss': loss, - 'root_policy_loss': policy_loss, - 'root_critic_loss': critic_loss, - 'policy_loss': policy_loss, - 'critic_loss': critic_loss, - } - - # Append MPO statistics. - loss_logging_dict.update( - {f'policy/root/{k}': v for k, v in policy_stats._asdict().items()}) - - # Compute rollout losses. - if self._model_rollout_length > 0: - model_rollout_loss, rollout_logs = self._model_rollout_loss_fn( - params, dual_params, sequence, predictions.embedding, targets, key) - loss += model_rollout_loss - loss_logging_dict.update(rollout_logs) - loss_logging_dict.update({ - 'policy_loss': policy_loss + rollout_logs['rollout_policy_loss'], - 'critic_loss': critic_loss + rollout_logs['rollout_critic_loss'], - 'loss': loss}) - - return loss, loss_logging_dict - - def _sgd_step( - self, - state: TrainingState, - transitions: Union[types.Transition, adders.Step], - ) -> Tuple[TrainingState, Dict[str, Any]]: - """Perform one parameter update step.""" - - if isinstance(transitions, types.Transition): - sequences = mpo_utils.make_sequences_from_transitions(transitions) - if self._model_rollout_length > 0: - raise ValueError('model rollouts not yet supported from transitions') - else: - sequences = transitions - - # Get next random_key and `batch_size` keys. - batch_size = sequences.reward.shape[0] - keys = jax.random.split(state.random_key, num=batch_size+1) - random_key, keys = keys[0], keys[1:] - - # Vmap over the batch dimension when learning from sequences. - loss_vfn = jax.vmap(self._loss_fn, in_axes=(None, None, 0, None, 0)) - safe_mean = lambda x: jnp.mean(x) if x is not None else x - # TODO(bshahr): Consider cleaning this up via acme.tree_utils.tree_map. - loss_fn = lambda *a, **k: tree.map_structure(safe_mean, loss_vfn(*a, **k)) - - loss_and_grad = jax.value_and_grad(loss_fn, argnums=(0, 1), has_aux=True) - - # Compute the loss and gradient. - (_, loss_log_dict), all_gradients = loss_and_grad( - state.params, state.dual_params, sequences, state.target_params, keys) - - # Average gradients across replicas. - gradients, dual_gradients = jax.lax.pmean(all_gradients, _PMAP_AXIS_NAME) - - # Compute gradient norms before clipping. - gradients_norm = optax.global_norm(gradients) - dual_gradients_norm = optax.global_norm(dual_gradients) - - # Get optimizer updates and state. - updates, opt_state = self._optimizer.update( - gradients, state.opt_state, state.params) - dual_updates, dual_opt_state = self._dual_optimizer.update( - dual_gradients, state.dual_opt_state, state.dual_params) - - # Apply optimizer updates to parameters. - params = optax.apply_updates(state.params, updates) - dual_params = optax.apply_updates(state.dual_params, dual_updates) - - # Clip dual params at some minimum value. - dual_params = self._dual_clip_fn(dual_params) - - steps = state.steps + 1 - - # Periodically update target networks. - if self._target_update_period: - target_params = optax.periodic_update(params, state.target_params, steps, # pytype: disable=wrong-arg-types # numpy-scalars - self._target_update_period) - elif self._target_update_rate: - target_params = optax.incremental_update(params, state.target_params, - self._target_update_rate) - - new_state = TrainingState( # pytype: disable=wrong-arg-types # numpy-scalars - params=params, - target_params=target_params, - dual_params=dual_params, - opt_state=opt_state, - dual_opt_state=dual_opt_state, - steps=steps, - random_key=random_key, - ) - - # Log the metrics from this learner step. - metrics = {f'loss/{k}': v for k, v in loss_log_dict.items()} - - metrics.update({ - 'opt/grad_norm': gradients_norm, - 'opt/param_norm': optax.global_norm(params)}) - if callable(self._lr_schedule): - metrics['opt/learning_rate'] = self._lr_schedule(state.steps) # pylint: disable=not-callable - - dual_metrics = { - 'opt/dual_grad_norm': dual_gradients_norm, - 'opt/dual_param_norm': optax.global_norm(dual_params), - 'params/dual/log_temperature_avg': dual_params.log_temperature} - if isinstance(dual_params, continuous_losses.MPOParams): - dual_metrics.update({ - 'params/dual/log_alpha_mean_avg': dual_params.log_alpha_mean, - 'params/dual/log_alpha_stddev_avg': dual_params.log_alpha_stddev}) - if dual_params.log_penalty_temperature is not None: - dual_metrics['params/dual/log_penalty_temp_mean'] = ( - dual_params.log_penalty_temperature) - elif isinstance(dual_params, discrete_losses.CategoricalMPOParams): - dual_metrics['params/dual/log_alpha_avg'] = dual_params.log_alpha - metrics.update(jax.tree_map(jnp.mean, dual_metrics)) - - return new_state, metrics - - def step(self): - """Perform one learner step, which in general does multiple SGD steps.""" - with jax.profiler.StepTraceAnnotation('step', step_num=self._current_step): - # Get data from replay (dropping extras if any). Note there is no - # extra data here because we do not insert any into Reverb. - sample = next(self._iterator) - if isinstance(self._experience_type, mpo_types.FromTransitions): - minibatch = types.Transition(*sample.data) - elif isinstance(self._experience_type, mpo_types.FromSequences): - minibatch = adders.Step(*sample.data) - - self._state, metrics = self._sgd_steps(self._state, minibatch) - self._current_step, metrics = mpo_utils.get_from_first_device( - (self._state.steps, metrics)) - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Increment counts and record the current time - counts = self._counter.increment( - steps=self._sgd_steps_per_learner_step, walltime=elapsed_time) - - if elapsed_time > 0: - metrics['steps_per_second'] = ( - self._sgd_steps_per_learner_step / elapsed_time) - else: - metrics['steps_per_second'] = 0. - - # Attempts to write the logs. - if self._logger: - self._logger.write({**metrics, **counts}) - - def get_variables(self, names: List[str]) -> network_lib.Params: - params = mpo_utils.get_from_first_device(self._state.target_params) - - variables = { - 'policy_head': params.policy_head, - 'critic_head': params.critic_head, - 'torso': params.torso, - 'network': params, - 'policy': params._replace(critic_head={}), - 'critic': params._replace(policy_head={}), - } - return [variables[name] for name in names] - - def save(self) -> TrainingState: - return jax.tree_map(mpo_utils.get_from_first_device, self._state) - - def restore(self, state: TrainingState): - self._state = utils.replicate_in_all_devices(state, self._local_devices) + """MPO learner (discrete or continuous, distributional or not).""" + + _state: TrainingState + + def __init__( # pytype: disable=annotation-type-mismatch # numpy-scalars + self, + critic_type: CriticType, + discrete_policy: bool, + environment_spec: specs.EnvironmentSpec, + networks: mpo_networks.MPONetworks, + random_key: jax_types.PRNGKey, + discount: float, + num_samples: int, + iterator: Iterator[reverb.ReplaySample], + experience_type: mpo_types.ExperienceType, + loss_scales: mpo_types.LossScalesConfig, + target_update_period: Optional[int] = 100, + target_update_rate: Optional[float] = None, + sgd_steps_per_learner_step: int = 20, + policy_eval_stochastic: bool = True, + policy_eval_num_val_samples: int = 128, + policy_loss_config: Optional[mpo_types.PolicyLossConfig] = None, + use_online_policy_to_bootstrap: bool = False, + use_stale_state: bool = False, + use_retrace: bool = False, + retrace_lambda: float = 0.95, + model_rollout_length: int = 0, + optimizer: Optional[optax.GradientTransformation] = None, + learning_rate: Optional[Union[float, optax.Schedule]] = None, + dual_optimizer: Optional[optax.GradientTransformation] = None, + grad_norm_clip: float = 40.0, + reward_clip: float = np.float32("inf"), + value_tx_pair: rlax.TxPair = rlax.IDENTITY_PAIR, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + devices: Optional[Sequence[jax.Device]] = None, + ): + self._critic_type = critic_type + self._discrete_policy = discrete_policy + + process_id = jax.process_index() + local_devices = jax.local_devices() + self._devices = devices or local_devices + logging.info("Learner process id: %s. Devices passed: %s", process_id, devices) + logging.info( + "Learner process id: %s. Local devices from JAX API: %s", + process_id, + local_devices, + ) + self._local_devices = [d for d in self._devices if d in local_devices] + + # Store networks. + self._networks = networks + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger + + # Other learner parameters. + self._discount = discount + self._num_samples = num_samples + self._sgd_steps_per_learner_step = sgd_steps_per_learner_step + + self._policy_eval_stochastic = policy_eval_stochastic + self._policy_eval_num_val_samples = policy_eval_num_val_samples + + self._reward_clip_range = sorted([-reward_clip, reward_clip]) + self._tx_pair = value_tx_pair + self._loss_scales = loss_scales + self._use_online_policy_to_bootstrap = use_online_policy_to_bootstrap + self._model_rollout_length = model_rollout_length + + self._use_retrace = use_retrace + self._retrace_lambda = retrace_lambda + if use_retrace and critic_type == CriticType.MIXTURE_OF_GAUSSIANS: + logging.warning("Warning! Retrace has not been tested with the MoG critic.") + self._use_stale_state = use_stale_state + + self._experience_type = experience_type + if isinstance(self._experience_type, mpo_types.FromTransitions): + # Each n=5-step transition will be converted to a length 2 sequence before + # being passed to the loss, so we do n=1 step bootstrapping on the + # resulting sequence to get n=5-step bootstrapping as intended. + self._n_step_for_sequence_bootstrap = 1 + self._td_lambda = 1.0 + elif isinstance(self._experience_type, mpo_types.FromSequences): + self._n_step_for_sequence_bootstrap = self._experience_type.n_step + self._td_lambda = self._experience_type.td_lambda + + # Necessary to track when to update target networks. + self._target_update_period = target_update_period + self._target_update_rate = target_update_rate + # Assert one and only one of target update period or rate is defined. + if (target_update_period and target_update_rate) or ( + target_update_period is None and target_update_rate is None + ): + raise ValueError( + "Exactly one of target_update_{period|rate} must be set." + f" Received target_update_period={target_update_period} and" + f" target_update_rate={target_update_rate}." + ) + + # Create policy loss. + if self._discrete_policy: + policy_loss_config = ( + policy_loss_config or mpo_types.CategoricalPolicyLossConfig() + ) + self._policy_loss_module = discrete_losses.CategoricalMPO( + **dataclasses.asdict(policy_loss_config) + ) + else: + policy_loss_config = ( + policy_loss_config or mpo_types.GaussianPolicyLossConfig() + ) + self._policy_loss_module = continuous_losses.MPO( + **dataclasses.asdict(policy_loss_config) + ) + + self._policy_loss_module.__call__ = jax.named_call( + self._policy_loss_module.__call__, name="policy_loss" + ) + + # Create the dynamics model rollout loss. + if model_rollout_length > 0: + if not discrete_policy and ( + self._loss_scales.rollout.policy or self._loss_scales.rollout.bc_policy + ): + raise ValueError( + "Policy rollout losses are only supported in the " + "discrete policy case." + ) + self._model_rollout_loss_fn = rollout_loss.RolloutLoss( + dynamics_model=networks.dynamics_model, + model_rollout_length=model_rollout_length, + loss_scales=loss_scales, + distributional_loss_fn=self._distributional_loss, + ) + + # Create optimizers if they aren't given. + self._optimizer = optimizer or _get_default_optimizer(1e-4, grad_norm_clip) + self._dual_optimizer = dual_optimizer or _get_default_optimizer( + 1e-2, grad_norm_clip + ) + self._lr_schedule = learning_rate if callable(learning_rate) else None + + self._action_spec = environment_spec.actions + + # Initialize random key for the rest of training. + random_key, key = jax.random.split(random_key) + + # Initialize network parameters, ignoring the dummy initial state. + network_params, _ = mpo_networks.init_params( + self._networks, + environment_spec, + key, + add_batch_dim=True, + dynamics_rollout_length=self._model_rollout_length, + ) + + # Get action dims (unused in the discrete case). + dummy_action = utils.zeros_like(environment_spec.actions) + dummy_action_concat = utils.batch_concat(dummy_action, num_batch_dims=0) + + if isinstance(self._policy_loss_module, discrete_losses.CategoricalMPO): + self._dual_clip_fn = discrete_losses.clip_categorical_mpo_params + elif isinstance(self._policy_loss_module, continuous_losses.MPO): + is_constraining = self._policy_loss_module.per_dim_constraining + self._dual_clip_fn = lambda dp: continuous_losses.clip_mpo_params( # pylint: disable=g-long-lambda # pytype: disable=wrong-arg-types # numpy-scalars + dp, per_dim_constraining=is_constraining + ) + + # Create dual parameters. In the discrete case, the action dim is unused. + dual_params = self._policy_loss_module.init_params( + action_dim=dummy_action_concat.shape[-1], dtype=jnp.float32 + ) + + # Initialize optimizers. + opt_state = self._optimizer.init(network_params) + dual_opt_state = self._dual_optimizer.init(dual_params) + + # Initialise training state (parameters and optimiser state). + state = TrainingState( + params=network_params, + target_params=network_params, + dual_params=dual_params, + opt_state=opt_state, + dual_opt_state=dual_opt_state, + steps=0, + random_key=random_key, + ) + self._state = utils.replicate_in_all_devices(state, self._local_devices) + + # Log how many parameters the network has. + sizes = tree.map_structure(jnp.size, network_params)._asdict() + num_params_by_component_str = " | ".join( + [f"{key}: {sum(tree.flatten(size))}" for key, size in sizes.items()] + ) + logging.info( + "Number of params by network component: %s", num_params_by_component_str + ) + logging.info("Total number of params: %d", sum(tree.flatten(sizes.values()))) + + # Combine multiple SGD steps and pmap across devices. + sgd_steps = utils.process_multiple_batches( + self._sgd_step, self._sgd_steps_per_learner_step + ) + self._sgd_steps = jax.pmap( + sgd_steps, axis_name=_PMAP_AXIS_NAME, devices=self._devices + ) + + self._iterator = iterator + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + self._current_step = 0 + + def _distributional_loss( + self, prediction: mpo_types.DistributionLike, target: chex.Array + ): + """Compute the critic loss given the prediction and target.""" + # TODO(abef): break this function into separate functions for each critic. + chex.assert_rank(target, 3) # [N, Z, T] except for Categorical is [1, T, L] + if self._critic_type == CriticType.MIXTURE_OF_GAUSSIANS: + # Sample-based cross-entropy loss. + loss = -prediction.log_prob(target[..., jnp.newaxis]) + loss = jnp.mean(loss, axis=[0, 1]) # [T] + elif self._critic_type == CriticType.NONDISTRIBUTIONAL: + # TD error. + prediction = prediction.squeeze(axis=-1) # [T] + loss = 0.5 * jnp.square(target - prediction) + chex.assert_equal_shape([target, loss]) # Check broadcasting. + elif self._critic_type == mpo_types.CriticType.CATEGORICAL_2HOT: + # Cross-entropy loss (two-hot categorical). + target = jnp.mean(target, axis=(0, 1)) # [N, Z, T] -> [T] + # TODO(abef): Compute target differently? (e.g., do mean cross ent.). + target_probs = rlax.transform_to_2hot( # [T, L] + target, + min_value=prediction.values.min(), + max_value=prediction.values.max(), + num_bins=prediction.logits.shape[-1], + ) + logits = jnp.squeeze(prediction.logits, axis=1) # [T, L] + chex.assert_equal_shape([target_probs, logits]) + loss = jax.vmap(rlax.categorical_cross_entropy)(target_probs, logits) + elif self._critic_type == mpo_types.CriticType.CATEGORICAL: + loss = jax.vmap(rlax.categorical_cross_entropy)( + jnp.squeeze(target, axis=0), jnp.squeeze(prediction.logits, axis=1) + ) + return jnp.mean(loss) # [T] -> [] + + def _compute_predictions( + self, params: mpo_networks.MPONetworkParams, sequence: adders.Step + ) -> mpo_types.ModelOutputs: + """Compute model predictions at observed and rolled out states.""" + + # Initialize the core states, possibly to the recorded stale state. + if self._use_stale_state: + initial_state = utils.maybe_recover_lstm_type(sequence.extras["core_state"]) + initial_state = tree.map_structure(lambda x: x[0], initial_state) + else: + initial_state = self._networks.torso.initial_state_fn( + params.torso_initial_state, None + ) + + # Unroll the online core network. Note that this may pass the embeddings + # unchanged if, say, the core is an hk.IdentityCore. + state_embedding, _ = self._networks.torso_unroll( # [T, ...] + params, sequence.observation, initial_state + ) + + # Compute the root policy and critic outputs; [T, ...] and [T-1, ...]. + policy = self._networks.policy_head_apply(params, state_embedding) + q_value = self._networks.critic_head_apply( + params, state_embedding[:-1], sequence.action[:-1] + ) + + return mpo_types.ModelOutputs( + policy=policy, # [T, ...] + q_value=q_value, # [T-1, ...] + reward=None, + embedding=state_embedding, + ) # [T, ...] + + def _compute_targets( + self, + target_params: mpo_networks.MPONetworkParams, + dual_params: mpo_types.DualParams, + sequence: adders.Step, + online_policy: types.NestedArray, # TODO(abef): remove this. + key: jax_types.PRNGKey, + ) -> mpo_types.LossTargets: + """Compute the targets needed to train the agent.""" + + # Initialize the core states, possibly to the recorded stale state. + if self._use_stale_state: + initial_state = utils.maybe_recover_lstm_type(sequence.extras["core_state"]) + initial_state = tree.map_structure(lambda x: x[0], initial_state) + else: + initial_state = self._networks.torso.initial_state_fn( + target_params.torso_initial_state, None + ) + + # Unroll the target core network. Note that this may pass the embeddings + # unchanged if, say, the core is an hk.IdentityCore. + target_state_embedding, _ = self._networks.torso_unroll( + target_params, sequence.observation, initial_state + ) # [T, ...] + + # Compute the action distribution from target policy network. + target_policy = self._networks.policy_head_apply( + target_params, target_state_embedding + ) # [T, ...] + + # Maybe reward clip. + clipped_reward = jnp.clip(sequence.reward, *self._reward_clip_range) # [T] + # TODO(abef): when to clip rewards, if at all, if learning dynamics model? + + @jax.named_call + @jax.vmap + def critic_mean_fn(action_: jnp.ndarray) -> jnp.ndarray: + """Compute mean of target critic distribution.""" + critic_output = self._networks.critic_head_apply( + target_params, target_state_embedding, action_ + ) + if self._critic_type != CriticType.NONDISTRIBUTIONAL: + critic_output = critic_output.mean() + return critic_output + + @jax.named_call + @jax.vmap + def critic_sample_fn(action_: jnp.ndarray, seed_: jnp.ndarray) -> jnp.ndarray: + """Sample from the target critic distribution.""" + z_distribution = self._networks.critic_head_apply( + target_params, target_state_embedding, action_ + ) + z_samples = z_distribution.sample( + self._policy_eval_num_val_samples, seed=seed_ + ) + return z_samples # [Z, T, 1] + + if self._discrete_policy: + # Use all actions to improve policy (no sampling); N = num_actions. + a_improvement = jnp.arange(self._action_spec.num_values) # [N] + seq_len = target_state_embedding.shape[0] # T + a_improvement = jnp.tile(a_improvement[..., None], [1, seq_len]) # [N, T] + else: + # Sample actions to improve policy; [N=num_samples, T]. + a_improvement = target_policy.sample(self._num_samples, seed=key) + + # TODO(abef): use model to get q_improvement = r + gamma*V? + + # Compute the mean Q-values used in policy improvement; [N, T]. + q_improvement = critic_mean_fn(a_improvement).squeeze(axis=-1) + + # Policy to use for policy evaluation and bootstrapping. + if self._use_online_policy_to_bootstrap: + policy_to_evaluate = online_policy + chex.assert_equal(online_policy.batch_shape, target_policy.batch_shape) + else: + policy_to_evaluate = target_policy + + # Action(s) to use for policy evaluation; shape [N, T]. + if self._policy_eval_stochastic: + a_evaluation = policy_to_evaluate.sample(self._num_samples, seed=key) + else: + a_evaluation = policy_to_evaluate.mode() + a_evaluation = jnp.expand_dims(a_evaluation, axis=0) # [N=1, T] + + # TODO(abef): policy_eval_stochastic=False makes our targets more "greedy" + + # Add a stopgrad in case we use the online policy for evaluation. + a_evaluation = jax.lax.stop_gradient(a_evaluation) + + if self._critic_type == CriticType.MIXTURE_OF_GAUSSIANS: + # Produce Z return samples for every N action sample; [N, Z, T, 1]. + seeds = jax.random.split(key, num=a_evaluation.shape[0]) + z_samples = critic_sample_fn(a_evaluation, seeds) + else: + normalized_weights = 1.0 / a_evaluation.shape[0] + z_samples = critic_mean_fn(a_evaluation) # [N, T, 1] + + # When policy_eval_stochastic == True, this corresponds to expected SARSA. + # Otherwise, normalized_weights = 1.0 and N = 1 so the sum is a no-op. + z_samples = jnp.sum(normalized_weights * z_samples, axis=0, keepdims=True) + z_samples = jnp.expand_dims(z_samples, axis=1) # [N, Z=1, T, 1] + + # Slice to t = 1...T and transform into raw reward space; [N, Z, T]. + z_samples_itx = self._tx_pair.apply_inv(z_samples.squeeze(axis=-1)) + + # Compute the value estimate by averaging the sampled returns in the raw + # reward space; shape [N=1, Z=1, T]. + value_target_itx = jnp.mean(z_samples_itx, axis=(0, 1), keepdims=True) + + if self._use_retrace: + # Warning! Retrace has not been tested with the MoG critic. + log_rhos = ( + target_policy.log_prob(sequence.action) - sequence.extras["log_prob"] + ) + + # Compute Q-values; expand and squeeze because critic_mean_fn is vmapped. + q_t = critic_mean_fn(jnp.expand_dims(sequence.action, axis=0)).squeeze(0) + q_t = q_t.squeeze(-1) # Also squeeze trailing scalar dimension; [T]. + + # Compute retrace targets. + # These targets use the rewards and discounts as in normal TD-learning but + # they use a mix of bootstrapped values V(s') and Q(s', a'), weighing the + # latter based on how likely a' is under the current policy (s' and a' are + # samples from replay). + # See [Munos et al., 2016](https://arxiv.org/abs/1606.02647) for more. + q_value_target_itx = rlax.general_off_policy_returns_from_q_and_v( + q_t=self._tx_pair.apply_inv(q_t[1:-1]), + v_t=jnp.squeeze(value_target_itx, axis=(0, 1))[1:], + r_t=clipped_reward[:-1], + discount_t=self._discount * sequence.discount[:-1], + c_t=self._retrace_lambda * jnp.minimum(1.0, jnp.exp(log_rhos[1:-1])), + ) + + # Expand dims to the expected [N=1, Z=1, T-1]. + q_value_target_itx = jnp.expand_dims(q_value_target_itx, axis=(0, 1)) + else: + # Compute bootstrap target from sequences. vmap return computation across + # N action and Z return samples; shape [N, Z, T-1]. + n_step_return_fn = functools.partial( + rlax.n_step_bootstrapped_returns, + r_t=clipped_reward[:-1], + discount_t=self._discount * sequence.discount[:-1], + n=self._n_step_for_sequence_bootstrap, + lambda_t=self._td_lambda, + ) + n_step_return_vfn = jax.vmap(jax.vmap(n_step_return_fn)) + q_value_target_itx = n_step_return_vfn(v_t=z_samples_itx[..., 1:]) + + # Transform back to the canonical space and stop gradients. + q_value_target = jax.lax.stop_gradient(self._tx_pair.apply(q_value_target_itx)) + reward_target = jax.lax.stop_gradient(self._tx_pair.apply(clipped_reward)) + value_target = jax.lax.stop_gradient(self._tx_pair.apply(value_target_itx)) + + if self._critic_type == mpo_types.CriticType.CATEGORICAL: + + @jax.vmap + def get_logits_and_values( + action: jnp.ndarray, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + critic_output = self._networks.critic_head_apply( + target_params, target_state_embedding[1:], action + ) + return critic_output.logits, critic_output.values + + z_t_logits, z_t_values = get_logits_and_values(a_evaluation[:, 1:]) + z_t_logits = jnp.squeeze(z_t_logits, axis=2) # [N, T-1, L] + z_t_values = z_t_values[0] # Values are identical at each N; [L]. + + gamma = self._discount * sequence.discount[:-1, None] # [T-1, 1] + r_t = clipped_reward[:-1, None] # [T-1, 1] + atoms_itx = self._tx_pair.apply_inv(z_t_values)[None, ...] # [1, L] + z_target_atoms = self._tx_pair.apply(r_t + gamma * atoms_itx) # [T-1, L] + # Note: this is n=1-step TD unless using experience=FromTransitions(n>1). + z_target_probs = jax.nn.softmax(z_t_logits) # [N, T-1, L] + z_target_atoms = jax.lax.broadcast( + z_target_atoms, z_target_probs.shape[:1] + ) # [N, T-1, L] + project_fn = functools.partial(rlax.categorical_l2_project, z_q=z_t_values) + z_target = jax.vmap(jax.vmap(project_fn))(z_target_atoms, z_target_probs) + z_target = jnp.mean(z_target, axis=0) + q_value_target = jax.lax.stop_gradient(z_target[None, ...]) # [1, T-1, L] + # TODO(abef): make q_v_target shape align with expected [N, Z, T-1] shape? + + targets = mpo_types.LossTargets( + policy=target_policy, # [T, ...] + a_improvement=a_improvement, # [N, T] + q_improvement=q_improvement, # [N, T] + q_value=q_value_target, # [N, Z, T-1] ([1, T-1, L] for CATEGORICAL) + value=value_target[..., :-1], # [N=1, Z=1, T-1] + reward=reward_target, # [T] + embedding=target_state_embedding, + ) # [T, ...] + + return targets + + def _loss_fn( + self, + params: mpo_networks.MPONetworkParams, + dual_params: mpo_types.DualParams, + # TODO(bshahr): clean up types: Step is not a great type for sequences. + sequence: adders.Step, + target_params: mpo_networks.MPONetworkParams, + key: jax_types.PRNGKey, + ) -> Tuple[jnp.ndarray, mpo_types.LogDict]: + # Compute the model predictions at the root and for the rollouts. + predictions = self._compute_predictions(params=params, sequence=sequence) + + # Compute the targets to use for the losses. + targets = self._compute_targets( + target_params=target_params, + dual_params=dual_params, + sequence=sequence, + online_policy=predictions.policy, + key=key, + ) + + # TODO(abef): mask policy loss at terminal states or use uniform targets + # is_terminal = sequence.discount == 0. + + # Compute MPO policy loss on each state in the sequence. + policy_loss, policy_stats = self._policy_loss_module( + params=dual_params, + online_action_distribution=predictions.policy, # [T, ...]. + target_action_distribution=targets.policy, # [T, ...]. + actions=targets.a_improvement, # Unused in discrete case; [N, T]. + q_values=targets.q_improvement, + ) # [N, T] + + # Compute the critic loss on the states in the sequence. + critic_loss = self._distributional_loss( + prediction=predictions.q_value, target=targets.q_value # [T-1, 1, ...] + ) # [N, Z, T-1] + + loss = ( + self._loss_scales.policy * policy_loss + + self._loss_scales.critic * critic_loss + ) + loss_logging_dict = { + "loss": loss, + "root_policy_loss": policy_loss, + "root_critic_loss": critic_loss, + "policy_loss": policy_loss, + "critic_loss": critic_loss, + } + + # Append MPO statistics. + loss_logging_dict.update( + {f"policy/root/{k}": v for k, v in policy_stats._asdict().items()} + ) + + # Compute rollout losses. + if self._model_rollout_length > 0: + model_rollout_loss, rollout_logs = self._model_rollout_loss_fn( + params, dual_params, sequence, predictions.embedding, targets, key + ) + loss += model_rollout_loss + loss_logging_dict.update(rollout_logs) + loss_logging_dict.update( + { + "policy_loss": policy_loss + rollout_logs["rollout_policy_loss"], + "critic_loss": critic_loss + rollout_logs["rollout_critic_loss"], + "loss": loss, + } + ) + + return loss, loss_logging_dict + + def _sgd_step( + self, state: TrainingState, transitions: Union[types.Transition, adders.Step], + ) -> Tuple[TrainingState, Dict[str, Any]]: + """Perform one parameter update step.""" + + if isinstance(transitions, types.Transition): + sequences = mpo_utils.make_sequences_from_transitions(transitions) + if self._model_rollout_length > 0: + raise ValueError("model rollouts not yet supported from transitions") + else: + sequences = transitions + + # Get next random_key and `batch_size` keys. + batch_size = sequences.reward.shape[0] + keys = jax.random.split(state.random_key, num=batch_size + 1) + random_key, keys = keys[0], keys[1:] + + # Vmap over the batch dimension when learning from sequences. + loss_vfn = jax.vmap(self._loss_fn, in_axes=(None, None, 0, None, 0)) + safe_mean = lambda x: jnp.mean(x) if x is not None else x + # TODO(bshahr): Consider cleaning this up via acme.tree_utils.tree_map. + loss_fn = lambda *a, **k: tree.map_structure(safe_mean, loss_vfn(*a, **k)) + + loss_and_grad = jax.value_and_grad(loss_fn, argnums=(0, 1), has_aux=True) + + # Compute the loss and gradient. + (_, loss_log_dict), all_gradients = loss_and_grad( + state.params, state.dual_params, sequences, state.target_params, keys + ) + + # Average gradients across replicas. + gradients, dual_gradients = jax.lax.pmean(all_gradients, _PMAP_AXIS_NAME) + + # Compute gradient norms before clipping. + gradients_norm = optax.global_norm(gradients) + dual_gradients_norm = optax.global_norm(dual_gradients) + + # Get optimizer updates and state. + updates, opt_state = self._optimizer.update( + gradients, state.opt_state, state.params + ) + dual_updates, dual_opt_state = self._dual_optimizer.update( + dual_gradients, state.dual_opt_state, state.dual_params + ) + + # Apply optimizer updates to parameters. + params = optax.apply_updates(state.params, updates) + dual_params = optax.apply_updates(state.dual_params, dual_updates) + + # Clip dual params at some minimum value. + dual_params = self._dual_clip_fn(dual_params) + + steps = state.steps + 1 + + # Periodically update target networks. + if self._target_update_period: + target_params = optax.periodic_update( + params, + state.target_params, + steps, # pytype: disable=wrong-arg-types # numpy-scalars + self._target_update_period, + ) + elif self._target_update_rate: + target_params = optax.incremental_update( + params, state.target_params, self._target_update_rate + ) + + new_state = TrainingState( # pytype: disable=wrong-arg-types # numpy-scalars + params=params, + target_params=target_params, + dual_params=dual_params, + opt_state=opt_state, + dual_opt_state=dual_opt_state, + steps=steps, + random_key=random_key, + ) + + # Log the metrics from this learner step. + metrics = {f"loss/{k}": v for k, v in loss_log_dict.items()} + + metrics.update( + { + "opt/grad_norm": gradients_norm, + "opt/param_norm": optax.global_norm(params), + } + ) + if callable(self._lr_schedule): + metrics["opt/learning_rate"] = self._lr_schedule( + state.steps + ) # pylint: disable=not-callable + + dual_metrics = { + "opt/dual_grad_norm": dual_gradients_norm, + "opt/dual_param_norm": optax.global_norm(dual_params), + "params/dual/log_temperature_avg": dual_params.log_temperature, + } + if isinstance(dual_params, continuous_losses.MPOParams): + dual_metrics.update( + { + "params/dual/log_alpha_mean_avg": dual_params.log_alpha_mean, + "params/dual/log_alpha_stddev_avg": dual_params.log_alpha_stddev, + } + ) + if dual_params.log_penalty_temperature is not None: + dual_metrics[ + "params/dual/log_penalty_temp_mean" + ] = dual_params.log_penalty_temperature + elif isinstance(dual_params, discrete_losses.CategoricalMPOParams): + dual_metrics["params/dual/log_alpha_avg"] = dual_params.log_alpha + metrics.update(jax.tree_map(jnp.mean, dual_metrics)) + + return new_state, metrics + + def step(self): + """Perform one learner step, which in general does multiple SGD steps.""" + with jax.profiler.StepTraceAnnotation("step", step_num=self._current_step): + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + sample = next(self._iterator) + if isinstance(self._experience_type, mpo_types.FromTransitions): + minibatch = types.Transition(*sample.data) + elif isinstance(self._experience_type, mpo_types.FromSequences): + minibatch = adders.Step(*sample.data) + + self._state, metrics = self._sgd_steps(self._state, minibatch) + self._current_step, metrics = mpo_utils.get_from_first_device( + (self._state.steps, metrics) + ) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment( + steps=self._sgd_steps_per_learner_step, walltime=elapsed_time + ) + + if elapsed_time > 0: + metrics["steps_per_second"] = ( + self._sgd_steps_per_learner_step / elapsed_time + ) + else: + metrics["steps_per_second"] = 0.0 + + # Attempts to write the logs. + if self._logger: + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> network_lib.Params: + params = mpo_utils.get_from_first_device(self._state.target_params) + + variables = { + "policy_head": params.policy_head, + "critic_head": params.critic_head, + "torso": params.torso, + "network": params, + "policy": params._replace(critic_head={}), + "critic": params._replace(policy_head={}), + } + return [variables[name] for name in names] + + def save(self) -> TrainingState: + return jax.tree_map(mpo_utils.get_from_first_device, self._state) + + def restore(self, state: TrainingState): + self._state = utils.replicate_in_all_devices(state, self._local_devices) def _get_default_optimizer( - learning_rate: float, - max_grad_norm: Optional[float] = None) -> optax.GradientTransformation: - optimizer = optax.adam(learning_rate) - if max_grad_norm and max_grad_norm > 0: - optimizer = optax.chain(optax.clip_by_global_norm(max_grad_norm), optimizer) - return optimizer + learning_rate: float, max_grad_norm: Optional[float] = None +) -> optax.GradientTransformation: + optimizer = optax.adam(learning_rate) + if max_grad_norm and max_grad_norm > 0: + optimizer = optax.chain(optax.clip_by_global_norm(max_grad_norm), optimizer) + return optimizer diff --git a/acme/agents/jax/mpo/networks.py b/acme/agents/jax/mpo/networks.py index 3097cdeae5..96624ac53e 100644 --- a/acme/agents/jax/mpo/networks.py +++ b/acme/agents/jax/mpo/networks.py @@ -17,10 +17,6 @@ import dataclasses from typing import Callable, NamedTuple, Optional, Sequence, Tuple, Union -from acme import specs -from acme.agents.jax.mpo import types -from acme.jax import networks as networks_lib -from acme.jax import utils import chex import haiku as hk import haiku.initializers as hk_init @@ -29,57 +25,73 @@ import numpy as np import tensorflow_probability.substrates.jax as tfp +from acme import specs +from acme.agents.jax.mpo import types +from acme.jax import networks as networks_lib +from acme.jax import utils + tfd = tfp.distributions DistributionOrArray = Union[tfd.Distribution, jnp.ndarray] class MPONetworkParams(NamedTuple): - policy_head: Optional[hk.Params] = None - critic_head: Optional[hk.Params] = None - torso: Optional[hk.Params] = None - torso_initial_state: Optional[hk.Params] = None - dynamics_model: Union[hk.Params, Tuple[()]] = () - dynamics_model_initial_state: Union[hk.Params, Tuple[()]] = () + policy_head: Optional[hk.Params] = None + critic_head: Optional[hk.Params] = None + torso: Optional[hk.Params] = None + torso_initial_state: Optional[hk.Params] = None + dynamics_model: Union[hk.Params, Tuple[()]] = () + dynamics_model_initial_state: Union[hk.Params, Tuple[()]] = () @dataclasses.dataclass class UnrollableNetwork: - """Network that can unroll over an input sequence.""" - init: Callable[[networks_lib.PRNGKey, types.Observation, hk.LSTMState], - hk.Params] - apply: Callable[[hk.Params, types.Observation, hk.LSTMState], - Tuple[jnp.ndarray, hk.LSTMState]] - unroll: Callable[[hk.Params, types.Observation, hk.LSTMState], - Tuple[jnp.ndarray, hk.LSTMState]] - initial_state_fn_init: Callable[[networks_lib.PRNGKey, Optional[int]], - hk.Params] - initial_state_fn: Callable[[hk.Params, Optional[int]], hk.LSTMState] + """Network that can unroll over an input sequence.""" + + init: Callable[[networks_lib.PRNGKey, types.Observation, hk.LSTMState], hk.Params] + apply: Callable[ + [hk.Params, types.Observation, hk.LSTMState], Tuple[jnp.ndarray, hk.LSTMState] + ] + unroll: Callable[ + [hk.Params, types.Observation, hk.LSTMState], Tuple[jnp.ndarray, hk.LSTMState] + ] + initial_state_fn_init: Callable[[networks_lib.PRNGKey, Optional[int]], hk.Params] + initial_state_fn: Callable[[hk.Params, Optional[int]], hk.LSTMState] @dataclasses.dataclass class MPONetworks: - """Network for the MPO agent.""" - policy_head: Optional[hk.Transformed] = None - critic_head: Optional[hk.Transformed] = None - torso: Optional[UnrollableNetwork] = None - dynamics_model: Optional[UnrollableNetwork] = None - - def policy_head_apply(self, params: MPONetworkParams, - obs_embedding: types.ObservationEmbedding): - return self.policy_head.apply(params.policy_head, obs_embedding) - - def critic_head_apply(self, params: MPONetworkParams, - obs_embedding: types.ObservationEmbedding, - actions: types.Action): - return self.critic_head.apply(params.critic_head, obs_embedding, actions) - - def torso_unroll(self, params: MPONetworkParams, - observations: types.Observation, state: hk.LSTMState): - return self.torso.unroll(params.torso, observations, state) - - def dynamics_model_unroll(self, params: MPONetworkParams, - actions: types.Action, state: hk.LSTMState): - return self.dynamics_model.unroll(params.dynamics_model, actions, state) + """Network for the MPO agent.""" + + policy_head: Optional[hk.Transformed] = None + critic_head: Optional[hk.Transformed] = None + torso: Optional[UnrollableNetwork] = None + dynamics_model: Optional[UnrollableNetwork] = None + + def policy_head_apply( + self, params: MPONetworkParams, obs_embedding: types.ObservationEmbedding + ): + return self.policy_head.apply(params.policy_head, obs_embedding) + + def critic_head_apply( + self, + params: MPONetworkParams, + obs_embedding: types.ObservationEmbedding, + actions: types.Action, + ): + return self.critic_head.apply(params.critic_head, obs_embedding, actions) + + def torso_unroll( + self, + params: MPONetworkParams, + observations: types.Observation, + state: hk.LSTMState, + ): + return self.torso.unroll(params.torso, observations, state) + + def dynamics_model_unroll( + self, params: MPONetworkParams, actions: types.Action, state: hk.LSTMState + ): + return self.dynamics_model.unroll(params.dynamics_model, actions, state) def init_params( @@ -89,58 +101,61 @@ def init_params( add_batch_dim: bool = False, dynamics_rollout_length: int = 0, ) -> Tuple[MPONetworkParams, hk.LSTMState]: - """Initialize the parameters of a MPO network.""" - - rng_keys = jax.random.split(random_key, 6) - - # Create a dummy observation/action to initialize network parameters. - observations, actions = utils.zeros_like((spec.observations, spec.actions)) - - # Add batch dimensions if necessary by the scope that is calling this init. - if add_batch_dim: - observations, actions = utils.add_batch_dim((observations, actions)) - - # Initialize the state torso parameters and create a dummy core state. - batch_size = 1 if add_batch_dim else None - params_torso_initial_state = networks.torso.initial_state_fn_init( - rng_keys[0], batch_size) - state = networks.torso.initial_state_fn( - params_torso_initial_state, batch_size) - - # Initialize the core and unroll one step to create a dummy core output. - # The input to the core is the current action and the next observation. - params_torso = networks.torso.init(rng_keys[1], observations, state) - embeddings, _ = networks.torso.apply(params_torso, observations, state) - - # Initialize the policy and critic heads by passing in the dummy embedding. - params_policy_head, params_critic_head = {}, {} # Cannot be None for BIT. - if networks.policy_head: - params_policy_head = networks.policy_head.init(rng_keys[2], embeddings) - if networks.critic_head: - params_critic_head = networks.critic_head.init(rng_keys[3], embeddings, - actions) - - # Initialize the recurrent dynamics model if it exists. - if networks.dynamics_model and dynamics_rollout_length > 0: - params_dynamics_initial_state = networks.dynamics_model.initial_state_fn_init( - rng_keys[4], embeddings) - dynamics_state = networks.dynamics_model.initial_state_fn( - params_dynamics_initial_state, embeddings) - params_dynamics = networks.dynamics_model.init( - rng_keys[5], actions, dynamics_state) - else: - params_dynamics_initial_state = () - params_dynamics = () - - params = MPONetworkParams( - policy_head=params_policy_head, - critic_head=params_critic_head, - torso=params_torso, - torso_initial_state=params_torso_initial_state, - dynamics_model=params_dynamics, - dynamics_model_initial_state=params_dynamics_initial_state) - - return params, state + """Initialize the parameters of a MPO network.""" + + rng_keys = jax.random.split(random_key, 6) + + # Create a dummy observation/action to initialize network parameters. + observations, actions = utils.zeros_like((spec.observations, spec.actions)) + + # Add batch dimensions if necessary by the scope that is calling this init. + if add_batch_dim: + observations, actions = utils.add_batch_dim((observations, actions)) + + # Initialize the state torso parameters and create a dummy core state. + batch_size = 1 if add_batch_dim else None + params_torso_initial_state = networks.torso.initial_state_fn_init( + rng_keys[0], batch_size + ) + state = networks.torso.initial_state_fn(params_torso_initial_state, batch_size) + + # Initialize the core and unroll one step to create a dummy core output. + # The input to the core is the current action and the next observation. + params_torso = networks.torso.init(rng_keys[1], observations, state) + embeddings, _ = networks.torso.apply(params_torso, observations, state) + + # Initialize the policy and critic heads by passing in the dummy embedding. + params_policy_head, params_critic_head = {}, {} # Cannot be None for BIT. + if networks.policy_head: + params_policy_head = networks.policy_head.init(rng_keys[2], embeddings) + if networks.critic_head: + params_critic_head = networks.critic_head.init(rng_keys[3], embeddings, actions) + + # Initialize the recurrent dynamics model if it exists. + if networks.dynamics_model and dynamics_rollout_length > 0: + params_dynamics_initial_state = networks.dynamics_model.initial_state_fn_init( + rng_keys[4], embeddings + ) + dynamics_state = networks.dynamics_model.initial_state_fn( + params_dynamics_initial_state, embeddings + ) + params_dynamics = networks.dynamics_model.init( + rng_keys[5], actions, dynamics_state + ) + else: + params_dynamics_initial_state = () + params_dynamics = () + + params = MPONetworkParams( + policy_head=params_policy_head, + critic_head=params_critic_head, + torso=params_torso, + torso_initial_state=params_torso_initial_state, + dynamics_model=params_dynamics, + dynamics_model_initial_state=params_dynamics_initial_state, + ) + + return params, state def make_unrollable_network( @@ -148,40 +163,43 @@ def make_unrollable_network( make_feedforward_module: Optional[Callable[[], hk.SupportsCall]] = None, make_initial_state_fn: Optional[Callable[[], hk.SupportsCall]] = None, ) -> UnrollableNetwork: - """Produces an UnrollableNetwork and a state initializing hk.Transformed.""" - - def default_initial_state_fn(batch_size: Optional[int] = None) -> jnp.ndarray: - return make_core_module().initial_state(batch_size) - - def _apply_core_fn(observation: types.Observation, - state: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: - if make_feedforward_module: - observation = make_feedforward_module()(observation) - return make_core_module()(observation, state) - - def _unroll_core_fn(observation: types.Observation, - state: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: - if make_feedforward_module: - observation = make_feedforward_module()(observation) - return hk.dynamic_unroll(make_core_module(), observation, state) - - if make_initial_state_fn: - initial_state_fn = make_initial_state_fn() - else: - initial_state_fn = default_initial_state_fn - - # Transform module functions into pure functions. - hk_initial_state_fn = hk.without_apply_rng(hk.transform(initial_state_fn)) - apply_core = hk.without_apply_rng(hk.transform(_apply_core_fn)) - unroll_core = hk.without_apply_rng(hk.transform(_unroll_core_fn)) - - # Pack all core network pure functions into a single convenient container. - return UnrollableNetwork( - init=apply_core.init, - apply=apply_core.apply, - unroll=unroll_core.apply, - initial_state_fn_init=hk_initial_state_fn.init, - initial_state_fn=hk_initial_state_fn.apply) + """Produces an UnrollableNetwork and a state initializing hk.Transformed.""" + + def default_initial_state_fn(batch_size: Optional[int] = None) -> jnp.ndarray: + return make_core_module().initial_state(batch_size) + + def _apply_core_fn( + observation: types.Observation, state: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + if make_feedforward_module: + observation = make_feedforward_module()(observation) + return make_core_module()(observation, state) + + def _unroll_core_fn( + observation: types.Observation, state: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + if make_feedforward_module: + observation = make_feedforward_module()(observation) + return hk.dynamic_unroll(make_core_module(), observation, state) + + if make_initial_state_fn: + initial_state_fn = make_initial_state_fn() + else: + initial_state_fn = default_initial_state_fn + + # Transform module functions into pure functions. + hk_initial_state_fn = hk.without_apply_rng(hk.transform(initial_state_fn)) + apply_core = hk.without_apply_rng(hk.transform(_apply_core_fn)) + unroll_core = hk.without_apply_rng(hk.transform(_unroll_core_fn)) + + # Pack all core network pure functions into a single convenient container. + return UnrollableNetwork( + init=apply_core.init, + apply=apply_core.apply, + unroll=unroll_core.apply, + initial_state_fn_init=hk_initial_state_fn.init, + initial_state_fn=hk_initial_state_fn.apply, + ) def make_control_networks( @@ -195,81 +213,85 @@ def make_control_networks( mog_init_scale: float = 1e-3, # Used by MoG critic. mog_num_components: int = 5, # Used by MoG critic. categorical_num_bins: int = 51, # Used by CATEGORICAL* critics. - vmin: float = -150., # Used by CATEGORICAL* critics. - vmax: float = 150., # Used by CATEGORICAL* critics. + vmin: float = -150.0, # Used by CATEGORICAL* critics. + vmax: float = 150.0, # Used by CATEGORICAL* critics. ) -> MPONetworks: - """Creates MPONetworks to be used DM Control suite tasks.""" - - # Unpack the environment spec to get appropriate shapes, dtypes, etc. - num_dimensions = np.prod(environment_spec.actions.shape, dtype=int) - - # Factory to create the core hk.Module. Must be a factory as the module must - # be initialized within a hk.transform scope. - if with_recurrence: - make_core_module = lambda: GRUWithSkip(16) - else: - make_core_module = hk.IdentityCore - - def policy_fn(observation: types.NestedArray) -> tfd.Distribution: - embedding = networks_lib.LayerNormMLP( - policy_layer_sizes, activate_final=True)( - observation) - return networks_lib.MultivariateNormalDiagHead( - num_dimensions, init_scale=policy_init_scale)( - embedding) - - def critic_fn(observation: types.NestedArray, - action: types.NestedArray) -> DistributionOrArray: - # Action is clipped to avoid critic extrapolations outside the spec range. - clipped_action = networks_lib.ClipToSpec(environment_spec.actions)(action) - inputs = jnp.concatenate([observation, clipped_action], axis=-1) - embedding = networks_lib.LayerNormMLP( - critic_layer_sizes, activate_final=True)( - inputs) - - if critic_type == types.CriticType.MIXTURE_OF_GAUSSIANS: - return networks_lib.GaussianMixture( - num_dimensions=1, - num_components=mog_num_components, - multivariate=False, - init_scale=mog_init_scale, - append_singleton_event_dim=False, - reinterpreted_batch_ndims=0)( - embedding) - elif critic_type in (types.CriticType.CATEGORICAL, - types.CriticType.CATEGORICAL_2HOT): - return networks_lib.CategoricalCriticHead( - num_bins=categorical_num_bins, vmin=vmin, vmax=vmax)( - embedding) - else: - return hk.Linear( - output_size=1, w_init=hk_init.TruncatedNormal(0.01))( - embedding) + """Creates MPONetworks to be used DM Control suite tasks.""" - # Create unrollable torso. - torso = make_unrollable_network(make_core_module=make_core_module) + # Unpack the environment spec to get appropriate shapes, dtypes, etc. + num_dimensions = np.prod(environment_spec.actions.shape, dtype=int) - # Create MPONetworks to add functionality required by the agent. - return MPONetworks( - policy_head=hk.without_apply_rng(hk.transform(policy_fn)), - critic_head=hk.without_apply_rng(hk.transform(critic_fn)), - torso=torso) + # Factory to create the core hk.Module. Must be a factory as the module must + # be initialized within a hk.transform scope. + if with_recurrence: + make_core_module = lambda: GRUWithSkip(16) + else: + make_core_module = hk.IdentityCore + + def policy_fn(observation: types.NestedArray) -> tfd.Distribution: + embedding = networks_lib.LayerNormMLP(policy_layer_sizes, activate_final=True)( + observation + ) + return networks_lib.MultivariateNormalDiagHead( + num_dimensions, init_scale=policy_init_scale + )(embedding) + + def critic_fn( + observation: types.NestedArray, action: types.NestedArray + ) -> DistributionOrArray: + # Action is clipped to avoid critic extrapolations outside the spec range. + clipped_action = networks_lib.ClipToSpec(environment_spec.actions)(action) + inputs = jnp.concatenate([observation, clipped_action], axis=-1) + embedding = networks_lib.LayerNormMLP(critic_layer_sizes, activate_final=True)( + inputs + ) + + if critic_type == types.CriticType.MIXTURE_OF_GAUSSIANS: + return networks_lib.GaussianMixture( + num_dimensions=1, + num_components=mog_num_components, + multivariate=False, + init_scale=mog_init_scale, + append_singleton_event_dim=False, + reinterpreted_batch_ndims=0, + )(embedding) + elif critic_type in ( + types.CriticType.CATEGORICAL, + types.CriticType.CATEGORICAL_2HOT, + ): + return networks_lib.CategoricalCriticHead( + num_bins=categorical_num_bins, vmin=vmin, vmax=vmax + )(embedding) + else: + return hk.Linear(output_size=1, w_init=hk_init.TruncatedNormal(0.01))( + embedding + ) + + # Create unrollable torso. + torso = make_unrollable_network(make_core_module=make_core_module) + + # Create MPONetworks to add functionality required by the agent. + return MPONetworks( + policy_head=hk.without_apply_rng(hk.transform(policy_fn)), + critic_head=hk.without_apply_rng(hk.transform(critic_fn)), + torso=torso, + ) def add_batch(nest, batch_size: Optional[int]): - """Adds a batch dimension at axis 0 to the leaves of a nested structure.""" - broadcast = lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape) - return jax.tree_map(broadcast, nest) + """Adds a batch dimension at axis 0 to the leaves of a nested structure.""" + broadcast = lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape) + return jax.tree_map(broadcast, nest) def w_init_identity(shape: Sequence[int], dtype) -> jnp.ndarray: - chex.assert_equal(len(shape), 2) - chex.assert_equal(shape[0], shape[1]) - return jnp.eye(shape[0], dtype=dtype) + chex.assert_equal(len(shape), 2) + chex.assert_equal(shape[0], shape[1]) + return jnp.eye(shape[0], dtype=dtype) class IdentityRNN(hk.RNNCore): - r"""Basic fully-connected RNN core with identity initialization. + r"""Basic fully-connected RNN core with identity initialization. Given :math:`x_t` and the previous hidden state :math:`h_{t-1}` the core computes @@ -281,66 +303,70 @@ class IdentityRNN(hk.RNNCore): https://arxiv.org/pdf/1504.00941.pdf """ - def __init__(self, - hidden_size: int, - hidden_scale: float = 1e-2, - name: Optional[str] = None): - """Constructs a vanilla RNN core. + def __init__( + self, hidden_size: int, hidden_scale: float = 1e-2, name: Optional[str] = None + ): + """Constructs a vanilla RNN core. Args: hidden_size: Hidden layer size. hidden_scale: Scalar multiplying the hidden-to-hidden matmul. name: Name of the module. """ - super().__init__(name=name) - self._initial_state = jnp.zeros([hidden_size]) - self._hidden_scale = hidden_scale - self._input_to_hidden = hk.Linear(hidden_size) - self._hidden_to_hidden = hk.Linear( - hidden_size, with_bias=True, w_init=w_init_identity) - - def __call__(self, inputs: jnp.ndarray, prev_state: jnp.ndarray): - out = jax.nn.relu( - self._input_to_hidden(inputs) + - self._hidden_scale * self._hidden_to_hidden(prev_state)) - return out, out - - def initial_state(self, batch_size: Optional[int]): - state = self._initial_state - if batch_size is not None: - state = add_batch(state, batch_size) - return state + super().__init__(name=name) + self._initial_state = jnp.zeros([hidden_size]) + self._hidden_scale = hidden_scale + self._input_to_hidden = hk.Linear(hidden_size) + self._hidden_to_hidden = hk.Linear( + hidden_size, with_bias=True, w_init=w_init_identity + ) + + def __call__(self, inputs: jnp.ndarray, prev_state: jnp.ndarray): + out = jax.nn.relu( + self._input_to_hidden(inputs) + + self._hidden_scale * self._hidden_to_hidden(prev_state) + ) + return out, out + + def initial_state(self, batch_size: Optional[int]): + state = self._initial_state + if batch_size is not None: + state = add_batch(state, batch_size) + return state class GRU(hk.GRU): - """GRU with an identity initialization.""" - - def __init__(self, hidden_size: int, name: Optional[str] = None): + """GRU with an identity initialization.""" - def b_init(unused_size: Sequence[int], dtype) -> jnp.ndarray: - """Initializes the biases so the GRU ignores the state and acts as a tanh.""" - return jnp.concatenate([ - +2 * jnp.ones([hidden_size], dtype=dtype), - -2 * jnp.ones([hidden_size], dtype=dtype), - jnp.zeros([hidden_size], dtype=dtype) - ]) + def __init__(self, hidden_size: int, name: Optional[str] = None): + def b_init(unused_size: Sequence[int], dtype) -> jnp.ndarray: + """Initializes the biases so the GRU ignores the state and acts as a tanh.""" + return jnp.concatenate( + [ + +2 * jnp.ones([hidden_size], dtype=dtype), + -2 * jnp.ones([hidden_size], dtype=dtype), + jnp.zeros([hidden_size], dtype=dtype), + ] + ) - super().__init__(hidden_size=hidden_size, b_init=b_init, name=name) + super().__init__(hidden_size=hidden_size, b_init=b_init, name=name) class GRUWithSkip(hk.GRU): - """GRU with a skip-connection from input to output.""" + """GRU with a skip-connection from input to output.""" - def __call__(self, inputs: jnp.ndarray, prev_state: jnp.ndarray): - outputs, state = super().__call__(inputs, prev_state) - outputs = jnp.concatenate([inputs, outputs], axis=-1) - return outputs, state + def __call__(self, inputs: jnp.ndarray, prev_state: jnp.ndarray): + outputs, state = super().__call__(inputs, prev_state) + outputs = jnp.concatenate([inputs, outputs], axis=-1) + return outputs, state class Conv2DLSTMWithSkip(hk.Conv2DLSTM): - """Conv2DLSTM with a skip-connection from input to output.""" - - def __call__(self, inputs: jnp.ndarray, state: jnp.ndarray): - outputs, state = super().__call__(inputs, state) # pytype: disable=wrong-arg-types # jax-ndarray - outputs = jnp.concatenate([inputs, outputs], axis=-1) - return outputs, state + """Conv2DLSTM with a skip-connection from input to output.""" + + def __call__(self, inputs: jnp.ndarray, state: jnp.ndarray): + outputs, state = super().__call__( + inputs, state + ) # pytype: disable=wrong-arg-types # jax-ndarray + outputs = jnp.concatenate([inputs, outputs], axis=-1) + return outputs, state diff --git a/acme/agents/jax/mpo/rollout_loss.py b/acme/agents/jax/mpo/rollout_loss.py index 35e292d821..a2cde7e128 100644 --- a/acme/agents/jax/mpo/rollout_loss.py +++ b/acme/agents/jax/mpo/rollout_loss.py @@ -16,6 +16,11 @@ from typing import Tuple +import chex +import jax +import jax.numpy as jnp +import rlax + from acme import types from acme.adders import reverb as adders from acme.agents.jax.mpo import categorical_mpo as discrete_losses @@ -23,49 +28,40 @@ from acme.agents.jax.mpo import types as mpo_types from acme.agents.jax.mpo import utils as mpo_utils from acme.jax import networks as network_lib -import chex -import jax -import jax.numpy as jnp -import rlax -def softmax_cross_entropy( - logits: chex.Array, target_probs: chex.Array) -> chex.Array: - """Compute cross entropy loss between logits and target probabilities.""" - chex.assert_equal_shape([target_probs, logits]) - return -jnp.sum(target_probs * jax.nn.log_softmax(logits), axis=-1) +def softmax_cross_entropy(logits: chex.Array, target_probs: chex.Array) -> chex.Array: + """Compute cross entropy loss between logits and target probabilities.""" + chex.assert_equal_shape([target_probs, logits]) + return -jnp.sum(target_probs * jax.nn.log_softmax(logits), axis=-1) def top1_accuracy_tiebreak( - logits: chex.Array, - targets: chex.Array, - *, - rng: chex.PRNGKey, - eps: float = 1e-6) -> chex.Array: - """Compute the top-1 accuracy with an argmax of targets (random tie-break).""" - noise = jax.random.uniform(rng, shape=targets.shape, - minval=-eps, maxval=eps) - acc = jnp.argmax(logits, axis=-1) == jnp.argmax(targets + noise, axis=-1) - return jnp.mean(acc) + logits: chex.Array, targets: chex.Array, *, rng: chex.PRNGKey, eps: float = 1e-6 +) -> chex.Array: + """Compute the top-1 accuracy with an argmax of targets (random tie-break).""" + noise = jax.random.uniform(rng, shape=targets.shape, minval=-eps, maxval=eps) + acc = jnp.argmax(logits, axis=-1) == jnp.argmax(targets + noise, axis=-1) + return jnp.mean(acc) class RolloutLoss: - """A MuZero/Muesli-style loss on the rollouts of the dynamics model.""" - - def __init__( - self, - dynamics_model: mpo_networks.UnrollableNetwork, - model_rollout_length: int, - loss_scales: mpo_types.LossScalesConfig, - distributional_loss_fn: mpo_types.DistributionalLossFn, - ): - self._dynamics_model = dynamics_model - self._model_rollout_length = model_rollout_length - self._loss_scales = loss_scales - self._distributional_loss_fn = distributional_loss_fn - - def _rolling_window(self, x: chex.Array, axis: int = 0) -> chex.Array: - """A convenient tree-mapped and configured call to rolling window. + """A MuZero/Muesli-style loss on the rollouts of the dynamics model.""" + + def __init__( + self, + dynamics_model: mpo_networks.UnrollableNetwork, + model_rollout_length: int, + loss_scales: mpo_types.LossScalesConfig, + distributional_loss_fn: mpo_types.DistributionalLossFn, + ): + self._dynamics_model = dynamics_model + self._model_rollout_length = model_rollout_length + self._loss_scales = loss_scales + self._distributional_loss_fn = distributional_loss_fn + + def _rolling_window(self, x: chex.Array, axis: int = 0) -> chex.Array: + """A convenient tree-mapped and configured call to rolling window. Stacks R = T - K + 1 action slices of length K = model_rollout_length from tensor x: [..., 0:K; ...; T-K:T, ...]. @@ -79,159 +75,191 @@ def _rolling_window(self, x: chex.Array, axis: int = 0) -> chex.Array: A tensor containing the stacked slices [0:K, ... T-K:T] from an axis of x with shape [..., K, R, ...] for input shape [..., T, ...]. """ - def rw(y): - return mpo_utils.rolling_window( - y, window=self._model_rollout_length, axis=axis, time_major=True) - - return mpo_utils.tree_map_distribution(rw, x) - - def _compute_model_rollout_predictions( - self, params: mpo_networks.MPONetworkParams, - state_embeddings: types.NestedArray, - action_sequence: types.NestedArray) -> mpo_types.ModelOutputs: - """Roll out the dynamics model for each embedding state.""" - assert self._model_rollout_length > 0 - # Stack the R=T-K+1 action slices of length K: [0:K; ...; T-K:T]; [K, R]. - rollout_actions = self._rolling_window(action_sequence) - - # Create batch of root states (embeddings) s_t for t \in {0, ..., R}. - num_rollouts = action_sequence.shape[0] - self._model_rollout_length + 1 - root_state = self._dynamics_model.initial_state_fn( - params.dynamics_model_initial_state, state_embeddings[:num_rollouts]) - # TODO(abef): randomly choose (fewer?) root unroll states, as in Muesli? - - # Roll out K steps forward in time for each root embedding; [K, R, ...]. - # For example, policy_rollout[k, t] is the step-k prediction starting from - # state s_t (and same for value_rollout and reward_rollout). Thus, for - # valid values of k, t, and i, policy_rollout[k, t] and - # policy_rollout[k-i, t+i] share the same target. - (policy_rollout, value_rollout, reward_rollout, - embedding_rollout), _ = self._dynamics_model.unroll( - params.dynamics_model, rollout_actions, root_state) - # TODO(abef): try using the same params for both the root & rollout heads. - - chex.assert_shape([rollout_actions, embedding_rollout], - (self._model_rollout_length, num_rollouts, ...)) - - # Create the outputs but drop the rollout that uses action a_{T-1} (and - # thus contains state s_T) for the policy, value, and embedding because we - # don't have targets for s_T (but we do know them for the final reward). - # Also drop the rollout with s_{T-1} for the value because we don't have - # targets for that either. - return mpo_types.ModelOutputs( - policy=policy_rollout[:, :-1], # [K, R-1, ...] - value=value_rollout[:, :-2], # [K, R-2, ...] - reward=reward_rollout, # [K, R, ...] - embedding=embedding_rollout[:, :-1]) # [K, R-1, ...] - - def __call__( - self, - params: mpo_networks.MPONetworkParams, - dual_params: mpo_types.DualParams, - sequence: adders.Step, - state_embeddings: types.NestedArray, - targets: mpo_types.LossTargets, - key: network_lib.PRNGKey, - ) -> Tuple[jnp.ndarray, mpo_types.LogDict]: - - num_rollouts = sequence.reward.shape[0] - self._model_rollout_length + 1 - indices = jnp.arange(num_rollouts) - - # Create rollout predictions. - rollout = self._compute_model_rollout_predictions( - params=params, state_embeddings=state_embeddings, - action_sequence=sequence.action) - - # Create rollout target tensors. The rollouts will not contain the policy - # and value at t=0 because they start after taking the first action in - # the sequence, so drop those when creating the targets. They will contain - # the reward at t=0, however, because of how the sequences are stored. - # Rollout target shapes: - # - value: [N, Z, T-2] -> [N, Z, K, R-2], - # - reward: [T] -> [K, R]. - value_targets = self._rolling_window(targets.value[..., 1:], axis=-1) - reward_targets = self._rolling_window(targets.reward)[None, None, ...] - - # Define the value and reward rollout loss functions. - def value_loss_fn(root_idx) -> jnp.ndarray: - return self._distributional_loss_fn( - rollout.value[:, root_idx], # [K, R-2, ...] - value_targets[..., root_idx]) # [..., K, R-2] - - def reward_loss_fn(root_idx) -> jnp.ndarray: - return self._distributional_loss_fn( - rollout.reward[:, root_idx], # [K, R, ...] - reward_targets[..., root_idx]) # [..., K, R] - - # Reward and value losses. - critic_loss = jnp.mean(jax.vmap(value_loss_fn)(indices[:-2])) - reward_loss = jnp.mean(jax.vmap(reward_loss_fn)(indices)) - - # Define the MPO policy rollout loss. - mpo_policy_loss = 0 - if self._loss_scales.rollout.policy: - # Rollout target shapes: - # - policy: [T-1, ...] -> [K, R-1, ...], - # - q_improvement: [N, T-1] -> [N, K, R-1]. - policy_targets = self._rolling_window(targets.policy[1:]) - q_improvement = self._rolling_window(targets.q_improvement[:, 1:], axis=1) - - def policy_loss_fn(root_idx) -> jnp.ndarray: - chex.assert_shape((rollout.policy.logits, policy_targets.logits), # pytype: disable=attribute-error # numpy-scalars - (self._model_rollout_length, num_rollouts - 1, None)) - chex.assert_shape(q_improvement, - (None, self._model_rollout_length, num_rollouts - 1)) - # Compute MPO's E-step unnormalized logits. - temperature = discrete_losses.get_temperature_from_params(dual_params) - policy_target_probs = jax.nn.softmax( - jnp.transpose(q_improvement[..., root_idx]) / temperature + - jax.nn.log_softmax(policy_targets[:, root_idx].logits, axis=-1)) # pytype: disable=attribute-error # numpy-scalars - return softmax_cross_entropy(rollout.policy[:, root_idx].logits, # pytype: disable=bad-return-type # numpy-scalars - jax.lax.stop_gradient(policy_target_probs)) - - # Compute the MPO loss and add it to the overall rollout policy loss. - mpo_policy_loss = jax.vmap(policy_loss_fn)(indices[:-1]) - mpo_policy_loss = jnp.mean(mpo_policy_loss) - - # Define the BC policy rollout loss (only supported for discrete policies). - bc_policy_loss, bc_policy_acc = 0, 0 - if self._loss_scales.rollout.bc_policy: - num_actions = rollout.policy.logits.shape[-1] # A - bc_targets = self._rolling_window( # [T-1, A] -> [K, R-1, A] - rlax.one_hot(sequence.action[1:], num_actions)) - - def bc_policy_loss_fn(root_idx) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Self-behavior-cloning loss (cross entropy on rollout actions).""" + + def rw(y): + return mpo_utils.rolling_window( + y, window=self._model_rollout_length, axis=axis, time_major=True + ) + + return mpo_utils.tree_map_distribution(rw, x) + + def _compute_model_rollout_predictions( + self, + params: mpo_networks.MPONetworkParams, + state_embeddings: types.NestedArray, + action_sequence: types.NestedArray, + ) -> mpo_types.ModelOutputs: + """Roll out the dynamics model for each embedding state.""" + assert self._model_rollout_length > 0 + # Stack the R=T-K+1 action slices of length K: [0:K; ...; T-K:T]; [K, R]. + rollout_actions = self._rolling_window(action_sequence) + + # Create batch of root states (embeddings) s_t for t \in {0, ..., R}. + num_rollouts = action_sequence.shape[0] - self._model_rollout_length + 1 + root_state = self._dynamics_model.initial_state_fn( + params.dynamics_model_initial_state, state_embeddings[:num_rollouts] + ) + # TODO(abef): randomly choose (fewer?) root unroll states, as in Muesli? + + # Roll out K steps forward in time for each root embedding; [K, R, ...]. + # For example, policy_rollout[k, t] is the step-k prediction starting from + # state s_t (and same for value_rollout and reward_rollout). Thus, for + # valid values of k, t, and i, policy_rollout[k, t] and + # policy_rollout[k-i, t+i] share the same target. + ( + (policy_rollout, value_rollout, reward_rollout, embedding_rollout), + _, + ) = self._dynamics_model.unroll( + params.dynamics_model, rollout_actions, root_state + ) + # TODO(abef): try using the same params for both the root & rollout heads. + chex.assert_shape( - (rollout.policy.logits, bc_targets), - (self._model_rollout_length, num_rollouts - 1, num_actions)) - loss = softmax_cross_entropy(rollout.policy.logits[:, root_idx], - bc_targets[:, root_idx]) - top1_accuracy = top1_accuracy_tiebreak( - rollout.policy.logits[:, root_idx], - bc_targets[:, root_idx], - rng=key) - return loss, top1_accuracy # pytype: disable=bad-return-type # numpy-scalars - - # Compute each rollout loss by vmapping over the rollouts. - bc_policy_loss, bc_policy_acc = jax.vmap(bc_policy_loss_fn)(indices[:-1]) - bc_policy_loss = jnp.mean(bc_policy_loss) - bc_policy_acc = jnp.mean(bc_policy_acc) - - # Combine losses. - loss = ( - self._loss_scales.rollout.policy * mpo_policy_loss + - self._loss_scales.rollout.bc_policy * bc_policy_loss + - self._loss_scales.critic * self._loss_scales.rollout.critic * - critic_loss + self._loss_scales.rollout.reward * reward_loss) - - logging_dict = { - 'rollout_critic_loss': critic_loss, - 'rollout_reward_loss': reward_loss, - 'rollout_policy_loss': mpo_policy_loss, - 'rollout_bc_policy_loss': bc_policy_loss, - 'rollout_bc_accuracy': bc_policy_acc, - 'rollout_loss': loss, - } - - return loss, logging_dict # pytype: disable=bad-return-type # jax-ndarray + [rollout_actions, embedding_rollout], + (self._model_rollout_length, num_rollouts, ...), + ) + + # Create the outputs but drop the rollout that uses action a_{T-1} (and + # thus contains state s_T) for the policy, value, and embedding because we + # don't have targets for s_T (but we do know them for the final reward). + # Also drop the rollout with s_{T-1} for the value because we don't have + # targets for that either. + return mpo_types.ModelOutputs( + policy=policy_rollout[:, :-1], # [K, R-1, ...] + value=value_rollout[:, :-2], # [K, R-2, ...] + reward=reward_rollout, # [K, R, ...] + embedding=embedding_rollout[:, :-1], + ) # [K, R-1, ...] + + def __call__( + self, + params: mpo_networks.MPONetworkParams, + dual_params: mpo_types.DualParams, + sequence: adders.Step, + state_embeddings: types.NestedArray, + targets: mpo_types.LossTargets, + key: network_lib.PRNGKey, + ) -> Tuple[jnp.ndarray, mpo_types.LogDict]: + + num_rollouts = sequence.reward.shape[0] - self._model_rollout_length + 1 + indices = jnp.arange(num_rollouts) + + # Create rollout predictions. + rollout = self._compute_model_rollout_predictions( + params=params, + state_embeddings=state_embeddings, + action_sequence=sequence.action, + ) + + # Create rollout target tensors. The rollouts will not contain the policy + # and value at t=0 because they start after taking the first action in + # the sequence, so drop those when creating the targets. They will contain + # the reward at t=0, however, because of how the sequences are stored. + # Rollout target shapes: + # - value: [N, Z, T-2] -> [N, Z, K, R-2], + # - reward: [T] -> [K, R]. + value_targets = self._rolling_window(targets.value[..., 1:], axis=-1) + reward_targets = self._rolling_window(targets.reward)[None, None, ...] + + # Define the value and reward rollout loss functions. + def value_loss_fn(root_idx) -> jnp.ndarray: + return self._distributional_loss_fn( + rollout.value[:, root_idx], # [K, R-2, ...] + value_targets[..., root_idx], + ) # [..., K, R-2] + + def reward_loss_fn(root_idx) -> jnp.ndarray: + return self._distributional_loss_fn( + rollout.reward[:, root_idx], # [K, R, ...] + reward_targets[..., root_idx], + ) # [..., K, R] + + # Reward and value losses. + critic_loss = jnp.mean(jax.vmap(value_loss_fn)(indices[:-2])) + reward_loss = jnp.mean(jax.vmap(reward_loss_fn)(indices)) + + # Define the MPO policy rollout loss. + mpo_policy_loss = 0 + if self._loss_scales.rollout.policy: + # Rollout target shapes: + # - policy: [T-1, ...] -> [K, R-1, ...], + # - q_improvement: [N, T-1] -> [N, K, R-1]. + policy_targets = self._rolling_window(targets.policy[1:]) + q_improvement = self._rolling_window(targets.q_improvement[:, 1:], axis=1) + + def policy_loss_fn(root_idx) -> jnp.ndarray: + chex.assert_shape( + ( + rollout.policy.logits, + policy_targets.logits, + ), # pytype: disable=attribute-error # numpy-scalars + (self._model_rollout_length, num_rollouts - 1, None), + ) + chex.assert_shape( + q_improvement, (None, self._model_rollout_length, num_rollouts - 1) + ) + # Compute MPO's E-step unnormalized logits. + temperature = discrete_losses.get_temperature_from_params(dual_params) + policy_target_probs = jax.nn.softmax( + jnp.transpose(q_improvement[..., root_idx]) / temperature + + jax.nn.log_softmax(policy_targets[:, root_idx].logits, axis=-1) + ) # pytype: disable=attribute-error # numpy-scalars + return softmax_cross_entropy( + rollout.policy[ + :, root_idx + ].logits, # pytype: disable=bad-return-type # numpy-scalars + jax.lax.stop_gradient(policy_target_probs), + ) + + # Compute the MPO loss and add it to the overall rollout policy loss. + mpo_policy_loss = jax.vmap(policy_loss_fn)(indices[:-1]) + mpo_policy_loss = jnp.mean(mpo_policy_loss) + + # Define the BC policy rollout loss (only supported for discrete policies). + bc_policy_loss, bc_policy_acc = 0, 0 + if self._loss_scales.rollout.bc_policy: + num_actions = rollout.policy.logits.shape[-1] # A + bc_targets = self._rolling_window( # [T-1, A] -> [K, R-1, A] + rlax.one_hot(sequence.action[1:], num_actions) + ) + + def bc_policy_loss_fn(root_idx) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Self-behavior-cloning loss (cross entropy on rollout actions).""" + chex.assert_shape( + (rollout.policy.logits, bc_targets), + (self._model_rollout_length, num_rollouts - 1, num_actions), + ) + loss = softmax_cross_entropy( + rollout.policy.logits[:, root_idx], bc_targets[:, root_idx] + ) + top1_accuracy = top1_accuracy_tiebreak( + rollout.policy.logits[:, root_idx], bc_targets[:, root_idx], rng=key + ) + return ( + loss, + top1_accuracy, + ) # pytype: disable=bad-return-type # numpy-scalars + + # Compute each rollout loss by vmapping over the rollouts. + bc_policy_loss, bc_policy_acc = jax.vmap(bc_policy_loss_fn)(indices[:-1]) + bc_policy_loss = jnp.mean(bc_policy_loss) + bc_policy_acc = jnp.mean(bc_policy_acc) + + # Combine losses. + loss = ( + self._loss_scales.rollout.policy * mpo_policy_loss + + self._loss_scales.rollout.bc_policy * bc_policy_loss + + self._loss_scales.critic * self._loss_scales.rollout.critic * critic_loss + + self._loss_scales.rollout.reward * reward_loss + ) + + logging_dict = { + "rollout_critic_loss": critic_loss, + "rollout_reward_loss": reward_loss, + "rollout_policy_loss": mpo_policy_loss, + "rollout_bc_policy_loss": bc_policy_loss, + "rollout_bc_accuracy": bc_policy_acc, + "rollout_loss": loss, + } + + return loss, logging_dict # pytype: disable=bad-return-type # jax-ndarray diff --git a/acme/agents/jax/mpo/types.py b/acme/agents/jax/mpo/types.py index 035eb76cfc..c8a8a02716 100644 --- a/acme/agents/jax/mpo/types.py +++ b/acme/agents/jax/mpo/types.py @@ -18,12 +18,12 @@ import enum from typing import Callable, Mapping, Optional, Union -from acme import types -from acme.agents.jax.mpo import categorical_mpo as discrete_losses -import acme.jax.losses.mpo as continuous_losses import distrax import jax.numpy as jnp +import acme.jax.losses.mpo as continuous_losses +from acme import types +from acme.agents.jax.mpo import categorical_mpo as discrete_losses # TODO(bshahr): consider upstreaming these to core types. NestedArray = types.NestedArray @@ -33,73 +33,77 @@ RNGKey = jnp.ndarray Entropy = jnp.ndarray LogProb = jnp.ndarray -ExperienceType = Union['FromTransitions', 'FromSequences'] +ExperienceType = Union["FromTransitions", "FromSequences"] DistributionLike = distrax.DistributionLike DistributionOrArray = Union[DistributionLike, jnp.ndarray] LogDict = Mapping[str, jnp.ndarray] -PolicyStats = Union[ - discrete_losses.CategoricalMPOStats, continuous_losses.MPOStats] -DualParams = Union[continuous_losses.MPOParams, - discrete_losses.CategoricalMPOParams] +PolicyStats = Union[discrete_losses.CategoricalMPOStats, continuous_losses.MPOStats] +DualParams = Union[continuous_losses.MPOParams, discrete_losses.CategoricalMPOParams] DistributionalLossFn = Callable[[DistributionLike, jnp.ndarray], jnp.ndarray] @dataclasses.dataclass class FromTransitions: - """Configuration for learning from n-step transitions.""" - n_step: int = 1 - # TODO(bshahr): consider adding the discount here. + """Configuration for learning from n-step transitions.""" + + n_step: int = 1 + # TODO(bshahr): consider adding the discount here. @dataclasses.dataclass class FromSequences: - """Configuration for learning from sequences.""" - sequence_length: int = 2 - sequence_period: int = 1 - # Configuration of how to bootstrap from these sequences. - n_step: Optional[int] = 5 - # Lambda used to discount future rewards as in TD(lambda), Retrace, etc. - td_lambda: Optional[float] = 1.0 + """Configuration for learning from sequences.""" + + sequence_length: int = 2 + sequence_period: int = 1 + # Configuration of how to bootstrap from these sequences. + n_step: Optional[int] = 5 + # Lambda used to discount future rewards as in TD(lambda), Retrace, etc. + td_lambda: Optional[float] = 1.0 class CriticType(enum.Enum): - """Types of critic that are supported.""" - NONDISTRIBUTIONAL = 'nondistributional' - MIXTURE_OF_GAUSSIANS = 'mixture_of_gaussians' - CATEGORICAL_2HOT = 'categorical_2hot' - CATEGORICAL = 'categorical' + """Types of critic that are supported.""" + + NONDISTRIBUTIONAL = "nondistributional" + MIXTURE_OF_GAUSSIANS = "mixture_of_gaussians" + CATEGORICAL_2HOT = "categorical_2hot" + CATEGORICAL = "categorical" class RnnCoreType(enum.Enum): - """Types of core that are supported for rnn.""" - IDENTITY = 'identity' - GRU = 'gru' + """Types of core that are supported for rnn.""" + + IDENTITY = "identity" + GRU = "gru" @dataclasses.dataclass class GaussianPolicyLossConfig: - """Configuration for the continuous (Gaussian) policy loss.""" - epsilon: float = 0.1 - epsilon_penalty: float = 0.001 - epsilon_mean: float = 0.0025 - epsilon_stddev: float = 1e-6 - init_log_temperature: float = 10. - init_log_alpha_mean: float = 10. - init_log_alpha_stddev: float = 1000. - action_penalization: bool = True - per_dim_constraining: bool = True + """Configuration for the continuous (Gaussian) policy loss.""" + + epsilon: float = 0.1 + epsilon_penalty: float = 0.001 + epsilon_mean: float = 0.0025 + epsilon_stddev: float = 1e-6 + init_log_temperature: float = 10.0 + init_log_alpha_mean: float = 10.0 + init_log_alpha_stddev: float = 1000.0 + action_penalization: bool = True + per_dim_constraining: bool = True @dataclasses.dataclass class CategoricalPolicyLossConfig: - """Configuration for the discrete (categorical) policy loss.""" - epsilon: float = 0.1 - epsilon_policy: float = 0.0025 - init_log_temperature: float = 3. - init_log_alpha: float = 3. + """Configuration for the discrete (categorical) policy loss.""" + + epsilon: float = 0.1 + epsilon_policy: float = 0.0025 + init_log_temperature: float = 3.0 + init_log_alpha: float = 3.0 PolicyLossConfig = Union[GaussianPolicyLossConfig, CategoricalPolicyLossConfig] @@ -107,43 +111,47 @@ class CategoricalPolicyLossConfig: @dataclasses.dataclass(frozen=True) class RolloutLossScalesConfig: - """Configuration for scaling the rollout losses used in the learner.""" - policy: float = 1.0 - bc_policy: float = 1.0 - critic: float = 1.0 - reward: float = 1.0 + """Configuration for scaling the rollout losses used in the learner.""" + + policy: float = 1.0 + bc_policy: float = 1.0 + critic: float = 1.0 + reward: float = 1.0 @dataclasses.dataclass(frozen=True) class LossScalesConfig: - """Configuration for scaling the rollout losses used in the learner.""" - policy: float = 1.0 - critic: float = 1.0 - rollout: Optional[RolloutLossScalesConfig] = None + """Configuration for scaling the rollout losses used in the learner.""" + + policy: float = 1.0 + critic: float = 1.0 + rollout: Optional[RolloutLossScalesConfig] = None @dataclasses.dataclass(frozen=True) class ModelOutputs: - """Container for the outputs of the model.""" - policy: Optional[types.NestedArray] = None - q_value: Optional[types.NestedArray] = None - value: Optional[types.NestedArray] = None - reward: Optional[types.NestedArray] = None - embedding: Optional[types.NestedArray] = None - recurrent_state: Optional[types.NestedArray] = None + """Container for the outputs of the model.""" + + policy: Optional[types.NestedArray] = None + q_value: Optional[types.NestedArray] = None + value: Optional[types.NestedArray] = None + reward: Optional[types.NestedArray] = None + embedding: Optional[types.NestedArray] = None + recurrent_state: Optional[types.NestedArray] = None @dataclasses.dataclass(frozen=True) class LossTargets: - """Container for the targets used to compute the model loss.""" - # Policy targets. - policy: Optional[types.NestedArray] = None - a_improvement: Optional[types.NestedArray] = None - q_improvement: Optional[types.NestedArray] = None - - # Value targets. - q_value: Optional[types.NestedArray] = None - value: Optional[types.NestedArray] = None - reward: Optional[types.NestedArray] = None - - embedding: Optional[types.NestedArray] = None + """Container for the targets used to compute the model loss.""" + + # Policy targets. + policy: Optional[types.NestedArray] = None + a_improvement: Optional[types.NestedArray] = None + q_improvement: Optional[types.NestedArray] = None + + # Value targets. + q_value: Optional[types.NestedArray] = None + value: Optional[types.NestedArray] = None + reward: Optional[types.NestedArray] = None + + embedding: Optional[types.NestedArray] = None diff --git a/acme/agents/jax/mpo/utils.py b/acme/agents/jax/mpo/utils.py index 88fd73f065..49e8b84241 100644 --- a/acme/agents/jax/mpo/utils.py +++ b/acme/agents/jax/mpo/utils.py @@ -16,40 +16,38 @@ from typing import Callable -from acme import types -from acme.adders import reverb as adders -from acme.agents.jax.mpo import types as mpo_types import distrax import jax import jax.numpy as jnp import numpy as np - import tensorflow_probability.substrates.jax as tfp + +from acme import types +from acme.adders import reverb as adders +from acme.agents.jax.mpo import types as mpo_types + tfd = tfp.distributions def _fetch_devicearray(x): - if isinstance(x, jax.Array): - return np.asarray(x) - return x + if isinstance(x, jax.Array): + return np.asarray(x) + return x def get_from_first_device(nest, as_numpy: bool = True): - """Gets the first array of a nest of `jax.pxla.ShardedDeviceArray`s.""" - # TODO(abef): remove this when fake_pmap is fixed or acme error is removed. + """Gets the first array of a nest of `jax.pxla.ShardedDeviceArray`s.""" + # TODO(abef): remove this when fake_pmap is fixed or acme error is removed. - def _slice_and_maybe_to_numpy(x): - x = x[0] - return _fetch_devicearray(x) if as_numpy else x + def _slice_and_maybe_to_numpy(x): + x = x[0] + return _fetch_devicearray(x) if as_numpy else x - return jax.tree_map(_slice_and_maybe_to_numpy, nest) + return jax.tree_map(_slice_and_maybe_to_numpy, nest) -def rolling_window(x: jnp.ndarray, - window: int, - axis: int = 0, - time_major: bool = True): - """Stack the N=T-W+1 length W slices [0:W, 1:W+1, ..., T-W:T] from a tensor. +def rolling_window(x: jnp.ndarray, window: int, axis: int = 0, time_major: bool = True): + """Stack the N=T-W+1 length W slices [0:W, 1:W+1, ..., T-W:T] from a tensor. Args: x: The tensor to select rolling slices from (along specified axis), with @@ -62,47 +60,50 @@ def rolling_window(x: jnp.ndarray, Returns: A tensor containing the stacked slices [0:W, ... T-W:T] from an axis of x. """ - sequence_length = x.shape[axis] - starts = jnp.arange(sequence_length - window + 1) - ends = jnp.arange(window) - if time_major: - idx = starts[None, :] + ends[:, None] # Output will be [..., W, N, ...]. - else: - idx = starts[:, None] + ends[None, :] # Output will be [..., N, W, ...]. - out = jnp.take(x, idx, axis=axis) - return out + sequence_length = x.shape[axis] + starts = jnp.arange(sequence_length - window + 1) + ends = jnp.arange(window) + if time_major: + idx = starts[None, :] + ends[:, None] # Output will be [..., W, N, ...]. + else: + idx = starts[:, None] + ends[None, :] # Output will be [..., N, W, ...]. + out = jnp.take(x, idx, axis=axis) + return out def tree_map_distribution( f: Callable[[mpo_types.DistributionOrArray], mpo_types.DistributionOrArray], - x: mpo_types.DistributionOrArray) -> mpo_types.DistributionOrArray: - """Apply a jax function to a distribution by treating it as tree.""" - if isinstance(x, distrax.Distribution): - safe_f = lambda y: f(y) if isinstance(y, jnp.ndarray) else y - nil, tree_data = x.tree_flatten() - new_tree_data = jax.tree_map(safe_f, tree_data) - new_x = x.tree_unflatten(new_tree_data, nil) - return new_x - elif isinstance(x, tfd.Distribution): - return jax.tree_map(f, x) - else: - return f(x) + x: mpo_types.DistributionOrArray, +) -> mpo_types.DistributionOrArray: + """Apply a jax function to a distribution by treating it as tree.""" + if isinstance(x, distrax.Distribution): + safe_f = lambda y: f(y) if isinstance(y, jnp.ndarray) else y + nil, tree_data = x.tree_flatten() + new_tree_data = jax.tree_map(safe_f, tree_data) + new_x = x.tree_unflatten(new_tree_data, nil) + return new_x + elif isinstance(x, tfd.Distribution): + return jax.tree_map(f, x) + else: + return f(x) def make_sequences_from_transitions( - transitions: types.Transition, - num_batch_dims: int = 1) -> adders.Step: - """Convert a batch of transitions into a batch of 1-step sequences.""" - stack = lambda x, y: jnp.stack((x, y), axis=num_batch_dims) - duplicate = lambda x: stack(x, x) - observation = jax.tree_map(stack, transitions.observation, - transitions.next_observation) - reward = duplicate(transitions.reward) - - return adders.Step( - observation=observation, - action=duplicate(transitions.action), - reward=reward, - discount=duplicate(transitions.discount), - start_of_episode=jnp.zeros_like(reward, dtype=jnp.bool_), - extras=jax.tree_map(duplicate, transitions.extras)) + transitions: types.Transition, num_batch_dims: int = 1 +) -> adders.Step: + """Convert a batch of transitions into a batch of 1-step sequences.""" + stack = lambda x, y: jnp.stack((x, y), axis=num_batch_dims) + duplicate = lambda x: stack(x, x) + observation = jax.tree_map( + stack, transitions.observation, transitions.next_observation + ) + reward = duplicate(transitions.reward) + + return adders.Step( + observation=observation, + action=duplicate(transitions.action), + reward=reward, + discount=duplicate(transitions.discount), + start_of_episode=jnp.zeros_like(reward, dtype=jnp.bool_), + extras=jax.tree_map(duplicate, transitions.extras), + ) diff --git a/acme/agents/jax/multiagent/decentralized/__init__.py b/acme/agents/jax/multiagent/decentralized/__init__.py index 64bc8e632e..0ec3996a7b 100644 --- a/acme/agents/jax/multiagent/decentralized/__init__.py +++ b/acme/agents/jax/multiagent/decentralized/__init__.py @@ -14,10 +14,16 @@ """Decentralized multiagent configuration.""" -from acme.agents.jax.multiagent.decentralized.builder import DecentralizedMultiAgentBuilder -from acme.agents.jax.multiagent.decentralized.config import DecentralizedMultiagentConfig -from acme.agents.jax.multiagent.decentralized.factories import builder_factory -from acme.agents.jax.multiagent.decentralized.factories import default_config_factory -from acme.agents.jax.multiagent.decentralized.factories import DefaultSupportedAgent -from acme.agents.jax.multiagent.decentralized.factories import network_factory -from acme.agents.jax.multiagent.decentralized.factories import policy_network_factory +from acme.agents.jax.multiagent.decentralized.builder import ( + DecentralizedMultiAgentBuilder, +) +from acme.agents.jax.multiagent.decentralized.config import ( + DecentralizedMultiagentConfig, +) +from acme.agents.jax.multiagent.decentralized.factories import ( + DefaultSupportedAgent, + builder_factory, + default_config_factory, + network_factory, + policy_network_factory, +) diff --git a/acme/agents/jax/multiagent/decentralized/actor.py b/acme/agents/jax/multiagent/decentralized/actor.py index 6f5df02c79..7a26fdbf04 100644 --- a/acme/agents/jax/multiagent/decentralized/actor.py +++ b/acme/agents/jax/multiagent/decentralized/actor.py @@ -16,43 +16,47 @@ from typing import Dict +import dm_env + from acme import core from acme.jax import networks from acme.multiagent import types as ma_types from acme.multiagent import utils as ma_utils -import dm_env class SimultaneousActingMultiAgentActor(core.Actor): - """Simultaneous-move actor (see README.md for expected environment interface).""" + """Simultaneous-move actor (see README.md for expected environment interface).""" - def __init__(self, actors: Dict[ma_types.AgentID, core.Actor]): - """Initializer. + def __init__(self, actors: Dict[ma_types.AgentID, core.Actor]): + """Initializer. Args: actors: a dict specifying sub-actors. """ - self._actors = actors - - def select_action( - self, observation: Dict[ma_types.AgentID, networks.Observation] - ) -> Dict[ma_types.AgentID, networks.Action]: - return { - actor_id: actor.select_action(observation[actor_id]) - for actor_id, actor in self._actors.items() - } - - def observe_first(self, timestep: dm_env.TimeStep): - for actor_id, actor in self._actors.items(): - sub_timestep = ma_utils.get_agent_timestep(timestep, actor_id) - actor.observe_first(sub_timestep) - - def observe(self, actions: Dict[ma_types.AgentID, networks.Action], - next_timestep: dm_env.TimeStep): - for actor_id, actor in self._actors.items(): - sub_next_timestep = ma_utils.get_agent_timestep(next_timestep, actor_id) - actor.observe(actions[actor_id], sub_next_timestep) - - def update(self, wait: bool = False): - for actor in self._actors.values(): - actor.update(wait=wait) + self._actors = actors + + def select_action( + self, observation: Dict[ma_types.AgentID, networks.Observation] + ) -> Dict[ma_types.AgentID, networks.Action]: + return { + actor_id: actor.select_action(observation[actor_id]) + for actor_id, actor in self._actors.items() + } + + def observe_first(self, timestep: dm_env.TimeStep): + for actor_id, actor in self._actors.items(): + sub_timestep = ma_utils.get_agent_timestep(timestep, actor_id) + actor.observe_first(sub_timestep) + + def observe( + self, + actions: Dict[ma_types.AgentID, networks.Action], + next_timestep: dm_env.TimeStep, + ): + for actor_id, actor in self._actors.items(): + sub_next_timestep = ma_utils.get_agent_timestep(next_timestep, actor_id) + actor.observe(actions[actor_id], sub_next_timestep) + + def update(self, wait: bool = False): + for actor in self._actors.values(): + actor.update(wait=wait) diff --git a/acme/agents/jax/multiagent/decentralized/builder.py b/acme/agents/jax/multiagent/decentralized/builder.py index 8a4624e7f9..17645e51a6 100644 --- a/acme/agents/jax/multiagent/decentralized/builder.py +++ b/acme/agents/jax/multiagent/decentralized/builder.py @@ -16,51 +16,51 @@ from typing import Dict, Iterator, List, Mapping, Optional, Sequence -from acme import adders -from acme import core -from acme import specs -from acme import types +import jax +import reverb + +from acme import adders, core, specs, types from acme.agents.jax import builders as acme_builders from acme.agents.jax.multiagent.decentralized import actor -from acme.agents.jax.multiagent.decentralized import factories as decentralized_factories +from acme.agents.jax.multiagent.decentralized import ( + factories as decentralized_factories, +) from acme.agents.jax.multiagent.decentralized import learner_set from acme.jax import networks as networks_lib from acme.multiagent import types as ma_types from acme.multiagent import utils as ma_utils -from acme.utils import counting -from acme.utils import iterator_utils -from acme.utils import loggers -import jax -import reverb - +from acme.utils import counting, iterator_utils, loggers -VARIABLE_SEPARATOR = '-' +VARIABLE_SEPARATOR = "-" class PrefixedVariableSource(core.VariableSource): - """Wraps a variable source to add a pre-defined prefix to all names.""" + """Wraps a variable source to add a pre-defined prefix to all names.""" - def __init__(self, source: core.VariableSource, prefix: str): - self._source = source - self._prefix = prefix + def __init__(self, source: core.VariableSource, prefix: str): + self._source = source + self._prefix = prefix - def get_variables(self, names: Sequence[str]) -> List[types.NestedArray]: - return self._source.get_variables([self._prefix + name for name in names]) + def get_variables(self, names: Sequence[str]) -> List[types.NestedArray]: + return self._source.get_variables([self._prefix + name for name in names]) class DecentralizedMultiAgentBuilder( acme_builders.GenericActorLearnerBuilder[ ma_types.MultiAgentNetworks, ma_types.MultiAgentPolicyNetworks, - ma_types.MultiAgentSample]): - """Builder for decentralized multiagent setup.""" - - def __init__( - self, - agent_types: Dict[ma_types.AgentID, ma_types.GenericAgent], - agent_configs: Dict[ma_types.AgentID, ma_types.AgentConfig], - init_policy_network_fn: Optional[ma_types.InitPolicyNetworkFn] = None): - """Initializer. + ma_types.MultiAgentSample, + ] +): + """Builder for decentralized multiagent setup.""" + + def __init__( + self, + agent_types: Dict[ma_types.AgentID, ma_types.GenericAgent], + agent_configs: Dict[ma_types.AgentID, ma_types.AgentConfig], + init_policy_network_fn: Optional[ma_types.InitPolicyNetworkFn] = None, + ): + """Initializer. Args: agent_types: Dict mapping agent IDs to their types. @@ -69,19 +69,20 @@ def __init__( function. """ - self._agent_types = agent_types - self._agent_configs = agent_configs - self._builders = decentralized_factories.builder_factory( - agent_types, agent_configs) - self._num_agents = len(self._builders) - self._init_policy_network_fn = init_policy_network_fn - - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: ma_types.MultiAgentPolicyNetworks, - ) -> List[reverb.Table]: - """Returns replay tables for all agents. + self._agent_types = agent_types + self._agent_configs = agent_configs + self._builders = decentralized_factories.builder_factory( + agent_types, agent_configs + ) + self._num_agents = len(self._builders) + self._init_policy_network_fn = init_policy_network_fn + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: ma_types.MultiAgentPolicyNetworks, + ) -> List[reverb.Table]: + """Returns replay tables for all agents. Args: environment_spec: the (multiagent) environment spec, which will be @@ -89,35 +90,36 @@ def make_replay_tables( policy: the (multiagent) mapping from agent ID to the corresponding agent's policy, used to get the correct extras_spec. """ - replay_tables = [] - for agent_id, builder in self._builders.items(): - single_agent_spec = ma_utils.get_agent_spec(environment_spec, agent_id) - replay_tables += builder.make_replay_tables(single_agent_spec, - policy[agent_id]) - return replay_tables - - def make_dataset_iterator( - self, - replay_client: reverb.Client) -> Iterator[ma_types.MultiAgentSample]: - # Zipping stores sub-iterators in the order dictated by - # self._builders.values(), which are insertion-ordered in Python3.7+. - # Hence, later unzipping (in make_learner()) and accessing the iterators - # via the same self._builders.items() dict ordering should be safe. - return zip(*[ - b.make_dataset_iterator(replay_client) for b in self._builders.values() - ]) - - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: ma_types.MultiAgentNetworks, - dataset: Iterator[ma_types.MultiAgentSample], - logger_fn: loggers.LoggerFactory, - environment_spec: Optional[specs.EnvironmentSpec] = None, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None - ) -> learner_set.SynchronousDecentralizedLearnerSet: - """Returns multiagent learner set. + replay_tables = [] + for agent_id, builder in self._builders.items(): + single_agent_spec = ma_utils.get_agent_spec(environment_spec, agent_id) + replay_tables += builder.make_replay_tables( + single_agent_spec, policy[agent_id] + ) + return replay_tables + + def make_dataset_iterator( + self, replay_client: reverb.Client + ) -> Iterator[ma_types.MultiAgentSample]: + # Zipping stores sub-iterators in the order dictated by + # self._builders.values(), which are insertion-ordered in Python3.7+. + # Hence, later unzipping (in make_learner()) and accessing the iterators + # via the same self._builders.items() dict ordering should be safe. + return zip( + *[b.make_dataset_iterator(replay_client) for b in self._builders.values()] + ) + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: ma_types.MultiAgentNetworks, + dataset: Iterator[ma_types.MultiAgentSample], + logger_fn: loggers.LoggerFactory, + environment_spec: Optional[specs.EnvironmentSpec] = None, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> learner_set.SynchronousDecentralizedLearnerSet: + """Returns multiagent learner set. Args: random_key: random key. @@ -131,59 +133,62 @@ def make_learner( counter: a Counter which allows for recording of counts (learner steps, actor steps, etc.) distributed throughout the agent. """ - parent_counter = counter or counting.Counter() - sub_learners = {} - unzipped_dataset = iterator_utils.unzip_iterators( - dataset, num_sub_iterators=self._num_agents) - - def make_logger_fn(agent_id: str) -> loggers.LoggerFactory: - """Returns a logger factory for the subagent with the given id.""" - - def logger_factory( - label: loggers.LoggerLabel, - steps_key: Optional[loggers.LoggerStepsKey] = None, - instance: Optional[loggers.TaskInstance] = None) -> loggers.Logger: - return logger_fn(f'{label}{agent_id}', steps_key, instance) - - return logger_factory - - for i_dataset, (agent_id, builder) in enumerate(self._builders.items()): - counter = counting.Counter(parent_counter, prefix=f'{agent_id}') - single_agent_spec = ma_utils.get_agent_spec(environment_spec, agent_id) - random_key, learner_key = jax.random.split(random_key) - sub_learners[agent_id] = builder.make_learner( - learner_key, - networks[agent_id], - unzipped_dataset[i_dataset], - logger_fn=make_logger_fn(agent_id), - environment_spec=single_agent_spec, - replay_client=replay_client, - counter=counter) - return learner_set.SynchronousDecentralizedLearnerSet( - sub_learners, separator=VARIABLE_SEPARATOR) - - def make_adder( # Internal pytype check. - self, - replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec] = None, - policy: Optional[ma_types.MultiAgentPolicyNetworks] = None, - ) -> Mapping[ma_types.AgentID, Optional[adders.Adder]]: - del environment_spec, policy # Unused. - return { - agent_id: - b.make_adder(replay_client, environment_spec=None, policy=None) - for agent_id, b in self._builders.items() - } - - def make_actor( # Internal pytype check. - self, - random_key: networks_lib.PRNGKey, - policy: ma_types.MultiAgentPolicyNetworks, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - adder: Optional[Mapping[ma_types.AgentID, adders.Adder]] = None, - ) -> core.Actor: - """Returns simultaneous-acting multiagent actor instance. + parent_counter = counter or counting.Counter() + sub_learners = {} + unzipped_dataset = iterator_utils.unzip_iterators( + dataset, num_sub_iterators=self._num_agents + ) + + def make_logger_fn(agent_id: str) -> loggers.LoggerFactory: + """Returns a logger factory for the subagent with the given id.""" + + def logger_factory( + label: loggers.LoggerLabel, + steps_key: Optional[loggers.LoggerStepsKey] = None, + instance: Optional[loggers.TaskInstance] = None, + ) -> loggers.Logger: + return logger_fn(f"{label}{agent_id}", steps_key, instance) + + return logger_factory + + for i_dataset, (agent_id, builder) in enumerate(self._builders.items()): + counter = counting.Counter(parent_counter, prefix=f"{agent_id}") + single_agent_spec = ma_utils.get_agent_spec(environment_spec, agent_id) + random_key, learner_key = jax.random.split(random_key) + sub_learners[agent_id] = builder.make_learner( + learner_key, + networks[agent_id], + unzipped_dataset[i_dataset], + logger_fn=make_logger_fn(agent_id), + environment_spec=single_agent_spec, + replay_client=replay_client, + counter=counter, + ) + return learner_set.SynchronousDecentralizedLearnerSet( + sub_learners, separator=VARIABLE_SEPARATOR + ) + + def make_adder( # Internal pytype check. + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec] = None, + policy: Optional[ma_types.MultiAgentPolicyNetworks] = None, + ) -> Mapping[ma_types.AgentID, Optional[adders.Adder]]: + del environment_spec, policy # Unused. + return { + agent_id: b.make_adder(replay_client, environment_spec=None, policy=None) + for agent_id, b in self._builders.items() + } + + def make_actor( # Internal pytype check. + self, + random_key: networks_lib.PRNGKey, + policy: ma_types.MultiAgentPolicyNetworks, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[Mapping[ma_types.AgentID, adders.Adder]] = None, + ) -> core.Actor: + """Returns simultaneous-acting multiagent actor instance. Args: random_key: random key. @@ -195,32 +200,38 @@ def make_actor( # Internal pytype check. variables from variable_source. adder: how data is recorded (e.g., added to replay) for each actor. """ - if adder is None: - adder = {agent_id: None for agent_id in policy.keys()} - - sub_actors = {} - for agent_id, builder in self._builders.items(): - single_agent_spec = ma_utils.get_agent_spec(environment_spec, agent_id) - random_key, actor_key = jax.random.split(random_key) - # Adds a prefix to each sub-actor's variable names to ensure the correct - # sub-learner is queried for variables. - sub_variable_source = PrefixedVariableSource( - variable_source, f'{agent_id}{VARIABLE_SEPARATOR}') - sub_actors[agent_id] = builder.make_actor(actor_key, policy[agent_id], - single_agent_spec, - sub_variable_source, - adder[agent_id]) - return actor.SimultaneousActingMultiAgentActor(sub_actors) - - def make_policy( - self, - networks: ma_types.MultiAgentNetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> ma_types.MultiAgentPolicyNetworks: - return decentralized_factories.policy_network_factory( - networks, - environment_spec, - self._agent_types, - self._agent_configs, - eval_mode=evaluation, - init_policy_network_fn=self._init_policy_network_fn) + if adder is None: + adder = {agent_id: None for agent_id in policy.keys()} + + sub_actors = {} + for agent_id, builder in self._builders.items(): + single_agent_spec = ma_utils.get_agent_spec(environment_spec, agent_id) + random_key, actor_key = jax.random.split(random_key) + # Adds a prefix to each sub-actor's variable names to ensure the correct + # sub-learner is queried for variables. + sub_variable_source = PrefixedVariableSource( + variable_source, f"{agent_id}{VARIABLE_SEPARATOR}" + ) + sub_actors[agent_id] = builder.make_actor( + actor_key, + policy[agent_id], + single_agent_spec, + sub_variable_source, + adder[agent_id], + ) + return actor.SimultaneousActingMultiAgentActor(sub_actors) + + def make_policy( + self, + networks: ma_types.MultiAgentNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> ma_types.MultiAgentPolicyNetworks: + return decentralized_factories.policy_network_factory( + networks, + environment_spec, + self._agent_types, + self._agent_configs, + eval_mode=evaluation, + init_policy_network_fn=self._init_policy_network_fn, + ) diff --git a/acme/agents/jax/multiagent/decentralized/config.py b/acme/agents/jax/multiagent/decentralized/config.py index 9c0cc5698d..50ab69dc3b 100644 --- a/acme/agents/jax/multiagent/decentralized/config.py +++ b/acme/agents/jax/multiagent/decentralized/config.py @@ -22,7 +22,8 @@ @dataclasses.dataclass class DecentralizedMultiagentConfig: - """Configuration options for decentralized multiagent.""" - sub_agent_configs: Dict[types.AgentID, types.AgentConfig] - batch_size: int = 256 - prefetch_size: int = 2 + """Configuration options for decentralized multiagent.""" + + sub_agent_configs: Dict[types.AgentID, types.AgentConfig] + batch_size: int = 256 + prefetch_size: int = 2 diff --git a/acme/agents/jax/multiagent/decentralized/factories.py b/acme/agents/jax/multiagent/decentralized/factories.py index b215d54d7d..a10d231a71 100644 --- a/acme/agents/jax/multiagent/decentralized/factories.py +++ b/acme/agents/jax/multiagent/decentralized/factories.py @@ -23,32 +23,31 @@ from acme import specs from acme.adders import reverb as adders_reverb from acme.agents.jax import builders as jax_builders -from acme.agents.jax import ppo -from acme.agents.jax import sac -from acme.agents.jax import td3 +from acme.agents.jax import ppo, sac, td3 from acme.multiagent import types as ma_types from acme.multiagent import utils as ma_utils class DefaultSupportedAgent(enum.Enum): - """Agents which have default initializers supported below.""" - TD3 = 'TD3' - SAC = 'SAC' - PPO = 'PPO' + """Agents which have default initializers supported below.""" + + TD3 = "TD3" + SAC = "SAC" + PPO = "PPO" def init_default_network( - agent_type: DefaultSupportedAgent, - agent_spec: specs.EnvironmentSpec) -> ma_types.Networks: - """Returns default networks for a single agent.""" - if agent_type == DefaultSupportedAgent.TD3: - return td3.make_networks(agent_spec) - elif agent_type == DefaultSupportedAgent.SAC: - return sac.make_networks(agent_spec) - elif agent_type == DefaultSupportedAgent.PPO: - return ppo.make_networks(agent_spec) - else: - raise ValueError(f'Unsupported agent type: {agent_type}.') + agent_type: DefaultSupportedAgent, agent_spec: specs.EnvironmentSpec +) -> ma_types.Networks: + """Returns default networks for a single agent.""" + if agent_type == DefaultSupportedAgent.TD3: + return td3.make_networks(agent_spec) + elif agent_type == DefaultSupportedAgent.SAC: + return sac.make_networks(agent_spec) + elif agent_type == DefaultSupportedAgent.PPO: + return ppo.make_networks(agent_spec) + else: + raise ValueError(f"Unsupported agent type: {agent_type}.") def init_default_policy_network( @@ -56,58 +55,57 @@ def init_default_policy_network( network: ma_types.Networks, agent_spec: specs.EnvironmentSpec, config: ma_types.AgentConfig, - eval_mode: ma_types.EvalMode = False) -> ma_types.PolicyNetwork: - """Returns default policy network for a single agent.""" - if agent_type == DefaultSupportedAgent.TD3: - sigma = 0. if eval_mode else config.sigma - return td3.get_default_behavior_policy( - network, agent_spec.actions, sigma=sigma) - elif agent_type == DefaultSupportedAgent.SAC: - return sac.apply_policy_and_sample(network, eval_mode=eval_mode) - elif agent_type == DefaultSupportedAgent.PPO: - return ppo.make_inference_fn(network, evaluation=eval_mode) - else: - raise ValueError(f'Unsupported agent type: {agent_type}.') + eval_mode: ma_types.EvalMode = False, +) -> ma_types.PolicyNetwork: + """Returns default policy network for a single agent.""" + if agent_type == DefaultSupportedAgent.TD3: + sigma = 0.0 if eval_mode else config.sigma + return td3.get_default_behavior_policy(network, agent_spec.actions, sigma=sigma) + elif agent_type == DefaultSupportedAgent.SAC: + return sac.apply_policy_and_sample(network, eval_mode=eval_mode) + elif agent_type == DefaultSupportedAgent.PPO: + return ppo.make_inference_fn(network, evaluation=eval_mode) + else: + raise ValueError(f"Unsupported agent type: {agent_type}.") def init_default_builder( - agent_type: DefaultSupportedAgent, - agent_config: ma_types.AgentConfig, + agent_type: DefaultSupportedAgent, agent_config: ma_types.AgentConfig, ) -> jax_builders.GenericActorLearnerBuilder: - """Returns default builder for a single agent.""" - if agent_type == DefaultSupportedAgent.TD3: - assert isinstance(agent_config, td3.TD3Config) - return td3.TD3Builder(agent_config) - elif agent_type == DefaultSupportedAgent.SAC: - assert isinstance(agent_config, sac.SACConfig) - return sac.SACBuilder(agent_config) - elif agent_type == DefaultSupportedAgent.PPO: - assert isinstance(agent_config, ppo.PPOConfig) - return ppo.PPOBuilder(agent_config) - else: - raise ValueError(f'Unsupported agent type: {agent_type}.') + """Returns default builder for a single agent.""" + if agent_type == DefaultSupportedAgent.TD3: + assert isinstance(agent_config, td3.TD3Config) + return td3.TD3Builder(agent_config) + elif agent_type == DefaultSupportedAgent.SAC: + assert isinstance(agent_config, sac.SACConfig) + return sac.SACBuilder(agent_config) + elif agent_type == DefaultSupportedAgent.PPO: + assert isinstance(agent_config, ppo.PPOConfig) + return ppo.PPOBuilder(agent_config) + else: + raise ValueError(f"Unsupported agent type: {agent_type}.") def init_default_config( - agent_type: DefaultSupportedAgent, - config_overrides: Dict[str, Any]) -> ma_types.AgentConfig: - """Returns default config for a single agent.""" - if agent_type == DefaultSupportedAgent.TD3: - return td3.TD3Config(**config_overrides) - elif agent_type == DefaultSupportedAgent.SAC: - return sac.SACConfig(**config_overrides) - elif agent_type == DefaultSupportedAgent.PPO: - return ppo.PPOConfig(**config_overrides) - else: - raise ValueError(f'Unsupported agent type: {agent_type}.') + agent_type: DefaultSupportedAgent, config_overrides: Dict[str, Any] +) -> ma_types.AgentConfig: + """Returns default config for a single agent.""" + if agent_type == DefaultSupportedAgent.TD3: + return td3.TD3Config(**config_overrides) + elif agent_type == DefaultSupportedAgent.SAC: + return sac.SACConfig(**config_overrides) + elif agent_type == DefaultSupportedAgent.PPO: + return ppo.PPOConfig(**config_overrides) + else: + raise ValueError(f"Unsupported agent type: {agent_type}.") def default_config_factory( agent_types: Dict[ma_types.AgentID, DefaultSupportedAgent], batch_size: int, - config_overrides: Optional[Dict[ma_types.AgentID, Dict[str, Any]]] = None + config_overrides: Optional[Dict[ma_types.AgentID, Dict[str, Any]]] = None, ) -> Dict[ma_types.AgentID, ma_types.AgentConfig]: - """Returns default configs for all agents. + """Returns default configs for all agents. Args: agent_types: dict mapping agent IDs to their type. @@ -116,30 +114,30 @@ def default_config_factory( config overrides. This should include any mandatory config parameters for the agents that do not have default values. """ - configs = {} - for agent_id, agent_type in agent_types.items(): - agent_config_overrides = dict( - # batch_size is required by LocalLayout, which is shared amongst - # the agents. Hence, we enforce a shared batch_size in builders. - batch_size=batch_size, - # Unique replay_table_name per agent. - replay_table_name=f'{adders_reverb.DEFAULT_PRIORITY_TABLE}_agent{agent_id}' - ) - if config_overrides is not None and agent_id in config_overrides: - agent_config_overrides = { - **config_overrides[agent_id], - **agent_config_overrides # Comes second to ensure batch_size override - } - configs[agent_id] = init_default_config(agent_type, agent_config_overrides) - return configs + configs = {} + for agent_id, agent_type in agent_types.items(): + agent_config_overrides = dict( + # batch_size is required by LocalLayout, which is shared amongst + # the agents. Hence, we enforce a shared batch_size in builders. + batch_size=batch_size, + # Unique replay_table_name per agent. + replay_table_name=f"{adders_reverb.DEFAULT_PRIORITY_TABLE}_agent{agent_id}", + ) + if config_overrides is not None and agent_id in config_overrides: + agent_config_overrides = { + **config_overrides[agent_id], + **agent_config_overrides, # Comes second to ensure batch_size override + } + configs[agent_id] = init_default_config(agent_type, agent_config_overrides) + return configs def network_factory( environment_spec: specs.EnvironmentSpec, agent_types: Dict[ma_types.AgentID, ma_types.GenericAgent], - init_network_fn: Optional[ma_types.InitNetworkFn] = None + init_network_fn: Optional[ma_types.InitNetworkFn] = None, ) -> ma_types.MultiAgentNetworks: - """Returns networks for all agents. + """Returns networks for all agents. Args: environment_spec: environment spec. @@ -148,12 +146,12 @@ def network_factory( for all sub-agents. If this is not supplied, a default network initializer is used (if it is supported for the designated agent type). """ - init_fn = init_network_fn or init_default_network - networks = {} - for agent_id, agent_type in agent_types.items(): - single_agent_spec = ma_utils.get_agent_spec(environment_spec, agent_id) - networks[agent_id] = init_fn(agent_type, single_agent_spec) - return networks + init_fn = init_network_fn or init_default_network + networks = {} + for agent_id, agent_type in agent_types.items(): + single_agent_spec = ma_utils.get_agent_spec(environment_spec, agent_id) + networks[agent_id] = init_fn(agent_type, single_agent_spec) + return networks def policy_network_factory( @@ -162,9 +160,9 @@ def policy_network_factory( agent_types: Dict[ma_types.AgentID, ma_types.GenericAgent], agent_configs: Dict[ma_types.AgentID, ma_types.AgentConfig], eval_mode: ma_types.EvalMode, - init_policy_network_fn: Optional[ma_types.InitPolicyNetworkFn] = None + init_policy_network_fn: Optional[ma_types.InitPolicyNetworkFn] = None, ) -> ma_types.MultiAgentPolicyNetworks: - """Returns default policy networks for all agents. + """Returns default policy networks for all agents. Args: networks: dict mapping agent IDs to their networks. @@ -178,24 +176,28 @@ def policy_network_factory( policy network initializer is used (if it is supported for the designated agent type). """ - init_fn = init_policy_network_fn or init_default_policy_network - policy_networks = {} - for agent_id, agent_type in agent_types.items(): - single_agent_spec = ma_utils.get_agent_spec(environment_spec, agent_id) - policy_networks[agent_id] = init_fn(agent_type, networks[agent_id], - single_agent_spec, - agent_configs[agent_id], eval_mode) - return policy_networks + init_fn = init_policy_network_fn or init_default_policy_network + policy_networks = {} + for agent_id, agent_type in agent_types.items(): + single_agent_spec = ma_utils.get_agent_spec(environment_spec, agent_id) + policy_networks[agent_id] = init_fn( + agent_type, + networks[agent_id], + single_agent_spec, + agent_configs[agent_id], + eval_mode, + ) + return policy_networks def builder_factory( agent_types: Dict[ma_types.AgentID, ma_types.GenericAgent], agent_configs: Dict[ma_types.AgentID, ma_types.AgentConfig], - init_builder_fn: Optional[ma_types.InitBuilderFn] = None + init_builder_fn: Optional[ma_types.InitBuilderFn] = None, ) -> Dict[ma_types.AgentID, jax_builders.GenericActorLearnerBuilder]: - """Returns default policy networks for all agents.""" - init_fn = init_builder_fn or init_default_builder - builders = {} - for agent_id, agent_type in agent_types.items(): - builders[agent_id] = init_fn(agent_type, agent_configs[agent_id]) - return builders + """Returns default policy networks for all agents.""" + init_fn = init_builder_fn or init_default_builder + builders = {} + for agent_id, agent_type in agent_types.items(): + builders[agent_id] = init_fn(agent_type, agent_configs[agent_id]) + return builders diff --git a/acme/agents/jax/multiagent/decentralized/learner_set.py b/acme/agents/jax/multiagent/decentralized/learner_set.py index 110bb3b930..89a131a9a4 100644 --- a/acme/agents/jax/multiagent/decentralized/learner_set.py +++ b/acme/agents/jax/multiagent/decentralized/learner_set.py @@ -17,9 +17,7 @@ import dataclasses from typing import Any, Dict, List -from acme import core -from acme import types - +from acme import core, types from acme.multiagent import types as ma_types LearnerState = Any @@ -27,32 +25,33 @@ @dataclasses.dataclass class SynchronousDecentralizedLearnerSetState: - """State of a SynchronousDecentralizedLearnerSet.""" - # States of the learners keyed by their names. - learner_states: Dict[ma_types.AgentID, LearnerState] + """State of a SynchronousDecentralizedLearnerSet.""" + + # States of the learners keyed by their names. + learner_states: Dict[ma_types.AgentID, LearnerState] class SynchronousDecentralizedLearnerSet(core.Learner): - """Creates a composed learner which wraps a set of local agent learners.""" + """Creates a composed learner which wraps a set of local agent learners.""" - def __init__(self, - learners: Dict[ma_types.AgentID, core.Learner], - separator: str = '-'): - """Initializer. + def __init__( + self, learners: Dict[ma_types.AgentID, core.Learner], separator: str = "-" + ): + """Initializer. Args: learners: a dict specifying the learners for all sub-agents. separator: separator character used to disambiguate sub-learner variables. """ - self._learners = learners - self._separator = separator + self._learners = learners + self._separator = separator - def step(self): - for learner in self._learners.values(): - learner.step() + def step(self): + for learner in self._learners.values(): + learner.step() - def get_variables(self, names: List[str]) -> List[types.NestedArray]: - """Return the named variables as a collection of (nested) numpy arrays. + def get_variables(self, names: List[str]) -> List[types.NestedArray]: + """Return the named variables as a collection of (nested) numpy arrays. The variable names should be prefixed with the name of the child learners using the separator specified in the constructor, e.g. learner1/var. @@ -67,19 +66,21 @@ def get_variables(self, names: List[str]) -> List[types.NestedArray]: A list of (nested) numpy arrays `variables` such that `variables[i]` corresponds to the collection named by `names[i]`. """ - variables = [] - for name in names: - # Note: if separator is not found, learner_name=name, which is OK. - learner_id, _, variable_name = name.partition(self._separator) - learner = self._learners[learner_id] - variables.extend(learner.get_variables([variable_name])) - return variables - - def save(self) -> SynchronousDecentralizedLearnerSetState: - return SynchronousDecentralizedLearnerSetState(learner_states={ - name: learner.save() for name, learner in self._learners.items() - }) - - def restore(self, state: SynchronousDecentralizedLearnerSetState): - for name, learner in self._learners.items(): - learner.restore(state.learner_states[name]) + variables = [] + for name in names: + # Note: if separator is not found, learner_name=name, which is OK. + learner_id, _, variable_name = name.partition(self._separator) + learner = self._learners[learner_id] + variables.extend(learner.get_variables([variable_name])) + return variables + + def save(self) -> SynchronousDecentralizedLearnerSetState: + return SynchronousDecentralizedLearnerSetState( + learner_states={ + name: learner.save() for name, learner in self._learners.items() + } + ) + + def restore(self, state: SynchronousDecentralizedLearnerSetState): + for name, learner in self._learners.items(): + learner.restore(state.learner_states[name]) diff --git a/acme/agents/jax/normalization.py b/acme/agents/jax/normalization.py index a98a26e729..bf5ed6448d 100644 --- a/acme/agents/jax/normalization.py +++ b/acme/agents/jax/normalization.py @@ -18,327 +18,355 @@ import functools from typing import Any, Callable, Generic, Iterator, List, Optional, Tuple -import acme -from acme import adders -from acme import core -from acme import specs -from acme import types -from acme.agents.jax import builders -from acme.jax import networks as networks_lib -from acme.jax import running_statistics -from acme.jax import variable_utils -from acme.jax.types import Networks, Policy # pylint: disable=g-multiple-import -from acme.utils import counting -from acme.utils import loggers import dm_env import jax import reverb from typing_extensions import Protocol -_NORMALIZATION_VARIABLES = 'normalization_variables' +import acme +from acme import adders, core, specs, types +from acme.agents.jax import builders +from acme.jax import networks as networks_lib +from acme.jax import running_statistics, variable_utils +from acme.jax.types import Networks, Policy # pylint: disable=g-multiple-import +from acme.utils import counting, loggers + +_NORMALIZATION_VARIABLES = "normalization_variables" # Wrapping the network instead might look more straightforward, but then # different implementations would be needed for feed-forward and # recurrent networks. class NormalizationActorWrapper(core.Actor): - """An actor wrapper that normalizes observations before applying policy.""" - - def __init__(self, - wrapped_actor: core.Actor, - variable_source: core.VariableSource, - max_abs_observation: Optional[float], - update_period: int = 1, - backend: Optional[str] = None): - self._wrapped_actor = wrapped_actor - self._variable_client = variable_utils.VariableClient( - variable_source, - key=_NORMALIZATION_VARIABLES, - update_period=update_period, - device=backend) - self._apply_normalization = jax.jit( - functools.partial( - running_statistics.normalize, max_abs_value=max_abs_observation), - backend=backend) - - def select_action(self, observation: types.NestedArray) -> types.NestedArray: - self._variable_client.update() - observation_stats = self._variable_client.params - observation = self._apply_normalization(observation, observation_stats) - return self._wrapped_actor.select_action(observation) - - def observe_first(self, timestep: dm_env.TimeStep): - return self._wrapped_actor.observe_first(timestep) - - def observe( - self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - ): - return self._wrapped_actor.observe(action, next_timestep) - - def update(self, wait: bool = False): - return self._wrapped_actor.update(wait) + """An actor wrapper that normalizes observations before applying policy.""" + + def __init__( + self, + wrapped_actor: core.Actor, + variable_source: core.VariableSource, + max_abs_observation: Optional[float], + update_period: int = 1, + backend: Optional[str] = None, + ): + self._wrapped_actor = wrapped_actor + self._variable_client = variable_utils.VariableClient( + variable_source, + key=_NORMALIZATION_VARIABLES, + update_period=update_period, + device=backend, + ) + self._apply_normalization = jax.jit( + functools.partial( + running_statistics.normalize, max_abs_value=max_abs_observation + ), + backend=backend, + ) + + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + self._variable_client.update() + observation_stats = self._variable_client.params + observation = self._apply_normalization(observation, observation_stats) + return self._wrapped_actor.select_action(observation) + + def observe_first(self, timestep: dm_env.TimeStep): + return self._wrapped_actor.observe_first(timestep) + + def observe( + self, action: types.NestedArray, next_timestep: dm_env.TimeStep, + ): + return self._wrapped_actor.observe(action, next_timestep) + + def update(self, wait: bool = False): + return self._wrapped_actor.update(wait) @dataclasses.dataclass class NormalizationLearnerWrapperState: - wrapped_learner_state: Any - observation_running_statistics: running_statistics.RunningStatisticsState + wrapped_learner_state: Any + observation_running_statistics: running_statistics.RunningStatisticsState class NormalizationLearnerWrapper(core.Learner, core.Saveable): - """A learner wrapper that normalizes observations using running statistics.""" - - def __init__(self, learner_factory: Callable[[Iterator[reverb.ReplaySample]], - acme.Learner], - iterator: Iterator[reverb.ReplaySample], - environment_spec: specs.EnvironmentSpec, - max_abs_observation: Optional[float]): - - def normalize_sample( - observation_statistics: running_statistics.RunningStatisticsState, - sample: reverb.ReplaySample - ) -> Tuple[running_statistics.RunningStatisticsState, reverb.ReplaySample]: - observation = sample.data.observation - observation_statistics = running_statistics.update( - observation_statistics, observation) - observation = running_statistics.normalize( - observation, - observation_statistics, - max_abs_value=max_abs_observation) - sample = reverb.ReplaySample( - sample.info, sample.data._replace(observation=observation)) - if hasattr(sample.data, 'next_observation'): - next_observation = running_statistics.normalize( - sample.data.next_observation, - observation_statistics, - max_abs_value=max_abs_observation) - sample = reverb.ReplaySample( - sample.info, - sample.data._replace(next_observation=next_observation)) - - return observation_statistics, sample - - self._observation_running_statistics = running_statistics.init_state( - environment_spec.observations) - self._normalize_sample = jax.jit(normalize_sample) - - normalizing_iterator = ( - self._normalize_sample_and_update(sample) for sample in iterator) - self._wrapped_learner = learner_factory(normalizing_iterator) - - def _normalize_sample_and_update( - self, sample: reverb.ReplaySample) -> reverb.ReplaySample: - self._observation_running_statistics, sample = self._normalize_sample( - self._observation_running_statistics, sample) - return sample - - def step(self): - self._wrapped_learner.step() - - def get_variables(self, names: List[str]) -> List[types.NestedArray]: - stats = self._observation_running_statistics - # Make sure to only pass mean and std to minimize trafic. - mean_std = running_statistics.NestedMeanStd(mean=stats.mean, std=stats.std) - normalization_variables = {_NORMALIZATION_VARIABLES: mean_std} - - learner_names = [ - name for name in names if name not in normalization_variables - ] - learner_variables = dict( - zip(learner_names, self._wrapped_learner.get_variables( - learner_names))) if learner_names else {} - - return [ - normalization_variables.get(name, learner_variables.get(name, None)) - for name in names - ] - - def save(self) -> NormalizationLearnerWrapperState: - return NormalizationLearnerWrapperState( - wrapped_learner_state=self._wrapped_learner.save(), - observation_running_statistics=self._observation_running_statistics) + """A learner wrapper that normalizes observations using running statistics.""" - def restore(self, state: NormalizationLearnerWrapperState): - self._wrapped_learner.restore(state.wrapped_learner_state) - self._observation_running_statistics = state.observation_running_statistics + def __init__( + self, + learner_factory: Callable[[Iterator[reverb.ReplaySample]], acme.Learner], + iterator: Iterator[reverb.ReplaySample], + environment_spec: specs.EnvironmentSpec, + max_abs_observation: Optional[float], + ): + def normalize_sample( + observation_statistics: running_statistics.RunningStatisticsState, + sample: reverb.ReplaySample, + ) -> Tuple[running_statistics.RunningStatisticsState, reverb.ReplaySample]: + observation = sample.data.observation + observation_statistics = running_statistics.update( + observation_statistics, observation + ) + observation = running_statistics.normalize( + observation, observation_statistics, max_abs_value=max_abs_observation + ) + sample = reverb.ReplaySample( + sample.info, sample.data._replace(observation=observation) + ) + if hasattr(sample.data, "next_observation"): + next_observation = running_statistics.normalize( + sample.data.next_observation, + observation_statistics, + max_abs_value=max_abs_observation, + ) + sample = reverb.ReplaySample( + sample.info, sample.data._replace(next_observation=next_observation) + ) + + return observation_statistics, sample + + self._observation_running_statistics = running_statistics.init_state( + environment_spec.observations + ) + self._normalize_sample = jax.jit(normalize_sample) + + normalizing_iterator = ( + self._normalize_sample_and_update(sample) for sample in iterator + ) + self._wrapped_learner = learner_factory(normalizing_iterator) + + def _normalize_sample_and_update( + self, sample: reverb.ReplaySample + ) -> reverb.ReplaySample: + self._observation_running_statistics, sample = self._normalize_sample( + self._observation_running_statistics, sample + ) + return sample + + def step(self): + self._wrapped_learner.step() + + def get_variables(self, names: List[str]) -> List[types.NestedArray]: + stats = self._observation_running_statistics + # Make sure to only pass mean and std to minimize trafic. + mean_std = running_statistics.NestedMeanStd(mean=stats.mean, std=stats.std) + normalization_variables = {_NORMALIZATION_VARIABLES: mean_std} + + learner_names = [name for name in names if name not in normalization_variables] + learner_variables = ( + dict(zip(learner_names, self._wrapped_learner.get_variables(learner_names))) + if learner_names + else {} + ) + + return [ + normalization_variables.get(name, learner_variables.get(name, None)) + for name in names + ] + + def save(self) -> NormalizationLearnerWrapperState: + return NormalizationLearnerWrapperState( + wrapped_learner_state=self._wrapped_learner.save(), + observation_running_statistics=self._observation_running_statistics, + ) + + def restore(self, state: NormalizationLearnerWrapperState): + self._wrapped_learner.restore(state.wrapped_learner_state) + self._observation_running_statistics = state.observation_running_statistics @dataclasses.dataclass -class NormalizationBuilder(Generic[Networks, Policy], - builders.ActorLearnerBuilder[Networks, Policy, - reverb.ReplaySample]): - """Builder wrapper that normalizes observations using running mean/std.""" - builder: builders.ActorLearnerBuilder[Networks, Policy, reverb.ReplaySample] - max_abs_observation: Optional[float] = 10.0 - statistics_update_period: int = 100 - - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: Policy, - ) -> List[reverb.Table]: - return self.builder.make_replay_tables(environment_spec, policy) - - def make_dataset_iterator( - self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: - return self.builder.make_dataset_iterator(replay_client) - - def make_adder(self, replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[Policy]) -> Optional[adders.Adder]: - return self.builder.make_adder(replay_client, environment_spec, policy) - - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: Networks, - dataset: Iterator[reverb.ReplaySample], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - - learner_factory = functools.partial( - self.builder.make_learner, - random_key, - networks, - logger_fn=logger_fn, - environment_spec=environment_spec, - replay_client=replay_client, - counter=counter) - - return NormalizationLearnerWrapper( - learner_factory=learner_factory, - iterator=dataset, - environment_spec=environment_spec, - max_abs_observation=self.max_abs_observation) - - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: Policy, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - adder: Optional[adders.Adder] = None, - ) -> core.Actor: - actor = self.builder.make_actor(random_key, policy, environment_spec, - variable_source, adder) - return NormalizationActorWrapper( - actor, - variable_source, - max_abs_observation=self.max_abs_observation, - update_period=self.statistics_update_period, - backend='cpu') - - def make_policy(self, - networks: Networks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> Policy: - return self.builder.make_policy( - networks=networks, - environment_spec=environment_spec, - evaluation=evaluation) - - -@dataclasses.dataclass(frozen=True) -class NormalizationConfig: - """Configuration for normalization based on running statistics. - - Attributes: - max_abs: Maximum value for clipping. - statistics_update_period: How often to update running statistics used for - normalization. - """ - max_abs: int = 10 - statistics_update_period: int = 100 - - -class InputNormalizerConfig(Protocol): - """Protocol for the config of the agent that uses the normalization decorator. - - If the agent builder is decorated with the `input_normalization_builder` - the agent config class must implement this protocol. - """ - - @property - def input_normalization(self) -> Optional[NormalizationConfig]: - ... +class NormalizationBuilder( + Generic[Networks, Policy], + builders.ActorLearnerBuilder[Networks, Policy, reverb.ReplaySample], +): + """Builder wrapper that normalizes observations using running mean/std.""" - -def input_normalization_builder( - actor_learner_builder_class: Callable[[InputNormalizerConfig], - builders.ActorLearnerBuilder]): - """Builder class decorator that adds support for input normalization.""" - - # TODO(b/247075349): find a way to use ActorLearnerBuilderWrapper here. - class InputNormalizationBuilder( - Generic[builders.Networks, builders.Policy, builders.Sample], - builders.ActorLearnerBuilder[builders.Networks, builders.Policy, - builders.Sample]): - """Builder wrapper that adds input normalization based on the config.""" - - def __init__(self, config: InputNormalizerConfig): - builder = actor_learner_builder_class(config) - if config.input_normalization: - builder = NormalizationBuilder( - builder, - max_abs_observation=config.input_normalization.max_abs, - statistics_update_period=config.input_normalization - .statistics_update_period) - self.wrapped = builder + builder: builders.ActorLearnerBuilder[Networks, Policy, reverb.ReplaySample] + max_abs_observation: Optional[float] = 10.0 + statistics_update_period: int = 100 def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: builders.Policy, + self, environment_spec: specs.EnvironmentSpec, policy: Policy, ) -> List[reverb.Table]: - return self.wrapped.make_replay_tables(environment_spec, policy) + return self.builder.make_replay_tables(environment_spec, policy) def make_dataset_iterator( - self, - replay_client: reverb.Client, - ) -> Iterator[builders.Sample]: - return self.wrapped.make_dataset_iterator(replay_client) + self, replay_client: reverb.Client + ) -> Iterator[reverb.ReplaySample]: + return self.builder.make_dataset_iterator(replay_client) def make_adder( self, replay_client: reverb.Client, environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[builders.Policy], + policy: Optional[Policy], ) -> Optional[adders.Adder]: - return self.wrapped.make_adder(replay_client, environment_spec, policy) + return self.builder.make_adder(replay_client, environment_spec, policy) + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: Networks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + + learner_factory = functools.partial( + self.builder.make_learner, + random_key, + networks, + logger_fn=logger_fn, + environment_spec=environment_spec, + replay_client=replay_client, + counter=counter, + ) + + return NormalizationLearnerWrapper( + learner_factory=learner_factory, + iterator=dataset, + environment_spec=environment_spec, + max_abs_observation=self.max_abs_observation, + ) def make_actor( self, random_key: networks_lib.PRNGKey, - policy: builders.Policy, + policy: Policy, environment_spec: specs.EnvironmentSpec, variable_source: Optional[core.VariableSource] = None, adder: Optional[adders.Adder] = None, ) -> core.Actor: - return self.wrapped.make_actor(random_key, policy, environment_spec, - variable_source, adder) - - def make_learner( + actor = self.builder.make_actor( + random_key, policy, environment_spec, variable_source, adder + ) + return NormalizationActorWrapper( + actor, + variable_source, + max_abs_observation=self.max_abs_observation, + update_period=self.statistics_update_period, + backend="cpu", + ) + + def make_policy( self, - random_key: networks_lib.PRNGKey, networks: Networks, - dataset: Iterator[builders.Sample], - logger_fn: loggers.LoggerFactory, environment_spec: specs.EnvironmentSpec, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - return self.wrapped.make_learner(random_key, networks, dataset, logger_fn, - environment_spec, replay_client, counter) + evaluation: bool = False, + ) -> Policy: + return self.builder.make_policy( + networks=networks, environment_spec=environment_spec, evaluation=evaluation + ) + + +@dataclasses.dataclass(frozen=True) +class NormalizationConfig: + """Configuration for normalization based on running statistics. + + Attributes: + max_abs: Maximum value for clipping. + statistics_update_period: How often to update running statistics used for + normalization. + """ + + max_abs: int = 10 + statistics_update_period: int = 100 - def make_policy(self, - networks: builders.Networks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> builders.Policy: - return self.wrapped.make_policy(networks, environment_spec, evaluation) - return InputNormalizationBuilder +class InputNormalizerConfig(Protocol): + """Protocol for the config of the agent that uses the normalization decorator. + + If the agent builder is decorated with the `input_normalization_builder` + the agent config class must implement this protocol. + """ + + @property + def input_normalization(self) -> Optional[NormalizationConfig]: + ... + + +def input_normalization_builder( + actor_learner_builder_class: Callable[ + [InputNormalizerConfig], builders.ActorLearnerBuilder + ] +): + """Builder class decorator that adds support for input normalization.""" + + # TODO(b/247075349): find a way to use ActorLearnerBuilderWrapper here. + class InputNormalizationBuilder( + Generic[builders.Networks, builders.Policy, builders.Sample], + builders.ActorLearnerBuilder[ + builders.Networks, builders.Policy, builders.Sample + ], + ): + """Builder wrapper that adds input normalization based on the config.""" + + def __init__(self, config: InputNormalizerConfig): + builder = actor_learner_builder_class(config) + if config.input_normalization: + builder = NormalizationBuilder( + builder, + max_abs_observation=config.input_normalization.max_abs, + statistics_update_period=config.input_normalization.statistics_update_period, + ) + self.wrapped = builder + + def make_replay_tables( + self, environment_spec: specs.EnvironmentSpec, policy: builders.Policy, + ) -> List[reverb.Table]: + return self.wrapped.make_replay_tables(environment_spec, policy) + + def make_dataset_iterator( + self, replay_client: reverb.Client, + ) -> Iterator[builders.Sample]: + return self.wrapped.make_dataset_iterator(replay_client) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[builders.Policy], + ) -> Optional[adders.Adder]: + return self.wrapped.make_adder(replay_client, environment_spec, policy) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: builders.Policy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + return self.wrapped.make_actor( + random_key, policy, environment_spec, variable_source, adder + ) + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: Networks, + dataset: Iterator[builders.Sample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + return self.wrapped.make_learner( + random_key, + networks, + dataset, + logger_fn, + environment_spec, + replay_client, + counter, + ) + + def make_policy( + self, + networks: builders.Networks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> builders.Policy: + return self.wrapped.make_policy(networks, environment_spec, evaluation) + + return InputNormalizationBuilder diff --git a/acme/agents/jax/ppo/__init__.py b/acme/agents/jax/ppo/__init__.py index 4185e6345d..dfa023d6ee 100644 --- a/acme/agents/jax/ppo/__init__.py +++ b/acme/agents/jax/ppo/__init__.py @@ -17,17 +17,21 @@ from acme.agents.jax.ppo.builder import PPOBuilder from acme.agents.jax.ppo.config import PPOConfig from acme.agents.jax.ppo.learning import PPOLearner -from acme.agents.jax.ppo.networks import EntropyFn -from acme.agents.jax.ppo.networks import make_categorical_ppo_networks -from acme.agents.jax.ppo.networks import make_continuous_networks -from acme.agents.jax.ppo.networks import make_discrete_networks -from acme.agents.jax.ppo.networks import make_inference_fn -from acme.agents.jax.ppo.networks import make_mvn_diag_ppo_networks -from acme.agents.jax.ppo.networks import make_networks -from acme.agents.jax.ppo.networks import make_ppo_networks -from acme.agents.jax.ppo.networks import make_tanh_normal_ppo_networks -from acme.agents.jax.ppo.networks import PPONetworks -from acme.agents.jax.ppo.normalization import build_ema_mean_std_normalizer -from acme.agents.jax.ppo.normalization import build_mean_std_normalizer -from acme.agents.jax.ppo.normalization import NormalizationFns -from acme.agents.jax.ppo.normalization import NormalizedGenericActor +from acme.agents.jax.ppo.networks import ( + EntropyFn, + PPONetworks, + make_categorical_ppo_networks, + make_continuous_networks, + make_discrete_networks, + make_inference_fn, + make_mvn_diag_ppo_networks, + make_networks, + make_ppo_networks, + make_tanh_normal_ppo_networks, +) +from acme.agents.jax.ppo.normalization import ( + NormalizationFns, + NormalizedGenericActor, + build_ema_mean_std_normalizer, + build_mean_std_normalizer, +) diff --git a/acme/agents/jax/ppo/builder.py b/acme/agents/jax/ppo/builder.py index 9dc7c4a5e3..330e83558b 100644 --- a/acme/agents/jax/ppo/builder.py +++ b/acme/agents/jax/ppo/builder.py @@ -15,68 +15,69 @@ """PPO Builder.""" from typing import Iterator, List, Optional -from acme import adders -from acme import core -from acme import specs +import jax +import numpy as np +import optax +import reverb + +from acme import adders, core, specs from acme.adders import reverb as adders_reverb from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.agents.jax import builders +from acme.agents.jax import actors, builders from acme.agents.jax.ppo import config as ppo_config from acme.agents.jax.ppo import learning from acme.agents.jax.ppo import networks as ppo_networks from acme.agents.jax.ppo import normalization from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.jax import variable_utils -from acme.utils import counting -from acme.utils import loggers -import jax -import numpy as np -import optax -import reverb +from acme.jax import utils, variable_utils +from acme.utils import counting, loggers class PPOBuilder( - builders.ActorLearnerBuilder[ppo_networks.PPONetworks, - actor_core_lib.FeedForwardPolicyWithExtra, - reverb.ReplaySample]): - """PPO Builder.""" - - def __init__( - self, - config: ppo_config.PPOConfig, - ): - """Creates PPO builder.""" - self._config = config - - # An extra step is used for bootstrapping when computing advantages. - self._sequence_length = config.unroll_length + 1 - - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: actor_core_lib.FeedForwardPolicyWithExtra, - ) -> List[reverb.Table]: - """Creates reverb tables for the algorithm.""" - del policy - # params_num_sgd_steps is used to track how old the actor parameters are - extra_spec = { - 'log_prob': np.ones(shape=(), dtype=np.float32), - 'params_num_sgd_steps': np.ones(shape=(), dtype=np.float32), - } - signature = adders_reverb.SequenceAdder.signature( - environment_spec, extra_spec, sequence_length=self._sequence_length) - return [ - reverb.Table.queue( - name=self._config.replay_table_name, - max_size=self._config.batch_size, - signature=signature) + builders.ActorLearnerBuilder[ + ppo_networks.PPONetworks, + actor_core_lib.FeedForwardPolicyWithExtra, + reverb.ReplaySample, ] - - def make_dataset_iterator( - self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: - """Creates a dataset. +): + """PPO Builder.""" + + def __init__( + self, config: ppo_config.PPOConfig, + ): + """Creates PPO builder.""" + self._config = config + + # An extra step is used for bootstrapping when computing advantages. + self._sequence_length = config.unroll_length + 1 + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: actor_core_lib.FeedForwardPolicyWithExtra, + ) -> List[reverb.Table]: + """Creates reverb tables for the algorithm.""" + del policy + # params_num_sgd_steps is used to track how old the actor parameters are + extra_spec = { + "log_prob": np.ones(shape=(), dtype=np.float32), + "params_num_sgd_steps": np.ones(shape=(), dtype=np.float32), + } + signature = adders_reverb.SequenceAdder.signature( + environment_spec, extra_spec, sequence_length=self._sequence_length + ) + return [ + reverb.Table.queue( + name=self._config.replay_table_name, + max_size=self._config.batch_size, + signature=signature, + ) + ] + + def make_dataset_iterator( + self, replay_client: reverb.Client + ) -> Iterator[reverb.ReplaySample]: + """Creates a dataset. The iterator batch size is computed as follows: @@ -101,141 +102,154 @@ def make_dataset_iterator( Returns: A replay buffer iterator to be used by the local devices. """ - iterator_batch_size, ragged = divmod(self._config.batch_size, - jax.device_count()) - if ragged: - raise ValueError( - 'Learner batch size must be divisible by total number of devices!') - - # We don't use datasets.make_reverb_dataset() here to avoid interleaving - # and prefetching, that doesn't work well with can_sample() check on update. - # NOTE: Value for max_in_flight_samples_per_worker comes from a - # recommendation here: https://git.io/JYzXB - dataset = reverb.TrajectoryDataset.from_table_signature( - server_address=replay_client.server_address, - table=self._config.replay_table_name, - max_in_flight_samples_per_worker=( - 2 * self._config.batch_size // jax.process_count() - ), - ) - dataset = dataset.batch(iterator_batch_size, drop_remainder=True) - dataset = dataset.as_numpy_iterator() - return utils.multi_device_put(iterable=dataset, devices=jax.local_devices()) - - def make_adder( - self, - replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[actor_core_lib.FeedForwardPolicyWithExtra], - ) -> Optional[adders.Adder]: - """Creates an adder which handles observations.""" - del environment_spec, policy - # Note that the last transition in the sequence is used for bootstrapping - # only and is ignored otherwise. So we need to make sure that sequences - # overlap on one transition, thus "-1" in the period length computation. - return adders_reverb.SequenceAdder( - client=replay_client, - priority_fns={self._config.replay_table_name: None}, - period=self._sequence_length - 1, - sequence_length=self._sequence_length, - ) - - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: ppo_networks.PPONetworks, - dataset: Iterator[reverb.ReplaySample], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - del replay_client - - if callable(self._config.learning_rate): - optimizer = optax.chain( - optax.clip_by_global_norm(self._config.max_gradient_norm), - optax.scale_by_adam(eps=self._config.adam_epsilon), - optax.scale_by_schedule(self._config.learning_rate), optax.scale(-1)) # pytype: disable=wrong-arg-types # numpy-scalars - else: - optimizer = optax.chain( - optax.clip_by_global_norm(self._config.max_gradient_norm), - optax.scale_by_adam(eps=self._config.adam_epsilon), - optax.scale(-self._config.learning_rate)) - - obs_normalization_fns = None - if self._config.obs_normalization_fns_factory is not None: - obs_normalization_fns = self._config.obs_normalization_fns_factory( - environment_spec.observations) - - return learning.PPOLearner( - ppo_networks=networks, - iterator=dataset, - discount=self._config.discount, - entropy_cost=self._config.entropy_cost, - value_cost=self._config.value_cost, - ppo_clipping_epsilon=self._config.ppo_clipping_epsilon, - normalize_advantage=self._config.normalize_advantage, - normalize_value=self._config.normalize_value, - normalization_ema_tau=self._config.normalization_ema_tau, - clip_value=self._config.clip_value, - value_clipping_epsilon=self._config.value_clipping_epsilon, - max_abs_reward=self._config.max_abs_reward, - gae_lambda=self._config.gae_lambda, - counter=counter, - random_key=random_key, - optimizer=optimizer, - num_epochs=self._config.num_epochs, - num_minibatches=self._config.num_minibatches, - logger=logger_fn('learner'), - log_global_norm_metrics=self._config.log_global_norm_metrics, - metrics_logging_period=self._config.metrics_logging_period, - pmap_axis_name=self._config.pmap_axis_name, - obs_normalization_fns=obs_normalization_fns, - ) - - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: actor_core_lib.FeedForwardPolicyWithExtra, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - adder: Optional[adders.Adder] = None, - ) -> core.Actor: - assert variable_source is not None - actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core( - policy) - if self._config.obs_normalization_fns_factory is not None: - variable_client = variable_utils.VariableClient( - variable_source, ['params', 'obs_normalization_params'], - device='cpu', - update_period=self._config.variable_update_period) - obs_normalization_fns = self._config.obs_normalization_fns_factory( - environment_spec.observations) - actor = normalization.NormalizedGenericActor( - actor_core, - obs_normalization_fns, - random_key, - variable_client, - adder, - jit=True, - backend='cpu', - per_episode_update=False, - ) - else: - variable_client = variable_utils.VariableClient( - variable_source, - 'params', - device='cpu', - update_period=self._config.variable_update_period) - actor = actors.GenericActor( - actor_core, random_key, variable_client, adder, backend='cpu') - return actor - - def make_policy( - self, - networks: ppo_networks.PPONetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> actor_core_lib.FeedForwardPolicyWithExtra: - del environment_spec - return ppo_networks.make_inference_fn(networks, evaluation) + iterator_batch_size, ragged = divmod( + self._config.batch_size, jax.device_count() + ) + if ragged: + raise ValueError( + "Learner batch size must be divisible by total number of devices!" + ) + + # We don't use datasets.make_reverb_dataset() here to avoid interleaving + # and prefetching, that doesn't work well with can_sample() check on update. + # NOTE: Value for max_in_flight_samples_per_worker comes from a + # recommendation here: https://git.io/JYzXB + dataset = reverb.TrajectoryDataset.from_table_signature( + server_address=replay_client.server_address, + table=self._config.replay_table_name, + max_in_flight_samples_per_worker=( + 2 * self._config.batch_size // jax.process_count() + ), + ) + dataset = dataset.batch(iterator_batch_size, drop_remainder=True) + dataset = dataset.as_numpy_iterator() + return utils.multi_device_put(iterable=dataset, devices=jax.local_devices()) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[actor_core_lib.FeedForwardPolicyWithExtra], + ) -> Optional[adders.Adder]: + """Creates an adder which handles observations.""" + del environment_spec, policy + # Note that the last transition in the sequence is used for bootstrapping + # only and is ignored otherwise. So we need to make sure that sequences + # overlap on one transition, thus "-1" in the period length computation. + return adders_reverb.SequenceAdder( + client=replay_client, + priority_fns={self._config.replay_table_name: None}, + period=self._sequence_length - 1, + sequence_length=self._sequence_length, + ) + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: ppo_networks.PPONetworks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del replay_client + + if callable(self._config.learning_rate): + optimizer = optax.chain( + optax.clip_by_global_norm(self._config.max_gradient_norm), + optax.scale_by_adam(eps=self._config.adam_epsilon), + optax.scale_by_schedule(self._config.learning_rate), + optax.scale(-1), + ) # pytype: disable=wrong-arg-types # numpy-scalars + else: + optimizer = optax.chain( + optax.clip_by_global_norm(self._config.max_gradient_norm), + optax.scale_by_adam(eps=self._config.adam_epsilon), + optax.scale(-self._config.learning_rate), + ) + + obs_normalization_fns = None + if self._config.obs_normalization_fns_factory is not None: + obs_normalization_fns = self._config.obs_normalization_fns_factory( + environment_spec.observations + ) + + return learning.PPOLearner( + ppo_networks=networks, + iterator=dataset, + discount=self._config.discount, + entropy_cost=self._config.entropy_cost, + value_cost=self._config.value_cost, + ppo_clipping_epsilon=self._config.ppo_clipping_epsilon, + normalize_advantage=self._config.normalize_advantage, + normalize_value=self._config.normalize_value, + normalization_ema_tau=self._config.normalization_ema_tau, + clip_value=self._config.clip_value, + value_clipping_epsilon=self._config.value_clipping_epsilon, + max_abs_reward=self._config.max_abs_reward, + gae_lambda=self._config.gae_lambda, + counter=counter, + random_key=random_key, + optimizer=optimizer, + num_epochs=self._config.num_epochs, + num_minibatches=self._config.num_minibatches, + logger=logger_fn("learner"), + log_global_norm_metrics=self._config.log_global_norm_metrics, + metrics_logging_period=self._config.metrics_logging_period, + pmap_axis_name=self._config.pmap_axis_name, + obs_normalization_fns=obs_normalization_fns, + ) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: actor_core_lib.FeedForwardPolicyWithExtra, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + assert variable_source is not None + actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core( + policy + ) + if self._config.obs_normalization_fns_factory is not None: + variable_client = variable_utils.VariableClient( + variable_source, + ["params", "obs_normalization_params"], + device="cpu", + update_period=self._config.variable_update_period, + ) + obs_normalization_fns = self._config.obs_normalization_fns_factory( + environment_spec.observations + ) + actor = normalization.NormalizedGenericActor( + actor_core, + obs_normalization_fns, + random_key, + variable_client, + adder, + jit=True, + backend="cpu", + per_episode_update=False, + ) + else: + variable_client = variable_utils.VariableClient( + variable_source, + "params", + device="cpu", + update_period=self._config.variable_update_period, + ) + actor = actors.GenericActor( + actor_core, random_key, variable_client, adder, backend="cpu" + ) + return actor + + def make_policy( + self, + networks: ppo_networks.PPONetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> actor_core_lib.FeedForwardPolicyWithExtra: + del environment_spec + return ppo_networks.make_inference_fn(networks, evaluation) diff --git a/acme/agents/jax/ppo/config.py b/acme/agents/jax/ppo/config.py index 94c411b316..13896b52a1 100644 --- a/acme/agents/jax/ppo/config.py +++ b/acme/agents/jax/ppo/config.py @@ -14,7 +14,7 @@ """PPO config.""" import dataclasses -from typing import Callable, Union, Optional +from typing import Callable, Optional, Union from acme import types from acme.adders import reverb as adders_reverb @@ -23,7 +23,7 @@ @dataclasses.dataclass class PPOConfig: - """Configuration options for PPO. + """Configuration options for PPO. Attributes: unroll_length: Length of sequences added to the replay buffer. @@ -61,28 +61,30 @@ class PPOConfig: normalization functions. Setting to None (default) disables observation normalization. """ - unroll_length: int = 8 - num_minibatches: int = 8 - num_epochs: int = 2 - batch_size: int = 256 - replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE - ppo_clipping_epsilon: float = 0.2 - normalize_advantage: bool = False - normalize_value: bool = False - normalization_ema_tau: float = 0.995 - clip_value: bool = False - value_clipping_epsilon: float = 0.2 - max_abs_reward: Optional[float] = None - gae_lambda: float = 0.95 - discount: float = 0.99 - learning_rate: Union[float, Callable[[int], float]] = 3e-4 - adam_epsilon: float = 1e-7 - entropy_cost: float = 3e-4 - value_cost: float = 1. - max_gradient_norm: float = 0.5 - variable_update_period: int = 1 - log_global_norm_metrics: bool = False - metrics_logging_period: int = 100 - pmap_axis_name: str = 'devices' - obs_normalization_fns_factory: Optional[Callable[ - [types.NestedSpec], normalization.NormalizationFns]] = None + + unroll_length: int = 8 + num_minibatches: int = 8 + num_epochs: int = 2 + batch_size: int = 256 + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE + ppo_clipping_epsilon: float = 0.2 + normalize_advantage: bool = False + normalize_value: bool = False + normalization_ema_tau: float = 0.995 + clip_value: bool = False + value_clipping_epsilon: float = 0.2 + max_abs_reward: Optional[float] = None + gae_lambda: float = 0.95 + discount: float = 0.99 + learning_rate: Union[float, Callable[[int], float]] = 3e-4 + adam_epsilon: float = 1e-7 + entropy_cost: float = 3e-4 + value_cost: float = 1.0 + max_gradient_norm: float = 0.5 + variable_update_period: int = 1 + log_global_norm_metrics: bool = False + metrics_logging_period: int = 100 + pmap_axis_name: str = "devices" + obs_normalization_fns_factory: Optional[ + Callable[[types.NestedSpec], normalization.NormalizationFns] + ] = None diff --git a/acme/agents/jax/ppo/learning.py b/acme/agents/jax/ppo/learning.py index 101f249b9b..9336df7eef 100644 --- a/acme/agents/jax/ppo/learning.py +++ b/acme/agents/jax/ppo/learning.py @@ -16,503 +16,550 @@ from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple -import acme -from acme import types -from acme.agents.jax.ppo import networks -from acme.agents.jax.ppo import normalization -from acme.jax import networks as networks_lib -from acme.jax.utils import get_from_first_device -from acme.utils import counting -from acme.utils import loggers import jax import jax.numpy as jnp import optax import reverb import rlax +import acme +from acme import types +from acme.agents.jax.ppo import networks, normalization +from acme.jax import networks as networks_lib +from acme.jax.utils import get_from_first_device +from acme.utils import counting, loggers PPOParams = networks.PPOParams class Batch(NamedTuple): - """A batch of data; all shapes are expected to be [B, ...].""" - observations: types.NestedArray - actions: jnp.ndarray - advantages: jnp.ndarray + """A batch of data; all shapes are expected to be [B, ...].""" - # Target value estimate used to bootstrap the value function. - target_values: jnp.ndarray + observations: types.NestedArray + actions: jnp.ndarray + advantages: jnp.ndarray - # Value estimate and action log-prob at behavior time. - behavior_values: jnp.ndarray - behavior_log_probs: jnp.ndarray + # Target value estimate used to bootstrap the value function. + target_values: jnp.ndarray + + # Value estimate and action log-prob at behavior time. + behavior_values: jnp.ndarray + behavior_log_probs: jnp.ndarray class TrainingState(NamedTuple): - """Training state for the PPO learner.""" - params: PPOParams - opt_state: optax.OptState - random_key: networks_lib.PRNGKey + """Training state for the PPO learner.""" + + params: PPOParams + opt_state: optax.OptState + random_key: networks_lib.PRNGKey - # Optional counter used for exponential moving average zero debiasing - # Using float32 as it covers a larger range than int32. If using int64 we - # would need to do jax_enable_x64. - ema_counter: Optional[jnp.float32] = None + # Optional counter used for exponential moving average zero debiasing + # Using float32 as it covers a larger range than int32. If using int64 we + # would need to do jax_enable_x64. + ema_counter: Optional[jnp.float32] = None - # Optional parameter for maintaining a running estimate of the scale of - # advantage estimates - biased_advantage_scale: Optional[networks_lib.Params] = None - advantage_scale: Optional[networks_lib.Params] = None + # Optional parameter for maintaining a running estimate of the scale of + # advantage estimates + biased_advantage_scale: Optional[networks_lib.Params] = None + advantage_scale: Optional[networks_lib.Params] = None - # Optional parameter for maintaining a running estimate of the mean and - # standard deviation of value estimates - biased_value_first_moment: Optional[networks_lib.Params] = None - biased_value_second_moment: Optional[networks_lib.Params] = None - value_mean: Optional[networks_lib.Params] = None - value_std: Optional[networks_lib.Params] = None + # Optional parameter for maintaining a running estimate of the mean and + # standard deviation of value estimates + biased_value_first_moment: Optional[networks_lib.Params] = None + biased_value_second_moment: Optional[networks_lib.Params] = None + value_mean: Optional[networks_lib.Params] = None + value_std: Optional[networks_lib.Params] = None - # Optional parameters for observation normalization - obs_normalization_params: Optional[normalization.NormalizationParams] = None + # Optional parameters for observation normalization + obs_normalization_params: Optional[normalization.NormalizationParams] = None class PPOLearner(acme.Learner): - """Learner for PPO.""" - - def __init__( - self, - ppo_networks: networks.PPONetworks, - iterator: Iterator[reverb.ReplaySample], - optimizer: optax.GradientTransformation, - random_key: networks_lib.PRNGKey, - ppo_clipping_epsilon: float = 0.2, - normalize_advantage: bool = True, - normalize_value: bool = False, - normalization_ema_tau: float = 0.995, - clip_value: bool = False, - value_clipping_epsilon: float = 0.2, - max_abs_reward: Optional[float] = None, - gae_lambda: float = 0.95, - discount: float = 0.99, - entropy_cost: float = 0., - value_cost: float = 1., - num_epochs: int = 4, - num_minibatches: int = 1, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - log_global_norm_metrics: bool = False, - metrics_logging_period: int = 100, - pmap_axis_name: str = 'devices', - obs_normalization_fns: Optional[normalization.NormalizationFns] = None, - ): - self.local_learner_devices = jax.local_devices() - self.num_local_learner_devices = jax.local_device_count() - self.learner_devices = jax.devices() - self.num_epochs = num_epochs - self.num_minibatches = num_minibatches - self.metrics_logging_period = metrics_logging_period - self._num_full_update_steps = 0 - self._iterator = iterator - - normalize_obs = obs_normalization_fns is not None - if normalize_obs: - assert obs_normalization_fns is not None - - # Set up logging/counting. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger('learner') - - def ppo_loss( - params: networks_lib.Params, - observations: networks_lib.Observation, - actions: networks_lib.Action, - advantages: jnp.ndarray, - target_values: networks_lib.Value, - behavior_values: networks_lib.Value, - behavior_log_probs: networks_lib.LogProb, - value_mean: jnp.ndarray, - value_std: jnp.ndarray, - key: networks_lib.PRNGKey, - ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]: - """PPO loss for the policy and the critic.""" - distribution_params, values = ppo_networks.network.apply( - params, observations) - if normalize_value: - # values = values * jnp.fmax(value_std, 1e-6) + value_mean - target_values = (target_values - value_mean) / jnp.fmax(value_std, 1e-6) - policy_log_probs = ppo_networks.log_prob(distribution_params, actions) - key, sub_key = jax.random.split(key) - policy_entropies = ppo_networks.entropy(distribution_params, sub_key) - - # Compute the policy losses - rhos = jnp.exp(policy_log_probs - behavior_log_probs) - clipped_ppo_policy_loss = rlax.clipped_surrogate_pg_loss( - rhos, advantages, ppo_clipping_epsilon) - policy_entropy_loss = -jnp.mean(policy_entropies) - total_policy_loss = ( - clipped_ppo_policy_loss + entropy_cost * policy_entropy_loss) - - # Compute the critic losses - unclipped_value_loss = (values - target_values)**2 - - if clip_value: - # Clip values to reduce variablility during critic training. - clipped_values = behavior_values + jnp.clip(values - behavior_values, - -value_clipping_epsilon, - value_clipping_epsilon) - clipped_value_error = target_values - clipped_values - clipped_value_loss = clipped_value_error ** 2 - value_loss = jnp.mean(jnp.fmax(unclipped_value_loss, - clipped_value_loss)) - else: - # For Mujoco envs clipping hurts a lot. Evidenced by Figure 43 in - # https://arxiv.org/pdf/2006.05990.pdf - value_loss = jnp.mean(unclipped_value_loss) - - total_ppo_loss = total_policy_loss + value_cost * value_loss - return total_ppo_loss, { # pytype: disable=bad-return-type # numpy-scalars - 'loss_total': total_ppo_loss, - 'loss_policy_total': total_policy_loss, - 'loss_policy_pg': clipped_ppo_policy_loss, - 'loss_policy_entropy': policy_entropy_loss, - 'loss_critic': value_loss, - } - - ppo_loss_grad = jax.grad(ppo_loss, has_aux=True) - - def sgd_step(state: TrainingState, minibatch: Batch): - observations = minibatch.observations - actions = minibatch.actions - advantages = minibatch.advantages - target_values = minibatch.target_values - behavior_values = minibatch.behavior_values - behavior_log_probs = minibatch.behavior_log_probs - key, sub_key = jax.random.split(state.random_key) - - loss_grad, metrics = ppo_loss_grad( - state.params.model_params, - observations, - actions, - advantages, - target_values, - behavior_values, - behavior_log_probs, - state.value_mean, - state.value_std, - sub_key, - ) - - # Apply updates - loss_grad = jax.lax.pmean(loss_grad, axis_name=pmap_axis_name) - updates, opt_state = optimizer.update(loss_grad, state.opt_state) - model_params = optax.apply_updates(state.params.model_params, updates) - params = PPOParams( - model_params=model_params, - num_sgd_steps=state.params.num_sgd_steps + 1) - - if log_global_norm_metrics: - metrics['norm_grad'] = optax.global_norm(loss_grad) - metrics['norm_updates'] = optax.global_norm(updates) - - state = state._replace(params=params, opt_state=opt_state, random_key=key) - - return state, metrics - - def epoch_update( - carry: Tuple[TrainingState, Batch], - unused_t: Tuple[()], + """Learner for PPO.""" + + def __init__( + self, + ppo_networks: networks.PPONetworks, + iterator: Iterator[reverb.ReplaySample], + optimizer: optax.GradientTransformation, + random_key: networks_lib.PRNGKey, + ppo_clipping_epsilon: float = 0.2, + normalize_advantage: bool = True, + normalize_value: bool = False, + normalization_ema_tau: float = 0.995, + clip_value: bool = False, + value_clipping_epsilon: float = 0.2, + max_abs_reward: Optional[float] = None, + gae_lambda: float = 0.95, + discount: float = 0.99, + entropy_cost: float = 0.0, + value_cost: float = 1.0, + num_epochs: int = 4, + num_minibatches: int = 1, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + log_global_norm_metrics: bool = False, + metrics_logging_period: int = 100, + pmap_axis_name: str = "devices", + obs_normalization_fns: Optional[normalization.NormalizationFns] = None, ): - state, carry_batch = carry - - # Shuffling into minibatches - batch_size = carry_batch.advantages.shape[0] - key, sub_key = jax.random.split(state.random_key) - # TODO(kamyar) For effiency could use same permutation for all epochs - permuted_batch = jax.tree_util.tree_map( - lambda x: jax.random.permutation( # pylint: disable=g-long-lambda - sub_key, - x, - axis=0, - independent=False), - carry_batch) - state = state._replace(random_key=key) - minibatches = jax.tree_util.tree_map( - lambda x: jnp.reshape( # pylint: disable=g-long-lambda - x, - [ # pylint: disable=g-long-lambda - num_minibatches, batch_size // num_minibatches - ] + list(x.shape[1:])), - permuted_batch) - - # Scan over the minibatches - state, metrics = jax.lax.scan( - sgd_step, state, minibatches, length=num_minibatches) - metrics = jax.tree_util.tree_map(jnp.mean, metrics) - - return (state, carry_batch), metrics - - vmapped_network_apply = jax.vmap( - ppo_networks.network.apply, in_axes=(None, 0), out_axes=0) - - def single_device_update( - state: TrainingState, - trajectories: types.NestedArray, - ): - params_num_sgd_steps_before_update = state.params.num_sgd_steps - - # Update the EMA counter and obtain the zero debiasing multiplier - if normalize_advantage or normalize_value: - ema_counter = state.ema_counter + 1 - state = state._replace(ema_counter=ema_counter) - zero_debias = 1. / (1. - jnp.power(normalization_ema_tau, ema_counter)) - - # Extract the data. - data = trajectories.data - observations, actions, rewards, termination, extra = (data.observation, - data.action, - data.reward, - data.discount, - data.extras) - - if normalize_obs: - obs_norm_params = obs_normalization_fns.update( - state.obs_normalization_params, observations, pmap_axis_name) - state = state._replace(obs_normalization_params=obs_norm_params) - observations = obs_normalization_fns.normalize( - observations, state.obs_normalization_params) - - if max_abs_reward is not None: - # Apply reward clipping. - rewards = jnp.clip(rewards, -1. * max_abs_reward, max_abs_reward) - discounts = termination * discount - behavior_log_probs = extra['log_prob'] - _, behavior_values = vmapped_network_apply(state.params.model_params, - observations) - - if normalize_value: - batch_value_first_moment = jnp.mean(behavior_values) - batch_value_second_moment = jnp.mean(behavior_values**2) - batch_value_first_moment, batch_value_second_moment = jax.lax.pmean( - (batch_value_first_moment, batch_value_second_moment), - axis_name=pmap_axis_name) - - biased_value_first_moment = ( - normalization_ema_tau * state.biased_value_first_moment + - (1. - normalization_ema_tau) * batch_value_first_moment) - biased_value_second_moment = ( - normalization_ema_tau * state.biased_value_second_moment + - (1. - normalization_ema_tau) * batch_value_second_moment) - - value_mean = biased_value_first_moment * zero_debias - value_second_moment = biased_value_second_moment * zero_debias - value_std = jnp.sqrt(jax.nn.relu(value_second_moment - value_mean**2)) - - state = state._replace( - biased_value_first_moment=biased_value_first_moment, - biased_value_second_moment=biased_value_second_moment, - value_mean=value_mean, - value_std=value_std, + self.local_learner_devices = jax.local_devices() + self.num_local_learner_devices = jax.local_device_count() + self.learner_devices = jax.devices() + self.num_epochs = num_epochs + self.num_minibatches = num_minibatches + self.metrics_logging_period = metrics_logging_period + self._num_full_update_steps = 0 + self._iterator = iterator + + normalize_obs = obs_normalization_fns is not None + if normalize_obs: + assert obs_normalization_fns is not None + + # Set up logging/counting. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger("learner") + + def ppo_loss( + params: networks_lib.Params, + observations: networks_lib.Observation, + actions: networks_lib.Action, + advantages: jnp.ndarray, + target_values: networks_lib.Value, + behavior_values: networks_lib.Value, + behavior_log_probs: networks_lib.LogProb, + value_mean: jnp.ndarray, + value_std: jnp.ndarray, + key: networks_lib.PRNGKey, + ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]: + """PPO loss for the policy and the critic.""" + distribution_params, values = ppo_networks.network.apply( + params, observations + ) + if normalize_value: + # values = values * jnp.fmax(value_std, 1e-6) + value_mean + target_values = (target_values - value_mean) / jnp.fmax(value_std, 1e-6) + policy_log_probs = ppo_networks.log_prob(distribution_params, actions) + key, sub_key = jax.random.split(key) + policy_entropies = ppo_networks.entropy(distribution_params, sub_key) + + # Compute the policy losses + rhos = jnp.exp(policy_log_probs - behavior_log_probs) + clipped_ppo_policy_loss = rlax.clipped_surrogate_pg_loss( + rhos, advantages, ppo_clipping_epsilon + ) + policy_entropy_loss = -jnp.mean(policy_entropies) + total_policy_loss = ( + clipped_ppo_policy_loss + entropy_cost * policy_entropy_loss + ) + + # Compute the critic losses + unclipped_value_loss = (values - target_values) ** 2 + + if clip_value: + # Clip values to reduce variablility during critic training. + clipped_values = behavior_values + jnp.clip( + values - behavior_values, + -value_clipping_epsilon, + value_clipping_epsilon, + ) + clipped_value_error = target_values - clipped_values + clipped_value_loss = clipped_value_error ** 2 + value_loss = jnp.mean( + jnp.fmax(unclipped_value_loss, clipped_value_loss) + ) + else: + # For Mujoco envs clipping hurts a lot. Evidenced by Figure 43 in + # https://arxiv.org/pdf/2006.05990.pdf + value_loss = jnp.mean(unclipped_value_loss) + + total_ppo_loss = total_policy_loss + value_cost * value_loss + return ( + total_ppo_loss, + { # pytype: disable=bad-return-type # numpy-scalars + "loss_total": total_ppo_loss, + "loss_policy_total": total_policy_loss, + "loss_policy_pg": clipped_ppo_policy_loss, + "loss_policy_entropy": policy_entropy_loss, + "loss_critic": value_loss, + }, + ) + + ppo_loss_grad = jax.grad(ppo_loss, has_aux=True) + + def sgd_step(state: TrainingState, minibatch: Batch): + observations = minibatch.observations + actions = minibatch.actions + advantages = minibatch.advantages + target_values = minibatch.target_values + behavior_values = minibatch.behavior_values + behavior_log_probs = minibatch.behavior_log_probs + key, sub_key = jax.random.split(state.random_key) + + loss_grad, metrics = ppo_loss_grad( + state.params.model_params, + observations, + actions, + advantages, + target_values, + behavior_values, + behavior_log_probs, + state.value_mean, + state.value_std, + sub_key, + ) + + # Apply updates + loss_grad = jax.lax.pmean(loss_grad, axis_name=pmap_axis_name) + updates, opt_state = optimizer.update(loss_grad, state.opt_state) + model_params = optax.apply_updates(state.params.model_params, updates) + params = PPOParams( + model_params=model_params, num_sgd_steps=state.params.num_sgd_steps + 1 + ) + + if log_global_norm_metrics: + metrics["norm_grad"] = optax.global_norm(loss_grad) + metrics["norm_updates"] = optax.global_norm(updates) + + state = state._replace(params=params, opt_state=opt_state, random_key=key) + + return state, metrics + + def epoch_update( + carry: Tuple[TrainingState, Batch], unused_t: Tuple[()], + ): + state, carry_batch = carry + + # Shuffling into minibatches + batch_size = carry_batch.advantages.shape[0] + key, sub_key = jax.random.split(state.random_key) + # TODO(kamyar) For effiency could use same permutation for all epochs + permuted_batch = jax.tree_util.tree_map( + lambda x: jax.random.permutation( # pylint: disable=g-long-lambda + sub_key, x, axis=0, independent=False + ), + carry_batch, + ) + state = state._replace(random_key=key) + minibatches = jax.tree_util.tree_map( + lambda x: jnp.reshape( # pylint: disable=g-long-lambda + x, + [ # pylint: disable=g-long-lambda + num_minibatches, + batch_size // num_minibatches, + ] + + list(x.shape[1:]), + ), + permuted_batch, + ) + + # Scan over the minibatches + state, metrics = jax.lax.scan( + sgd_step, state, minibatches, length=num_minibatches + ) + metrics = jax.tree_util.tree_map(jnp.mean, metrics) + + return (state, carry_batch), metrics + + vmapped_network_apply = jax.vmap( + ppo_networks.network.apply, in_axes=(None, 0), out_axes=0 ) - behavior_values = behavior_values * jnp.fmax(state.value_std, - 1e-6) + state.value_mean - - behavior_values = jax.lax.stop_gradient(behavior_values) - - # Compute GAE using rlax - vmapped_rlax_truncated_generalized_advantage_estimation = jax.vmap( - rlax.truncated_generalized_advantage_estimation, - in_axes=(0, 0, None, 0)) - advantages = vmapped_rlax_truncated_generalized_advantage_estimation( - rewards[:, :-1], discounts[:, :-1], gae_lambda, behavior_values) - advantages = jax.lax.stop_gradient(advantages) - target_values = behavior_values[:, :-1] + advantages - target_values = jax.lax.stop_gradient(target_values) - - # Exclude the last step - it was only used for bootstrapping. - # The shape is [num_sequences, num_steps, ..] - (observations, actions, behavior_log_probs, behavior_values) = ( - jax.tree_util.tree_map( - lambda x: x[:, :-1], - (observations, actions, behavior_log_probs, behavior_values), - ) - ) - - # Shuffle the data and break into minibatches - batch_size = advantages.shape[0] * advantages.shape[1] - batch = Batch( - observations=observations, - actions=actions, - advantages=advantages, - target_values=target_values, - behavior_values=behavior_values, - behavior_log_probs=behavior_log_probs) - batch = jax.tree_util.tree_map( - lambda x: jnp.reshape(x, [batch_size] + list(x.shape[2:])), batch) - - if normalize_advantage: - batch_advantage_scale = jnp.mean(jnp.abs(batch.advantages)) - batch_advantage_scale = jax.lax.pmean(batch_advantage_scale, - pmap_axis_name) - - # update the running statistics - biased_advantage_scale = ( - normalization_ema_tau * state.biased_advantage_scale + - (1. - normalization_ema_tau) * batch_advantage_scale) - advantage_scale = biased_advantage_scale * zero_debias - state = state._replace( - biased_advantage_scale=biased_advantage_scale, - advantage_scale=advantage_scale) - - # scale the advantages - scaled_advantages = batch.advantages / jnp.fmax(state.advantage_scale, - 1e-6) - batch = batch._replace(advantages=scaled_advantages) - - # Scan desired number of epoch updates - (state, _), metrics = jax.lax.scan( - epoch_update, (state, batch), (), length=num_epochs) - metrics = jax.tree_util.tree_map(jnp.mean, metrics) - - if normalize_advantage: - metrics['advantage_scale'] = state.advantage_scale - - if normalize_value: - metrics['value_mean'] = value_mean - metrics['value_std'] = value_std - - delta_params_sgd_steps = ( - data.extras['params_num_sgd_steps'][:, 0] - - params_num_sgd_steps_before_update) - metrics['delta_params_sgd_steps_min'] = jnp.min(delta_params_sgd_steps) - metrics['delta_params_sgd_steps_max'] = jnp.max(delta_params_sgd_steps) - metrics['delta_params_sgd_steps_mean'] = jnp.mean(delta_params_sgd_steps) - metrics['delta_params_sgd_steps_std'] = jnp.std(delta_params_sgd_steps) - - return state, metrics - - pmapped_update_step = jax.pmap( - single_device_update, - axis_name=pmap_axis_name, - devices=self.learner_devices) - - def full_update_step( - state: TrainingState, - trajectories: types.NestedArray, - ): - state, metrics = pmapped_update_step(state, trajectories) - return state, metrics - - self._full_update_step = full_update_step - - def make_initial_state(key: networks_lib.PRNGKey) -> TrainingState: - """Initialises the training state (parameters and optimiser state).""" - all_keys = jax.random.split(key, num=self.num_local_learner_devices + 1) - key_init, key_state = all_keys[0], all_keys[1:] - key_state = [key_state[i] for i in range(self.num_local_learner_devices)] - key_state = jax.device_put_sharded(key_state, self.local_learner_devices) - - initial_params = ppo_networks.network.init(key_init) - initial_opt_state = optimizer.init(initial_params) - # Using float32 as it covers a larger range than int32. If using int64 we - # would need to do jax_enable_x64. - params_num_sgd_steps = jnp.zeros(shape=(), dtype=jnp.float32) - - initial_params = jax.device_put_replicated(initial_params, - self.local_learner_devices) - initial_opt_state = jax.device_put_replicated(initial_opt_state, - self.local_learner_devices) - params_num_sgd_steps = jax.device_put_replicated( - params_num_sgd_steps, self.local_learner_devices) - - ema_counter = jnp.float32(0) - ema_counter = jax.device_put_replicated(ema_counter, - self.local_learner_devices) - - init_state = TrainingState( - params=PPOParams( - model_params=initial_params, num_sgd_steps=params_num_sgd_steps), - opt_state=initial_opt_state, - random_key=key_state, - ema_counter=ema_counter, - ) - - if normalize_advantage: - biased_advantage_scale = jax.device_put_replicated( - jnp.zeros([]), self.local_learner_devices) - advantage_scale = jax.device_put_replicated( - jnp.zeros([]), self.local_learner_devices) - - init_state = init_state._replace( - biased_advantage_scale=biased_advantage_scale, - advantage_scale=advantage_scale) - - if normalize_value: - biased_value_first_moment = jax.device_put_replicated( - jnp.zeros([]), self.local_learner_devices) - value_mean = biased_value_first_moment - - biased_value_second_moment = jax.device_put_replicated( - jnp.zeros([]), self.local_learner_devices) - value_second_moment = biased_value_second_moment - value_std = jnp.sqrt(jax.nn.relu(value_second_moment - value_mean**2)) - - init_state = init_state._replace( - biased_value_first_moment=biased_value_first_moment, - biased_value_second_moment=biased_value_second_moment, - value_mean=value_mean, - value_std=value_std) - - if normalize_obs: - obs_norm_params = obs_normalization_fns.init() # pytype: disable=attribute-error - obs_norm_params = jax.device_put_replicated(obs_norm_params, - self.local_learner_devices) - init_state = init_state._replace( - obs_normalization_params=obs_norm_params) - - return init_state - - # Initialise training state (parameters and optimizer state). - self._state = make_initial_state(random_key) - self._cached_state = get_from_first_device(self._state, as_numpy=True) - - def step(self): - """Does a learner step and logs the results. + def single_device_update( + state: TrainingState, trajectories: types.NestedArray, + ): + params_num_sgd_steps_before_update = state.params.num_sgd_steps + + # Update the EMA counter and obtain the zero debiasing multiplier + if normalize_advantage or normalize_value: + ema_counter = state.ema_counter + 1 + state = state._replace(ema_counter=ema_counter) + zero_debias = 1.0 / ( + 1.0 - jnp.power(normalization_ema_tau, ema_counter) + ) + + # Extract the data. + data = trajectories.data + observations, actions, rewards, termination, extra = ( + data.observation, + data.action, + data.reward, + data.discount, + data.extras, + ) + + if normalize_obs: + obs_norm_params = obs_normalization_fns.update( + state.obs_normalization_params, observations, pmap_axis_name + ) + state = state._replace(obs_normalization_params=obs_norm_params) + observations = obs_normalization_fns.normalize( + observations, state.obs_normalization_params + ) + + if max_abs_reward is not None: + # Apply reward clipping. + rewards = jnp.clip(rewards, -1.0 * max_abs_reward, max_abs_reward) + discounts = termination * discount + behavior_log_probs = extra["log_prob"] + _, behavior_values = vmapped_network_apply( + state.params.model_params, observations + ) + + if normalize_value: + batch_value_first_moment = jnp.mean(behavior_values) + batch_value_second_moment = jnp.mean(behavior_values ** 2) + batch_value_first_moment, batch_value_second_moment = jax.lax.pmean( + (batch_value_first_moment, batch_value_second_moment), + axis_name=pmap_axis_name, + ) + + biased_value_first_moment = ( + normalization_ema_tau * state.biased_value_first_moment + + (1.0 - normalization_ema_tau) * batch_value_first_moment + ) + biased_value_second_moment = ( + normalization_ema_tau * state.biased_value_second_moment + + (1.0 - normalization_ema_tau) * batch_value_second_moment + ) + + value_mean = biased_value_first_moment * zero_debias + value_second_moment = biased_value_second_moment * zero_debias + value_std = jnp.sqrt(jax.nn.relu(value_second_moment - value_mean ** 2)) + + state = state._replace( + biased_value_first_moment=biased_value_first_moment, + biased_value_second_moment=biased_value_second_moment, + value_mean=value_mean, + value_std=value_std, + ) + + behavior_values = ( + behavior_values * jnp.fmax(state.value_std, 1e-6) + state.value_mean + ) + + behavior_values = jax.lax.stop_gradient(behavior_values) + + # Compute GAE using rlax + vmapped_rlax_truncated_generalized_advantage_estimation = jax.vmap( + rlax.truncated_generalized_advantage_estimation, in_axes=(0, 0, None, 0) + ) + advantages = vmapped_rlax_truncated_generalized_advantage_estimation( + rewards[:, :-1], discounts[:, :-1], gae_lambda, behavior_values + ) + advantages = jax.lax.stop_gradient(advantages) + target_values = behavior_values[:, :-1] + advantages + target_values = jax.lax.stop_gradient(target_values) + + # Exclude the last step - it was only used for bootstrapping. + # The shape is [num_sequences, num_steps, ..] + ( + observations, + actions, + behavior_log_probs, + behavior_values, + ) = jax.tree_util.tree_map( + lambda x: x[:, :-1], + (observations, actions, behavior_log_probs, behavior_values), + ) + + # Shuffle the data and break into minibatches + batch_size = advantages.shape[0] * advantages.shape[1] + batch = Batch( + observations=observations, + actions=actions, + advantages=advantages, + target_values=target_values, + behavior_values=behavior_values, + behavior_log_probs=behavior_log_probs, + ) + batch = jax.tree_util.tree_map( + lambda x: jnp.reshape(x, [batch_size] + list(x.shape[2:])), batch + ) + + if normalize_advantage: + batch_advantage_scale = jnp.mean(jnp.abs(batch.advantages)) + batch_advantage_scale = jax.lax.pmean( + batch_advantage_scale, pmap_axis_name + ) + + # update the running statistics + biased_advantage_scale = ( + normalization_ema_tau * state.biased_advantage_scale + + (1.0 - normalization_ema_tau) * batch_advantage_scale + ) + advantage_scale = biased_advantage_scale * zero_debias + state = state._replace( + biased_advantage_scale=biased_advantage_scale, + advantage_scale=advantage_scale, + ) + + # scale the advantages + scaled_advantages = batch.advantages / jnp.fmax( + state.advantage_scale, 1e-6 + ) + batch = batch._replace(advantages=scaled_advantages) + + # Scan desired number of epoch updates + (state, _), metrics = jax.lax.scan( + epoch_update, (state, batch), (), length=num_epochs + ) + metrics = jax.tree_util.tree_map(jnp.mean, metrics) + + if normalize_advantage: + metrics["advantage_scale"] = state.advantage_scale + + if normalize_value: + metrics["value_mean"] = value_mean + metrics["value_std"] = value_std + + delta_params_sgd_steps = ( + data.extras["params_num_sgd_steps"][:, 0] + - params_num_sgd_steps_before_update + ) + metrics["delta_params_sgd_steps_min"] = jnp.min(delta_params_sgd_steps) + metrics["delta_params_sgd_steps_max"] = jnp.max(delta_params_sgd_steps) + metrics["delta_params_sgd_steps_mean"] = jnp.mean(delta_params_sgd_steps) + metrics["delta_params_sgd_steps_std"] = jnp.std(delta_params_sgd_steps) + + return state, metrics + + pmapped_update_step = jax.pmap( + single_device_update, axis_name=pmap_axis_name, devices=self.learner_devices + ) + + def full_update_step( + state: TrainingState, trajectories: types.NestedArray, + ): + state, metrics = pmapped_update_step(state, trajectories) + return state, metrics + + self._full_update_step = full_update_step + + def make_initial_state(key: networks_lib.PRNGKey) -> TrainingState: + """Initialises the training state (parameters and optimiser state).""" + all_keys = jax.random.split(key, num=self.num_local_learner_devices + 1) + key_init, key_state = all_keys[0], all_keys[1:] + key_state = [key_state[i] for i in range(self.num_local_learner_devices)] + key_state = jax.device_put_sharded(key_state, self.local_learner_devices) + + initial_params = ppo_networks.network.init(key_init) + initial_opt_state = optimizer.init(initial_params) + # Using float32 as it covers a larger range than int32. If using int64 we + # would need to do jax_enable_x64. + params_num_sgd_steps = jnp.zeros(shape=(), dtype=jnp.float32) + + initial_params = jax.device_put_replicated( + initial_params, self.local_learner_devices + ) + initial_opt_state = jax.device_put_replicated( + initial_opt_state, self.local_learner_devices + ) + params_num_sgd_steps = jax.device_put_replicated( + params_num_sgd_steps, self.local_learner_devices + ) + + ema_counter = jnp.float32(0) + ema_counter = jax.device_put_replicated( + ema_counter, self.local_learner_devices + ) + + init_state = TrainingState( + params=PPOParams( + model_params=initial_params, num_sgd_steps=params_num_sgd_steps + ), + opt_state=initial_opt_state, + random_key=key_state, + ema_counter=ema_counter, + ) + + if normalize_advantage: + biased_advantage_scale = jax.device_put_replicated( + jnp.zeros([]), self.local_learner_devices + ) + advantage_scale = jax.device_put_replicated( + jnp.zeros([]), self.local_learner_devices + ) + + init_state = init_state._replace( + biased_advantage_scale=biased_advantage_scale, + advantage_scale=advantage_scale, + ) + + if normalize_value: + biased_value_first_moment = jax.device_put_replicated( + jnp.zeros([]), self.local_learner_devices + ) + value_mean = biased_value_first_moment + + biased_value_second_moment = jax.device_put_replicated( + jnp.zeros([]), self.local_learner_devices + ) + value_second_moment = biased_value_second_moment + value_std = jnp.sqrt(jax.nn.relu(value_second_moment - value_mean ** 2)) + + init_state = init_state._replace( + biased_value_first_moment=biased_value_first_moment, + biased_value_second_moment=biased_value_second_moment, + value_mean=value_mean, + value_std=value_std, + ) + + if normalize_obs: + obs_norm_params = ( + obs_normalization_fns.init() + ) # pytype: disable=attribute-error + obs_norm_params = jax.device_put_replicated( + obs_norm_params, self.local_learner_devices + ) + init_state = init_state._replace( + obs_normalization_params=obs_norm_params + ) + + return init_state + + # Initialise training state (parameters and optimizer state). + self._state = make_initial_state(random_key) + self._cached_state = get_from_first_device(self._state, as_numpy=True) + + def step(self): + """Does a learner step and logs the results. One learner step consists of (possibly multiple) epochs of PPO updates on a batch of NxT steps collected by the actors. """ - sample = next(self._iterator) - self._state, results = self._full_update_step(self._state, sample) - self._cached_state = get_from_first_device(self._state, as_numpy=True) - - # Update our counts and record it. - counts = self._counter.increment(steps=self.num_epochs * - self.num_minibatches) - - # Snapshot and attempt to write logs. - if self._num_full_update_steps % self.metrics_logging_period == 0: - results = jax.tree_util.tree_map(jnp.mean, results) - self._logger.write({**results, **counts}) - - self._num_full_update_steps += 1 - - def get_variables(self, names: List[str]) -> List[networks_lib.Params]: - variables = self._cached_state - return [getattr(variables, name) for name in names] - - def save(self) -> TrainingState: - return self._cached_state - - def restore(self, state: TrainingState): - # TODO(kamyar) Should the random_key come from self._state instead? - random_key = state.random_key - random_key = jax.random.split( - random_key, num=self.num_local_learner_devices) - random_key = jax.device_put_sharded( - [random_key[i] for i in range(self.num_local_learner_devices)], - self.local_learner_devices) - - state = jax.device_put_replicated(state, self.local_learner_devices) - state = state._replace(random_key=random_key) - self._state = state - self._cached_state = get_from_first_device(self._state, as_numpy=True) + sample = next(self._iterator) + self._state, results = self._full_update_step(self._state, sample) + self._cached_state = get_from_first_device(self._state, as_numpy=True) + + # Update our counts and record it. + counts = self._counter.increment(steps=self.num_epochs * self.num_minibatches) + + # Snapshot and attempt to write logs. + if self._num_full_update_steps % self.metrics_logging_period == 0: + results = jax.tree_util.tree_map(jnp.mean, results) + self._logger.write({**results, **counts}) + + self._num_full_update_steps += 1 + + def get_variables(self, names: List[str]) -> List[networks_lib.Params]: + variables = self._cached_state + return [getattr(variables, name) for name in names] + + def save(self) -> TrainingState: + return self._cached_state + + def restore(self, state: TrainingState): + # TODO(kamyar) Should the random_key come from self._state instead? + random_key = state.random_key + random_key = jax.random.split(random_key, num=self.num_local_learner_devices) + random_key = jax.device_put_sharded( + [random_key[i] for i in range(self.num_local_learner_devices)], + self.local_learner_devices, + ) + + state = jax.device_put_replicated(state, self.local_learner_devices) + state = state._replace(random_key=random_key) + self._state = state + self._cached_state = get_from_first_device(self._state, as_numpy=True) diff --git a/acme/agents/jax/ppo/networks.py b/acme/agents/jax/ppo/networks.py index 99b009a269..41ac1aafbe 100644 --- a/acme/agents/jax/ppo/networks.py +++ b/acme/agents/jax/ppo/networks.py @@ -17,51 +17,53 @@ import dataclasses from typing import Callable, NamedTuple, Optional, Sequence -from acme import specs -from acme.agents.jax import actor_core as actor_core_lib -from acme.jax import networks as networks_lib -from acme.jax import utils import haiku as hk import jax import jax.numpy as jnp import numpy as np import tensorflow_probability +from acme import specs +from acme.agents.jax import actor_core as actor_core_lib +from acme.jax import networks as networks_lib +from acme.jax import utils + tfp = tensorflow_probability.substrates.jax tfd = tfp.distributions -EntropyFn = Callable[ - [networks_lib.Params, networks_lib.PRNGKey], networks_lib.Entropy -] +EntropyFn = Callable[[networks_lib.Params, networks_lib.PRNGKey], networks_lib.Entropy] class MVNDiagParams(NamedTuple): - """Parameters for a diagonal multi-variate normal distribution.""" - loc: jnp.ndarray - scale_diag: jnp.ndarray + """Parameters for a diagonal multi-variate normal distribution.""" + + loc: jnp.ndarray + scale_diag: jnp.ndarray class TanhNormalParams(NamedTuple): - """Parameters for a tanh squashed diagonal MVN distribution.""" - loc: jnp.ndarray - scale: jnp.ndarray + """Parameters for a tanh squashed diagonal MVN distribution.""" + + loc: jnp.ndarray + scale: jnp.ndarray class CategoricalParams(NamedTuple): - """Parameters for a categorical distribution.""" - logits: jnp.ndarray + """Parameters for a categorical distribution.""" + + logits: jnp.ndarray class PPOParams(NamedTuple): - model_params: networks_lib.Params - # Using float32 as it covers a larger range than int32. If using int64 we - # would need to do jax_enable_x64. - num_sgd_steps: jnp.float32 + model_params: networks_lib.Params + # Using float32 as it covers a larger range than int32. If using int64 we + # would need to do jax_enable_x64. + num_sgd_steps: jnp.float32 @dataclasses.dataclass class PPONetworks: - """Network and pure functions for the PPO agent. + """Network and pure functions for the PPO agent. If 'network' returns tfd.Distribution, you can use make_ppo_networks() to create this object properly. @@ -73,56 +75,57 @@ class PPONetworks: make_continuous_networks() for an example where the network does not return a tfd.Distribution object. """ - network: networks_lib.FeedForwardNetwork - log_prob: networks_lib.LogProbFn - entropy: EntropyFn - sample: networks_lib.SampleFn - sample_eval: Optional[networks_lib.SampleFn] = None + network: networks_lib.FeedForwardNetwork + log_prob: networks_lib.LogProbFn + entropy: EntropyFn + sample: networks_lib.SampleFn + sample_eval: Optional[networks_lib.SampleFn] = None -def make_inference_fn( - ppo_networks: PPONetworks, - evaluation: bool = False) -> actor_core_lib.FeedForwardPolicyWithExtra: - """Returns a function to be used for inference by a PPO actor.""" - - def inference( - params: networks_lib.Params, - key: networks_lib.PRNGKey, - observations: networks_lib.Observation, - ): - dist_params, _ = ppo_networks.network.apply(params.model_params, - observations) - if evaluation and ppo_networks.sample_eval: - actions = ppo_networks.sample_eval(dist_params, key) - else: - actions = ppo_networks.sample(dist_params, key) - if evaluation: - return actions, {} - log_prob = ppo_networks.log_prob(dist_params, actions) - extras = { - 'log_prob': log_prob, - # Add batch dimension. - 'params_num_sgd_steps': params.num_sgd_steps[None, ...] - } - return actions, extras - return inference +def make_inference_fn( + ppo_networks: PPONetworks, evaluation: bool = False +) -> actor_core_lib.FeedForwardPolicyWithExtra: + """Returns a function to be used for inference by a PPO actor.""" + + def inference( + params: networks_lib.Params, + key: networks_lib.PRNGKey, + observations: networks_lib.Observation, + ): + dist_params, _ = ppo_networks.network.apply(params.model_params, observations) + if evaluation and ppo_networks.sample_eval: + actions = ppo_networks.sample_eval(dist_params, key) + else: + actions = ppo_networks.sample(dist_params, key) + if evaluation: + return actions, {} + log_prob = ppo_networks.log_prob(dist_params, actions) + extras = { + "log_prob": log_prob, + # Add batch dimension. + "params_num_sgd_steps": params.num_sgd_steps[None, ...], + } + return actions, extras + + return inference def make_networks( spec: specs.EnvironmentSpec, hidden_layer_sizes: Sequence[int] = (256, 256) ) -> PPONetworks: - if isinstance(spec.actions, specs.DiscreteArray): - return make_discrete_networks(spec, hidden_layer_sizes) - else: - return make_continuous_networks( - spec, - policy_layer_sizes=hidden_layer_sizes, - value_layer_sizes=hidden_layer_sizes) + if isinstance(spec.actions, specs.DiscreteArray): + return make_discrete_networks(spec, hidden_layer_sizes) + else: + return make_continuous_networks( + spec, + policy_layer_sizes=hidden_layer_sizes, + value_layer_sizes=hidden_layer_sizes, + ) def make_ppo_networks(network: networks_lib.FeedForwardNetwork) -> PPONetworks: - """Constructs a PPONetworks instance from the given FeedForwardNetwork. + """Constructs a PPONetworks instance from the given FeedForwardNetwork. This method assumes that the network returns a tfd.Distribution. Sometimes it may be preferable to have networks that do not return tfd.Distribution @@ -137,17 +140,17 @@ def make_ppo_networks(network: networks_lib.FeedForwardNetwork) -> PPONetworks: Returns: A PPONetworks instance with pure functions wrapping the input network. """ - return PPONetworks( - network=network, - log_prob=lambda distribution, action: distribution.log_prob(action), - entropy=lambda distribution, key=None: distribution.entropy(), - sample=lambda distribution, key: distribution.sample(seed=key), - sample_eval=lambda distribution, key: distribution.mode()) + return PPONetworks( + network=network, + log_prob=lambda distribution, action: distribution.log_prob(action), + entropy=lambda distribution, key=None: distribution.entropy(), + sample=lambda distribution, key: distribution.sample(seed=key), + sample_eval=lambda distribution, key: distribution.mode(), + ) -def make_mvn_diag_ppo_networks( - network: networks_lib.FeedForwardNetwork) -> PPONetworks: - """Constructs a PPONetworks for MVN Diag policy from the FeedForwardNetwork. +def make_mvn_diag_ppo_networks(network: networks_lib.FeedForwardNetwork) -> PPONetworks: + """Constructs a PPONetworks for MVN Diag policy from the FeedForwardNetwork. Args: network: a transformed Haiku network (or equivalent in other libraries) that @@ -157,37 +160,43 @@ def make_mvn_diag_ppo_networks( A PPONetworks instance with pure functions wrapping the input network. """ - def log_prob(params: MVNDiagParams, action): - return tfd.MultivariateNormalDiag( - loc=params.loc, scale_diag=params.scale_diag).log_prob(action) - - def entropy( - params: MVNDiagParams, key: networks_lib.PRNGKey - ) -> networks_lib.Entropy: - del key - return tfd.MultivariateNormalDiag( - loc=params.loc, scale_diag=params.scale_diag).entropy() - - def sample(params: MVNDiagParams, key: networks_lib.PRNGKey): - return tfd.MultivariateNormalDiag( - loc=params.loc, scale_diag=params.scale_diag).sample(seed=key) - - def sample_eval(params: MVNDiagParams, key: networks_lib.PRNGKey): - del key - return tfd.MultivariateNormalDiag( - loc=params.loc, scale_diag=params.scale_diag).mode() - - return PPONetworks( - network=network, - log_prob=log_prob, - entropy=entropy, - sample=sample, - sample_eval=sample_eval) + def log_prob(params: MVNDiagParams, action): + return tfd.MultivariateNormalDiag( + loc=params.loc, scale_diag=params.scale_diag + ).log_prob(action) + + def entropy( + params: MVNDiagParams, key: networks_lib.PRNGKey + ) -> networks_lib.Entropy: + del key + return tfd.MultivariateNormalDiag( + loc=params.loc, scale_diag=params.scale_diag + ).entropy() + + def sample(params: MVNDiagParams, key: networks_lib.PRNGKey): + return tfd.MultivariateNormalDiag( + loc=params.loc, scale_diag=params.scale_diag + ).sample(seed=key) + + def sample_eval(params: MVNDiagParams, key: networks_lib.PRNGKey): + del key + return tfd.MultivariateNormalDiag( + loc=params.loc, scale_diag=params.scale_diag + ).mode() + + return PPONetworks( + network=network, + log_prob=log_prob, + entropy=entropy, + sample=sample, + sample_eval=sample_eval, + ) def make_tanh_normal_ppo_networks( - network: networks_lib.FeedForwardNetwork) -> PPONetworks: - """Constructs a PPONetworks for Tanh MVN Diag policy from the FeedForwardNetwork. + network: networks_lib.FeedForwardNetwork, +) -> PPONetworks: + """Constructs a PPONetworks for Tanh MVN Diag policy from the FeedForwardNetwork. Args: network: a transformed Haiku network (or equivalent in other libraries) that @@ -197,38 +206,40 @@ def make_tanh_normal_ppo_networks( A PPONetworks instance with pure functions wrapping the input network. """ - def build_distribution(params: TanhNormalParams): - distribution = tfd.Normal(loc=params.loc, scale=params.scale) - distribution = tfd.Independent( - networks_lib.TanhTransformedDistribution(distribution), - reinterpreted_batch_ndims=1) - return distribution - - def log_prob(params: TanhNormalParams, action): - distribution = build_distribution(params) - return distribution.log_prob(action) - - def entropy( - params: TanhNormalParams, key: networks_lib.PRNGKey - ) -> networks_lib.Entropy: - distribution = build_distribution(params) - return distribution.entropy(seed=key) - - def sample(params: TanhNormalParams, key: networks_lib.PRNGKey): - distribution = build_distribution(params) - return distribution.sample(seed=key) - - def sample_eval(params: TanhNormalParams, key: networks_lib.PRNGKey): - del key - distribution = build_distribution(params) - return distribution.mode() - - return PPONetworks( - network=network, - log_prob=log_prob, - entropy=entropy, - sample=sample, - sample_eval=sample_eval) + def build_distribution(params: TanhNormalParams): + distribution = tfd.Normal(loc=params.loc, scale=params.scale) + distribution = tfd.Independent( + networks_lib.TanhTransformedDistribution(distribution), + reinterpreted_batch_ndims=1, + ) + return distribution + + def log_prob(params: TanhNormalParams, action): + distribution = build_distribution(params) + return distribution.log_prob(action) + + def entropy( + params: TanhNormalParams, key: networks_lib.PRNGKey + ) -> networks_lib.Entropy: + distribution = build_distribution(params) + return distribution.entropy(seed=key) + + def sample(params: TanhNormalParams, key: networks_lib.PRNGKey): + distribution = build_distribution(params) + return distribution.sample(seed=key) + + def sample_eval(params: TanhNormalParams, key: networks_lib.PRNGKey): + del key + distribution = build_distribution(params) + return distribution.mode() + + return PPONetworks( + network=network, + log_prob=log_prob, + entropy=entropy, + sample=sample, + sample_eval=sample_eval, + ) def make_discrete_networks( @@ -236,7 +247,7 @@ def make_discrete_networks( hidden_layer_sizes: Sequence[int] = (512,), use_conv: bool = True, ) -> PPONetworks: - """Creates networks used by the agent for discrete action environments. + """Creates networks used by the agent for discrete action environments. Args: environment_spec: Environment spec used to define number of actions. @@ -246,33 +257,35 @@ def make_discrete_networks( PPONetworks """ - num_actions = environment_spec.actions.num_values - - def forward_fn(inputs): - layers = [] - if use_conv: - layers.extend([networks_lib.AtariTorso()]) - layers.extend([hk.nets.MLP(hidden_layer_sizes, activate_final=True)]) - trunk = hk.Sequential(layers) - h = utils.batch_concat(inputs) - h = trunk(h) - logits = hk.Linear(num_actions)(h) - values = hk.Linear(1)(h) - values = jnp.squeeze(values, axis=-1) - return (CategoricalParams(logits=logits), values) - - forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) - dummy_obs = utils.zeros_like(environment_spec.observations) - dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. - network = networks_lib.FeedForwardNetwork( - lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply) - # Create PPONetworks to add functionality required by the agent. - return make_categorical_ppo_networks(network) # pylint:disable=undefined-variable + num_actions = environment_spec.actions.num_values + + def forward_fn(inputs): + layers = [] + if use_conv: + layers.extend([networks_lib.AtariTorso()]) + layers.extend([hk.nets.MLP(hidden_layer_sizes, activate_final=True)]) + trunk = hk.Sequential(layers) + h = utils.batch_concat(inputs) + h = trunk(h) + logits = hk.Linear(num_actions)(h) + values = hk.Linear(1)(h) + values = jnp.squeeze(values, axis=-1) + return (CategoricalParams(logits=logits), values) + + forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) + dummy_obs = utils.zeros_like(environment_spec.observations) + dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. + network = networks_lib.FeedForwardNetwork( + lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply + ) + # Create PPONetworks to add functionality required by the agent. + return make_categorical_ppo_networks(network) # pylint:disable=undefined-variable def make_categorical_ppo_networks( - network: networks_lib.FeedForwardNetwork) -> PPONetworks: - """Constructs a PPONetworks for Categorical Policy from FeedForwardNetwork. + network: networks_lib.FeedForwardNetwork, +) -> PPONetworks: + """Constructs a PPONetworks for Categorical Policy from FeedForwardNetwork. Args: network: a transformed Haiku network (or equivalent in other libraries) that @@ -282,28 +295,29 @@ def make_categorical_ppo_networks( A PPONetworks instance with pure functions wrapping the input network. """ - def log_prob(params: CategoricalParams, action): - return tfd.Categorical(logits=params.logits).log_prob(action) + def log_prob(params: CategoricalParams, action): + return tfd.Categorical(logits=params.logits).log_prob(action) - def entropy( - params: CategoricalParams, key: networks_lib.PRNGKey - ) -> networks_lib.Entropy: - del key - return tfd.Categorical(logits=params.logits).entropy() + def entropy( + params: CategoricalParams, key: networks_lib.PRNGKey + ) -> networks_lib.Entropy: + del key + return tfd.Categorical(logits=params.logits).entropy() - def sample(params: CategoricalParams, key: networks_lib.PRNGKey): - return tfd.Categorical(logits=params.logits).sample(seed=key) + def sample(params: CategoricalParams, key: networks_lib.PRNGKey): + return tfd.Categorical(logits=params.logits).sample(seed=key) - def sample_eval(params: CategoricalParams, key: networks_lib.PRNGKey): - del key - return tfd.Categorical(logits=params.logits).mode() + def sample_eval(params: CategoricalParams, key: networks_lib.PRNGKey): + del key + return tfd.Categorical(logits=params.logits).mode() - return PPONetworks( - network=network, - log_prob=log_prob, - entropy=entropy, - sample=sample, - sample_eval=sample_eval) + return PPONetworks( + network=network, + log_prob=log_prob, + entropy=entropy, + sample=sample, + sample_eval=sample_eval, + ) def make_continuous_networks( @@ -312,72 +326,74 @@ def make_continuous_networks( value_layer_sizes: Sequence[int] = (64, 64), use_tanh_gaussian_policy: bool = True, ) -> PPONetworks: - """Creates PPONetworks to be used for continuous action environments.""" - - # Get total number of action dimensions from action spec. - num_dimensions = np.prod(environment_spec.actions.shape, dtype=int) - - def forward_fn(inputs: networks_lib.Observation): - - def _policy_network(obs: networks_lib.Observation): - h = utils.batch_concat(obs) - h = hk.nets.MLP(policy_layer_sizes, activate_final=True)(h) - - # tfd distributions have a weird bug in jax when vmapping is used, so the - # safer implementation in general is for the policy network to output the - # distribution parameters, and for the distribution to be constructed - # in a method such as make_ppo_networks above - if not use_tanh_gaussian_policy: - # Following networks_lib.MultivariateNormalDiagHead - init_scale = 0.3 - min_scale = 1e-6 - w_init = hk.initializers.VarianceScaling(1e-4) - b_init = hk.initializers.Constant(0.) - loc_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) - scale_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) - - loc = loc_layer(h) - scale = jax.nn.softplus(scale_layer(h)) - scale *= init_scale / jax.nn.softplus(0.) - scale += min_scale - - return MVNDiagParams(loc=loc, scale_diag=scale) - - # Following networks_lib.NormalTanhDistribution - min_scale = 1e-3 - w_init = hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform') - b_init = hk.initializers.Constant(0.) - loc_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) - scale_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) - - loc = loc_layer(h) - scale = scale_layer(h) - scale = jax.nn.softplus(scale) + min_scale - - return TanhNormalParams(loc=loc, scale=scale) - - value_network = hk.Sequential([ - utils.batch_concat, - hk.nets.MLP(value_layer_sizes, activate_final=True), - hk.Linear(1), - lambda x: jnp.squeeze(x, axis=-1) - ]) - - policy_output = _policy_network(inputs) - value = value_network(inputs) - return (policy_output, value) - - # Transform into pure functions. - forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) - - dummy_obs = utils.zeros_like(environment_spec.observations) - dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. - network = networks_lib.FeedForwardNetwork( - lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply) - - # Create PPONetworks to add functionality required by the agent. - - if not use_tanh_gaussian_policy: - return make_mvn_diag_ppo_networks(network) - - return make_tanh_normal_ppo_networks(network) + """Creates PPONetworks to be used for continuous action environments.""" + + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(environment_spec.actions.shape, dtype=int) + + def forward_fn(inputs: networks_lib.Observation): + def _policy_network(obs: networks_lib.Observation): + h = utils.batch_concat(obs) + h = hk.nets.MLP(policy_layer_sizes, activate_final=True)(h) + + # tfd distributions have a weird bug in jax when vmapping is used, so the + # safer implementation in general is for the policy network to output the + # distribution parameters, and for the distribution to be constructed + # in a method such as make_ppo_networks above + if not use_tanh_gaussian_policy: + # Following networks_lib.MultivariateNormalDiagHead + init_scale = 0.3 + min_scale = 1e-6 + w_init = hk.initializers.VarianceScaling(1e-4) + b_init = hk.initializers.Constant(0.0) + loc_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) + scale_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) + + loc = loc_layer(h) + scale = jax.nn.softplus(scale_layer(h)) + scale *= init_scale / jax.nn.softplus(0.0) + scale += min_scale + + return MVNDiagParams(loc=loc, scale_diag=scale) + + # Following networks_lib.NormalTanhDistribution + min_scale = 1e-3 + w_init = hk.initializers.VarianceScaling(1.0, "fan_in", "uniform") + b_init = hk.initializers.Constant(0.0) + loc_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) + scale_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) + + loc = loc_layer(h) + scale = scale_layer(h) + scale = jax.nn.softplus(scale) + min_scale + + return TanhNormalParams(loc=loc, scale=scale) + + value_network = hk.Sequential( + [ + utils.batch_concat, + hk.nets.MLP(value_layer_sizes, activate_final=True), + hk.Linear(1), + lambda x: jnp.squeeze(x, axis=-1), + ] + ) + + policy_output = _policy_network(inputs) + value = value_network(inputs) + return (policy_output, value) + + # Transform into pure functions. + forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) + + dummy_obs = utils.zeros_like(environment_spec.observations) + dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. + network = networks_lib.FeedForwardNetwork( + lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply + ) + + # Create PPONetworks to add functionality required by the agent. + + if not use_tanh_gaussian_policy: + return make_mvn_diag_ppo_networks(network) + + return make_tanh_normal_ppo_networks(network) diff --git a/acme/agents/jax/ppo/normalization.py b/acme/agents/jax/ppo/normalization.py index 1413f047fd..4d82672188 100644 --- a/acme/agents/jax/ppo/normalization.py +++ b/acme/agents/jax/ppo/normalization.py @@ -16,55 +16,56 @@ import dataclasses from typing import Any, Callable, Generic, NamedTuple, Optional -from acme import adders -from acme import types -from acme.agents.jax import actor_core -from acme.agents.jax import actors -from acme.jax import networks as network_lib -from acme.jax import running_statistics -from acme.jax import utils -from acme.jax import variable_utils import jax import jax.numpy as jnp +from acme import adders, types +from acme.agents.jax import actor_core, actors +from acme.jax import networks as network_lib +from acme.jax import running_statistics, utils, variable_utils + NormalizationParams = Any RunningStatisticsState = running_statistics.RunningStatisticsState @dataclasses.dataclass class NormalizationFns: - """Holds pure functions for normalization. + """Holds pure functions for normalization. Attributes: init: A pure function: ``params = init()`` normalize: A pure function: ``norm_x = normalize(x, params)`` update: A pure function: ``params = update(params, x, pmap_axis_name)`` """ - # Returns the initial parameters for the normalization utility. - init: Callable[[], NormalizationParams] - # Returns the normalized input nested array. - normalize: Callable[[types.NestedArray, NormalizationParams], - types.NestedArray] - # Returns updates normalization parameters. - update: Callable[[NormalizationParams, types.NestedArray, Optional[str]], - NormalizationParams] - - -class NormalizedGenericActor(actors.GenericActor[actor_core.State, - actor_core.Extras], - Generic[actor_core.State, actor_core.Extras]): - """A GenericActor that uses observation normalization.""" - - def __init__(self, - actor: actor_core.ActorCore[actor_core.State, actor_core.Extras], - normalization_fns: NormalizationFns, - random_key: network_lib.PRNGKey, - variable_client: Optional[variable_utils.VariableClient], - adder: Optional[adders.Adder] = None, - jit: bool = True, - backend: Optional[str] = 'cpu', - per_episode_update: bool = False): - """Initializes a feed forward actor. + + # Returns the initial parameters for the normalization utility. + init: Callable[[], NormalizationParams] + # Returns the normalized input nested array. + normalize: Callable[[types.NestedArray, NormalizationParams], types.NestedArray] + # Returns updates normalization parameters. + update: Callable[ + [NormalizationParams, types.NestedArray, Optional[str]], NormalizationParams + ] + + +class NormalizedGenericActor( + actors.GenericActor[actor_core.State, actor_core.Extras], + Generic[actor_core.State, actor_core.Extras], +): + """A GenericActor that uses observation normalization.""" + + def __init__( + self, + actor: actor_core.ActorCore[actor_core.State, actor_core.Extras], + normalization_fns: NormalizationFns, + random_key: network_lib.PRNGKey, + variable_client: Optional[variable_utils.VariableClient], + adder: Optional[adders.Adder] = None, + jit: bool = True, + backend: Optional[str] = "cpu", + per_episode_update: bool = False, + ): + """Initializes a feed forward actor. Args: actor: actor core. @@ -79,34 +80,35 @@ def __init__(self, per_episode_update: if True, updates variable client params once at the beginning of each episode """ - super().__init__(actor, random_key, variable_client, adder, jit, backend, - per_episode_update) - if jit: - self._apply_normalization = jax.jit( - normalization_fns.normalize, backend=backend) - else: - self._apply_normalization = normalization_fns.normalize - - def select_action(self, - observation: network_lib.Observation) -> types.NestedArray: - policy_params, obs_norm_params = tuple(self._params) - observation = self._apply_normalization(observation, obs_norm_params) - action, self._state = self._policy(policy_params, observation, self._state) - return utils.to_numpy(action) + super().__init__( + actor, random_key, variable_client, adder, jit, backend, per_episode_update + ) + if jit: + self._apply_normalization = jax.jit( + normalization_fns.normalize, backend=backend + ) + else: + self._apply_normalization = normalization_fns.normalize + + def select_action(self, observation: network_lib.Observation) -> types.NestedArray: + policy_params, obs_norm_params = tuple(self._params) + observation = self._apply_normalization(observation, obs_norm_params) + action, self._state = self._policy(policy_params, observation, self._state) + return utils.to_numpy(action) class EMAMeanStdNormalizerParams(NamedTuple): - """Using technique form Adam optimizer paper for computing running stats.""" - ema_counter: jnp.int32 - biased_first_moment: types.NestedArray - biased_second_moment: types.NestedArray + """Using technique form Adam optimizer paper for computing running stats.""" + + ema_counter: jnp.int32 + biased_first_moment: types.NestedArray + biased_second_moment: types.NestedArray def build_ema_mean_std_normalizer( - nested_spec: types.NestedSpec, - tau: float = 0.995, - epsilon: float = 1e-6,) -> NormalizationFns: - """Builds pure functions used for normalizing based on EMA mean and std. + nested_spec: types.NestedSpec, tau: float = 0.995, epsilon: float = 1e-6, +) -> NormalizationFns: + """Builds pure functions used for normalizing based on EMA mean and std. The built normalizer functions can be used to normalize nested arrays that have a structure corresponding to nested_spec. Currently only supports @@ -120,107 +122,111 @@ def build_ema_mean_std_normalizer( Returns: NormalizationFns to be used for normalization """ - nested_dims = jax.tree_util.tree_map(lambda x: len(x.shape), nested_spec) - - def init() -> EMAMeanStdNormalizerParams: - first_moment = utils.zeros_like(nested_spec) - second_moment = utils.zeros_like(nested_spec) - - return EMAMeanStdNormalizerParams( - ema_counter=jnp.int32(0), - biased_first_moment=first_moment, - biased_second_moment=second_moment, - ) - - def _normalize_leaf(x: jnp.array, ema_counter: jnp.int32, - biased_first_moment: jnp.array, - biased_second_moment: jnp.array) -> jnp.ndarray: - zero_debias = 1. / (1. - jnp.power(tau, ema_counter)) - mean = biased_first_moment * zero_debias - second_moment = biased_second_moment * zero_debias - std = jnp.sqrt(jax.nn.relu(second_moment - mean**2)) - - mean = jnp.broadcast_to(mean, x.shape) - std = jnp.broadcast_to(std, x.shape) - return (x - mean) / jnp.fmax(std, epsilon) - - def _normalize(nested_array: types.NestedArray, - params: EMAMeanStdNormalizerParams) -> types.NestedArray: - ema_counter = params.ema_counter - normalized_nested_array = jax.tree_util.tree_map( - lambda x, f, s: _normalize_leaf(x, ema_counter, f, s), - nested_array, - params.biased_first_moment, - params.biased_second_moment) - return normalized_nested_array - - def normalize(nested_array: types.NestedArray, - params: EMAMeanStdNormalizerParams) -> types.NestedArray: - ema_counter = params.ema_counter - norm_obs = jax.lax.cond( - ema_counter > 0, - _normalize, - lambda o, p: o, - nested_array, params) - return norm_obs - - def _compute_first_moment(x: jnp.ndarray, ndim: int): - reduce_axes = tuple(range(len(x.shape) - ndim)) - first_moment = jnp.mean(x, axis=reduce_axes) - return first_moment - - def _compute_second_moment(x: jnp.ndarray, ndim: int): - reduce_axes = tuple(range(len(x.shape) - ndim)) - second_moment = jnp.mean(x**2, axis=reduce_axes) - return second_moment - - def update( - params: EMAMeanStdNormalizerParams, - nested_array: types.NestedArray, - pmap_axis_name: Optional[str] = None) -> EMAMeanStdNormalizerParams: - # compute the stats - first_moment = jax.tree_util.tree_map( - _compute_first_moment, nested_array, nested_dims) - second_moment = jax.tree_util.tree_map( - _compute_second_moment, nested_array, nested_dims) - - # propagate across devices - if pmap_axis_name is not None: - first_moment, second_moment = jax.lax.pmean( - (first_moment, second_moment), axis_name=pmap_axis_name) - - # update running statistics - new_first_moment = jax.tree_util.tree_map( - lambda x, y: tau * x + # pylint: disable=g-long-lambda - (1. - tau) * y, - params.biased_first_moment, - first_moment) - new_second_moment = jax.tree_util.tree_map( - lambda x, y: tau * x + # pylint: disable=g-long-lambda - (1. - tau) * y, - params.biased_second_moment, - second_moment) - - # update ema_counter and return updated params - new_params = EMAMeanStdNormalizerParams( - ema_counter=params.ema_counter + 1, - biased_first_moment=new_first_moment, - biased_second_moment=new_second_moment, - ) - - return new_params - - return NormalizationFns( - init=init, - normalize=normalize, - update=update, - ) + nested_dims = jax.tree_util.tree_map(lambda x: len(x.shape), nested_spec) + + def init() -> EMAMeanStdNormalizerParams: + first_moment = utils.zeros_like(nested_spec) + second_moment = utils.zeros_like(nested_spec) + + return EMAMeanStdNormalizerParams( + ema_counter=jnp.int32(0), + biased_first_moment=first_moment, + biased_second_moment=second_moment, + ) + + def _normalize_leaf( + x: jnp.array, + ema_counter: jnp.int32, + biased_first_moment: jnp.array, + biased_second_moment: jnp.array, + ) -> jnp.ndarray: + zero_debias = 1.0 / (1.0 - jnp.power(tau, ema_counter)) + mean = biased_first_moment * zero_debias + second_moment = biased_second_moment * zero_debias + std = jnp.sqrt(jax.nn.relu(second_moment - mean ** 2)) + + mean = jnp.broadcast_to(mean, x.shape) + std = jnp.broadcast_to(std, x.shape) + return (x - mean) / jnp.fmax(std, epsilon) + + def _normalize( + nested_array: types.NestedArray, params: EMAMeanStdNormalizerParams + ) -> types.NestedArray: + ema_counter = params.ema_counter + normalized_nested_array = jax.tree_util.tree_map( + lambda x, f, s: _normalize_leaf(x, ema_counter, f, s), + nested_array, + params.biased_first_moment, + params.biased_second_moment, + ) + return normalized_nested_array + + def normalize( + nested_array: types.NestedArray, params: EMAMeanStdNormalizerParams + ) -> types.NestedArray: + ema_counter = params.ema_counter + norm_obs = jax.lax.cond( + ema_counter > 0, _normalize, lambda o, p: o, nested_array, params + ) + return norm_obs + + def _compute_first_moment(x: jnp.ndarray, ndim: int): + reduce_axes = tuple(range(len(x.shape) - ndim)) + first_moment = jnp.mean(x, axis=reduce_axes) + return first_moment + + def _compute_second_moment(x: jnp.ndarray, ndim: int): + reduce_axes = tuple(range(len(x.shape) - ndim)) + second_moment = jnp.mean(x ** 2, axis=reduce_axes) + return second_moment + + def update( + params: EMAMeanStdNormalizerParams, + nested_array: types.NestedArray, + pmap_axis_name: Optional[str] = None, + ) -> EMAMeanStdNormalizerParams: + # compute the stats + first_moment = jax.tree_util.tree_map( + _compute_first_moment, nested_array, nested_dims + ) + second_moment = jax.tree_util.tree_map( + _compute_second_moment, nested_array, nested_dims + ) + + # propagate across devices + if pmap_axis_name is not None: + first_moment, second_moment = jax.lax.pmean( + (first_moment, second_moment), axis_name=pmap_axis_name + ) + + # update running statistics + new_first_moment = jax.tree_util.tree_map( + lambda x, y: tau * x + (1.0 - tau) * y, # pylint: disable=g-long-lambda + params.biased_first_moment, + first_moment, + ) + new_second_moment = jax.tree_util.tree_map( + lambda x, y: tau * x + (1.0 - tau) * y, # pylint: disable=g-long-lambda + params.biased_second_moment, + second_moment, + ) + + # update ema_counter and return updated params + new_params = EMAMeanStdNormalizerParams( + ema_counter=params.ema_counter + 1, + biased_first_moment=new_first_moment, + biased_second_moment=new_second_moment, + ) + + return new_params + + return NormalizationFns(init=init, normalize=normalize, update=update,) def build_mean_std_normalizer( - nested_spec: types.NestedSpec, - max_abs_value: Optional[float] = None) -> NormalizationFns: - """Builds pure functions used for normalizing based on mean and std. + nested_spec: types.NestedSpec, max_abs_value: Optional[float] = None +) -> NormalizationFns: + """Builds pure functions used for normalizing based on mean and std. Arguments: nested_spec: A nested spec where all leaves have float dtype @@ -232,23 +238,23 @@ def build_mean_std_normalizer( NormalizationFns to be used for normalization """ - def init() -> RunningStatisticsState: - return running_statistics.init_state(nested_spec) - - def normalize( - nested_array: types.NestedArray, - params: RunningStatisticsState) -> types.NestedArray: - return running_statistics.normalize( - nested_array, params, max_abs_value=max_abs_value) - - def update( - params: RunningStatisticsState, - nested_array: types.NestedArray, - pmap_axis_name: Optional[str]) -> RunningStatisticsState: - return running_statistics.update( - params, nested_array, pmap_axis_name=pmap_axis_name) - - return NormalizationFns( - init=init, - normalize=normalize, - update=update) + def init() -> RunningStatisticsState: + return running_statistics.init_state(nested_spec) + + def normalize( + nested_array: types.NestedArray, params: RunningStatisticsState + ) -> types.NestedArray: + return running_statistics.normalize( + nested_array, params, max_abs_value=max_abs_value + ) + + def update( + params: RunningStatisticsState, + nested_array: types.NestedArray, + pmap_axis_name: Optional[str], + ) -> RunningStatisticsState: + return running_statistics.update( + params, nested_array, pmap_axis_name=pmap_axis_name + ) + + return NormalizationFns(init=init, normalize=normalize, update=update) diff --git a/acme/agents/jax/pwil/__init__.py b/acme/agents/jax/pwil/__init__.py index 95b492f3f5..d884121686 100644 --- a/acme/agents/jax/pwil/__init__.py +++ b/acme/agents/jax/pwil/__init__.py @@ -15,5 +15,4 @@ """PWIL agent.""" from acme.agents.jax.pwil.builder import PWILBuilder -from acme.agents.jax.pwil.config import PWILConfig -from acme.agents.jax.pwil.config import PWILDemonstrations +from acme.agents.jax.pwil.config import PWILConfig, PWILDemonstrations diff --git a/acme/agents/jax/pwil/adder.py b/acme/agents/jax/pwil/adder.py index 3aba5ec100..a7867b66c6 100644 --- a/acme/agents/jax/pwil/adder.py +++ b/acme/agents/jax/pwil/adder.py @@ -14,36 +14,43 @@ """Reward-substituting adder wrapper.""" -from acme import adders -from acme import types -from acme.agents.jax.pwil import rewarder import dm_env +from acme import adders, types +from acme.agents.jax.pwil import rewarder + class PWILAdder(adders.Adder): - """Adder wrapper substituting PWIL rewards.""" - - def __init__(self, direct_rl_adder: adders.Adder, - pwil_rewarder: rewarder.WassersteinDistanceRewarder): - self._adder = direct_rl_adder - self._rewarder = pwil_rewarder - self._latest_observation = None - - def add_first(self, timestep: dm_env.TimeStep): - self._rewarder.reset() - self._latest_observation = timestep.observation - self._adder.add_first(timestep) - - def add(self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - extras: types.NestedArray = ()): - updated_timestep = next_timestep._replace( - reward=self._rewarder.append_and_compute_reward( - observation=self._latest_observation, action=action)) - self._latest_observation = next_timestep.observation - self._adder.add(action, updated_timestep, extras) - - def reset(self): - self._latest_observation = None - self._adder.reset() + """Adder wrapper substituting PWIL rewards.""" + + def __init__( + self, + direct_rl_adder: adders.Adder, + pwil_rewarder: rewarder.WassersteinDistanceRewarder, + ): + self._adder = direct_rl_adder + self._rewarder = pwil_rewarder + self._latest_observation = None + + def add_first(self, timestep: dm_env.TimeStep): + self._rewarder.reset() + self._latest_observation = timestep.observation + self._adder.add_first(timestep) + + def add( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = (), + ): + updated_timestep = next_timestep._replace( + reward=self._rewarder.append_and_compute_reward( + observation=self._latest_observation, action=action + ) + ) + self._latest_observation = next_timestep.observation + self._adder.add(action, updated_timestep, extras) + + def reset(self): + self._latest_observation = None + self._adder.reset() diff --git a/acme/agents/jax/pwil/builder.py b/acme/agents/jax/pwil/builder.py index edf49e2642..a0eb27c9b0 100644 --- a/acme/agents/jax/pwil/builder.py +++ b/acme/agents/jax/pwil/builder.py @@ -17,29 +17,31 @@ import threading from typing import Callable, Generic, Iterator, List, Optional, Sequence -from acme import adders -from acme import core -from acme import specs -from acme import types +import dm_env +import numpy as np +import reverb + +from acme import adders, core, specs, types from acme.agents.jax import builders from acme.agents.jax.pwil import adder as pwil_adder from acme.agents.jax.pwil import config as pwil_config from acme.agents.jax.pwil import rewarder from acme.jax import networks as networks_lib -from acme.jax.imitation_learning_types import DirectPolicyNetwork, DirectRLNetworks # pylint: disable=g-multiple-import +from acme.jax.imitation_learning_types import ( # pylint: disable=g-multiple-import + DirectPolicyNetwork, + DirectRLNetworks, +) from acme.jax.types import PRNGKey -from acme.utils import counting -from acme.utils import loggers -import dm_env -import numpy as np -import reverb +from acme.utils import counting, loggers -def _prefill_with_demonstrations(adder: adders.Adder, - demonstrations: Sequence[types.Transition], - reward: Optional[float], - min_num_transitions: int = 0) -> None: - """Fill the adder's replay buffer with expert transitions. +def _prefill_with_demonstrations( + adder: adders.Adder, + demonstrations: Sequence[types.Transition], + reward: Optional[float], + min_num_transitions: int = 0, +) -> None: + """Fill the adder's replay buffer with expert transitions. Assumes that the demonstrations dataset stores transitions in order. @@ -52,58 +54,63 @@ def _prefill_with_demonstrations(adder: adders.Adder, min_num_transitions are added, the processing is interrupted at the nearest episode end. """ - if not demonstrations: - return - - reward = np.float32(reward) if reward is not None else reward - remaining_transitions = min_num_transitions - step_type = None - action = None - ts = dm_env.TimeStep(None, None, None, None) # Unused. - while remaining_transitions > 0: - # In case we share the adder or demonstrations don't end with - # end-of-episode, reset the adder prior to add_first. + if not demonstrations: + return + + reward = np.float32(reward) if reward is not None else reward + remaining_transitions = min_num_transitions + step_type = None + action = None + ts = dm_env.TimeStep(None, None, None, None) # Unused. + while remaining_transitions > 0: + # In case we share the adder or demonstrations don't end with + # end-of-episode, reset the adder prior to add_first. + adder.reset() + for transition_num, transition in enumerate(demonstrations): + remaining_transitions -= 1 + discount = np.float32(1.0) + ts_reward = reward if reward is not None else transition.reward + if step_type == dm_env.StepType.LAST or transition_num == 0: + ts = dm_env.TimeStep( + dm_env.StepType.FIRST, ts_reward, discount, transition.observation + ) + adder.add_first(ts) + + observation = transition.next_observation + action = transition.action + if transition.discount == 0.0 or transition_num == len(demonstrations) - 1: + step_type = dm_env.StepType.LAST + discount = np.float32(0.0) + else: + step_type = dm_env.StepType.MID + ts = dm_env.TimeStep(step_type, ts_reward, discount, observation) + adder.add(action, ts) + if remaining_transitions <= 0: + # Note: we could check `step_type == dm_env.StepType.LAST` to stop at an + # episode end if possible. + break + + # Explicitly finalize the Reverb client writes. adder.reset() - for transition_num, transition in enumerate(demonstrations): - remaining_transitions -= 1 - discount = np.float32(1.0) - ts_reward = reward if reward is not None else transition.reward - if step_type == dm_env.StepType.LAST or transition_num == 0: - ts = dm_env.TimeStep(dm_env.StepType.FIRST, ts_reward, discount, - transition.observation) - adder.add_first(ts) - - observation = transition.next_observation - action = transition.action - if transition.discount == 0. or transition_num == len(demonstrations) - 1: - step_type = dm_env.StepType.LAST - discount = np.float32(0.0) - else: - step_type = dm_env.StepType.MID - ts = dm_env.TimeStep(step_type, ts_reward, discount, observation) - adder.add(action, ts) - if remaining_transitions <= 0: - # Note: we could check `step_type == dm_env.StepType.LAST` to stop at an - # episode end if possible. - break - - # Explicitly finalize the Reverb client writes. - adder.reset() - - -class PWILBuilder(builders.ActorLearnerBuilder[DirectRLNetworks, - DirectPolicyNetwork, - reverb.ReplaySample], - Generic[DirectRLNetworks, DirectPolicyNetwork]): - """PWIL Agent builder.""" - - def __init__(self, - rl_agent: builders.ActorLearnerBuilder[DirectRLNetworks, - DirectPolicyNetwork, - reverb.ReplaySample], - config: pwil_config.PWILConfig, - demonstrations_fn: Callable[[], pwil_config.PWILDemonstrations]): - """Initialize the agent. + + +class PWILBuilder( + builders.ActorLearnerBuilder[ + DirectRLNetworks, DirectPolicyNetwork, reverb.ReplaySample + ], + Generic[DirectRLNetworks, DirectPolicyNetwork], +): + """PWIL Agent builder.""" + + def __init__( + self, + rl_agent: builders.ActorLearnerBuilder[ + DirectRLNetworks, DirectPolicyNetwork, reverb.ReplaySample + ], + config: pwil_config.PWILConfig, + demonstrations_fn: Callable[[], pwil_config.PWILDemonstrations], + ): + """Initialize the agent. Args: rl_agent: the standard RL algorithm. @@ -111,93 +118,100 @@ def __init__(self, demonstrations_fn: A function that returns an iterator over contiguous demonstration transitions, and the average demonstration episode length. """ - self._rl_agent = rl_agent - self._config = config - self._demonstrations_fn = demonstrations_fn - super().__init__() - - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: DirectRLNetworks, - dataset: Iterator[reverb.ReplaySample], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - return self._rl_agent.make_learner( - random_key=random_key, - networks=networks, - dataset=dataset, - logger_fn=logger_fn, - environment_spec=environment_spec, - replay_client=replay_client, - counter=counter) - - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: DirectPolicyNetwork, - ) -> List[reverb.Table]: - return self._rl_agent.make_replay_tables(environment_spec, policy) - - def make_dataset_iterator( # pytype: disable=signature-mismatch # overriding-return-type-checks - self, - replay_client: reverb.Client) -> Optional[Iterator[reverb.ReplaySample]]: - # make_dataset_iterator is only called once (per learner), to pass the - # iterator to make_learner. By using adders we ensure the transition types - # (e.g. n-step transitions) that the direct RL agent expects. - if self._config.num_transitions_rb > 0: - - def prefill_thread(): - # Populating the replay buffer with the direct RL agent guarantees that - # a constant reward will be used, not the imitation reward. - prefill_reward = ( - self._config.alpha - if self._config.prefill_constant_reward else None) - _prefill_with_demonstrations( - adder=self._rl_agent.make_adder(replay_client, None, None), - demonstrations=list(self._demonstrations_fn().demonstrations), - min_num_transitions=self._config.num_transitions_rb, - reward=prefill_reward) - # Populate the replay buffer in a separate thread, so that the learner - # can sample from the buffer, to avoid blocking on the buffer being full. - threading.Thread(target=prefill_thread, daemon=True).start() - - return self._rl_agent.make_dataset_iterator(replay_client) - - def make_adder( - self, - replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[DirectPolicyNetwork], - ) -> Optional[adders.Adder]: - """Creates the adder substituting imitation reward.""" - pwil_demonstrations = self._demonstrations_fn() - return pwil_adder.PWILAdder( - direct_rl_adder=self._rl_agent.make_adder(replay_client, - environment_spec, policy), - pwil_rewarder=rewarder.WassersteinDistanceRewarder( - demonstrations_it=pwil_demonstrations.demonstrations, - episode_length=pwil_demonstrations.episode_length, - use_actions_for_distance=self._config.use_actions_for_distance, - alpha=self._config.alpha, - beta=self._config.beta)) - - def make_actor( - self, - random_key: PRNGKey, - policy: DirectPolicyNetwork, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - adder: Optional[adders.Adder] = None, - ) -> core.Actor: - return self._rl_agent.make_actor(random_key, policy, environment_spec, - variable_source, adder) - - def make_policy(self, - networks: DirectRLNetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> DirectPolicyNetwork: - return self._rl_agent.make_policy(networks, environment_spec, evaluation) + self._rl_agent = rl_agent + self._config = config + self._demonstrations_fn = demonstrations_fn + super().__init__() + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: DirectRLNetworks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + return self._rl_agent.make_learner( + random_key=random_key, + networks=networks, + dataset=dataset, + logger_fn=logger_fn, + environment_spec=environment_spec, + replay_client=replay_client, + counter=counter, + ) + + def make_replay_tables( + self, environment_spec: specs.EnvironmentSpec, policy: DirectPolicyNetwork, + ) -> List[reverb.Table]: + return self._rl_agent.make_replay_tables(environment_spec, policy) + + def make_dataset_iterator( # pytype: disable=signature-mismatch # overriding-return-type-checks + self, replay_client: reverb.Client + ) -> Optional[Iterator[reverb.ReplaySample]]: + # make_dataset_iterator is only called once (per learner), to pass the + # iterator to make_learner. By using adders we ensure the transition types + # (e.g. n-step transitions) that the direct RL agent expects. + if self._config.num_transitions_rb > 0: + + def prefill_thread(): + # Populating the replay buffer with the direct RL agent guarantees that + # a constant reward will be used, not the imitation reward. + prefill_reward = ( + self._config.alpha if self._config.prefill_constant_reward else None + ) + _prefill_with_demonstrations( + adder=self._rl_agent.make_adder(replay_client, None, None), + demonstrations=list(self._demonstrations_fn().demonstrations), + min_num_transitions=self._config.num_transitions_rb, + reward=prefill_reward, + ) + + # Populate the replay buffer in a separate thread, so that the learner + # can sample from the buffer, to avoid blocking on the buffer being full. + threading.Thread(target=prefill_thread, daemon=True).start() + + return self._rl_agent.make_dataset_iterator(replay_client) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[DirectPolicyNetwork], + ) -> Optional[adders.Adder]: + """Creates the adder substituting imitation reward.""" + pwil_demonstrations = self._demonstrations_fn() + return pwil_adder.PWILAdder( + direct_rl_adder=self._rl_agent.make_adder( + replay_client, environment_spec, policy + ), + pwil_rewarder=rewarder.WassersteinDistanceRewarder( + demonstrations_it=pwil_demonstrations.demonstrations, + episode_length=pwil_demonstrations.episode_length, + use_actions_for_distance=self._config.use_actions_for_distance, + alpha=self._config.alpha, + beta=self._config.beta, + ), + ) + + def make_actor( + self, + random_key: PRNGKey, + policy: DirectPolicyNetwork, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + return self._rl_agent.make_actor( + random_key, policy, environment_spec, variable_source, adder + ) + + def make_policy( + self, + networks: DirectRLNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> DirectPolicyNetwork: + return self._rl_agent.make_policy(networks, environment_spec, evaluation) diff --git a/acme/agents/jax/pwil/config.py b/acme/agents/jax/pwil/config.py index 83cdd12003..a4fa5032e6 100644 --- a/acme/agents/jax/pwil/config.py +++ b/acme/agents/jax/pwil/config.py @@ -21,35 +21,36 @@ @dataclasses.dataclass class PWILConfig: - """Configuration options for PWIL. + """Configuration options for PWIL. The default values correspond to the experiment setup from the PWIL publication http://arxiv.org/abs/2006.04678. """ - # Number of transitions to fill the replay buffer with for pretraining. - num_transitions_rb: int = 50000 + # Number of transitions to fill the replay buffer with for pretraining. + num_transitions_rb: int = 50000 - # If False, uses only observations for computing the distance; if True, also - # uses the actions. - use_actions_for_distance: bool = True + # If False, uses only observations for computing the distance; if True, also + # uses the actions. + use_actions_for_distance: bool = True - # Scaling for the reward function, see equation (6) in - # http://arxiv.org/abs/2006.04678. - alpha: float = 5. + # Scaling for the reward function, see equation (6) in + # http://arxiv.org/abs/2006.04678. + alpha: float = 5.0 - # Controls the kernel size of the reward function, see equation (6) - # in http://arxiv.org/abs/2006.04678. - beta: float = 5. + # Controls the kernel size of the reward function, see equation (6) + # in http://arxiv.org/abs/2006.04678. + beta: float = 5.0 - # When False, uses the reward signal from the dataset during prefilling. - prefill_constant_reward: bool = True + # When False, uses the reward signal from the dataset during prefilling. + prefill_constant_reward: bool = True - num_sgd_steps_per_step: int = 1 + num_sgd_steps_per_step: int = 1 @dataclasses.dataclass class PWILDemonstrations: - """Unbatched, unshuffled transitions with approximate episode length.""" - demonstrations: Iterator[types.Transition] - episode_length: int + """Unbatched, unshuffled transitions with approximate episode length.""" + + demonstrations: Iterator[types.Transition] + episode_length: int diff --git a/acme/agents/jax/pwil/rewarder.py b/acme/agents/jax/pwil/rewarder.py index b94a41bdfc..ec92cee967 100644 --- a/acme/agents/jax/pwil/rewarder.py +++ b/acme/agents/jax/pwil/rewarder.py @@ -16,26 +16,29 @@ from typing import Iterator -from acme import types import jax import jax.numpy as jnp import numpy as np +from acme import types + class WassersteinDistanceRewarder: - """Computes PWIL rewards along a trajectory. + """Computes PWIL rewards along a trajectory. The rewards measure similarity to the demonstration transitions and are based on a greedy approximation to the Wasserstein distance between trajectories. """ - def __init__(self, - demonstrations_it: Iterator[types.Transition], - episode_length: int, - use_actions_for_distance: bool = False, - alpha: float = 5., - beta: float = 5.): - """Initializes the rewarder. + def __init__( + self, + demonstrations_it: Iterator[types.Transition], + episode_length: int, + use_actions_for_distance: bool = False, + alpha: float = 5.0, + beta: float = 5.0, + ): + """Initializes the rewarder. Args: demonstrations_it: An iterator over acme.types.Transition. @@ -45,30 +48,30 @@ def __init__(self, alpha: float scaling the reward function. beta: float controling the kernel size of the reward function. """ - self._episode_length = episode_length + self._episode_length = episode_length - self._use_actions_for_distance = use_actions_for_distance - self._vectorized_demonstrations = self._vectorize(demonstrations_it) + self._use_actions_for_distance = use_actions_for_distance + self._vectorized_demonstrations = self._vectorize(demonstrations_it) - # Observations and actions are flat. - atom_dims = self._vectorized_demonstrations.shape[1] - self._reward_sigma = beta * self._episode_length / np.sqrt(atom_dims) - self._reward_scale = alpha + # Observations and actions are flat. + atom_dims = self._vectorized_demonstrations.shape[1] + self._reward_sigma = beta * self._episode_length / np.sqrt(atom_dims) + self._reward_scale = alpha - self._std = np.std(self._vectorized_demonstrations, axis=0, dtype='float64') - # The std is set to 1 if the observation values are below a threshold. - # This prevents normalizing observation values that are constant (which can - # be problematic with e.g. demonstrations coming from a different version - # of the environment and where the constant values are slightly different). - self._std = (self._std < 1e-6) + self._std + self._std = np.std(self._vectorized_demonstrations, axis=0, dtype="float64") + # The std is set to 1 if the observation values are below a threshold. + # This prevents normalizing observation values that are constant (which can + # be problematic with e.g. demonstrations coming from a different version + # of the environment and where the constant values are slightly different). + self._std = (self._std < 1e-6) + self._std - self.expert_atoms = self._vectorized_demonstrations / self._std - self._compute_norm = jax.jit(lambda a, b: jnp.linalg.norm(a - b, axis=1), - device=jax.devices('cpu')[0]) + self.expert_atoms = self._vectorized_demonstrations / self._std + self._compute_norm = jax.jit( + lambda a, b: jnp.linalg.norm(a - b, axis=1), device=jax.devices("cpu")[0] + ) - def _vectorize(self, - demonstrations_it: Iterator[types.Transition]) -> np.ndarray: - """Converts filtered expert demonstrations to numpy array. + def _vectorize(self, demonstrations_it: Iterator[types.Transition]) -> np.ndarray: + """Converts filtered expert demonstrations to numpy array. Args: demonstrations_it: list of expert demonstrations @@ -78,23 +81,24 @@ def _vectorize(self, [num_expert_transitions, dim_observation] if not use_actions_for_distance [num_expert_transitions, (dim_observation + dim_action)] otherwise """ - if self._use_actions_for_distance: - demonstrations = [ - np.concatenate([t.observation, t.action]) for t in demonstrations_it - ] - else: - demonstrations = [t.observation for t in demonstrations_it] - return np.array(demonstrations) - - def reset(self) -> None: - """Makes all expert transitions available and initialize weights.""" - num_expert_atoms = len(self.expert_atoms) - self._all_expert_weights_zero = False - self.expert_weights = np.ones(num_expert_atoms) / num_expert_atoms - - def append_and_compute_reward(self, observation: jnp.ndarray, - action: jnp.ndarray) -> np.float32: - """Computes reward and updates state, advancing it along a trajectory. + if self._use_actions_for_distance: + demonstrations = [ + np.concatenate([t.observation, t.action]) for t in demonstrations_it + ] + else: + demonstrations = [t.observation for t in demonstrations_it] + return np.array(demonstrations) + + def reset(self) -> None: + """Makes all expert transitions available and initialize weights.""" + num_expert_atoms = len(self.expert_atoms) + self._all_expert_weights_zero = False + self.expert_weights = np.ones(num_expert_atoms) / num_expert_atoms + + def append_and_compute_reward( + self, observation: jnp.ndarray, action: jnp.ndarray + ) -> np.float32: + """Computes reward and updates state, advancing it along a trajectory. Subsequent calls to append_and_compute_reward assume inputs are subsequent trajectory points. @@ -108,50 +112,50 @@ def append_and_compute_reward(self, observation: jnp.ndarray, the reward value: the return contribution from the trajectory point. """ - # If we run out of demonstrations, penalize further action. - if self._all_expert_weights_zero: - return np.float32(0.) - - # Scale observation and action. - if self._use_actions_for_distance: - agent_atom = np.concatenate([observation, action]) - else: - agent_atom = observation - agent_atom /= self._std - - cost = 0. - # A special marker for records with zero expert weight. Has to be large so - # that argmin will not return it. - DELETED = 1e10 # pylint: disable=invalid-name - # As we match the expert's weights with the agent's weights, we might - # raise an error due to float precision, we substract a small epsilon from - # the agent's weights to prevent that. - weight = 1. / self._episode_length - 1e-6 - norms = np.array(self._compute_norm(self.expert_atoms, agent_atom)) - # We need to mask out states with zero weight, so that 'argmin' would not - # return them. - adjusted_norms = (1 - np.sign(self.expert_weights)) * DELETED + norms - while weight > 0: - # Get closest expert state action to agent's state action. - argmin = adjusted_norms.argmin() - effective_weight = min(weight, self.expert_weights[argmin]) - - if adjusted_norms[argmin] >= DELETED: - self._all_expert_weights_zero = True - break - - # Update cost and weights. - weight -= effective_weight - self.expert_weights[argmin] -= effective_weight - cost += effective_weight * norms[argmin] - adjusted_norms[argmin] = DELETED - - if weight > 0: - # We have a 'partial' cost if we ran out of demonstrations in the reward - # computation loop. We assign a high cost (infinite) in this case which - # makes the reward equal to 0. - reward = np.array(0.) - else: - reward = self._reward_scale * np.exp(-self._reward_sigma * cost) - - return reward.astype('float32') + # If we run out of demonstrations, penalize further action. + if self._all_expert_weights_zero: + return np.float32(0.0) + + # Scale observation and action. + if self._use_actions_for_distance: + agent_atom = np.concatenate([observation, action]) + else: + agent_atom = observation + agent_atom /= self._std + + cost = 0.0 + # A special marker for records with zero expert weight. Has to be large so + # that argmin will not return it. + DELETED = 1e10 # pylint: disable=invalid-name + # As we match the expert's weights with the agent's weights, we might + # raise an error due to float precision, we substract a small epsilon from + # the agent's weights to prevent that. + weight = 1.0 / self._episode_length - 1e-6 + norms = np.array(self._compute_norm(self.expert_atoms, agent_atom)) + # We need to mask out states with zero weight, so that 'argmin' would not + # return them. + adjusted_norms = (1 - np.sign(self.expert_weights)) * DELETED + norms + while weight > 0: + # Get closest expert state action to agent's state action. + argmin = adjusted_norms.argmin() + effective_weight = min(weight, self.expert_weights[argmin]) + + if adjusted_norms[argmin] >= DELETED: + self._all_expert_weights_zero = True + break + + # Update cost and weights. + weight -= effective_weight + self.expert_weights[argmin] -= effective_weight + cost += effective_weight * norms[argmin] + adjusted_norms[argmin] = DELETED + + if weight > 0: + # We have a 'partial' cost if we ran out of demonstrations in the reward + # computation loop. We assign a high cost (infinite) in this case which + # makes the reward equal to 0. + reward = np.array(0.0) + else: + reward = self._reward_scale * np.exp(-self._reward_sigma * cost) + + return reward.astype("float32") diff --git a/acme/agents/jax/r2d2/__init__.py b/acme/agents/jax/r2d2/__init__.py index 63b1ddf223..06850aa836 100644 --- a/acme/agents/jax/r2d2/__init__.py +++ b/acme/agents/jax/r2d2/__init__.py @@ -14,11 +14,8 @@ """Implementation of an R2D2 agent.""" -from acme.agents.jax.r2d2.actor import EpsilonRecurrentPolicy -from acme.agents.jax.r2d2.actor import make_behavior_policy +from acme.agents.jax.r2d2.actor import EpsilonRecurrentPolicy, make_behavior_policy from acme.agents.jax.r2d2.builder import R2D2Builder from acme.agents.jax.r2d2.config import R2D2Config -from acme.agents.jax.r2d2.learning import R2D2Learner -from acme.agents.jax.r2d2.learning import R2D2ReplaySample -from acme.agents.jax.r2d2.networks import make_atari_networks -from acme.agents.jax.r2d2.networks import R2D2Networks +from acme.agents.jax.r2d2.learning import R2D2Learner, R2D2ReplaySample +from acme.agents.jax.r2d2.networks import R2D2Networks, make_atari_networks diff --git a/acme/agents/jax/r2d2/actor.py b/acme/agents/jax/r2d2/actor.py index 0206dbcf12..628403bc87 100644 --- a/acme/agents/jax/r2d2/actor.py +++ b/acme/agents/jax/r2d2/actor.py @@ -16,35 +16,43 @@ from typing import Callable, Generic, Mapping, Optional, Tuple -from acme import types -from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax.r2d2 import config as r2d2_config -from acme.agents.jax.r2d2 import networks as r2d2_networks -from acme.jax import networks as networks_lib import chex import jax import jax.numpy as jnp import numpy as np import rlax +from acme import types +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax.r2d2 import config as r2d2_config +from acme.agents.jax.r2d2 import networks as r2d2_networks +from acme.jax import networks as networks_lib + Epsilon = float R2D2Extras = Mapping[str, jnp.ndarray] -EpsilonRecurrentPolicy = Callable[[ - networks_lib.Params, networks_lib.PRNGKey, networks_lib - .Observation, actor_core_lib.RecurrentState, Epsilon -], Tuple[networks_lib.Action, actor_core_lib.RecurrentState]] +EpsilonRecurrentPolicy = Callable[ + [ + networks_lib.Params, + networks_lib.PRNGKey, + networks_lib.Observation, + actor_core_lib.RecurrentState, + Epsilon, + ], + Tuple[networks_lib.Action, actor_core_lib.RecurrentState], +] @chex.dataclass(frozen=True, mappable_dataclass=False) class R2D2ActorState(Generic[actor_core_lib.RecurrentState]): - rng: networks_lib.PRNGKey - epsilon: jnp.ndarray - recurrent_state: actor_core_lib.RecurrentState - prev_recurrent_state: actor_core_lib.RecurrentState + rng: networks_lib.PRNGKey + epsilon: jnp.ndarray + recurrent_state: actor_core_lib.RecurrentState + prev_recurrent_state: actor_core_lib.RecurrentState R2D2Policy = actor_core_lib.ActorCore[ - R2D2ActorState[actor_core_lib.RecurrentState], R2D2Extras] + R2D2ActorState[actor_core_lib.RecurrentState], R2D2Extras +] def get_actor_core( @@ -52,65 +60,82 @@ def get_actor_core( num_epsilons: Optional[int], evaluation_epsilon: Optional[float] = None, ) -> R2D2Policy: - """Returns ActorCore for R2D2.""" - - if (not num_epsilons and evaluation_epsilon is None) or (num_epsilons and - evaluation_epsilon): - raise ValueError( - 'Exactly one of `num_epsilons` or `evaluation_epsilon` must be ' - f'specified. Received num_epsilon={num_epsilons} and ' - f'evaluation_epsilon={evaluation_epsilon}.') - - def select_action(params: networks_lib.Params, - observation: networks_lib.Observation, - state: R2D2ActorState[actor_core_lib.RecurrentState]): - rng, policy_rng = jax.random.split(state.rng) - - q_values, recurrent_state = networks.apply(params, policy_rng, observation, - state.recurrent_state) - action = rlax.epsilon_greedy(state.epsilon).sample(policy_rng, q_values) - - return action, R2D2ActorState( - rng=rng, - epsilon=state.epsilon, - recurrent_state=recurrent_state, - prev_recurrent_state=state.recurrent_state) - - def init( - rng: networks_lib.PRNGKey - ) -> R2D2ActorState[actor_core_lib.RecurrentState]: - rng, epsilon_rng, state_rng = jax.random.split(rng, 3) - if num_epsilons: - epsilon = jax.random.choice(epsilon_rng, - np.logspace(1, 3, num_epsilons, base=0.1)) - else: - epsilon = evaluation_epsilon - initial_core_state = networks.init_recurrent_state(state_rng, None) - return R2D2ActorState( - rng=rng, - epsilon=epsilon, - recurrent_state=initial_core_state, - prev_recurrent_state=initial_core_state) - - def get_extras( - state: R2D2ActorState[actor_core_lib.RecurrentState]) -> R2D2Extras: - return {'core_state': state.prev_recurrent_state} - - return actor_core_lib.ActorCore(init=init, select_action=select_action, - get_extras=get_extras) + """Returns ActorCore for R2D2.""" + + if (not num_epsilons and evaluation_epsilon is None) or ( + num_epsilons and evaluation_epsilon + ): + raise ValueError( + "Exactly one of `num_epsilons` or `evaluation_epsilon` must be " + f"specified. Received num_epsilon={num_epsilons} and " + f"evaluation_epsilon={evaluation_epsilon}." + ) + + def select_action( + params: networks_lib.Params, + observation: networks_lib.Observation, + state: R2D2ActorState[actor_core_lib.RecurrentState], + ): + rng, policy_rng = jax.random.split(state.rng) + + q_values, recurrent_state = networks.apply( + params, policy_rng, observation, state.recurrent_state + ) + action = rlax.epsilon_greedy(state.epsilon).sample(policy_rng, q_values) + + return ( + action, + R2D2ActorState( + rng=rng, + epsilon=state.epsilon, + recurrent_state=recurrent_state, + prev_recurrent_state=state.recurrent_state, + ), + ) + + def init( + rng: networks_lib.PRNGKey, + ) -> R2D2ActorState[actor_core_lib.RecurrentState]: + rng, epsilon_rng, state_rng = jax.random.split(rng, 3) + if num_epsilons: + epsilon = jax.random.choice( + epsilon_rng, np.logspace(1, 3, num_epsilons, base=0.1) + ) + else: + epsilon = evaluation_epsilon + initial_core_state = networks.init_recurrent_state(state_rng, None) + return R2D2ActorState( + rng=rng, + epsilon=epsilon, + recurrent_state=initial_core_state, + prev_recurrent_state=initial_core_state, + ) + + def get_extras(state: R2D2ActorState[actor_core_lib.RecurrentState]) -> R2D2Extras: + return {"core_state": state.prev_recurrent_state} + + return actor_core_lib.ActorCore( + init=init, select_action=select_action, get_extras=get_extras + ) # TODO(bshahr): Deprecate this in favour of R2D2Builder.make_policy. -def make_behavior_policy(networks: r2d2_networks.R2D2Networks, - config: r2d2_config.R2D2Config, - evaluation: bool = False) -> EpsilonRecurrentPolicy: - """Selects action according to the policy.""" - - def behavior_policy(params: networks_lib.Params, key: networks_lib.PRNGKey, - observation: types.NestedArray, - core_state: types.NestedArray, epsilon: float): - q_values, core_state = networks.apply(params, key, observation, core_state) - epsilon = config.evaluation_epsilon if evaluation else epsilon - return rlax.epsilon_greedy(epsilon).sample(key, q_values), core_state - - return behavior_policy +def make_behavior_policy( + networks: r2d2_networks.R2D2Networks, + config: r2d2_config.R2D2Config, + evaluation: bool = False, +) -> EpsilonRecurrentPolicy: + """Selects action according to the policy.""" + + def behavior_policy( + params: networks_lib.Params, + key: networks_lib.PRNGKey, + observation: types.NestedArray, + core_state: types.NestedArray, + epsilon: float, + ): + q_values, core_state = networks.apply(params, key, observation, core_state) + epsilon = config.evaluation_epsilon if evaluation else epsilon + return rlax.epsilon_greedy(epsilon).sample(key, q_values), core_state + + return behavior_policy diff --git a/acme/agents/jax/r2d2/builder.py b/acme/agents/jax/r2d2/builder.py index 3637110201..c5fbeed960 100644 --- a/acme/agents/jax/r2d2/builder.py +++ b/acme/agents/jax/r2d2/builder.py @@ -15,32 +15,28 @@ """R2D2 Builder.""" from typing import Generic, Iterator, List, Optional +import jax +import optax +import reverb +import tensorflow as tf +import tree +from reverb import structured_writer as sw + import acme -from acme import adders -from acme import core -from acme import specs +from acme import adders, core, specs from acme.adders import reverb as adders_reverb from acme.adders.reverb import base as reverb_base from acme.adders.reverb import structured from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.agents.jax import builders +from acme.agents.jax import actors, builders from acme.agents.jax.r2d2 import actor as r2d2_actor from acme.agents.jax.r2d2 import config as r2d2_config from acme.agents.jax.r2d2 import learning as r2d2_learning from acme.agents.jax.r2d2 import networks as r2d2_networks from acme.datasets import reverb as datasets from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.jax import variable_utils -from acme.utils import counting -from acme.utils import loggers -import jax -import optax -import reverb -from reverb import structured_writer as sw -import tensorflow as tf -import tree +from acme.jax import utils, variable_utils +from acme.utils import counting, loggers # TODO(b/450949030): extrac the private functions to a library once other agents # reuse them. @@ -50,214 +46,235 @@ # We have to check if this requires moving _zero_pad to the adder. -def _build_sequence(length: int, - step_spec: reverb_base.Step) -> reverb_base.Trajectory: - """Constructs the sequence using only the first value of core_state.""" - step_dict = step_spec._asdict() - extras_dict = step_dict.pop('extras') - return reverb_base.Trajectory( - **tree.map_structure(lambda x: x[-length:], step_dict), - extras=tree.map_structure(lambda x: x[-length], extras_dict)) +def _build_sequence(length: int, step_spec: reverb_base.Step) -> reverb_base.Trajectory: + """Constructs the sequence using only the first value of core_state.""" + step_dict = step_spec._asdict() + extras_dict = step_dict.pop("extras") + return reverb_base.Trajectory( + **tree.map_structure(lambda x: x[-length:], step_dict), + extras=tree.map_structure(lambda x: x[-length], extras_dict) + ) def _zero_pad(sequence_length: int) -> datasets.Transform: - """Adds zero padding to the right so all samples have the same length.""" + """Adds zero padding to the right so all samples have the same length.""" - def _zero_pad_transform(sample: reverb.ReplaySample) -> reverb.ReplaySample: - trajectory: reverb_base.Trajectory = sample.data + def _zero_pad_transform(sample: reverb.ReplaySample) -> reverb.ReplaySample: + trajectory: reverb_base.Trajectory = sample.data - # Split steps and extras data (the extras won't be padded as they only - # contain one element) - trajectory_steps = trajectory._asdict() - trajectory_extras = trajectory_steps.pop('extras') + # Split steps and extras data (the extras won't be padded as they only + # contain one element) + trajectory_steps = trajectory._asdict() + trajectory_extras = trajectory_steps.pop("extras") - unpadded_length = len(tree.flatten(trajectory_steps)[0]) + unpadded_length = len(tree.flatten(trajectory_steps)[0]) - # Do nothing if the sequence is already full. - if unpadded_length != sequence_length: - to_pad = sequence_length - unpadded_length - pad = lambda x: tf.pad(x, [[0, to_pad]] + [[0, 0]] * (len(x.shape) - 1)) + # Do nothing if the sequence is already full. + if unpadded_length != sequence_length: + to_pad = sequence_length - unpadded_length + pad = lambda x: tf.pad(x, [[0, to_pad]] + [[0, 0]] * (len(x.shape) - 1)) - trajectory_steps = tree.map_structure(pad, trajectory_steps) + trajectory_steps = tree.map_structure(pad, trajectory_steps) - # Set the shape to be statically known, and checks it at runtime. - def _ensure_shape(x): - shape = tf.TensorShape([sequence_length]).concatenate(x.shape[1:]) - return tf.ensure_shape(x, shape) + # Set the shape to be statically known, and checks it at runtime. + def _ensure_shape(x): + shape = tf.TensorShape([sequence_length]).concatenate(x.shape[1:]) + return tf.ensure_shape(x, shape) - trajectory_steps = tree.map_structure(_ensure_shape, trajectory_steps) - return reverb.ReplaySample( - info=sample.info, - data=reverb_base.Trajectory( - **trajectory_steps, extras=trajectory_extras)) + trajectory_steps = tree.map_structure(_ensure_shape, trajectory_steps) + return reverb.ReplaySample( + info=sample.info, + data=reverb_base.Trajectory(**trajectory_steps, extras=trajectory_extras), + ) - return _zero_pad_transform + return _zero_pad_transform -def _make_adder_config(step_spec: reverb_base.Step, seq_len: int, - seq_period: int) -> List[sw.Config]: - return structured.create_sequence_config( - step_spec=step_spec, - sequence_length=seq_len, - period=seq_period, - end_of_episode_behavior=adders_reverb.EndBehavior.TRUNCATE, - sequence_pattern=_build_sequence) +def _make_adder_config( + step_spec: reverb_base.Step, seq_len: int, seq_period: int +) -> List[sw.Config]: + return structured.create_sequence_config( + step_spec=step_spec, + sequence_length=seq_len, + period=seq_period, + end_of_episode_behavior=adders_reverb.EndBehavior.TRUNCATE, + sequence_pattern=_build_sequence, + ) -class R2D2Builder(Generic[actor_core_lib.RecurrentState], - builders.ActorLearnerBuilder[r2d2_networks.R2D2Networks, - r2d2_actor.R2D2Policy, - r2d2_learning.R2D2ReplaySample]): - """R2D2 Builder. +class R2D2Builder( + Generic[actor_core_lib.RecurrentState], + builders.ActorLearnerBuilder[ + r2d2_networks.R2D2Networks, + r2d2_actor.R2D2Policy, + r2d2_learning.R2D2ReplaySample, + ], +): + """R2D2 Builder. This is constructs all of the components for Recurrent Experience Replay in Distributed Reinforcement Learning (Kapturowski et al.) https://openreview.net/pdf?id=r1lyTjAqYX. """ - def __init__(self, config: r2d2_config.R2D2Config): - """Creates a R2D2 learner, a behavior policy and an eval actor.""" - self._config = config - self._sequence_length = ( - self._config.burn_in_length + self._config.trace_length + 1) - - @property - def _batch_size_per_device(self) -> int: - """Splits batch size across all learner devices evenly.""" - # TODO(bshahr): Using jax.device_count will not be valid when colocating - # learning and inference. - return self._config.batch_size // jax.device_count() - - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: r2d2_networks.R2D2Networks, - dataset: Iterator[r2d2_learning.R2D2ReplaySample], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - del environment_spec - - # The learner updates the parameters (and initializes them). - return r2d2_learning.R2D2Learner( - networks=networks, - batch_size=self._batch_size_per_device, - random_key=random_key, - burn_in_length=self._config.burn_in_length, - discount=self._config.discount, - importance_sampling_exponent=( - self._config.importance_sampling_exponent), - max_priority_weight=self._config.max_priority_weight, - target_update_period=self._config.target_update_period, - iterator=dataset, - optimizer=optax.adam(self._config.learning_rate), - bootstrap_n=self._config.bootstrap_n, - tx_pair=self._config.tx_pair, - clip_rewards=self._config.clip_rewards, - replay_client=replay_client, - counter=counter, - logger=logger_fn('learner')) - - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: r2d2_actor.R2D2Policy, - ) -> List[reverb.Table]: - """Create tables to insert data into.""" - dummy_actor_state = policy.init(jax.random.PRNGKey(0)) - extras_spec = policy.get_extras(dummy_actor_state) - step_spec = structured.create_step_spec( - environment_spec=environment_spec, extras_spec=extras_spec) - if self._config.samples_per_insert: - samples_per_insert_tolerance = ( - self._config.samples_per_insert_tolerance_rate * - self._config.samples_per_insert) - error_buffer = self._config.min_replay_size * samples_per_insert_tolerance - limiter = reverb.rate_limiters.SampleToInsertRatio( - min_size_to_sample=self._config.min_replay_size, - samples_per_insert=self._config.samples_per_insert, - error_buffer=error_buffer) - else: - limiter = reverb.rate_limiters.MinSize(self._config.min_replay_size) - return [ - reverb.Table( - name=self._config.replay_table_name, - sampler=reverb.selectors.Prioritized( - self._config.priority_exponent), - remover=reverb.selectors.Fifo(), - max_size=self._config.max_replay_size, - rate_limiter=limiter, - signature=sw.infer_signature( - configs=_make_adder_config(step_spec, self._sequence_length, - self._config.sequence_period), - step_spec=step_spec)) - ] - - def make_dataset_iterator( - self, - replay_client: reverb.Client) -> Iterator[r2d2_learning.R2D2ReplaySample]: - """Create a dataset iterator to use for learning/updating the agent.""" - batch_size_per_learner = self._config.batch_size // jax.process_count() - dataset = datasets.make_reverb_dataset( - table=self._config.replay_table_name, - server_address=replay_client.server_address, - batch_size=self._batch_size_per_device, - num_parallel_calls=None, - max_in_flight_samples_per_worker=2 * batch_size_per_learner, - postprocess=_zero_pad(self._sequence_length), - ) - - return utils.multi_device_put( - dataset.as_numpy_iterator(), - devices=jax.local_devices(), - split_fn=utils.keep_key_on_host) - - def make_adder( - self, replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[r2d2_actor.R2D2Policy]) -> Optional[adders.Adder]: - """Create an adder which records data generated by the actor/environment.""" - if environment_spec is None or policy is None: - raise ValueError('`environment_spec` and `policy` cannot be None.') - dummy_actor_state = policy.init(jax.random.PRNGKey(0)) - extras_spec = policy.get_extras(dummy_actor_state) - step_spec = structured.create_step_spec( - environment_spec=environment_spec, extras_spec=extras_spec) - return structured.StructuredAdder( - client=replay_client, - max_in_flight_items=5, - configs=_make_adder_config(step_spec, self._sequence_length, - self._config.sequence_period), - step_spec=step_spec) - - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: r2d2_actor.R2D2Policy, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - adder: Optional[adders.Adder] = None, - ) -> acme.Actor: - del environment_spec - # Create variable client. - variable_client = variable_utils.VariableClient( - variable_source, - key='actor_variables', - update_period=self._config.variable_update_period) - - return actors.GenericActor( - policy, random_key, variable_client, adder, backend='cpu') - - def make_policy(self, - networks: r2d2_networks.R2D2Networks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> r2d2_actor.R2D2Policy: - if evaluation: - return r2d2_actor.get_actor_core( - networks, - num_epsilons=None, - evaluation_epsilon=self._config.evaluation_epsilon) - else: - return r2d2_actor.get_actor_core(networks, self._config.num_epsilons) + def __init__(self, config: r2d2_config.R2D2Config): + """Creates a R2D2 learner, a behavior policy and an eval actor.""" + self._config = config + self._sequence_length = ( + self._config.burn_in_length + self._config.trace_length + 1 + ) + + @property + def _batch_size_per_device(self) -> int: + """Splits batch size across all learner devices evenly.""" + # TODO(bshahr): Using jax.device_count will not be valid when colocating + # learning and inference. + return self._config.batch_size // jax.device_count() + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: r2d2_networks.R2D2Networks, + dataset: Iterator[r2d2_learning.R2D2ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec + + # The learner updates the parameters (and initializes them). + return r2d2_learning.R2D2Learner( + networks=networks, + batch_size=self._batch_size_per_device, + random_key=random_key, + burn_in_length=self._config.burn_in_length, + discount=self._config.discount, + importance_sampling_exponent=(self._config.importance_sampling_exponent), + max_priority_weight=self._config.max_priority_weight, + target_update_period=self._config.target_update_period, + iterator=dataset, + optimizer=optax.adam(self._config.learning_rate), + bootstrap_n=self._config.bootstrap_n, + tx_pair=self._config.tx_pair, + clip_rewards=self._config.clip_rewards, + replay_client=replay_client, + counter=counter, + logger=logger_fn("learner"), + ) + + def make_replay_tables( + self, environment_spec: specs.EnvironmentSpec, policy: r2d2_actor.R2D2Policy, + ) -> List[reverb.Table]: + """Create tables to insert data into.""" + dummy_actor_state = policy.init(jax.random.PRNGKey(0)) + extras_spec = policy.get_extras(dummy_actor_state) + step_spec = structured.create_step_spec( + environment_spec=environment_spec, extras_spec=extras_spec + ) + if self._config.samples_per_insert: + samples_per_insert_tolerance = ( + self._config.samples_per_insert_tolerance_rate + * self._config.samples_per_insert + ) + error_buffer = self._config.min_replay_size * samples_per_insert_tolerance + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer, + ) + else: + limiter = reverb.rate_limiters.MinSize(self._config.min_replay_size) + return [ + reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Prioritized(self._config.priority_exponent), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=sw.infer_signature( + configs=_make_adder_config( + step_spec, self._sequence_length, self._config.sequence_period + ), + step_spec=step_spec, + ), + ) + ] + + def make_dataset_iterator( + self, replay_client: reverb.Client + ) -> Iterator[r2d2_learning.R2D2ReplaySample]: + """Create a dataset iterator to use for learning/updating the agent.""" + batch_size_per_learner = self._config.batch_size // jax.process_count() + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=replay_client.server_address, + batch_size=self._batch_size_per_device, + num_parallel_calls=None, + max_in_flight_samples_per_worker=2 * batch_size_per_learner, + postprocess=_zero_pad(self._sequence_length), + ) + + return utils.multi_device_put( + dataset.as_numpy_iterator(), + devices=jax.local_devices(), + split_fn=utils.keep_key_on_host, + ) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[r2d2_actor.R2D2Policy], + ) -> Optional[adders.Adder]: + """Create an adder which records data generated by the actor/environment.""" + if environment_spec is None or policy is None: + raise ValueError("`environment_spec` and `policy` cannot be None.") + dummy_actor_state = policy.init(jax.random.PRNGKey(0)) + extras_spec = policy.get_extras(dummy_actor_state) + step_spec = structured.create_step_spec( + environment_spec=environment_spec, extras_spec=extras_spec + ) + return structured.StructuredAdder( + client=replay_client, + max_in_flight_items=5, + configs=_make_adder_config( + step_spec, self._sequence_length, self._config.sequence_period + ), + step_spec=step_spec, + ) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: r2d2_actor.R2D2Policy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> acme.Actor: + del environment_spec + # Create variable client. + variable_client = variable_utils.VariableClient( + variable_source, + key="actor_variables", + update_period=self._config.variable_update_period, + ) + + return actors.GenericActor( + policy, random_key, variable_client, adder, backend="cpu" + ) + + def make_policy( + self, + networks: r2d2_networks.R2D2Networks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> r2d2_actor.R2D2Policy: + if evaluation: + return r2d2_actor.get_actor_core( + networks, + num_epsilons=None, + evaluation_epsilon=self._config.evaluation_epsilon, + ) + else: + return r2d2_actor.get_actor_core(networks, self._config.num_epsilons) diff --git a/acme/agents/jax/r2d2/config.py b/acme/agents/jax/r2d2/config.py index 2fc52d908c..d3ade9c6e0 100644 --- a/acme/agents/jax/r2d2/config.py +++ b/acme/agents/jax/r2d2/config.py @@ -15,39 +15,41 @@ """PPO config.""" import dataclasses -from acme.adders import reverb as adders_reverb import rlax +from acme.adders import reverb as adders_reverb + @dataclasses.dataclass class R2D2Config: - """Configuration options for R2D2 agent.""" - discount: float = 0.997 - target_update_period: int = 2500 - evaluation_epsilon: float = 0. - num_epsilons: int = 256 - variable_update_period: int = 400 - - # Learner options - burn_in_length: int = 40 - trace_length: int = 80 - sequence_period: int = 40 - learning_rate: float = 1e-3 - bootstrap_n: int = 5 - clip_rewards: bool = False - tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR - - # Replay options - samples_per_insert_tolerance_rate: float = 0.1 - samples_per_insert: float = 4.0 - min_replay_size: int = 50_000 - max_replay_size: int = 100_000 - batch_size: int = 64 - prefetch_size: int = 2 - num_parallel_calls: int = 16 - replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE - - # Priority options - importance_sampling_exponent: float = 0.6 - priority_exponent: float = 0.9 - max_priority_weight: float = 0.9 + """Configuration options for R2D2 agent.""" + + discount: float = 0.997 + target_update_period: int = 2500 + evaluation_epsilon: float = 0.0 + num_epsilons: int = 256 + variable_update_period: int = 400 + + # Learner options + burn_in_length: int = 40 + trace_length: int = 80 + sequence_period: int = 40 + learning_rate: float = 1e-3 + bootstrap_n: int = 5 + clip_rewards: bool = False + tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR + + # Replay options + samples_per_insert_tolerance_rate: float = 0.1 + samples_per_insert: float = 4.0 + min_replay_size: int = 50_000 + max_replay_size: int = 100_000 + batch_size: int = 64 + prefetch_size: int = 2 + num_parallel_calls: int = 16 + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE + + # Priority options + importance_sampling_exponent: float = 0.6 + priority_exponent: float = 0.9 + max_priority_weight: float = 0.9 diff --git a/acme/agents/jax/r2d2/learning.py b/acme/agents/jax/r2d2/learning.py index 8ac7b475ab..8e1d93752a 100644 --- a/acme/agents/jax/r2d2/learning.py +++ b/acme/agents/jax/r2d2/learning.py @@ -18,259 +18,269 @@ import time from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple -from absl import logging -import acme -from acme.adders import reverb as adders -from acme.agents.jax.r2d2 import networks as r2d2_networks -from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.utils import async_utils -from acme.utils import counting -from acme.utils import loggers import jax import jax.numpy as jnp import optax import reverb import rlax import tree +from absl import logging -_PMAP_AXIS_NAME = 'data' +import acme +from acme.adders import reverb as adders +from acme.agents.jax.r2d2 import networks as r2d2_networks +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import async_utils, counting, loggers + +_PMAP_AXIS_NAME = "data" # This type allows splitting a sample between the host and device, which avoids # putting item keys (uint64) on device for the purposes of priority updating. R2D2ReplaySample = utils.PrefetchingSplit class TrainingState(NamedTuple): - """Holds the agent's training state.""" - params: networks_lib.Params - target_params: networks_lib.Params - opt_state: optax.OptState - steps: int - random_key: networks_lib.PRNGKey + """Holds the agent's training state.""" + + params: networks_lib.Params + target_params: networks_lib.Params + opt_state: optax.OptState + steps: int + random_key: networks_lib.PRNGKey class R2D2Learner(acme.Learner): - """R2D2 learner.""" - - def __init__(self, - networks: r2d2_networks.R2D2Networks, - batch_size: int, - random_key: networks_lib.PRNGKey, - burn_in_length: int, - discount: float, - importance_sampling_exponent: float, - max_priority_weight: float, - target_update_period: int, - iterator: Iterator[R2D2ReplaySample], - optimizer: optax.GradientTransformation, - bootstrap_n: int = 5, - tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR, - clip_rewards: bool = False, - max_abs_reward: float = 1., - use_core_state: bool = True, - prefetch_size: int = 2, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None): - """Initializes the learner.""" - - def loss( - params: networks_lib.Params, - target_params: networks_lib.Params, - key_grad: networks_lib.PRNGKey, - sample: reverb.ReplaySample - ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: - """Computes mean transformed N-step loss for a batch of sequences.""" - - # Get core state & warm it up on observations for a burn-in period. - if use_core_state: - # Replay core state. - # NOTE: We may need to recover the type of the hk.LSTMState if the user - # specifies a dynamically unrolled RNN as it will strictly enforce the - # match between input/output state types. - online_state = utils.maybe_recover_lstm_type( - sample.data.extras.get('core_state')) - else: - key_grad, initial_state_rng = jax.random.split(key_grad) - online_state = networks.init_recurrent_state(initial_state_rng, - batch_size) - target_state = online_state - - # Convert sample data to sequence-major format [T, B, ...]. - data = utils.batch_to_sequence(sample.data) - - # Maybe burn the core state in. - if burn_in_length: - burn_obs = jax.tree_map(lambda x: x[:burn_in_length], data.observation) - key_grad, key1, key2 = jax.random.split(key_grad, 3) - _, online_state = networks.unroll(params, key1, burn_obs, online_state) - _, target_state = networks.unroll(target_params, key2, burn_obs, - target_state) - - # Only get data to learn on from after the end of the burn in period. - data = jax.tree_map(lambda seq: seq[burn_in_length:], data) - - # Unroll on sequences to get online and target Q-Values. - key1, key2 = jax.random.split(key_grad) - online_q, _ = networks.unroll(params, key1, data.observation, - online_state) - target_q, _ = networks.unroll(target_params, key2, data.observation, - target_state) - - # Get value-selector actions from online Q-values for double Q-learning. - selector_actions = jnp.argmax(online_q, axis=-1) - # Preprocess discounts & rewards. - discounts = (data.discount * discount).astype(online_q.dtype) - rewards = data.reward - if clip_rewards: - rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) - rewards = rewards.astype(online_q.dtype) - - # Get N-step transformed TD error and loss. - batch_td_error_fn = jax.vmap( - functools.partial( - rlax.transformed_n_step_q_learning, - n=bootstrap_n, - tx_pair=tx_pair), - in_axes=1, - out_axes=1) - batch_td_error = batch_td_error_fn( - online_q[:-1], - data.action[:-1], - target_q[1:], - selector_actions[1:], - rewards[:-1], - discounts[:-1]) - batch_loss = 0.5 * jnp.square(batch_td_error).sum(axis=0) - - # Importance weighting. - probs = sample.info.probability - importance_weights = (1. / (probs + 1e-6)).astype(online_q.dtype) - importance_weights **= importance_sampling_exponent - importance_weights /= jnp.max(importance_weights) - mean_loss = jnp.mean(importance_weights * batch_loss) - - # Calculate priorities as a mixture of max and mean sequence errors. - abs_td_error = jnp.abs(batch_td_error).astype(online_q.dtype) - max_priority = max_priority_weight * jnp.max(abs_td_error, axis=0) - mean_priority = (1 - max_priority_weight) * jnp.mean(abs_td_error, axis=0) - priorities = (max_priority + mean_priority) - - return mean_loss, priorities - - def sgd_step( - state: TrainingState, - samples: reverb.ReplaySample - ) -> Tuple[TrainingState, jnp.ndarray, Dict[str, jnp.ndarray]]: - """Performs an update step, averaging over pmap replicas.""" - - # Compute loss and gradients. - grad_fn = jax.value_and_grad(loss, has_aux=True) - key, key_grad = jax.random.split(state.random_key) - (loss_value, priorities), gradients = grad_fn(state.params, - state.target_params, - key_grad, - samples) - - # Average gradients over pmap replicas before optimizer update. - gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) - - # Apply optimizer updates. - updates, new_opt_state = optimizer.update(gradients, state.opt_state) - new_params = optax.apply_updates(state.params, updates) - - # Periodically update target networks. - steps = state.steps + 1 - target_params = optax.periodic_update(new_params, state.target_params, # pytype: disable=wrong-arg-types # numpy-scalars - steps, self._target_update_period) - - new_state = TrainingState( - params=new_params, - target_params=target_params, - opt_state=new_opt_state, - steps=steps, - random_key=key) - return new_state, priorities, {'loss': loss_value} - - def update_priorities( - keys_and_priorities: Tuple[jnp.ndarray, jnp.ndarray]): - keys, priorities = keys_and_priorities - keys, priorities = tree.map_structure( - # Fetch array and combine device and batch dimensions. - lambda x: utils.fetch_devicearray(x).reshape((-1,) + x.shape[2:]), - (keys, priorities)) - replay_client.mutate_priorities( # pytype: disable=attribute-error - table=adders.DEFAULT_PRIORITY_TABLE, - updates=dict(zip(keys, priorities))) - - # Internalise components, hyperparameters, logger, counter, and methods. - self._iterator = iterator - self._replay_client = replay_client - self._target_update_period = target_update_period - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger( - 'learner', - asynchronous=True, - serialize_fn=utils.fetch_devicearray, - time_delta=1., - steps_key=self._counter.get_steps_key()) - - self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME) - self._async_priority_updater = async_utils.AsyncExecutor(update_priorities) - - # Initialise and internalise training state (parameters/optimiser state). - random_key, key_init = jax.random.split(random_key) - initial_params = networks.init(key_init) - opt_state = optimizer.init(initial_params) - - # Log how many parameters the network has. - sizes = tree.map_structure(jnp.size, initial_params) - logging.info('Total number of params: %d', - sum(tree.flatten(sizes.values()))) - - state = TrainingState( - params=initial_params, - target_params=initial_params, - opt_state=opt_state, - steps=jnp.array(0), - random_key=random_key) - # Replicate parameters. - self._state = utils.replicate_in_all_devices(state) - - def step(self): - prefetching_split = next(self._iterator) - # The split_sample method passed to utils.sharded_prefetch specifies what - # parts of the objects returned by the original iterator are kept in the - # host and what parts are prefetched on-device. - # In this case the host property of the prefetching split contains only the - # replay keys and the device property is the prefetched full original - # sample. - keys = prefetching_split.host - samples: reverb.ReplaySample = prefetching_split.device - - # Do a batch of SGD. - start = time.time() - self._state, priorities, metrics = self._sgd_step(self._state, samples) - # Take metrics from first replica. - metrics = utils.get_from_first_device(metrics) - # Update our counts and record it. - counts = self._counter.increment(steps=1, time_elapsed=time.time() - start) - - # Update priorities in replay. - if self._replay_client: - self._async_priority_updater.put((keys, priorities)) - - # Attempt to write logs. - self._logger.write({**metrics, **counts}) - - def get_variables(self, names: List[str]) -> List[networks_lib.Params]: - del names # There's only one available set of params in this agent. - # Return first replica of parameters. - return utils.get_from_first_device([self._state.params]) - - def save(self) -> TrainingState: - # Serialize only the first replica of parameters and optimizer state. - return utils.get_from_first_device(self._state) - - def restore(self, state: TrainingState): - self._state = utils.replicate_in_all_devices(state) + """R2D2 learner.""" + + def __init__( + self, + networks: r2d2_networks.R2D2Networks, + batch_size: int, + random_key: networks_lib.PRNGKey, + burn_in_length: int, + discount: float, + importance_sampling_exponent: float, + max_priority_weight: float, + target_update_period: int, + iterator: Iterator[R2D2ReplaySample], + optimizer: optax.GradientTransformation, + bootstrap_n: int = 5, + tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR, + clip_rewards: bool = False, + max_abs_reward: float = 1.0, + use_core_state: bool = True, + prefetch_size: int = 2, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + ): + """Initializes the learner.""" + + def loss( + params: networks_lib.Params, + target_params: networks_lib.Params, + key_grad: networks_lib.PRNGKey, + sample: reverb.ReplaySample, + ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: + """Computes mean transformed N-step loss for a batch of sequences.""" + + # Get core state & warm it up on observations for a burn-in period. + if use_core_state: + # Replay core state. + # NOTE: We may need to recover the type of the hk.LSTMState if the user + # specifies a dynamically unrolled RNN as it will strictly enforce the + # match between input/output state types. + online_state = utils.maybe_recover_lstm_type( + sample.data.extras.get("core_state") + ) + else: + key_grad, initial_state_rng = jax.random.split(key_grad) + online_state = networks.init_recurrent_state( + initial_state_rng, batch_size + ) + target_state = online_state + + # Convert sample data to sequence-major format [T, B, ...]. + data = utils.batch_to_sequence(sample.data) + + # Maybe burn the core state in. + if burn_in_length: + burn_obs = jax.tree_map(lambda x: x[:burn_in_length], data.observation) + key_grad, key1, key2 = jax.random.split(key_grad, 3) + _, online_state = networks.unroll(params, key1, burn_obs, online_state) + _, target_state = networks.unroll( + target_params, key2, burn_obs, target_state + ) + + # Only get data to learn on from after the end of the burn in period. + data = jax.tree_map(lambda seq: seq[burn_in_length:], data) + + # Unroll on sequences to get online and target Q-Values. + key1, key2 = jax.random.split(key_grad) + online_q, _ = networks.unroll(params, key1, data.observation, online_state) + target_q, _ = networks.unroll( + target_params, key2, data.observation, target_state + ) + + # Get value-selector actions from online Q-values for double Q-learning. + selector_actions = jnp.argmax(online_q, axis=-1) + # Preprocess discounts & rewards. + discounts = (data.discount * discount).astype(online_q.dtype) + rewards = data.reward + if clip_rewards: + rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) + rewards = rewards.astype(online_q.dtype) + + # Get N-step transformed TD error and loss. + batch_td_error_fn = jax.vmap( + functools.partial( + rlax.transformed_n_step_q_learning, n=bootstrap_n, tx_pair=tx_pair + ), + in_axes=1, + out_axes=1, + ) + batch_td_error = batch_td_error_fn( + online_q[:-1], + data.action[:-1], + target_q[1:], + selector_actions[1:], + rewards[:-1], + discounts[:-1], + ) + batch_loss = 0.5 * jnp.square(batch_td_error).sum(axis=0) + + # Importance weighting. + probs = sample.info.probability + importance_weights = (1.0 / (probs + 1e-6)).astype(online_q.dtype) + importance_weights **= importance_sampling_exponent + importance_weights /= jnp.max(importance_weights) + mean_loss = jnp.mean(importance_weights * batch_loss) + + # Calculate priorities as a mixture of max and mean sequence errors. + abs_td_error = jnp.abs(batch_td_error).astype(online_q.dtype) + max_priority = max_priority_weight * jnp.max(abs_td_error, axis=0) + mean_priority = (1 - max_priority_weight) * jnp.mean(abs_td_error, axis=0) + priorities = max_priority + mean_priority + + return mean_loss, priorities + + def sgd_step( + state: TrainingState, samples: reverb.ReplaySample + ) -> Tuple[TrainingState, jnp.ndarray, Dict[str, jnp.ndarray]]: + """Performs an update step, averaging over pmap replicas.""" + + # Compute loss and gradients. + grad_fn = jax.value_and_grad(loss, has_aux=True) + key, key_grad = jax.random.split(state.random_key) + (loss_value, priorities), gradients = grad_fn( + state.params, state.target_params, key_grad, samples + ) + + # Average gradients over pmap replicas before optimizer update. + gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) + + # Apply optimizer updates. + updates, new_opt_state = optimizer.update(gradients, state.opt_state) + new_params = optax.apply_updates(state.params, updates) + + # Periodically update target networks. + steps = state.steps + 1 + target_params = optax.periodic_update( + new_params, + state.target_params, # pytype: disable=wrong-arg-types # numpy-scalars + steps, + self._target_update_period, + ) + + new_state = TrainingState( + params=new_params, + target_params=target_params, + opt_state=new_opt_state, + steps=steps, + random_key=key, + ) + return new_state, priorities, {"loss": loss_value} + + def update_priorities(keys_and_priorities: Tuple[jnp.ndarray, jnp.ndarray]): + keys, priorities = keys_and_priorities + keys, priorities = tree.map_structure( + # Fetch array and combine device and batch dimensions. + lambda x: utils.fetch_devicearray(x).reshape((-1,) + x.shape[2:]), + (keys, priorities), + ) + replay_client.mutate_priorities( # pytype: disable=attribute-error + table=adders.DEFAULT_PRIORITY_TABLE, updates=dict(zip(keys, priorities)) + ) + + # Internalise components, hyperparameters, logger, counter, and methods. + self._iterator = iterator + self._replay_client = replay_client + self._target_update_period = target_update_period + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + "learner", + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + time_delta=1.0, + steps_key=self._counter.get_steps_key(), + ) + + self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME) + self._async_priority_updater = async_utils.AsyncExecutor(update_priorities) + + # Initialise and internalise training state (parameters/optimiser state). + random_key, key_init = jax.random.split(random_key) + initial_params = networks.init(key_init) + opt_state = optimizer.init(initial_params) + + # Log how many parameters the network has. + sizes = tree.map_structure(jnp.size, initial_params) + logging.info("Total number of params: %d", sum(tree.flatten(sizes.values()))) + + state = TrainingState( + params=initial_params, + target_params=initial_params, + opt_state=opt_state, + steps=jnp.array(0), + random_key=random_key, + ) + # Replicate parameters. + self._state = utils.replicate_in_all_devices(state) + + def step(self): + prefetching_split = next(self._iterator) + # The split_sample method passed to utils.sharded_prefetch specifies what + # parts of the objects returned by the original iterator are kept in the + # host and what parts are prefetched on-device. + # In this case the host property of the prefetching split contains only the + # replay keys and the device property is the prefetched full original + # sample. + keys = prefetching_split.host + samples: reverb.ReplaySample = prefetching_split.device + + # Do a batch of SGD. + start = time.time() + self._state, priorities, metrics = self._sgd_step(self._state, samples) + # Take metrics from first replica. + metrics = utils.get_from_first_device(metrics) + # Update our counts and record it. + counts = self._counter.increment(steps=1, time_elapsed=time.time() - start) + + # Update priorities in replay. + if self._replay_client: + self._async_priority_updater.put((keys, priorities)) + + # Attempt to write logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[networks_lib.Params]: + del names # There's only one available set of params in this agent. + # Return first replica of parameters. + return utils.get_from_first_device([self._state.params]) + + def save(self) -> TrainingState: + # Serialize only the first replica of parameters and optimizer state. + return utils.get_from_first_device(self._state) + + def restore(self, state: TrainingState): + self._state = utils.replicate_in_all_devices(state) diff --git a/acme/agents/jax/r2d2/networks.py b/acme/agents/jax/r2d2/networks.py index e19b328dc8..afbc1b880e 100644 --- a/acme/agents/jax/r2d2/networks.py +++ b/acme/agents/jax/r2d2/networks.py @@ -17,14 +17,13 @@ from acme import specs from acme.jax import networks as networks_lib - R2D2Networks = networks_lib.UnrollableNetwork def make_atari_networks(env_spec: specs.EnvironmentSpec) -> R2D2Networks: - """Builds default R2D2 networks for Atari games.""" + """Builds default R2D2 networks for Atari games.""" - def make_core_module() -> networks_lib.R2D2AtariNetwork: - return networks_lib.R2D2AtariNetwork(env_spec.actions.num_values) + def make_core_module() -> networks_lib.R2D2AtariNetwork: + return networks_lib.R2D2AtariNetwork(env_spec.actions.num_values) - return networks_lib.make_unrollable_network(env_spec, make_core_module) + return networks_lib.make_unrollable_network(env_spec, make_core_module) diff --git a/acme/agents/jax/rnd/__init__.py b/acme/agents/jax/rnd/__init__.py index cb09a0ec9c..1de80a4c32 100644 --- a/acme/agents/jax/rnd/__init__.py +++ b/acme/agents/jax/rnd/__init__.py @@ -16,11 +16,15 @@ from acme.agents.jax.rnd.builder import RNDBuilder from acme.agents.jax.rnd.config import RNDConfig -from acme.agents.jax.rnd.learning import rnd_loss -from acme.agents.jax.rnd.learning import rnd_update_step -from acme.agents.jax.rnd.learning import RNDLearner -from acme.agents.jax.rnd.learning import RNDTrainingState -from acme.agents.jax.rnd.networks import compute_rnd_reward -from acme.agents.jax.rnd.networks import make_networks -from acme.agents.jax.rnd.networks import rnd_reward_fn -from acme.agents.jax.rnd.networks import RNDNetworks +from acme.agents.jax.rnd.learning import ( + RNDLearner, + RNDTrainingState, + rnd_loss, + rnd_update_step, +) +from acme.agents.jax.rnd.networks import ( + RNDNetworks, + compute_rnd_reward, + make_networks, + rnd_reward_fn, +) diff --git a/acme/agents/jax/rnd/builder.py b/acme/agents/jax/rnd/builder.py index 67fc986f44..693c48270f 100644 --- a/acme/agents/jax/rnd/builder.py +++ b/acme/agents/jax/rnd/builder.py @@ -16,9 +16,11 @@ from typing import Callable, Generic, Iterator, List, Optional -from acme import adders -from acme import core -from acme import specs +import jax +import optax +import reverb + +from acme import adders, core, specs from acme.agents.jax import actor_core as actor_core_lib from acme.agents.jax import builders from acme.agents.jax.rnd import config as rnd_config @@ -26,108 +28,114 @@ from acme.agents.jax.rnd import networks as rnd_networks from acme.jax import networks as networks_lib from acme.jax.types import Policy -from acme.utils import counting -from acme.utils import loggers -import jax -import optax -import reverb +from acme.utils import counting, loggers -class RNDBuilder(Generic[rnd_networks.DirectRLNetworks, Policy], - builders.ActorLearnerBuilder[rnd_networks.RNDNetworks, Policy, - reverb.ReplaySample]): - """RND Builder.""" +class RNDBuilder( + Generic[rnd_networks.DirectRLNetworks, Policy], + builders.ActorLearnerBuilder[rnd_networks.RNDNetworks, Policy, reverb.ReplaySample], +): + """RND Builder.""" - def __init__( - self, - rl_agent: builders.ActorLearnerBuilder[rnd_networks.DirectRLNetworks, - Policy, reverb.ReplaySample], - config: rnd_config.RNDConfig, - logger_fn: Callable[[], loggers.Logger] = lambda: None, - ): - """Implements a builder for RND using rl_agent as forward RL algorithm. + def __init__( + self, + rl_agent: builders.ActorLearnerBuilder[ + rnd_networks.DirectRLNetworks, Policy, reverb.ReplaySample + ], + config: rnd_config.RNDConfig, + logger_fn: Callable[[], loggers.Logger] = lambda: None, + ): + """Implements a builder for RND using rl_agent as forward RL algorithm. Args: rl_agent: The standard RL agent used by RND to optimize the generator. config: A config with RND HPs. logger_fn: a logger factory for the rl_agent's learner. """ - self._rl_agent = rl_agent - self._config = config - self._logger_fn = logger_fn - - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: rnd_networks.RNDNetworks[rnd_networks.DirectRLNetworks], - dataset: Iterator[reverb.ReplaySample], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - direct_rl_learner_key, rnd_learner_key = jax.random.split(random_key) - - counter = counter or counting.Counter() - direct_rl_counter = counting.Counter(counter, 'direct_rl') - - def direct_rl_learner_factory( - networks: rnd_networks.DirectRLNetworks, - dataset: Iterator[reverb.ReplaySample]) -> core.Learner: - return self._rl_agent.make_learner( - direct_rl_learner_key, - networks, - dataset, - logger_fn=lambda name: self._logger_fn(), - environment_spec=environment_spec, - replay_client=replay_client, - counter=direct_rl_counter) - - optimizer = optax.adam(learning_rate=self._config.predictor_learning_rate) - - return rnd_learning.RNDLearner( - direct_rl_learner_factory=direct_rl_learner_factory, - iterator=dataset, - optimizer=optimizer, - rnd_network=networks, - rng_key=rnd_learner_key, - is_sequence_based=self._config.is_sequence_based, - grad_updates_per_batch=self._config.num_sgd_steps_per_step, - counter=counter, - logger=logger_fn('learner')) - - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: Policy, - ) -> List[reverb.Table]: - return self._rl_agent.make_replay_tables(environment_spec, policy) - - def make_dataset_iterator( # pytype: disable=signature-mismatch # overriding-return-type-checks - self, - replay_client: reverb.Client) -> Optional[Iterator[reverb.ReplaySample]]: - return self._rl_agent.make_dataset_iterator(replay_client) - - def make_adder(self, replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[Policy]) -> Optional[adders.Adder]: - return self._rl_agent.make_adder(replay_client, environment_spec, policy) - - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: Policy, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - adder: Optional[adders.Adder] = None, - ) -> core.Actor: - return self._rl_agent.make_actor(random_key, policy, environment_spec, - variable_source, adder) - - def make_policy(self, - networks: rnd_networks.RNDNetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> actor_core_lib.FeedForwardPolicy: - """Construct the policy.""" - return self._rl_agent.make_policy(networks.direct_rl_networks, - environment_spec, evaluation) + self._rl_agent = rl_agent + self._config = config + self._logger_fn = logger_fn + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: rnd_networks.RNDNetworks[rnd_networks.DirectRLNetworks], + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + direct_rl_learner_key, rnd_learner_key = jax.random.split(random_key) + + counter = counter or counting.Counter() + direct_rl_counter = counting.Counter(counter, "direct_rl") + + def direct_rl_learner_factory( + networks: rnd_networks.DirectRLNetworks, + dataset: Iterator[reverb.ReplaySample], + ) -> core.Learner: + return self._rl_agent.make_learner( + direct_rl_learner_key, + networks, + dataset, + logger_fn=lambda name: self._logger_fn(), + environment_spec=environment_spec, + replay_client=replay_client, + counter=direct_rl_counter, + ) + + optimizer = optax.adam(learning_rate=self._config.predictor_learning_rate) + + return rnd_learning.RNDLearner( + direct_rl_learner_factory=direct_rl_learner_factory, + iterator=dataset, + optimizer=optimizer, + rnd_network=networks, + rng_key=rnd_learner_key, + is_sequence_based=self._config.is_sequence_based, + grad_updates_per_batch=self._config.num_sgd_steps_per_step, + counter=counter, + logger=logger_fn("learner"), + ) + + def make_replay_tables( + self, environment_spec: specs.EnvironmentSpec, policy: Policy, + ) -> List[reverb.Table]: + return self._rl_agent.make_replay_tables(environment_spec, policy) + + def make_dataset_iterator( # pytype: disable=signature-mismatch # overriding-return-type-checks + self, replay_client: reverb.Client + ) -> Optional[Iterator[reverb.ReplaySample]]: + return self._rl_agent.make_dataset_iterator(replay_client) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[Policy], + ) -> Optional[adders.Adder]: + return self._rl_agent.make_adder(replay_client, environment_spec, policy) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: Policy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + return self._rl_agent.make_actor( + random_key, policy, environment_spec, variable_source, adder + ) + + def make_policy( + self, + networks: rnd_networks.RNDNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> actor_core_lib.FeedForwardPolicy: + """Construct the policy.""" + return self._rl_agent.make_policy( + networks.direct_rl_networks, environment_spec, evaluation + ) diff --git a/acme/agents/jax/rnd/config.py b/acme/agents/jax/rnd/config.py index db50c84338..7694788496 100644 --- a/acme/agents/jax/rnd/config.py +++ b/acme/agents/jax/rnd/config.py @@ -18,13 +18,13 @@ @dataclasses.dataclass class RNDConfig: - """Configuration options for RND.""" + """Configuration options for RND.""" - # Learning rate for the predictor. - predictor_learning_rate: float = 1e-4 + # Learning rate for the predictor. + predictor_learning_rate: float = 1e-4 - # If True, the direct rl algorithm is using the SequenceAdder data format. - is_sequence_based: bool = False + # If True, the direct rl algorithm is using the SequenceAdder data format. + is_sequence_based: bool = False - # How many gradient updates to perform per step. - num_sgd_steps_per_step: int = 1 + # How many gradient updates to perform per step. + num_sgd_steps_per_step: int = 1 diff --git a/acme/agents/jax/rnd/learning.py b/acme/agents/jax/rnd/learning.py index 168c213289..85c419cf27 100644 --- a/acme/agents/jax/rnd/learning.py +++ b/acme/agents/jax/rnd/learning.py @@ -18,43 +18,45 @@ import time from typing import Any, Callable, Dict, Iterator, List, NamedTuple, Optional, Tuple +import jax +import jax.numpy as jnp +import optax +import reverb + import acme from acme import types from acme.agents.jax.rnd import networks as rnd_networks from acme.jax import networks as networks_lib from acme.jax import utils -from acme.utils import counting -from acme.utils import loggers -from acme.utils import reverb_utils -import jax -import jax.numpy as jnp -import optax -import reverb +from acme.utils import counting, loggers, reverb_utils class RNDTrainingState(NamedTuple): - """Contains training state for the learner.""" - optimizer_state: optax.OptState - params: networks_lib.Params - target_params: networks_lib.Params - steps: int + """Contains training state for the learner.""" + + optimizer_state: optax.OptState + params: networks_lib.Params + target_params: networks_lib.Params + steps: int class GlobalTrainingState(NamedTuple): - """Contains training state of the RND learner.""" - rewarder_state: RNDTrainingState - learner_state: Any + """Contains training state of the RND learner.""" + + rewarder_state: RNDTrainingState + learner_state: Any -RNDLoss = Callable[[networks_lib.Params, networks_lib.Params, types.Transition], - float] +RNDLoss = Callable[[networks_lib.Params, networks_lib.Params, types.Transition], float] def rnd_update_step( - state: RNDTrainingState, transitions: types.Transition, - loss_fn: RNDLoss, optimizer: optax.GradientTransformation + state: RNDTrainingState, + transitions: types.Transition, + loss_fn: RNDLoss, + optimizer: optax.GradientTransformation, ) -> Tuple[RNDTrainingState, Dict[str, jnp.ndarray]]: - """Run an update steps on the given transitions. + """Run an update steps on the given transitions. Args: state: The learner state. @@ -65,21 +67,20 @@ def rnd_update_step( Returns: A new state and metrics. """ - loss, grads = jax.value_and_grad(loss_fn)( - state.params, - state.target_params, - transitions=transitions) + loss, grads = jax.value_and_grad(loss_fn)( + state.params, state.target_params, transitions=transitions + ) - update, optimizer_state = optimizer.update(grads, state.optimizer_state) - params = optax.apply_updates(state.params, update) + update, optimizer_state = optimizer.update(grads, state.optimizer_state) + params = optax.apply_updates(state.params, update) - new_state = RNDTrainingState( - optimizer_state=optimizer_state, - params=params, - target_params=state.target_params, - steps=state.steps + 1, - ) - return new_state, {'rnd_loss': loss} + new_state = RNDTrainingState( + optimizer_state=optimizer_state, + params=params, + target_params=state.target_params, + steps=state.steps + 1, + ) + return new_state, {"rnd_loss": loss} def rnd_loss( @@ -88,7 +89,7 @@ def rnd_loss( transitions: types.Transition, networks: rnd_networks.RNDNetworks, ) -> float: - """The Random Network Distillation loss. + """The Random Network Distillation loss. See https://arxiv.org/pdf/1810.12894.pdf A.2 @@ -101,77 +102,83 @@ def rnd_loss( Returns: The MSE loss as a float. """ - target_output = networks.target.apply(target_params, - transitions.observation, - transitions.action) - predictor_output = networks.predictor.apply(predictor_params, - transitions.observation, - transitions.action) - return jnp.mean(jnp.square(target_output - predictor_output)) + target_output = networks.target.apply( + target_params, transitions.observation, transitions.action + ) + predictor_output = networks.predictor.apply( + predictor_params, transitions.observation, transitions.action + ) + return jnp.mean(jnp.square(target_output - predictor_output)) class RNDLearner(acme.Learner): - """RND learner.""" - - def __init__( - self, - direct_rl_learner_factory: Callable[[Any, Iterator[reverb.ReplaySample]], - acme.Learner], - iterator: Iterator[reverb.ReplaySample], - optimizer: optax.GradientTransformation, - rnd_network: rnd_networks.RNDNetworks, - rng_key: jnp.ndarray, - grad_updates_per_batch: int, - is_sequence_based: bool, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None): - self._is_sequence_based = is_sequence_based - - target_key, predictor_key = jax.random.split(rng_key) - target_params = rnd_network.target.init(target_key) - predictor_params = rnd_network.predictor.init(predictor_key) - optimizer_state = optimizer.init(predictor_params) - - self._state = RNDTrainingState( - optimizer_state=optimizer_state, - params=predictor_params, - target_params=target_params, - steps=0) - - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger( - 'learner', - asynchronous=True, - serialize_fn=utils.fetch_devicearray, - steps_key=self._counter.get_steps_key()) - - loss = functools.partial(rnd_loss, networks=rnd_network) - self._update = functools.partial(rnd_update_step, - loss_fn=loss, - optimizer=optimizer) - self._update = utils.process_multiple_batches(self._update, - grad_updates_per_batch) - self._update = jax.jit(self._update) - - self._get_reward = jax.jit( - functools.partial( - rnd_networks.compute_rnd_reward, networks=rnd_network)) - - # Generator expression that works the same as an iterator. - # https://pymbook.readthedocs.io/en/latest/igd.html#generator-expressions - updated_iterator = (self._process_sample(sample) for sample in iterator) - - self._direct_rl_learner = direct_rl_learner_factory( - rnd_network.direct_rl_networks, updated_iterator) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - def _process_sample(self, sample: reverb.ReplaySample) -> reverb.ReplaySample: - """Uses the replay sample to train and update its reward. + """RND learner.""" + + def __init__( + self, + direct_rl_learner_factory: Callable[ + [Any, Iterator[reverb.ReplaySample]], acme.Learner + ], + iterator: Iterator[reverb.ReplaySample], + optimizer: optax.GradientTransformation, + rnd_network: rnd_networks.RNDNetworks, + rng_key: jnp.ndarray, + grad_updates_per_batch: int, + is_sequence_based: bool, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + ): + self._is_sequence_based = is_sequence_based + + target_key, predictor_key = jax.random.split(rng_key) + target_params = rnd_network.target.init(target_key) + predictor_params = rnd_network.predictor.init(predictor_key) + optimizer_state = optimizer.init(predictor_params) + + self._state = RNDTrainingState( + optimizer_state=optimizer_state, + params=predictor_params, + target_params=target_params, + steps=0, + ) + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + "learner", + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key(), + ) + + loss = functools.partial(rnd_loss, networks=rnd_network) + self._update = functools.partial( + rnd_update_step, loss_fn=loss, optimizer=optimizer + ) + self._update = utils.process_multiple_batches( + self._update, grad_updates_per_batch + ) + self._update = jax.jit(self._update) + + self._get_reward = jax.jit( + functools.partial(rnd_networks.compute_rnd_reward, networks=rnd_network) + ) + + # Generator expression that works the same as an iterator. + # https://pymbook.readthedocs.io/en/latest/igd.html#generator-expressions + updated_iterator = (self._process_sample(sample) for sample in iterator) + + self._direct_rl_learner = direct_rl_learner_factory( + rnd_network.direct_rl_networks, updated_iterator + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + def _process_sample(self, sample: reverb.ReplaySample) -> reverb.ReplaySample: + """Uses the replay sample to train and update its reward. Args: sample: Replay sample to train on. @@ -179,51 +186,53 @@ def _process_sample(self, sample: reverb.ReplaySample) -> reverb.ReplaySample: Returns: The sample replay sample with an updated reward. """ - transitions = reverb_utils.replay_sample_to_sars_transition( - sample, is_sequence=self._is_sequence_based) - self._state, metrics = self._update(self._state, transitions) - rewards = self._get_reward(self._state.params, self._state.target_params, - transitions) - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Increment counts and record the current time - counts = self._counter.increment(steps=1, walltime=elapsed_time) - - # Attempts to write the logs. - self._logger.write({**metrics, **counts}) - - return sample._replace(data=sample.data._replace(reward=rewards)) - - def step(self): - self._direct_rl_learner.step() - - def get_variables(self, names: List[str]) -> List[Any]: - rnd_variables = { - 'target_params': self._state.target_params, - 'predictor_params': self._state.params - } - - learner_names = [name for name in names if name not in rnd_variables] - learner_dict = {} - if learner_names: - learner_dict = dict( - zip(learner_names, - self._direct_rl_learner.get_variables(learner_names))) - - variables = [ - rnd_variables.get(name, learner_dict.get(name, None)) for name in names - ] - return variables - - def save(self) -> GlobalTrainingState: - return GlobalTrainingState( - rewarder_state=self._state, - learner_state=self._direct_rl_learner.save()) - - def restore(self, state: GlobalTrainingState): - self._state = state.rewarder_state - self._direct_rl_learner.restore(state.learner_state) + transitions = reverb_utils.replay_sample_to_sars_transition( + sample, is_sequence=self._is_sequence_based + ) + self._state, metrics = self._update(self._state, transitions) + rewards = self._get_reward( + self._state.params, self._state.target_params, transitions + ) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + return sample._replace(data=sample.data._replace(reward=rewards)) + + def step(self): + self._direct_rl_learner.step() + + def get_variables(self, names: List[str]) -> List[Any]: + rnd_variables = { + "target_params": self._state.target_params, + "predictor_params": self._state.params, + } + + learner_names = [name for name in names if name not in rnd_variables] + learner_dict = {} + if learner_names: + learner_dict = dict( + zip(learner_names, self._direct_rl_learner.get_variables(learner_names)) + ) + + variables = [ + rnd_variables.get(name, learner_dict.get(name, None)) for name in names + ] + return variables + + def save(self) -> GlobalTrainingState: + return GlobalTrainingState( + rewarder_state=self._state, learner_state=self._direct_rl_learner.save() + ) + + def restore(self, state: GlobalTrainingState): + self._state = state.rewarder_state + self._direct_rl_learner.restore(state.learner_state) diff --git a/acme/agents/jax/rnd/networks.py b/acme/agents/jax/rnd/networks.py index c81ebc1ea0..6d1b030d04 100644 --- a/acme/agents/jax/rnd/networks.py +++ b/acme/agents/jax/rnd/networks.py @@ -18,27 +18,28 @@ import functools from typing import Callable, Generic, Tuple, TypeVar -from acme import specs -from acme import types -from acme.jax import networks as networks_lib -from acme.jax import utils import haiku as hk import jax.numpy as jnp +from acme import specs, types +from acme.jax import networks as networks_lib +from acme.jax import utils -DirectRLNetworks = TypeVar('DirectRLNetworks') +DirectRLNetworks = TypeVar("DirectRLNetworks") @dataclasses.dataclass class RNDNetworks(Generic[DirectRLNetworks]): - """Container of RND networks factories.""" - target: networks_lib.FeedForwardNetwork - predictor: networks_lib.FeedForwardNetwork - # Function from predictor output, target output, and original reward to reward - get_reward: Callable[ - [networks_lib.NetworkOutput, networks_lib.NetworkOutput, jnp.ndarray], - jnp.ndarray] - direct_rl_networks: DirectRLNetworks = None + """Container of RND networks factories.""" + + target: networks_lib.FeedForwardNetwork + predictor: networks_lib.FeedForwardNetwork + # Function from predictor output, target output, and original reward to reward + get_reward: Callable[ + [networks_lib.NetworkOutput, networks_lib.NetworkOutput, jnp.ndarray], + jnp.ndarray, + ] + direct_rl_networks: DirectRLNetworks = None # See Appendix A.2 of https://arxiv.org/pdf/1810.12894.pdf @@ -49,10 +50,11 @@ def rnd_reward_fn( intrinsic_reward_coefficient: float = 1.0, extrinsic_reward_coefficient: float = 0.0, ) -> jnp.ndarray: - intrinsic_reward = jnp.mean( - jnp.square(predictor_output - target_output), axis=-1) - return (intrinsic_reward_coefficient * intrinsic_reward + - extrinsic_reward_coefficient * original_reward) + intrinsic_reward = jnp.mean(jnp.square(predictor_output - target_output), axis=-1) + return ( + intrinsic_reward_coefficient * intrinsic_reward + + extrinsic_reward_coefficient * original_reward + ) def make_networks( @@ -62,7 +64,7 @@ def make_networks( intrinsic_reward_coefficient: float = 1.0, extrinsic_reward_coefficient: float = 0.0, ) -> RNDNetworks[DirectRLNetworks]: - """Creates networks used by the agent and returns RNDNetworks. + """Creates networks used by the agent and returns RNDNetworks. Args: spec: Environment spec. @@ -75,36 +77,42 @@ def make_networks( The RND networks. """ - def _rnd_fn(obs, act): - # RND does not use the action but other variants like RED do. - del act - network = networks_lib.LayerNormMLP(list(layer_sizes)) - return network(obs) - - target = hk.without_apply_rng(hk.transform(_rnd_fn)) - predictor = hk.without_apply_rng(hk.transform(_rnd_fn)) - - # Create dummy observations and actions to create network parameters. - dummy_obs = utils.zeros_like(spec.observations) - dummy_obs = utils.add_batch_dim(dummy_obs) - - return RNDNetworks( - target=networks_lib.FeedForwardNetwork( - lambda key: target.init(key, dummy_obs, ()), target.apply), - predictor=networks_lib.FeedForwardNetwork( - lambda key: predictor.init(key, dummy_obs, ()), predictor.apply), - direct_rl_networks=direct_rl_networks, - get_reward=functools.partial( - rnd_reward_fn, - intrinsic_reward_coefficient=intrinsic_reward_coefficient, - extrinsic_reward_coefficient=extrinsic_reward_coefficient)) - - -def compute_rnd_reward(predictor_params: networks_lib.Params, - target_params: networks_lib.Params, - transitions: types.Transition, - networks: RNDNetworks) -> jnp.ndarray: - """Computes the intrinsic RND reward for a given transition. + def _rnd_fn(obs, act): + # RND does not use the action but other variants like RED do. + del act + network = networks_lib.LayerNormMLP(list(layer_sizes)) + return network(obs) + + target = hk.without_apply_rng(hk.transform(_rnd_fn)) + predictor = hk.without_apply_rng(hk.transform(_rnd_fn)) + + # Create dummy observations and actions to create network parameters. + dummy_obs = utils.zeros_like(spec.observations) + dummy_obs = utils.add_batch_dim(dummy_obs) + + return RNDNetworks( + target=networks_lib.FeedForwardNetwork( + lambda key: target.init(key, dummy_obs, ()), target.apply + ), + predictor=networks_lib.FeedForwardNetwork( + lambda key: predictor.init(key, dummy_obs, ()), predictor.apply + ), + direct_rl_networks=direct_rl_networks, + get_reward=functools.partial( + rnd_reward_fn, + intrinsic_reward_coefficient=intrinsic_reward_coefficient, + extrinsic_reward_coefficient=extrinsic_reward_coefficient, + ), + ) + + +def compute_rnd_reward( + predictor_params: networks_lib.Params, + target_params: networks_lib.Params, + transitions: types.Transition, + networks: RNDNetworks, +) -> jnp.ndarray: + """Computes the intrinsic RND reward for a given transition. Args: predictor_params: Parameters of the predictor network. @@ -115,10 +123,10 @@ def compute_rnd_reward(predictor_params: networks_lib.Params, Returns: The rewards as an ndarray. """ - target_output = networks.target.apply(target_params, transitions.observation, - transitions.action) - predictor_output = networks.predictor.apply(predictor_params, - transitions.observation, - transitions.action) - return networks.get_reward(predictor_output, target_output, - transitions.reward) + target_output = networks.target.apply( + target_params, transitions.observation, transitions.action + ) + predictor_output = networks.predictor.apply( + predictor_params, transitions.observation, transitions.action + ) + return networks.get_reward(predictor_output, target_output, transitions.reward) diff --git a/acme/agents/jax/sac/__init__.py b/acme/agents/jax/sac/__init__.py index 38f38d47b4..6199294be1 100644 --- a/acme/agents/jax/sac/__init__.py +++ b/acme/agents/jax/sac/__init__.py @@ -15,10 +15,11 @@ """SAC agent.""" from acme.agents.jax.sac.builder import SACBuilder -from acme.agents.jax.sac.config import SACConfig -from acme.agents.jax.sac.config import target_entropy_from_env_spec +from acme.agents.jax.sac.config import SACConfig, target_entropy_from_env_spec from acme.agents.jax.sac.learning import SACLearner -from acme.agents.jax.sac.networks import apply_policy_and_sample -from acme.agents.jax.sac.networks import default_models_to_snapshot -from acme.agents.jax.sac.networks import make_networks -from acme.agents.jax.sac.networks import SACNetworks +from acme.agents.jax.sac.networks import ( + SACNetworks, + apply_policy_and_sample, + default_models_to_snapshot, + make_networks, +) diff --git a/acme/agents/jax/sac/builder.py b/acme/agents/jax/sac/builder.py index b3d0f6c46c..3237b4f4a4 100644 --- a/acme/agents/jax/sac/builder.py +++ b/acme/agents/jax/sac/builder.py @@ -15,148 +15,156 @@ """SAC Builder.""" from typing import Iterator, List, Optional +import jax +import optax +import reverb +from reverb import rate_limiters + import acme -from acme import adders -from acme import core -from acme import specs +from acme import adders, core, specs from acme.adders import reverb as adders_reverb from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.agents.jax import builders -from acme.agents.jax import normalization +from acme.agents.jax import actors, builders, normalization from acme.agents.jax.sac import config as sac_config from acme.agents.jax.sac import learning from acme.agents.jax.sac import networks as sac_networks from acme.datasets import reverb as datasets from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.jax import variable_utils -from acme.utils import counting -from acme.utils import loggers -import jax -import optax -import reverb -from reverb import rate_limiters +from acme.jax import utils, variable_utils +from acme.utils import counting, loggers @normalization.input_normalization_builder -class SACBuilder(builders.ActorLearnerBuilder[sac_networks.SACNetworks, - actor_core_lib.FeedForwardPolicy, - reverb.ReplaySample]): - """SAC Builder.""" +class SACBuilder( + builders.ActorLearnerBuilder[ + sac_networks.SACNetworks, actor_core_lib.FeedForwardPolicy, reverb.ReplaySample + ] +): + """SAC Builder.""" - def __init__( - self, - config: sac_config.SACConfig, - ): - """Creates a SAC learner, a behavior policy and an eval actor. + def __init__( + self, config: sac_config.SACConfig, + ): + """Creates a SAC learner, a behavior policy and an eval actor. Args: config: a config with SAC hps """ - self._config = config + self._config = config - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: sac_networks.SACNetworks, - dataset: Iterator[reverb.ReplaySample], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - del environment_spec, replay_client + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: sac_networks.SACNetworks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec, replay_client - # Create optimizers - policy_optimizer = optax.adam(learning_rate=self._config.learning_rate) - q_optimizer = optax.adam(learning_rate=self._config.learning_rate) + # Create optimizers + policy_optimizer = optax.adam(learning_rate=self._config.learning_rate) + q_optimizer = optax.adam(learning_rate=self._config.learning_rate) - return learning.SACLearner( - networks=networks, - tau=self._config.tau, - discount=self._config.discount, - entropy_coefficient=self._config.entropy_coefficient, - target_entropy=self._config.target_entropy, - rng=random_key, - reward_scale=self._config.reward_scale, - num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, - policy_optimizer=policy_optimizer, - q_optimizer=q_optimizer, - iterator=dataset, - logger=logger_fn('learner'), - counter=counter) + return learning.SACLearner( + networks=networks, + tau=self._config.tau, + discount=self._config.discount, + entropy_coefficient=self._config.entropy_coefficient, + target_entropy=self._config.target_entropy, + rng=random_key, + reward_scale=self._config.reward_scale, + num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, + policy_optimizer=policy_optimizer, + q_optimizer=q_optimizer, + iterator=dataset, + logger=logger_fn("learner"), + counter=counter, + ) - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: actor_core_lib.FeedForwardPolicy, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - adder: Optional[adders.Adder] = None, - ) -> acme.Actor: - del environment_spec - assert variable_source is not None - actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) - variable_client = variable_utils.VariableClient( - variable_source, 'policy', device='cpu') - return actors.GenericActor( - actor_core, random_key, variable_client, adder, backend='cpu') + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: actor_core_lib.FeedForwardPolicy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> acme.Actor: + del environment_spec + assert variable_source is not None + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) + variable_client = variable_utils.VariableClient( + variable_source, "policy", device="cpu" + ) + return actors.GenericActor( + actor_core, random_key, variable_client, adder, backend="cpu" + ) - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: actor_core_lib.FeedForwardPolicy, - ) -> List[reverb.Table]: - """Create tables to insert data into.""" - del policy - samples_per_insert_tolerance = ( - self._config.samples_per_insert_tolerance_rate * - self._config.samples_per_insert) - error_buffer = self._config.min_replay_size * samples_per_insert_tolerance - limiter = rate_limiters.SampleToInsertRatio( - min_size_to_sample=self._config.min_replay_size, - samples_per_insert=self._config.samples_per_insert, - error_buffer=error_buffer) - return [ - reverb.Table( - name=self._config.replay_table_name, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=self._config.max_replay_size, - rate_limiter=limiter, - signature=adders_reverb.NStepTransitionAdder.signature( - environment_spec)) - ] + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: actor_core_lib.FeedForwardPolicy, + ) -> List[reverb.Table]: + """Create tables to insert data into.""" + del policy + samples_per_insert_tolerance = ( + self._config.samples_per_insert_tolerance_rate + * self._config.samples_per_insert + ) + error_buffer = self._config.min_replay_size * samples_per_insert_tolerance + limiter = rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer, + ) + return [ + reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=adders_reverb.NStepTransitionAdder.signature( + environment_spec + ), + ) + ] - def make_dataset_iterator( - self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: - """Create a dataset iterator to use for learning/updating the agent.""" - dataset = datasets.make_reverb_dataset( - table=self._config.replay_table_name, - server_address=replay_client.server_address, - batch_size=(self._config.batch_size * - self._config.num_sgd_steps_per_step), - prefetch_size=self._config.prefetch_size) - return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0]) + def make_dataset_iterator( + self, replay_client: reverb.Client + ) -> Iterator[reverb.ReplaySample]: + """Create a dataset iterator to use for learning/updating the agent.""" + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=replay_client.server_address, + batch_size=(self._config.batch_size * self._config.num_sgd_steps_per_step), + prefetch_size=self._config.prefetch_size, + ) + return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0]) - def make_adder( - self, replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[actor_core_lib.FeedForwardPolicy] - ) -> Optional[adders.Adder]: - """Create an adder which records data generated by the actor/environment.""" - del environment_spec, policy - return adders_reverb.NStepTransitionAdder( - priority_fns={self._config.replay_table_name: None}, - client=replay_client, - n_step=self._config.n_step, - discount=self._config.discount) + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[actor_core_lib.FeedForwardPolicy], + ) -> Optional[adders.Adder]: + """Create an adder which records data generated by the actor/environment.""" + del environment_spec, policy + return adders_reverb.NStepTransitionAdder( + priority_fns={self._config.replay_table_name: None}, + client=replay_client, + n_step=self._config.n_step, + discount=self._config.discount, + ) - def make_policy(self, - networks: sac_networks.SACNetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> actor_core_lib.FeedForwardPolicy: - """Construct the policy.""" - del environment_spec - return sac_networks.apply_policy_and_sample(networks, eval_mode=evaluation) + def make_policy( + self, + networks: sac_networks.SACNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> actor_core_lib.FeedForwardPolicy: + """Construct the policy.""" + del environment_spec + return sac_networks.apply_policy_and_sample(networks, eval_mode=evaluation) diff --git a/acme/agents/jax/sac/config.py b/acme/agents/jax/sac/config.py index a60e69b295..4e08044640 100644 --- a/acme/agents/jax/sac/config.py +++ b/acme/agents/jax/sac/config.py @@ -16,49 +16,50 @@ import dataclasses from typing import Any, Optional +import numpy as onp + from acme import specs from acme.adders import reverb as adders_reverb from acme.agents.jax import normalization -import numpy as onp @dataclasses.dataclass class SACConfig(normalization.InputNormalizerConfig): - """Configuration options for SAC.""" - # Loss options - batch_size: int = 256 - learning_rate: float = 3e-4 - reward_scale: float = 1 - discount: float = 0.99 - n_step: int = 1 - # Coefficient applied to the entropy bonus. If None, an adaptative - # coefficient will be used. - entropy_coefficient: Optional[float] = None - target_entropy: float = 0.0 - # Target smoothing coefficient. - tau: float = 0.005 - - # Replay options - min_replay_size: int = 10000 - max_replay_size: int = 1000000 - replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE - prefetch_size: int = 4 - samples_per_insert: float = 256 - # Rate to be used for the SampleToInsertRatio rate limitter tolerance. - # See a formula in make_replay_tables for more details. - samples_per_insert_tolerance_rate: float = 0.1 - - # How many gradient updates to perform per step. - num_sgd_steps_per_step: int = 1 - - input_normalization: Optional[normalization.NormalizationConfig] = None + """Configuration options for SAC.""" + + # Loss options + batch_size: int = 256 + learning_rate: float = 3e-4 + reward_scale: float = 1 + discount: float = 0.99 + n_step: int = 1 + # Coefficient applied to the entropy bonus. If None, an adaptative + # coefficient will be used. + entropy_coefficient: Optional[float] = None + target_entropy: float = 0.0 + # Target smoothing coefficient. + tau: float = 0.005 + + # Replay options + min_replay_size: int = 10000 + max_replay_size: int = 1000000 + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE + prefetch_size: int = 4 + samples_per_insert: float = 256 + # Rate to be used for the SampleToInsertRatio rate limitter tolerance. + # See a formula in make_replay_tables for more details. + samples_per_insert_tolerance_rate: float = 0.1 + + # How many gradient updates to perform per step. + num_sgd_steps_per_step: int = 1 + + input_normalization: Optional[normalization.NormalizationConfig] = None def target_entropy_from_env_spec( - spec: specs.EnvironmentSpec, - target_entropy_per_dimension: Optional[float] = None, + spec: specs.EnvironmentSpec, target_entropy_per_dimension: Optional[float] = None, ) -> float: - """A heuristic to determine a target entropy. + """A heuristic to determine a target entropy. If target_entropy_per_dimension is not specified, the target entropy is computed as "-num_actions", otherwise it is @@ -72,28 +73,29 @@ def target_entropy_from_env_spec( target entropy """ - def get_num_actions(action_spec: Any) -> float: - """Returns a number of actions in the spec.""" - if isinstance(action_spec, specs.BoundedArray): - return onp.prod(action_spec.shape, dtype=int) - elif isinstance(action_spec, tuple): - return sum(get_num_actions(subspace) for subspace in action_spec) + def get_num_actions(action_spec: Any) -> float: + """Returns a number of actions in the spec.""" + if isinstance(action_spec, specs.BoundedArray): + return onp.prod(action_spec.shape, dtype=int) + elif isinstance(action_spec, tuple): + return sum(get_num_actions(subspace) for subspace in action_spec) + else: + raise ValueError("Unknown action space type.") + + num_actions = get_num_actions(spec.actions) + if target_entropy_per_dimension is None: + if not isinstance(spec.actions, specs.BoundedArray) or isinstance( + spec.actions, specs.DiscreteArray + ): + raise ValueError( + "Only accept BoundedArrays for automatic " + f"target_entropy, got: {spec.actions}" + ) + if not onp.all(spec.actions.minimum == -1.0): + raise ValueError(f"Minimum expected to be -1, got: {spec.actions.minimum}") + if not onp.all(spec.actions.maximum == 1.0): + raise ValueError(f"Maximum expected to be 1, got: {spec.actions.maximum}") + + return -num_actions else: - raise ValueError('Unknown action space type.') - - num_actions = get_num_actions(spec.actions) - if target_entropy_per_dimension is None: - if not isinstance(spec.actions, specs.BoundedArray) or isinstance( - spec.actions, specs.DiscreteArray): - raise ValueError('Only accept BoundedArrays for automatic ' - f'target_entropy, got: {spec.actions}') - if not onp.all(spec.actions.minimum == -1.): - raise ValueError( - f'Minimum expected to be -1, got: {spec.actions.minimum}') - if not onp.all(spec.actions.maximum == 1.): - raise ValueError( - f'Maximum expected to be 1, got: {spec.actions.maximum}') - - return -num_actions - else: - return target_entropy_per_dimension * num_actions + return target_entropy_per_dimension * num_actions diff --git a/acme/agents/jax/sac/learning.py b/acme/agents/jax/sac/learning.py index c10b11e1ce..11f7fb06e1 100644 --- a/acme/agents/jax/sac/learning.py +++ b/acme/agents/jax/sac/learning.py @@ -17,52 +17,54 @@ import time from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple +import jax +import jax.numpy as jnp +import optax +import reverb + import acme from acme import types from acme.agents.jax.sac import networks as sac_networks from acme.jax import networks as networks_lib from acme.jax import utils -from acme.utils import counting -from acme.utils import loggers -import jax -import jax.numpy as jnp -import optax -import reverb +from acme.utils import counting, loggers class TrainingState(NamedTuple): - """Contains training state for the learner.""" - policy_optimizer_state: optax.OptState - q_optimizer_state: optax.OptState - policy_params: networks_lib.Params - q_params: networks_lib.Params - target_q_params: networks_lib.Params - key: networks_lib.PRNGKey - alpha_optimizer_state: Optional[optax.OptState] = None - alpha_params: Optional[networks_lib.Params] = None + """Contains training state for the learner.""" + + policy_optimizer_state: optax.OptState + q_optimizer_state: optax.OptState + policy_params: networks_lib.Params + q_params: networks_lib.Params + target_q_params: networks_lib.Params + key: networks_lib.PRNGKey + alpha_optimizer_state: Optional[optax.OptState] = None + alpha_params: Optional[networks_lib.Params] = None class SACLearner(acme.Learner): - """SAC learner.""" - - _state: TrainingState - - def __init__( - self, - networks: sac_networks.SACNetworks, - rng: jnp.ndarray, - iterator: Iterator[reverb.ReplaySample], - policy_optimizer: optax.GradientTransformation, - q_optimizer: optax.GradientTransformation, - tau: float = 0.005, - reward_scale: float = 1.0, - discount: float = 0.99, - entropy_coefficient: Optional[float] = None, - target_entropy: float = 0, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - num_sgd_steps_per_step: int = 1): - """Initialize the SAC learner. + """SAC learner.""" + + _state: TrainingState + + def __init__( + self, + networks: sac_networks.SACNetworks, + rng: jnp.ndarray, + iterator: Iterator[reverb.ReplaySample], + policy_optimizer: optax.GradientTransformation, + q_optimizer: optax.GradientTransformation, + tau: float = 0.005, + reward_scale: float = 1.0, + discount: float = 0.99, + entropy_coefficient: Optional[float] = None, + target_entropy: float = 0, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + num_sgd_steps_per_step: int = 1, + ): + """Initialize the SAC learner. Args: networks: SAC networks @@ -81,209 +83,238 @@ def __init__( logger: logger object to be used by learner. num_sgd_steps_per_step: number of sgd steps to perform per learner 'step'. """ - adaptive_entropy_coefficient = entropy_coefficient is None - if adaptive_entropy_coefficient: - # alpha is the temperature parameter that determines the relative - # importance of the entropy term versus the reward. - log_alpha = jnp.asarray(0., dtype=jnp.float32) - alpha_optimizer = optax.adam(learning_rate=3e-4) - alpha_optimizer_state = alpha_optimizer.init(log_alpha) - else: - if target_entropy: - raise ValueError('target_entropy should not be set when ' - 'entropy_coefficient is provided') - - def alpha_loss(log_alpha: jnp.ndarray, - policy_params: networks_lib.Params, - transitions: types.Transition, - key: networks_lib.PRNGKey) -> jnp.ndarray: - """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.""" - dist_params = networks.policy_network.apply( - policy_params, transitions.observation) - action = networks.sample(dist_params, key) - log_prob = networks.log_prob(dist_params, action) - alpha = jnp.exp(log_alpha) - alpha_loss = alpha * jax.lax.stop_gradient(-log_prob - target_entropy) - return jnp.mean(alpha_loss) - - def critic_loss(q_params: networks_lib.Params, - policy_params: networks_lib.Params, - target_q_params: networks_lib.Params, - alpha: jnp.ndarray, - transitions: types.Transition, - key: networks_lib.PRNGKey) -> jnp.ndarray: - q_old_action = networks.q_network.apply( - q_params, transitions.observation, transitions.action) - next_dist_params = networks.policy_network.apply( - policy_params, transitions.next_observation) - next_action = networks.sample(next_dist_params, key) - next_log_prob = networks.log_prob(next_dist_params, next_action) - next_q = networks.q_network.apply( - target_q_params, transitions.next_observation, next_action) - next_v = jnp.min(next_q, axis=-1) - alpha * next_log_prob - target_q = jax.lax.stop_gradient(transitions.reward * reward_scale + - transitions.discount * discount * next_v) - q_error = q_old_action - jnp.expand_dims(target_q, -1) - q_loss = 0.5 * jnp.mean(jnp.square(q_error)) - return q_loss - - def actor_loss(policy_params: networks_lib.Params, - q_params: networks_lib.Params, - alpha: jnp.ndarray, - transitions: types.Transition, - key: networks_lib.PRNGKey) -> jnp.ndarray: - dist_params = networks.policy_network.apply( - policy_params, transitions.observation) - action = networks.sample(dist_params, key) - log_prob = networks.log_prob(dist_params, action) - q_action = networks.q_network.apply( - q_params, transitions.observation, action) - min_q = jnp.min(q_action, axis=-1) - actor_loss = alpha * log_prob - min_q - return jnp.mean(actor_loss) - - alpha_grad = jax.value_and_grad(alpha_loss) - critic_grad = jax.value_and_grad(critic_loss) - actor_grad = jax.value_and_grad(actor_loss) - - def update_step( - state: TrainingState, - transitions: types.Transition, - ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: - - key, key_alpha, key_critic, key_actor = jax.random.split(state.key, 4) - if adaptive_entropy_coefficient: - alpha_loss, alpha_grads = alpha_grad(state.alpha_params, - state.policy_params, transitions, - key_alpha) - alpha = jnp.exp(state.alpha_params) - else: - alpha = entropy_coefficient - critic_loss, critic_grads = critic_grad(state.q_params, - state.policy_params, - state.target_q_params, alpha, - transitions, key_critic) - actor_loss, actor_grads = actor_grad(state.policy_params, state.q_params, - alpha, transitions, key_actor) - - # Apply policy gradients - actor_update, policy_optimizer_state = policy_optimizer.update( - actor_grads, state.policy_optimizer_state) - policy_params = optax.apply_updates(state.policy_params, actor_update) - - # Apply critic gradients - critic_update, q_optimizer_state = q_optimizer.update( - critic_grads, state.q_optimizer_state) - q_params = optax.apply_updates(state.q_params, critic_update) - - new_target_q_params = jax.tree_map(lambda x, y: x * (1 - tau) + y * tau, - state.target_q_params, q_params) - - metrics = { - 'critic_loss': critic_loss, - 'actor_loss': actor_loss, - } - - new_state = TrainingState( - policy_optimizer_state=policy_optimizer_state, - q_optimizer_state=q_optimizer_state, - policy_params=policy_params, - q_params=q_params, - target_q_params=new_target_q_params, - key=key, - ) - if adaptive_entropy_coefficient: - # Apply alpha gradients - alpha_update, alpha_optimizer_state = alpha_optimizer.update( - alpha_grads, state.alpha_optimizer_state) - alpha_params = optax.apply_updates(state.alpha_params, alpha_update) - metrics.update({ - 'alpha_loss': alpha_loss, - 'alpha': jnp.exp(alpha_params), - }) - new_state = new_state._replace( - alpha_optimizer_state=alpha_optimizer_state, - alpha_params=alpha_params) - - metrics['rewards_mean'] = jnp.mean( - jnp.abs(jnp.mean(transitions.reward, axis=0))) - metrics['rewards_std'] = jnp.std(transitions.reward, axis=0) - - return new_state, metrics - - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger( - 'learner', - asynchronous=True, - serialize_fn=utils.fetch_devicearray, - steps_key=self._counter.get_steps_key()) - - # Iterator on demonstration transitions. - self._iterator = iterator - - update_step = utils.process_multiple_batches(update_step, - num_sgd_steps_per_step) - # Use the JIT compiler. - self._update_step = jax.jit(update_step) - - def make_initial_state(key: networks_lib.PRNGKey) -> TrainingState: - """Initialises the training state (parameters and optimiser state).""" - key_policy, key_q, key = jax.random.split(key, 3) - - policy_params = networks.policy_network.init(key_policy) - policy_optimizer_state = policy_optimizer.init(policy_params) - - q_params = networks.q_network.init(key_q) - q_optimizer_state = q_optimizer.init(q_params) - - state = TrainingState( - policy_optimizer_state=policy_optimizer_state, - q_optimizer_state=q_optimizer_state, - policy_params=policy_params, - q_params=q_params, - target_q_params=q_params, - key=key) - - if adaptive_entropy_coefficient: - state = state._replace(alpha_optimizer_state=alpha_optimizer_state, - alpha_params=log_alpha) - return state - - # Create initial state. - self._state = make_initial_state(rng) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - def step(self): - sample = next(self._iterator) - transitions = types.Transition(*sample.data) - - self._state, metrics = self._update_step(self._state, transitions) - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Increment counts and record the current time - counts = self._counter.increment(steps=1, walltime=elapsed_time) - - # Attempts to write the logs. - self._logger.write({**metrics, **counts}) - - def get_variables(self, names: List[str]) -> List[Any]: - variables = { - 'policy': self._state.policy_params, - 'critic': self._state.q_params, - } - return [variables[name] for name in names] - - def save(self) -> TrainingState: - return self._state - - def restore(self, state: TrainingState): - self._state = state + adaptive_entropy_coefficient = entropy_coefficient is None + if adaptive_entropy_coefficient: + # alpha is the temperature parameter that determines the relative + # importance of the entropy term versus the reward. + log_alpha = jnp.asarray(0.0, dtype=jnp.float32) + alpha_optimizer = optax.adam(learning_rate=3e-4) + alpha_optimizer_state = alpha_optimizer.init(log_alpha) + else: + if target_entropy: + raise ValueError( + "target_entropy should not be set when " + "entropy_coefficient is provided" + ) + + def alpha_loss( + log_alpha: jnp.ndarray, + policy_params: networks_lib.Params, + transitions: types.Transition, + key: networks_lib.PRNGKey, + ) -> jnp.ndarray: + """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.""" + dist_params = networks.policy_network.apply( + policy_params, transitions.observation + ) + action = networks.sample(dist_params, key) + log_prob = networks.log_prob(dist_params, action) + alpha = jnp.exp(log_alpha) + alpha_loss = alpha * jax.lax.stop_gradient(-log_prob - target_entropy) + return jnp.mean(alpha_loss) + + def critic_loss( + q_params: networks_lib.Params, + policy_params: networks_lib.Params, + target_q_params: networks_lib.Params, + alpha: jnp.ndarray, + transitions: types.Transition, + key: networks_lib.PRNGKey, + ) -> jnp.ndarray: + q_old_action = networks.q_network.apply( + q_params, transitions.observation, transitions.action + ) + next_dist_params = networks.policy_network.apply( + policy_params, transitions.next_observation + ) + next_action = networks.sample(next_dist_params, key) + next_log_prob = networks.log_prob(next_dist_params, next_action) + next_q = networks.q_network.apply( + target_q_params, transitions.next_observation, next_action + ) + next_v = jnp.min(next_q, axis=-1) - alpha * next_log_prob + target_q = jax.lax.stop_gradient( + transitions.reward * reward_scale + + transitions.discount * discount * next_v + ) + q_error = q_old_action - jnp.expand_dims(target_q, -1) + q_loss = 0.5 * jnp.mean(jnp.square(q_error)) + return q_loss + + def actor_loss( + policy_params: networks_lib.Params, + q_params: networks_lib.Params, + alpha: jnp.ndarray, + transitions: types.Transition, + key: networks_lib.PRNGKey, + ) -> jnp.ndarray: + dist_params = networks.policy_network.apply( + policy_params, transitions.observation + ) + action = networks.sample(dist_params, key) + log_prob = networks.log_prob(dist_params, action) + q_action = networks.q_network.apply( + q_params, transitions.observation, action + ) + min_q = jnp.min(q_action, axis=-1) + actor_loss = alpha * log_prob - min_q + return jnp.mean(actor_loss) + + alpha_grad = jax.value_and_grad(alpha_loss) + critic_grad = jax.value_and_grad(critic_loss) + actor_grad = jax.value_and_grad(actor_loss) + + def update_step( + state: TrainingState, transitions: types.Transition, + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + + key, key_alpha, key_critic, key_actor = jax.random.split(state.key, 4) + if adaptive_entropy_coefficient: + alpha_loss, alpha_grads = alpha_grad( + state.alpha_params, state.policy_params, transitions, key_alpha + ) + alpha = jnp.exp(state.alpha_params) + else: + alpha = entropy_coefficient + critic_loss, critic_grads = critic_grad( + state.q_params, + state.policy_params, + state.target_q_params, + alpha, + transitions, + key_critic, + ) + actor_loss, actor_grads = actor_grad( + state.policy_params, state.q_params, alpha, transitions, key_actor + ) + + # Apply policy gradients + actor_update, policy_optimizer_state = policy_optimizer.update( + actor_grads, state.policy_optimizer_state + ) + policy_params = optax.apply_updates(state.policy_params, actor_update) + + # Apply critic gradients + critic_update, q_optimizer_state = q_optimizer.update( + critic_grads, state.q_optimizer_state + ) + q_params = optax.apply_updates(state.q_params, critic_update) + + new_target_q_params = jax.tree_map( + lambda x, y: x * (1 - tau) + y * tau, state.target_q_params, q_params + ) + + metrics = { + "critic_loss": critic_loss, + "actor_loss": actor_loss, + } + + new_state = TrainingState( + policy_optimizer_state=policy_optimizer_state, + q_optimizer_state=q_optimizer_state, + policy_params=policy_params, + q_params=q_params, + target_q_params=new_target_q_params, + key=key, + ) + if adaptive_entropy_coefficient: + # Apply alpha gradients + alpha_update, alpha_optimizer_state = alpha_optimizer.update( + alpha_grads, state.alpha_optimizer_state + ) + alpha_params = optax.apply_updates(state.alpha_params, alpha_update) + metrics.update( + {"alpha_loss": alpha_loss, "alpha": jnp.exp(alpha_params),} + ) + new_state = new_state._replace( + alpha_optimizer_state=alpha_optimizer_state, + alpha_params=alpha_params, + ) + + metrics["rewards_mean"] = jnp.mean( + jnp.abs(jnp.mean(transitions.reward, axis=0)) + ) + metrics["rewards_std"] = jnp.std(transitions.reward, axis=0) + + return new_state, metrics + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + "learner", + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key(), + ) + + # Iterator on demonstration transitions. + self._iterator = iterator + + update_step = utils.process_multiple_batches( + update_step, num_sgd_steps_per_step + ) + # Use the JIT compiler. + self._update_step = jax.jit(update_step) + + def make_initial_state(key: networks_lib.PRNGKey) -> TrainingState: + """Initialises the training state (parameters and optimiser state).""" + key_policy, key_q, key = jax.random.split(key, 3) + + policy_params = networks.policy_network.init(key_policy) + policy_optimizer_state = policy_optimizer.init(policy_params) + + q_params = networks.q_network.init(key_q) + q_optimizer_state = q_optimizer.init(q_params) + + state = TrainingState( + policy_optimizer_state=policy_optimizer_state, + q_optimizer_state=q_optimizer_state, + policy_params=policy_params, + q_params=q_params, + target_q_params=q_params, + key=key, + ) + + if adaptive_entropy_coefficient: + state = state._replace( + alpha_optimizer_state=alpha_optimizer_state, alpha_params=log_alpha + ) + return state + + # Create initial state. + self._state = make_initial_state(rng) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + def step(self): + sample = next(self._iterator) + transitions = types.Transition(*sample.data) + + self._state, metrics = self._update_step(self._state, transitions) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[Any]: + variables = { + "policy": self._state.policy_params, + "critic": self._state.q_params, + } + return [variables[name] for name in names] + + def save(self) -> TrainingState: + return self._state + + def restore(self, state: TrainingState): + self._state = state diff --git a/acme/agents/jax/sac/networks.py b/acme/agents/jax/sac/networks.py index 10ebfbb2bf..a0c452e5fc 100644 --- a/acme/agents/jax/sac/networks.py +++ b/acme/agents/jax/sac/networks.py @@ -17,127 +17,139 @@ import dataclasses from typing import Optional, Tuple -from acme import core -from acme import specs -from acme.agents.jax import actor_core as actor_core_lib -from acme.jax import networks as networks_lib -from acme.jax import types -from acme.jax import utils import haiku as hk import jax import jax.numpy as jnp import numpy as np +from acme import core, specs +from acme.agents.jax import actor_core as actor_core_lib +from acme.jax import networks as networks_lib +from acme.jax import types, utils + @dataclasses.dataclass class SACNetworks: - """Network and pure functions for the SAC agent..""" - policy_network: networks_lib.FeedForwardNetwork - q_network: networks_lib.FeedForwardNetwork - log_prob: networks_lib.LogProbFn - sample: networks_lib.SampleFn - sample_eval: Optional[networks_lib.SampleFn] = None - - -def default_models_to_snapshot( - networks: SACNetworks, - spec: specs.EnvironmentSpec): - """Defines default models to be snapshotted.""" - dummy_obs = utils.zeros_like(spec.observations) - dummy_action = utils.zeros_like(spec.actions) - dummy_key = jax.random.PRNGKey(0) - - def q_network( - source: core.VariableSource) -> types.ModelToSnapshot: - params = source.get_variables(['critic'])[0] - return types.ModelToSnapshot( - networks.q_network.apply, params, - {'obs': dummy_obs, 'action': dummy_action}) - - def default_training_actor( - source: core.VariableSource) -> types.ModelToSnapshot: - params = source.get_variables(['policy'])[0] - return types.ModelToSnapshot(apply_policy_and_sample(networks, False), - params, - {'key': dummy_key, 'obs': dummy_obs}) - - def default_eval_actor( - source: core.VariableSource) -> types.ModelToSnapshot: - params = source.get_variables(['policy'])[0] - return types.ModelToSnapshot( - apply_policy_and_sample(networks, True), params, - {'key': dummy_key, 'obs': dummy_obs}) - - return { - 'q_network': q_network, - 'default_training_actor': default_training_actor, - 'default_eval_actor': default_eval_actor, - } + """Network and pure functions for the SAC agent..""" + + policy_network: networks_lib.FeedForwardNetwork + q_network: networks_lib.FeedForwardNetwork + log_prob: networks_lib.LogProbFn + sample: networks_lib.SampleFn + sample_eval: Optional[networks_lib.SampleFn] = None + + +def default_models_to_snapshot(networks: SACNetworks, spec: specs.EnvironmentSpec): + """Defines default models to be snapshotted.""" + dummy_obs = utils.zeros_like(spec.observations) + dummy_action = utils.zeros_like(spec.actions) + dummy_key = jax.random.PRNGKey(0) + + def q_network(source: core.VariableSource) -> types.ModelToSnapshot: + params = source.get_variables(["critic"])[0] + return types.ModelToSnapshot( + networks.q_network.apply, params, {"obs": dummy_obs, "action": dummy_action} + ) + + def default_training_actor(source: core.VariableSource) -> types.ModelToSnapshot: + params = source.get_variables(["policy"])[0] + return types.ModelToSnapshot( + apply_policy_and_sample(networks, False), + params, + {"key": dummy_key, "obs": dummy_obs}, + ) + + def default_eval_actor(source: core.VariableSource) -> types.ModelToSnapshot: + params = source.get_variables(["policy"])[0] + return types.ModelToSnapshot( + apply_policy_and_sample(networks, True), + params, + {"key": dummy_key, "obs": dummy_obs}, + ) + + return { + "q_network": q_network, + "default_training_actor": default_training_actor, + "default_eval_actor": default_eval_actor, + } def apply_policy_and_sample( - networks: SACNetworks, - eval_mode: bool = False) -> actor_core_lib.FeedForwardPolicy: - """Returns a function that computes actions.""" - sample_fn = networks.sample if not eval_mode else networks.sample_eval - if not sample_fn: - raise ValueError('sample function is not provided') + networks: SACNetworks, eval_mode: bool = False +) -> actor_core_lib.FeedForwardPolicy: + """Returns a function that computes actions.""" + sample_fn = networks.sample if not eval_mode else networks.sample_eval + if not sample_fn: + raise ValueError("sample function is not provided") + + def apply_and_sample(params, key, obs): + return sample_fn(networks.policy_network.apply(params, obs), key) - def apply_and_sample(params, key, obs): - return sample_fn(networks.policy_network.apply(params, obs), key) - return apply_and_sample + return apply_and_sample def make_networks( - spec: specs.EnvironmentSpec, - hidden_layer_sizes: Tuple[int, ...] = (256, 256)) -> SACNetworks: - """Creates networks used by the agent.""" - - num_dimensions = np.prod(spec.actions.shape, dtype=int) - - def _actor_fn(obs): - network = hk.Sequential([ - hk.nets.MLP( - list(hidden_layer_sizes), - w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), - activation=jax.nn.relu, - activate_final=True), - networks_lib.NormalTanhDistribution(num_dimensions), - ]) - return network(obs) - - def _critic_fn(obs, action): - network1 = hk.Sequential([ - hk.nets.MLP( - list(hidden_layer_sizes) + [1], - w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), - activation=jax.nn.relu), - ]) - network2 = hk.Sequential([ - hk.nets.MLP( - list(hidden_layer_sizes) + [1], - w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), - activation=jax.nn.relu), - ]) - input_ = jnp.concatenate([obs, action], axis=-1) - value1 = network1(input_) - value2 = network2(input_) - return jnp.concatenate([value1, value2], axis=-1) - - policy = hk.without_apply_rng(hk.transform(_actor_fn)) - critic = hk.without_apply_rng(hk.transform(_critic_fn)) - - # Create dummy observations and actions to create network parameters. - dummy_action = utils.zeros_like(spec.actions) - dummy_obs = utils.zeros_like(spec.observations) - dummy_action = utils.add_batch_dim(dummy_action) - dummy_obs = utils.add_batch_dim(dummy_obs) - - return SACNetworks( - policy_network=networks_lib.FeedForwardNetwork( - lambda key: policy.init(key, dummy_obs), policy.apply), - q_network=networks_lib.FeedForwardNetwork( - lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply), - log_prob=lambda params, actions: params.log_prob(actions), - sample=lambda params, key: params.sample(seed=key), - sample_eval=lambda params, key: params.mode()) + spec: specs.EnvironmentSpec, hidden_layer_sizes: Tuple[int, ...] = (256, 256) +) -> SACNetworks: + """Creates networks used by the agent.""" + + num_dimensions = np.prod(spec.actions.shape, dtype=int) + + def _actor_fn(obs): + network = hk.Sequential( + [ + hk.nets.MLP( + list(hidden_layer_sizes), + w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), + activation=jax.nn.relu, + activate_final=True, + ), + networks_lib.NormalTanhDistribution(num_dimensions), + ] + ) + return network(obs) + + def _critic_fn(obs, action): + network1 = hk.Sequential( + [ + hk.nets.MLP( + list(hidden_layer_sizes) + [1], + w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), + activation=jax.nn.relu, + ), + ] + ) + network2 = hk.Sequential( + [ + hk.nets.MLP( + list(hidden_layer_sizes) + [1], + w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), + activation=jax.nn.relu, + ), + ] + ) + input_ = jnp.concatenate([obs, action], axis=-1) + value1 = network1(input_) + value2 = network2(input_) + return jnp.concatenate([value1, value2], axis=-1) + + policy = hk.without_apply_rng(hk.transform(_actor_fn)) + critic = hk.without_apply_rng(hk.transform(_critic_fn)) + + # Create dummy observations and actions to create network parameters. + dummy_action = utils.zeros_like(spec.actions) + dummy_obs = utils.zeros_like(spec.observations) + dummy_action = utils.add_batch_dim(dummy_action) + dummy_obs = utils.add_batch_dim(dummy_obs) + + return SACNetworks( + policy_network=networks_lib.FeedForwardNetwork( + lambda key: policy.init(key, dummy_obs), policy.apply + ), + q_network=networks_lib.FeedForwardNetwork( + lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply + ), + log_prob=lambda params, actions: params.log_prob(actions), + sample=lambda params, key: params.sample(seed=key), + sample_eval=lambda params, key: params.mode(), + ) diff --git a/acme/agents/jax/sqil/builder.py b/acme/agents/jax/sqil/builder.py index 0c497737f5..ba02e4b89d 100644 --- a/acme/agents/jax/sqil/builder.py +++ b/acme/agents/jax/sqil/builder.py @@ -16,27 +16,27 @@ from typing import Callable, Generic, Iterator, List, Optional -from acme import adders -from acme import core -from acme import specs -from acme import types -from acme.agents.jax import builders -from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.jax.imitation_learning_types import DirectPolicyNetwork, DirectRLNetworks # pylint: disable=g-multiple-import -from acme.utils import counting -from acme.utils import loggers import jax import numpy as np import reverb import tree +from acme import adders, core, specs, types +from acme.agents.jax import builders +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.jax.imitation_learning_types import ( # pylint: disable=g-multiple-import + DirectPolicyNetwork, + DirectRLNetworks, +) +from acme.utils import counting, loggers + def _generate_sqil_samples( demonstration_iterator: Iterator[types.Transition], - replay_iterator: Iterator[reverb.ReplaySample] + replay_iterator: Iterator[reverb.ReplaySample], ) -> Iterator[reverb.ReplaySample]: - """Generator which creates the sample iterator for SQIL. + """Generator which creates the sample iterator for SQIL. Args: demonstration_iterator: Iterator of demonstrations. @@ -46,42 +46,49 @@ def _generate_sqil_samples( Samples having a mix of demonstrations with reward 1 and replay samples with reward 0. """ - for demonstrations, replay_sample in zip(demonstration_iterator, - replay_iterator): - demonstrations = demonstrations._replace( - reward=np.ones_like(demonstrations.reward)) - - replay_transitions = replay_sample.data - replay_transitions = replay_transitions._replace( - reward=np.zeros_like(replay_transitions.reward)) - - double_batch = tree.map_structure(lambda x, y: np.concatenate([x, y]), - demonstrations, replay_transitions) - - # Split the double batch in an interleaving fashion. - # e.g [1, 2, 3, 4 ,5 ,6] -> [1, 3, 5] and [2, 4, 6] - yield reverb.ReplaySample( - info=replay_sample.info, - data=tree.map_structure(lambda x: x[0::2], double_batch)) - yield reverb.ReplaySample( - info=replay_sample.info, - data=tree.map_structure(lambda x: x[1::2], double_batch)) - - -class SQILBuilder(Generic[DirectRLNetworks, DirectPolicyNetwork], - builders.ActorLearnerBuilder[DirectRLNetworks, - DirectPolicyNetwork, - reverb.ReplaySample]): - """SQIL Builder (https://openreview.net/pdf?id=S1xKd24twB).""" - - def __init__(self, - rl_agent: builders.ActorLearnerBuilder[DirectRLNetworks, - DirectPolicyNetwork, - reverb.ReplaySample], - rl_agent_batch_size: int, - make_demonstrations: Callable[[int], - Iterator[types.Transition]]): - """Builds a SQIL agent. + for demonstrations, replay_sample in zip(demonstration_iterator, replay_iterator): + demonstrations = demonstrations._replace( + reward=np.ones_like(demonstrations.reward) + ) + + replay_transitions = replay_sample.data + replay_transitions = replay_transitions._replace( + reward=np.zeros_like(replay_transitions.reward) + ) + + double_batch = tree.map_structure( + lambda x, y: np.concatenate([x, y]), demonstrations, replay_transitions + ) + + # Split the double batch in an interleaving fashion. + # e.g [1, 2, 3, 4 ,5 ,6] -> [1, 3, 5] and [2, 4, 6] + yield reverb.ReplaySample( + info=replay_sample.info, + data=tree.map_structure(lambda x: x[0::2], double_batch), + ) + yield reverb.ReplaySample( + info=replay_sample.info, + data=tree.map_structure(lambda x: x[1::2], double_batch), + ) + + +class SQILBuilder( + Generic[DirectRLNetworks, DirectPolicyNetwork], + builders.ActorLearnerBuilder[ + DirectRLNetworks, DirectPolicyNetwork, reverb.ReplaySample + ], +): + """SQIL Builder (https://openreview.net/pdf?id=S1xKd24twB).""" + + def __init__( + self, + rl_agent: builders.ActorLearnerBuilder[ + DirectRLNetworks, DirectPolicyNetwork, reverb.ReplaySample + ], + rl_agent_batch_size: int, + make_demonstrations: Callable[[int], Iterator[types.Transition]], + ): + """Builds a SQIL agent. Args: rl_agent: An off policy direct RL agent.. @@ -89,43 +96,42 @@ def __init__(self, make_demonstrations: A function that returns an infinite iterator with demonstrations. """ - self._rl_agent = rl_agent - self._rl_agent_batch_size = rl_agent_batch_size - self._make_demonstrations = make_demonstrations - - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: DirectRLNetworks, - dataset: Iterator[reverb.ReplaySample], - logger_fn: loggers.LoggerFactory, - environment_spec: Optional[specs.EnvironmentSpec] = None, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - """Creates the learner.""" - counter = counter or counting.Counter() - direct_rl_counter = counting.Counter(counter, 'direct_rl') - return self._rl_agent.make_learner( - random_key, - networks, - dataset=dataset, - logger_fn=logger_fn, - environment_spec=environment_spec, - replay_client=replay_client, - counter=direct_rl_counter) - - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: DirectPolicyNetwork, - ) -> List[reverb.Table]: - return self._rl_agent.make_replay_tables(environment_spec, policy) - - def make_dataset_iterator( # pytype: disable=signature-mismatch # overriding-return-type-checks - self, - replay_client: reverb.Client) -> Optional[Iterator[reverb.ReplaySample]]: - """The returned iterator returns batches with both expert and policy data. + self._rl_agent = rl_agent + self._rl_agent_batch_size = rl_agent_batch_size + self._make_demonstrations = make_demonstrations + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: DirectRLNetworks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: Optional[specs.EnvironmentSpec] = None, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + """Creates the learner.""" + counter = counter or counting.Counter() + direct_rl_counter = counting.Counter(counter, "direct_rl") + return self._rl_agent.make_learner( + random_key, + networks, + dataset=dataset, + logger_fn=logger_fn, + environment_spec=environment_spec, + replay_client=replay_client, + counter=direct_rl_counter, + ) + + def make_replay_tables( + self, environment_spec: specs.EnvironmentSpec, policy: DirectPolicyNetwork, + ) -> List[reverb.Table]: + return self._rl_agent.make_replay_tables(environment_spec, policy) + + def make_dataset_iterator( # pytype: disable=signature-mismatch # overriding-return-type-checks + self, replay_client: reverb.Client + ) -> Optional[Iterator[reverb.ReplaySample]]: + """The returned iterator returns batches with both expert and policy data. Batch items will alternate between expert data and policy data. @@ -135,36 +141,41 @@ def make_dataset_iterator( # pytype: disable=signature-mismatch # overriding-r Returns: The Replay sample iterator. """ - # TODO(eorsini): Make sure we have the exact same format as the rl_agent's - # adder writes in. - demonstration_iterator = self._make_demonstrations( - self._rl_agent_batch_size) - - rb_iterator = self._rl_agent.make_dataset_iterator(replay_client) - - return utils.device_put( - _generate_sqil_samples(demonstration_iterator, rb_iterator), - jax.devices()[0]) - - def make_adder( - self, replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[DirectPolicyNetwork]) -> Optional[adders.Adder]: - return self._rl_agent.make_adder(replay_client, environment_spec, policy) - - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: DirectPolicyNetwork, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - adder: Optional[adders.Adder] = None, - ) -> core.Actor: - return self._rl_agent.make_actor(random_key, policy, environment_spec, - variable_source, adder) - - def make_policy(self, - networks: DirectRLNetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> DirectPolicyNetwork: - return self._rl_agent.make_policy(networks, environment_spec, evaluation) + # TODO(eorsini): Make sure we have the exact same format as the rl_agent's + # adder writes in. + demonstration_iterator = self._make_demonstrations(self._rl_agent_batch_size) + + rb_iterator = self._rl_agent.make_dataset_iterator(replay_client) + + return utils.device_put( + _generate_sqil_samples(demonstration_iterator, rb_iterator), + jax.devices()[0], + ) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[DirectPolicyNetwork], + ) -> Optional[adders.Adder]: + return self._rl_agent.make_adder(replay_client, environment_spec, policy) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: DirectPolicyNetwork, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + return self._rl_agent.make_actor( + random_key, policy, environment_spec, variable_source, adder + ) + + def make_policy( + self, + networks: DirectRLNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> DirectPolicyNetwork: + return self._rl_agent.make_policy(networks, environment_spec, evaluation) diff --git a/acme/agents/jax/sqil/builder_test.py b/acme/agents/jax/sqil/builder_test.py index 4a2608f080..5da3e4e1da 100644 --- a/acme/agents/jax/sqil/builder_test.py +++ b/acme/agents/jax/sqil/builder_test.py @@ -14,31 +14,32 @@ """Tests for the SQIL iterator.""" -from acme import types -from acme.agents.jax.sqil import builder import numpy as np import reverb - from absl.testing import absltest +from acme import types +from acme.agents.jax.sqil import builder + class BuilderTest(absltest.TestCase): + def test_sqil_iterator(self): + demonstrations = [types.Transition(np.array([[1], [2], [3]]), (), (), (), ())] + replay = [ + reverb.ReplaySample( + info=(), + data=types.Transition(np.array([[4], [5], [6]]), (), (), (), ()), + ) + ] + sqil_it = builder._generate_sqil_samples(iter(demonstrations), iter(replay)) + np.testing.assert_array_equal( + next(sqil_it).data.observation, np.array([[1], [3], [5]]) + ) + np.testing.assert_array_equal( + next(sqil_it).data.observation, np.array([[2], [4], [6]]) + ) + self.assertRaises(StopIteration, lambda: next(sqil_it)) - def test_sqil_iterator(self): - demonstrations = [ - types.Transition(np.array([[1], [2], [3]]), (), (), (), ()) - ] - replay = [ - reverb.ReplaySample( - info=(), - data=types.Transition(np.array([[4], [5], [6]]), (), (), (), ())) - ] - sqil_it = builder._generate_sqil_samples(iter(demonstrations), iter(replay)) - np.testing.assert_array_equal( - next(sqil_it).data.observation, np.array([[1], [3], [5]])) - np.testing.assert_array_equal( - next(sqil_it).data.observation, np.array([[2], [4], [6]])) - self.assertRaises(StopIteration, lambda: next(sqil_it)) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/jax/td3/__init__.py b/acme/agents/jax/td3/__init__.py index 1e3f9387a5..6994979c15 100644 --- a/acme/agents/jax/td3/__init__.py +++ b/acme/agents/jax/td3/__init__.py @@ -17,6 +17,8 @@ from acme.agents.jax.td3.builder import TD3Builder from acme.agents.jax.td3.config import TD3Config from acme.agents.jax.td3.learning import TD3Learner -from acme.agents.jax.td3.networks import get_default_behavior_policy -from acme.agents.jax.td3.networks import make_networks -from acme.agents.jax.td3.networks import TD3Networks +from acme.agents.jax.td3.networks import ( + TD3Networks, + get_default_behavior_policy, + make_networks, +) diff --git a/acme/agents/jax/td3/builder.py b/acme/agents/jax/td3/builder.py index 1149bd9de4..0bc2a03c4c 100644 --- a/acme/agents/jax/td3/builder.py +++ b/acme/agents/jax/td3/builder.py @@ -15,151 +15,164 @@ """TD3 Builder.""" from typing import Iterator, List, Optional -from acme import adders -from acme import core -from acme import specs +import jax +import optax +import reverb +from reverb import rate_limiters + +from acme import adders, core, specs from acme.adders import reverb as adders_reverb from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.agents.jax import builders +from acme.agents.jax import actors, builders from acme.agents.jax.td3 import config as td3_config from acme.agents.jax.td3 import learning from acme.agents.jax.td3 import networks as td3_networks from acme.datasets import reverb as datasets from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.jax import variable_utils -from acme.utils import counting -from acme.utils import loggers -import jax -import optax -import reverb -from reverb import rate_limiters +from acme.jax import utils, variable_utils +from acme.utils import counting, loggers -class TD3Builder(builders.ActorLearnerBuilder[td3_networks.TD3Networks, - actor_core_lib.FeedForwardPolicy, - reverb.ReplaySample]): - """TD3 Builder.""" +class TD3Builder( + builders.ActorLearnerBuilder[ + td3_networks.TD3Networks, actor_core_lib.FeedForwardPolicy, reverb.ReplaySample + ] +): + """TD3 Builder.""" - def __init__( - self, - config: td3_config.TD3Config, - ): - """Creates a TD3 learner, a behavior policy and an eval actor. + def __init__( + self, config: td3_config.TD3Config, + ): + """Creates a TD3 learner, a behavior policy and an eval actor. Args: config: a config with TD3 hps """ - self._config = config - - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: td3_networks.TD3Networks, - dataset: Iterator[reverb.ReplaySample], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - del environment_spec, replay_client - - critic_optimizer = optax.adam(self._config.critic_learning_rate) - twin_critic_optimizer = optax.adam(self._config.critic_learning_rate) - policy_optimizer = optax.adam(self._config.policy_learning_rate) - - if self._config.policy_gradient_clipping is not None: - policy_optimizer = optax.chain( - optax.clip_by_global_norm(self._config.policy_gradient_clipping), - policy_optimizer) - - return learning.TD3Learner( - networks=networks, - random_key=random_key, - discount=self._config.discount, - target_sigma=self._config.target_sigma, - noise_clip=self._config.noise_clip, - policy_optimizer=policy_optimizer, - critic_optimizer=critic_optimizer, - twin_critic_optimizer=twin_critic_optimizer, - num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, - bc_alpha=self._config.bc_alpha, - iterator=dataset, - logger=logger_fn('learner'), - counter=counter) - - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: actor_core_lib.FeedForwardPolicy, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - adder: Optional[adders.Adder] = None, - ) -> core.Actor: - del environment_spec - assert variable_source is not None - actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) - # Inference happens on CPU, so it's better to move variables there too. - variable_client = variable_utils.VariableClient(variable_source, 'policy', - device='cpu') - return actors.GenericActor( - actor_core, random_key, variable_client, adder, backend='cpu') - - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: actor_core_lib.FeedForwardPolicy, - ) -> List[reverb.Table]: - """Creates reverb tables for the algorithm.""" - del policy - samples_per_insert_tolerance = ( - self._config.samples_per_insert_tolerance_rate * - self._config.samples_per_insert) - error_buffer = self._config.min_replay_size * samples_per_insert_tolerance - limiter = rate_limiters.SampleToInsertRatio( - min_size_to_sample=self._config.min_replay_size, - samples_per_insert=self._config.samples_per_insert, - error_buffer=error_buffer) - return [reverb.Table( - name=self._config.replay_table_name, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=self._config.max_replay_size, - rate_limiter=limiter, - signature=adders_reverb.NStepTransitionAdder.signature( - environment_spec))] - - def make_dataset_iterator( - self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: - """Creates a dataset iterator to use for learning.""" - dataset = datasets.make_reverb_dataset( - table=self._config.replay_table_name, - server_address=replay_client.server_address, - batch_size=( - self._config.batch_size * self._config.num_sgd_steps_per_step), - prefetch_size=self._config.prefetch_size, - transition_adder=True) - return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0]) - - def make_adder( - self, replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[actor_core_lib.FeedForwardPolicy] - ) -> Optional[adders.Adder]: - """Creates an adder which handles observations.""" - del environment_spec, policy - return adders_reverb.NStepTransitionAdder( - priority_fns={self._config.replay_table_name: None}, - client=replay_client, - n_step=self._config.n_step, - discount=self._config.discount) - - def make_policy(self, - networks: td3_networks.TD3Networks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> actor_core_lib.FeedForwardPolicy: - """Creates a policy.""" - sigma = 0 if evaluation else self._config.sigma - return td3_networks.get_default_behavior_policy( - networks=networks, action_specs=environment_spec.actions, sigma=sigma) + self._config = config + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: td3_networks.TD3Networks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec, replay_client + + critic_optimizer = optax.adam(self._config.critic_learning_rate) + twin_critic_optimizer = optax.adam(self._config.critic_learning_rate) + policy_optimizer = optax.adam(self._config.policy_learning_rate) + + if self._config.policy_gradient_clipping is not None: + policy_optimizer = optax.chain( + optax.clip_by_global_norm(self._config.policy_gradient_clipping), + policy_optimizer, + ) + + return learning.TD3Learner( + networks=networks, + random_key=random_key, + discount=self._config.discount, + target_sigma=self._config.target_sigma, + noise_clip=self._config.noise_clip, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + twin_critic_optimizer=twin_critic_optimizer, + num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, + bc_alpha=self._config.bc_alpha, + iterator=dataset, + logger=logger_fn("learner"), + counter=counter, + ) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: actor_core_lib.FeedForwardPolicy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + del environment_spec + assert variable_source is not None + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) + # Inference happens on CPU, so it's better to move variables there too. + variable_client = variable_utils.VariableClient( + variable_source, "policy", device="cpu" + ) + return actors.GenericActor( + actor_core, random_key, variable_client, adder, backend="cpu" + ) + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: actor_core_lib.FeedForwardPolicy, + ) -> List[reverb.Table]: + """Creates reverb tables for the algorithm.""" + del policy + samples_per_insert_tolerance = ( + self._config.samples_per_insert_tolerance_rate + * self._config.samples_per_insert + ) + error_buffer = self._config.min_replay_size * samples_per_insert_tolerance + limiter = rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer, + ) + return [ + reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=adders_reverb.NStepTransitionAdder.signature( + environment_spec + ), + ) + ] + + def make_dataset_iterator( + self, replay_client: reverb.Client + ) -> Iterator[reverb.ReplaySample]: + """Creates a dataset iterator to use for learning.""" + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=replay_client.server_address, + batch_size=(self._config.batch_size * self._config.num_sgd_steps_per_step), + prefetch_size=self._config.prefetch_size, + transition_adder=True, + ) + return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0]) + + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[actor_core_lib.FeedForwardPolicy], + ) -> Optional[adders.Adder]: + """Creates an adder which handles observations.""" + del environment_spec, policy + return adders_reverb.NStepTransitionAdder( + priority_fns={self._config.replay_table_name: None}, + client=replay_client, + n_step=self._config.n_step, + discount=self._config.discount, + ) + + def make_policy( + self, + networks: td3_networks.TD3Networks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> actor_core_lib.FeedForwardPolicy: + """Creates a policy.""" + sigma = 0 if evaluation else self._config.sigma + return td3_networks.get_default_behavior_policy( + networks=networks, action_specs=environment_spec.actions, sigma=sigma + ) diff --git a/acme/agents/jax/td3/config.py b/acme/agents/jax/td3/config.py index cb4e51e079..2745ab9dfe 100644 --- a/acme/agents/jax/td3/config.py +++ b/acme/agents/jax/td3/config.py @@ -16,45 +16,46 @@ import dataclasses from typing import Optional, Union -from acme.adders import reverb as adders_reverb import optax +from acme.adders import reverb as adders_reverb + @dataclasses.dataclass class TD3Config: - """Configuration options for TD3.""" - - # Loss options - batch_size: int = 256 - policy_learning_rate: Union[optax.Schedule, float] = 3e-4 - critic_learning_rate: Union[optax.Schedule, float] = 3e-4 - # Policy gradient clipping is not part of the original TD3 implementation, - # used e.g. in DAC https://arxiv.org/pdf/1809.02925.pdf - policy_gradient_clipping: Optional[float] = None - discount: float = 0.99 - n_step: int = 1 - - # TD3 specific options (https://arxiv.org/pdf/1802.09477.pdf) - sigma: float = 0.1 - delay: int = 2 - target_sigma: float = 0.2 - noise_clip: float = 0.5 - tau: float = 0.005 - - # Replay options - min_replay_size: int = 1000 - max_replay_size: int = 1000000 - replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE - prefetch_size: int = 4 - samples_per_insert: float = 256 - # Rate to be used for the SampleToInsertRatio rate limiter tolerance. - # See a formula in make_replay_tables for more details. - samples_per_insert_tolerance_rate: float = 0.1 - - # How many gradient updates to perform per step. - num_sgd_steps_per_step: int = 1 - - # Offline RL options - # if bc_alpha: if given, will add a bc regularization term to the policy loss, - # (https://arxiv.org/pdf/2106.06860.pdf), useful for offline training. - bc_alpha: Optional[float] = None + """Configuration options for TD3.""" + + # Loss options + batch_size: int = 256 + policy_learning_rate: Union[optax.Schedule, float] = 3e-4 + critic_learning_rate: Union[optax.Schedule, float] = 3e-4 + # Policy gradient clipping is not part of the original TD3 implementation, + # used e.g. in DAC https://arxiv.org/pdf/1809.02925.pdf + policy_gradient_clipping: Optional[float] = None + discount: float = 0.99 + n_step: int = 1 + + # TD3 specific options (https://arxiv.org/pdf/1802.09477.pdf) + sigma: float = 0.1 + delay: int = 2 + target_sigma: float = 0.2 + noise_clip: float = 0.5 + tau: float = 0.005 + + # Replay options + min_replay_size: int = 1000 + max_replay_size: int = 1000000 + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE + prefetch_size: int = 4 + samples_per_insert: float = 256 + # Rate to be used for the SampleToInsertRatio rate limiter tolerance. + # See a formula in make_replay_tables for more details. + samples_per_insert_tolerance_rate: float = 0.1 + + # How many gradient updates to perform per step. + num_sgd_steps_per_step: int = 1 + + # Offline RL options + # if bc_alpha: if given, will add a bc regularization term to the policy loss, + # (https://arxiv.org/pdf/2106.06860.pdf), useful for offline training. + bc_alpha: Optional[float] = None diff --git a/acme/agents/jax/td3/learning.py b/acme/agents/jax/td3/learning.py index 4459c9d7d6..22f45d920b 100644 --- a/acme/agents/jax/td3/learning.py +++ b/acme/agents/jax/td3/learning.py @@ -17,58 +17,61 @@ import time from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple -import acme -from acme import types -from acme.agents.jax.td3 import networks as td3_networks -from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.utils import counting -from acme.utils import loggers import jax import jax.numpy as jnp import optax import reverb import rlax +import acme +from acme import types +from acme.agents.jax.td3 import networks as td3_networks +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import counting, loggers + class TrainingState(NamedTuple): - """Contains training state for the learner.""" - policy_params: networks_lib.Params - target_policy_params: networks_lib.Params - critic_params: networks_lib.Params - target_critic_params: networks_lib.Params - twin_critic_params: networks_lib.Params - target_twin_critic_params: networks_lib.Params - policy_opt_state: optax.OptState - critic_opt_state: optax.OptState - twin_critic_opt_state: optax.OptState - steps: int - random_key: networks_lib.PRNGKey + """Contains training state for the learner.""" + + policy_params: networks_lib.Params + target_policy_params: networks_lib.Params + critic_params: networks_lib.Params + target_critic_params: networks_lib.Params + twin_critic_params: networks_lib.Params + target_twin_critic_params: networks_lib.Params + policy_opt_state: optax.OptState + critic_opt_state: optax.OptState + twin_critic_opt_state: optax.OptState + steps: int + random_key: networks_lib.PRNGKey class TD3Learner(acme.Learner): - """TD3 learner.""" - - _state: TrainingState - - def __init__(self, - networks: td3_networks.TD3Networks, - random_key: networks_lib.PRNGKey, - discount: float, - iterator: Iterator[reverb.ReplaySample], - policy_optimizer: optax.GradientTransformation, - critic_optimizer: optax.GradientTransformation, - twin_critic_optimizer: optax.GradientTransformation, - delay: int = 2, - target_sigma: float = 0.2, - noise_clip: float = 0.5, - tau: float = 0.005, - use_sarsa_target: bool = False, - bc_alpha: Optional[float] = None, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - num_sgd_steps_per_step: int = 1): - """Initializes the TD3 learner. + """TD3 learner.""" + + _state: TrainingState + + def __init__( + self, + networks: td3_networks.TD3Networks, + random_key: networks_lib.PRNGKey, + discount: float, + iterator: Iterator[reverb.ReplaySample], + policy_optimizer: optax.GradientTransformation, + critic_optimizer: optax.GradientTransformation, + twin_critic_optimizer: optax.GradientTransformation, + delay: int = 2, + target_sigma: float = 0.2, + noise_clip: float = 0.5, + tau: float = 0.005, + use_sarsa_target: bool = False, + bc_alpha: Optional[float] = None, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + num_sgd_steps_per_step: int = 1, + ): + """Initializes the TD3 learner. Args: networks: TD3 networks. @@ -98,236 +101,257 @@ def __init__(self, num_sgd_steps_per_step: number of sgd steps to perform per learner 'step'. """ - def policy_loss( - policy_params: networks_lib.Params, - critic_params: networks_lib.Params, - transition: types.NestedArray, - ) -> jnp.ndarray: - # Computes the discrete policy gradient loss. - action = networks.policy_network.apply( - policy_params, transition.observation) - grad_critic = jax.vmap( - jax.grad(networks.critic_network.apply, argnums=2), - in_axes=(None, 0, 0)) - dq_da = grad_critic(critic_params, transition.observation, action) - batch_dpg_learning = jax.vmap(rlax.dpg_loss, in_axes=(0, 0)) - loss = jnp.mean(batch_dpg_learning(action, dq_da)) - if bc_alpha is not None: - # BC regularization for offline RL - q_sa = networks.critic_network.apply(critic_params, - transition.observation, action) - bc_factor = jax.lax.stop_gradient(bc_alpha / jnp.mean(jnp.abs(q_sa))) - loss += jnp.mean(jnp.square(action - transition.action)) / bc_factor - return loss - - def critic_loss( - critic_params: networks_lib.Params, - state: TrainingState, - transition: types.Transition, - random_key: jnp.ndarray, - ): - # Computes the critic loss. - q_tm1 = networks.critic_network.apply( - critic_params, transition.observation, transition.action) - - if use_sarsa_target: - # TODO(b/222674779): use N-steps Trajectories to get the next actions. - assert 'next_action' in transition.extras, ( - 'next actions should be given as extras for one step RL.') - action = transition.extras['next_action'] - else: - action = networks.policy_network.apply(state.target_policy_params, - transition.next_observation) - action = networks.add_policy_noise(action, random_key, - target_sigma, noise_clip) - - q_t = networks.critic_network.apply( - state.target_critic_params, - transition.next_observation, - action) - twin_q_t = networks.twin_critic_network.apply( - state.target_twin_critic_params, - transition.next_observation, - action) - - q_t = jnp.minimum(q_t, twin_q_t) - - target_q_tm1 = transition.reward + discount * transition.discount * q_t - td_error = jax.lax.stop_gradient(target_q_tm1) - q_tm1 - - return jnp.mean(jnp.square(td_error)) - - def update_step( - state: TrainingState, - transitions: types.Transition, - ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: - - random_key, key_critic, key_twin = jax.random.split(state.random_key, 3) - - # Updates on the critic: compute the gradients, and update using - # Polyak averaging. - critic_loss_and_grad = jax.value_and_grad(critic_loss) - critic_loss_value, critic_gradients = critic_loss_and_grad( - state.critic_params, state, transitions, key_critic) - critic_updates, critic_opt_state = critic_optimizer.update( - critic_gradients, state.critic_opt_state) - critic_params = optax.apply_updates(state.critic_params, critic_updates) - # In the original authors' implementation the critic target update is - # delayed similarly to the policy update which we found empirically to - # perform slightly worse. - target_critic_params = optax.incremental_update( - new_tensors=critic_params, - old_tensors=state.target_critic_params, - step_size=tau) - - # Updates on the twin critic: compute the gradients, and update using - # Polyak averaging. - twin_critic_loss_value, twin_critic_gradients = critic_loss_and_grad( - state.twin_critic_params, state, transitions, key_twin) - twin_critic_updates, twin_critic_opt_state = twin_critic_optimizer.update( - twin_critic_gradients, state.twin_critic_opt_state) - twin_critic_params = optax.apply_updates(state.twin_critic_params, - twin_critic_updates) - # In the original authors' implementation the twin critic target update is - # delayed similarly to the policy update which we found empirically to - # perform slightly worse. - target_twin_critic_params = optax.incremental_update( - new_tensors=twin_critic_params, - old_tensors=state.target_twin_critic_params, - step_size=tau) - - # Updates on the policy: compute the gradients, and update using - # Polyak averaging (if delay enabled, the update might not be applied). - policy_loss_and_grad = jax.value_and_grad(policy_loss) - policy_loss_value, policy_gradients = policy_loss_and_grad( - state.policy_params, state.critic_params, transitions) - def update_policy_step(): - policy_updates, policy_opt_state = policy_optimizer.update( - policy_gradients, state.policy_opt_state) - policy_params = optax.apply_updates(state.policy_params, policy_updates) - target_policy_params = optax.incremental_update( - new_tensors=policy_params, - old_tensors=state.target_policy_params, - step_size=tau) - return policy_params, target_policy_params, policy_opt_state - - # The update on the policy is applied every `delay` steps. - current_policy_state = (state.policy_params, state.target_policy_params, - state.policy_opt_state) - policy_params, target_policy_params, policy_opt_state = jax.lax.cond( - state.steps % delay == 0, - lambda _: update_policy_step(), - lambda _: current_policy_state, - operand=None) - - steps = state.steps + 1 - - new_state = TrainingState( - policy_params=policy_params, - critic_params=critic_params, - twin_critic_params=twin_critic_params, - target_policy_params=target_policy_params, - target_critic_params=target_critic_params, - target_twin_critic_params=target_twin_critic_params, - policy_opt_state=policy_opt_state, - critic_opt_state=critic_opt_state, - twin_critic_opt_state=twin_critic_opt_state, - steps=steps, - random_key=random_key, - ) - - metrics = { - 'policy_loss': policy_loss_value, - 'critic_loss': critic_loss_value, - 'twin_critic_loss': twin_critic_loss_value, - } - - return new_state, metrics - - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger( - 'learner', - asynchronous=True, - serialize_fn=utils.fetch_devicearray, - steps_key=self._counter.get_steps_key()) - - # Create prefetching dataset iterator. - self._iterator = iterator - - # Faster sgd step - update_step = utils.process_multiple_batches(update_step, - num_sgd_steps_per_step) - # Use the JIT compiler. - self._update_step = jax.jit(update_step) - - (key_init_policy, key_init_twin, key_init_target, key_state - ) = jax.random.split(random_key, 4) - # Create the network parameters and copy into the target network parameters. - initial_policy_params = networks.policy_network.init(key_init_policy) - initial_critic_params = networks.critic_network.init(key_init_twin) - initial_twin_critic_params = networks.twin_critic_network.init( - key_init_target) - - initial_target_policy_params = initial_policy_params - initial_target_critic_params = initial_critic_params - initial_target_twin_critic_params = initial_twin_critic_params - - # Initialize optimizers. - initial_policy_opt_state = policy_optimizer.init(initial_policy_params) - initial_critic_opt_state = critic_optimizer.init(initial_critic_params) - initial_twin_critic_opt_state = twin_critic_optimizer.init( - initial_twin_critic_params) - - # Create initial state. - self._state = TrainingState( - policy_params=initial_policy_params, - target_policy_params=initial_target_policy_params, - critic_params=initial_critic_params, - twin_critic_params=initial_twin_critic_params, - target_critic_params=initial_target_critic_params, - target_twin_critic_params=initial_target_twin_critic_params, - policy_opt_state=initial_policy_opt_state, - critic_opt_state=initial_critic_opt_state, - twin_critic_opt_state=initial_twin_critic_opt_state, - steps=0, - random_key=key_state - ) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - def step(self): - # Get data from replay (dropping extras if any). Note there is no - # extra data here because we do not insert any into Reverb. - sample = next(self._iterator) - transitions = types.Transition(*sample.data) - - self._state, metrics = self._update_step(self._state, transitions) - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Increment counts and record the current time - counts = self._counter.increment(steps=1, walltime=elapsed_time) - - # Attempts to write the logs. - self._logger.write({**metrics, **counts}) - - def get_variables(self, names: List[str]) -> List[networks_lib.Params]: - variables = { - 'policy': self._state.policy_params, - 'critic': self._state.critic_params, - 'twin_critic': self._state.twin_critic_params, - } - return [variables[name] for name in names] - - def save(self) -> TrainingState: - return self._state - - def restore(self, state: TrainingState): - self._state = state + def policy_loss( + policy_params: networks_lib.Params, + critic_params: networks_lib.Params, + transition: types.NestedArray, + ) -> jnp.ndarray: + # Computes the discrete policy gradient loss. + action = networks.policy_network.apply( + policy_params, transition.observation + ) + grad_critic = jax.vmap( + jax.grad(networks.critic_network.apply, argnums=2), in_axes=(None, 0, 0) + ) + dq_da = grad_critic(critic_params, transition.observation, action) + batch_dpg_learning = jax.vmap(rlax.dpg_loss, in_axes=(0, 0)) + loss = jnp.mean(batch_dpg_learning(action, dq_da)) + if bc_alpha is not None: + # BC regularization for offline RL + q_sa = networks.critic_network.apply( + critic_params, transition.observation, action + ) + bc_factor = jax.lax.stop_gradient(bc_alpha / jnp.mean(jnp.abs(q_sa))) + loss += jnp.mean(jnp.square(action - transition.action)) / bc_factor + return loss + + def critic_loss( + critic_params: networks_lib.Params, + state: TrainingState, + transition: types.Transition, + random_key: jnp.ndarray, + ): + # Computes the critic loss. + q_tm1 = networks.critic_network.apply( + critic_params, transition.observation, transition.action + ) + + if use_sarsa_target: + # TODO(b/222674779): use N-steps Trajectories to get the next actions. + assert ( + "next_action" in transition.extras + ), "next actions should be given as extras for one step RL." + action = transition.extras["next_action"] + else: + action = networks.policy_network.apply( + state.target_policy_params, transition.next_observation + ) + action = networks.add_policy_noise( + action, random_key, target_sigma, noise_clip + ) + + q_t = networks.critic_network.apply( + state.target_critic_params, transition.next_observation, action + ) + twin_q_t = networks.twin_critic_network.apply( + state.target_twin_critic_params, transition.next_observation, action + ) + + q_t = jnp.minimum(q_t, twin_q_t) + + target_q_tm1 = transition.reward + discount * transition.discount * q_t + td_error = jax.lax.stop_gradient(target_q_tm1) - q_tm1 + + return jnp.mean(jnp.square(td_error)) + + def update_step( + state: TrainingState, transitions: types.Transition, + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + + random_key, key_critic, key_twin = jax.random.split(state.random_key, 3) + + # Updates on the critic: compute the gradients, and update using + # Polyak averaging. + critic_loss_and_grad = jax.value_and_grad(critic_loss) + critic_loss_value, critic_gradients = critic_loss_and_grad( + state.critic_params, state, transitions, key_critic + ) + critic_updates, critic_opt_state = critic_optimizer.update( + critic_gradients, state.critic_opt_state + ) + critic_params = optax.apply_updates(state.critic_params, critic_updates) + # In the original authors' implementation the critic target update is + # delayed similarly to the policy update which we found empirically to + # perform slightly worse. + target_critic_params = optax.incremental_update( + new_tensors=critic_params, + old_tensors=state.target_critic_params, + step_size=tau, + ) + + # Updates on the twin critic: compute the gradients, and update using + # Polyak averaging. + twin_critic_loss_value, twin_critic_gradients = critic_loss_and_grad( + state.twin_critic_params, state, transitions, key_twin + ) + twin_critic_updates, twin_critic_opt_state = twin_critic_optimizer.update( + twin_critic_gradients, state.twin_critic_opt_state + ) + twin_critic_params = optax.apply_updates( + state.twin_critic_params, twin_critic_updates + ) + # In the original authors' implementation the twin critic target update is + # delayed similarly to the policy update which we found empirically to + # perform slightly worse. + target_twin_critic_params = optax.incremental_update( + new_tensors=twin_critic_params, + old_tensors=state.target_twin_critic_params, + step_size=tau, + ) + + # Updates on the policy: compute the gradients, and update using + # Polyak averaging (if delay enabled, the update might not be applied). + policy_loss_and_grad = jax.value_and_grad(policy_loss) + policy_loss_value, policy_gradients = policy_loss_and_grad( + state.policy_params, state.critic_params, transitions + ) + + def update_policy_step(): + policy_updates, policy_opt_state = policy_optimizer.update( + policy_gradients, state.policy_opt_state + ) + policy_params = optax.apply_updates(state.policy_params, policy_updates) + target_policy_params = optax.incremental_update( + new_tensors=policy_params, + old_tensors=state.target_policy_params, + step_size=tau, + ) + return policy_params, target_policy_params, policy_opt_state + + # The update on the policy is applied every `delay` steps. + current_policy_state = ( + state.policy_params, + state.target_policy_params, + state.policy_opt_state, + ) + policy_params, target_policy_params, policy_opt_state = jax.lax.cond( + state.steps % delay == 0, + lambda _: update_policy_step(), + lambda _: current_policy_state, + operand=None, + ) + + steps = state.steps + 1 + + new_state = TrainingState( + policy_params=policy_params, + critic_params=critic_params, + twin_critic_params=twin_critic_params, + target_policy_params=target_policy_params, + target_critic_params=target_critic_params, + target_twin_critic_params=target_twin_critic_params, + policy_opt_state=policy_opt_state, + critic_opt_state=critic_opt_state, + twin_critic_opt_state=twin_critic_opt_state, + steps=steps, + random_key=random_key, + ) + + metrics = { + "policy_loss": policy_loss_value, + "critic_loss": critic_loss_value, + "twin_critic_loss": twin_critic_loss_value, + } + + return new_state, metrics + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + "learner", + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key(), + ) + + # Create prefetching dataset iterator. + self._iterator = iterator + + # Faster sgd step + update_step = utils.process_multiple_batches( + update_step, num_sgd_steps_per_step + ) + # Use the JIT compiler. + self._update_step = jax.jit(update_step) + + (key_init_policy, key_init_twin, key_init_target, key_state) = jax.random.split( + random_key, 4 + ) + # Create the network parameters and copy into the target network parameters. + initial_policy_params = networks.policy_network.init(key_init_policy) + initial_critic_params = networks.critic_network.init(key_init_twin) + initial_twin_critic_params = networks.twin_critic_network.init(key_init_target) + + initial_target_policy_params = initial_policy_params + initial_target_critic_params = initial_critic_params + initial_target_twin_critic_params = initial_twin_critic_params + + # Initialize optimizers. + initial_policy_opt_state = policy_optimizer.init(initial_policy_params) + initial_critic_opt_state = critic_optimizer.init(initial_critic_params) + initial_twin_critic_opt_state = twin_critic_optimizer.init( + initial_twin_critic_params + ) + + # Create initial state. + self._state = TrainingState( + policy_params=initial_policy_params, + target_policy_params=initial_target_policy_params, + critic_params=initial_critic_params, + twin_critic_params=initial_twin_critic_params, + target_critic_params=initial_target_critic_params, + target_twin_critic_params=initial_target_twin_critic_params, + policy_opt_state=initial_policy_opt_state, + critic_opt_state=initial_critic_opt_state, + twin_critic_opt_state=initial_twin_critic_opt_state, + steps=0, + random_key=key_state, + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + def step(self): + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + sample = next(self._iterator) + transitions = types.Transition(*sample.data) + + self._state, metrics = self._update_step(self._state, transitions) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[networks_lib.Params]: + variables = { + "policy": self._state.policy_params, + "critic": self._state.critic_params, + "twin_critic": self._state.twin_critic_params, + } + return [variables[name] for name in names] + + def save(self) -> TrainingState: + return self._state + + def restore(self, state: TrainingState): + self._state = state diff --git a/acme/agents/jax/td3/networks.py b/acme/agents/jax/td3/networks.py index aa478b03c1..4fc829e0f7 100644 --- a/acme/agents/jax/td3/networks.py +++ b/acme/agents/jax/td3/networks.py @@ -16,45 +16,53 @@ import dataclasses from typing import Callable, Sequence -from acme import specs -from acme import types -from acme.agents.jax import actor_core as actor_core_lib -from acme.jax import networks as networks_lib -from acme.jax import utils import haiku as hk import jax import jax.numpy as jnp import numpy as np +from acme import specs, types +from acme.agents.jax import actor_core as actor_core_lib +from acme.jax import networks as networks_lib +from acme.jax import utils + @dataclasses.dataclass class TD3Networks: - """Network and pure functions for the TD3 agent.""" - policy_network: networks_lib.FeedForwardNetwork - critic_network: networks_lib.FeedForwardNetwork - twin_critic_network: networks_lib.FeedForwardNetwork - add_policy_noise: Callable[[types.NestedArray, networks_lib.PRNGKey, - float, float], types.NestedArray] + """Network and pure functions for the TD3 agent.""" + + policy_network: networks_lib.FeedForwardNetwork + critic_network: networks_lib.FeedForwardNetwork + twin_critic_network: networks_lib.FeedForwardNetwork + add_policy_noise: Callable[ + [types.NestedArray, networks_lib.PRNGKey, float, float], types.NestedArray + ] def get_default_behavior_policy( - networks: TD3Networks, action_specs: specs.BoundedArray, - sigma: float) -> actor_core_lib.FeedForwardPolicy: - """Selects action according to the policy.""" - def behavior_policy(params: networks_lib.Params, key: networks_lib.PRNGKey, - observation: types.NestedArray): - action = networks.policy_network.apply(params, observation) - noise = jax.random.normal(key, shape=action.shape) * sigma - noisy_action = jnp.clip(action + noise, - action_specs.minimum, action_specs.maximum) - return noisy_action - return behavior_policy + networks: TD3Networks, action_specs: specs.BoundedArray, sigma: float +) -> actor_core_lib.FeedForwardPolicy: + """Selects action according to the policy.""" + + def behavior_policy( + params: networks_lib.Params, + key: networks_lib.PRNGKey, + observation: types.NestedArray, + ): + action = networks.policy_network.apply(params, observation) + noise = jax.random.normal(key, shape=action.shape) * sigma + noisy_action = jnp.clip( + action + noise, action_specs.minimum, action_specs.maximum + ) + return noisy_action + + return behavior_policy def make_networks( - spec: specs.EnvironmentSpec, - hidden_layer_sizes: Sequence[int] = (256, 256)) -> TD3Networks: - """Creates networks used by the agent. + spec: specs.EnvironmentSpec, hidden_layer_sizes: Sequence[int] = (256, 256) +) -> TD3Networks: + """Creates networks used by the agent. The networks used are based on LayerNormMLP, which is different than the MLP with relu activation described in TD3 (which empirically performs worse). @@ -67,52 +75,60 @@ def make_networks( network: TD3Networks """ - action_specs = spec.actions - num_dimensions = np.prod(action_specs.shape, dtype=int) - - def add_policy_noise(action: types.NestedArray, - key: networks_lib.PRNGKey, - target_sigma: float, - noise_clip: float) -> types.NestedArray: - """Adds action noise to bootstrapped Q-value estimate in critic loss.""" - noise = jax.random.normal(key=key, shape=action_specs.shape) * target_sigma - noise = jnp.clip(noise, -noise_clip, noise_clip) - return jnp.clip(action + noise, action_specs.minimum, action_specs.maximum) - - def _actor_fn(obs: types.NestedArray) -> types.NestedArray: - network = hk.Sequential([ - networks_lib.LayerNormMLP(hidden_layer_sizes, - activate_final=True), - networks_lib.NearZeroInitializedLinear(num_dimensions), - networks_lib.TanhToSpec(spec.actions), - ]) - return network(obs) - - def _critic_fn(obs: types.NestedArray, - action: types.NestedArray) -> types.NestedArray: - network1 = hk.Sequential([ - networks_lib.LayerNormMLP(list(hidden_layer_sizes) + [1]), - ]) - input_ = jnp.concatenate([obs, action], axis=-1) - value = network1(input_) - return jnp.squeeze(value) - - policy = hk.without_apply_rng(hk.transform(_actor_fn)) - critic = hk.without_apply_rng(hk.transform(_critic_fn)) - - # Create dummy observations and actions to create network parameters. - dummy_action = utils.zeros_like(spec.actions) - dummy_obs = utils.zeros_like(spec.observations) - dummy_action = utils.add_batch_dim(dummy_action) - dummy_obs = utils.add_batch_dim(dummy_obs) - - network = TD3Networks( - policy_network=networks_lib.FeedForwardNetwork( - lambda key: policy.init(key, dummy_obs), policy.apply), - critic_network=networks_lib.FeedForwardNetwork( - lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply), - twin_critic_network=networks_lib.FeedForwardNetwork( - lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply), - add_policy_noise=add_policy_noise) - - return network + action_specs = spec.actions + num_dimensions = np.prod(action_specs.shape, dtype=int) + + def add_policy_noise( + action: types.NestedArray, + key: networks_lib.PRNGKey, + target_sigma: float, + noise_clip: float, + ) -> types.NestedArray: + """Adds action noise to bootstrapped Q-value estimate in critic loss.""" + noise = jax.random.normal(key=key, shape=action_specs.shape) * target_sigma + noise = jnp.clip(noise, -noise_clip, noise_clip) + return jnp.clip(action + noise, action_specs.minimum, action_specs.maximum) + + def _actor_fn(obs: types.NestedArray) -> types.NestedArray: + network = hk.Sequential( + [ + networks_lib.LayerNormMLP(hidden_layer_sizes, activate_final=True), + networks_lib.NearZeroInitializedLinear(num_dimensions), + networks_lib.TanhToSpec(spec.actions), + ] + ) + return network(obs) + + def _critic_fn( + obs: types.NestedArray, action: types.NestedArray + ) -> types.NestedArray: + network1 = hk.Sequential( + [networks_lib.LayerNormMLP(list(hidden_layer_sizes) + [1]),] + ) + input_ = jnp.concatenate([obs, action], axis=-1) + value = network1(input_) + return jnp.squeeze(value) + + policy = hk.without_apply_rng(hk.transform(_actor_fn)) + critic = hk.without_apply_rng(hk.transform(_critic_fn)) + + # Create dummy observations and actions to create network parameters. + dummy_action = utils.zeros_like(spec.actions) + dummy_obs = utils.zeros_like(spec.observations) + dummy_action = utils.add_batch_dim(dummy_action) + dummy_obs = utils.add_batch_dim(dummy_obs) + + network = TD3Networks( + policy_network=networks_lib.FeedForwardNetwork( + lambda key: policy.init(key, dummy_obs), policy.apply + ), + critic_network=networks_lib.FeedForwardNetwork( + lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply + ), + twin_critic_network=networks_lib.FeedForwardNetwork( + lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply + ), + add_policy_noise=add_policy_noise, + ) + + return network diff --git a/acme/agents/jax/value_dice/__init__.py b/acme/agents/jax/value_dice/__init__.py index d86640aa9f..eae19b7905 100644 --- a/acme/agents/jax/value_dice/__init__.py +++ b/acme/agents/jax/value_dice/__init__.py @@ -17,6 +17,8 @@ from acme.agents.jax.value_dice.builder import ValueDiceBuilder from acme.agents.jax.value_dice.config import ValueDiceConfig from acme.agents.jax.value_dice.learning import ValueDiceLearner -from acme.agents.jax.value_dice.networks import apply_policy_and_sample -from acme.agents.jax.value_dice.networks import make_networks -from acme.agents.jax.value_dice.networks import ValueDiceNetworks +from acme.agents.jax.value_dice.networks import ( + ValueDiceNetworks, + apply_policy_and_sample, + make_networks, +) diff --git a/acme/agents/jax/value_dice/builder.py b/acme/agents/jax/value_dice/builder.py index 98880ccaba..c72841a5ab 100644 --- a/acme/agents/jax/value_dice/builder.py +++ b/acme/agents/jax/value_dice/builder.py @@ -16,144 +16,158 @@ from typing import Callable, Iterator, List, Optional -from acme import adders -from acme import core -from acme import specs -from acme import types +import jax +import optax +import reverb +from reverb import rate_limiters + +from acme import adders, core, specs, types from acme.adders import reverb as adders_reverb from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.agents.jax import builders +from acme.agents.jax import actors, builders from acme.agents.jax.value_dice import config as value_dice_config from acme.agents.jax.value_dice import learning from acme.agents.jax.value_dice import networks as value_dice_networks from acme.datasets import reverb as datasets from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.jax import variable_utils -from acme.utils import counting -from acme.utils import loggers -import jax -import optax -import reverb -from reverb import rate_limiters +from acme.jax import utils, variable_utils +from acme.utils import counting, loggers class ValueDiceBuilder( - builders.ActorLearnerBuilder[value_dice_networks.ValueDiceNetworks, - actor_core_lib.FeedForwardPolicy, - reverb.ReplaySample]): - """ValueDice Builder. + builders.ActorLearnerBuilder[ + value_dice_networks.ValueDiceNetworks, + actor_core_lib.FeedForwardPolicy, + reverb.ReplaySample, + ] +): + """ValueDice Builder. This builder is an entry point for online version of ValueDice. For offline please use the ValueDiceLearner directly. """ - def __init__(self, config: value_dice_config.ValueDiceConfig, - make_demonstrations: Callable[[int], - Iterator[types.Transition]]): - self._make_demonstrations = make_demonstrations - self._config = config + def __init__( + self, + config: value_dice_config.ValueDiceConfig, + make_demonstrations: Callable[[int], Iterator[types.Transition]], + ): + self._make_demonstrations = make_demonstrations + self._config = config - def make_learner( - self, - random_key: networks_lib.PRNGKey, - networks: value_dice_networks.ValueDiceNetworks, - dataset: Iterator[reverb.ReplaySample], - logger_fn: loggers.LoggerFactory, - environment_spec: specs.EnvironmentSpec, - replay_client: Optional[reverb.Client] = None, - counter: Optional[counting.Counter] = None, - ) -> core.Learner: - del environment_spec, replay_client - iterator_demonstration = self._make_demonstrations( - self._config.batch_size * self._config.num_sgd_steps_per_step) - policy_optimizer = optax.adam( - learning_rate=self._config.policy_learning_rate) - nu_optimizer = optax.adam(learning_rate=self._config.nu_learning_rate) - return learning.ValueDiceLearner( - networks=networks, - policy_optimizer=policy_optimizer, - nu_optimizer=nu_optimizer, - discount=self._config.discount, - rng=random_key, - alpha=self._config.alpha, - policy_reg_scale=self._config.policy_reg_scale, - nu_reg_scale=self._config.nu_reg_scale, - num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, - iterator_replay=dataset, - iterator_demonstrations=iterator_demonstration, - logger=logger_fn('learner'), - counter=counter, - ) + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: value_dice_networks.ValueDiceNetworks, + dataset: Iterator[reverb.ReplaySample], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + del environment_spec, replay_client + iterator_demonstration = self._make_demonstrations( + self._config.batch_size * self._config.num_sgd_steps_per_step + ) + policy_optimizer = optax.adam(learning_rate=self._config.policy_learning_rate) + nu_optimizer = optax.adam(learning_rate=self._config.nu_learning_rate) + return learning.ValueDiceLearner( + networks=networks, + policy_optimizer=policy_optimizer, + nu_optimizer=nu_optimizer, + discount=self._config.discount, + rng=random_key, + alpha=self._config.alpha, + policy_reg_scale=self._config.policy_reg_scale, + nu_reg_scale=self._config.nu_reg_scale, + num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, + iterator_replay=dataset, + iterator_demonstrations=iterator_demonstration, + logger=logger_fn("learner"), + counter=counter, + ) - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - policy: actor_core_lib.FeedForwardPolicy, - ) -> List[reverb.Table]: - del policy - samples_per_insert_tolerance = ( - self._config.samples_per_insert_tolerance_rate * - self._config.samples_per_insert) - error_buffer = self._config.min_replay_size * samples_per_insert_tolerance - limiter = rate_limiters.SampleToInsertRatio( - min_size_to_sample=self._config.min_replay_size, - samples_per_insert=self._config.samples_per_insert, - error_buffer=error_buffer) - return [reverb.Table( - name=self._config.replay_table_name, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=self._config.max_replay_size, - rate_limiter=limiter, - signature=adders_reverb.NStepTransitionAdder.signature( - environment_spec))] + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + policy: actor_core_lib.FeedForwardPolicy, + ) -> List[reverb.Table]: + del policy + samples_per_insert_tolerance = ( + self._config.samples_per_insert_tolerance_rate + * self._config.samples_per_insert + ) + error_buffer = self._config.min_replay_size * samples_per_insert_tolerance + limiter = rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer, + ) + return [ + reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=adders_reverb.NStepTransitionAdder.signature( + environment_spec + ), + ) + ] - def make_dataset_iterator( - self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]: - """Creates a dataset iterator to use for learning.""" - dataset = datasets.make_reverb_dataset( - table=self._config.replay_table_name, - server_address=replay_client.server_address, - batch_size=( - self._config.batch_size * self._config.num_sgd_steps_per_step), - prefetch_size=self._config.prefetch_size) - return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0]) + def make_dataset_iterator( + self, replay_client: reverb.Client + ) -> Iterator[reverb.ReplaySample]: + """Creates a dataset iterator to use for learning.""" + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=replay_client.server_address, + batch_size=(self._config.batch_size * self._config.num_sgd_steps_per_step), + prefetch_size=self._config.prefetch_size, + ) + return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0]) - def make_adder( - self, replay_client: reverb.Client, - environment_spec: Optional[specs.EnvironmentSpec], - policy: Optional[actor_core_lib.FeedForwardPolicy] - ) -> Optional[adders.Adder]: - del environment_spec, policy - return adders_reverb.NStepTransitionAdder( - priority_fns={self._config.replay_table_name: None}, - client=replay_client, - n_step=1, - discount=self._config.discount) + def make_adder( + self, + replay_client: reverb.Client, + environment_spec: Optional[specs.EnvironmentSpec], + policy: Optional[actor_core_lib.FeedForwardPolicy], + ) -> Optional[adders.Adder]: + del environment_spec, policy + return adders_reverb.NStepTransitionAdder( + priority_fns={self._config.replay_table_name: None}, + client=replay_client, + n_step=1, + discount=self._config.discount, + ) - def make_actor( - self, - random_key: networks_lib.PRNGKey, - policy: actor_core_lib.FeedForwardPolicy, - environment_spec: specs.EnvironmentSpec, - variable_source: Optional[core.VariableSource] = None, - adder: Optional[adders.Adder] = None, - ) -> core.Actor: - del environment_spec - assert variable_source is not None - actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) - # Inference happens on CPU, so it's better to move variables there too. - variable_client = variable_utils.VariableClient(variable_source, 'policy', - device='cpu') - return actors.GenericActor( - actor_core, random_key, variable_client, adder, backend='cpu') + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: actor_core_lib.FeedForwardPolicy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + adder: Optional[adders.Adder] = None, + ) -> core.Actor: + del environment_spec + assert variable_source is not None + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) + # Inference happens on CPU, so it's better to move variables there too. + variable_client = variable_utils.VariableClient( + variable_source, "policy", device="cpu" + ) + return actors.GenericActor( + actor_core, random_key, variable_client, adder, backend="cpu" + ) - def make_policy(self, - networks: value_dice_networks.ValueDiceNetworks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool = False) -> actor_core_lib.FeedForwardPolicy: - del environment_spec - return value_dice_networks.apply_policy_and_sample( - networks, eval_mode=evaluation) + def make_policy( + self, + networks: value_dice_networks.ValueDiceNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool = False, + ) -> actor_core_lib.FeedForwardPolicy: + del environment_spec + return value_dice_networks.apply_policy_and_sample( + networks, eval_mode=evaluation + ) diff --git a/acme/agents/jax/value_dice/config.py b/acme/agents/jax/value_dice/config.py index 7f8a28dad7..e1c810c09b 100644 --- a/acme/agents/jax/value_dice/config.py +++ b/acme/agents/jax/value_dice/config.py @@ -21,25 +21,25 @@ @dataclasses.dataclass class ValueDiceConfig: - """Configuration options for ValueDice.""" - - policy_learning_rate: float = 1e-5 - nu_learning_rate: float = 1e-3 - discount: float = .99 - batch_size: int = 256 - alpha: float = 0.05 - policy_reg_scale: float = 1e-4 - nu_reg_scale: float = 10.0 - - # Replay options - replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE - samples_per_insert: float = 256 * 4 - # Rate to be used for the SampleToInsertRatio rate limitter tolerance. - # See a formula in make_replay_tables for more details. - samples_per_insert_tolerance_rate: float = 0.1 - min_replay_size: int = 1000 - max_replay_size: int = 1000000 - prefetch_size: int = 4 - - # How many gradient updates to perform per step. - num_sgd_steps_per_step: int = 1 + """Configuration options for ValueDice.""" + + policy_learning_rate: float = 1e-5 + nu_learning_rate: float = 1e-3 + discount: float = 0.99 + batch_size: int = 256 + alpha: float = 0.05 + policy_reg_scale: float = 1e-4 + nu_reg_scale: float = 10.0 + + # Replay options + replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE + samples_per_insert: float = 256 * 4 + # Rate to be used for the SampleToInsertRatio rate limitter tolerance. + # See a formula in make_replay_tables for more details. + samples_per_insert_tolerance_rate: float = 0.1 + min_replay_size: int = 1000 + max_replay_size: int = 1000000 + prefetch_size: int = 4 + + # How many gradient updates to perform per step. + num_sgd_steps_per_step: int = 1 diff --git a/acme/agents/jax/value_dice/learning.py b/acme/agents/jax/value_dice/learning.py index 33d49e08cf..65c921d4c5 100644 --- a/acme/agents/jax/value_dice/learning.py +++ b/acme/agents/jax/value_dice/learning.py @@ -18,31 +18,32 @@ import time from typing import Any, Dict, Iterator, List, Mapping, NamedTuple, Optional, Tuple +import jax +import jax.numpy as jnp +import optax +import reverb + import acme from acme import types from acme.agents.jax.value_dice import networks as value_dice_networks from acme.jax import networks as networks_lib from acme.jax import utils -from acme.utils import counting -from acme.utils import loggers -import jax -import jax.numpy as jnp -import optax -import reverb +from acme.utils import counting, loggers class TrainingState(NamedTuple): - """Contains training state for the learner.""" - policy_optimizer_state: optax.OptState - policy_params: networks_lib.Params - nu_optimizer_state: optax.OptState - nu_params: networks_lib.Params - key: jnp.ndarray - steps: int + """Contains training state for the learner.""" + + policy_optimizer_state: optax.OptState + policy_params: networks_lib.Params + nu_optimizer_state: optax.OptState + nu_params: networks_lib.Params + key: jnp.ndarray + steps: int def _orthogonal_regularization_loss(params: networks_lib.Params): - """Orthogonal regularization. + """Orthogonal regularization. See equation (3) in https://arxiv.org/abs/1809.11096. @@ -52,278 +53,302 @@ def _orthogonal_regularization_loss(params: networks_lib.Params): Returns: A regularization loss term. """ - reg_loss = 0 - for key in params: - if isinstance(params[key], Mapping): - reg_loss += _orthogonal_regularization_loss(params[key]) - continue - variable = params[key] - assert len(variable.shape) in [1, 2, 4] - if len(variable.shape) == 1: - # This is a bias so do not apply regularization. - continue - if len(variable.shape) == 4: - # CNN - variable = jnp.reshape(variable, (-1, variable.shape[-1])) - prod = jnp.matmul(jnp.transpose(variable), variable) - reg_loss += jnp.sum(jnp.square(prod * (1 - jnp.eye(prod.shape[0])))) - return reg_loss + reg_loss = 0 + for key in params: + if isinstance(params[key], Mapping): + reg_loss += _orthogonal_regularization_loss(params[key]) + continue + variable = params[key] + assert len(variable.shape) in [1, 2, 4] + if len(variable.shape) == 1: + # This is a bias so do not apply regularization. + continue + if len(variable.shape) == 4: + # CNN + variable = jnp.reshape(variable, (-1, variable.shape[-1])) + prod = jnp.matmul(jnp.transpose(variable), variable) + reg_loss += jnp.sum(jnp.square(prod * (1 - jnp.eye(prod.shape[0])))) + return reg_loss class ValueDiceLearner(acme.Learner): - """ValueDice learner.""" - - _state: TrainingState - - def __init__(self, - networks: value_dice_networks.ValueDiceNetworks, - policy_optimizer: optax.GradientTransformation, - nu_optimizer: optax.GradientTransformation, - discount: float, - rng: jnp.ndarray, - iterator_replay: Iterator[reverb.ReplaySample], - iterator_demonstrations: Iterator[types.Transition], - alpha: float = 0.05, - policy_reg_scale: float = 1e-4, - nu_reg_scale: float = 10.0, - num_sgd_steps_per_step: int = 1, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None): - - rng, policy_key, nu_key = jax.random.split(rng, 3) - policy_init_params = networks.policy_network.init(policy_key) - policy_optimizer_state = policy_optimizer.init(policy_init_params) - - nu_init_params = networks.nu_network.init(nu_key) - nu_optimizer_state = nu_optimizer.init(nu_init_params) - - def compute_losses( - policy_params: networks_lib.Params, - nu_params: networks_lib.Params, - key: jnp.ndarray, - replay_o_tm1: types.NestedArray, - replay_a_tm1: types.NestedArray, - replay_o_t: types.NestedArray, - demo_o_tm1: types.NestedArray, - demo_a_tm1: types.NestedArray, - demo_o_t: types.NestedArray, - ) -> jnp.ndarray: - # TODO(damienv, hussenot): what to do with the discounts ? - - def policy(obs, key): - dist_params = networks.policy_network.apply(policy_params, obs) - return networks.sample(dist_params, key) - - key1, key2, key3, key4 = jax.random.split(key, 4) - - # Predicted actions. - demo_o_t0 = demo_o_tm1 - policy_demo_a_t0 = policy(demo_o_t0, key1) - policy_demo_a_t = policy(demo_o_t, key2) - policy_replay_a_t = policy(replay_o_t, key3) - - replay_a_tm1 = networks.encode_action(replay_a_tm1) - demo_a_tm1 = networks.encode_action(demo_a_tm1) - policy_demo_a_t0 = networks.encode_action(policy_demo_a_t0) - policy_demo_a_t = networks.encode_action(policy_demo_a_t) - policy_replay_a_t = networks.encode_action(policy_replay_a_t) - - # "Value function" nu over the expert states. - nu_demo_t0 = networks.nu_network.apply(nu_params, demo_o_t0, - policy_demo_a_t0) - nu_demo_tm1 = networks.nu_network.apply(nu_params, demo_o_tm1, demo_a_tm1) - nu_demo_t = networks.nu_network.apply(nu_params, demo_o_t, - policy_demo_a_t) - nu_demo_diff = nu_demo_tm1 - discount * nu_demo_t - - # "Value function" nu over the replay buffer states. - nu_replay_tm1 = networks.nu_network.apply(nu_params, replay_o_tm1, - replay_a_tm1) - nu_replay_t = networks.nu_network.apply(nu_params, replay_o_t, - policy_replay_a_t) - nu_replay_diff = nu_replay_tm1 - discount * nu_replay_t - - # Linear part of the loss. - linear_loss_demo = jnp.mean(nu_demo_t0 * (1.0 - discount)) - linear_loss_rb = jnp.mean(nu_replay_diff) - linear_loss = (linear_loss_demo * (1 - alpha) + linear_loss_rb * alpha) - - # Non linear part of the loss. - nu_replay_demo_diff = jnp.concatenate([nu_demo_diff, nu_replay_diff], - axis=0) - replay_demo_weights = jnp.concatenate([ - jnp.ones_like(nu_demo_diff) * (1 - alpha), - jnp.ones_like(nu_replay_diff) * alpha - ], - axis=0) - replay_demo_weights /= jnp.mean(replay_demo_weights) - non_linear_loss = jnp.sum( - jax.lax.stop_gradient( - utils.weighted_softmax(nu_replay_demo_diff, replay_demo_weights, - axis=0)) * - nu_replay_demo_diff) - - # Final loss. - loss = (non_linear_loss - linear_loss) - - # Regularized policy loss. - if policy_reg_scale > 0.: - policy_reg = _orthogonal_regularization_loss(policy_params) - else: - policy_reg = 0. - - # Gradient penality on nu - if nu_reg_scale > 0.0: - batch_size = demo_o_tm1.shape[0] - c = jax.random.uniform(key4, shape=(batch_size,)) - shape_o = [ - dim if i == 0 else 1 for i, dim in enumerate(replay_o_tm1.shape) - ] - shape_a = [ - dim if i == 0 else 1 for i, dim in enumerate(replay_a_tm1.shape) - ] - c_o = jnp.reshape(c, shape_o) - c_a = jnp.reshape(c, shape_a) - mixed_o_tm1 = c_o * demo_o_tm1 + (1 - c_o) * replay_o_tm1 - mixed_a_tm1 = c_a * demo_a_tm1 + (1 - c_a) * replay_a_tm1 - mixed_o_t = c_o * demo_o_t + (1 - c_o) * replay_o_t - mixed_policy_a_t = c_a * policy_demo_a_t + (1 - c_a) * policy_replay_a_t - mixed_o = jnp.concatenate([mixed_o_tm1, mixed_o_t], axis=0) - mixed_a = jnp.concatenate([mixed_a_tm1, mixed_policy_a_t], axis=0) - - def sum_nu(o, a): - return jnp.sum(networks.nu_network.apply(nu_params, o, a)) - - nu_grad_o_fn = jax.grad(sum_nu, argnums=0) - nu_grad_a_fn = jax.grad(sum_nu, argnums=1) - nu_grad_o = nu_grad_o_fn(mixed_o, mixed_a) - nu_grad_a = nu_grad_a_fn(mixed_o, mixed_a) - nu_grad = jnp.concatenate([ - jnp.reshape(nu_grad_o, [batch_size, -1]), - jnp.reshape(nu_grad_a, [batch_size, -1])], axis=-1) - # TODO(damienv, hussenot): check for the need of eps - # (like in the original value dice code). - nu_grad_penalty = jnp.mean( - jnp.square( - jnp.linalg.norm(nu_grad + 1e-8, axis=-1, keepdims=True) - 1)) - else: - nu_grad_penalty = 0.0 - - policy_loss = -loss + policy_reg_scale * policy_reg - nu_loss = loss + nu_reg_scale * nu_grad_penalty - - return policy_loss, nu_loss # pytype: disable=bad-return-type # jax-ndarray - - def sgd_step( - state: TrainingState, - data: Tuple[types.Transition, types.Transition] - ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: - replay_transitions, demo_transitions = data - key, key_loss = jax.random.split(state.key) - compute_losses_with_input = functools.partial( - compute_losses, - replay_o_tm1=replay_transitions.observation, - replay_a_tm1=replay_transitions.action, - replay_o_t=replay_transitions.next_observation, - demo_o_tm1=demo_transitions.observation, - demo_a_tm1=demo_transitions.action, - demo_o_t=demo_transitions.next_observation, - key=key_loss) - (policy_loss_value, nu_loss_value), vjpfun = jax.vjp( - compute_losses_with_input, - state.policy_params, state.nu_params) - policy_gradients, _ = vjpfun((1.0, 0.0)) - _, nu_gradients = vjpfun((0.0, 1.0)) - - # Update optimizers. - policy_update, policy_optimizer_state = policy_optimizer.update( - policy_gradients, state.policy_optimizer_state) - policy_params = optax.apply_updates(state.policy_params, policy_update) - - nu_update, nu_optimizer_state = nu_optimizer.update( - nu_gradients, state.nu_optimizer_state) - nu_params = optax.apply_updates(state.nu_params, nu_update) - - new_state = TrainingState( - policy_optimizer_state=policy_optimizer_state, - policy_params=policy_params, - nu_optimizer_state=nu_optimizer_state, - nu_params=nu_params, - key=key, - steps=state.steps + 1, - ) - - metrics = { - 'policy_loss': policy_loss_value, - 'nu_loss': nu_loss_value, - } - - return new_state, metrics - - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger( - 'learner', - asynchronous=True, - serialize_fn=utils.fetch_devicearray, - steps_key=self._counter.get_steps_key()) - - # Iterator on demonstration transitions. - self._iterator_demonstrations = iterator_demonstrations - self._iterator_replay = iterator_replay - - self._sgd_step = jax.jit(utils.process_multiple_batches( - sgd_step, num_sgd_steps_per_step)) - - # Create initial state. - self._state = TrainingState( - policy_optimizer_state=policy_optimizer_state, - policy_params=policy_init_params, - nu_optimizer_state=nu_optimizer_state, - nu_params=nu_init_params, - key=rng, - steps=0, - ) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - def step(self): - # Get data from replay (dropping extras if any). Note there is no - # extra data here because we do not insert any into Reverb. - # TODO(raveman): Add a support for offline training, where we do not consume - # data from the replay buffer. - sample = next(self._iterator_replay) - replay_transitions = types.Transition(*sample.data) - - # Get a batch of Transitions from the demonstration. - demonstration_transitions = next(self._iterator_demonstrations) - - self._state, metrics = self._sgd_step( - self._state, (replay_transitions, demonstration_transitions)) - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Increment counts and record the current time - counts = self._counter.increment(steps=1, walltime=elapsed_time) - - # Attempts to write the logs. - self._logger.write({**metrics, **counts}) - - def get_variables(self, names: List[str]) -> List[Any]: - variables = { - 'policy': self._state.policy_params, - 'nu': self._state.nu_params, - } - return [variables[name] for name in names] - - def save(self) -> TrainingState: - return self._state - - def restore(self, state: TrainingState): - self._state = state + """ValueDice learner.""" + + _state: TrainingState + + def __init__( + self, + networks: value_dice_networks.ValueDiceNetworks, + policy_optimizer: optax.GradientTransformation, + nu_optimizer: optax.GradientTransformation, + discount: float, + rng: jnp.ndarray, + iterator_replay: Iterator[reverb.ReplaySample], + iterator_demonstrations: Iterator[types.Transition], + alpha: float = 0.05, + policy_reg_scale: float = 1e-4, + nu_reg_scale: float = 10.0, + num_sgd_steps_per_step: int = 1, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + ): + + rng, policy_key, nu_key = jax.random.split(rng, 3) + policy_init_params = networks.policy_network.init(policy_key) + policy_optimizer_state = policy_optimizer.init(policy_init_params) + + nu_init_params = networks.nu_network.init(nu_key) + nu_optimizer_state = nu_optimizer.init(nu_init_params) + + def compute_losses( + policy_params: networks_lib.Params, + nu_params: networks_lib.Params, + key: jnp.ndarray, + replay_o_tm1: types.NestedArray, + replay_a_tm1: types.NestedArray, + replay_o_t: types.NestedArray, + demo_o_tm1: types.NestedArray, + demo_a_tm1: types.NestedArray, + demo_o_t: types.NestedArray, + ) -> jnp.ndarray: + # TODO(damienv, hussenot): what to do with the discounts ? + + def policy(obs, key): + dist_params = networks.policy_network.apply(policy_params, obs) + return networks.sample(dist_params, key) + + key1, key2, key3, key4 = jax.random.split(key, 4) + + # Predicted actions. + demo_o_t0 = demo_o_tm1 + policy_demo_a_t0 = policy(demo_o_t0, key1) + policy_demo_a_t = policy(demo_o_t, key2) + policy_replay_a_t = policy(replay_o_t, key3) + + replay_a_tm1 = networks.encode_action(replay_a_tm1) + demo_a_tm1 = networks.encode_action(demo_a_tm1) + policy_demo_a_t0 = networks.encode_action(policy_demo_a_t0) + policy_demo_a_t = networks.encode_action(policy_demo_a_t) + policy_replay_a_t = networks.encode_action(policy_replay_a_t) + + # "Value function" nu over the expert states. + nu_demo_t0 = networks.nu_network.apply( + nu_params, demo_o_t0, policy_demo_a_t0 + ) + nu_demo_tm1 = networks.nu_network.apply(nu_params, demo_o_tm1, demo_a_tm1) + nu_demo_t = networks.nu_network.apply(nu_params, demo_o_t, policy_demo_a_t) + nu_demo_diff = nu_demo_tm1 - discount * nu_demo_t + + # "Value function" nu over the replay buffer states. + nu_replay_tm1 = networks.nu_network.apply( + nu_params, replay_o_tm1, replay_a_tm1 + ) + nu_replay_t = networks.nu_network.apply( + nu_params, replay_o_t, policy_replay_a_t + ) + nu_replay_diff = nu_replay_tm1 - discount * nu_replay_t + + # Linear part of the loss. + linear_loss_demo = jnp.mean(nu_demo_t0 * (1.0 - discount)) + linear_loss_rb = jnp.mean(nu_replay_diff) + linear_loss = linear_loss_demo * (1 - alpha) + linear_loss_rb * alpha + + # Non linear part of the loss. + nu_replay_demo_diff = jnp.concatenate( + [nu_demo_diff, nu_replay_diff], axis=0 + ) + replay_demo_weights = jnp.concatenate( + [ + jnp.ones_like(nu_demo_diff) * (1 - alpha), + jnp.ones_like(nu_replay_diff) * alpha, + ], + axis=0, + ) + replay_demo_weights /= jnp.mean(replay_demo_weights) + non_linear_loss = jnp.sum( + jax.lax.stop_gradient( + utils.weighted_softmax( + nu_replay_demo_diff, replay_demo_weights, axis=0 + ) + ) + * nu_replay_demo_diff + ) + + # Final loss. + loss = non_linear_loss - linear_loss + + # Regularized policy loss. + if policy_reg_scale > 0.0: + policy_reg = _orthogonal_regularization_loss(policy_params) + else: + policy_reg = 0.0 + + # Gradient penality on nu + if nu_reg_scale > 0.0: + batch_size = demo_o_tm1.shape[0] + c = jax.random.uniform(key4, shape=(batch_size,)) + shape_o = [ + dim if i == 0 else 1 for i, dim in enumerate(replay_o_tm1.shape) + ] + shape_a = [ + dim if i == 0 else 1 for i, dim in enumerate(replay_a_tm1.shape) + ] + c_o = jnp.reshape(c, shape_o) + c_a = jnp.reshape(c, shape_a) + mixed_o_tm1 = c_o * demo_o_tm1 + (1 - c_o) * replay_o_tm1 + mixed_a_tm1 = c_a * demo_a_tm1 + (1 - c_a) * replay_a_tm1 + mixed_o_t = c_o * demo_o_t + (1 - c_o) * replay_o_t + mixed_policy_a_t = c_a * policy_demo_a_t + (1 - c_a) * policy_replay_a_t + mixed_o = jnp.concatenate([mixed_o_tm1, mixed_o_t], axis=0) + mixed_a = jnp.concatenate([mixed_a_tm1, mixed_policy_a_t], axis=0) + + def sum_nu(o, a): + return jnp.sum(networks.nu_network.apply(nu_params, o, a)) + + nu_grad_o_fn = jax.grad(sum_nu, argnums=0) + nu_grad_a_fn = jax.grad(sum_nu, argnums=1) + nu_grad_o = nu_grad_o_fn(mixed_o, mixed_a) + nu_grad_a = nu_grad_a_fn(mixed_o, mixed_a) + nu_grad = jnp.concatenate( + [ + jnp.reshape(nu_grad_o, [batch_size, -1]), + jnp.reshape(nu_grad_a, [batch_size, -1]), + ], + axis=-1, + ) + # TODO(damienv, hussenot): check for the need of eps + # (like in the original value dice code). + nu_grad_penalty = jnp.mean( + jnp.square( + jnp.linalg.norm(nu_grad + 1e-8, axis=-1, keepdims=True) - 1 + ) + ) + else: + nu_grad_penalty = 0.0 + + policy_loss = -loss + policy_reg_scale * policy_reg + nu_loss = loss + nu_reg_scale * nu_grad_penalty + + return ( + policy_loss, + nu_loss, + ) # pytype: disable=bad-return-type # jax-ndarray + + def sgd_step( + state: TrainingState, data: Tuple[types.Transition, types.Transition] + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + replay_transitions, demo_transitions = data + key, key_loss = jax.random.split(state.key) + compute_losses_with_input = functools.partial( + compute_losses, + replay_o_tm1=replay_transitions.observation, + replay_a_tm1=replay_transitions.action, + replay_o_t=replay_transitions.next_observation, + demo_o_tm1=demo_transitions.observation, + demo_a_tm1=demo_transitions.action, + demo_o_t=demo_transitions.next_observation, + key=key_loss, + ) + (policy_loss_value, nu_loss_value), vjpfun = jax.vjp( + compute_losses_with_input, state.policy_params, state.nu_params + ) + policy_gradients, _ = vjpfun((1.0, 0.0)) + _, nu_gradients = vjpfun((0.0, 1.0)) + + # Update optimizers. + policy_update, policy_optimizer_state = policy_optimizer.update( + policy_gradients, state.policy_optimizer_state + ) + policy_params = optax.apply_updates(state.policy_params, policy_update) + + nu_update, nu_optimizer_state = nu_optimizer.update( + nu_gradients, state.nu_optimizer_state + ) + nu_params = optax.apply_updates(state.nu_params, nu_update) + + new_state = TrainingState( + policy_optimizer_state=policy_optimizer_state, + policy_params=policy_params, + nu_optimizer_state=nu_optimizer_state, + nu_params=nu_params, + key=key, + steps=state.steps + 1, + ) + + metrics = { + "policy_loss": policy_loss_value, + "nu_loss": nu_loss_value, + } + + return new_state, metrics + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + "learner", + asynchronous=True, + serialize_fn=utils.fetch_devicearray, + steps_key=self._counter.get_steps_key(), + ) + + # Iterator on demonstration transitions. + self._iterator_demonstrations = iterator_demonstrations + self._iterator_replay = iterator_replay + + self._sgd_step = jax.jit( + utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step) + ) + + # Create initial state. + self._state = TrainingState( + policy_optimizer_state=policy_optimizer_state, + policy_params=policy_init_params, + nu_optimizer_state=nu_optimizer_state, + nu_params=nu_init_params, + key=rng, + steps=0, + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + def step(self): + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + # TODO(raveman): Add a support for offline training, where we do not consume + # data from the replay buffer. + sample = next(self._iterator_replay) + replay_transitions = types.Transition(*sample.data) + + # Get a batch of Transitions from the demonstration. + demonstration_transitions = next(self._iterator_demonstrations) + + self._state, metrics = self._sgd_step( + self._state, (replay_transitions, demonstration_transitions) + ) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Increment counts and record the current time + counts = self._counter.increment(steps=1, walltime=elapsed_time) + + # Attempts to write the logs. + self._logger.write({**metrics, **counts}) + + def get_variables(self, names: List[str]) -> List[Any]: + variables = { + "policy": self._state.policy_params, + "nu": self._state.nu_params, + } + return [variables[name] for name in names] + + def save(self) -> TrainingState: + return self._state + + def restore(self, state: TrainingState): + self._state = state diff --git a/acme/agents/jax/value_dice/networks.py b/acme/agents/jax/value_dice/networks.py index 479a0f8593..217b2f24d6 100644 --- a/acme/agents/jax/value_dice/networks.py +++ b/acme/agents/jax/value_dice/networks.py @@ -18,83 +18,95 @@ import dataclasses from typing import Callable, Optional, Tuple -from acme import specs -from acme.agents.jax import actor_core as actor_core_lib -from acme.jax import networks as networks_lib -from acme.jax import utils import haiku as hk import jax import jax.numpy as jnp import numpy as np +from acme import specs +from acme.agents.jax import actor_core as actor_core_lib +from acme.jax import networks as networks_lib +from acme.jax import utils + @dataclasses.dataclass class ValueDiceNetworks: - """ValueDice networks.""" - policy_network: networks_lib.FeedForwardNetwork - nu_network: networks_lib.FeedForwardNetwork - # Functions for actors and evaluators, resp., to sample actions. - sample: networks_lib.SampleFn - sample_eval: Optional[networks_lib.SampleFn] = None - # Function that transforms an action before a mixture is applied, typically - # the identity for continuous actions and one-hot encoding for discrete - # actions. - encode_action: Callable[[networks_lib.Action], jnp.ndarray] = lambda x: x + """ValueDice networks.""" + + policy_network: networks_lib.FeedForwardNetwork + nu_network: networks_lib.FeedForwardNetwork + # Functions for actors and evaluators, resp., to sample actions. + sample: networks_lib.SampleFn + sample_eval: Optional[networks_lib.SampleFn] = None + # Function that transforms an action before a mixture is applied, typically + # the identity for continuous actions and one-hot encoding for discrete + # actions. + encode_action: Callable[[networks_lib.Action], jnp.ndarray] = lambda x: x def apply_policy_and_sample( - networks: ValueDiceNetworks, - eval_mode: bool = False) -> actor_core_lib.FeedForwardPolicy: - """Returns a function that computes actions.""" - sample_fn = networks.sample if not eval_mode else networks.sample_eval - if not sample_fn: - raise ValueError('sample function is not provided') + networks: ValueDiceNetworks, eval_mode: bool = False +) -> actor_core_lib.FeedForwardPolicy: + """Returns a function that computes actions.""" + sample_fn = networks.sample if not eval_mode else networks.sample_eval + if not sample_fn: + raise ValueError("sample function is not provided") + + def apply_and_sample(params, key, obs): + return sample_fn(networks.policy_network.apply(params, obs), key) - def apply_and_sample(params, key, obs): - return sample_fn(networks.policy_network.apply(params, obs), key) - return apply_and_sample + return apply_and_sample def make_networks( - spec: specs.EnvironmentSpec, - hidden_layer_sizes: Tuple[int, ...] = (256, 256)) -> ValueDiceNetworks: - """Creates networks used by the agent.""" - - num_dimensions = np.prod(spec.actions.shape, dtype=int) - - def _actor_fn(obs): - network = hk.Sequential([ - hk.nets.MLP( - list(hidden_layer_sizes), - w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), - activation=jax.nn.relu, - activate_final=True), - networks_lib.NormalTanhDistribution(num_dimensions), - ]) - return network(obs) - - def _nu_fn(obs, action): - network = hk.Sequential([ - hk.nets.MLP( - list(hidden_layer_sizes) + [1], - w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), - activation=jax.nn.relu), - ]) - return network(jnp.concatenate([obs, action], axis=-1)) - - policy = hk.without_apply_rng(hk.transform(_actor_fn)) - nu = hk.without_apply_rng(hk.transform(_nu_fn)) - - # Create dummy observations and actions to create network parameters. - dummy_action = utils.zeros_like(spec.actions) - dummy_obs = utils.zeros_like(spec.observations) - dummy_action = utils.add_batch_dim(dummy_action) - dummy_obs = utils.add_batch_dim(dummy_obs) - - return ValueDiceNetworks( - policy_network=networks_lib.FeedForwardNetwork( - lambda key: policy.init(key, dummy_obs), policy.apply), - nu_network=networks_lib.FeedForwardNetwork( - lambda key: nu.init(key, dummy_obs, dummy_action), nu.apply), - sample=lambda params, key: params.sample(seed=key), - sample_eval=lambda params, key: params.mode()) + spec: specs.EnvironmentSpec, hidden_layer_sizes: Tuple[int, ...] = (256, 256) +) -> ValueDiceNetworks: + """Creates networks used by the agent.""" + + num_dimensions = np.prod(spec.actions.shape, dtype=int) + + def _actor_fn(obs): + network = hk.Sequential( + [ + hk.nets.MLP( + list(hidden_layer_sizes), + w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), + activation=jax.nn.relu, + activate_final=True, + ), + networks_lib.NormalTanhDistribution(num_dimensions), + ] + ) + return network(obs) + + def _nu_fn(obs, action): + network = hk.Sequential( + [ + hk.nets.MLP( + list(hidden_layer_sizes) + [1], + w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), + activation=jax.nn.relu, + ), + ] + ) + return network(jnp.concatenate([obs, action], axis=-1)) + + policy = hk.without_apply_rng(hk.transform(_actor_fn)) + nu = hk.without_apply_rng(hk.transform(_nu_fn)) + + # Create dummy observations and actions to create network parameters. + dummy_action = utils.zeros_like(spec.actions) + dummy_obs = utils.zeros_like(spec.observations) + dummy_action = utils.add_batch_dim(dummy_action) + dummy_obs = utils.add_batch_dim(dummy_obs) + + return ValueDiceNetworks( + policy_network=networks_lib.FeedForwardNetwork( + lambda key: policy.init(key, dummy_obs), policy.apply + ), + nu_network=networks_lib.FeedForwardNetwork( + lambda key: nu.init(key, dummy_obs, dummy_action), nu.apply + ), + sample=lambda params, key: params.sample(seed=key), + sample_eval=lambda params, key: params.mode(), + ) diff --git a/acme/agents/replay.py b/acme/agents/replay.py index 7f275766b1..13f6639191 100644 --- a/acme/agents/replay.py +++ b/acme/agents/replay.py @@ -17,21 +17,20 @@ import dataclasses from typing import Any, Callable, Dict, Iterator, Optional +import reverb + from acme import adders as adders_lib -from acme import datasets -from acme import specs -from acme import types +from acme import datasets, specs, types from acme.adders import reverb as adders -import reverb @dataclasses.dataclass class ReverbReplay: - server: reverb.Server - adder: adders_lib.Adder - data_iterator: Iterator[reverb.ReplaySample] - client: Optional[reverb.Client] = None - can_sample: Callable[[], bool] = lambda: True + server: reverb.Server + adder: adders_lib.Adder + data_iterator: Iterator[reverb.ReplaySample] + client: Optional[reverb.Client] = None + can_sample: Callable[[], bool] = lambda: True def make_reverb_prioritized_nstep_replay( @@ -41,47 +40,47 @@ def make_reverb_prioritized_nstep_replay( batch_size: int = 32, max_replay_size: int = 100_000, min_replay_size: int = 1, - discount: float = 1., + discount: float = 1.0, prefetch_size: int = 4, # TODO(iosband): rationalize prefetch size. replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, priority_exponent: Optional[float] = None, # If None, default to uniform. ) -> ReverbReplay: - """Creates a single-process replay infrastructure from an environment spec.""" - # Parsing priority exponent to determine uniform vs prioritized replay - if priority_exponent is None: - sampler = reverb.selectors.Uniform() - priority_fns = {replay_table_name: lambda x: 1.} - else: - sampler = reverb.selectors.Prioritized(priority_exponent) - priority_fns = None - - # Create a replay server to add data to. This uses no limiter behavior in - # order to allow the Agent interface to handle it. - replay_table = reverb.Table( - name=replay_table_name, - sampler=sampler, - remover=reverb.selectors.Fifo(), - max_size=max_replay_size, - rate_limiter=reverb.rate_limiters.MinSize(min_replay_size), - signature=adders.NStepTransitionAdder.signature(environment_spec, - extra_spec), - ) - server = reverb.Server([replay_table], port=None) - - # The adder is used to insert observations into replay. - address = f'localhost:{server.port}' - client = reverb.Client(address) - adder = adders.NStepTransitionAdder( - client, n_step, discount, priority_fns=priority_fns) - - # The dataset provides an interface to sample from replay. - data_iterator = datasets.make_reverb_dataset( - table=replay_table_name, - server_address=address, - batch_size=batch_size, - prefetch_size=prefetch_size, - ).as_numpy_iterator() - return ReverbReplay(server, adder, data_iterator, client=client) + """Creates a single-process replay infrastructure from an environment spec.""" + # Parsing priority exponent to determine uniform vs prioritized replay + if priority_exponent is None: + sampler = reverb.selectors.Uniform() + priority_fns = {replay_table_name: lambda x: 1.0} + else: + sampler = reverb.selectors.Prioritized(priority_exponent) + priority_fns = None + + # Create a replay server to add data to. This uses no limiter behavior in + # order to allow the Agent interface to handle it. + replay_table = reverb.Table( + name=replay_table_name, + sampler=sampler, + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(min_replay_size), + signature=adders.NStepTransitionAdder.signature(environment_spec, extra_spec), + ) + server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f"localhost:{server.port}" + client = reverb.Client(address) + adder = adders.NStepTransitionAdder( + client, n_step, discount, priority_fns=priority_fns + ) + + # The dataset provides an interface to sample from replay. + data_iterator = datasets.make_reverb_dataset( + table=replay_table_name, + server_address=address, + batch_size=batch_size, + prefetch_size=prefetch_size, + ).as_numpy_iterator() + return ReverbReplay(server, adder, data_iterator, client=client) def make_reverb_online_queue( @@ -93,32 +92,33 @@ def make_reverb_online_queue( batch_size: int, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, ) -> ReverbReplay: - """Creates a single process queue from an environment spec and extra_spec.""" - signature = adders.SequenceAdder.signature(environment_spec, extra_spec) - queue = reverb.Table.queue( - name=replay_table_name, max_size=max_queue_size, signature=signature) - server = reverb.Server([queue], port=None) - can_sample = lambda: queue.can_sample(batch_size) - - # Component to add things into replay. - address = f'localhost:{server.port}' - adder = adders.SequenceAdder( - client=reverb.Client(address), - period=sequence_period, - sequence_length=sequence_length, - ) - - # The dataset object to learn from. - # We don't use datasets.make_reverb_dataset() here to avoid interleaving - # and prefetching, that doesn't work well with can_sample() check on update. - dataset = reverb.TrajectoryDataset.from_table_signature( - server_address=address, - table=replay_table_name, - max_in_flight_samples_per_worker=1, - ) - dataset = dataset.batch(batch_size, drop_remainder=True) - data_iterator = dataset.as_numpy_iterator() - return ReverbReplay(server, adder, data_iterator, can_sample=can_sample) + """Creates a single process queue from an environment spec and extra_spec.""" + signature = adders.SequenceAdder.signature(environment_spec, extra_spec) + queue = reverb.Table.queue( + name=replay_table_name, max_size=max_queue_size, signature=signature + ) + server = reverb.Server([queue], port=None) + can_sample = lambda: queue.can_sample(batch_size) + + # Component to add things into replay. + address = f"localhost:{server.port}" + adder = adders.SequenceAdder( + client=reverb.Client(address), + period=sequence_period, + sequence_length=sequence_length, + ) + + # The dataset object to learn from. + # We don't use datasets.make_reverb_dataset() here to avoid interleaving + # and prefetching, that doesn't work well with can_sample() check on update. + dataset = reverb.TrajectoryDataset.from_table_signature( + server_address=address, + table=replay_table_name, + max_in_flight_samples_per_worker=1, + ) + dataset = dataset.batch(batch_size, drop_remainder=True) + data_iterator = dataset.as_numpy_iterator() + return ReverbReplay(server, adder, data_iterator, can_sample=can_sample) def make_reverb_prioritized_sequence_replay( @@ -127,42 +127,42 @@ def make_reverb_prioritized_sequence_replay( batch_size: int = 32, max_replay_size: int = 100_000, min_replay_size: int = 1, - priority_exponent: float = 0., + priority_exponent: float = 0.0, burn_in_length: int = 40, sequence_length: int = 80, sequence_period: int = 40, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, prefetch_size: int = 4, ) -> ReverbReplay: - """Single-process replay for sequence data from an environment spec.""" - # Create a replay server to add data to. This uses no limiter behavior in - # order to allow the Agent interface to handle it. - replay_table = reverb.Table( - name=replay_table_name, - sampler=reverb.selectors.Prioritized(priority_exponent), - remover=reverb.selectors.Fifo(), - max_size=max_replay_size, - rate_limiter=reverb.rate_limiters.MinSize(min_replay_size), - signature=adders.SequenceAdder.signature(environment_spec, extra_spec), - ) - server = reverb.Server([replay_table], port=None) - - # The adder is used to insert observations into replay. - address = f'localhost:{server.port}' - client = reverb.Client(address) - sequence_length = burn_in_length + sequence_length + 1 - adder = adders.SequenceAdder( - client=client, - period=sequence_period, - sequence_length=sequence_length, - delta_encoded=True, - ) - - # The dataset provides an interface to sample from replay. - data_iterator = datasets.make_reverb_dataset( - table=replay_table_name, - server_address=address, - batch_size=batch_size, - prefetch_size=prefetch_size, - ).as_numpy_iterator() - return ReverbReplay(server, adder, data_iterator, client) + """Single-process replay for sequence data from an environment spec.""" + # Create a replay server to add data to. This uses no limiter behavior in + # order to allow the Agent interface to handle it. + replay_table = reverb.Table( + name=replay_table_name, + sampler=reverb.selectors.Prioritized(priority_exponent), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(min_replay_size), + signature=adders.SequenceAdder.signature(environment_spec, extra_spec), + ) + server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f"localhost:{server.port}" + client = reverb.Client(address) + sequence_length = burn_in_length + sequence_length + 1 + adder = adders.SequenceAdder( + client=client, + period=sequence_period, + sequence_length=sequence_length, + delta_encoded=True, + ) + + # The dataset provides an interface to sample from replay. + data_iterator = datasets.make_reverb_dataset( + table=replay_table_name, + server_address=address, + batch_size=batch_size, + prefetch_size=prefetch_size, + ).as_numpy_iterator() + return ReverbReplay(server, adder, data_iterator, client) diff --git a/acme/agents/tf/__init__.py b/acme/agents/tf/__init__.py index 240cb71526..de867df849 100644 --- a/acme/agents/tf/__init__.py +++ b/acme/agents/tf/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/acme/agents/tf/actors.py b/acme/agents/tf/actors.py index 1f934be4e2..d778881b6a 100644 --- a/acme/agents/tf/actors.py +++ b/acme/agents/tf/actors.py @@ -16,36 +16,35 @@ from typing import Optional, Tuple -from acme import adders -from acme import core -from acme import types -# Internal imports. -from acme.tf import utils as tf2_utils -from acme.tf import variable_utils as tf2_variable_utils - import dm_env import sonnet as snt import tensorflow as tf import tensorflow_probability as tfp +from acme import adders, core, types + +# Internal imports. +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils + tfd = tfp.distributions class FeedForwardActor(core.Actor): - """A feed-forward actor. + """A feed-forward actor. An actor based on a feed-forward policy which takes non-batched observations and outputs non-batched actions. It also allows adding experiences to replay and updating the weights from the policy on the learner. """ - def __init__( - self, - policy_network: snt.Module, - adder: Optional[adders.Adder] = None, - variable_client: Optional[tf2_variable_utils.VariableClient] = None, - ): - """Initializes the actor. + def __init__( + self, + policy_network: snt.Module, + adder: Optional[adders.Adder] = None, + variable_client: Optional[tf2_variable_utils.VariableClient] = None, + ): + """Initializes the actor. Args: policy_network: the policy to run. @@ -55,46 +54,46 @@ def __init__( of the policy to the actor copy (in case they are separate). """ - # Store these for later use. - self._adder = adder - self._variable_client = variable_client - self._policy_network = policy_network + # Store these for later use. + self._adder = adder + self._variable_client = variable_client + self._policy_network = policy_network - @tf.function - def _policy(self, observation: types.NestedTensor) -> types.NestedTensor: - # Add a dummy batch dimension and as a side effect convert numpy to TF. - batched_observation = tf2_utils.add_batch_dim(observation) + @tf.function + def _policy(self, observation: types.NestedTensor) -> types.NestedTensor: + # Add a dummy batch dimension and as a side effect convert numpy to TF. + batched_observation = tf2_utils.add_batch_dim(observation) - # Compute the policy, conditioned on the observation. - policy = self._policy_network(batched_observation) + # Compute the policy, conditioned on the observation. + policy = self._policy_network(batched_observation) - # Sample from the policy if it is stochastic. - action = policy.sample() if isinstance(policy, tfd.Distribution) else policy + # Sample from the policy if it is stochastic. + action = policy.sample() if isinstance(policy, tfd.Distribution) else policy - return action + return action - def select_action(self, observation: types.NestedArray) -> types.NestedArray: - # Pass the observation through the policy network. - action = self._policy(observation) + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + # Pass the observation through the policy network. + action = self._policy(observation) - # Return a numpy array with squeezed out batch dimension. - return tf2_utils.to_numpy_squeeze(action) + # Return a numpy array with squeezed out batch dimension. + return tf2_utils.to_numpy_squeeze(action) - def observe_first(self, timestep: dm_env.TimeStep): - if self._adder: - self._adder.add_first(timestep) + def observe_first(self, timestep: dm_env.TimeStep): + if self._adder: + self._adder.add_first(timestep) - def observe(self, action: types.NestedArray, next_timestep: dm_env.TimeStep): - if self._adder: - self._adder.add(action, next_timestep) + def observe(self, action: types.NestedArray, next_timestep: dm_env.TimeStep): + if self._adder: + self._adder.add(action, next_timestep) - def update(self, wait: bool = False): - if self._variable_client: - self._variable_client.update(wait) + def update(self, wait: bool = False): + if self._variable_client: + self._variable_client.update(wait) class RecurrentActor(core.Actor): - """A recurrent actor. + """A recurrent actor. An actor based on a recurrent policy which takes non-batched observations and outputs non-batched actions, and keeps track of the recurrent state inside. It @@ -102,14 +101,14 @@ class RecurrentActor(core.Actor): policy on the learner. """ - def __init__( - self, - policy_network: snt.RNNCore, - adder: Optional[adders.Adder] = None, - variable_client: Optional[tf2_variable_utils.VariableClient] = None, - store_recurrent_state: bool = True, - ): - """Initializes the actor. + def __init__( + self, + policy_network: snt.RNNCore, + adder: Optional[adders.Adder] = None, + variable_client: Optional[tf2_variable_utils.VariableClient] = None, + store_recurrent_state: bool = True, + ): + """Initializes the actor. Args: policy_network: the (recurrent) policy to run. @@ -119,68 +118,67 @@ def __init__( of the policy to the actor copy (in case they are separate). store_recurrent_state: Whether to pass the recurrent state to the adder. """ - # Store these for later use. - self._adder = adder - self._variable_client = variable_client - self._network = policy_network - self._state = None - self._prev_state = None - self._store_recurrent_state = store_recurrent_state + # Store these for later use. + self._adder = adder + self._variable_client = variable_client + self._network = policy_network + self._state = None + self._prev_state = None + self._store_recurrent_state = store_recurrent_state + + @tf.function + def _policy( + self, observation: types.NestedTensor, state: types.NestedTensor, + ) -> Tuple[types.NestedTensor, types.NestedTensor]: - @tf.function - def _policy( - self, - observation: types.NestedTensor, - state: types.NestedTensor, - ) -> Tuple[types.NestedTensor, types.NestedTensor]: + # Add a dummy batch dimension and as a side effect convert numpy to TF. + batched_observation = tf2_utils.add_batch_dim(observation) - # Add a dummy batch dimension and as a side effect convert numpy to TF. - batched_observation = tf2_utils.add_batch_dim(observation) + # Compute the policy, conditioned on the observation. + policy, new_state = self._network(batched_observation, state) - # Compute the policy, conditioned on the observation. - policy, new_state = self._network(batched_observation, state) + # Sample from the policy if it is stochastic. + action = policy.sample() if isinstance(policy, tfd.Distribution) else policy - # Sample from the policy if it is stochastic. - action = policy.sample() if isinstance(policy, tfd.Distribution) else policy + return action, new_state - return action, new_state + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + # Initialize the RNN state if necessary. + if self._state is None: + self._state = self._network.initial_state(1) - def select_action(self, observation: types.NestedArray) -> types.NestedArray: - # Initialize the RNN state if necessary. - if self._state is None: - self._state = self._network.initial_state(1) + # Step the recurrent policy forward given the current observation and state. + policy_output, new_state = self._policy(observation, self._state) - # Step the recurrent policy forward given the current observation and state. - policy_output, new_state = self._policy(observation, self._state) + # Bookkeeping of recurrent states for the observe method. + self._prev_state = self._state + self._state = new_state - # Bookkeeping of recurrent states for the observe method. - self._prev_state = self._state - self._state = new_state + # Return a numpy array with squeezed out batch dimension. + return tf2_utils.to_numpy_squeeze(policy_output) - # Return a numpy array with squeezed out batch dimension. - return tf2_utils.to_numpy_squeeze(policy_output) + def observe_first(self, timestep: dm_env.TimeStep): + if self._adder: + self._adder.add_first(timestep) - def observe_first(self, timestep: dm_env.TimeStep): - if self._adder: - self._adder.add_first(timestep) + # Set the state to None so that we re-initialize at the next policy call. + self._state = None - # Set the state to None so that we re-initialize at the next policy call. - self._state = None + def observe(self, action: types.NestedArray, next_timestep: dm_env.TimeStep): + if not self._adder: + return - def observe(self, action: types.NestedArray, next_timestep: dm_env.TimeStep): - if not self._adder: - return + if not self._store_recurrent_state: + self._adder.add(action, next_timestep) + return - if not self._store_recurrent_state: - self._adder.add(action, next_timestep) - return + numpy_state = tf2_utils.to_numpy_squeeze(self._prev_state) + self._adder.add(action, next_timestep, extras=(numpy_state,)) - numpy_state = tf2_utils.to_numpy_squeeze(self._prev_state) - self._adder.add(action, next_timestep, extras=(numpy_state,)) + def update(self, wait: bool = False): + if self._variable_client: + self._variable_client.update(wait) - def update(self, wait: bool = False): - if self._variable_client: - self._variable_client.update(wait) # Internal class 1. # Internal class 2. diff --git a/acme/agents/tf/actors_test.py b/acme/agents/tf/actors_test.py index a8d4e35fd1..d1b85cc5a6 100644 --- a/acme/agents/tf/actors_test.py +++ b/acme/agents/tf/actors_test.py @@ -14,59 +14,62 @@ """Tests for actors_tf2.""" -from acme import environment_loop -from acme import specs -from acme.agents.tf import actors -from acme.testing import fakes import dm_env import numpy as np import sonnet as snt import tensorflow as tf - from absl.testing import absltest +from acme import environment_loop, specs +from acme.agents.tf import actors +from acme.testing import fakes + def _make_fake_env() -> dm_env.Environment: - env_spec = specs.EnvironmentSpec( - observations=specs.Array(shape=(10, 5), dtype=np.float32), - actions=specs.DiscreteArray(num_values=3), - rewards=specs.Array(shape=(), dtype=np.float32), - discounts=specs.BoundedArray( - shape=(), dtype=np.float32, minimum=0., maximum=1.), - ) - return fakes.Environment(env_spec, episode_length=10) + env_spec = specs.EnvironmentSpec( + observations=specs.Array(shape=(10, 5), dtype=np.float32), + actions=specs.DiscreteArray(num_values=3), + rewards=specs.Array(shape=(), dtype=np.float32), + discounts=specs.BoundedArray( + shape=(), dtype=np.float32, minimum=0.0, maximum=1.0 + ), + ) + return fakes.Environment(env_spec, episode_length=10) class ActorTest(absltest.TestCase): + def test_feedforward(self): + environment = _make_fake_env() + env_spec = specs.make_environment_spec(environment) - def test_feedforward(self): - environment = _make_fake_env() - env_spec = specs.make_environment_spec(environment) - - network = snt.Sequential([ - snt.Flatten(), - snt.Linear(env_spec.actions.num_values), - lambda x: tf.argmax(x, axis=-1, output_type=env_spec.actions.dtype), - ]) + network = snt.Sequential( + [ + snt.Flatten(), + snt.Linear(env_spec.actions.num_values), + lambda x: tf.argmax(x, axis=-1, output_type=env_spec.actions.dtype), + ] + ) - actor = actors.FeedForwardActor(network) - loop = environment_loop.EnvironmentLoop(environment, actor) - loop.run(20) + actor = actors.FeedForwardActor(network) + loop = environment_loop.EnvironmentLoop(environment, actor) + loop.run(20) - def test_recurrent(self): - environment = _make_fake_env() - env_spec = specs.make_environment_spec(environment) + def test_recurrent(self): + environment = _make_fake_env() + env_spec = specs.make_environment_spec(environment) - network = snt.DeepRNN([ - snt.Flatten(), - snt.Linear(env_spec.actions.num_values), - lambda x: tf.argmax(x, axis=-1, output_type=env_spec.actions.dtype), - ]) + network = snt.DeepRNN( + [ + snt.Flatten(), + snt.Linear(env_spec.actions.num_values), + lambda x: tf.argmax(x, axis=-1, output_type=env_spec.actions.dtype), + ] + ) - actor = actors.RecurrentActor(network) - loop = environment_loop.EnvironmentLoop(environment, actor) - loop.run(20) + actor = actors.RecurrentActor(network) + loop = environment_loop.EnvironmentLoop(environment, actor) + loop.run(20) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/bc/learning.py b/acme/agents/tf/bc/learning.py index 05c692ebdc..b240fa334c 100644 --- a/acme/agents/tf/bc/learning.py +++ b/acme/agents/tf/bc/learning.py @@ -16,32 +16,34 @@ from typing import Dict, List, Optional +import numpy as np +import sonnet as snt +import tensorflow as tf + import acme from acme import types from acme.tf import savers as tf2_savers from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers -import numpy as np -import sonnet as snt -import tensorflow as tf +from acme.utils import counting, loggers class BCLearner(acme.Learner, tf2_savers.TFSaveable): - """BC learner. + """BC learner. This is the learning component of a BC agent. IE it takes a dataset as input and implements update functionality to learn from this dataset. """ - def __init__(self, - network: snt.Module, - learning_rate: float, - dataset: tf.data.Dataset, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - checkpoint: bool = True): - """Initializes the learner. + def __init__( + self, + network: snt.Module, + learning_rate: float, + dataset: tf.data.Dataset, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + """Initializes the learner. Args: network: the BC network (the one being optimized) @@ -52,72 +54,73 @@ def __init__(self, checkpoint: boolean indicating whether to checkpoint the learner. """ - self._counter = counter or counting.Counter() - self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) + self._counter = counter or counting.Counter() + self._logger = logger or loggers.TerminalLogger("learner", time_delta=1.0) - # Get an iterator over the dataset. - self._iterator = iter(dataset) # pytype: disable=wrong-arg-types - # TODO(b/155086959): Fix type stubs and remove. + # Get an iterator over the dataset. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + # TODO(b/155086959): Fix type stubs and remove. - self._network = network - self._optimizer = snt.optimizers.Adam(learning_rate) + self._network = network + self._optimizer = snt.optimizers.Adam(learning_rate) - self._variables: List[List[tf.Tensor]] = [network.trainable_variables] - self._num_steps = tf.Variable(0, dtype=tf.int32) + self._variables: List[List[tf.Tensor]] = [network.trainable_variables] + self._num_steps = tf.Variable(0, dtype=tf.int32) - # Create a snapshotter object. - if checkpoint: - self._snapshotter = tf2_savers.Snapshotter( - objects_to_save={'network': network}, time_delta_minutes=60.) - else: - self._snapshotter = None + # Create a snapshotter object. + if checkpoint: + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={"network": network}, time_delta_minutes=60.0 + ) + else: + self._snapshotter = None - @tf.function - def _step(self) -> Dict[str, tf.Tensor]: - """Do a step of SGD and update the priorities.""" + @tf.function + def _step(self) -> Dict[str, tf.Tensor]: + """Do a step of SGD and update the priorities.""" - # Pull out the data needed for updates/priorities. - inputs = next(self._iterator) - transitions: types.Transition = inputs.data + # Pull out the data needed for updates/priorities. + inputs = next(self._iterator) + transitions: types.Transition = inputs.data - with tf.GradientTape() as tape: - # Evaluate our networks. - logits = self._network(transitions.observation) - cce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) - loss = cce(transitions.action, logits) + with tf.GradientTape() as tape: + # Evaluate our networks. + logits = self._network(transitions.observation) + cce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + loss = cce(transitions.action, logits) - gradients = tape.gradient(loss, self._network.trainable_variables) - self._optimizer.apply(gradients, self._network.trainable_variables) + gradients = tape.gradient(loss, self._network.trainable_variables) + self._optimizer.apply(gradients, self._network.trainable_variables) - self._num_steps.assign_add(1) + self._num_steps.assign_add(1) - # Compute the global norm of the gradients for logging. - global_gradient_norm = tf.linalg.global_norm(gradients) - fetches = {'loss': loss, 'gradient_norm': global_gradient_norm} + # Compute the global norm of the gradients for logging. + global_gradient_norm = tf.linalg.global_norm(gradients) + fetches = {"loss": loss, "gradient_norm": global_gradient_norm} - return fetches + return fetches - def step(self): - # Do a batch of SGD. - result = self._step() + def step(self): + # Do a batch of SGD. + result = self._step() - # Update our counts and record it. - counts = self._counter.increment(steps=1) - result.update(counts) + # Update our counts and record it. + counts = self._counter.increment(steps=1) + result.update(counts) - # Snapshot and attempt to write logs. - if self._snapshotter is not None: - self._snapshotter.save() - self._logger.write(result) + # Snapshot and attempt to write logs. + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(result) - def get_variables(self, names: List[str]) -> List[np.ndarray]: - return tf2_utils.to_numpy(self._variables) + def get_variables(self, names: List[str]) -> List[np.ndarray]: + return tf2_utils.to_numpy(self._variables) - @property - def state(self): - """Returns the stateful parts of the learner for checkpointing.""" - return { - 'network': self._network, - 'optimizer': self._optimizer, - 'num_steps': self._num_steps - } + @property + def state(self): + """Returns the stateful parts of the learner for checkpointing.""" + return { + "network": self._network, + "optimizer": self._optimizer, + "num_steps": self._num_steps, + } diff --git a/acme/agents/tf/bcq/discrete_learning.py b/acme/agents/tf/bcq/discrete_learning.py index 4b4e234092..a6d69d70ce 100644 --- a/acme/agents/tf/bcq/discrete_learning.py +++ b/acme/agents/tf/bcq/discrete_learning.py @@ -20,44 +20,43 @@ import copy from typing import Dict, List, Optional -from acme import core -from acme import types +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf +import trfl + +from acme import core, types from acme.adders import reverb as adders from acme.agents.tf import bc from acme.tf import losses from acme.tf import savers as tf2_savers from acme.tf import utils as tf2_utils from acme.tf.networks import discrete as discrete_networks -from acme.utils import counting -from acme.utils import loggers -import numpy as np -import reverb -import sonnet as snt -import tensorflow as tf -import trfl +from acme.utils import counting, loggers class _InternalBCQLearner(core.Learner, tf2_savers.TFSaveable): - """Internal BCQ learner. + """Internal BCQ learner. This implements the Q-learning component in the discrete BCQ algorithm. """ - def __init__( - self, - network: discrete_networks.DiscreteFilteredQNetwork, - discount: float, - importance_sampling_exponent: float, - learning_rate: float, - target_update_period: int, - dataset: tf.data.Dataset, - huber_loss_parameter: float = 1., - replay_client: Optional[reverb.TFClient] = None, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - checkpoint: bool = False, - ): - """Initializes the learner. + def __init__( + self, + network: discrete_networks.DiscreteFilteredQNetwork, + discount: float, + importance_sampling_exponent: float, + learning_rate: float, + target_update_period: int, + dataset: tf.data.Dataset, + huber_loss_parameter: float = 1.0, + replay_client: Optional[reverb.TFClient] = None, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = False, + ): + """Initializes the learner. Args: network: BCQ network @@ -76,185 +75,195 @@ def __init__( checkpoint: boolean indicating whether to checkpoint the learner. """ - # Internalise agent components (replay buffer, networks, optimizer). - # TODO(b/155086959): Fix type stubs and remove. - self._iterator = iter(dataset) # pytype: disable=wrong-arg-types - self._network = network - self._q_network = network.q_network - self._target_q_network = copy.deepcopy(network.q_network) - self._optimizer = snt.optimizers.Adam(learning_rate) - self._replay_client = replay_client - - # Internalise the hyperparameters. - self._discount = discount - self._target_update_period = target_update_period - self._importance_sampling_exponent = importance_sampling_exponent - self._huber_loss_parameter = huber_loss_parameter - - # Learner state. - self._variables = [self._network.trainable_variables] - self._num_steps = tf.Variable(0, dtype=tf.int32) - - # Internalise logging/counting objects. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger('learner', - save_data=False) - - # Create a snapshotter object. - if checkpoint: - self._snapshotter = tf2_savers.Snapshotter( - objects_to_save={'network': network}, time_delta_minutes=60.) - else: - self._snapshotter = None - - @tf.function - def _step(self) -> Dict[str, tf.Tensor]: - """Do a step of SGD and update the priorities.""" - - # Pull out the data needed for updates/priorities. - inputs = next(self._iterator) - transitions: types.Transition = inputs.data - keys, probs = inputs.info[:2] - - with tf.GradientTape() as tape: - # Evaluate our networks. - q_tm1 = self._q_network(transitions.observation) - q_t_value = self._target_q_network(transitions.next_observation) - q_t_selector = self._network(transitions.next_observation) - - # The rewards and discounts have to have the same type as network values. - r_t = tf.cast(transitions.reward, q_tm1.dtype) - r_t = tf.clip_by_value(r_t, -1., 1.) - d_t = tf.cast(transitions.discount, q_tm1.dtype) * tf.cast( - self._discount, q_tm1.dtype) - - # Compute the loss. - _, extra = trfl.double_qlearning(q_tm1, transitions.action, r_t, d_t, - q_t_value, q_t_selector) - loss = losses.huber(extra.td_error, self._huber_loss_parameter) - - # Get the importance weights. - importance_weights = 1. / probs # [B] - importance_weights **= self._importance_sampling_exponent - importance_weights /= tf.reduce_max(importance_weights) - - # Reweight. - loss *= tf.cast(importance_weights, loss.dtype) # [B] - loss = tf.reduce_mean(loss, axis=[0]) # [] - - # Do a step of SGD. - gradients = tape.gradient(loss, self._network.trainable_variables) - self._optimizer.apply(gradients, self._network.trainable_variables) - - # Update the priorities in the replay buffer. - if self._replay_client: - priorities = tf.cast(tf.abs(extra.td_error), tf.float64) - self._replay_client.update_priorities( - table=adders.DEFAULT_PRIORITY_TABLE, keys=keys, priorities=priorities) - - # Periodically update the target network. - if tf.math.mod(self._num_steps, self._target_update_period) == 0: - for src, dest in zip(self._q_network.variables, - self._target_q_network.variables): - dest.assign(src) - self._num_steps.assign_add(1) - - # Compute the global norm of the gradients for logging. - global_gradient_norm = tf.linalg.global_norm(gradients) - - # Compute statistics of the Q-values for logging. - max_q = tf.reduce_max(q_t_value) - min_q = tf.reduce_min(q_t_value) - mean_q, var_q = tf.nn.moments(q_t_value, [0, 1]) - - # Report loss & statistics for logging. - fetches = { - 'gradient_norm': global_gradient_norm, - 'loss': loss, - 'max_q': max_q, - 'mean_q': mean_q, - 'min_q': min_q, - 'var_q': var_q, - } - - return fetches - - def step(self): - # Do a batch of SGD. - result = self._step() - - # Update our counts and record it. - counts = self._counter.increment(steps=1) - result.update(counts) - - # Snapshot and attempt to write logs. - if self._snapshotter is not None: - self._snapshotter.save() - self._logger.write(result) - - def get_variables(self, names: List[str]) -> List[np.ndarray]: - return tf2_utils.to_numpy(self._variables) - - @property - def state(self): - """Returns the stateful parts of the learner for checkpointing.""" - return { - 'network': self._network, - 'target_q_network': self._target_q_network, - 'optimizer': self._optimizer, - 'num_steps': self._num_steps - } + # Internalise agent components (replay buffer, networks, optimizer). + # TODO(b/155086959): Fix type stubs and remove. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + self._network = network + self._q_network = network.q_network + self._target_q_network = copy.deepcopy(network.q_network) + self._optimizer = snt.optimizers.Adam(learning_rate) + self._replay_client = replay_client + + # Internalise the hyperparameters. + self._discount = discount + self._target_update_period = target_update_period + self._importance_sampling_exponent = importance_sampling_exponent + self._huber_loss_parameter = huber_loss_parameter + + # Learner state. + self._variables = [self._network.trainable_variables] + self._num_steps = tf.Variable(0, dtype=tf.int32) + + # Internalise logging/counting objects. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger("learner", save_data=False) + + # Create a snapshotter object. + if checkpoint: + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={"network": network}, time_delta_minutes=60.0 + ) + else: + self._snapshotter = None + + @tf.function + def _step(self) -> Dict[str, tf.Tensor]: + """Do a step of SGD and update the priorities.""" + + # Pull out the data needed for updates/priorities. + inputs = next(self._iterator) + transitions: types.Transition = inputs.data + keys, probs = inputs.info[:2] + + with tf.GradientTape() as tape: + # Evaluate our networks. + q_tm1 = self._q_network(transitions.observation) + q_t_value = self._target_q_network(transitions.next_observation) + q_t_selector = self._network(transitions.next_observation) + + # The rewards and discounts have to have the same type as network values. + r_t = tf.cast(transitions.reward, q_tm1.dtype) + r_t = tf.clip_by_value(r_t, -1.0, 1.0) + d_t = tf.cast(transitions.discount, q_tm1.dtype) * tf.cast( + self._discount, q_tm1.dtype + ) + + # Compute the loss. + _, extra = trfl.double_qlearning( + q_tm1, transitions.action, r_t, d_t, q_t_value, q_t_selector + ) + loss = losses.huber(extra.td_error, self._huber_loss_parameter) + + # Get the importance weights. + importance_weights = 1.0 / probs # [B] + importance_weights **= self._importance_sampling_exponent + importance_weights /= tf.reduce_max(importance_weights) + + # Reweight. + loss *= tf.cast(importance_weights, loss.dtype) # [B] + loss = tf.reduce_mean(loss, axis=[0]) # [] + + # Do a step of SGD. + gradients = tape.gradient(loss, self._network.trainable_variables) + self._optimizer.apply(gradients, self._network.trainable_variables) + + # Update the priorities in the replay buffer. + if self._replay_client: + priorities = tf.cast(tf.abs(extra.td_error), tf.float64) + self._replay_client.update_priorities( + table=adders.DEFAULT_PRIORITY_TABLE, keys=keys, priorities=priorities + ) + + # Periodically update the target network. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip( + self._q_network.variables, self._target_q_network.variables + ): + dest.assign(src) + self._num_steps.assign_add(1) + + # Compute the global norm of the gradients for logging. + global_gradient_norm = tf.linalg.global_norm(gradients) + + # Compute statistics of the Q-values for logging. + max_q = tf.reduce_max(q_t_value) + min_q = tf.reduce_min(q_t_value) + mean_q, var_q = tf.nn.moments(q_t_value, [0, 1]) + + # Report loss & statistics for logging. + fetches = { + "gradient_norm": global_gradient_norm, + "loss": loss, + "max_q": max_q, + "mean_q": mean_q, + "min_q": min_q, + "var_q": var_q, + } + + return fetches + + def step(self): + # Do a batch of SGD. + result = self._step() + + # Update our counts and record it. + counts = self._counter.increment(steps=1) + result.update(counts) + + # Snapshot and attempt to write logs. + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(result) + + def get_variables(self, names: List[str]) -> List[np.ndarray]: + return tf2_utils.to_numpy(self._variables) + + @property + def state(self): + """Returns the stateful parts of the learner for checkpointing.""" + return { + "network": self._network, + "target_q_network": self._target_q_network, + "optimizer": self._optimizer, + "num_steps": self._num_steps, + } class DiscreteBCQLearner(core.Learner, tf2_savers.TFSaveable): - """Discrete BCQ learner. + """Discrete BCQ learner. This learner combines supervised BC learning and Q learning to implement the discrete BCQ algorithm as described in https://arxiv.org/pdf/1910.01708.pdf. """ - def __init__(self, - network: discrete_networks.DiscreteFilteredQNetwork, - dataset: tf.data.Dataset, - learning_rate: float, - counter: Optional[counting.Counter] = None, - bc_logger: Optional[loggers.Logger] = None, - bcq_logger: Optional[loggers.Logger] = None, - **bcq_learner_kwargs): - counter = counter or counting.Counter() - self._bc_logger = bc_logger or loggers.TerminalLogger('bc_learner', - time_delta=1.) - self._bcq_logger = bcq_logger or loggers.TerminalLogger('bcq_learner', - time_delta=1.) - - self._bc_learner = bc.BCLearner( - network=network.g_network, - learning_rate=learning_rate, - dataset=dataset, - counter=counting.Counter(counter, 'bc'), - logger=self._bc_logger, - checkpoint=False) - self._bcq_learner = _InternalBCQLearner( - network=network, - learning_rate=learning_rate, - dataset=dataset, - counter=counting.Counter(counter, 'bcq'), - logger=self._bcq_logger, - **bcq_learner_kwargs) - - def get_variables(self, names): - return self._bcq_learner.get_variables(names) - - @property - def state(self): - bc_state = self._bc_learner.state - bc_state.pop('network') # No need to checkpoint the BC network. - bcq_state = self._bcq_learner.state - state = dict() - state.update({f'bc_{k}': v for k, v in bc_state.items()}) - state.update({f'bcq_{k}': v for k, v in bcq_state.items()}) - return state - - def step(self): - self._bc_learner.step() - self._bcq_learner.step() + def __init__( + self, + network: discrete_networks.DiscreteFilteredQNetwork, + dataset: tf.data.Dataset, + learning_rate: float, + counter: Optional[counting.Counter] = None, + bc_logger: Optional[loggers.Logger] = None, + bcq_logger: Optional[loggers.Logger] = None, + **bcq_learner_kwargs, + ): + counter = counter or counting.Counter() + self._bc_logger = bc_logger or loggers.TerminalLogger( + "bc_learner", time_delta=1.0 + ) + self._bcq_logger = bcq_logger or loggers.TerminalLogger( + "bcq_learner", time_delta=1.0 + ) + + self._bc_learner = bc.BCLearner( + network=network.g_network, + learning_rate=learning_rate, + dataset=dataset, + counter=counting.Counter(counter, "bc"), + logger=self._bc_logger, + checkpoint=False, + ) + self._bcq_learner = _InternalBCQLearner( + network=network, + learning_rate=learning_rate, + dataset=dataset, + counter=counting.Counter(counter, "bcq"), + logger=self._bcq_logger, + **bcq_learner_kwargs, + ) + + def get_variables(self, names): + return self._bcq_learner.get_variables(names) + + @property + def state(self): + bc_state = self._bc_learner.state + bc_state.pop("network") # No need to checkpoint the BC network. + bcq_state = self._bcq_learner.state + state = dict() + state.update({f"bc_{k}": v for k, v in bc_state.items()}) + state.update({f"bcq_{k}": v for k, v in bcq_state.items()}) + return state + + def step(self): + self._bc_learner.step() + self._bcq_learner.step() diff --git a/acme/agents/tf/bcq/discrete_learning_test.py b/acme/agents/tf/bcq/discrete_learning_test.py index 8169f10c06..89ec7dfc81 100644 --- a/acme/agents/tf/bcq/discrete_learning_test.py +++ b/acme/agents/tf/bcq/discrete_learning_test.py @@ -14,68 +14,65 @@ """Tests for discrete BCQ learner.""" +import numpy as np +import sonnet as snt +from absl.testing import absltest + from acme import specs from acme.agents.tf import bcq from acme.testing import fakes from acme.tf import utils as tf2_utils from acme.tf.networks import discrete as discrete_networks from acme.utils import counting -import numpy as np -import sonnet as snt - -from absl.testing import absltest def _make_network(action_spec: specs.DiscreteArray) -> snt.Module: - return snt.Sequential([ - snt.Flatten(), - snt.nets.MLP([50, 50, action_spec.num_values]), - ]) + return snt.Sequential( + [snt.Flatten(), snt.nets.MLP([50, 50, action_spec.num_values]),] + ) class DiscreteBCQLearnerTest(absltest.TestCase): + def test_full_learner(self): + # Create dataset. + environment = fakes.DiscreteEnvironment( + num_actions=5, num_observations=10, obs_dtype=np.float32, episode_length=10 + ) + spec = specs.make_environment_spec(environment) + dataset = fakes.transition_dataset(environment).batch(2) - def test_full_learner(self): - # Create dataset. - environment = fakes.DiscreteEnvironment( - num_actions=5, - num_observations=10, - obs_dtype=np.float32, - episode_length=10) - spec = specs.make_environment_spec(environment) - dataset = fakes.transition_dataset(environment).batch(2) - - # Build network. - g_network = _make_network(spec.actions) - q_network = _make_network(spec.actions) - network = discrete_networks.DiscreteFilteredQNetwork(g_network=g_network, - q_network=q_network, - threshold=0.5) - tf2_utils.create_variables(network, [spec.observations]) + # Build network. + g_network = _make_network(spec.actions) + q_network = _make_network(spec.actions) + network = discrete_networks.DiscreteFilteredQNetwork( + g_network=g_network, q_network=q_network, threshold=0.5 + ) + tf2_utils.create_variables(network, [spec.observations]) - # Build learner. - counter = counting.Counter() - learner = bcq.DiscreteBCQLearner( - network=network, - dataset=dataset, - learning_rate=1e-4, - discount=0.99, - importance_sampling_exponent=0.2, - target_update_period=100, - counter=counter) + # Build learner. + counter = counting.Counter() + learner = bcq.DiscreteBCQLearner( + network=network, + dataset=dataset, + learning_rate=1e-4, + discount=0.99, + importance_sampling_exponent=0.2, + target_update_period=100, + counter=counter, + ) - # Run a learner step. - learner.step() + # Run a learner step. + learner.step() - # Check counts from BC and BCQ learners. - counts = counter.get_counts() - self.assertEqual(1, counts['bc_steps']) - self.assertEqual(1, counts['bcq_steps']) + # Check counts from BC and BCQ learners. + counts = counter.get_counts() + self.assertEqual(1, counts["bc_steps"]) + self.assertEqual(1, counts["bcq_steps"]) - # Check learner state. - self.assertEqual(1, learner.state['bc_num_steps'].numpy()) - self.assertEqual(1, learner.state['bcq_num_steps'].numpy()) + # Check learner state. + self.assertEqual(1, learner.state["bc_num_steps"].numpy()) + self.assertEqual(1, learner.state["bcq_num_steps"].numpy()) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/crr/recurrent_learning.py b/acme/agents/tf/crr/recurrent_learning.py index 9369b6951b..678214ca73 100644 --- a/acme/agents/tf/crr/recurrent_learning.py +++ b/acme/agents/tf/crr/recurrent_learning.py @@ -18,51 +18,52 @@ import time from typing import Dict, List, Optional -from acme import core -from acme.tf import losses -from acme.tf import networks -from acme.tf import savers as tf2_savers -from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers import numpy as np import reverb import sonnet as snt import tensorflow as tf import tree +from acme import core +from acme.tf import losses, networks +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting, loggers + class RCRRLearner(core.Learner): - """Recurrent CRR learner. + """Recurrent CRR learner. This is the learning component of a RCRR agent. It takes a dataset as input and implements update functionality to learn from this dataset. """ - def __init__(self, - policy_network: snt.RNNCore, - critic_network: networks.CriticDeepRNN, - target_policy_network: snt.RNNCore, - target_critic_network: networks.CriticDeepRNN, - dataset: tf.data.Dataset, - accelerator_strategy: Optional[tf.distribute.Strategy] = None, - behavior_network: Optional[snt.Module] = None, - cwp_network: Optional[snt.Module] = None, - policy_optimizer: Optional[snt.Optimizer] = None, - critic_optimizer: Optional[snt.Optimizer] = None, - discount: float = 0.99, - target_update_period: int = 100, - num_action_samples_td_learning: int = 1, - num_action_samples_policy_weight: int = 4, - baseline_reduce_function: str = 'mean', - clipping: bool = True, - policy_improvement_modes: str = 'exp', - ratio_upper_bound: float = 20., - beta: float = 1.0, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - checkpoint: bool = False): - """Initializes the learner. + def __init__( + self, + policy_network: snt.RNNCore, + critic_network: networks.CriticDeepRNN, + target_policy_network: snt.RNNCore, + target_critic_network: networks.CriticDeepRNN, + dataset: tf.data.Dataset, + accelerator_strategy: Optional[tf.distribute.Strategy] = None, + behavior_network: Optional[snt.Module] = None, + cwp_network: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + discount: float = 0.99, + target_update_period: int = 100, + num_action_samples_td_learning: int = 1, + num_action_samples_policy_weight: int = 4, + baseline_reduce_function: str = "mean", + clipping: bool = True, + policy_improvement_modes: str = "exp", + ratio_upper_bound: float = 20.0, + beta: float = 1.0, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = False, + ): + """Initializes the learner. Args: policy_network: the online (optimized) policy. @@ -109,299 +110,326 @@ def __init__(self, checkpoint: boolean indicating whether to checkpoint the learner. """ - if accelerator_strategy is None: - accelerator_strategy = snt.distribute.Replicator() - self._accelerator_strategy = accelerator_strategy - self._policy_improvement_modes = policy_improvement_modes - self._ratio_upper_bound = ratio_upper_bound - self._num_action_samples_td_learning = num_action_samples_td_learning - self._num_action_samples_policy_weight = num_action_samples_policy_weight - self._baseline_reduce_function = baseline_reduce_function - self._beta = beta - - # When running on TPUs we have to know the amount of memory required (and - # thus the sequence length) at the graph compilation stage. At the moment, - # the only way to get it is to sample from the dataset, since the dataset - # does not have any metadata, see b/160672927 to track this upcoming - # feature. - sample = next(dataset.as_numpy_iterator()) - self._sequence_length = sample.action.shape[1] - - self._counter = counter or counting.Counter() - self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) - self._discount = discount - self._clipping = clipping - - self._target_update_period = target_update_period - - with self._accelerator_strategy.scope(): - # Necessary to track when to update target networks. - self._num_steps = tf.Variable(0, dtype=tf.int32) - - # (Maybe) distributing the dataset across multiple accelerators. - distributed_dataset = self._accelerator_strategy.experimental_distribute_dataset( - dataset) - self._iterator = iter(distributed_dataset) - - # Create the optimizers. - self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) - self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) - - # Store online and target networks. - self._policy_network = policy_network - self._critic_network = critic_network - self._target_policy_network = target_policy_network - self._target_critic_network = target_critic_network - - # Expose the variables. - self._variables = { - 'critic': self._target_critic_network.variables, - 'policy': self._target_policy_network.variables, - } - - # Create a checkpointer object. - self._checkpointer = None - self._snapshotter = None - if checkpoint: - self._checkpointer = tf2_savers.Checkpointer( - objects_to_save={ - 'counter': self._counter, - 'policy': self._policy_network, - 'critic': self._critic_network, - 'target_policy': self._target_policy_network, - 'target_critic': self._target_critic_network, - 'policy_optimizer': self._policy_optimizer, - 'critic_optimizer': self._critic_optimizer, - 'num_steps': self._num_steps, - }, - time_delta_minutes=30.) - - raw_policy = snt.DeepRNN( - [policy_network, networks.StochasticSamplingHead()]) - critic_mean = networks.CriticDeepRNN( - [critic_network, networks.StochasticMeanHead()]) - objects_to_save = { - 'raw_policy': raw_policy, - 'critic': critic_mean, - } - if behavior_network is not None: - objects_to_save['policy'] = behavior_network - if cwp_network is not None: - objects_to_save['cwp_policy'] = cwp_network - self._snapshotter = tf2_savers.Snapshotter( - objects_to_save=objects_to_save, time_delta_minutes=30) - # Timestamp to keep track of the wall time. - self._walltime_timestamp = time.time() - - def _step(self, sample: reverb.ReplaySample) -> Dict[str, tf.Tensor]: - # Transpose batch and sequence axes, i.e. [B, T, ...] to [T, B, ...]. - sample = tf2_utils.batch_to_sequence(sample) - observations = sample.observation - actions = sample.action - rewards = sample.reward - discounts = sample.discount - - dtype = rewards.dtype - - # Cast the additional discount to match the environment discount dtype. - discount = tf.cast(self._discount, dtype=discounts.dtype) - - # Loss cumulants across time. These cannot be python mutable objects. - critic_loss = 0. - policy_loss = 0. - - # Each transition induces a policy loss, which we then weight using - # the `policy_loss_coef_t`; shape [B], see https://arxiv.org/abs/2006.15134. - # `policy_loss_coef` is a scalar average of these coefficients across - # the batch and sequence length dimensions. - policy_loss_coef = 0. - - per_device_batch_size = actions.shape[1] - - # Initialize recurrent states. - critic_state = self._critic_network.initial_state(per_device_batch_size) - target_critic_state = critic_state - policy_state = self._policy_network.initial_state(per_device_batch_size) - target_policy_state = policy_state - - with tf.GradientTape(persistent=True) as tape: - for t in range(1, self._sequence_length): - o_tm1 = tree.map_structure(operator.itemgetter(t - 1), observations) - a_tm1 = tree.map_structure(operator.itemgetter(t - 1), actions) - r_t = tree.map_structure(operator.itemgetter(t - 1), rewards) - d_t = tree.map_structure(operator.itemgetter(t - 1), discounts) - o_t = tree.map_structure(operator.itemgetter(t), observations) - - if t != 1: - # By only updating the target critic state here we are forcing - # the target critic to ignore observations[0]. Otherwise, the - # target_critic will be unrolled for one more timestep than critic. - # The smaller the sequence length, the more problematic this is: if - # you use RNN on sequences of length 2, you would expect the code to - # never use recurrent connections. But if you don't skip updating the - # target_critic_state on observation[0] here, it won't be the case. - _, target_critic_state = self._target_critic_network( - o_tm1, a_tm1, target_critic_state) - - # ========================= Critic learning ============================ - q_tm1, next_critic_state = self._critic_network(o_tm1, a_tm1, - critic_state) - target_action_distribution, target_policy_state = self._target_policy_network( - o_t, target_policy_state) - - sampled_actions_t = target_action_distribution.sample( - self._num_action_samples_td_learning) - # [N, B, ...] - tiled_o_t = tf2_utils.tile_nested( - o_t, self._num_action_samples_td_learning) - tiled_target_critic_state = tf2_utils.tile_nested( - target_critic_state, self._num_action_samples_td_learning) - - # Compute the target critic's Q-value of the sampled actions. - sampled_q_t, _ = snt.BatchApply(self._target_critic_network)( - tiled_o_t, sampled_actions_t, tiled_target_critic_state) - - # Compute average logits by first reshaping them to [N, B, A] and then - # normalizing them across atoms. - new_shape = [self._num_action_samples_td_learning, r_t.shape[0], -1] - sampled_logits = tf.reshape(sampled_q_t.logits, new_shape) - sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) - averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) - - # Construct the expected distributional value for bootstrapping. - q_t = networks.DiscreteValuedDistribution( - values=sampled_q_t.values, logits=averaged_logits) - critic_loss_t = losses.categorical(q_tm1, r_t, discount * d_t, q_t) - critic_loss_t = tf.reduce_mean(critic_loss_t) - - # ========================= Actor learning ============================= - action_distribution_tm1, policy_state = self._policy_network( - o_tm1, policy_state) - q_tm1_mean = q_tm1.mean() - - # Compute the estimate of the value function based on - # self._num_action_samples_policy_weight samples from the policy. - tiled_o_tm1 = tf2_utils.tile_nested( - o_tm1, self._num_action_samples_policy_weight) - tiled_critic_state = tf2_utils.tile_nested( - critic_state, self._num_action_samples_policy_weight) - action_tm1 = action_distribution_tm1.sample( - self._num_action_samples_policy_weight) - tiled_z_tm1, _ = snt.BatchApply(self._critic_network)( - tiled_o_tm1, action_tm1, tiled_critic_state) - tiled_v_tm1 = tf.reshape(tiled_z_tm1.mean(), - [self._num_action_samples_policy_weight, -1]) - - # Use mean, min, or max to aggregate Q(s, a_i), a_i ~ pi(s) into the - # final estimate of the value function. - if self._baseline_reduce_function == 'mean': - v_tm1_estimate = tf.reduce_mean(tiled_v_tm1, axis=0) - elif self._baseline_reduce_function == 'max': - v_tm1_estimate = tf.reduce_max(tiled_v_tm1, axis=0) - elif self._baseline_reduce_function == 'min': - v_tm1_estimate = tf.reduce_min(tiled_v_tm1, axis=0) - - # Assert that action_distribution_tm1 is a batch of multivariate - # distributions (in contrast to e.g. a [batch, action_size] collection - # of 1d distributions). - assert len(action_distribution_tm1.batch_shape) == 1 - policy_loss_batch = -action_distribution_tm1.log_prob(a_tm1) - - advantage = q_tm1_mean - v_tm1_estimate - if self._policy_improvement_modes == 'exp': - policy_loss_coef_t = tf.math.minimum( - tf.math.exp(advantage / self._beta), self._ratio_upper_bound) - elif self._policy_improvement_modes == 'binary': - policy_loss_coef_t = tf.cast(advantage > 0, dtype=dtype) - elif self._policy_improvement_modes == 'all': - # Regress against all actions (effectively pure BC). - policy_loss_coef_t = 1. - policy_loss_coef_t = tf.stop_gradient(policy_loss_coef_t) - - policy_loss_batch *= policy_loss_coef_t - policy_loss_t = tf.reduce_mean(policy_loss_batch) - - critic_state = next_critic_state - - critic_loss += critic_loss_t - policy_loss += policy_loss_t - policy_loss_coef += tf.reduce_mean(policy_loss_coef_t) # For logging. - - # Divide by sequence length to get mean losses. - critic_loss /= tf.cast(self._sequence_length, dtype=dtype) - policy_loss /= tf.cast(self._sequence_length, dtype=dtype) - policy_loss_coef /= tf.cast(self._sequence_length, dtype=dtype) - - # Compute gradients. - critic_gradients = tape.gradient(critic_loss, - self._critic_network.trainable_variables) - policy_gradients = tape.gradient(policy_loss, - self._policy_network.trainable_variables) - - # Delete the tape manually because of the persistent=True flag. - del tape - - # Sync gradients across GPUs or TPUs. - ctx = tf.distribute.get_replica_context() - critic_gradients = ctx.all_reduce('mean', critic_gradients) - policy_gradients = ctx.all_reduce('mean', policy_gradients) - - # Maybe clip gradients. - if self._clipping: - policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0] - critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0] - - # Apply gradients. - self._critic_optimizer.apply(critic_gradients, - self._critic_network.trainable_variables) - self._policy_optimizer.apply(policy_gradients, - self._policy_network.trainable_variables) - - source_variables = ( - self._critic_network.variables + self._policy_network.variables) - target_variables = ( - self._target_critic_network.variables + - self._target_policy_network.variables) - - # Make online -> target network update ops. - if tf.math.mod(self._num_steps, self._target_update_period) == 0: - for src, dest in zip(source_variables, target_variables): - dest.assign(src) - self._num_steps.assign_add(1) - - return { - 'critic_loss': critic_loss, - 'policy_loss': policy_loss, - 'policy_loss_coef': policy_loss_coef, - } - - @tf.function - def _replicated_step(self) -> Dict[str, tf.Tensor]: - sample = next(self._iterator) - fetches = self._accelerator_strategy.run(self._step, args=(sample,)) - mean = tf.distribute.ReduceOp.MEAN - return { - k: self._accelerator_strategy.reduce(mean, fetches[k], axis=None) - for k in fetches - } - - def step(self): - # Run the learning step. - with self._accelerator_strategy.scope(): - fetches = self._replicated_step() - - # Update our counts and record it. - new_timestamp = time.time() - time_passed = new_timestamp - self._walltime_timestamp - self._walltime_timestamp = new_timestamp - counts = self._counter.increment(steps=1, wall_time=time_passed) - fetches.update(counts) - - # Checkpoint and attempt to write the logs. - if self._checkpointer is not None: - self._checkpointer.save() - self._snapshotter.save() - self._logger.write(fetches) - - def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: - return [tf2_utils.to_numpy(self._variables[name]) for name in names] + if accelerator_strategy is None: + accelerator_strategy = snt.distribute.Replicator() + self._accelerator_strategy = accelerator_strategy + self._policy_improvement_modes = policy_improvement_modes + self._ratio_upper_bound = ratio_upper_bound + self._num_action_samples_td_learning = num_action_samples_td_learning + self._num_action_samples_policy_weight = num_action_samples_policy_weight + self._baseline_reduce_function = baseline_reduce_function + self._beta = beta + + # When running on TPUs we have to know the amount of memory required (and + # thus the sequence length) at the graph compilation stage. At the moment, + # the only way to get it is to sample from the dataset, since the dataset + # does not have any metadata, see b/160672927 to track this upcoming + # feature. + sample = next(dataset.as_numpy_iterator()) + self._sequence_length = sample.action.shape[1] + + self._counter = counter or counting.Counter() + self._logger = logger or loggers.TerminalLogger("learner", time_delta=1.0) + self._discount = discount + self._clipping = clipping + + self._target_update_period = target_update_period + + with self._accelerator_strategy.scope(): + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + + # (Maybe) distributing the dataset across multiple accelerators. + distributed_dataset = self._accelerator_strategy.experimental_distribute_dataset( + dataset + ) + self._iterator = iter(distributed_dataset) + + # Create the optimizers. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + + # Expose the variables. + self._variables = { + "critic": self._target_critic_network.variables, + "policy": self._target_policy_network.variables, + } + + # Create a checkpointer object. + self._checkpointer = None + self._snapshotter = None + if checkpoint: + self._checkpointer = tf2_savers.Checkpointer( + objects_to_save={ + "counter": self._counter, + "policy": self._policy_network, + "critic": self._critic_network, + "target_policy": self._target_policy_network, + "target_critic": self._target_critic_network, + "policy_optimizer": self._policy_optimizer, + "critic_optimizer": self._critic_optimizer, + "num_steps": self._num_steps, + }, + time_delta_minutes=30.0, + ) + + raw_policy = snt.DeepRNN( + [policy_network, networks.StochasticSamplingHead()] + ) + critic_mean = networks.CriticDeepRNN( + [critic_network, networks.StochasticMeanHead()] + ) + objects_to_save = { + "raw_policy": raw_policy, + "critic": critic_mean, + } + if behavior_network is not None: + objects_to_save["policy"] = behavior_network + if cwp_network is not None: + objects_to_save["cwp_policy"] = cwp_network + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save=objects_to_save, time_delta_minutes=30 + ) + # Timestamp to keep track of the wall time. + self._walltime_timestamp = time.time() + + def _step(self, sample: reverb.ReplaySample) -> Dict[str, tf.Tensor]: + # Transpose batch and sequence axes, i.e. [B, T, ...] to [T, B, ...]. + sample = tf2_utils.batch_to_sequence(sample) + observations = sample.observation + actions = sample.action + rewards = sample.reward + discounts = sample.discount + + dtype = rewards.dtype + + # Cast the additional discount to match the environment discount dtype. + discount = tf.cast(self._discount, dtype=discounts.dtype) + + # Loss cumulants across time. These cannot be python mutable objects. + critic_loss = 0.0 + policy_loss = 0.0 + + # Each transition induces a policy loss, which we then weight using + # the `policy_loss_coef_t`; shape [B], see https://arxiv.org/abs/2006.15134. + # `policy_loss_coef` is a scalar average of these coefficients across + # the batch and sequence length dimensions. + policy_loss_coef = 0.0 + + per_device_batch_size = actions.shape[1] + + # Initialize recurrent states. + critic_state = self._critic_network.initial_state(per_device_batch_size) + target_critic_state = critic_state + policy_state = self._policy_network.initial_state(per_device_batch_size) + target_policy_state = policy_state + + with tf.GradientTape(persistent=True) as tape: + for t in range(1, self._sequence_length): + o_tm1 = tree.map_structure(operator.itemgetter(t - 1), observations) + a_tm1 = tree.map_structure(operator.itemgetter(t - 1), actions) + r_t = tree.map_structure(operator.itemgetter(t - 1), rewards) + d_t = tree.map_structure(operator.itemgetter(t - 1), discounts) + o_t = tree.map_structure(operator.itemgetter(t), observations) + + if t != 1: + # By only updating the target critic state here we are forcing + # the target critic to ignore observations[0]. Otherwise, the + # target_critic will be unrolled for one more timestep than critic. + # The smaller the sequence length, the more problematic this is: if + # you use RNN on sequences of length 2, you would expect the code to + # never use recurrent connections. But if you don't skip updating the + # target_critic_state on observation[0] here, it won't be the case. + _, target_critic_state = self._target_critic_network( + o_tm1, a_tm1, target_critic_state + ) + + # ========================= Critic learning ============================ + q_tm1, next_critic_state = self._critic_network( + o_tm1, a_tm1, critic_state + ) + ( + target_action_distribution, + target_policy_state, + ) = self._target_policy_network(o_t, target_policy_state) + + sampled_actions_t = target_action_distribution.sample( + self._num_action_samples_td_learning + ) + # [N, B, ...] + tiled_o_t = tf2_utils.tile_nested( + o_t, self._num_action_samples_td_learning + ) + tiled_target_critic_state = tf2_utils.tile_nested( + target_critic_state, self._num_action_samples_td_learning + ) + + # Compute the target critic's Q-value of the sampled actions. + sampled_q_t, _ = snt.BatchApply(self._target_critic_network)( + tiled_o_t, sampled_actions_t, tiled_target_critic_state + ) + + # Compute average logits by first reshaping them to [N, B, A] and then + # normalizing them across atoms. + new_shape = [self._num_action_samples_td_learning, r_t.shape[0], -1] + sampled_logits = tf.reshape(sampled_q_t.logits, new_shape) + sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) + averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) + + # Construct the expected distributional value for bootstrapping. + q_t = networks.DiscreteValuedDistribution( + values=sampled_q_t.values, logits=averaged_logits + ) + critic_loss_t = losses.categorical(q_tm1, r_t, discount * d_t, q_t) + critic_loss_t = tf.reduce_mean(critic_loss_t) + + # ========================= Actor learning ============================= + action_distribution_tm1, policy_state = self._policy_network( + o_tm1, policy_state + ) + q_tm1_mean = q_tm1.mean() + + # Compute the estimate of the value function based on + # self._num_action_samples_policy_weight samples from the policy. + tiled_o_tm1 = tf2_utils.tile_nested( + o_tm1, self._num_action_samples_policy_weight + ) + tiled_critic_state = tf2_utils.tile_nested( + critic_state, self._num_action_samples_policy_weight + ) + action_tm1 = action_distribution_tm1.sample( + self._num_action_samples_policy_weight + ) + tiled_z_tm1, _ = snt.BatchApply(self._critic_network)( + tiled_o_tm1, action_tm1, tiled_critic_state + ) + tiled_v_tm1 = tf.reshape( + tiled_z_tm1.mean(), [self._num_action_samples_policy_weight, -1] + ) + + # Use mean, min, or max to aggregate Q(s, a_i), a_i ~ pi(s) into the + # final estimate of the value function. + if self._baseline_reduce_function == "mean": + v_tm1_estimate = tf.reduce_mean(tiled_v_tm1, axis=0) + elif self._baseline_reduce_function == "max": + v_tm1_estimate = tf.reduce_max(tiled_v_tm1, axis=0) + elif self._baseline_reduce_function == "min": + v_tm1_estimate = tf.reduce_min(tiled_v_tm1, axis=0) + + # Assert that action_distribution_tm1 is a batch of multivariate + # distributions (in contrast to e.g. a [batch, action_size] collection + # of 1d distributions). + assert len(action_distribution_tm1.batch_shape) == 1 + policy_loss_batch = -action_distribution_tm1.log_prob(a_tm1) + + advantage = q_tm1_mean - v_tm1_estimate + if self._policy_improvement_modes == "exp": + policy_loss_coef_t = tf.math.minimum( + tf.math.exp(advantage / self._beta), self._ratio_upper_bound + ) + elif self._policy_improvement_modes == "binary": + policy_loss_coef_t = tf.cast(advantage > 0, dtype=dtype) + elif self._policy_improvement_modes == "all": + # Regress against all actions (effectively pure BC). + policy_loss_coef_t = 1.0 + policy_loss_coef_t = tf.stop_gradient(policy_loss_coef_t) + + policy_loss_batch *= policy_loss_coef_t + policy_loss_t = tf.reduce_mean(policy_loss_batch) + + critic_state = next_critic_state + + critic_loss += critic_loss_t + policy_loss += policy_loss_t + policy_loss_coef += tf.reduce_mean(policy_loss_coef_t) # For logging. + + # Divide by sequence length to get mean losses. + critic_loss /= tf.cast(self._sequence_length, dtype=dtype) + policy_loss /= tf.cast(self._sequence_length, dtype=dtype) + policy_loss_coef /= tf.cast(self._sequence_length, dtype=dtype) + + # Compute gradients. + critic_gradients = tape.gradient( + critic_loss, self._critic_network.trainable_variables + ) + policy_gradients = tape.gradient( + policy_loss, self._policy_network.trainable_variables + ) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Sync gradients across GPUs or TPUs. + ctx = tf.distribute.get_replica_context() + critic_gradients = ctx.all_reduce("mean", critic_gradients) + policy_gradients = ctx.all_reduce("mean", policy_gradients) + + # Maybe clip gradients. + if self._clipping: + policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.0)[0] + critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.0)[0] + + # Apply gradients. + self._critic_optimizer.apply( + critic_gradients, self._critic_network.trainable_variables + ) + self._policy_optimizer.apply( + policy_gradients, self._policy_network.trainable_variables + ) + + source_variables = ( + self._critic_network.variables + self._policy_network.variables + ) + target_variables = ( + self._target_critic_network.variables + + self._target_policy_network.variables + ) + + # Make online -> target network update ops. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(source_variables, target_variables): + dest.assign(src) + self._num_steps.assign_add(1) + + return { + "critic_loss": critic_loss, + "policy_loss": policy_loss, + "policy_loss_coef": policy_loss_coef, + } + + @tf.function + def _replicated_step(self) -> Dict[str, tf.Tensor]: + sample = next(self._iterator) + fetches = self._accelerator_strategy.run(self._step, args=(sample,)) + mean = tf.distribute.ReduceOp.MEAN + return { + k: self._accelerator_strategy.reduce(mean, fetches[k], axis=None) + for k in fetches + } + + def step(self): + # Run the learning step. + with self._accelerator_strategy.scope(): + fetches = self._replicated_step() + + # Update our counts and record it. + new_timestamp = time.time() + time_passed = new_timestamp - self._walltime_timestamp + self._walltime_timestamp = new_timestamp + counts = self._counter.increment(steps=1, wall_time=time_passed) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + if self._checkpointer is not None: + self._checkpointer.save() + self._snapshotter.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] diff --git a/acme/agents/tf/d4pg/agent.py b/acme/agents/tf/d4pg/agent.py index 96daa1aed6..712612c938 100644 --- a/acme/agents/tf/d4pg/agent.py +++ b/acme/agents/tf/d4pg/agent.py @@ -17,239 +17,228 @@ import copy import dataclasses import functools -from typing import Iterator, List, Optional, Tuple, Union, Sequence +from typing import Iterator, List, Optional, Sequence, Tuple, Union -from acme import adders -from acme import core -from acme import datasets -from acme import specs -from acme import types +import reverb +import sonnet as snt +import tensorflow as tf + +from acme import adders, core, datasets, specs, types from acme.adders import reverb as reverb_adders from acme.agents import agent from acme.agents.tf import actors from acme.agents.tf.d4pg import learning from acme.tf import networks as network_utils -from acme.tf import utils -from acme.tf import variable_utils -from acme.utils import counting -from acme.utils import loggers -import reverb -import sonnet as snt -import tensorflow as tf +from acme.tf import utils, variable_utils +from acme.utils import counting, loggers Replicator = Union[snt.distribute.Replicator, snt.distribute.TpuReplicator] @dataclasses.dataclass class D4PGConfig: - """Configuration options for the D4PG agent.""" - - accelerator: Optional[str] = None - discount: float = 0.99 - batch_size: int = 256 - prefetch_size: int = 4 - target_update_period: int = 100 - variable_update_period: int = 1000 - policy_optimizer: Optional[snt.Optimizer] = None - critic_optimizer: Optional[snt.Optimizer] = None - min_replay_size: int = 1000 - max_replay_size: int = 1000000 - samples_per_insert: Optional[float] = 32.0 - n_step: int = 5 - sigma: float = 0.3 - clipping: bool = True - replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE + """Configuration options for the D4PG agent.""" + + accelerator: Optional[str] = None + discount: float = 0.99 + batch_size: int = 256 + prefetch_size: int = 4 + target_update_period: int = 100 + variable_update_period: int = 1000 + policy_optimizer: Optional[snt.Optimizer] = None + critic_optimizer: Optional[snt.Optimizer] = None + min_replay_size: int = 1000 + max_replay_size: int = 1000000 + samples_per_insert: Optional[float] = 32.0 + n_step: int = 5 + sigma: float = 0.3 + clipping: bool = True + replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE @dataclasses.dataclass class D4PGNetworks: - """Structure containing the networks for D4PG.""" - - policy_network: snt.Module - critic_network: snt.Module - observation_network: snt.Module - - def __init__( - self, - policy_network: snt.Module, - critic_network: snt.Module, - observation_network: types.TensorTransformation, - ): - # This method is implemented (rather than added by the dataclass decorator) - # in order to allow observation network to be passed as an arbitrary tensor - # transformation rather than as a snt Module. - # TODO(mwhoffman): use Protocol rather than Module/TensorTransformation. - self.policy_network = policy_network - self.critic_network = critic_network - self.observation_network = utils.to_sonnet_module(observation_network) - - def init(self, environment_spec: specs.EnvironmentSpec): - """Initialize the networks given an environment spec.""" - # Get observation and action specs. - act_spec = environment_spec.actions - obs_spec = environment_spec.observations - - # Create variables for the observation net and, as a side-effect, get a - # spec describing the embedding space. - emb_spec = utils.create_variables(self.observation_network, [obs_spec]) - - # Create variables for the policy and critic nets. - _ = utils.create_variables(self.policy_network, [emb_spec]) - _ = utils.create_variables(self.critic_network, [emb_spec, act_spec]) - - def make_policy( - self, - environment_spec: specs.EnvironmentSpec, - sigma: float = 0.0, - ) -> snt.Module: - """Create a single network which evaluates the policy.""" - # Stack the observation and policy networks. - stack = [ - self.observation_network, - self.policy_network, - ] - - # If a stochastic/non-greedy policy is requested, add Gaussian noise on - # top to enable a simple form of exploration. - # TODO(mwhoffman): Refactor this to remove it from the class. - if sigma > 0.0: - stack += [ - network_utils.ClippedGaussian(sigma), - network_utils.ClipToSpec(environment_spec.actions), - ] - - # Return a network which sequentially evaluates everything in the stack. - return snt.Sequential(stack) + """Structure containing the networks for D4PG.""" + + policy_network: snt.Module + critic_network: snt.Module + observation_network: snt.Module + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + observation_network: types.TensorTransformation, + ): + # This method is implemented (rather than added by the dataclass decorator) + # in order to allow observation network to be passed as an arbitrary tensor + # transformation rather than as a snt Module. + # TODO(mwhoffman): use Protocol rather than Module/TensorTransformation. + self.policy_network = policy_network + self.critic_network = critic_network + self.observation_network = utils.to_sonnet_module(observation_network) + + def init(self, environment_spec: specs.EnvironmentSpec): + """Initialize the networks given an environment spec.""" + # Get observation and action specs. + act_spec = environment_spec.actions + obs_spec = environment_spec.observations + + # Create variables for the observation net and, as a side-effect, get a + # spec describing the embedding space. + emb_spec = utils.create_variables(self.observation_network, [obs_spec]) + + # Create variables for the policy and critic nets. + _ = utils.create_variables(self.policy_network, [emb_spec]) + _ = utils.create_variables(self.critic_network, [emb_spec, act_spec]) + + def make_policy( + self, environment_spec: specs.EnvironmentSpec, sigma: float = 0.0, + ) -> snt.Module: + """Create a single network which evaluates the policy.""" + # Stack the observation and policy networks. + stack = [ + self.observation_network, + self.policy_network, + ] + + # If a stochastic/non-greedy policy is requested, add Gaussian noise on + # top to enable a simple form of exploration. + # TODO(mwhoffman): Refactor this to remove it from the class. + if sigma > 0.0: + stack += [ + network_utils.ClippedGaussian(sigma), + network_utils.ClipToSpec(environment_spec.actions), + ] + + # Return a network which sequentially evaluates everything in the stack. + return snt.Sequential(stack) class D4PGBuilder: - """Builder for D4PG which constructs individual components of the agent.""" - - def __init__(self, config: D4PGConfig): - self._config = config - - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - ) -> List[reverb.Table]: - """Create tables to insert data into.""" - if self._config.samples_per_insert is None: - # We will take a samples_per_insert ratio of None to mean that there is - # no limit, i.e. this only implies a min size limit. - limiter = reverb.rate_limiters.MinSize(self._config.min_replay_size) - - else: - # Create enough of an error buffer to give a 10% tolerance in rate. - samples_per_insert_tolerance = 0.1 * self._config.samples_per_insert - error_buffer = self._config.min_replay_size * samples_per_insert_tolerance - limiter = reverb.rate_limiters.SampleToInsertRatio( - min_size_to_sample=self._config.min_replay_size, - samples_per_insert=self._config.samples_per_insert, - error_buffer=error_buffer) - - replay_table = reverb.Table( - name=self._config.replay_table_name, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=self._config.max_replay_size, - rate_limiter=limiter, - signature=reverb_adders.NStepTransitionAdder.signature( - environment_spec)) - - return [replay_table] - - def make_dataset_iterator( - self, - reverb_client: reverb.Client, - ) -> Iterator[reverb.ReplaySample]: - """Create a dataset iterator to use for learning/updating the agent.""" - # The dataset provides an interface to sample from replay. - dataset = datasets.make_reverb_dataset( - table=self._config.replay_table_name, - server_address=reverb_client.server_address, - batch_size=self._config.batch_size, - prefetch_size=self._config.prefetch_size) - - replicator = get_replicator(self._config.accelerator) - dataset = replicator.experimental_distribute_dataset(dataset) - - # TODO(b/155086959): Fix type stubs and remove. - return iter(dataset) # pytype: disable=wrong-arg-types - - def make_adder( - self, - replay_client: reverb.Client, - ) -> adders.Adder: - """Create an adder which records data generated by the actor/environment.""" - return reverb_adders.NStepTransitionAdder( - priority_fns={self._config.replay_table_name: lambda x: 1.}, - client=replay_client, - n_step=self._config.n_step, - discount=self._config.discount) - - def make_actor( - self, - policy_network: snt.Module, - adder: Optional[adders.Adder] = None, - variable_source: Optional[core.VariableSource] = None, - ): - """Create an actor instance.""" - if variable_source: - # Create the variable client responsible for keeping the actor up-to-date. - variable_client = variable_utils.VariableClient( - client=variable_source, - variables={'policy': policy_network.variables}, - update_period=self._config.variable_update_period, - ) - - # Make sure not to use a random policy after checkpoint restoration by - # assigning variables before running the environment loop. - variable_client.update_and_wait() - - else: - variable_client = None - - # Create the actor which defines how we take actions. - return actors.FeedForwardActor( - policy_network=policy_network, - adder=adder, - variable_client=variable_client, - ) - - def make_learner( - self, - networks: Tuple[D4PGNetworks, D4PGNetworks], - dataset: Iterator[reverb.ReplaySample], - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - checkpoint: bool = False, - ): - """Creates an instance of the learner.""" - online_networks, target_networks = networks - - # The learner updates the parameters (and initializes them). - return learning.D4PGLearner( - policy_network=online_networks.policy_network, - critic_network=online_networks.critic_network, - observation_network=online_networks.observation_network, - target_policy_network=target_networks.policy_network, - target_critic_network=target_networks.critic_network, - target_observation_network=target_networks.observation_network, - policy_optimizer=self._config.policy_optimizer, - critic_optimizer=self._config.critic_optimizer, - clipping=self._config.clipping, - discount=self._config.discount, - target_update_period=self._config.target_update_period, - dataset_iterator=dataset, - replicator=get_replicator(self._config.accelerator), - counter=counter, - logger=logger, - checkpoint=checkpoint, - ) + """Builder for D4PG which constructs individual components of the agent.""" + + def __init__(self, config: D4PGConfig): + self._config = config + + def make_replay_tables( + self, environment_spec: specs.EnvironmentSpec, + ) -> List[reverb.Table]: + """Create tables to insert data into.""" + if self._config.samples_per_insert is None: + # We will take a samples_per_insert ratio of None to mean that there is + # no limit, i.e. this only implies a min size limit. + limiter = reverb.rate_limiters.MinSize(self._config.min_replay_size) + + else: + # Create enough of an error buffer to give a 10% tolerance in rate. + samples_per_insert_tolerance = 0.1 * self._config.samples_per_insert + error_buffer = self._config.min_replay_size * samples_per_insert_tolerance + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer, + ) + + replay_table = reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=reverb_adders.NStepTransitionAdder.signature(environment_spec), + ) + + return [replay_table] + + def make_dataset_iterator( + self, reverb_client: reverb.Client, + ) -> Iterator[reverb.ReplaySample]: + """Create a dataset iterator to use for learning/updating the agent.""" + # The dataset provides an interface to sample from replay. + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=reverb_client.server_address, + batch_size=self._config.batch_size, + prefetch_size=self._config.prefetch_size, + ) + + replicator = get_replicator(self._config.accelerator) + dataset = replicator.experimental_distribute_dataset(dataset) + + # TODO(b/155086959): Fix type stubs and remove. + return iter(dataset) # pytype: disable=wrong-arg-types + + def make_adder(self, replay_client: reverb.Client,) -> adders.Adder: + """Create an adder which records data generated by the actor/environment.""" + return reverb_adders.NStepTransitionAdder( + priority_fns={self._config.replay_table_name: lambda x: 1.0}, + client=replay_client, + n_step=self._config.n_step, + discount=self._config.discount, + ) + + def make_actor( + self, + policy_network: snt.Module, + adder: Optional[adders.Adder] = None, + variable_source: Optional[core.VariableSource] = None, + ): + """Create an actor instance.""" + if variable_source: + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = variable_utils.VariableClient( + client=variable_source, + variables={"policy": policy_network.variables}, + update_period=self._config.variable_update_period, + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + else: + variable_client = None + + # Create the actor which defines how we take actions. + return actors.FeedForwardActor( + policy_network=policy_network, adder=adder, variable_client=variable_client, + ) + + def make_learner( + self, + networks: Tuple[D4PGNetworks, D4PGNetworks], + dataset: Iterator[reverb.ReplaySample], + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = False, + ): + """Creates an instance of the learner.""" + online_networks, target_networks = networks + + # The learner updates the parameters (and initializes them). + return learning.D4PGLearner( + policy_network=online_networks.policy_network, + critic_network=online_networks.critic_network, + observation_network=online_networks.observation_network, + target_policy_network=target_networks.policy_network, + target_critic_network=target_networks.critic_network, + target_observation_network=target_networks.observation_network, + policy_optimizer=self._config.policy_optimizer, + critic_optimizer=self._config.critic_optimizer, + clipping=self._config.clipping, + discount=self._config.discount, + target_update_period=self._config.target_update_period, + dataset_iterator=dataset, + replicator=get_replicator(self._config.accelerator), + counter=counter, + logger=logger, + checkpoint=checkpoint, + ) class D4PG(agent.Agent): - """D4PG Agent. + """D4PG Agent. This implements a single-process D4PG agent. This is an actor-critic algorithm that generates data via a behavior policy, inserts N-step transitions into @@ -257,31 +246,31 @@ class D4PG(agent.Agent): behavior) by sampling uniformly from this buffer. """ - def __init__( - self, - environment_spec: specs.EnvironmentSpec, - policy_network: snt.Module, - critic_network: snt.Module, - observation_network: types.TensorTransformation = tf.identity, - accelerator: Optional[str] = None, - discount: float = 0.99, - batch_size: int = 256, - prefetch_size: int = 4, - target_update_period: int = 100, - policy_optimizer: Optional[snt.Optimizer] = None, - critic_optimizer: Optional[snt.Optimizer] = None, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: float = 32.0, - n_step: int = 5, - sigma: float = 0.3, - clipping: bool = True, - replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - checkpoint: bool = True, - ): - """Initialize the agent. + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + policy_network: snt.Module, + critic_network: snt.Module, + observation_network: types.TensorTransformation = tf.identity, + accelerator: Optional[str] = None, + discount: float = 0.99, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 32.0, + n_step: int = 5, + sigma: float = 0.3, + clipping: bool = True, + replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. @@ -310,81 +299,84 @@ def __init__( logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ - if not accelerator: - accelerator = _get_first_available_accelerator_type(['TPU', 'GPU', 'CPU']) - - # Create the Builder object which will internally create agent components. - builder = D4PGBuilder( - # TODO(mwhoffman): pass the config dataclass in directly. - # TODO(mwhoffman): use the limiter rather than the workaround below. - # Right now this modifies min_replay_size and samples_per_insert so that - # they are not controlled by a limiter and are instead handled by the - # Agent base class (the above TODO directly references this behavior). - D4PGConfig( - accelerator=accelerator, - discount=discount, - batch_size=batch_size, - prefetch_size=prefetch_size, - target_update_period=target_update_period, - policy_optimizer=policy_optimizer, - critic_optimizer=critic_optimizer, - min_replay_size=1, # Let the Agent class handle this. - max_replay_size=max_replay_size, - samples_per_insert=None, # Let the Agent class handle this. - n_step=n_step, - sigma=sigma, - clipping=clipping, - replay_table_name=replay_table_name, - )) - - replicator = get_replicator(accelerator) - - with replicator.scope(): - # TODO(mwhoffman): pass the network dataclass in directly. - online_networks = D4PGNetworks(policy_network=policy_network, - critic_network=critic_network, - observation_network=observation_network) - - # Target networks are just a copy of the online networks. - target_networks = copy.deepcopy(online_networks) - - # Initialize the networks. - online_networks.init(environment_spec) - target_networks.init(environment_spec) - - # TODO(mwhoffman): either make this Dataclass or pass only one struct. - # The network struct passed to make_learner is just a tuple for the - # time-being (for backwards compatibility). - networks = (online_networks, target_networks) - - # Create the behavior policy. - policy_network = online_networks.make_policy(environment_spec, sigma) - - # Create the replay server and grab its address. - replay_tables = builder.make_replay_tables(environment_spec) - replay_server = reverb.Server(replay_tables, port=None) - replay_client = reverb.Client(f'localhost:{replay_server.port}') - - # Create actor, dataset, and learner for generating, storing, and consuming - # data respectively. - adder = builder.make_adder(replay_client) - actor = builder.make_actor(policy_network, adder) - dataset = builder.make_dataset_iterator(replay_client) - learner = builder.make_learner(networks, dataset, counter, logger, - checkpoint) - - super().__init__( - actor=actor, - learner=learner, - min_observations=max(batch_size, min_replay_size), - observations_per_step=float(batch_size) / samples_per_insert) - - # Save the replay so we don't garbage collect it. - self._replay_server = replay_server + if not accelerator: + accelerator = _get_first_available_accelerator_type(["TPU", "GPU", "CPU"]) + + # Create the Builder object which will internally create agent components. + builder = D4PGBuilder( + # TODO(mwhoffman): pass the config dataclass in directly. + # TODO(mwhoffman): use the limiter rather than the workaround below. + # Right now this modifies min_replay_size and samples_per_insert so that + # they are not controlled by a limiter and are instead handled by the + # Agent base class (the above TODO directly references this behavior). + D4PGConfig( + accelerator=accelerator, + discount=discount, + batch_size=batch_size, + prefetch_size=prefetch_size, + target_update_period=target_update_period, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + min_replay_size=1, # Let the Agent class handle this. + max_replay_size=max_replay_size, + samples_per_insert=None, # Let the Agent class handle this. + n_step=n_step, + sigma=sigma, + clipping=clipping, + replay_table_name=replay_table_name, + ) + ) + + replicator = get_replicator(accelerator) + + with replicator.scope(): + # TODO(mwhoffman): pass the network dataclass in directly. + online_networks = D4PGNetworks( + policy_network=policy_network, + critic_network=critic_network, + observation_network=observation_network, + ) + + # Target networks are just a copy of the online networks. + target_networks = copy.deepcopy(online_networks) + + # Initialize the networks. + online_networks.init(environment_spec) + target_networks.init(environment_spec) + + # TODO(mwhoffman): either make this Dataclass or pass only one struct. + # The network struct passed to make_learner is just a tuple for the + # time-being (for backwards compatibility). + networks = (online_networks, target_networks) + + # Create the behavior policy. + policy_network = online_networks.make_policy(environment_spec, sigma) + + # Create the replay server and grab its address. + replay_tables = builder.make_replay_tables(environment_spec) + replay_server = reverb.Server(replay_tables, port=None) + replay_client = reverb.Client(f"localhost:{replay_server.port}") + + # Create actor, dataset, and learner for generating, storing, and consuming + # data respectively. + adder = builder.make_adder(replay_client) + actor = builder.make_actor(policy_network, adder) + dataset = builder.make_dataset_iterator(replay_client) + learner = builder.make_learner(networks, dataset, counter, logger, checkpoint) + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert, + ) + + # Save the replay so we don't garbage collect it. + self._replay_server = replay_server def _ensure_accelerator(accelerator: str) -> str: - """Checks for the existence of the expected accelerator type. + """Checks for the existence of the expected accelerator type. Args: accelerator: 'CPU', 'GPU' or 'TPU'. @@ -395,20 +387,23 @@ def _ensure_accelerator(accelerator: str) -> str: Raises: RuntimeError: Thrown if the expected accelerator isn't found. """ - devices = tf.config.get_visible_devices(device_type=accelerator) + devices = tf.config.get_visible_devices(device_type=accelerator) - if devices: - return accelerator - else: - error_messages = [f'Couldn\'t find any {accelerator} devices.', - 'tf.config.get_visible_devices() returned:'] - error_messages.extend([str(d) for d in devices]) - raise RuntimeError('\n'.join(error_messages)) + if devices: + return accelerator + else: + error_messages = [ + f"Couldn't find any {accelerator} devices.", + "tf.config.get_visible_devices() returned:", + ] + error_messages.extend([str(d) for d in devices]) + raise RuntimeError("\n".join(error_messages)) def _get_first_available_accelerator_type( - wishlist: Sequence[str] = ('TPU', 'GPU', 'CPU')) -> str: - """Returns the first available accelerator type listed in a wishlist. + wishlist: Sequence[str] = ("TPU", "GPU", "CPU") +) -> str: + """Returns the first available accelerator type listed in a wishlist. Args: wishlist: A sequence of elements from {'CPU', 'GPU', 'TPU'}, listed in @@ -420,25 +415,25 @@ def _get_first_available_accelerator_type( Raises: RuntimeError: Thrown if no accelerators from the `wishlist` are found. """ - get_visible_devices = tf.config.get_visible_devices + get_visible_devices = tf.config.get_visible_devices - for wishlist_device in wishlist: - devices = get_visible_devices(device_type=wishlist_device) - if devices: - return wishlist_device + for wishlist_device in wishlist: + devices = get_visible_devices(device_type=wishlist_device) + if devices: + return wishlist_device - available = ', '.join( - sorted(frozenset([d.type for d in get_visible_devices()]))) - raise RuntimeError( - 'Couldn\'t find any devices from {wishlist}.' + - f'Only the following types are available: {available}.') + available = ", ".join(sorted(frozenset([d.type for d in get_visible_devices()]))) + raise RuntimeError( + "Couldn't find any devices from {wishlist}." + + f"Only the following types are available: {available}." + ) # Only instantiate one replicator per (process, accelerator type), in case # a replicator stores state that needs to be carried between its method calls. @functools.lru_cache() def get_replicator(accelerator: Optional[str]) -> Replicator: - """Returns a replicator instance appropriate for the given accelerator. + """Returns a replicator instance appropriate for the given accelerator. This caches the instance using functools.cache, so that only one replicator is instantiated per process and argument value. @@ -451,13 +446,13 @@ def get_replicator(accelerator: Optional[str]) -> Replicator: A replicator, for replciating weights, datasets, and updates across one or more accelerators. """ - if accelerator: - accelerator = _ensure_accelerator(accelerator) - else: - accelerator = _get_first_available_accelerator_type() - - if accelerator == 'TPU': - tf.tpu.experimental.initialize_tpu_system() - return snt.distribute.TpuReplicator() - else: - return snt.distribute.Replicator() + if accelerator: + accelerator = _ensure_accelerator(accelerator) + else: + accelerator = _get_first_available_accelerator_type() + + if accelerator == "TPU": + tf.tpu.experimental.initialize_tpu_system() + return snt.distribute.TpuReplicator() + else: + return snt.distribute.Replicator() diff --git a/acme/agents/tf/d4pg/agent_distributed.py b/acme/agents/tf/d4pg/agent_distributed.py index 8ac8ff3bf4..2e24421e9a 100644 --- a/acme/agents/tf/d4pg/agent_distributed.py +++ b/acme/agents/tf/d4pg/agent_distributed.py @@ -17,252 +17,252 @@ import copy from typing import Callable, Dict, Optional -import acme -from acme import specs -from acme.agents.tf.d4pg import agent -from acme.tf import savers as tf2_savers -from acme.utils import counting -from acme.utils import loggers -from acme.utils import lp_utils import dm_env import launchpad as lp import reverb import sonnet as snt import tensorflow as tf +import acme +from acme import specs +from acme.agents.tf.d4pg import agent +from acme.tf import savers as tf2_savers +from acme.utils import counting, loggers, lp_utils + # Valid values of the "accelerator" argument. -_ACCELERATORS = ('CPU', 'GPU', 'TPU') +_ACCELERATORS = ("CPU", "GPU", "TPU") class DistributedD4PG: - """Program definition for D4PG.""" - - def __init__( - self, - environment_factory: Callable[[bool], dm_env.Environment], - network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], - accelerator: Optional[str] = None, - num_actors: int = 1, - num_caches: int = 0, - environment_spec: Optional[specs.EnvironmentSpec] = None, - batch_size: int = 256, - prefetch_size: int = 4, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: Optional[float] = 32.0, - n_step: int = 5, - sigma: float = 0.3, - clipping: bool = True, - discount: float = 0.99, - policy_optimizer: Optional[snt.Optimizer] = None, - critic_optimizer: Optional[snt.Optimizer] = None, - target_update_period: int = 100, - variable_update_period: int = 1000, - max_actor_steps: Optional[int] = None, - log_every: float = 10.0, - ): - - if accelerator is not None and accelerator not in _ACCELERATORS: - raise ValueError(f'Accelerator must be one of {_ACCELERATORS}, ' - f'not "{accelerator}".') - - if not environment_spec: - environment_spec = specs.make_environment_spec(environment_factory(False)) - - # TODO(mwhoffman): Make network_factory directly return the struct. - # TODO(mwhoffman): Make the factory take the entire spec. - def wrapped_network_factory(action_spec): - networks_dict = network_factory(action_spec) - networks = agent.D4PGNetworks( - policy_network=networks_dict.get('policy'), - critic_network=networks_dict.get('critic'), - observation_network=networks_dict.get('observation', tf.identity)) - return networks - - self._environment_factory = environment_factory - self._network_factory = wrapped_network_factory - self._environment_spec = environment_spec - self._sigma = sigma - self._num_actors = num_actors - self._num_caches = num_caches - self._max_actor_steps = max_actor_steps - self._log_every = log_every - self._accelerator = accelerator - self._variable_update_period = variable_update_period - - self._builder = agent.D4PGBuilder( - # TODO(mwhoffman): pass the config dataclass in directly. - # TODO(mwhoffman): use the limiter rather than the workaround below. - agent.D4PGConfig( - accelerator=accelerator, - discount=discount, - batch_size=batch_size, - prefetch_size=prefetch_size, - target_update_period=target_update_period, - variable_update_period=variable_update_period, - policy_optimizer=policy_optimizer, - critic_optimizer=critic_optimizer, - min_replay_size=min_replay_size, - max_replay_size=max_replay_size, - samples_per_insert=samples_per_insert, - n_step=n_step, - sigma=sigma, - clipping=clipping, - )) - - def replay(self): - """The replay storage.""" - return self._builder.make_replay_tables(self._environment_spec) - - def counter(self): - return tf2_savers.CheckpointingRunner(counting.Counter(), - time_delta_minutes=1, - subdirectory='counter') - - def coordinator(self, counter: counting.Counter): - return lp_utils.StepsLimiter(counter, self._max_actor_steps) - - def learner( - self, - replay: reverb.Client, - counter: counting.Counter, - ): - """The Learning part of the agent.""" - - # If we are running on multiple accelerator devices, this replicates - # weights and updates across devices. - replicator = agent.get_replicator(self._accelerator) - - with replicator.scope(): - # Create the networks to optimize (online) and target networks. - online_networks = self._network_factory(self._environment_spec.actions) - target_networks = copy.deepcopy(online_networks) - - # Initialize the networks. - online_networks.init(self._environment_spec) - target_networks.init(self._environment_spec) - - dataset = self._builder.make_dataset_iterator(replay) - - counter = counting.Counter(counter, 'learner') - logger = loggers.make_default_logger( - 'learner', time_delta=self._log_every, steps_key='learner_steps') - - return self._builder.make_learner( - networks=(online_networks, target_networks), - dataset=dataset, - counter=counter, - logger=logger, - checkpoint=True, - ) - - def actor( - self, - replay: reverb.Client, - variable_source: acme.VariableSource, - counter: counting.Counter, - ) -> acme.EnvironmentLoop: - """The actor process.""" - - # Create the behavior policy. - networks = self._network_factory(self._environment_spec.actions) - networks.init(self._environment_spec) - policy_network = networks.make_policy( - environment_spec=self._environment_spec, - sigma=self._sigma, - ) - - # Create the agent. - actor = self._builder.make_actor( - policy_network=policy_network, - adder=self._builder.make_adder(replay), - variable_source=variable_source, - ) - - # Create the environment. - environment = self._environment_factory(False) - - # Create logger and counter; actors will not spam bigtable. - counter = counting.Counter(counter, 'actor') - logger = loggers.make_default_logger( - 'actor', - save_data=False, - time_delta=self._log_every, - steps_key='actor_steps') - - # Create the loop to connect environment and agent. - return acme.EnvironmentLoop(environment, actor, counter, logger) - - def evaluator( - self, - variable_source: acme.VariableSource, - counter: counting.Counter, - logger: Optional[loggers.Logger] = None, - ): - """The evaluation process.""" - - # Create the behavior policy. - networks = self._network_factory(self._environment_spec.actions) - networks.init(self._environment_spec) - policy_network = networks.make_policy(self._environment_spec) - - # Create the agent. - actor = self._builder.make_actor( - policy_network=policy_network, - variable_source=variable_source, - ) - - # Make the environment. - environment = self._environment_factory(True) - - # Create logger and counter. - counter = counting.Counter(counter, 'evaluator') - logger = logger or loggers.make_default_logger( - 'evaluator', - time_delta=self._log_every, - steps_key='evaluator_steps', - ) - - # Create the run loop and return it. - return acme.EnvironmentLoop(environment, actor, counter, logger) - - def build(self, name='d4pg'): - """Build the distributed agent topology.""" - program = lp.Program(name=name) - - with program.group('replay'): - replay = program.add_node(lp.ReverbNode(self.replay)) - - with program.group('counter'): - counter = program.add_node(lp.CourierNode(self.counter)) - - if self._max_actor_steps: - with program.group('coordinator'): - _ = program.add_node(lp.CourierNode(self.coordinator, counter)) - - with program.group('learner'): - learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) - - with program.group('evaluator'): - program.add_node(lp.CourierNode(self.evaluator, learner, counter)) - - if not self._num_caches: - # Use our learner as a single variable source. - sources = [learner] - else: - with program.group('cacher'): - # Create a set of learner caches. - sources = [] - for _ in range(self._num_caches): - cacher = program.add_node( - lp.CacherNode( - learner, refresh_interval_ms=2000, stale_after_ms=4000)) - sources.append(cacher) - - with program.group('actor'): - # Add actors which pull round-robin from our variable sources. - for actor_id in range(self._num_actors): - source = sources[actor_id % len(sources)] - program.add_node(lp.CourierNode(self.actor, replay, source, counter)) - - return program + """Program definition for D4PG.""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], + accelerator: Optional[str] = None, + num_actors: int = 1, + num_caches: int = 0, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = 4, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: Optional[float] = 32.0, + n_step: int = 5, + sigma: float = 0.3, + clipping: bool = True, + discount: float = 0.99, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + target_update_period: int = 100, + variable_update_period: int = 1000, + max_actor_steps: Optional[int] = None, + log_every: float = 10.0, + ): + + if accelerator is not None and accelerator not in _ACCELERATORS: + raise ValueError( + f"Accelerator must be one of {_ACCELERATORS}, " f'not "{accelerator}".' + ) + + if not environment_spec: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + # TODO(mwhoffman): Make network_factory directly return the struct. + # TODO(mwhoffman): Make the factory take the entire spec. + def wrapped_network_factory(action_spec): + networks_dict = network_factory(action_spec) + networks = agent.D4PGNetworks( + policy_network=networks_dict.get("policy"), + critic_network=networks_dict.get("critic"), + observation_network=networks_dict.get("observation", tf.identity), + ) + return networks + + self._environment_factory = environment_factory + self._network_factory = wrapped_network_factory + self._environment_spec = environment_spec + self._sigma = sigma + self._num_actors = num_actors + self._num_caches = num_caches + self._max_actor_steps = max_actor_steps + self._log_every = log_every + self._accelerator = accelerator + self._variable_update_period = variable_update_period + + self._builder = agent.D4PGBuilder( + # TODO(mwhoffman): pass the config dataclass in directly. + # TODO(mwhoffman): use the limiter rather than the workaround below. + agent.D4PGConfig( + accelerator=accelerator, + discount=discount, + batch_size=batch_size, + prefetch_size=prefetch_size, + target_update_period=target_update_period, + variable_update_period=variable_update_period, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + min_replay_size=min_replay_size, + max_replay_size=max_replay_size, + samples_per_insert=samples_per_insert, + n_step=n_step, + sigma=sigma, + clipping=clipping, + ) + ) + + def replay(self): + """The replay storage.""" + return self._builder.make_replay_tables(self._environment_spec) + + def counter(self): + return tf2_savers.CheckpointingRunner( + counting.Counter(), time_delta_minutes=1, subdirectory="counter" + ) + + def coordinator(self, counter: counting.Counter): + return lp_utils.StepsLimiter(counter, self._max_actor_steps) + + def learner( + self, replay: reverb.Client, counter: counting.Counter, + ): + """The Learning part of the agent.""" + + # If we are running on multiple accelerator devices, this replicates + # weights and updates across devices. + replicator = agent.get_replicator(self._accelerator) + + with replicator.scope(): + # Create the networks to optimize (online) and target networks. + online_networks = self._network_factory(self._environment_spec.actions) + target_networks = copy.deepcopy(online_networks) + + # Initialize the networks. + online_networks.init(self._environment_spec) + target_networks.init(self._environment_spec) + + dataset = self._builder.make_dataset_iterator(replay) + + counter = counting.Counter(counter, "learner") + logger = loggers.make_default_logger( + "learner", time_delta=self._log_every, steps_key="learner_steps" + ) + + return self._builder.make_learner( + networks=(online_networks, target_networks), + dataset=dataset, + counter=counter, + logger=logger, + checkpoint=True, + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + ) -> acme.EnvironmentLoop: + """The actor process.""" + + # Create the behavior policy. + networks = self._network_factory(self._environment_spec.actions) + networks.init(self._environment_spec) + policy_network = networks.make_policy( + environment_spec=self._environment_spec, sigma=self._sigma, + ) + + # Create the agent. + actor = self._builder.make_actor( + policy_network=policy_network, + adder=self._builder.make_adder(replay), + variable_source=variable_source, + ) + + # Create the environment. + environment = self._environment_factory(False) + + # Create logger and counter; actors will not spam bigtable. + counter = counting.Counter(counter, "actor") + logger = loggers.make_default_logger( + "actor", + save_data=False, + time_delta=self._log_every, + steps_key="actor_steps", + ) + + # Create the loop to connect environment and agent. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, + variable_source: acme.VariableSource, + counter: counting.Counter, + logger: Optional[loggers.Logger] = None, + ): + """The evaluation process.""" + + # Create the behavior policy. + networks = self._network_factory(self._environment_spec.actions) + networks.init(self._environment_spec) + policy_network = networks.make_policy(self._environment_spec) + + # Create the agent. + actor = self._builder.make_actor( + policy_network=policy_network, variable_source=variable_source, + ) + + # Make the environment. + environment = self._environment_factory(True) + + # Create logger and counter. + counter = counting.Counter(counter, "evaluator") + logger = logger or loggers.make_default_logger( + "evaluator", time_delta=self._log_every, steps_key="evaluator_steps", + ) + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def build(self, name="d4pg"): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group("replay"): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group("counter"): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + with program.group("coordinator"): + _ = program.add_node(lp.CourierNode(self.coordinator, counter)) + + with program.group("learner"): + learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) + + with program.group("evaluator"): + program.add_node(lp.CourierNode(self.evaluator, learner, counter)) + + if not self._num_caches: + # Use our learner as a single variable source. + sources = [learner] + else: + with program.group("cacher"): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000 + ) + ) + sources.append(cacher) + + with program.group("actor"): + # Add actors which pull round-robin from our variable sources. + for actor_id in range(self._num_actors): + source = sources[actor_id % len(sources)] + program.add_node(lp.CourierNode(self.actor, replay, source, counter)) + + return program diff --git a/acme/agents/tf/d4pg/agent_distributed_test.py b/acme/agents/tf/d4pg/agent_distributed_test.py index cc71b8908a..a72ba4e10e 100644 --- a/acme/agents/tf/d4pg/agent_distributed_test.py +++ b/acme/agents/tf/d4pg/agent_distributed_test.py @@ -14,71 +14,75 @@ """Integration test for the distributed agent.""" +import launchpad as lp +import numpy as np +import sonnet as snt +from absl.testing import absltest + import acme from acme import specs from acme.agents.tf import d4pg from acme.testing import fakes from acme.tf import networks from acme.tf import utils as tf2_utils -import launchpad as lp -import numpy as np -import sonnet as snt - -from absl.testing import absltest def make_networks(action_spec: specs.BoundedArray): - """Simple networks for testing..""" - - num_dimensions = np.prod(action_spec.shape, dtype=int) - - policy_network = snt.Sequential([ - networks.LayerNormMLP([50], activate_final=True), - networks.NearZeroInitializedLinear(num_dimensions), - networks.TanhToSpec(action_spec) - ]) - # The multiplexer concatenates the (maybe transformed) observations/actions. - critic_network = snt.Sequential([ - networks.CriticMultiplexer( - critic_network=networks.LayerNormMLP( - [50], activate_final=True)), - networks.DiscreteValuedHead(-1., 1., 10) - ]) - - return { - 'policy': policy_network, - 'critic': critic_network, - 'observation': tf2_utils.batch_concat, - } + """Simple networks for testing..""" + num_dimensions = np.prod(action_spec.shape, dtype=int) -class DistributedAgentTest(absltest.TestCase): - """Simple integration/smoke test for the distributed agent.""" - - def test_control_suite(self): - """Tests that the agent can run on the control suite without crashing.""" - - agent = d4pg.DistributedD4PG( - environment_factory=lambda x: fakes.ContinuousEnvironment(bounded=True), - network_factory=make_networks, - accelerator='CPU', - num_actors=2, - batch_size=32, - min_replay_size=32, - max_replay_size=1000, + policy_network = snt.Sequential( + [ + networks.LayerNormMLP([50], activate_final=True), + networks.NearZeroInitializedLinear(num_dimensions), + networks.TanhToSpec(action_spec), + ] + ) + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = snt.Sequential( + [ + networks.CriticMultiplexer( + critic_network=networks.LayerNormMLP([50], activate_final=True) + ), + networks.DiscreteValuedHead(-1.0, 1.0, 10), + ] ) - program = agent.build() - (learner_node,) = program.groups['learner'] - learner_node.disable_run() + return { + "policy": policy_network, + "critic": critic_network, + "observation": tf2_utils.batch_concat, + } + + +class DistributedAgentTest(absltest.TestCase): + """Simple integration/smoke test for the distributed agent.""" + + def test_control_suite(self): + """Tests that the agent can run on the control suite without crashing.""" + + agent = d4pg.DistributedD4PG( + environment_factory=lambda x: fakes.ContinuousEnvironment(bounded=True), + network_factory=make_networks, + accelerator="CPU", + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + ) + program = agent.build() + + (learner_node,) = program.groups["learner"] + learner_node.disable_run() - lp.launch(program, launch_type='test_mt') + lp.launch(program, launch_type="test_mt") - learner: acme.Learner = learner_node.create_handle().dereference() + learner: acme.Learner = learner_node.create_handle().dereference() - for _ in range(5): - learner.step() + for _ in range(5): + learner.step() -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/d4pg/agent_test.py b/acme/agents/tf/d4pg/agent_test.py index 10b89b9f94..d51059d110 100644 --- a/acme/agents/tf/d4pg/agent_test.py +++ b/acme/agents/tf/d4pg/agent_test.py @@ -17,75 +17,78 @@ import sys from typing import Dict, Sequence -import acme -from acme import specs -from acme import types -from acme.agents.tf import d4pg -from acme.testing import fakes -from acme.tf import networks import numpy as np import sonnet as snt import tensorflow as tf - from absl.testing import absltest +import acme +from acme import specs, types +from acme.agents.tf import d4pg +from acme.testing import fakes +from acme.tf import networks + def make_networks( action_spec: types.NestedSpec, policy_layer_sizes: Sequence[int] = (10, 10), critic_layer_sizes: Sequence[int] = (10, 10), - vmin: float = -150., - vmax: float = 150., + vmin: float = -150.0, + vmax: float = 150.0, num_atoms: int = 51, ) -> Dict[str, snt.Module]: - """Creates networks used by the agent.""" - - num_dimensions = np.prod(action_spec.shape, dtype=int) - policy_layer_sizes = list(policy_layer_sizes) + [num_dimensions] - - policy_network = snt.Sequential( - [networks.LayerNormMLP(policy_layer_sizes), tf.tanh]) - critic_network = snt.Sequential([ - networks.CriticMultiplexer( - critic_network=networks.LayerNormMLP( - critic_layer_sizes, activate_final=True)), - networks.DiscreteValuedHead(vmin, vmax, num_atoms) - ]) + """Creates networks used by the agent.""" - return { - 'policy': policy_network, - 'critic': critic_network, - } + num_dimensions = np.prod(action_spec.shape, dtype=int) + policy_layer_sizes = list(policy_layer_sizes) + [num_dimensions] - -class D4PGTest(absltest.TestCase): - - def test_d4pg(self): - # Create a fake environment to test with. - environment = fakes.ContinuousEnvironment(episode_length=10, bounded=True) - spec = specs.make_environment_spec(environment) - - # Create the networks. - agent_networks = make_networks(spec.actions) - - # Construct the agent. - agent = d4pg.D4PG( - environment_spec=spec, - accelerator='CPU', - policy_network=agent_networks['policy'], - critic_network=agent_networks['critic'], - batch_size=10, - samples_per_insert=2, - min_replay_size=10, + policy_network = snt.Sequential( + [networks.LayerNormMLP(policy_layer_sizes), tf.tanh] + ) + critic_network = snt.Sequential( + [ + networks.CriticMultiplexer( + critic_network=networks.LayerNormMLP( + critic_layer_sizes, activate_final=True + ) + ), + networks.DiscreteValuedHead(vmin, vmax, num_atoms), + ] ) - # Try running the environment loop. We have no assertions here because all - # we care about is that the agent runs without raising any errors. - loop = acme.EnvironmentLoop(environment, agent) - loop.run(num_episodes=2) - - # Imports check + return { + "policy": policy_network, + "critic": critic_network, + } -if __name__ == '__main__': - absltest.main() +class D4PGTest(absltest.TestCase): + def test_d4pg(self): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment(episode_length=10, bounded=True) + spec = specs.make_environment_spec(environment) + + # Create the networks. + agent_networks = make_networks(spec.actions) + + # Construct the agent. + agent = d4pg.D4PG( + environment_spec=spec, + accelerator="CPU", + policy_network=agent_networks["policy"], + critic_network=agent_networks["critic"], + batch_size=10, + samples_per_insert=2, + min_replay_size=10, + ) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) + + # Imports check + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/d4pg/learning.py b/acme/agents/tf/d4pg/learning.py index b167fb982d..1959452936 100644 --- a/acme/agents/tf/d4pg/learning.py +++ b/acme/agents/tf/d4pg/learning.py @@ -15,7 +15,13 @@ """D4PG learner implementation.""" import time -from typing import Dict, Iterator, List, Optional, Union, Sequence +from typing import Dict, Iterator, List, Optional, Sequence, Union + +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf +import tree import acme from acme import types @@ -23,44 +29,38 @@ from acme.tf import networks as acme_nets from acme.tf import savers as tf2_savers from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers -import numpy as np -import reverb -import sonnet as snt -import tensorflow as tf -import tree +from acme.utils import counting, loggers Replicator = Union[snt.distribute.Replicator, snt.distribute.TpuReplicator] class D4PGLearner(acme.Learner): - """D4PG learner. + """D4PG learner. This is the learning component of a D4PG agent. IE it takes a dataset as input and implements update functionality to learn from this dataset. """ - def __init__( - self, - policy_network: snt.Module, - critic_network: snt.Module, - target_policy_network: snt.Module, - target_critic_network: snt.Module, - discount: float, - target_update_period: int, - dataset_iterator: Iterator[reverb.ReplaySample], - replicator: Optional[Replicator] = None, - observation_network: types.TensorTransformation = lambda x: x, - target_observation_network: types.TensorTransformation = lambda x: x, - policy_optimizer: Optional[snt.Optimizer] = None, - critic_optimizer: Optional[snt.Optimizer] = None, - clipping: bool = True, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - checkpoint: bool = True, - ): - """Initializes the learner. + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + target_policy_network: snt.Module, + target_critic_network: snt.Module, + discount: float, + target_update_period: int, + dataset_iterator: Iterator[reverb.ReplaySample], + replicator: Optional[Replicator] = None, + observation_network: types.TensorTransformation = lambda x: x, + target_observation_network: types.TensorTransformation = lambda x: x, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + clipping: bool = True, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + """Initializes the learner. Args: policy_network: the online (optimized) policy. @@ -87,232 +87,237 @@ def __init__( checkpoint: boolean indicating whether to checkpoint the learner. """ - # Store online and target networks. - self._policy_network = policy_network - self._critic_network = critic_network - self._target_policy_network = target_policy_network - self._target_critic_network = target_critic_network - - # Make sure observation networks are snt.Module's so they have variables. - self._observation_network = tf2_utils.to_sonnet_module(observation_network) - self._target_observation_network = tf2_utils.to_sonnet_module( - target_observation_network) - - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger('learner') - - # Other learner parameters. - self._discount = discount - self._clipping = clipping - - # Replicates Variables across multiple accelerators - if not replicator: - accelerator = _get_first_available_accelerator_type() - if accelerator == 'TPU': - replicator = snt.distribute.TpuReplicator() - else: - replicator = snt.distribute.Replicator() - - self._replicator = replicator - - with replicator.scope(): - # Necessary to track when to update target networks. - self._num_steps = tf.Variable(0, dtype=tf.int32) - self._target_update_period = target_update_period - - # Create optimizers if they aren't given. - self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) - self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) - - # Batch dataset and create iterator. - self._iterator = dataset_iterator - - # Expose the variables. - policy_network_to_expose = snt.Sequential( - [self._target_observation_network, self._target_policy_network]) - self._variables = { - 'critic': self._target_critic_network.variables, - 'policy': policy_network_to_expose.variables, - } - - # Create a checkpointer and snapshotter objects. - self._checkpointer = None - self._snapshotter = None - - if checkpoint: - self._checkpointer = tf2_savers.Checkpointer( - subdirectory='d4pg_learner', - objects_to_save={ - 'counter': self._counter, - 'policy': self._policy_network, - 'critic': self._critic_network, - 'observation': self._observation_network, - 'target_policy': self._target_policy_network, - 'target_critic': self._target_critic_network, - 'target_observation': self._target_observation_network, - 'policy_optimizer': self._policy_optimizer, - 'critic_optimizer': self._critic_optimizer, - 'num_steps': self._num_steps, - }) - critic_mean = snt.Sequential( - [self._critic_network, acme_nets.StochasticMeanHead()]) - self._snapshotter = tf2_savers.Snapshotter( - objects_to_save={ - 'policy': self._policy_network, - 'critic': critic_mean, - }) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - @tf.function - def _step(self, sample) -> Dict[str, tf.Tensor]: - transitions: types.Transition = sample.data # Assuming ReverbSample. - - # Cast the additional discount to match the environment discount dtype. - discount = tf.cast(self._discount, dtype=transitions.discount.dtype) - - with tf.GradientTape(persistent=True) as tape: - # Maybe transform the observation before feeding into policy and critic. - # Transforming the observations this way at the start of the learning - # step effectively means that the policy and critic share observation - # network weights. - o_tm1 = self._observation_network(transitions.observation) - o_t = self._target_observation_network(transitions.next_observation) - # This stop_gradient prevents gradients to propagate into the target - # observation network. In addition, since the online policy network is - # evaluated at o_t, this also means the policy loss does not influence - # the observation network training. - o_t = tree.map_structure(tf.stop_gradient, o_t) - - # Critic learning. - q_tm1 = self._critic_network(o_tm1, transitions.action) - q_t = self._target_critic_network(o_t, self._target_policy_network(o_t)) - - # Critic loss. - critic_loss = losses.categorical(q_tm1, transitions.reward, - discount * transitions.discount, q_t) - critic_loss = tf.reduce_mean(critic_loss, axis=[0]) - - # Actor learning. - dpg_a_t = self._policy_network(o_t) - dpg_z_t = self._critic_network(o_t, dpg_a_t) - dpg_q_t = dpg_z_t.mean() - - # Actor loss. If clipping is true use dqda clipping and clip the norm. - dqda_clipping = 1.0 if self._clipping else None - policy_loss = losses.dpg( - dpg_q_t, - dpg_a_t, - tape=tape, - dqda_clipping=dqda_clipping, - clip_norm=self._clipping) - policy_loss = tf.reduce_mean(policy_loss, axis=[0]) - - # Get trainable variables. - policy_variables = self._policy_network.trainable_variables - critic_variables = ( - # In this agent, the critic loss trains the observation network. - self._observation_network.trainable_variables + - self._critic_network.trainable_variables) - - # Compute gradients. - replica_context = tf.distribute.get_replica_context() - policy_gradients = _average_gradients_across_replicas( - replica_context, - tape.gradient(policy_loss, policy_variables)) - critic_gradients = _average_gradients_across_replicas( - replica_context, - tape.gradient(critic_loss, critic_variables)) - - # Delete the tape manually because of the persistent=True flag. - del tape - - # Maybe clip gradients. - if self._clipping: - policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0] - critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0] - - # Apply gradients. - self._policy_optimizer.apply(policy_gradients, policy_variables) - self._critic_optimizer.apply(critic_gradients, critic_variables) - - # Losses to track. - return { - 'critic_loss': critic_loss, - 'policy_loss': policy_loss, - } - - @tf.function - def _replicated_step(self): - # Update target network - online_variables = ( - *self._observation_network.variables, - *self._critic_network.variables, - *self._policy_network.variables, - ) - target_variables = ( - *self._target_observation_network.variables, - *self._target_critic_network.variables, - *self._target_policy_network.variables, - ) - - # Make online -> target network update ops. - if tf.math.mod(self._num_steps, self._target_update_period) == 0: - for src, dest in zip(online_variables, target_variables): - dest.assign(src) - self._num_steps.assign_add(1) - - # Get data from replay (dropping extras if any). Note there is no - # extra data here because we do not insert any into Reverb. - sample = next(self._iterator) - - # This mirrors the structure of the fetches returned by self._step(), - # but the Tensors are replaced with replicated Tensors, one per accelerator. - replicated_fetches = self._replicator.run(self._step, args=(sample,)) - - def reduce_mean_over_replicas(replicated_value): - """Averages a replicated_value across replicas.""" - # The "axis=None" arg means reduce across replicas, not internal axes. - return self._replicator.reduce( - reduce_op=tf.distribute.ReduceOp.MEAN, - value=replicated_value, - axis=None) - - fetches = tree.map_structure(reduce_mean_over_replicas, replicated_fetches) - - return fetches - - def step(self): - # Run the learning step. - fetches = self._replicated_step() - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Update our counts and record it. - counts = self._counter.increment(steps=1, walltime=elapsed_time) - fetches.update(counts) - - # Checkpoint and attempt to write the logs. - if self._checkpointer is not None: - self._checkpointer.save() - if self._snapshotter is not None: - self._snapshotter.save() - self._logger.write(fetches) - - def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: - return [tf2_utils.to_numpy(self._variables[name]) for name in names] + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + + # Make sure observation networks are snt.Module's so they have variables. + self._observation_network = tf2_utils.to_sonnet_module(observation_network) + self._target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network + ) + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger("learner") + + # Other learner parameters. + self._discount = discount + self._clipping = clipping + + # Replicates Variables across multiple accelerators + if not replicator: + accelerator = _get_first_available_accelerator_type() + if accelerator == "TPU": + replicator = snt.distribute.TpuReplicator() + else: + replicator = snt.distribute.Replicator() + + self._replicator = replicator + + with replicator.scope(): + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_update_period = target_update_period + + # Create optimizers if they aren't given. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + + # Batch dataset and create iterator. + self._iterator = dataset_iterator + + # Expose the variables. + policy_network_to_expose = snt.Sequential( + [self._target_observation_network, self._target_policy_network] + ) + self._variables = { + "critic": self._target_critic_network.variables, + "policy": policy_network_to_expose.variables, + } + + # Create a checkpointer and snapshotter objects. + self._checkpointer = None + self._snapshotter = None + + if checkpoint: + self._checkpointer = tf2_savers.Checkpointer( + subdirectory="d4pg_learner", + objects_to_save={ + "counter": self._counter, + "policy": self._policy_network, + "critic": self._critic_network, + "observation": self._observation_network, + "target_policy": self._target_policy_network, + "target_critic": self._target_critic_network, + "target_observation": self._target_observation_network, + "policy_optimizer": self._policy_optimizer, + "critic_optimizer": self._critic_optimizer, + "num_steps": self._num_steps, + }, + ) + critic_mean = snt.Sequential( + [self._critic_network, acme_nets.StochasticMeanHead()] + ) + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={"policy": self._policy_network, "critic": critic_mean,} + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self, sample) -> Dict[str, tf.Tensor]: + transitions: types.Transition = sample.data # Assuming ReverbSample. + + # Cast the additional discount to match the environment discount dtype. + discount = tf.cast(self._discount, dtype=transitions.discount.dtype) + + with tf.GradientTape(persistent=True) as tape: + # Maybe transform the observation before feeding into policy and critic. + # Transforming the observations this way at the start of the learning + # step effectively means that the policy and critic share observation + # network weights. + o_tm1 = self._observation_network(transitions.observation) + o_t = self._target_observation_network(transitions.next_observation) + # This stop_gradient prevents gradients to propagate into the target + # observation network. In addition, since the online policy network is + # evaluated at o_t, this also means the policy loss does not influence + # the observation network training. + o_t = tree.map_structure(tf.stop_gradient, o_t) + + # Critic learning. + q_tm1 = self._critic_network(o_tm1, transitions.action) + q_t = self._target_critic_network(o_t, self._target_policy_network(o_t)) + + # Critic loss. + critic_loss = losses.categorical( + q_tm1, transitions.reward, discount * transitions.discount, q_t + ) + critic_loss = tf.reduce_mean(critic_loss, axis=[0]) + + # Actor learning. + dpg_a_t = self._policy_network(o_t) + dpg_z_t = self._critic_network(o_t, dpg_a_t) + dpg_q_t = dpg_z_t.mean() + + # Actor loss. If clipping is true use dqda clipping and clip the norm. + dqda_clipping = 1.0 if self._clipping else None + policy_loss = losses.dpg( + dpg_q_t, + dpg_a_t, + tape=tape, + dqda_clipping=dqda_clipping, + clip_norm=self._clipping, + ) + policy_loss = tf.reduce_mean(policy_loss, axis=[0]) + + # Get trainable variables. + policy_variables = self._policy_network.trainable_variables + critic_variables = ( + # In this agent, the critic loss trains the observation network. + self._observation_network.trainable_variables + + self._critic_network.trainable_variables + ) + + # Compute gradients. + replica_context = tf.distribute.get_replica_context() + policy_gradients = _average_gradients_across_replicas( + replica_context, tape.gradient(policy_loss, policy_variables) + ) + critic_gradients = _average_gradients_across_replicas( + replica_context, tape.gradient(critic_loss, critic_variables) + ) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Maybe clip gradients. + if self._clipping: + policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.0)[0] + critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.0)[0] + + # Apply gradients. + self._policy_optimizer.apply(policy_gradients, policy_variables) + self._critic_optimizer.apply(critic_gradients, critic_variables) + + # Losses to track. + return { + "critic_loss": critic_loss, + "policy_loss": policy_loss, + } + + @tf.function + def _replicated_step(self): + # Update target network + online_variables = ( + *self._observation_network.variables, + *self._critic_network.variables, + *self._policy_network.variables, + ) + target_variables = ( + *self._target_observation_network.variables, + *self._target_critic_network.variables, + *self._target_policy_network.variables, + ) + + # Make online -> target network update ops. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(online_variables, target_variables): + dest.assign(src) + self._num_steps.assign_add(1) + + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + sample = next(self._iterator) + + # This mirrors the structure of the fetches returned by self._step(), + # but the Tensors are replaced with replicated Tensors, one per accelerator. + replicated_fetches = self._replicator.run(self._step, args=(sample,)) + + def reduce_mean_over_replicas(replicated_value): + """Averages a replicated_value across replicas.""" + # The "axis=None" arg means reduce across replicas, not internal axes. + return self._replicator.reduce( + reduce_op=tf.distribute.ReduceOp.MEAN, value=replicated_value, axis=None + ) + + fetches = tree.map_structure(reduce_mean_over_replicas, replicated_fetches) + + return fetches + + def step(self): + # Run the learning step. + fetches = self._replicated_step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + if self._checkpointer is not None: + self._checkpointer.save() + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] def _get_first_available_accelerator_type( - wishlist: Sequence[str] = ('TPU', 'GPU', 'CPU')) -> str: - """Returns the first available accelerator type listed in a wishlist. + wishlist: Sequence[str] = ("TPU", "GPU", "CPU") +) -> str: + """Returns the first available accelerator type listed in a wishlist. Args: wishlist: A sequence of elements from {'CPU', 'GPU', 'TPU'}, listed in @@ -324,22 +329,22 @@ def _get_first_available_accelerator_type( Raises: RuntimeError: Thrown if no accelerators from the `wishlist` are found. """ - get_visible_devices = tf.config.get_visible_devices + get_visible_devices = tf.config.get_visible_devices - for wishlist_device in wishlist: - devices = get_visible_devices(device_type=wishlist_device) - if devices: - return wishlist_device + for wishlist_device in wishlist: + devices = get_visible_devices(device_type=wishlist_device) + if devices: + return wishlist_device - available = ', '.join( - sorted(frozenset([d.type for d in get_visible_devices()]))) - raise RuntimeError( - 'Couldn\'t find any devices from {wishlist}.' + - f'Only the following types are available: {available}.') + available = ", ".join(sorted(frozenset([d.type for d in get_visible_devices()]))) + raise RuntimeError( + "Couldn't find any devices from {wishlist}." + + f"Only the following types are available: {available}." + ) def _average_gradients_across_replicas(replica_context, gradients): - """Computes the average gradient across replicas. + """Computes the average gradient across replicas. This computes the gradient locally on this device, then copies over the gradients computed on the other replicas, and takes the average across @@ -356,17 +361,16 @@ def _average_gradients_across_replicas(replica_context, gradients): A list of (d_loss/d_varabiable)s. """ - # We must remove any Nones from gradients before passing them to all_reduce. - # Nones occur when you call tape.gradient(loss, variables) with some - # variables that don't affect the loss. - # See: https://github.com/tensorflow/tensorflow/issues/783 - gradients_without_nones = [g for g in gradients if g is not None] - original_indices = [i for i, g in enumerate(gradients) if g is not None] + # We must remove any Nones from gradients before passing them to all_reduce. + # Nones occur when you call tape.gradient(loss, variables) with some + # variables that don't affect the loss. + # See: https://github.com/tensorflow/tensorflow/issues/783 + gradients_without_nones = [g for g in gradients if g is not None] + original_indices = [i for i, g in enumerate(gradients) if g is not None] - results_without_nones = replica_context.all_reduce('mean', - gradients_without_nones) - results = [None] * len(gradients) - for ii, result in zip(original_indices, results_without_nones): - results[ii] = result + results_without_nones = replica_context.all_reduce("mean", gradients_without_nones) + results = [None] * len(gradients) + for ii, result in zip(original_indices, results_without_nones): + results[ii] = result - return results + return results diff --git a/acme/agents/tf/d4pg/networks.py b/acme/agents/tf/d4pg/networks.py index d0a225c12c..703d98baec 100644 --- a/acme/agents/tf/d4pg/networks.py +++ b/acme/agents/tf/d4pg/networks.py @@ -16,48 +16,51 @@ from typing import Mapping, Sequence -from acme import specs -from acme import types -from acme.tf import networks -from acme.tf import utils as tf2_utils - import numpy as np import sonnet as snt +from acme import specs, types +from acme.tf import networks +from acme.tf import utils as tf2_utils + def make_default_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), - vmin: float = -150., - vmax: float = 150., + vmin: float = -150.0, + vmax: float = 150.0, num_atoms: int = 51, ) -> Mapping[str, types.TensorTransformation]: - """Creates networks used by the agent.""" - - # Get total number of action dimensions from action spec. - num_dimensions = np.prod(action_spec.shape, dtype=int) - - # Create the shared observation network; here simply a state-less operation. - observation_network = tf2_utils.batch_concat - - # Create the policy network. - policy_network = snt.Sequential([ - networks.LayerNormMLP(policy_layer_sizes, activate_final=True), - networks.NearZeroInitializedLinear(num_dimensions), - networks.TanhToSpec(action_spec), - ]) - - # Create the critic network. - critic_network = snt.Sequential([ - # The multiplexer concatenates the observations/actions. - networks.CriticMultiplexer(), - networks.LayerNormMLP(critic_layer_sizes, activate_final=True), - networks.DiscreteValuedHead(vmin, vmax, num_atoms), - ]) - - return { - 'policy': policy_network, - 'critic': critic_network, - 'observation': observation_network, - } + """Creates networks used by the agent.""" + + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + + # Create the shared observation network; here simply a state-less operation. + observation_network = tf2_utils.batch_concat + + # Create the policy network. + policy_network = snt.Sequential( + [ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(num_dimensions), + networks.TanhToSpec(action_spec), + ] + ) + + # Create the critic network. + critic_network = snt.Sequential( + [ + # The multiplexer concatenates the observations/actions. + networks.CriticMultiplexer(), + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.DiscreteValuedHead(vmin, vmax, num_atoms), + ] + ) + + return { + "policy": policy_network, + "critic": critic_network, + "observation": observation_network, + } diff --git a/acme/agents/tf/ddpg/agent.py b/acme/agents/tf/ddpg/agent.py index 4e6ea7915d..11ef0c5051 100644 --- a/acme/agents/tf/ddpg/agent.py +++ b/acme/agents/tf/ddpg/agent.py @@ -17,24 +17,22 @@ import copy from typing import Optional -from acme import datasets -from acme import specs -from acme import types +import reverb +import sonnet as snt +import tensorflow as tf + +from acme import datasets, specs, types from acme.adders import reverb as adders from acme.agents import agent from acme.agents.tf import actors from acme.agents.tf.ddpg import learning from acme.tf import networks from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers -import reverb -import sonnet as snt -import tensorflow as tf +from acme.utils import counting, loggers class DDPG(agent.Agent): - """DDPG Agent. + """DDPG Agent. This implements a single-process DDPG agent. This is an actor-critic algorithm that generates data via a behavior policy, inserts N-step transitions into @@ -42,26 +40,28 @@ class DDPG(agent.Agent): behavior) by sampling uniformly from this buffer. """ - def __init__(self, - environment_spec: specs.EnvironmentSpec, - policy_network: snt.Module, - critic_network: snt.Module, - observation_network: types.TensorTransformation = tf.identity, - discount: float = 0.99, - batch_size: int = 256, - prefetch_size: int = 4, - target_update_period: int = 100, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: float = 32.0, - n_step: int = 5, - sigma: float = 0.3, - clipping: bool = True, - logger: Optional[loggers.Logger] = None, - counter: Optional[counting.Counter] = None, - checkpoint: bool = True, - replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): - """Initialize the agent. + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + policy_network: snt.Module, + critic_network: snt.Module, + observation_network: types.TensorTransformation = tf.identity, + discount: float = 0.99, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 32.0, + n_step: int = 5, + sigma: float = 0.3, + clipping: bool = True, + logger: Optional[loggers.Logger] = None, + counter: Optional[counting.Counter] = None, + checkpoint: bool = True, + replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, + ): + """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. @@ -86,88 +86,94 @@ def __init__(self, checkpoint: boolean indicating whether to checkpoint the learner. replay_table_name: string indicating what name to give the replay table. """ - # Create a replay server to add data to. This uses no limiter behavior in - # order to allow the Agent interface to handle it. - replay_table = reverb.Table( - name=replay_table_name, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=max_replay_size, - rate_limiter=reverb.rate_limiters.MinSize(1), - signature=adders.NStepTransitionAdder.signature(environment_spec)) - self._server = reverb.Server([replay_table], port=None) - - # The adder is used to insert observations into replay. - address = f'localhost:{self._server.port}' - adder = adders.NStepTransitionAdder( - priority_fns={replay_table_name: lambda x: 1.}, - client=reverb.Client(address), - n_step=n_step, - discount=discount) - - # The dataset provides an interface to sample from replay. - dataset = datasets.make_reverb_dataset( - table=replay_table_name, - server_address=address, - batch_size=batch_size, - prefetch_size=prefetch_size) - - # Make sure observation network is a Sonnet Module. - observation_network = tf2_utils.to_sonnet_module(observation_network) - - # Get observation and action specs. - act_spec = environment_spec.actions - obs_spec = environment_spec.observations - emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) - - # Create target networks. - target_policy_network = copy.deepcopy(policy_network) - target_critic_network = copy.deepcopy(critic_network) - target_observation_network = copy.deepcopy(observation_network) - - # Create the behavior policy. - behavior_network = snt.Sequential([ - observation_network, - policy_network, - networks.ClippedGaussian(sigma), - networks.ClipToSpec(act_spec), - ]) - - # Create variables. - tf2_utils.create_variables(policy_network, [emb_spec]) - tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) - tf2_utils.create_variables(target_policy_network, [emb_spec]) - tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) - tf2_utils.create_variables(target_observation_network, [obs_spec]) - - # Create the actor which defines how we take actions. - actor = actors.FeedForwardActor(behavior_network, adder=adder) - - # Create optimizers. - policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4) - critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4) - - # The learner updates the parameters (and initializes them). - learner = learning.DDPGLearner( - policy_network=policy_network, - critic_network=critic_network, - observation_network=observation_network, - target_policy_network=target_policy_network, - target_critic_network=target_critic_network, - target_observation_network=target_observation_network, - policy_optimizer=policy_optimizer, - critic_optimizer=critic_optimizer, - clipping=clipping, - discount=discount, - target_update_period=target_update_period, - dataset=dataset, - counter=counter, - logger=logger, - checkpoint=checkpoint, - ) - - super().__init__( - actor=actor, - learner=learner, - min_observations=max(batch_size, min_replay_size), - observations_per_step=float(batch_size) / samples_per_insert) + # Create a replay server to add data to. This uses no limiter behavior in + # order to allow the Agent interface to handle it. + replay_table = reverb.Table( + name=replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(1), + signature=adders.NStepTransitionAdder.signature(environment_spec), + ) + self._server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f"localhost:{self._server.port}" + adder = adders.NStepTransitionAdder( + priority_fns={replay_table_name: lambda x: 1.0}, + client=reverb.Client(address), + n_step=n_step, + discount=discount, + ) + + # The dataset provides an interface to sample from replay. + dataset = datasets.make_reverb_dataset( + table=replay_table_name, + server_address=address, + batch_size=batch_size, + prefetch_size=prefetch_size, + ) + + # Make sure observation network is a Sonnet Module. + observation_network = tf2_utils.to_sonnet_module(observation_network) + + # Get observation and action specs. + act_spec = environment_spec.actions + obs_spec = environment_spec.observations + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + + # Create target networks. + target_policy_network = copy.deepcopy(policy_network) + target_critic_network = copy.deepcopy(critic_network) + target_observation_network = copy.deepcopy(observation_network) + + # Create the behavior policy. + behavior_network = snt.Sequential( + [ + observation_network, + policy_network, + networks.ClippedGaussian(sigma), + networks.ClipToSpec(act_spec), + ] + ) + + # Create variables. + tf2_utils.create_variables(policy_network, [emb_spec]) + tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) + tf2_utils.create_variables(target_policy_network, [emb_spec]) + tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) + tf2_utils.create_variables(target_observation_network, [obs_spec]) + + # Create the actor which defines how we take actions. + actor = actors.FeedForwardActor(behavior_network, adder=adder) + + # Create optimizers. + policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4) + critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4) + + # The learner updates the parameters (and initializes them). + learner = learning.DDPGLearner( + policy_network=policy_network, + critic_network=critic_network, + observation_network=observation_network, + target_policy_network=target_policy_network, + target_critic_network=target_critic_network, + target_observation_network=target_observation_network, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + clipping=clipping, + discount=discount, + target_update_period=target_update_period, + dataset=dataset, + counter=counter, + logger=logger, + checkpoint=checkpoint, + ) + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert, + ) diff --git a/acme/agents/tf/ddpg/agent_distributed.py b/acme/agents/tf/ddpg/agent_distributed.py index f9f852d48b..e3f0ff98ee 100644 --- a/acme/agents/tf/ddpg/agent_distributed.py +++ b/acme/agents/tf/ddpg/agent_distributed.py @@ -16,9 +16,14 @@ from typing import Callable, Dict, Optional +import dm_env +import launchpad as lp +import reverb +import sonnet as snt +import tensorflow as tf + import acme -from acme import datasets -from acme import specs +from acme import datasets, specs from acme.adders import reverb as adders from acme.agents.tf import actors from acme.agents.tf.ddpg import learning @@ -26,294 +31,298 @@ from acme.tf import savers as tf2_savers from acme.tf import utils as tf2_utils from acme.tf import variable_utils as tf2_variable_utils -from acme.utils import counting -from acme.utils import loggers -from acme.utils import lp_utils -import dm_env -import launchpad as lp -import reverb -import sonnet as snt -import tensorflow as tf +from acme.utils import counting, loggers, lp_utils class DistributedDDPG: - """Program definition for distributed DDPG (D3PG).""" - - def __init__( - self, - environment_factory: Callable[[bool], dm_env.Environment], - network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], - num_actors: int = 1, - num_caches: int = 0, - environment_spec: Optional[specs.EnvironmentSpec] = None, - batch_size: int = 256, - prefetch_size: int = 4, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: Optional[float] = 32.0, - n_step: int = 5, - sigma: float = 0.3, - clipping: bool = True, - discount: float = 0.99, - target_update_period: int = 100, - variable_update_period: int = 1000, - max_actor_steps: Optional[int] = None, - log_every: float = 10.0, - ): - - if not environment_spec: - environment_spec = specs.make_environment_spec(environment_factory(False)) - - self._environment_factory = environment_factory - self._network_factory = network_factory - self._environment_spec = environment_spec - self._num_actors = num_actors - self._num_caches = num_caches - self._batch_size = batch_size - self._prefetch_size = prefetch_size - self._min_replay_size = min_replay_size - self._max_replay_size = max_replay_size - self._samples_per_insert = samples_per_insert - self._n_step = n_step - self._sigma = sigma - self._clipping = clipping - self._discount = discount - self._target_update_period = target_update_period - self._variable_update_period = variable_update_period - self._max_actor_steps = max_actor_steps - self._log_every = log_every - - def replay(self): - """The replay storage.""" - if self._samples_per_insert is not None: - # Create enough of an error buffer to give a 10% tolerance in rate. - samples_per_insert_tolerance = 0.1 * self._samples_per_insert - error_buffer = self._min_replay_size * samples_per_insert_tolerance - - limiter = reverb.rate_limiters.SampleToInsertRatio( - min_size_to_sample=self._min_replay_size, - samples_per_insert=self._samples_per_insert, - error_buffer=error_buffer) - else: - limiter = reverb.rate_limiters.MinSize(self._min_replay_size) - replay_table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=self._max_replay_size, - rate_limiter=limiter, - signature=adders.NStepTransitionAdder.signature( - self._environment_spec)) - return [replay_table] - - def counter(self): - return tf2_savers.CheckpointingRunner(counting.Counter(), - time_delta_minutes=1, - subdirectory='counter') - - def coordinator(self, counter: counting.Counter, max_actor_steps: int): - return lp_utils.StepsLimiter(counter, max_actor_steps) - - def learner( - self, - replay: reverb.Client, - counter: counting.Counter, - ): - """The Learning part of the agent.""" - - act_spec = self._environment_spec.actions - obs_spec = self._environment_spec.observations - - # Create the networks to optimize (online) and target networks. - online_networks = self._network_factory(act_spec) - target_networks = self._network_factory(act_spec) - - # Make sure observation network is a Sonnet Module. - observation_network = online_networks.get('observation', tf.identity) - target_observation_network = target_networks.get('observation', tf.identity) - observation_network = tf2_utils.to_sonnet_module(observation_network) - target_observation_network = tf2_utils.to_sonnet_module( - target_observation_network) - - # Get embedding spec and create observation network variables. - emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) - - # Create variables. - tf2_utils.create_variables(online_networks['policy'], [emb_spec]) - tf2_utils.create_variables(online_networks['critic'], [emb_spec, act_spec]) - tf2_utils.create_variables(target_networks['policy'], [emb_spec]) - tf2_utils.create_variables(target_networks['critic'], [emb_spec, act_spec]) - tf2_utils.create_variables(target_observation_network, [obs_spec]) - - # The dataset object to learn from. - dataset = datasets.make_reverb_dataset( - server_address=replay.server_address, - batch_size=self._batch_size, - prefetch_size=self._prefetch_size) - - # Create optimizers. - policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4) - critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4) - - counter = counting.Counter(counter, 'learner') - logger = loggers.make_default_logger( - 'learner', time_delta=self._log_every, steps_key='learner_steps') - - # Return the learning agent. - return learning.DDPGLearner( - policy_network=online_networks['policy'], - critic_network=online_networks['critic'], - observation_network=observation_network, - target_policy_network=target_networks['policy'], - target_critic_network=target_networks['critic'], - target_observation_network=target_observation_network, - discount=self._discount, - target_update_period=self._target_update_period, - dataset=dataset, - policy_optimizer=policy_optimizer, - critic_optimizer=critic_optimizer, - clipping=self._clipping, - counter=counter, - logger=logger, - ) - - def actor( - self, - replay: reverb.Client, - variable_source: acme.VariableSource, - counter: counting.Counter, - ): - """The actor process.""" - - action_spec = self._environment_spec.actions - observation_spec = self._environment_spec.observations - - # Create environment and behavior networks - environment = self._environment_factory(False) - agent_networks = self._network_factory(action_spec) - - # Create behavior network by adding some random dithering. - behavior_network = snt.Sequential([ - agent_networks.get('observation', tf.identity), - agent_networks.get('policy'), - networks.ClippedGaussian(self._sigma), - ]) - - # Ensure network variables are created. - tf2_utils.create_variables(behavior_network, [observation_spec]) - variables = {'policy': behavior_network.variables} - - # Create the variable client responsible for keeping the actor up-to-date. - variable_client = tf2_variable_utils.VariableClient( - variable_source, variables, update_period=self._variable_update_period) - - # Make sure not to use a random policy after checkpoint restoration by - # assigning variables before running the environment loop. - variable_client.update_and_wait() - - # Component to add things into replay. - adder = adders.NStepTransitionAdder( - client=replay, n_step=self._n_step, discount=self._discount) - - # Create the agent. - actor = actors.FeedForwardActor( - behavior_network, adder=adder, variable_client=variable_client) - - # Create logger and counter; actors will not spam bigtable. - counter = counting.Counter(counter, 'actor') - logger = loggers.make_default_logger( - 'actor', - save_data=False, - time_delta=self._log_every, - steps_key='actor_steps') - - # Create the loop to connect environment and agent. - return acme.EnvironmentLoop(environment, actor, counter, logger) - - def evaluator( - self, - variable_source: acme.VariableSource, - counter: counting.Counter, - ): - """The evaluation process.""" - - action_spec = self._environment_spec.actions - observation_spec = self._environment_spec.observations - - # Create environment and evaluator networks - environment = self._environment_factory(True) - agent_networks = self._network_factory(action_spec) - - # Create evaluator network. - evaluator_network = snt.Sequential([ - agent_networks.get('observation', tf.identity), - agent_networks.get('policy'), - ]) - - # Ensure network variables are created. - tf2_utils.create_variables(evaluator_network, [observation_spec]) - variables = {'policy': evaluator_network.variables} - - # Create the variable client responsible for keeping the actor up-to-date. - variable_client = tf2_variable_utils.VariableClient( - variable_source, variables, update_period=self._variable_update_period) - - # Make sure not to evaluate a random actor by assigning variables before - # running the environment loop. - variable_client.update_and_wait() - - # Create the evaluator; note it will not add experience to replay. - evaluator = actors.FeedForwardActor( - evaluator_network, variable_client=variable_client) - - # Create logger and counter. - counter = counting.Counter(counter, 'evaluator') - logger = loggers.make_default_logger( - 'evaluator', time_delta=self._log_every, steps_key='evaluator_steps') - - # Create the run loop and return it. - return acme.EnvironmentLoop( - environment, evaluator, counter, logger) - - def build(self, name='ddpg'): - """Build the distributed agent topology.""" - program = lp.Program(name=name) - - with program.group('replay'): - replay = program.add_node(lp.ReverbNode(self.replay)) - - with program.group('counter'): - counter = program.add_node(lp.CourierNode(self.counter)) - - if self._max_actor_steps: - _ = program.add_node( - lp.CourierNode(self.coordinator, counter, self._max_actor_steps)) - - with program.group('learner'): - learner = program.add_node( - lp.CourierNode(self.learner, replay, counter)) - - with program.group('evaluator'): - program.add_node( - lp.CourierNode(self.evaluator, learner, counter)) - - if not self._num_caches: - # Use our learner as a single variable source. - sources = [learner] - else: - with program.group('cacher'): - # Create a set of learner caches. - sources = [] - for _ in range(self._num_caches): - cacher = program.add_node( - lp.CacherNode( - learner, refresh_interval_ms=2000, stale_after_ms=4000)) - sources.append(cacher) - - with program.group('actor'): - # Add actors which pull round-robin from our variable sources. - for actor_id in range(self._num_actors): - source = sources[actor_id % len(sources)] - program.add_node(lp.CourierNode(self.actor, replay, source, counter)) - - return program + """Program definition for distributed DDPG (D3PG).""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], + num_actors: int = 1, + num_caches: int = 0, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = 4, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: Optional[float] = 32.0, + n_step: int = 5, + sigma: float = 0.3, + clipping: bool = True, + discount: float = 0.99, + target_update_period: int = 100, + variable_update_period: int = 1000, + max_actor_steps: Optional[int] = None, + log_every: float = 10.0, + ): + + if not environment_spec: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + self._environment_factory = environment_factory + self._network_factory = network_factory + self._environment_spec = environment_spec + self._num_actors = num_actors + self._num_caches = num_caches + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._min_replay_size = min_replay_size + self._max_replay_size = max_replay_size + self._samples_per_insert = samples_per_insert + self._n_step = n_step + self._sigma = sigma + self._clipping = clipping + self._discount = discount + self._target_update_period = target_update_period + self._variable_update_period = variable_update_period + self._max_actor_steps = max_actor_steps + self._log_every = log_every + + def replay(self): + """The replay storage.""" + if self._samples_per_insert is not None: + # Create enough of an error buffer to give a 10% tolerance in rate. + samples_per_insert_tolerance = 0.1 * self._samples_per_insert + error_buffer = self._min_replay_size * samples_per_insert_tolerance + + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._min_replay_size, + samples_per_insert=self._samples_per_insert, + error_buffer=error_buffer, + ) + else: + limiter = reverb.rate_limiters.MinSize(self._min_replay_size) + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._max_replay_size, + rate_limiter=limiter, + signature=adders.NStepTransitionAdder.signature(self._environment_spec), + ) + return [replay_table] + + def counter(self): + return tf2_savers.CheckpointingRunner( + counting.Counter(), time_delta_minutes=1, subdirectory="counter" + ) + + def coordinator(self, counter: counting.Counter, max_actor_steps: int): + return lp_utils.StepsLimiter(counter, max_actor_steps) + + def learner( + self, replay: reverb.Client, counter: counting.Counter, + ): + """The Learning part of the agent.""" + + act_spec = self._environment_spec.actions + obs_spec = self._environment_spec.observations + + # Create the networks to optimize (online) and target networks. + online_networks = self._network_factory(act_spec) + target_networks = self._network_factory(act_spec) + + # Make sure observation network is a Sonnet Module. + observation_network = online_networks.get("observation", tf.identity) + target_observation_network = target_networks.get("observation", tf.identity) + observation_network = tf2_utils.to_sonnet_module(observation_network) + target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network + ) + + # Get embedding spec and create observation network variables. + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + + # Create variables. + tf2_utils.create_variables(online_networks["policy"], [emb_spec]) + tf2_utils.create_variables(online_networks["critic"], [emb_spec, act_spec]) + tf2_utils.create_variables(target_networks["policy"], [emb_spec]) + tf2_utils.create_variables(target_networks["critic"], [emb_spec, act_spec]) + tf2_utils.create_variables(target_observation_network, [obs_spec]) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + server_address=replay.server_address, + batch_size=self._batch_size, + prefetch_size=self._prefetch_size, + ) + + # Create optimizers. + policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4) + critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4) + + counter = counting.Counter(counter, "learner") + logger = loggers.make_default_logger( + "learner", time_delta=self._log_every, steps_key="learner_steps" + ) + + # Return the learning agent. + return learning.DDPGLearner( + policy_network=online_networks["policy"], + critic_network=online_networks["critic"], + observation_network=observation_network, + target_policy_network=target_networks["policy"], + target_critic_network=target_networks["critic"], + target_observation_network=target_observation_network, + discount=self._discount, + target_update_period=self._target_update_period, + dataset=dataset, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + clipping=self._clipping, + counter=counter, + logger=logger, + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + ): + """The actor process.""" + + action_spec = self._environment_spec.actions + observation_spec = self._environment_spec.observations + + # Create environment and behavior networks + environment = self._environment_factory(False) + agent_networks = self._network_factory(action_spec) + + # Create behavior network by adding some random dithering. + behavior_network = snt.Sequential( + [ + agent_networks.get("observation", tf.identity), + agent_networks.get("policy"), + networks.ClippedGaussian(self._sigma), + ] + ) + + # Ensure network variables are created. + tf2_utils.create_variables(behavior_network, [observation_spec]) + variables = {"policy": behavior_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, variables, update_period=self._variable_update_period + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Component to add things into replay. + adder = adders.NStepTransitionAdder( + client=replay, n_step=self._n_step, discount=self._discount + ) + + # Create the agent. + actor = actors.FeedForwardActor( + behavior_network, adder=adder, variable_client=variable_client + ) + + # Create logger and counter; actors will not spam bigtable. + counter = counting.Counter(counter, "actor") + logger = loggers.make_default_logger( + "actor", + save_data=False, + time_delta=self._log_every, + steps_key="actor_steps", + ) + + # Create the loop to connect environment and agent. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, variable_source: acme.VariableSource, counter: counting.Counter, + ): + """The evaluation process.""" + + action_spec = self._environment_spec.actions + observation_spec = self._environment_spec.observations + + # Create environment and evaluator networks + environment = self._environment_factory(True) + agent_networks = self._network_factory(action_spec) + + # Create evaluator network. + evaluator_network = snt.Sequential( + [ + agent_networks.get("observation", tf.identity), + agent_networks.get("policy"), + ] + ) + + # Ensure network variables are created. + tf2_utils.create_variables(evaluator_network, [observation_spec]) + variables = {"policy": evaluator_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, variables, update_period=self._variable_update_period + ) + + # Make sure not to evaluate a random actor by assigning variables before + # running the environment loop. + variable_client.update_and_wait() + + # Create the evaluator; note it will not add experience to replay. + evaluator = actors.FeedForwardActor( + evaluator_network, variable_client=variable_client + ) + + # Create logger and counter. + counter = counting.Counter(counter, "evaluator") + logger = loggers.make_default_logger( + "evaluator", time_delta=self._log_every, steps_key="evaluator_steps" + ) + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, evaluator, counter, logger) + + def build(self, name="ddpg"): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group("replay"): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group("counter"): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + _ = program.add_node( + lp.CourierNode(self.coordinator, counter, self._max_actor_steps) + ) + + with program.group("learner"): + learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) + + with program.group("evaluator"): + program.add_node(lp.CourierNode(self.evaluator, learner, counter)) + + if not self._num_caches: + # Use our learner as a single variable source. + sources = [learner] + else: + with program.group("cacher"): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000 + ) + ) + sources.append(cacher) + + with program.group("actor"): + # Add actors which pull round-robin from our variable sources. + for actor_id in range(self._num_actors): + source = sources[actor_id % len(sources)] + program.add_node(lp.CourierNode(self.actor, replay, source, counter)) + + return program diff --git a/acme/agents/tf/ddpg/agent_distributed_test.py b/acme/agents/tf/ddpg/agent_distributed_test.py index 1b930a9710..1fa72a4237 100644 --- a/acme/agents/tf/ddpg/agent_distributed_test.py +++ b/acme/agents/tf/ddpg/agent_distributed_test.py @@ -14,76 +14,80 @@ """Integration test for the distributed agent.""" +import launchpad as lp +import numpy as np +import sonnet as snt +from absl.testing import absltest + import acme from acme import specs from acme.agents.tf import ddpg from acme.testing import fakes from acme.tf import networks from acme.tf import utils as tf2_utils -import launchpad as lp -import numpy as np -import sonnet as snt - -from absl.testing import absltest def make_networks(action_spec: specs.BoundedArray): - """Creates simple networks for testing..""" - - num_dimensions = np.prod(action_spec.shape, dtype=int) + """Creates simple networks for testing..""" - # Create the observation network shared between the policy and critic. - observation_network = tf2_utils.batch_concat + num_dimensions = np.prod(action_spec.shape, dtype=int) - # Create the policy network (head) and the evaluation network. - policy_network = snt.Sequential([ - networks.LayerNormMLP([50], activate_final=True), - networks.NearZeroInitializedLinear(num_dimensions), - networks.TanhToSpec(action_spec) - ]) - evaluator_network = snt.Sequential([observation_network, policy_network]) + # Create the observation network shared between the policy and critic. + observation_network = tf2_utils.batch_concat - # Create the critic network. - critic_network = snt.Sequential([ - # The multiplexer concatenates the observations/actions. - networks.CriticMultiplexer(), - networks.LayerNormMLP([50], activate_final=True), - networks.NearZeroInitializedLinear(1), - ]) + # Create the policy network (head) and the evaluation network. + policy_network = snt.Sequential( + [ + networks.LayerNormMLP([50], activate_final=True), + networks.NearZeroInitializedLinear(num_dimensions), + networks.TanhToSpec(action_spec), + ] + ) + evaluator_network = snt.Sequential([observation_network, policy_network]) + + # Create the critic network. + critic_network = snt.Sequential( + [ + # The multiplexer concatenates the observations/actions. + networks.CriticMultiplexer(), + networks.LayerNormMLP([50], activate_final=True), + networks.NearZeroInitializedLinear(1), + ] + ) - return { - 'policy': policy_network, - 'critic': critic_network, - 'observation': observation_network, - 'evaluator': evaluator_network, - } + return { + "policy": policy_network, + "critic": critic_network, + "observation": observation_network, + "evaluator": evaluator_network, + } class DistributedAgentTest(absltest.TestCase): - """Simple integration/smoke test for the distributed agent.""" + """Simple integration/smoke test for the distributed agent.""" - def test_agent(self): + def test_agent(self): - agent = ddpg.DistributedDDPG( - environment_factory=lambda x: fakes.ContinuousEnvironment(bounded=True), - network_factory=make_networks, - num_actors=2, - batch_size=32, - min_replay_size=32, - max_replay_size=1000, - ) - program = agent.build() + agent = ddpg.DistributedDDPG( + environment_factory=lambda x: fakes.ContinuousEnvironment(bounded=True), + network_factory=make_networks, + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + ) + program = agent.build() - (learner_node,) = program.groups['learner'] - learner_node.disable_run() + (learner_node,) = program.groups["learner"] + learner_node.disable_run() - lp.launch(program, launch_type='test_mt') + lp.launch(program, launch_type="test_mt") - learner: acme.Learner = learner_node.create_handle().dereference() + learner: acme.Learner = learner_node.create_handle().dereference() - for _ in range(5): - learner.step() + for _ in range(5): + learner.step() -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/ddpg/agent_test.py b/acme/agents/tf/ddpg/agent_test.py index 9287e8c275..dc48070505 100644 --- a/acme/agents/tf/ddpg/agent_test.py +++ b/acme/agents/tf/ddpg/agent_test.py @@ -16,67 +16,67 @@ from typing import Dict, Sequence -import acme -from acme import specs -from acme import types -from acme.agents.tf import ddpg -from acme.testing import fakes -from acme.tf import networks import numpy as np import sonnet as snt import tensorflow as tf - from absl.testing import absltest +import acme +from acme import specs, types +from acme.agents.tf import ddpg +from acme.testing import fakes +from acme.tf import networks + def make_networks( action_spec: types.NestedSpec, policy_layer_sizes: Sequence[int] = (10, 10), critic_layer_sizes: Sequence[int] = (10, 10), ) -> Dict[str, snt.Module]: - """Creates networks used by the agent.""" - - num_dimensions = np.prod(action_spec.shape, dtype=int) - policy_layer_sizes = list(policy_layer_sizes) + [num_dimensions] - critic_layer_sizes = list(critic_layer_sizes) + [1] - - policy_network = snt.Sequential( - [networks.LayerNormMLP(policy_layer_sizes), tf.tanh]) - # The multiplexer concatenates the (maybe transformed) observations/actions. - critic_network = networks.CriticMultiplexer( - critic_network=networks.LayerNormMLP(critic_layer_sizes)) - - return { - 'policy': policy_network, - 'critic': critic_network, - } + """Creates networks used by the agent.""" + num_dimensions = np.prod(action_spec.shape, dtype=int) + policy_layer_sizes = list(policy_layer_sizes) + [num_dimensions] + critic_layer_sizes = list(critic_layer_sizes) + [1] -class DDPGTest(absltest.TestCase): - - def test_ddpg(self): - # Create a fake environment to test with. - environment = fakes.ContinuousEnvironment(episode_length=10, bounded=True) - spec = specs.make_environment_spec(environment) - - # Create the networks to optimize (online) and target networks. - agent_networks = make_networks(spec.actions) - - # Construct the agent. - agent = ddpg.DDPG( - environment_spec=spec, - policy_network=agent_networks['policy'], - critic_network=agent_networks['critic'], - batch_size=10, - samples_per_insert=2, - min_replay_size=10, + policy_network = snt.Sequential( + [networks.LayerNormMLP(policy_layer_sizes), tf.tanh] + ) + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = networks.CriticMultiplexer( + critic_network=networks.LayerNormMLP(critic_layer_sizes) ) - # Try running the environment loop. We have no assertions here because all - # we care about is that the agent runs without raising any errors. - loop = acme.EnvironmentLoop(environment, agent) - loop.run(num_episodes=2) + return { + "policy": policy_network, + "critic": critic_network, + } -if __name__ == '__main__': - absltest.main() +class DDPGTest(absltest.TestCase): + def test_ddpg(self): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment(episode_length=10, bounded=True) + spec = specs.make_environment_spec(environment) + + # Create the networks to optimize (online) and target networks. + agent_networks = make_networks(spec.actions) + + # Construct the agent. + agent = ddpg.DDPG( + environment_spec=spec, + policy_network=agent_networks["policy"], + critic_network=agent_networks["critic"], + batch_size=10, + samples_per_insert=2, + min_replay_size=10, + ) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/ddpg/learning.py b/acme/agents/tf/ddpg/learning.py index 1c74200a66..38c8321609 100644 --- a/acme/agents/tf/ddpg/learning.py +++ b/acme/agents/tf/ddpg/learning.py @@ -17,46 +17,46 @@ import time from typing import List, Optional -import acme -from acme import types -from acme.tf import losses -from acme.tf import savers as tf2_savers -from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers import numpy as np import sonnet as snt import tensorflow as tf import tree import trfl +import acme +from acme import types +from acme.tf import losses +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting, loggers + class DDPGLearner(acme.Learner): - """DDPG learner. + """DDPG learner. This is the learning component of a DDPG agent. IE it takes a dataset as input and implements update functionality to learn from this dataset. """ - def __init__( - self, - policy_network: snt.Module, - critic_network: snt.Module, - target_policy_network: snt.Module, - target_critic_network: snt.Module, - discount: float, - target_update_period: int, - dataset: tf.data.Dataset, - observation_network: types.TensorTransformation = lambda x: x, - target_observation_network: types.TensorTransformation = lambda x: x, - policy_optimizer: Optional[snt.Optimizer] = None, - critic_optimizer: Optional[snt.Optimizer] = None, - clipping: bool = True, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - checkpoint: bool = True, - ): - """Initializes the learner. + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + target_policy_network: snt.Module, + target_critic_network: snt.Module, + discount: float, + target_update_period: int, + dataset: tf.data.Dataset, + observation_network: types.TensorTransformation = lambda x: x, + target_observation_network: types.TensorTransformation = lambda x: x, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + clipping: bool = True, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + """Initializes the learner. Args: policy_network: the online (optimized) policy. @@ -80,178 +80,183 @@ def __init__( checkpoint: boolean indicating whether to checkpoint the learner. """ - # Store online and target networks. - self._policy_network = policy_network - self._critic_network = critic_network - self._target_policy_network = target_policy_network - self._target_critic_network = target_critic_network - - # Make sure observation networks are snt.Module's so they have variables. - self._observation_network = tf2_utils.to_sonnet_module(observation_network) - self._target_observation_network = tf2_utils.to_sonnet_module( - target_observation_network) - - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger('learner') - - # Other learner parameters. - self._discount = discount - self._clipping = clipping - - # Necessary to track when to update target networks. - self._num_steps = tf.Variable(0, dtype=tf.int32) - self._target_update_period = target_update_period - - # Create an iterator to go through the dataset. - # TODO(b/155086959): Fix type stubs and remove. - self._iterator = iter(dataset) # pytype: disable=wrong-arg-types - - # Create optimizers if they aren't given. - self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) - self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) - - # Expose the variables. - policy_network_to_expose = snt.Sequential( - [self._target_observation_network, self._target_policy_network]) - self._variables = { - 'critic': target_critic_network.variables, - 'policy': policy_network_to_expose.variables, - } - - self._checkpointer = tf2_savers.Checkpointer( - time_delta_minutes=5, - objects_to_save={ - 'counter': self._counter, - 'policy': self._policy_network, - 'critic': self._critic_network, - 'target_policy': self._target_policy_network, - 'target_critic': self._target_critic_network, - 'policy_optimizer': self._policy_optimizer, - 'critic_optimizer': self._critic_optimizer, - 'num_steps': self._num_steps, - }, - enable_checkpointing=checkpoint, - ) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - @tf.function - def _step(self): - # Update target network. - online_variables = ( - *self._observation_network.variables, - *self._critic_network.variables, - *self._policy_network.variables, - ) - target_variables = ( - *self._target_observation_network.variables, - *self._target_critic_network.variables, - *self._target_policy_network.variables, - ) - - # Make online -> target network update ops. - if tf.math.mod(self._num_steps, self._target_update_period) == 0: - for src, dest in zip(online_variables, target_variables): - dest.assign(src) - self._num_steps.assign_add(1) - - # Get data from replay (dropping extras if any). Note there is no - # extra data here because we do not insert any into Reverb. - inputs = next(self._iterator) - transitions: types.Transition = inputs.data - - # Cast the additional discount to match the environment discount dtype. - discount = tf.cast(self._discount, dtype=transitions.discount.dtype) - - with tf.GradientTape(persistent=True) as tape: - # Maybe transform the observation before feeding into policy and critic. - # Transforming the observations this way at the start of the learning - # step effectively means that the policy and critic share observation - # network weights. - o_tm1 = self._observation_network(transitions.observation) - o_t = self._target_observation_network(transitions.next_observation) - # This stop_gradient prevents gradients to propagate into the target - # observation network. In addition, since the online policy network is - # evaluated at o_t, this also means the policy loss does not influence - # the observation network training. - o_t = tree.map_structure(tf.stop_gradient, o_t) - - # Critic learning. - q_tm1 = self._critic_network(o_tm1, transitions.action) - q_t = self._target_critic_network(o_t, self._target_policy_network(o_t)) - - # Squeeze into the shape expected by the td_learning implementation. - q_tm1 = tf.squeeze(q_tm1, axis=-1) # [B] - q_t = tf.squeeze(q_t, axis=-1) # [B] - - # Critic loss. - critic_loss = trfl.td_learning(q_tm1, transitions.reward, - discount * transitions.discount, q_t).loss - critic_loss = tf.reduce_mean(critic_loss, axis=0) - - # Actor learning. - dpg_a_t = self._policy_network(o_t) - dpg_q_t = self._critic_network(o_t, dpg_a_t) - - # Actor loss. If clipping is true use dqda clipping and clip the norm. - dqda_clipping = 1.0 if self._clipping else None - policy_loss = losses.dpg( - dpg_q_t, - dpg_a_t, - tape=tape, - dqda_clipping=dqda_clipping, - clip_norm=self._clipping) - policy_loss = tf.reduce_mean(policy_loss, axis=0) - - # Get trainable variables. - policy_variables = self._policy_network.trainable_variables - critic_variables = ( - # In this agent, the critic loss trains the observation network. - self._observation_network.trainable_variables + - self._critic_network.trainable_variables) - - # Compute gradients. - policy_gradients = tape.gradient(policy_loss, policy_variables) - critic_gradients = tape.gradient(critic_loss, critic_variables) - - # Delete the tape manually because of the persistent=True flag. - del tape - - # Maybe clip gradients. - if self._clipping: - policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0] - critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0] - - # Apply gradients. - self._policy_optimizer.apply(policy_gradients, policy_variables) - self._critic_optimizer.apply(critic_gradients, critic_variables) - - # Losses to track. - return { - 'critic_loss': critic_loss, - 'policy_loss': policy_loss, - } - - def step(self): - # Run the learning step. - fetches = self._step() - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Update our counts and record it. - counts = self._counter.increment(steps=1, walltime=elapsed_time) - fetches.update(counts) - - # Checkpoint and attempt to write the logs. - self._checkpointer.save() - self._logger.write(fetches) - - def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: - return [tf2_utils.to_numpy(self._variables[name]) for name in names] + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + + # Make sure observation networks are snt.Module's so they have variables. + self._observation_network = tf2_utils.to_sonnet_module(observation_network) + self._target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network + ) + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger("learner") + + # Other learner parameters. + self._discount = discount + self._clipping = clipping + + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_update_period = target_update_period + + # Create an iterator to go through the dataset. + # TODO(b/155086959): Fix type stubs and remove. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + + # Create optimizers if they aren't given. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + + # Expose the variables. + policy_network_to_expose = snt.Sequential( + [self._target_observation_network, self._target_policy_network] + ) + self._variables = { + "critic": target_critic_network.variables, + "policy": policy_network_to_expose.variables, + } + + self._checkpointer = tf2_savers.Checkpointer( + time_delta_minutes=5, + objects_to_save={ + "counter": self._counter, + "policy": self._policy_network, + "critic": self._critic_network, + "target_policy": self._target_policy_network, + "target_critic": self._target_critic_network, + "policy_optimizer": self._policy_optimizer, + "critic_optimizer": self._critic_optimizer, + "num_steps": self._num_steps, + }, + enable_checkpointing=checkpoint, + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self): + # Update target network. + online_variables = ( + *self._observation_network.variables, + *self._critic_network.variables, + *self._policy_network.variables, + ) + target_variables = ( + *self._target_observation_network.variables, + *self._target_critic_network.variables, + *self._target_policy_network.variables, + ) + + # Make online -> target network update ops. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(online_variables, target_variables): + dest.assign(src) + self._num_steps.assign_add(1) + + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + inputs = next(self._iterator) + transitions: types.Transition = inputs.data + + # Cast the additional discount to match the environment discount dtype. + discount = tf.cast(self._discount, dtype=transitions.discount.dtype) + + with tf.GradientTape(persistent=True) as tape: + # Maybe transform the observation before feeding into policy and critic. + # Transforming the observations this way at the start of the learning + # step effectively means that the policy and critic share observation + # network weights. + o_tm1 = self._observation_network(transitions.observation) + o_t = self._target_observation_network(transitions.next_observation) + # This stop_gradient prevents gradients to propagate into the target + # observation network. In addition, since the online policy network is + # evaluated at o_t, this also means the policy loss does not influence + # the observation network training. + o_t = tree.map_structure(tf.stop_gradient, o_t) + + # Critic learning. + q_tm1 = self._critic_network(o_tm1, transitions.action) + q_t = self._target_critic_network(o_t, self._target_policy_network(o_t)) + + # Squeeze into the shape expected by the td_learning implementation. + q_tm1 = tf.squeeze(q_tm1, axis=-1) # [B] + q_t = tf.squeeze(q_t, axis=-1) # [B] + + # Critic loss. + critic_loss = trfl.td_learning( + q_tm1, transitions.reward, discount * transitions.discount, q_t + ).loss + critic_loss = tf.reduce_mean(critic_loss, axis=0) + + # Actor learning. + dpg_a_t = self._policy_network(o_t) + dpg_q_t = self._critic_network(o_t, dpg_a_t) + + # Actor loss. If clipping is true use dqda clipping and clip the norm. + dqda_clipping = 1.0 if self._clipping else None + policy_loss = losses.dpg( + dpg_q_t, + dpg_a_t, + tape=tape, + dqda_clipping=dqda_clipping, + clip_norm=self._clipping, + ) + policy_loss = tf.reduce_mean(policy_loss, axis=0) + + # Get trainable variables. + policy_variables = self._policy_network.trainable_variables + critic_variables = ( + # In this agent, the critic loss trains the observation network. + self._observation_network.trainable_variables + + self._critic_network.trainable_variables + ) + + # Compute gradients. + policy_gradients = tape.gradient(policy_loss, policy_variables) + critic_gradients = tape.gradient(critic_loss, critic_variables) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Maybe clip gradients. + if self._clipping: + policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.0)[0] + critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.0)[0] + + # Apply gradients. + self._policy_optimizer.apply(policy_gradients, policy_variables) + self._critic_optimizer.apply(critic_gradients, critic_variables) + + # Losses to track. + return { + "critic_loss": critic_loss, + "policy_loss": policy_loss, + } + + def step(self): + # Run the learning step. + fetches = self._step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + self._checkpointer.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] diff --git a/acme/agents/tf/dmpo/agent.py b/acme/agents/tf/dmpo/agent.py index 8ca0c621eb..c81fd2a351 100644 --- a/acme/agents/tf/dmpo/agent.py +++ b/acme/agents/tf/dmpo/agent.py @@ -17,24 +17,22 @@ import copy from typing import Optional -from acme import datasets -from acme import specs -from acme import types +import reverb +import sonnet as snt +import tensorflow as tf + +from acme import datasets, specs, types from acme.adders import reverb as adders from acme.agents import agent from acme.agents.tf import actors from acme.agents.tf.dmpo import learning from acme.tf import networks from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers -import reverb -import sonnet as snt -import tensorflow as tf +from acme.utils import counting, loggers class DistributionalMPO(agent.Agent): - """Distributional MPO Agent. + """Distributional MPO Agent. This implements a single-process distributional MPO agent. This is an actor-critic algorithm that generates data via a behavior policy, inserts @@ -44,30 +42,32 @@ class DistributionalMPO(agent.Agent): critic (state-action value approximator). """ - def __init__(self, - environment_spec: specs.EnvironmentSpec, - policy_network: snt.Module, - critic_network: snt.Module, - observation_network: types.TensorTransformation = tf.identity, - discount: float = 0.99, - batch_size: int = 256, - prefetch_size: int = 4, - target_policy_update_period: int = 100, - target_critic_update_period: int = 100, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: float = 32.0, - policy_loss_module: Optional[snt.Module] = None, - policy_optimizer: Optional[snt.Optimizer] = None, - critic_optimizer: Optional[snt.Optimizer] = None, - n_step: int = 5, - num_samples: int = 20, - clipping: bool = True, - logger: Optional[loggers.Logger] = None, - counter: Optional[counting.Counter] = None, - checkpoint: bool = True, - replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): - """Initialize the agent. + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + policy_network: snt.Module, + critic_network: snt.Module, + observation_network: types.TensorTransformation = tf.identity, + discount: float = 0.99, + batch_size: int = 256, + prefetch_size: int = 4, + target_policy_update_period: int = 100, + target_critic_update_period: int = 100, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 32.0, + policy_loss_module: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + n_step: int = 5, + num_samples: int = 20, + clipping: bool = True, + logger: Optional[loggers.Logger] = None, + counter: Optional[counting.Counter] = None, + checkpoint: bool = True, + replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, + ): + """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. @@ -101,88 +101,88 @@ def __init__(self, replay_table_name: string indicating what name to give the replay table. """ - # Create a replay server to add data to. - replay_table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=max_replay_size, - rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), - signature=adders.NStepTransitionAdder.signature(environment_spec)) - self._server = reverb.Server([replay_table], port=None) - - # The adder is used to insert observations into replay. - address = f'localhost:{self._server.port}' - adder = adders.NStepTransitionAdder( - client=reverb.Client(address), - n_step=n_step, - discount=discount) - - # The dataset object to learn from. - dataset = datasets.make_reverb_dataset( - table=replay_table_name, - server_address=address, - batch_size=batch_size, - prefetch_size=prefetch_size) - - # Make sure observation network is a Sonnet Module. - observation_network = tf2_utils.to_sonnet_module(observation_network) - - # Create target networks before creating online/target network variables. - target_policy_network = copy.deepcopy(policy_network) - target_critic_network = copy.deepcopy(critic_network) - target_observation_network = copy.deepcopy(observation_network) - - # Get observation and action specs. - act_spec = environment_spec.actions - obs_spec = environment_spec.observations - emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) - - # Create the behavior policy. - behavior_network = snt.Sequential([ - observation_network, - policy_network, - networks.StochasticSamplingHead(), - ]) - - # Create variables. - tf2_utils.create_variables(policy_network, [emb_spec]) - tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) - tf2_utils.create_variables(target_policy_network, [emb_spec]) - tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) - tf2_utils.create_variables(target_observation_network, [obs_spec]) - - # Create the actor which defines how we take actions. - actor = actors.FeedForwardActor( - policy_network=behavior_network, adder=adder) - - # Create optimizers. - policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) - critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) - - # The learner updates the parameters (and initializes them). - learner = learning.DistributionalMPOLearner( - policy_network=policy_network, - critic_network=critic_network, - observation_network=observation_network, - target_policy_network=target_policy_network, - target_critic_network=target_critic_network, - target_observation_network=target_observation_network, - policy_loss_module=policy_loss_module, - policy_optimizer=policy_optimizer, - critic_optimizer=critic_optimizer, - clipping=clipping, - discount=discount, - num_samples=num_samples, - target_policy_update_period=target_policy_update_period, - target_critic_update_period=target_critic_update_period, - dataset=dataset, - logger=logger, - counter=counter, - checkpoint=checkpoint) - - super().__init__( - actor=actor, - learner=learner, - min_observations=max(batch_size, min_replay_size), - observations_per_step=float(batch_size) / samples_per_insert) + # Create a replay server to add data to. + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), + signature=adders.NStepTransitionAdder.signature(environment_spec), + ) + self._server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f"localhost:{self._server.port}" + adder = adders.NStepTransitionAdder( + client=reverb.Client(address), n_step=n_step, discount=discount + ) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + table=replay_table_name, + server_address=address, + batch_size=batch_size, + prefetch_size=prefetch_size, + ) + + # Make sure observation network is a Sonnet Module. + observation_network = tf2_utils.to_sonnet_module(observation_network) + + # Create target networks before creating online/target network variables. + target_policy_network = copy.deepcopy(policy_network) + target_critic_network = copy.deepcopy(critic_network) + target_observation_network = copy.deepcopy(observation_network) + + # Get observation and action specs. + act_spec = environment_spec.actions + obs_spec = environment_spec.observations + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + + # Create the behavior policy. + behavior_network = snt.Sequential( + [observation_network, policy_network, networks.StochasticSamplingHead(),] + ) + + # Create variables. + tf2_utils.create_variables(policy_network, [emb_spec]) + tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) + tf2_utils.create_variables(target_policy_network, [emb_spec]) + tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) + tf2_utils.create_variables(target_observation_network, [obs_spec]) + + # Create the actor which defines how we take actions. + actor = actors.FeedForwardActor(policy_network=behavior_network, adder=adder) + + # Create optimizers. + policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + + # The learner updates the parameters (and initializes them). + learner = learning.DistributionalMPOLearner( + policy_network=policy_network, + critic_network=critic_network, + observation_network=observation_network, + target_policy_network=target_policy_network, + target_critic_network=target_critic_network, + target_observation_network=target_observation_network, + policy_loss_module=policy_loss_module, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + clipping=clipping, + discount=discount, + num_samples=num_samples, + target_policy_update_period=target_policy_update_period, + target_critic_update_period=target_critic_update_period, + dataset=dataset, + logger=logger, + counter=counter, + checkpoint=checkpoint, + ) + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert, + ) diff --git a/acme/agents/tf/dmpo/agent_distributed.py b/acme/agents/tf/dmpo/agent_distributed.py index 4fff2a17f1..51f6ae1f56 100644 --- a/acme/agents/tf/dmpo/agent_distributed.py +++ b/acme/agents/tf/dmpo/agent_distributed.py @@ -16,10 +16,14 @@ from typing import Callable, Dict, Optional, Sequence +import dm_env +import launchpad as lp +import reverb +import sonnet as snt +import tensorflow as tf + import acme -from acme import datasets -from acme import specs -from acme import types +from acme import datasets, specs, types from acme.adders import reverb as adders from acme.agents.tf import actors from acme.agents.tf.dmpo import learning @@ -28,331 +32,335 @@ from acme.tf import savers as tf2_savers from acme.tf import utils as tf2_utils from acme.tf import variable_utils as tf2_variable_utils -from acme.utils import counting -from acme.utils import loggers -from acme.utils import lp_utils +from acme.utils import counting, loggers, lp_utils from acme.utils import observers as observers_lib -import dm_env -import launchpad as lp -import reverb -import sonnet as snt -import tensorflow as tf class DistributedDistributionalMPO: - """Program definition for distributional MPO.""" - - def __init__( - self, - environment_factory: Callable[[bool], dm_env.Environment], - network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], - num_actors: int = 1, - num_caches: int = 0, - environment_spec: Optional[specs.EnvironmentSpec] = None, - batch_size: int = 256, - prefetch_size: int = 4, - observation_augmentation: Optional[types.TensorTransformation] = None, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: Optional[float] = 32.0, - n_step: int = 5, - num_samples: int = 20, - additional_discount: float = 0.99, - target_policy_update_period: int = 100, - target_critic_update_period: int = 100, - variable_update_period: int = 1000, - policy_loss_factory: Optional[Callable[[], snt.Module]] = None, - max_actor_steps: Optional[int] = None, - log_every: float = 10.0, - make_observers: Optional[Callable[ - [], Sequence[observers_lib.EnvLoopObserver]]] = None): - - if environment_spec is None: - environment_spec = specs.make_environment_spec(environment_factory(False)) - - self._environment_factory = environment_factory - self._network_factory = network_factory - self._policy_loss_factory = policy_loss_factory - self._environment_spec = environment_spec - self._num_actors = num_actors - self._num_caches = num_caches - self._batch_size = batch_size - self._prefetch_size = prefetch_size - self._observation_augmentation = observation_augmentation - self._min_replay_size = min_replay_size - self._max_replay_size = max_replay_size - self._samples_per_insert = samples_per_insert - self._n_step = n_step - self._additional_discount = additional_discount - self._num_samples = num_samples - self._target_policy_update_period = target_policy_update_period - self._target_critic_update_period = target_critic_update_period - self._variable_update_period = variable_update_period - self._max_actor_steps = max_actor_steps - self._log_every = log_every - self._make_observers = make_observers - - def replay(self): - """The replay storage.""" - if self._samples_per_insert is not None: - # Create enough of an error buffer to give a 10% tolerance in rate. - samples_per_insert_tolerance = 0.1 * self._samples_per_insert - error_buffer = self._min_replay_size * samples_per_insert_tolerance - - limiter = reverb.rate_limiters.SampleToInsertRatio( - min_size_to_sample=self._min_replay_size, - samples_per_insert=self._samples_per_insert, - error_buffer=error_buffer) - else: - limiter = reverb.rate_limiters.MinSize(self._min_replay_size) - replay_table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=self._max_replay_size, - rate_limiter=limiter, - signature=adders.NStepTransitionAdder.signature( - self._environment_spec)) - return [replay_table] - - def counter(self): - return tf2_savers.CheckpointingRunner(counting.Counter(), - time_delta_minutes=1, - subdirectory='counter') - - def coordinator(self, counter: counting.Counter, max_actor_steps: int): - return lp_utils.StepsLimiter(counter, max_actor_steps) - - def learner( - self, - replay: reverb.Client, - counter: counting.Counter, - ): - """The Learning part of the agent.""" - - act_spec = self._environment_spec.actions - obs_spec = self._environment_spec.observations - - # Create online and target networks. - online_networks = self._network_factory(act_spec) - target_networks = self._network_factory(act_spec) - - # Make sure observation network is a Sonnet Module. - observation_network = online_networks.get('observation', tf.identity) - target_observation_network = target_networks.get('observation', tf.identity) - observation_network = tf2_utils.to_sonnet_module(observation_network) - target_observation_network = tf2_utils.to_sonnet_module( - target_observation_network) - - # Get embedding spec and create observation network variables. - emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) - - # Create variables. - tf2_utils.create_variables(online_networks['policy'], [emb_spec]) - tf2_utils.create_variables(online_networks['critic'], [emb_spec, act_spec]) - tf2_utils.create_variables(target_networks['policy'], [emb_spec]) - tf2_utils.create_variables(target_networks['critic'], [emb_spec, act_spec]) - tf2_utils.create_variables(target_observation_network, [obs_spec]) - - # The dataset object to learn from. - dataset = datasets.make_reverb_dataset(server_address=replay.server_address) - dataset = dataset.batch(self._batch_size, drop_remainder=True) - if self._observation_augmentation: - transform = image_augmentation.make_transform( - observation_transform=self._observation_augmentation) - dataset = dataset.map( - transform, num_parallel_calls=16, deterministic=False) - dataset = dataset.prefetch(self._prefetch_size) - - counter = counting.Counter(counter, 'learner') - logger = loggers.make_default_logger( - 'learner', time_delta=self._log_every, steps_key='learner_steps') - - # Create policy loss module if a factory is passed. - if self._policy_loss_factory: - policy_loss_module = self._policy_loss_factory() - else: - policy_loss_module = None - - # Return the learning agent. - return learning.DistributionalMPOLearner( - policy_network=online_networks['policy'], - critic_network=online_networks['critic'], - observation_network=observation_network, - target_policy_network=target_networks['policy'], - target_critic_network=target_networks['critic'], - target_observation_network=target_observation_network, - discount=self._additional_discount, - num_samples=self._num_samples, - target_policy_update_period=self._target_policy_update_period, - target_critic_update_period=self._target_critic_update_period, - policy_loss_module=policy_loss_module, - dataset=dataset, - counter=counter, - logger=logger) - - def actor( - self, - replay: reverb.Client, - variable_source: acme.VariableSource, - counter: counting.Counter, - actor_id: int, - ) -> acme.EnvironmentLoop: - """The actor process.""" - - action_spec = self._environment_spec.actions - observation_spec = self._environment_spec.observations - - # Create environment and target networks to act with. - environment = self._environment_factory(False) - agent_networks = self._network_factory(action_spec) - - # Make sure observation network is defined. - observation_network = agent_networks.get('observation', tf.identity) - - # Create a stochastic behavior policy. - behavior_network = snt.Sequential([ - observation_network, - agent_networks['policy'], - networks.StochasticSamplingHead(), - ]) - - # Ensure network variables are created. - tf2_utils.create_variables(behavior_network, [observation_spec]) - policy_variables = {'policy': behavior_network.variables} - - # Create the variable client responsible for keeping the actor up-to-date. - variable_client = tf2_variable_utils.VariableClient( - variable_source, - policy_variables, - update_period=self._variable_update_period) - - # Make sure not to use a random policy after checkpoint restoration by - # assigning variables before running the environment loop. - variable_client.update_and_wait() - - # Component to add things into replay. - adder = adders.NStepTransitionAdder( - client=replay, - n_step=self._n_step, - discount=self._additional_discount) - - # Create the agent. - actor = actors.FeedForwardActor( - policy_network=behavior_network, - adder=adder, - variable_client=variable_client) - - # Create logger and counter; only the first actor stores logs to bigtable. - save_data = actor_id == 0 - counter = counting.Counter(counter, 'actor') - logger = loggers.make_default_logger( - 'actor', - save_data=save_data, - time_delta=self._log_every, - steps_key='actor_steps') - observers = self._make_observers() if self._make_observers else () - - # Create the run loop and return it. - return acme.EnvironmentLoop( - environment, actor, counter, logger, observers=observers) - - def evaluator( - self, - variable_source: acme.VariableSource, - counter: counting.Counter, - ): - """The evaluation process.""" - - action_spec = self._environment_spec.actions - observation_spec = self._environment_spec.observations - - # Create environment and target networks to act with. - environment = self._environment_factory(True) - agent_networks = self._network_factory(action_spec) - - # Make sure observation network is defined. - observation_network = agent_networks.get('observation', tf.identity) - - # Create a stochastic behavior policy. - evaluator_network = snt.Sequential([ - observation_network, - agent_networks['policy'], - networks.StochasticMeanHead(), - ]) - - # Ensure network variables are created. - tf2_utils.create_variables(evaluator_network, [observation_spec]) - policy_variables = {'policy': evaluator_network.variables} - - # Create the variable client responsible for keeping the actor up-to-date. - variable_client = tf2_variable_utils.VariableClient( - variable_source, - policy_variables, - update_period=self._variable_update_period) - - # Make sure not to evaluate a random actor by assigning variables before - # running the environment loop. - variable_client.update_and_wait() - - # Create the agent. - evaluator = actors.FeedForwardActor( - policy_network=evaluator_network, variable_client=variable_client) - - # Create logger and counter. - counter = counting.Counter(counter, 'evaluator') - logger = loggers.make_default_logger( - 'evaluator', time_delta=self._log_every, steps_key='evaluator_steps') - observers = self._make_observers() if self._make_observers else () - - # Create the run loop and return it. - return acme.EnvironmentLoop( - environment, - evaluator, - counter, - logger, - observers=observers) - - def build(self, name='dmpo'): - """Build the distributed agent topology.""" - program = lp.Program(name=name) - - with program.group('replay'): - replay = program.add_node(lp.ReverbNode(self.replay)) - - with program.group('counter'): - counter = program.add_node(lp.CourierNode(self.counter)) - - if self._max_actor_steps: - _ = program.add_node( - lp.CourierNode(self.coordinator, counter, self._max_actor_steps)) - - with program.group('learner'): - learner = program.add_node( - lp.CourierNode(self.learner, replay, counter)) - - with program.group('evaluator'): - program.add_node( - lp.CourierNode(self.evaluator, learner, counter)) - - if not self._num_caches: - # Use our learner as a single variable source. - sources = [learner] - else: - with program.group('cacher'): - # Create a set of learner caches. - sources = [] - for _ in range(self._num_caches): - cacher = program.add_node( - lp.CacherNode( - learner, refresh_interval_ms=2000, stale_after_ms=4000)) - sources.append(cacher) - - with program.group('actor'): - # Add actors which pull round-robin from our variable sources. - for actor_id in range(self._num_actors): - source = sources[actor_id % len(sources)] - program.add_node( - lp.CourierNode(self.actor, replay, source, counter, actor_id)) - - return program + """Program definition for distributional MPO.""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], + num_actors: int = 1, + num_caches: int = 0, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = 4, + observation_augmentation: Optional[types.TensorTransformation] = None, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: Optional[float] = 32.0, + n_step: int = 5, + num_samples: int = 20, + additional_discount: float = 0.99, + target_policy_update_period: int = 100, + target_critic_update_period: int = 100, + variable_update_period: int = 1000, + policy_loss_factory: Optional[Callable[[], snt.Module]] = None, + max_actor_steps: Optional[int] = None, + log_every: float = 10.0, + make_observers: Optional[ + Callable[[], Sequence[observers_lib.EnvLoopObserver]] + ] = None, + ): + + if environment_spec is None: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + self._environment_factory = environment_factory + self._network_factory = network_factory + self._policy_loss_factory = policy_loss_factory + self._environment_spec = environment_spec + self._num_actors = num_actors + self._num_caches = num_caches + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._observation_augmentation = observation_augmentation + self._min_replay_size = min_replay_size + self._max_replay_size = max_replay_size + self._samples_per_insert = samples_per_insert + self._n_step = n_step + self._additional_discount = additional_discount + self._num_samples = num_samples + self._target_policy_update_period = target_policy_update_period + self._target_critic_update_period = target_critic_update_period + self._variable_update_period = variable_update_period + self._max_actor_steps = max_actor_steps + self._log_every = log_every + self._make_observers = make_observers + + def replay(self): + """The replay storage.""" + if self._samples_per_insert is not None: + # Create enough of an error buffer to give a 10% tolerance in rate. + samples_per_insert_tolerance = 0.1 * self._samples_per_insert + error_buffer = self._min_replay_size * samples_per_insert_tolerance + + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._min_replay_size, + samples_per_insert=self._samples_per_insert, + error_buffer=error_buffer, + ) + else: + limiter = reverb.rate_limiters.MinSize(self._min_replay_size) + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._max_replay_size, + rate_limiter=limiter, + signature=adders.NStepTransitionAdder.signature(self._environment_spec), + ) + return [replay_table] + + def counter(self): + return tf2_savers.CheckpointingRunner( + counting.Counter(), time_delta_minutes=1, subdirectory="counter" + ) + + def coordinator(self, counter: counting.Counter, max_actor_steps: int): + return lp_utils.StepsLimiter(counter, max_actor_steps) + + def learner( + self, replay: reverb.Client, counter: counting.Counter, + ): + """The Learning part of the agent.""" + + act_spec = self._environment_spec.actions + obs_spec = self._environment_spec.observations + + # Create online and target networks. + online_networks = self._network_factory(act_spec) + target_networks = self._network_factory(act_spec) + + # Make sure observation network is a Sonnet Module. + observation_network = online_networks.get("observation", tf.identity) + target_observation_network = target_networks.get("observation", tf.identity) + observation_network = tf2_utils.to_sonnet_module(observation_network) + target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network + ) + + # Get embedding spec and create observation network variables. + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + + # Create variables. + tf2_utils.create_variables(online_networks["policy"], [emb_spec]) + tf2_utils.create_variables(online_networks["critic"], [emb_spec, act_spec]) + tf2_utils.create_variables(target_networks["policy"], [emb_spec]) + tf2_utils.create_variables(target_networks["critic"], [emb_spec, act_spec]) + tf2_utils.create_variables(target_observation_network, [obs_spec]) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset(server_address=replay.server_address) + dataset = dataset.batch(self._batch_size, drop_remainder=True) + if self._observation_augmentation: + transform = image_augmentation.make_transform( + observation_transform=self._observation_augmentation + ) + dataset = dataset.map(transform, num_parallel_calls=16, deterministic=False) + dataset = dataset.prefetch(self._prefetch_size) + + counter = counting.Counter(counter, "learner") + logger = loggers.make_default_logger( + "learner", time_delta=self._log_every, steps_key="learner_steps" + ) + + # Create policy loss module if a factory is passed. + if self._policy_loss_factory: + policy_loss_module = self._policy_loss_factory() + else: + policy_loss_module = None + + # Return the learning agent. + return learning.DistributionalMPOLearner( + policy_network=online_networks["policy"], + critic_network=online_networks["critic"], + observation_network=observation_network, + target_policy_network=target_networks["policy"], + target_critic_network=target_networks["critic"], + target_observation_network=target_observation_network, + discount=self._additional_discount, + num_samples=self._num_samples, + target_policy_update_period=self._target_policy_update_period, + target_critic_update_period=self._target_critic_update_period, + policy_loss_module=policy_loss_module, + dataset=dataset, + counter=counter, + logger=logger, + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + actor_id: int, + ) -> acme.EnvironmentLoop: + """The actor process.""" + + action_spec = self._environment_spec.actions + observation_spec = self._environment_spec.observations + + # Create environment and target networks to act with. + environment = self._environment_factory(False) + agent_networks = self._network_factory(action_spec) + + # Make sure observation network is defined. + observation_network = agent_networks.get("observation", tf.identity) + + # Create a stochastic behavior policy. + behavior_network = snt.Sequential( + [ + observation_network, + agent_networks["policy"], + networks.StochasticSamplingHead(), + ] + ) + + # Ensure network variables are created. + tf2_utils.create_variables(behavior_network, [observation_spec]) + policy_variables = {"policy": behavior_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, + policy_variables, + update_period=self._variable_update_period, + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Component to add things into replay. + adder = adders.NStepTransitionAdder( + client=replay, n_step=self._n_step, discount=self._additional_discount + ) + + # Create the agent. + actor = actors.FeedForwardActor( + policy_network=behavior_network, + adder=adder, + variable_client=variable_client, + ) + + # Create logger and counter; only the first actor stores logs to bigtable. + save_data = actor_id == 0 + counter = counting.Counter(counter, "actor") + logger = loggers.make_default_logger( + "actor", + save_data=save_data, + time_delta=self._log_every, + steps_key="actor_steps", + ) + observers = self._make_observers() if self._make_observers else () + + # Create the run loop and return it. + return acme.EnvironmentLoop( + environment, actor, counter, logger, observers=observers + ) + + def evaluator( + self, variable_source: acme.VariableSource, counter: counting.Counter, + ): + """The evaluation process.""" + + action_spec = self._environment_spec.actions + observation_spec = self._environment_spec.observations + + # Create environment and target networks to act with. + environment = self._environment_factory(True) + agent_networks = self._network_factory(action_spec) + + # Make sure observation network is defined. + observation_network = agent_networks.get("observation", tf.identity) + + # Create a stochastic behavior policy. + evaluator_network = snt.Sequential( + [ + observation_network, + agent_networks["policy"], + networks.StochasticMeanHead(), + ] + ) + + # Ensure network variables are created. + tf2_utils.create_variables(evaluator_network, [observation_spec]) + policy_variables = {"policy": evaluator_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, + policy_variables, + update_period=self._variable_update_period, + ) + + # Make sure not to evaluate a random actor by assigning variables before + # running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + evaluator = actors.FeedForwardActor( + policy_network=evaluator_network, variable_client=variable_client + ) + + # Create logger and counter. + counter = counting.Counter(counter, "evaluator") + logger = loggers.make_default_logger( + "evaluator", time_delta=self._log_every, steps_key="evaluator_steps" + ) + observers = self._make_observers() if self._make_observers else () + + # Create the run loop and return it. + return acme.EnvironmentLoop( + environment, evaluator, counter, logger, observers=observers + ) + + def build(self, name="dmpo"): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group("replay"): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group("counter"): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + _ = program.add_node( + lp.CourierNode(self.coordinator, counter, self._max_actor_steps) + ) + + with program.group("learner"): + learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) + + with program.group("evaluator"): + program.add_node(lp.CourierNode(self.evaluator, learner, counter)) + + if not self._num_caches: + # Use our learner as a single variable source. + sources = [learner] + else: + with program.group("cacher"): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000 + ) + ) + sources.append(cacher) + + with program.group("actor"): + # Add actors which pull round-robin from our variable sources. + for actor_id in range(self._num_actors): + source = sources[actor_id % len(sources)] + program.add_node( + lp.CourierNode(self.actor, replay, source, counter, actor_id) + ) + + return program diff --git a/acme/agents/tf/dmpo/agent_distributed_test.py b/acme/agents/tf/dmpo/agent_distributed_test.py index 80085a42df..33429e7a52 100644 --- a/acme/agents/tf/dmpo/agent_distributed_test.py +++ b/acme/agents/tf/dmpo/agent_distributed_test.py @@ -16,82 +16,85 @@ from typing import Sequence +import launchpad as lp +import numpy as np +import sonnet as snt +from absl.testing import absltest + import acme from acme import specs from acme.agents.tf import dmpo from acme.testing import fakes from acme.tf import networks from acme.tf import utils as tf2_utils -import launchpad as lp -import numpy as np -import sonnet as snt - -from absl.testing import absltest def make_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (50,), critic_layer_sizes: Sequence[int] = (50,), - vmin: float = -150., - vmax: float = 150., + vmin: float = -150.0, + vmax: float = 150.0, num_atoms: int = 51, ): - """Creates networks used by the agent.""" - - num_dimensions = np.prod(action_spec.shape, dtype=int) - - policy_network = snt.Sequential([ - networks.LayerNormMLP(policy_layer_sizes, activate_final=True), - networks.MultivariateNormalDiagHead( - num_dimensions, - tanh_mean=True, - init_scale=0.3, - fixed_scale=True, - use_tfd_independent=False) - ]) - - # The multiplexer concatenates the (maybe transformed) observations/actions. - critic_network = networks.CriticMultiplexer( - critic_network=networks.LayerNormMLP( - critic_layer_sizes, activate_final=True), - action_network=networks.ClipToSpec(action_spec)) - critic_network = snt.Sequential( - [critic_network, - networks.DiscreteValuedHead(vmin, vmax, num_atoms)]) - - return { - 'policy': policy_network, - 'critic': critic_network, - 'observation': tf2_utils.batch_concat, - } + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential( + [ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + init_scale=0.3, + fixed_scale=True, + use_tfd_independent=False, + ), + ] + ) + + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = networks.CriticMultiplexer( + critic_network=networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + action_network=networks.ClipToSpec(action_spec), + ) + critic_network = snt.Sequential( + [critic_network, networks.DiscreteValuedHead(vmin, vmax, num_atoms)] + ) + + return { + "policy": policy_network, + "critic": critic_network, + "observation": tf2_utils.batch_concat, + } class DistributedAgentTest(absltest.TestCase): - """Simple integration/smoke test for the distributed agent.""" + """Simple integration/smoke test for the distributed agent.""" - def test_agent(self): + def test_agent(self): - agent = dmpo.DistributedDistributionalMPO( - environment_factory=lambda x: fakes.ContinuousEnvironment(bounded=True), - network_factory=make_networks, - num_actors=2, - batch_size=32, - min_replay_size=32, - max_replay_size=1000, - ) - program = agent.build() + agent = dmpo.DistributedDistributionalMPO( + environment_factory=lambda x: fakes.ContinuousEnvironment(bounded=True), + network_factory=make_networks, + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + ) + program = agent.build() - (learner_node,) = program.groups['learner'] - learner_node.disable_run() + (learner_node,) = program.groups["learner"] + learner_node.disable_run() - lp.launch(program, launch_type='test_mt') + lp.launch(program, launch_type="test_mt") - learner: acme.Learner = learner_node.create_handle().dereference() + learner: acme.Learner = learner_node.create_handle().dereference() - for _ in range(5): - learner.step() + for _ in range(5): + learner.step() -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/dmpo/agent_test.py b/acme/agents/tf/dmpo/agent_test.py index 366a35cfad..2f144982a7 100644 --- a/acme/agents/tf/dmpo/agent_test.py +++ b/acme/agents/tf/dmpo/agent_test.py @@ -16,15 +16,15 @@ from typing import Dict, Sequence +import numpy as np +import sonnet as snt +from absl.testing import absltest + import acme from acme import specs from acme.agents.tf import dmpo from acme.testing import fakes from acme.tf import networks -import numpy as np -import sonnet as snt - -from absl.testing import absltest def make_networks( @@ -32,52 +32,57 @@ def make_networks( policy_layer_sizes: Sequence[int] = (300, 200), critic_layer_sizes: Sequence[int] = (400, 300), ) -> Dict[str, snt.Module]: - """Creates networks used by the agent.""" - - num_dimensions = np.prod(action_spec.shape, dtype=int) - critic_layer_sizes = list(critic_layer_sizes) - - policy_network = snt.Sequential([ - networks.LayerNormMLP(policy_layer_sizes), - networks.MultivariateNormalDiagHead(num_dimensions), - ]) - # The multiplexer concatenates the (maybe transformed) observations/actions. - critic_network = snt.Sequential([ - networks.CriticMultiplexer( - critic_network=networks.LayerNormMLP(critic_layer_sizes)), - networks.DiscreteValuedHead(0., 1., 10), - ]) - - return { - 'policy': policy_network, - 'critic': critic_network, - } + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + critic_layer_sizes = list(critic_layer_sizes) + + policy_network = snt.Sequential( + [ + networks.LayerNormMLP(policy_layer_sizes), + networks.MultivariateNormalDiagHead(num_dimensions), + ] + ) + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = snt.Sequential( + [ + networks.CriticMultiplexer( + critic_network=networks.LayerNormMLP(critic_layer_sizes) + ), + networks.DiscreteValuedHead(0.0, 1.0, 10), + ] + ) + + return { + "policy": policy_network, + "critic": critic_network, + } class DMPOTest(absltest.TestCase): - - def test_dmpo(self): - # Create a fake environment to test with. - environment = fakes.ContinuousEnvironment(episode_length=10) - spec = specs.make_environment_spec(environment) - - # Create networks. - agent_networks = make_networks(spec.actions) - - # Construct the agent. - agent = dmpo.DistributionalMPO( - spec, - policy_network=agent_networks['policy'], - critic_network=agent_networks['critic'], - batch_size=10, - samples_per_insert=2, - min_replay_size=10) - - # Try running the environment loop. We have no assertions here because all - # we care about is that the agent runs without raising any errors. - loop = acme.EnvironmentLoop(environment, agent) - loop.run(num_episodes=2) - - -if __name__ == '__main__': - absltest.main() + def test_dmpo(self): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment(episode_length=10) + spec = specs.make_environment_spec(environment) + + # Create networks. + agent_networks = make_networks(spec.actions) + + # Construct the agent. + agent = dmpo.DistributionalMPO( + spec, + policy_network=agent_networks["policy"], + critic_network=agent_networks["critic"], + batch_size=10, + samples_per_insert=2, + min_replay_size=10, + ) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/dmpo/learning.py b/acme/agents/tf/dmpo/learning.py index 1812297fa7..0291492e7c 100644 --- a/acme/agents/tf/dmpo/learning.py +++ b/acme/agents/tf/dmpo/learning.py @@ -17,284 +17,295 @@ import time from typing import List, Optional +import numpy as np +import sonnet as snt +import tensorflow as tf + import acme from acme import types -from acme.tf import losses -from acme.tf import networks +from acme.tf import losses, networks from acme.tf import savers as tf2_savers from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers -import numpy as np -import sonnet as snt -import tensorflow as tf +from acme.utils import counting, loggers class DistributionalMPOLearner(acme.Learner): - """Distributional MPO learner.""" - - def __init__( - self, - policy_network: snt.Module, - critic_network: snt.Module, - target_policy_network: snt.Module, - target_critic_network: snt.Module, - discount: float, - num_samples: int, - target_policy_update_period: int, - target_critic_update_period: int, - dataset: tf.data.Dataset, - observation_network: types.TensorTransformation = tf.identity, - target_observation_network: types.TensorTransformation = tf.identity, - policy_loss_module: Optional[snt.Module] = None, - policy_optimizer: Optional[snt.Optimizer] = None, - critic_optimizer: Optional[snt.Optimizer] = None, - dual_optimizer: Optional[snt.Optimizer] = None, - clipping: bool = True, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - checkpoint: bool = True, - ): - - # Store online and target networks. - self._policy_network = policy_network - self._critic_network = critic_network - self._target_policy_network = target_policy_network - self._target_critic_network = target_critic_network - - # Make sure observation networks are snt.Module's so they have variables. - self._observation_network = tf2_utils.to_sonnet_module(observation_network) - self._target_observation_network = tf2_utils.to_sonnet_module( - target_observation_network) - - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger('learner') - - # Other learner parameters. - self._discount = discount - self._num_samples = num_samples - self._clipping = clipping - - # Necessary to track when to update target networks. - self._num_steps = tf.Variable(0, dtype=tf.int32) - self._target_policy_update_period = target_policy_update_period - self._target_critic_update_period = target_critic_update_period - - # Batch dataset and create iterator. - # TODO(b/155086959): Fix type stubs and remove. - self._iterator = iter(dataset) # pytype: disable=wrong-arg-types - - self._policy_loss_module = policy_loss_module or losses.MPO( - epsilon=1e-1, - epsilon_penalty=1e-3, - epsilon_mean=2.5e-3, - epsilon_stddev=1e-6, - init_log_temperature=10., - init_log_alpha_mean=10., - init_log_alpha_stddev=1000.) - - # Create the optimizers. - self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) - self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) - self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) - - # Expose the variables. - policy_network_to_expose = snt.Sequential( - [self._target_observation_network, self._target_policy_network]) - self._variables = { - 'critic': self._target_critic_network.variables, - 'policy': policy_network_to_expose.variables, - } - - # Create a checkpointer and snapshotter object. - self._checkpointer = None - self._snapshotter = None - - if checkpoint: - self._checkpointer = tf2_savers.Checkpointer( - subdirectory='dmpo_learner', - objects_to_save={ - 'counter': self._counter, - 'policy': self._policy_network, - 'critic': self._critic_network, - 'observation': self._observation_network, - 'target_policy': self._target_policy_network, - 'target_critic': self._target_critic_network, - 'target_observation': self._target_observation_network, - 'policy_optimizer': self._policy_optimizer, - 'critic_optimizer': self._critic_optimizer, - 'dual_optimizer': self._dual_optimizer, - 'policy_loss_module': self._policy_loss_module, - 'num_steps': self._num_steps, - }) - - self._snapshotter = tf2_savers.Snapshotter( - objects_to_save={ - 'policy': - snt.Sequential([ - self._target_observation_network, - self._target_policy_network - ]), - }) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - @tf.function - def _step(self) -> types.NestedTensor: - # Update target network. - online_policy_variables = self._policy_network.variables - target_policy_variables = self._target_policy_network.variables - online_critic_variables = ( - *self._observation_network.variables, - *self._critic_network.variables, - ) - target_critic_variables = ( - *self._target_observation_network.variables, - *self._target_critic_network.variables, - ) - - # Make online policy -> target policy network update ops. - if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: - for src, dest in zip(online_policy_variables, target_policy_variables): - dest.assign(src) - # Make online critic -> target critic network update ops. - if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: - for src, dest in zip(online_critic_variables, target_critic_variables): - dest.assign(src) - - self._num_steps.assign_add(1) - - # Get data from replay (dropping extras if any). Note there is no - # extra data here because we do not insert any into Reverb. - inputs = next(self._iterator) - transitions: types.Transition = inputs.data - - # Get batch size and scalar dtype. - batch_size = transitions.reward.shape[0] - - # Cast the additional discount to match the environment discount dtype. - discount = tf.cast(self._discount, dtype=transitions.discount.dtype) - - with tf.GradientTape(persistent=True) as tape: - # Maybe transform the observation before feeding into policy and critic. - # Transforming the observations this way at the start of the learning - # step effectively means that the policy and critic share observation - # network weights. - o_tm1 = self._observation_network(transitions.observation) - # This stop_gradient prevents gradients to propagate into the target - # observation network. In addition, since the online policy network is - # evaluated at o_t, this also means the policy loss does not influence - # the observation network training. - o_t = tf.stop_gradient( - self._target_observation_network(transitions.next_observation)) - - # Get online and target action distributions from policy networks. - online_action_distribution = self._policy_network(o_t) - target_action_distribution = self._target_policy_network(o_t) - - # Sample actions to evaluate policy; of size [N, B, ...]. - sampled_actions = target_action_distribution.sample(self._num_samples) - - # Tile embedded observations to feed into the target critic network. - # Note: this is more efficient than tiling before the embedding layer. - tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] - - # Compute target-estimated distributional value of sampled actions at o_t. - sampled_q_t_distributions = self._target_critic_network( - # Merge batch dimensions; to shape [N*B, ...]. - snt.merge_leading_dims(tiled_o_t, num_dims=2), - snt.merge_leading_dims(sampled_actions, num_dims=2)) - - # Compute average logits by first reshaping them and normalizing them - # across atoms. - new_shape = [self._num_samples, batch_size, -1] # [N, B, A] - sampled_logits = tf.reshape(sampled_q_t_distributions.logits, new_shape) - sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) - averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) - - # Construct the expected distributional value for bootstrapping. - q_t_distribution = networks.DiscreteValuedDistribution( - values=sampled_q_t_distributions.values, logits=averaged_logits) - - # Compute online critic value distribution of a_tm1 in state o_tm1. - q_tm1_distribution = self._critic_network(o_tm1, transitions.action) - - # Compute critic distributional loss. - critic_loss = losses.categorical(q_tm1_distribution, transitions.reward, - discount * transitions.discount, - q_t_distribution) - critic_loss = tf.reduce_mean(critic_loss) - - # Compute Q-values of sampled actions and reshape to [N, B]. - sampled_q_values = sampled_q_t_distributions.mean() - sampled_q_values = tf.reshape(sampled_q_values, (self._num_samples, -1)) - - # Compute MPO policy loss. - policy_loss, policy_stats = self._policy_loss_module( - online_action_distribution=online_action_distribution, - target_action_distribution=target_action_distribution, - actions=sampled_actions, - q_values=sampled_q_values) - - # For clarity, explicitly define which variables are trained by which loss. - critic_trainable_variables = ( - # In this agent, the critic loss trains the observation network. - self._observation_network.trainable_variables + - self._critic_network.trainable_variables) - policy_trainable_variables = self._policy_network.trainable_variables - # The following are the MPO dual variables, stored in the loss module. - dual_trainable_variables = self._policy_loss_module.trainable_variables - - # Compute gradients. - critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) - policy_gradients, dual_gradients = tape.gradient( - policy_loss, (policy_trainable_variables, dual_trainable_variables)) - - # Delete the tape manually because of the persistent=True flag. - del tape - - # Maybe clip gradients. - if self._clipping: - policy_gradients = tuple(tf.clip_by_global_norm(policy_gradients, 40.)[0]) - critic_gradients = tuple(tf.clip_by_global_norm(critic_gradients, 40.)[0]) - - # Apply gradients. - self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) - self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) - self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) - - # Losses to track. - fetches = { - 'critic_loss': critic_loss, - 'policy_loss': policy_loss, - } - fetches.update(policy_stats) # Log MPO stats. - - return fetches - - def step(self): - # Run the learning step. - fetches = self._step() - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Update our counts and record it. - counts = self._counter.increment(steps=1, walltime=elapsed_time) - fetches.update(counts) - - # Checkpoint and attempt to write the logs. - if self._checkpointer is not None: - self._checkpointer.save() - if self._snapshotter is not None: - self._snapshotter.save() - self._logger.write(fetches) - - def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: - return [tf2_utils.to_numpy(self._variables[name]) for name in names] + """Distributional MPO learner.""" + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + target_policy_network: snt.Module, + target_critic_network: snt.Module, + discount: float, + num_samples: int, + target_policy_update_period: int, + target_critic_update_period: int, + dataset: tf.data.Dataset, + observation_network: types.TensorTransformation = tf.identity, + target_observation_network: types.TensorTransformation = tf.identity, + policy_loss_module: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + dual_optimizer: Optional[snt.Optimizer] = None, + clipping: bool = True, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + + # Make sure observation networks are snt.Module's so they have variables. + self._observation_network = tf2_utils.to_sonnet_module(observation_network) + self._target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network + ) + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger("learner") + + # Other learner parameters. + self._discount = discount + self._num_samples = num_samples + self._clipping = clipping + + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_policy_update_period = target_policy_update_period + self._target_critic_update_period = target_critic_update_period + + # Batch dataset and create iterator. + # TODO(b/155086959): Fix type stubs and remove. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + + self._policy_loss_module = policy_loss_module or losses.MPO( + epsilon=1e-1, + epsilon_penalty=1e-3, + epsilon_mean=2.5e-3, + epsilon_stddev=1e-6, + init_log_temperature=10.0, + init_log_alpha_mean=10.0, + init_log_alpha_stddev=1000.0, + ) + + # Create the optimizers. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) + + # Expose the variables. + policy_network_to_expose = snt.Sequential( + [self._target_observation_network, self._target_policy_network] + ) + self._variables = { + "critic": self._target_critic_network.variables, + "policy": policy_network_to_expose.variables, + } + + # Create a checkpointer and snapshotter object. + self._checkpointer = None + self._snapshotter = None + + if checkpoint: + self._checkpointer = tf2_savers.Checkpointer( + subdirectory="dmpo_learner", + objects_to_save={ + "counter": self._counter, + "policy": self._policy_network, + "critic": self._critic_network, + "observation": self._observation_network, + "target_policy": self._target_policy_network, + "target_critic": self._target_critic_network, + "target_observation": self._target_observation_network, + "policy_optimizer": self._policy_optimizer, + "critic_optimizer": self._critic_optimizer, + "dual_optimizer": self._dual_optimizer, + "policy_loss_module": self._policy_loss_module, + "num_steps": self._num_steps, + }, + ) + + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={ + "policy": snt.Sequential( + [self._target_observation_network, self._target_policy_network] + ), + } + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self) -> types.NestedTensor: + # Update target network. + online_policy_variables = self._policy_network.variables + target_policy_variables = self._target_policy_network.variables + online_critic_variables = ( + *self._observation_network.variables, + *self._critic_network.variables, + ) + target_critic_variables = ( + *self._target_observation_network.variables, + *self._target_critic_network.variables, + ) + + # Make online policy -> target policy network update ops. + if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: + for src, dest in zip(online_policy_variables, target_policy_variables): + dest.assign(src) + # Make online critic -> target critic network update ops. + if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: + for src, dest in zip(online_critic_variables, target_critic_variables): + dest.assign(src) + + self._num_steps.assign_add(1) + + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + inputs = next(self._iterator) + transitions: types.Transition = inputs.data + + # Get batch size and scalar dtype. + batch_size = transitions.reward.shape[0] + + # Cast the additional discount to match the environment discount dtype. + discount = tf.cast(self._discount, dtype=transitions.discount.dtype) + + with tf.GradientTape(persistent=True) as tape: + # Maybe transform the observation before feeding into policy and critic. + # Transforming the observations this way at the start of the learning + # step effectively means that the policy and critic share observation + # network weights. + o_tm1 = self._observation_network(transitions.observation) + # This stop_gradient prevents gradients to propagate into the target + # observation network. In addition, since the online policy network is + # evaluated at o_t, this also means the policy loss does not influence + # the observation network training. + o_t = tf.stop_gradient( + self._target_observation_network(transitions.next_observation) + ) + + # Get online and target action distributions from policy networks. + online_action_distribution = self._policy_network(o_t) + target_action_distribution = self._target_policy_network(o_t) + + # Sample actions to evaluate policy; of size [N, B, ...]. + sampled_actions = target_action_distribution.sample(self._num_samples) + + # Tile embedded observations to feed into the target critic network. + # Note: this is more efficient than tiling before the embedding layer. + tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] + + # Compute target-estimated distributional value of sampled actions at o_t. + sampled_q_t_distributions = self._target_critic_network( + # Merge batch dimensions; to shape [N*B, ...]. + snt.merge_leading_dims(tiled_o_t, num_dims=2), + snt.merge_leading_dims(sampled_actions, num_dims=2), + ) + + # Compute average logits by first reshaping them and normalizing them + # across atoms. + new_shape = [self._num_samples, batch_size, -1] # [N, B, A] + sampled_logits = tf.reshape(sampled_q_t_distributions.logits, new_shape) + sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) + averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) + + # Construct the expected distributional value for bootstrapping. + q_t_distribution = networks.DiscreteValuedDistribution( + values=sampled_q_t_distributions.values, logits=averaged_logits + ) + + # Compute online critic value distribution of a_tm1 in state o_tm1. + q_tm1_distribution = self._critic_network(o_tm1, transitions.action) + + # Compute critic distributional loss. + critic_loss = losses.categorical( + q_tm1_distribution, + transitions.reward, + discount * transitions.discount, + q_t_distribution, + ) + critic_loss = tf.reduce_mean(critic_loss) + + # Compute Q-values of sampled actions and reshape to [N, B]. + sampled_q_values = sampled_q_t_distributions.mean() + sampled_q_values = tf.reshape(sampled_q_values, (self._num_samples, -1)) + + # Compute MPO policy loss. + policy_loss, policy_stats = self._policy_loss_module( + online_action_distribution=online_action_distribution, + target_action_distribution=target_action_distribution, + actions=sampled_actions, + q_values=sampled_q_values, + ) + + # For clarity, explicitly define which variables are trained by which loss. + critic_trainable_variables = ( + # In this agent, the critic loss trains the observation network. + self._observation_network.trainable_variables + + self._critic_network.trainable_variables + ) + policy_trainable_variables = self._policy_network.trainable_variables + # The following are the MPO dual variables, stored in the loss module. + dual_trainable_variables = self._policy_loss_module.trainable_variables + + # Compute gradients. + critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) + policy_gradients, dual_gradients = tape.gradient( + policy_loss, (policy_trainable_variables, dual_trainable_variables) + ) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Maybe clip gradients. + if self._clipping: + policy_gradients = tuple(tf.clip_by_global_norm(policy_gradients, 40.0)[0]) + critic_gradients = tuple(tf.clip_by_global_norm(critic_gradients, 40.0)[0]) + + # Apply gradients. + self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) + self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) + self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) + + # Losses to track. + fetches = { + "critic_loss": critic_loss, + "policy_loss": policy_loss, + } + fetches.update(policy_stats) # Log MPO stats. + + return fetches + + def step(self): + # Run the learning step. + fetches = self._step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + if self._checkpointer is not None: + self._checkpointer.save() + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] diff --git a/acme/agents/tf/dqfd/agent.py b/acme/agents/tf/dqfd/agent.py index 8b9d21aa62..014ced1363 100644 --- a/acme/agents/tf/dqfd/agent.py +++ b/acme/agents/tf/dqfd/agent.py @@ -19,47 +19,46 @@ import operator from typing import Optional -from acme import datasets -from acme import specs -from acme import types as acme_types -from acme.adders import reverb as adders -from acme.agents import agent -from acme.agents.tf import actors -from acme.agents.tf import dqn -from acme.tf import utils as tf2_utils import reverb import sonnet as snt import tensorflow as tf import tree import trfl +from acme import datasets, specs +from acme import types as acme_types +from acme.adders import reverb as adders +from acme.agents import agent +from acme.agents.tf import actors, dqn +from acme.tf import utils as tf2_utils + class DQfD(agent.Agent): - """DQfD agent. + """DQfD agent. This implements a single-process DQN agent that mixes demonstrations with actor experience. """ - def __init__( - self, - environment_spec: specs.EnvironmentSpec, - network: snt.Module, - demonstration_dataset: tf.data.Dataset, - demonstration_ratio: float, - batch_size: int = 256, - prefetch_size: int = 4, - target_update_period: int = 100, - samples_per_insert: float = 32.0, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - importance_sampling_exponent: float = 0.2, - n_step: int = 5, - epsilon: Optional[tf.Tensor] = None, - learning_rate: float = 1e-3, - discount: float = 0.99, - ): - """Initialize the agent. + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + network: snt.Module, + demonstration_dataset: tf.data.Dataset, + demonstration_ratio: float, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + samples_per_insert: float = 32.0, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + importance_sampling_exponent: float = 0.2, + n_step: int = 5, + epsilon: Optional[tf.Tensor] = None, + learning_rate: float = 1e-3, + discount: float = 0.99, + ): + """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. @@ -86,84 +85,87 @@ def __init__( discount: discount to use for TD updates. """ - # Create a replay server to add data to. This uses no limiter behavior in - # order to allow the Agent interface to handle it. - replay_table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=max_replay_size, - rate_limiter=reverb.rate_limiters.MinSize(1), - signature=adders.NStepTransitionAdder.signature(environment_spec)) - self._server = reverb.Server([replay_table], port=None) - - # The adder is used to insert observations into replay. - address = f'localhost:{self._server.port}' - adder = adders.NStepTransitionAdder( - client=reverb.Client(address), - n_step=n_step, - discount=discount) - - # The dataset provides an interface to sample from replay. - replay_client = reverb.TFClient(address) - dataset = datasets.make_reverb_dataset(server_address=address) - - # Combine with demonstration dataset. - transition = functools.partial(_n_step_transition_from_episode, - n_step=n_step, - discount=discount) - dataset_demos = demonstration_dataset.map(transition) - dataset = tf.data.experimental.sample_from_datasets( - [dataset, dataset_demos], - [1 - demonstration_ratio, demonstration_ratio]) - - # Batch and prefetch. - dataset = dataset.batch(batch_size, drop_remainder=True) - dataset = dataset.prefetch(prefetch_size) - - # Use constant 0.05 epsilon greedy policy by default. - if epsilon is None: - epsilon = tf.Variable(0.05, trainable=False) - policy_network = snt.Sequential([ - network, - lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), - ]) - - # Create a target network. - target_network = copy.deepcopy(network) - - # Ensure that we create the variables before proceeding (maybe not needed). - tf2_utils.create_variables(network, [environment_spec.observations]) - tf2_utils.create_variables(target_network, [environment_spec.observations]) - - # Create the actor which defines how we take actions. - actor = actors.FeedForwardActor(policy_network, adder) - - # The learner updates the parameters (and initializes them). - learner = dqn.DQNLearner( - network=network, - target_network=target_network, - discount=discount, - importance_sampling_exponent=importance_sampling_exponent, - learning_rate=learning_rate, - target_update_period=target_update_period, - dataset=dataset, - replay_client=replay_client) - - super().__init__( - actor=actor, - learner=learner, - min_observations=max(batch_size, min_replay_size), - observations_per_step=float(batch_size) / samples_per_insert) - - -def _n_step_transition_from_episode(observations: acme_types.NestedTensor, - actions: tf.Tensor, - rewards: tf.Tensor, - discounts: tf.Tensor, - n_step: int, - discount: float): - """Produce Reverb-like N-step transition from a full episode. + # Create a replay server to add data to. This uses no limiter behavior in + # order to allow the Agent interface to handle it. + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(1), + signature=adders.NStepTransitionAdder.signature(environment_spec), + ) + self._server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f"localhost:{self._server.port}" + adder = adders.NStepTransitionAdder( + client=reverb.Client(address), n_step=n_step, discount=discount + ) + + # The dataset provides an interface to sample from replay. + replay_client = reverb.TFClient(address) + dataset = datasets.make_reverb_dataset(server_address=address) + + # Combine with demonstration dataset. + transition = functools.partial( + _n_step_transition_from_episode, n_step=n_step, discount=discount + ) + dataset_demos = demonstration_dataset.map(transition) + dataset = tf.data.experimental.sample_from_datasets( + [dataset, dataset_demos], [1 - demonstration_ratio, demonstration_ratio] + ) + + # Batch and prefetch. + dataset = dataset.batch(batch_size, drop_remainder=True) + dataset = dataset.prefetch(prefetch_size) + + # Use constant 0.05 epsilon greedy policy by default. + if epsilon is None: + epsilon = tf.Variable(0.05, trainable=False) + policy_network = snt.Sequential( + [network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),] + ) + + # Create a target network. + target_network = copy.deepcopy(network) + + # Ensure that we create the variables before proceeding (maybe not needed). + tf2_utils.create_variables(network, [environment_spec.observations]) + tf2_utils.create_variables(target_network, [environment_spec.observations]) + + # Create the actor which defines how we take actions. + actor = actors.FeedForwardActor(policy_network, adder) + + # The learner updates the parameters (and initializes them). + learner = dqn.DQNLearner( + network=network, + target_network=target_network, + discount=discount, + importance_sampling_exponent=importance_sampling_exponent, + learning_rate=learning_rate, + target_update_period=target_update_period, + dataset=dataset, + replay_client=replay_client, + ) + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert, + ) + + +def _n_step_transition_from_episode( + observations: acme_types.NestedTensor, + actions: tf.Tensor, + rewards: tf.Tensor, + discounts: tf.Tensor, + n_step: int, + discount: float, +): + """Produce Reverb-like N-step transition from a full episode. Observations, actions, rewards and discounts have the same length. This function will ignore the first reward and discount and the last action. @@ -180,32 +182,33 @@ def _n_step_transition_from_episode(observations: acme_types.NestedTensor, (o_t, a_t, r_t, d_t, o_tp1) tuple. """ - max_index = tf.shape(rewards)[0] - 1 - first = tf.random.uniform(shape=(), minval=0, maxval=max_index - 1, - dtype=tf.int32) - last = tf.minimum(first + n_step, max_index) - - o_t = tree.map_structure(operator.itemgetter(first), observations) - a_t = tree.map_structure(operator.itemgetter(first), actions) - o_tp1 = tree.map_structure(operator.itemgetter(last), observations) - - # 0, 1, ..., n-1. - discount_range = tf.cast(tf.range(last - first), tf.float32) - # 1, g, ..., g^{n-1}. - additional_discounts = tf.pow(discount, discount_range) - # 1, d_t, d_t * d_{t+1}, ..., d_t * ... * d_{t+n-2}. - discounts = tf.concat([[1.], tf.math.cumprod(discounts[first:last-1])], 0) - # 1, g * d_t, ..., g^{n-1} * d_t * ... * d_{t+n-2}. - discounts *= additional_discounts - # r_t + g * d_t * r_{t+1} + ... + g^{n-1} * d_t * ... * d_{t+n-2} * r_{t+n-1} - # We have to shift rewards by one so last=max_index corresponds to transitions - # that include the last reward. - r_t = tf.reduce_sum(rewards[first+1:last+1] * discounts) - - # g^{n-1} * d_{t} * ... * d_{t+n-1}. - d_t = discounts[-1] - - info = tree.map_structure(lambda dtype: tf.ones([], dtype), - reverb.SampleInfo.tf_dtypes()) - return reverb.ReplaySample( - info=info, data=acme_types.Transition(o_t, a_t, r_t, d_t, o_tp1)) + max_index = tf.shape(rewards)[0] - 1 + first = tf.random.uniform(shape=(), minval=0, maxval=max_index - 1, dtype=tf.int32) + last = tf.minimum(first + n_step, max_index) + + o_t = tree.map_structure(operator.itemgetter(first), observations) + a_t = tree.map_structure(operator.itemgetter(first), actions) + o_tp1 = tree.map_structure(operator.itemgetter(last), observations) + + # 0, 1, ..., n-1. + discount_range = tf.cast(tf.range(last - first), tf.float32) + # 1, g, ..., g^{n-1}. + additional_discounts = tf.pow(discount, discount_range) + # 1, d_t, d_t * d_{t+1}, ..., d_t * ... * d_{t+n-2}. + discounts = tf.concat([[1.0], tf.math.cumprod(discounts[first : last - 1])], 0) + # 1, g * d_t, ..., g^{n-1} * d_t * ... * d_{t+n-2}. + discounts *= additional_discounts + #  r_t + g * d_t * r_{t+1} + ... + g^{n-1} * d_t * ... * d_{t+n-2} * r_{t+n-1} + # We have to shift rewards by one so last=max_index corresponds to transitions + # that include the last reward. + r_t = tf.reduce_sum(rewards[first + 1 : last + 1] * discounts) + + # g^{n-1} * d_{t} * ... * d_{t+n-1}. + d_t = discounts[-1] + + info = tree.map_structure( + lambda dtype: tf.ones([], dtype), reverb.SampleInfo.tf_dtypes() + ) + return reverb.ReplaySample( + info=info, data=acme_types.Transition(o_t, a_t, r_t, d_t, o_tp1) + ) diff --git a/acme/agents/tf/dqfd/agent_test.py b/acme/agents/tf/dqfd/agent_test.py index 9b7d8c5cd0..faacbd6fd9 100644 --- a/acme/agents/tf/dqfd/agent_test.py +++ b/acme/agents/tf/dqfd/agent_test.py @@ -14,62 +14,59 @@ """Tests for DQN agent.""" +import dm_env +import numpy as np +import sonnet as snt +from absl.testing import absltest + import acme from acme import specs from acme.agents.tf.dqfd import agent as dqfd from acme.agents.tf.dqfd import bsuite_demonstrations from acme.testing import fakes -import dm_env -import numpy as np -import sonnet as snt - -from absl.testing import absltest def _make_network(action_spec: specs.DiscreteArray) -> snt.Module: - return snt.Sequential([ - snt.Flatten(), - snt.nets.MLP([50, 50, action_spec.num_values]), - ]) + return snt.Sequential( + [snt.Flatten(), snt.nets.MLP([50, 50, action_spec.num_values]),] + ) class DQfDTest(absltest.TestCase): + def test_dqfd(self): + # Create a fake environment to test with. + # TODO(b/152596848): Allow DQN to deal with integer observations. + environment = fakes.DiscreteEnvironment( + num_actions=5, num_observations=10, obs_dtype=np.float32, episode_length=10 + ) + spec = specs.make_environment_spec(environment) - def test_dqfd(self): - # Create a fake environment to test with. - # TODO(b/152596848): Allow DQN to deal with integer observations. - environment = fakes.DiscreteEnvironment( - num_actions=5, - num_observations=10, - obs_dtype=np.float32, - episode_length=10) - spec = specs.make_environment_spec(environment) - - # Build demonstrations. - dummy_action = np.zeros((), dtype=np.int32) - recorder = bsuite_demonstrations.DemonstrationRecorder() - timestep = environment.reset() - while timestep.step_type is not dm_env.StepType.LAST: - recorder.step(timestep, dummy_action) - timestep = environment.step(dummy_action) - recorder.step(timestep, dummy_action) - recorder.record_episode() + # Build demonstrations. + dummy_action = np.zeros((), dtype=np.int32) + recorder = bsuite_demonstrations.DemonstrationRecorder() + timestep = environment.reset() + while timestep.step_type is not dm_env.StepType.LAST: + recorder.step(timestep, dummy_action) + timestep = environment.step(dummy_action) + recorder.step(timestep, dummy_action) + recorder.record_episode() - # Construct the agent. - agent = dqfd.DQfD( - environment_spec=spec, - network=_make_network(spec.actions), - demonstration_dataset=recorder.make_tf_dataset(), - demonstration_ratio=0.5, - batch_size=10, - samples_per_insert=2, - min_replay_size=10) + # Construct the agent. + agent = dqfd.DQfD( + environment_spec=spec, + network=_make_network(spec.actions), + demonstration_dataset=recorder.make_tf_dataset(), + demonstration_ratio=0.5, + batch_size=10, + samples_per_insert=2, + min_replay_size=10, + ) - # Try running the environment loop. We have no assertions here because all - # we care about is that the agent runs without raising any errors. - loop = acme.EnvironmentLoop(environment, agent) - loop.run(num_episodes=10) + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=10) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/dqfd/bsuite_demonstrations.py b/acme/agents/tf/dqfd/bsuite_demonstrations.py index 67c9a8d55d..12b73027c6 100644 --- a/acme/agents/tf/dqfd/bsuite_demonstrations.py +++ b/acme/agents/tf/dqfd/bsuite_demonstrations.py @@ -16,120 +16,128 @@ from typing import Any, List -from absl import flags -from bsuite.environments import deep_sea import dm_env import numpy as np import tensorflow as tf import tree +from absl import flags +from bsuite.environments import deep_sea FLAGS = flags.FLAGS def _nested_stack(sequence: List[Any]): - """Stack nested elements in a sequence.""" - return tree.map_structure(lambda *x: np.stack(x), *sequence) + """Stack nested elements in a sequence.""" + return tree.map_structure(lambda *x: np.stack(x), *sequence) class DemonstrationRecorder: - """Records demonstrations. + """Records demonstrations. A demonstration is a (observation, action, reward, discount) tuple where every element is a numpy array corresponding to a full episode. """ - def __init__(self): - self._demos = [] - self._reset_episode() - - def step(self, timestep: dm_env.TimeStep, action: np.ndarray): - reward = np.array(timestep.reward or 0, np.float32) - self._episode_reward += reward - self._episode.append((timestep.observation, action, reward, - np.array(timestep.discount or 0, np.float32))) - - def record_episode(self): - self._demos.append(_nested_stack(self._episode)) - self._reset_episode() - - def discard_episode(self): - self._reset_episode() - - def _reset_episode(self): - self._episode = [] - self._episode_reward = 0 - - @property - def episode_reward(self): - return self._episode_reward - - def make_tf_dataset(self): - types = tree.map_structure(lambda x: x.dtype, self._demos[0]) - shapes = tree.map_structure(lambda x: x.shape, self._demos[0]) - ds = tf.data.Dataset.from_generator(lambda: self._demos, types, shapes) - return ds.repeat().shuffle(len(self._demos)) - - -def _optimal_deep_sea_policy(environment: deep_sea.DeepSea, - timestep: dm_env.TimeStep): - action = environment._action_mapping[np.where(timestep.observation)] # pylint: disable=protected-access - return action[0].astype(np.int32) - - -def _run_optimal_deep_sea_episode(environment: deep_sea.DeepSea, - recorder: DemonstrationRecorder): - timestep = environment.reset() - while timestep.step_type is not dm_env.StepType.LAST: - action = _optimal_deep_sea_policy(environment, timestep) - recorder.step(timestep, action) - timestep = environment.step(action) - recorder.step(timestep, np.zeros_like(action)) + def __init__(self): + self._demos = [] + self._reset_episode() + + def step(self, timestep: dm_env.TimeStep, action: np.ndarray): + reward = np.array(timestep.reward or 0, np.float32) + self._episode_reward += reward + self._episode.append( + ( + timestep.observation, + action, + reward, + np.array(timestep.discount or 0, np.float32), + ) + ) + + def record_episode(self): + self._demos.append(_nested_stack(self._episode)) + self._reset_episode() + + def discard_episode(self): + self._reset_episode() + + def _reset_episode(self): + self._episode = [] + self._episode_reward = 0 + + @property + def episode_reward(self): + return self._episode_reward + + def make_tf_dataset(self): + types = tree.map_structure(lambda x: x.dtype, self._demos[0]) + shapes = tree.map_structure(lambda x: x.shape, self._demos[0]) + ds = tf.data.Dataset.from_generator(lambda: self._demos, types, shapes) + return ds.repeat().shuffle(len(self._demos)) + + +def _optimal_deep_sea_policy(environment: deep_sea.DeepSea, timestep: dm_env.TimeStep): + action = environment._action_mapping[ + np.where(timestep.observation) + ] # pylint: disable=protected-access + return action[0].astype(np.int32) + + +def _run_optimal_deep_sea_episode( + environment: deep_sea.DeepSea, recorder: DemonstrationRecorder +): + timestep = environment.reset() + while timestep.step_type is not dm_env.StepType.LAST: + action = _optimal_deep_sea_policy(environment, timestep) + recorder.step(timestep, action) + timestep = environment.step(action) + recorder.step(timestep, np.zeros_like(action)) def _make_deep_sea_dataset(environment: deep_sea.DeepSea): - """Make DeepSea demonstration dataset.""" + """Make DeepSea demonstration dataset.""" - recorder = DemonstrationRecorder() + recorder = DemonstrationRecorder() - _run_optimal_deep_sea_episode(environment, recorder) - assert recorder.episode_reward > 0 - recorder.record_episode() - return recorder.make_tf_dataset() + _run_optimal_deep_sea_episode(environment, recorder) + assert recorder.episode_reward > 0 + recorder.record_episode() + return recorder.make_tf_dataset() def _make_deep_sea_stochastic_dataset(environment: deep_sea.DeepSea): - """Make stochastic DeepSea demonstration dataset.""" + """Make stochastic DeepSea demonstration dataset.""" - recorder = DemonstrationRecorder() + recorder = DemonstrationRecorder() - # Use 10*size demos, 80% success, 20% failure. - num_demos = environment._size * 10 # pylint: disable=protected-access - num_failures = num_demos // 5 - num_successes = num_demos - num_failures + # Use 10*size demos, 80% success, 20% failure. + num_demos = environment._size * 10 # pylint: disable=protected-access + num_failures = num_demos // 5 + num_successes = num_demos - num_failures - successes_saved = 0 - failures_saved = 0 - while (successes_saved < num_successes) or (failures_saved < num_failures): - _run_optimal_deep_sea_episode(environment, recorder) + successes_saved = 0 + failures_saved = 0 + while (successes_saved < num_successes) or (failures_saved < num_failures): + _run_optimal_deep_sea_episode(environment, recorder) - if recorder.episode_reward > 0 and successes_saved < num_successes: - recorder.record_episode() - successes_saved += 1 - elif recorder.episode_reward <= 0 and failures_saved < num_failures: - recorder.record_episode() - failures_saved += 1 - else: - recorder.discard_episode() + if recorder.episode_reward > 0 and successes_saved < num_successes: + recorder.record_episode() + successes_saved += 1 + elif recorder.episode_reward <= 0 and failures_saved < num_failures: + recorder.record_episode() + failures_saved += 1 + else: + recorder.discard_episode() - return recorder.make_tf_dataset() + return recorder.make_tf_dataset() def make_dataset(environment: dm_env.Environment, stochastic: bool): - """Make bsuite demos for the current task.""" - - if not stochastic: - assert isinstance(environment, deep_sea.DeepSea) - return _make_deep_sea_dataset(environment) - else: - assert isinstance(environment, deep_sea.DeepSea) - return _make_deep_sea_stochastic_dataset(environment) + """Make bsuite demos for the current task.""" + + if not stochastic: + assert isinstance(environment, deep_sea.DeepSea) + return _make_deep_sea_dataset(environment) + else: + assert isinstance(environment, deep_sea.DeepSea) + return _make_deep_sea_stochastic_dataset(environment) diff --git a/acme/agents/tf/dqn/agent.py b/acme/agents/tf/dqn/agent.py index 77844c7ff7..a0fc9b07a8 100644 --- a/acme/agents/tf/dqn/agent.py +++ b/acme/agents/tf/dqn/agent.py @@ -17,8 +17,12 @@ import copy from typing import Optional -from acme import datasets -from acme import specs +import reverb +import sonnet as snt +import tensorflow as tf +import trfl + +from acme import datasets, specs from acme.adders import reverb as adders from acme.agents import agent from acme.agents.tf import actors @@ -26,14 +30,10 @@ from acme.tf import savers as tf2_savers from acme.tf import utils as tf2_utils from acme.utils import loggers -import reverb -import sonnet as snt -import tensorflow as tf -import trfl class DQN(agent.Agent): - """DQN agent. + """DQN agent. This implements a single-process DQN agent. This is a simple Q-learning algorithm that inserts N-step transitions into a replay buffer, and @@ -41,29 +41,29 @@ class DQN(agent.Agent): prioritization. """ - def __init__( - self, - environment_spec: specs.EnvironmentSpec, - network: snt.Module, - batch_size: int = 256, - prefetch_size: int = 4, - target_update_period: int = 100, - samples_per_insert: float = 32.0, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - importance_sampling_exponent: float = 0.2, - priority_exponent: float = 0.6, - n_step: int = 5, - epsilon: Optional[tf.Variable] = None, - learning_rate: float = 1e-3, - discount: float = 0.99, - logger: Optional[loggers.Logger] = None, - checkpoint: bool = True, - checkpoint_subpath: str = '~/acme', - policy_network: Optional[snt.Module] = None, - max_gradient_norm: Optional[float] = None, - ): - """Initialize the agent. + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + network: snt.Module, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + samples_per_insert: float = 32.0, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + importance_sampling_exponent: float = 0.2, + priority_exponent: float = 0.6, + n_step: int = 5, + epsilon: Optional[tf.Variable] = None, + learning_rate: float = 1e-3, + discount: float = 0.99, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + checkpoint_subpath: str = "~/acme", + policy_network: Optional[snt.Module] = None, + max_gradient_norm: Optional[float] = None, + ): + """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. @@ -96,82 +96,83 @@ def __init__( max_gradient_norm: used for gradient clipping. """ - # Create a replay server to add data to. This uses no limiter behavior in - # order to allow the Agent interface to handle it. - replay_table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Prioritized(priority_exponent), - remover=reverb.selectors.Fifo(), - max_size=max_replay_size, - rate_limiter=reverb.rate_limiters.MinSize(1), - signature=adders.NStepTransitionAdder.signature(environment_spec)) - self._server = reverb.Server([replay_table], port=None) - - # The adder is used to insert observations into replay. - address = f'localhost:{self._server.port}' - adder = adders.NStepTransitionAdder( - client=reverb.Client(address), - n_step=n_step, - discount=discount) - - # The dataset provides an interface to sample from replay. - replay_client = reverb.Client(address) - dataset = datasets.make_reverb_dataset( - server_address=address, - batch_size=batch_size, - prefetch_size=prefetch_size) - - # Create epsilon greedy policy network by default. - if policy_network is None: - # Use constant 0.05 epsilon greedy policy by default. - if epsilon is None: - epsilon = tf.Variable(0.05, trainable=False) - policy_network = snt.Sequential([ - network, - lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), - ]) - - # Create a target network. - target_network = copy.deepcopy(network) - - # Ensure that we create the variables before proceeding (maybe not needed). - tf2_utils.create_variables(network, [environment_spec.observations]) - tf2_utils.create_variables(target_network, [environment_spec.observations]) - - # Create the actor which defines how we take actions. - actor = actors.FeedForwardActor(policy_network, adder) - - # The learner updates the parameters (and initializes them). - learner = learning.DQNLearner( - network=network, - target_network=target_network, - discount=discount, - importance_sampling_exponent=importance_sampling_exponent, - learning_rate=learning_rate, - target_update_period=target_update_period, - dataset=dataset, - replay_client=replay_client, - max_gradient_norm=max_gradient_norm, - logger=logger, - checkpoint=checkpoint, - save_directory=checkpoint_subpath) - - if checkpoint: - self._checkpointer = tf2_savers.Checkpointer( - directory=checkpoint_subpath, - objects_to_save=learner.state, - subdirectory='dqn_learner', - time_delta_minutes=60.) - else: - self._checkpointer = None - - super().__init__( - actor=actor, - learner=learner, - min_observations=max(batch_size, min_replay_size), - observations_per_step=float(batch_size) / samples_per_insert) - - def update(self): - super().update() - if self._checkpointer is not None: - self._checkpointer.save() + # Create a replay server to add data to. This uses no limiter behavior in + # order to allow the Agent interface to handle it. + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Prioritized(priority_exponent), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(1), + signature=adders.NStepTransitionAdder.signature(environment_spec), + ) + self._server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f"localhost:{self._server.port}" + adder = adders.NStepTransitionAdder( + client=reverb.Client(address), n_step=n_step, discount=discount + ) + + # The dataset provides an interface to sample from replay. + replay_client = reverb.Client(address) + dataset = datasets.make_reverb_dataset( + server_address=address, batch_size=batch_size, prefetch_size=prefetch_size + ) + + # Create epsilon greedy policy network by default. + if policy_network is None: + # Use constant 0.05 epsilon greedy policy by default. + if epsilon is None: + epsilon = tf.Variable(0.05, trainable=False) + policy_network = snt.Sequential( + [network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),] + ) + + # Create a target network. + target_network = copy.deepcopy(network) + + # Ensure that we create the variables before proceeding (maybe not needed). + tf2_utils.create_variables(network, [environment_spec.observations]) + tf2_utils.create_variables(target_network, [environment_spec.observations]) + + # Create the actor which defines how we take actions. + actor = actors.FeedForwardActor(policy_network, adder) + + # The learner updates the parameters (and initializes them). + learner = learning.DQNLearner( + network=network, + target_network=target_network, + discount=discount, + importance_sampling_exponent=importance_sampling_exponent, + learning_rate=learning_rate, + target_update_period=target_update_period, + dataset=dataset, + replay_client=replay_client, + max_gradient_norm=max_gradient_norm, + logger=logger, + checkpoint=checkpoint, + save_directory=checkpoint_subpath, + ) + + if checkpoint: + self._checkpointer = tf2_savers.Checkpointer( + directory=checkpoint_subpath, + objects_to_save=learner.state, + subdirectory="dqn_learner", + time_delta_minutes=60.0, + ) + else: + self._checkpointer = None + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert, + ) + + def update(self): + super().update() + if self._checkpointer is not None: + self._checkpointer.save() diff --git a/acme/agents/tf/dqn/agent_distributed.py b/acme/agents/tf/dqn/agent_distributed.py index 0e22dd61e2..a22f891cc6 100644 --- a/acme/agents/tf/dqn/agent_distributed.py +++ b/acme/agents/tf/dqn/agent_distributed.py @@ -17,256 +17,261 @@ import copy from typing import Callable, Optional +import dm_env +import launchpad as lp +import numpy as np +import reverb +import sonnet as snt +import trfl + import acme -from acme import datasets -from acme import specs +from acme import datasets, specs from acme.adders import reverb as adders from acme.agents.tf import actors from acme.agents.tf.dqn import learning from acme.tf import savers as tf2_savers from acme.tf import utils as tf2_utils from acme.tf import variable_utils as tf2_variable_utils -from acme.utils import counting -from acme.utils import loggers -from acme.utils import lp_utils -import dm_env -import launchpad as lp -import numpy as np -import reverb -import sonnet as snt -import trfl +from acme.utils import counting, loggers, lp_utils class DistributedDQN: - """Distributed DQN agent.""" - - def __init__( - self, - environment_factory: Callable[[bool], dm_env.Environment], - network_factory: Callable[[specs.DiscreteArray], snt.Module], - num_actors: int, - num_caches: int = 1, - batch_size: int = 256, - prefetch_size: int = 4, - target_update_period: int = 100, - samples_per_insert: float = 32.0, - min_replay_size: int = 1000, - max_replay_size: int = 1_000_000, - importance_sampling_exponent: float = 0.2, - priority_exponent: float = 0.6, - n_step: int = 5, - learning_rate: float = 1e-3, - evaluator_epsilon: float = 0., - max_actor_steps: Optional[int] = None, - discount: float = 0.99, - environment_spec: Optional[specs.EnvironmentSpec] = None, - variable_update_period: int = 1000, - ): - - assert num_caches >= 1 - - if environment_spec is None: - environment_spec = specs.make_environment_spec(environment_factory(False)) - - self._environment_factory = environment_factory - self._network_factory = network_factory - self._num_actors = num_actors - self._num_caches = num_caches - self._env_spec = environment_spec - self._batch_size = batch_size - self._prefetch_size = prefetch_size - self._target_update_period = target_update_period - self._samples_per_insert = samples_per_insert - self._min_replay_size = min_replay_size - self._max_replay_size = max_replay_size - self._importance_sampling_exponent = importance_sampling_exponent - self._priority_exponent = priority_exponent - self._n_step = n_step - self._learning_rate = learning_rate - self._evaluator_epsilon = evaluator_epsilon - self._max_actor_steps = max_actor_steps - self._discount = discount - self._variable_update_period = variable_update_period - - def replay(self): - """The replay storage.""" - if self._samples_per_insert: - limiter = reverb.rate_limiters.SampleToInsertRatio( - min_size_to_sample=self._min_replay_size, - samples_per_insert=self._samples_per_insert, - error_buffer=self._batch_size) - else: - limiter = reverb.rate_limiters.MinSize(self._min_replay_size) - replay_table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Prioritized(self._priority_exponent), - remover=reverb.selectors.Fifo(), - max_size=self._max_replay_size, - rate_limiter=limiter, - signature=adders.NStepTransitionAdder.signature(self._env_spec)) - return [replay_table] - - def counter(self): - """Creates the master counter process.""" - return tf2_savers.CheckpointingRunner( - counting.Counter(), time_delta_minutes=1, subdirectory='counter') - - def coordinator(self, counter: counting.Counter, max_actor_steps: int): - return lp_utils.StepsLimiter(counter, max_actor_steps) - - def learner(self, replay: reverb.Client, counter: counting.Counter): - """The Learning part of the agent.""" - - # Create the networks. - network = self._network_factory(self._env_spec.actions) - target_network = copy.deepcopy(network) - - tf2_utils.create_variables(network, [self._env_spec.observations]) - tf2_utils.create_variables(target_network, [self._env_spec.observations]) - - # The dataset object to learn from. - replay_client = reverb.Client(replay.server_address) - dataset = datasets.make_reverb_dataset( - server_address=replay.server_address, - batch_size=self._batch_size, - prefetch_size=self._prefetch_size) - - logger = loggers.make_default_logger('learner', steps_key='learner_steps') - - # Return the learning agent. - counter = counting.Counter(counter, 'learner') - - learner = learning.DQNLearner( - network=network, - target_network=target_network, - discount=self._discount, - importance_sampling_exponent=self._importance_sampling_exponent, - learning_rate=self._learning_rate, - target_update_period=self._target_update_period, - dataset=dataset, - replay_client=replay_client, - counter=counter, - logger=logger) - return tf2_savers.CheckpointingRunner( - learner, subdirectory='dqn_learner', time_delta_minutes=60) - - def actor( - self, - replay: reverb.Client, - variable_source: acme.VariableSource, - counter: counting.Counter, - epsilon: float, - ) -> acme.EnvironmentLoop: - """The actor process.""" - environment = self._environment_factory(False) - network = self._network_factory(self._env_spec.actions) - - # Just inline the policy network here. - policy_network = snt.Sequential([ - network, - lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), - ]) - - tf2_utils.create_variables(policy_network, [self._env_spec.observations]) - variable_client = tf2_variable_utils.VariableClient( - client=variable_source, - variables={'policy': policy_network.trainable_variables}, - update_period=self._variable_update_period) - - # Make sure not to use a random policy after checkpoint restoration by - # assigning variables before running the environment loop. - variable_client.update_and_wait() - - # Component to add things into replay. - adder = adders.NStepTransitionAdder( - client=replay, - n_step=self._n_step, - discount=self._discount, - ) - - # Create the agent. - actor = actors.FeedForwardActor(policy_network, adder, variable_client) - - # Create the loop to connect environment and agent. - counter = counting.Counter(counter, 'actor') - logger = loggers.make_default_logger( - 'actor', save_data=False, steps_key='actor_steps') - return acme.EnvironmentLoop(environment, actor, counter, logger) - - def evaluator( - self, - variable_source: acme.VariableSource, - counter: counting.Counter, - ): - """The evaluation process.""" - environment = self._environment_factory(True) - network = self._network_factory(self._env_spec.actions) - - # Just inline the policy network here. - policy_network = snt.Sequential([ - network, - lambda q: trfl.epsilon_greedy(q, self._evaluator_epsilon).sample(), - ]) - - tf2_utils.create_variables(policy_network, [self._env_spec.observations]) - - variable_client = tf2_variable_utils.VariableClient( - client=variable_source, - variables={'policy': policy_network.trainable_variables}, - update_period=self._variable_update_period) - - # Make sure not to use a random policy after checkpoint restoration by - # assigning variables before running the environment loop. - variable_client.update_and_wait() - - # Create the agent. - actor = actors.FeedForwardActor( - policy_network, variable_client=variable_client) - - # Create the run loop and return it. - logger = loggers.make_default_logger( - 'evaluator', steps_key='evaluator_steps') - counter = counting.Counter(counter, 'evaluator') - return acme.EnvironmentLoop( - environment, actor, counter=counter, logger=logger) - - def build(self, name='dqn'): - """Build the distributed agent topology.""" - program = lp.Program(name=name) - - with program.group('replay'): - replay = program.add_node(lp.ReverbNode(self.replay)) - - with program.group('counter'): - counter = program.add_node(lp.CourierNode(self.counter)) - - if self._max_actor_steps: - program.add_node( - lp.CourierNode(self.coordinator, counter, self._max_actor_steps)) - - with program.group('learner'): - learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) - - with program.group('evaluator'): - program.add_node(lp.CourierNode(self.evaluator, learner, counter)) - - # Generate an epsilon for each actor. - epsilons = np.flip(np.logspace(1, 8, self._num_actors, base=0.4), axis=0) - - with program.group('cacher'): - # Create a set of learner caches. - sources = [] - for _ in range(self._num_caches): - cacher = program.add_node( - lp.CacherNode( - learner, refresh_interval_ms=2000, stale_after_ms=4000)) - sources.append(cacher) - - with program.group('actor'): - # Add actors which pull round-robin from our variable sources. - for actor_id, epsilon in enumerate(epsilons): - source = sources[actor_id % len(sources)] - program.add_node( - lp.CourierNode(self.actor, replay, source, counter, epsilon)) - - return program + """Distributed DQN agent.""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.DiscreteArray], snt.Module], + num_actors: int, + num_caches: int = 1, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + samples_per_insert: float = 32.0, + min_replay_size: int = 1000, + max_replay_size: int = 1_000_000, + importance_sampling_exponent: float = 0.2, + priority_exponent: float = 0.6, + n_step: int = 5, + learning_rate: float = 1e-3, + evaluator_epsilon: float = 0.0, + max_actor_steps: Optional[int] = None, + discount: float = 0.99, + environment_spec: Optional[specs.EnvironmentSpec] = None, + variable_update_period: int = 1000, + ): + + assert num_caches >= 1 + + if environment_spec is None: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + self._environment_factory = environment_factory + self._network_factory = network_factory + self._num_actors = num_actors + self._num_caches = num_caches + self._env_spec = environment_spec + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._target_update_period = target_update_period + self._samples_per_insert = samples_per_insert + self._min_replay_size = min_replay_size + self._max_replay_size = max_replay_size + self._importance_sampling_exponent = importance_sampling_exponent + self._priority_exponent = priority_exponent + self._n_step = n_step + self._learning_rate = learning_rate + self._evaluator_epsilon = evaluator_epsilon + self._max_actor_steps = max_actor_steps + self._discount = discount + self._variable_update_period = variable_update_period + + def replay(self): + """The replay storage.""" + if self._samples_per_insert: + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._min_replay_size, + samples_per_insert=self._samples_per_insert, + error_buffer=self._batch_size, + ) + else: + limiter = reverb.rate_limiters.MinSize(self._min_replay_size) + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Prioritized(self._priority_exponent), + remover=reverb.selectors.Fifo(), + max_size=self._max_replay_size, + rate_limiter=limiter, + signature=adders.NStepTransitionAdder.signature(self._env_spec), + ) + return [replay_table] + + def counter(self): + """Creates the master counter process.""" + return tf2_savers.CheckpointingRunner( + counting.Counter(), time_delta_minutes=1, subdirectory="counter" + ) + + def coordinator(self, counter: counting.Counter, max_actor_steps: int): + return lp_utils.StepsLimiter(counter, max_actor_steps) + + def learner(self, replay: reverb.Client, counter: counting.Counter): + """The Learning part of the agent.""" + + # Create the networks. + network = self._network_factory(self._env_spec.actions) + target_network = copy.deepcopy(network) + + tf2_utils.create_variables(network, [self._env_spec.observations]) + tf2_utils.create_variables(target_network, [self._env_spec.observations]) + + # The dataset object to learn from. + replay_client = reverb.Client(replay.server_address) + dataset = datasets.make_reverb_dataset( + server_address=replay.server_address, + batch_size=self._batch_size, + prefetch_size=self._prefetch_size, + ) + + logger = loggers.make_default_logger("learner", steps_key="learner_steps") + + # Return the learning agent. + counter = counting.Counter(counter, "learner") + + learner = learning.DQNLearner( + network=network, + target_network=target_network, + discount=self._discount, + importance_sampling_exponent=self._importance_sampling_exponent, + learning_rate=self._learning_rate, + target_update_period=self._target_update_period, + dataset=dataset, + replay_client=replay_client, + counter=counter, + logger=logger, + ) + return tf2_savers.CheckpointingRunner( + learner, subdirectory="dqn_learner", time_delta_minutes=60 + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + epsilon: float, + ) -> acme.EnvironmentLoop: + """The actor process.""" + environment = self._environment_factory(False) + network = self._network_factory(self._env_spec.actions) + + # Just inline the policy network here. + policy_network = snt.Sequential( + [network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),] + ) + + tf2_utils.create_variables(policy_network, [self._env_spec.observations]) + variable_client = tf2_variable_utils.VariableClient( + client=variable_source, + variables={"policy": policy_network.trainable_variables}, + update_period=self._variable_update_period, + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Component to add things into replay. + adder = adders.NStepTransitionAdder( + client=replay, n_step=self._n_step, discount=self._discount, + ) + + # Create the agent. + actor = actors.FeedForwardActor(policy_network, adder, variable_client) + + # Create the loop to connect environment and agent. + counter = counting.Counter(counter, "actor") + logger = loggers.make_default_logger( + "actor", save_data=False, steps_key="actor_steps" + ) + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, variable_source: acme.VariableSource, counter: counting.Counter, + ): + """The evaluation process.""" + environment = self._environment_factory(True) + network = self._network_factory(self._env_spec.actions) + + # Just inline the policy network here. + policy_network = snt.Sequential( + [ + network, + lambda q: trfl.epsilon_greedy(q, self._evaluator_epsilon).sample(), + ] + ) + + tf2_utils.create_variables(policy_network, [self._env_spec.observations]) + + variable_client = tf2_variable_utils.VariableClient( + client=variable_source, + variables={"policy": policy_network.trainable_variables}, + update_period=self._variable_update_period, + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + actor = actors.FeedForwardActor(policy_network, variable_client=variable_client) + + # Create the run loop and return it. + logger = loggers.make_default_logger("evaluator", steps_key="evaluator_steps") + counter = counting.Counter(counter, "evaluator") + return acme.EnvironmentLoop(environment, actor, counter=counter, logger=logger) + + def build(self, name="dqn"): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group("replay"): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group("counter"): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + program.add_node( + lp.CourierNode(self.coordinator, counter, self._max_actor_steps) + ) + + with program.group("learner"): + learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) + + with program.group("evaluator"): + program.add_node(lp.CourierNode(self.evaluator, learner, counter)) + + # Generate an epsilon for each actor. + epsilons = np.flip(np.logspace(1, 8, self._num_actors, base=0.4), axis=0) + + with program.group("cacher"): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000 + ) + ) + sources.append(cacher) + + with program.group("actor"): + # Add actors which pull round-robin from our variable sources. + for actor_id, epsilon in enumerate(epsilons): + source = sources[actor_id % len(sources)] + program.add_node( + lp.CourierNode(self.actor, replay, source, counter, epsilon) + ) + + return program diff --git a/acme/agents/tf/dqn/agent_distributed_test.py b/acme/agents/tf/dqn/agent_distributed_test.py index d6b4788a92..549d696b87 100644 --- a/acme/agents/tf/dqn/agent_distributed_test.py +++ b/acme/agents/tf/dqn/agent_distributed_test.py @@ -14,43 +14,43 @@ """Integration test for the distributed agent.""" +import launchpad as lp +from absl.testing import absltest + import acme from acme.agents.tf import dqn from acme.testing import fakes from acme.tf import networks -import launchpad as lp - -from absl.testing import absltest class DistributedAgentTest(absltest.TestCase): - """Simple integration/smoke test for the distributed agent.""" + """Simple integration/smoke test for the distributed agent.""" - def test_atari(self): - """Tests that the agent can run for some steps without crashing.""" - env_factory = lambda x: fakes.fake_atari_wrapped() - net_factory = lambda spec: networks.DQNAtariNetwork(spec.num_values) + def test_atari(self): + """Tests that the agent can run for some steps without crashing.""" + env_factory = lambda x: fakes.fake_atari_wrapped() + net_factory = lambda spec: networks.DQNAtariNetwork(spec.num_values) - agent = dqn.DistributedDQN( - environment_factory=env_factory, - network_factory=net_factory, - num_actors=2, - batch_size=32, - min_replay_size=32, - max_replay_size=1000, - ) - program = agent.build() + agent = dqn.DistributedDQN( + environment_factory=env_factory, + network_factory=net_factory, + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + ) + program = agent.build() - (learner_node,) = program.groups['learner'] - learner_node.disable_run() + (learner_node,) = program.groups["learner"] + learner_node.disable_run() - lp.launch(program, launch_type='test_mt') + lp.launch(program, launch_type="test_mt") - learner: acme.Learner = learner_node.create_handle().dereference() + learner: acme.Learner = learner_node.create_handle().dereference() - for _ in range(5): - learner.step() + for _ in range(5): + learner.step() -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/dqn/agent_test.py b/acme/agents/tf/dqn/agent_test.py index c4dbfe7c54..ddef0f4b8d 100644 --- a/acme/agents/tf/dqn/agent_test.py +++ b/acme/agents/tf/dqn/agent_test.py @@ -14,47 +14,44 @@ """Tests for DQN agent.""" +import numpy as np +import sonnet as snt +from absl.testing import absltest + import acme from acme import specs from acme.agents.tf import dqn from acme.testing import fakes -import numpy as np -import sonnet as snt - -from absl.testing import absltest def _make_network(action_spec: specs.DiscreteArray) -> snt.Module: - return snt.Sequential([ - snt.Flatten(), - snt.nets.MLP([50, 50, action_spec.num_values]), - ]) + return snt.Sequential( + [snt.Flatten(), snt.nets.MLP([50, 50, action_spec.num_values]),] + ) class DQNTest(absltest.TestCase): - - def test_dqn(self): - # Create a fake environment to test with. - environment = fakes.DiscreteEnvironment( - num_actions=5, - num_observations=10, - obs_dtype=np.float32, - episode_length=10) - spec = specs.make_environment_spec(environment) - - # Construct the agent. - agent = dqn.DQN( - environment_spec=spec, - network=_make_network(spec.actions), - batch_size=10, - samples_per_insert=2, - min_replay_size=10) - - # Try running the environment loop. We have no assertions here because all - # we care about is that the agent runs without raising any errors. - loop = acme.EnvironmentLoop(environment, agent) - loop.run(num_episodes=2) - - -if __name__ == '__main__': - absltest.main() + def test_dqn(self): + # Create a fake environment to test with. + environment = fakes.DiscreteEnvironment( + num_actions=5, num_observations=10, obs_dtype=np.float32, episode_length=10 + ) + spec = specs.make_environment_spec(environment) + + # Construct the agent. + agent = dqn.DQN( + environment_spec=spec, + network=_make_network(spec.actions), + batch_size=10, + samples_per_insert=2, + min_replay_size=10, + ) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/dqn/learning.py b/acme/agents/tf/dqn/learning.py index 00ce17af89..8953caf8fb 100644 --- a/acme/agents/tf/dqn/learning.py +++ b/acme/agents/tf/dqn/learning.py @@ -17,48 +17,48 @@ import time from typing import Dict, List, Optional, Union +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf +import trfl + import acme from acme import types from acme.adders import reverb as adders from acme.tf import losses from acme.tf import savers as tf2_savers from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers -import numpy as np -import reverb -import sonnet as snt -import tensorflow as tf -import trfl +from acme.utils import counting, loggers class DQNLearner(acme.Learner, tf2_savers.TFSaveable): - """DQN learner. + """DQN learner. This is the learning component of a DQN agent. It takes a dataset as input and implements update functionality to learn from this dataset. Optionally it takes a replay client as well to allow for updating of priorities. """ - def __init__( - self, - network: snt.Module, - target_network: snt.Module, - discount: float, - importance_sampling_exponent: float, - learning_rate: float, - target_update_period: int, - dataset: tf.data.Dataset, - max_abs_reward: Optional[float] = 1., - huber_loss_parameter: float = 1., - replay_client: Optional[Union[reverb.Client, reverb.TFClient]] = None, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - checkpoint: bool = True, - save_directory: str = '~/acme', - max_gradient_norm: Optional[float] = None, - ): - """Initializes the learner. + def __init__( + self, + network: snt.Module, + target_network: snt.Module, + discount: float, + importance_sampling_exponent: float, + learning_rate: float, + target_update_period: int, + dataset: tf.data.Dataset, + max_abs_reward: Optional[float] = 1.0, + huber_loss_parameter: float = 1.0, + replay_client: Optional[Union[reverb.Client, reverb.TFClient]] = None, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + save_directory: str = "~/acme", + max_gradient_norm: Optional[float] = None, + ): + """Initializes the learner. Args: network: the online Q network (the one being optimized) @@ -82,157 +82,166 @@ def __init__( max_gradient_norm: used for gradient clipping. """ - # TODO(mwhoffman): stop allowing replay_client to be passed as a TFClient. - # This is just here for backwards compatability for agents which reuse this - # Learner and still pass a TFClient instance. - if isinstance(replay_client, reverb.TFClient): - # TODO(b/170419518): open source pytype does not understand this - # isinstance() check because it does not have a way of getting precise - # type information for pip-installed packages. - replay_client = reverb.Client(replay_client._server_address) # pytype: disable=attribute-error - - # Internalise agent components (replay buffer, networks, optimizer). - # TODO(b/155086959): Fix type stubs and remove. - self._iterator = iter(dataset) # pytype: disable=wrong-arg-types - self._network = network - self._target_network = target_network - self._optimizer = snt.optimizers.Adam(learning_rate) - self._replay_client = replay_client - - # Make sure to initialize the optimizer so that its variables (e.g. the Adam - # moments) are included in the state returned by the learner (which can then - # be checkpointed and restored). - self._optimizer._initialize(network.trainable_variables) # pylint: disable= protected-access - - # Internalise the hyperparameters. - self._discount = discount - self._target_update_period = target_update_period - self._importance_sampling_exponent = importance_sampling_exponent - self._max_abs_reward = max_abs_reward - self._huber_loss_parameter = huber_loss_parameter - if max_gradient_norm is None: - max_gradient_norm = 1e10 # A very large number. Infinity results in NaNs. - self._max_gradient_norm = tf.convert_to_tensor(max_gradient_norm) - - # Learner state. - self._variables: List[List[tf.Tensor]] = [network.trainable_variables] - self._num_steps = tf.Variable(0, dtype=tf.int32) - - # Internalise logging/counting objects. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) - - # Create a snapshotter object. - if checkpoint: - self._snapshotter = tf2_savers.Snapshotter( - objects_to_save={'network': network}, - directory=save_directory, - time_delta_minutes=60.) - else: - self._snapshotter = None - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - @tf.function - def _step(self) -> Dict[str, tf.Tensor]: - """Do a step of SGD and update the priorities.""" - - # Pull out the data needed for updates/priorities. - inputs = next(self._iterator) - transitions: types.Transition = inputs.data - keys, probs = inputs.info[:2] - - with tf.GradientTape() as tape: - # Evaluate our networks. - q_tm1 = self._network(transitions.observation) - q_t_value = self._target_network(transitions.next_observation) - q_t_selector = self._network(transitions.next_observation) - - # The rewards and discounts have to have the same type as network values. - r_t = tf.cast(transitions.reward, q_tm1.dtype) - if self._max_abs_reward: - r_t = tf.clip_by_value(r_t, -self._max_abs_reward, self._max_abs_reward) - d_t = tf.cast(transitions.discount, q_tm1.dtype) * tf.cast( - self._discount, q_tm1.dtype) - - # Compute the loss. - _, extra = trfl.double_qlearning(q_tm1, transitions.action, r_t, d_t, - q_t_value, q_t_selector) - loss = losses.huber(extra.td_error, self._huber_loss_parameter) - - # Get the importance weights. - importance_weights = 1. / probs # [B] - importance_weights **= self._importance_sampling_exponent - importance_weights /= tf.reduce_max(importance_weights) - - # Reweight. - loss *= tf.cast(importance_weights, loss.dtype) # [B] - loss = tf.reduce_mean(loss, axis=[0]) # [] - - # Do a step of SGD. - gradients = tape.gradient(loss, self._network.trainable_variables) - gradients, _ = tf.clip_by_global_norm(gradients, self._max_gradient_norm) - self._optimizer.apply(gradients, self._network.trainable_variables) - - # Get the priorities that we'll use to update. - priorities = tf.abs(extra.td_error) - - # Periodically update the target network. - if tf.math.mod(self._num_steps, self._target_update_period) == 0: - for src, dest in zip(self._network.variables, - self._target_network.variables): - dest.assign(src) - self._num_steps.assign_add(1) - - # Report loss & statistics for logging. - fetches = { - 'loss': loss, - 'keys': keys, - 'priorities': priorities, - } - - return fetches - - def step(self): - # Do a batch of SGD. - result = self._step() - - # Get the keys and priorities. - keys = result.pop('keys') - priorities = result.pop('priorities') - - # Update the priorities in the replay buffer. - if self._replay_client: - self._replay_client.mutate_priorities( - table=adders.DEFAULT_PRIORITY_TABLE, - updates=dict(zip(keys.numpy(), priorities.numpy()))) - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Update our counts and record it. - counts = self._counter.increment(steps=1, walltime=elapsed_time) - result.update(counts) - - # Snapshot and attempt to write logs. - if self._snapshotter is not None: - self._snapshotter.save() - self._logger.write(result) - - def get_variables(self, names: List[str]) -> List[np.ndarray]: - return tf2_utils.to_numpy(self._variables) - - @property - def state(self): - """Returns the stateful parts of the learner for checkpointing.""" - return { - 'network': self._network, - 'target_network': self._target_network, - 'optimizer': self._optimizer, - 'num_steps': self._num_steps - } + # TODO(mwhoffman): stop allowing replay_client to be passed as a TFClient. + # This is just here for backwards compatability for agents which reuse this + # Learner and still pass a TFClient instance. + if isinstance(replay_client, reverb.TFClient): + # TODO(b/170419518): open source pytype does not understand this + # isinstance() check because it does not have a way of getting precise + # type information for pip-installed packages. + replay_client = reverb.Client( + replay_client._server_address + ) # pytype: disable=attribute-error + + # Internalise agent components (replay buffer, networks, optimizer). + # TODO(b/155086959): Fix type stubs and remove. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + self._network = network + self._target_network = target_network + self._optimizer = snt.optimizers.Adam(learning_rate) + self._replay_client = replay_client + + # Make sure to initialize the optimizer so that its variables (e.g. the Adam + # moments) are included in the state returned by the learner (which can then + # be checkpointed and restored). + self._optimizer._initialize( + network.trainable_variables + ) # pylint: disable= protected-access + + # Internalise the hyperparameters. + self._discount = discount + self._target_update_period = target_update_period + self._importance_sampling_exponent = importance_sampling_exponent + self._max_abs_reward = max_abs_reward + self._huber_loss_parameter = huber_loss_parameter + if max_gradient_norm is None: + max_gradient_norm = 1e10 # A very large number. Infinity results in NaNs. + self._max_gradient_norm = tf.convert_to_tensor(max_gradient_norm) + + # Learner state. + self._variables: List[List[tf.Tensor]] = [network.trainable_variables] + self._num_steps = tf.Variable(0, dtype=tf.int32) + + # Internalise logging/counting objects. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.TerminalLogger("learner", time_delta=1.0) + + # Create a snapshotter object. + if checkpoint: + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={"network": network}, + directory=save_directory, + time_delta_minutes=60.0, + ) + else: + self._snapshotter = None + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self) -> Dict[str, tf.Tensor]: + """Do a step of SGD and update the priorities.""" + + # Pull out the data needed for updates/priorities. + inputs = next(self._iterator) + transitions: types.Transition = inputs.data + keys, probs = inputs.info[:2] + + with tf.GradientTape() as tape: + # Evaluate our networks. + q_tm1 = self._network(transitions.observation) + q_t_value = self._target_network(transitions.next_observation) + q_t_selector = self._network(transitions.next_observation) + + # The rewards and discounts have to have the same type as network values. + r_t = tf.cast(transitions.reward, q_tm1.dtype) + if self._max_abs_reward: + r_t = tf.clip_by_value(r_t, -self._max_abs_reward, self._max_abs_reward) + d_t = tf.cast(transitions.discount, q_tm1.dtype) * tf.cast( + self._discount, q_tm1.dtype + ) + + # Compute the loss. + _, extra = trfl.double_qlearning( + q_tm1, transitions.action, r_t, d_t, q_t_value, q_t_selector + ) + loss = losses.huber(extra.td_error, self._huber_loss_parameter) + + # Get the importance weights. + importance_weights = 1.0 / probs # [B] + importance_weights **= self._importance_sampling_exponent + importance_weights /= tf.reduce_max(importance_weights) + + # Reweight. + loss *= tf.cast(importance_weights, loss.dtype) # [B] + loss = tf.reduce_mean(loss, axis=[0]) # [] + + # Do a step of SGD. + gradients = tape.gradient(loss, self._network.trainable_variables) + gradients, _ = tf.clip_by_global_norm(gradients, self._max_gradient_norm) + self._optimizer.apply(gradients, self._network.trainable_variables) + + # Get the priorities that we'll use to update. + priorities = tf.abs(extra.td_error) + + # Periodically update the target network. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip( + self._network.variables, self._target_network.variables + ): + dest.assign(src) + self._num_steps.assign_add(1) + + # Report loss & statistics for logging. + fetches = { + "loss": loss, + "keys": keys, + "priorities": priorities, + } + + return fetches + + def step(self): + # Do a batch of SGD. + result = self._step() + + # Get the keys and priorities. + keys = result.pop("keys") + priorities = result.pop("priorities") + + # Update the priorities in the replay buffer. + if self._replay_client: + self._replay_client.mutate_priorities( + table=adders.DEFAULT_PRIORITY_TABLE, + updates=dict(zip(keys.numpy(), priorities.numpy())), + ) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + result.update(counts) + + # Snapshot and attempt to write logs. + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(result) + + def get_variables(self, names: List[str]) -> List[np.ndarray]: + return tf2_utils.to_numpy(self._variables) + + @property + def state(self): + """Returns the stateful parts of the learner for checkpointing.""" + return { + "network": self._network, + "target_network": self._target_network, + "optimizer": self._optimizer, + "num_steps": self._num_steps, + } diff --git a/acme/agents/tf/impala/acting.py b/acme/agents/tf/impala/acting.py index 748d9d5022..2e8b084143 100644 --- a/acme/agents/tf/impala/acting.py +++ b/acme/agents/tf/impala/acting.py @@ -16,81 +16,77 @@ from typing import Optional -from acme import adders -from acme import core -from acme import types -from acme.tf import utils as tf2_utils -from acme.tf import variable_utils as tf2_variable_utils - import dm_env import sonnet as snt import tensorflow as tf import tensorflow_probability as tfp +from acme import adders, core, types +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils + tfd = tfp.distributions class IMPALAActor(core.Actor): - """A recurrent actor.""" + """A recurrent actor.""" - def __init__( - self, - network: snt.RNNCore, - adder: Optional[adders.Adder] = None, - variable_client: Optional[tf2_variable_utils.VariableClient] = None, - ): + def __init__( + self, + network: snt.RNNCore, + adder: Optional[adders.Adder] = None, + variable_client: Optional[tf2_variable_utils.VariableClient] = None, + ): - # Store these for later use. - self._adder = adder - self._variable_client = variable_client - self._network = network + # Store these for later use. + self._adder = adder + self._variable_client = variable_client + self._network = network - # TODO(b/152382420): Ideally we would call tf.function(network) instead but - # this results in an error when using acme RNN snapshots. - self._policy = tf.function(network.__call__) + # TODO(b/152382420): Ideally we would call tf.function(network) instead but + # this results in an error when using acme RNN snapshots. + self._policy = tf.function(network.__call__) - self._state = None - self._prev_state = None - self._prev_logits = None + self._state = None + self._prev_state = None + self._prev_logits = None - def select_action(self, observation: types.NestedArray) -> types.NestedArray: - # Add a dummy batch dimension and as a side effect convert numpy to TF. - batched_obs = tf2_utils.add_batch_dim(observation) + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + # Add a dummy batch dimension and as a side effect convert numpy to TF. + batched_obs = tf2_utils.add_batch_dim(observation) - if self._state is None: - self._state = self._network.initial_state(1) + if self._state is None: + self._state = self._network.initial_state(1) - # Forward. - (logits, _), new_state = self._policy(batched_obs, self._state) + # Forward. + (logits, _), new_state = self._policy(batched_obs, self._state) - self._prev_logits = logits - self._prev_state = self._state - self._state = new_state + self._prev_logits = logits + self._prev_state = self._state + self._state = new_state - action = tfd.Categorical(logits).sample() - action = tf2_utils.to_numpy_squeeze(action) + action = tfd.Categorical(logits).sample() + action = tf2_utils.to_numpy_squeeze(action) - return action + return action - def observe_first(self, timestep: dm_env.TimeStep): - if self._adder: - self._adder.add_first(timestep) + def observe_first(self, timestep: dm_env.TimeStep): + if self._adder: + self._adder.add_first(timestep) - # Set the state to None so that we re-initialize at the next policy call. - self._state = None + # Set the state to None so that we re-initialize at the next policy call. + self._state = None - def observe( - self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - ): - if not self._adder: - return + def observe( + self, action: types.NestedArray, next_timestep: dm_env.TimeStep, + ): + if not self._adder: + return - extras = {'logits': self._prev_logits, 'core_state': self._prev_state} - extras = tf2_utils.to_numpy_squeeze(extras) - self._adder.add(action, next_timestep, extras) + extras = {"logits": self._prev_logits, "core_state": self._prev_state} + extras = tf2_utils.to_numpy_squeeze(extras) + self._adder.add(action, next_timestep, extras) - def update(self, wait: bool = False): - if self._variable_client: - self._variable_client.update(wait) + def update(self, wait: bool = False): + if self._variable_client: + self._variable_client.update(wait) diff --git a/acme/agents/tf/impala/agent.py b/acme/agents/tf/impala/agent.py index 807c58234e..1104ab725a 100644 --- a/acme/agents/tf/impala/agent.py +++ b/acme/agents/tf/impala/agent.py @@ -16,108 +16,105 @@ from typing import Optional -import acme -from acme import datasets -from acme import specs -from acme import types -from acme.adders import reverb as adders -from acme.agents.tf.impala import acting -from acme.agents.tf.impala import learning -from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers import dm_env import numpy as np import reverb import sonnet as snt import tensorflow as tf +import acme +from acme import datasets, specs, types +from acme.adders import reverb as adders +from acme.agents.tf.impala import acting, learning +from acme.tf import utils as tf2_utils +from acme.utils import counting, loggers + class IMPALA(acme.Actor): - """IMPALA Agent.""" - - def __init__( - self, - environment_spec: specs.EnvironmentSpec, - network: snt.RNNCore, - sequence_length: int, - sequence_period: int, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - discount: float = 0.99, - max_queue_size: int = 100000, - batch_size: int = 16, - learning_rate: float = 1e-3, - entropy_cost: float = 0.01, - baseline_cost: float = 0.5, - max_abs_reward: Optional[float] = None, - max_gradient_norm: Optional[float] = None, - ): - - num_actions = environment_spec.actions.num_values - self._logger = logger or loggers.TerminalLogger('agent') - - extra_spec = { - 'core_state': network.initial_state(1), - 'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32) - } - # Remove batch dimensions. - extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) - - queue = reverb.Table.queue( - name=adders.DEFAULT_PRIORITY_TABLE, - max_size=max_queue_size, - signature=adders.SequenceAdder.signature( - environment_spec, - extras_spec=extra_spec, - sequence_length=sequence_length)) - self._server = reverb.Server([queue], port=None) - self._can_sample = lambda: queue.can_sample(batch_size) - address = f'localhost:{self._server.port}' - - # Component to add things into replay. - adder = adders.SequenceAdder( - client=reverb.Client(address), - period=sequence_period, - sequence_length=sequence_length, - ) - - # The dataset object to learn from. - dataset = datasets.make_reverb_dataset( - server_address=address, - batch_size=batch_size) - - tf2_utils.create_variables(network, [environment_spec.observations]) - - self._actor = acting.IMPALAActor(network, adder) - self._learner = learning.IMPALALearner( - environment_spec=environment_spec, - network=network, - dataset=dataset, - counter=counter, - logger=logger, - discount=discount, - learning_rate=learning_rate, - entropy_cost=entropy_cost, - baseline_cost=baseline_cost, - max_gradient_norm=max_gradient_norm, - max_abs_reward=max_abs_reward, - ) - - def observe_first(self, timestep: dm_env.TimeStep): - self._actor.observe_first(timestep) - - def observe( - self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - ): - self._actor.observe(action, next_timestep) - - def update(self, wait: bool = False): - # Run a number of learner steps (usually gradient steps). - while self._can_sample(): - self._learner.step() - - def select_action(self, observation: np.ndarray) -> int: - return self._actor.select_action(observation) + """IMPALA Agent.""" + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + network: snt.RNNCore, + sequence_length: int, + sequence_period: int, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + discount: float = 0.99, + max_queue_size: int = 100000, + batch_size: int = 16, + learning_rate: float = 1e-3, + entropy_cost: float = 0.01, + baseline_cost: float = 0.5, + max_abs_reward: Optional[float] = None, + max_gradient_norm: Optional[float] = None, + ): + + num_actions = environment_spec.actions.num_values + self._logger = logger or loggers.TerminalLogger("agent") + + extra_spec = { + "core_state": network.initial_state(1), + "logits": tf.ones(shape=(1, num_actions), dtype=tf.float32), + } + # Remove batch dimensions. + extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) + + queue = reverb.Table.queue( + name=adders.DEFAULT_PRIORITY_TABLE, + max_size=max_queue_size, + signature=adders.SequenceAdder.signature( + environment_spec, + extras_spec=extra_spec, + sequence_length=sequence_length, + ), + ) + self._server = reverb.Server([queue], port=None) + self._can_sample = lambda: queue.can_sample(batch_size) + address = f"localhost:{self._server.port}" + + # Component to add things into replay. + adder = adders.SequenceAdder( + client=reverb.Client(address), + period=sequence_period, + sequence_length=sequence_length, + ) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + server_address=address, batch_size=batch_size + ) + + tf2_utils.create_variables(network, [environment_spec.observations]) + + self._actor = acting.IMPALAActor(network, adder) + self._learner = learning.IMPALALearner( + environment_spec=environment_spec, + network=network, + dataset=dataset, + counter=counter, + logger=logger, + discount=discount, + learning_rate=learning_rate, + entropy_cost=entropy_cost, + baseline_cost=baseline_cost, + max_gradient_norm=max_gradient_norm, + max_abs_reward=max_abs_reward, + ) + + def observe_first(self, timestep: dm_env.TimeStep): + self._actor.observe_first(timestep) + + def observe( + self, action: types.NestedArray, next_timestep: dm_env.TimeStep, + ): + self._actor.observe(action, next_timestep) + + def update(self, wait: bool = False): + # Run a number of learner steps (usually gradient steps). + while self._can_sample(): + self._learner.step() + + def select_action(self, observation: np.ndarray) -> int: + return self._actor.select_action(observation) diff --git a/acme/agents/tf/impala/agent_distributed.py b/acme/agents/tf/impala/agent_distributed.py index 6002601b13..7b5fe617c4 100644 --- a/acme/agents/tf/impala/agent_distributed.py +++ b/acme/agents/tf/impala/agent_distributed.py @@ -16,215 +16,219 @@ from typing import Callable, Optional -import acme -from acme import datasets -from acme import specs -from acme.adders import reverb as adders -from acme.agents.tf.impala import acting -from acme.agents.tf.impala import learning -from acme.tf import savers as tf2_savers -from acme.tf import utils as tf2_utils -from acme.tf import variable_utils as tf2_variable_utils -from acme.utils import counting -from acme.utils import loggers import dm_env import launchpad as lp import reverb import sonnet as snt import tensorflow as tf +import acme +from acme import datasets, specs +from acme.adders import reverb as adders +from acme.agents.tf.impala import acting, learning +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils +from acme.utils import counting, loggers + class DistributedIMPALA: - """Program definition for IMPALA.""" - - def __init__(self, - environment_factory: Callable[[bool], dm_env.Environment], - network_factory: Callable[[specs.DiscreteArray], snt.RNNCore], - num_actors: int, - sequence_length: int, - sequence_period: int, - environment_spec: Optional[specs.EnvironmentSpec] = None, - batch_size: int = 256, - prefetch_size: int = 4, - max_queue_size: int = 10_000, - learning_rate: float = 1e-3, - discount: float = 0.99, - entropy_cost: float = 0.01, - baseline_cost: float = 0.5, - max_abs_reward: Optional[float] = None, - max_gradient_norm: Optional[float] = None, - variable_update_period: int = 1000, - save_logs: bool = False): - - if environment_spec is None: - environment_spec = specs.make_environment_spec(environment_factory(False)) - - self._environment_factory = environment_factory - self._network_factory = network_factory - self._environment_spec = environment_spec - self._num_actors = num_actors - self._batch_size = batch_size - self._prefetch_size = prefetch_size - self._sequence_length = sequence_length - self._max_queue_size = max_queue_size - self._sequence_period = sequence_period - self._discount = discount - self._learning_rate = learning_rate - self._entropy_cost = entropy_cost - self._baseline_cost = baseline_cost - self._max_abs_reward = max_abs_reward - self._max_gradient_norm = max_gradient_norm - self._variable_update_period = variable_update_period - self._save_logs = save_logs - - def queue(self): - """The queue.""" - num_actions = self._environment_spec.actions.num_values - network = self._network_factory(self._environment_spec.actions) - extra_spec = { - 'core_state': network.initial_state(1), - 'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32) - } - # Remove batch dimensions. - extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) - signature = adders.SequenceAdder.signature( - self._environment_spec, - extra_spec, - sequence_length=self._sequence_length) - queue = reverb.Table.queue( - name=adders.DEFAULT_PRIORITY_TABLE, - max_size=self._max_queue_size, - signature=signature) - return [queue] - - def counter(self): - """Creates the master counter process.""" - return tf2_savers.CheckpointingRunner( - counting.Counter(), time_delta_minutes=1, subdirectory='counter') - - def learner(self, queue: reverb.Client, counter: counting.Counter): - """The Learning part of the agent.""" - # Use architect and create the environment. - # Create the networks. - network = self._network_factory(self._environment_spec.actions) - tf2_utils.create_variables(network, [self._environment_spec.observations]) - - # The dataset object to learn from. - dataset = datasets.make_reverb_dataset( - server_address=queue.server_address, - batch_size=self._batch_size, - prefetch_size=self._prefetch_size) - - logger = loggers.make_default_logger('learner', steps_key='learner_steps') - counter = counting.Counter(counter, 'learner') - - # Return the learning agent. - learner = learning.IMPALALearner( - environment_spec=self._environment_spec, - network=network, - dataset=dataset, - discount=self._discount, - learning_rate=self._learning_rate, - entropy_cost=self._entropy_cost, - baseline_cost=self._baseline_cost, - max_abs_reward=self._max_abs_reward, - max_gradient_norm=self._max_gradient_norm, - counter=counter, - logger=logger, - ) - - return tf2_savers.CheckpointingRunner(learner, - time_delta_minutes=5, - subdirectory='impala_learner') - - def actor( - self, - replay: reverb.Client, - variable_source: acme.VariableSource, - counter: counting.Counter, - ) -> acme.EnvironmentLoop: - """The actor process.""" - environment = self._environment_factory(False) - network = self._network_factory(self._environment_spec.actions) - tf2_utils.create_variables(network, [self._environment_spec.observations]) - - # Component to add things into the queue. - adder = adders.SequenceAdder( - client=replay, - period=self._sequence_period, - sequence_length=self._sequence_length) - - variable_client = tf2_variable_utils.VariableClient( - client=variable_source, - variables={'policy': network.variables}, - update_period=self._variable_update_period) - - # Make sure not to use a random policy after checkpoint restoration by - # assigning variables before running the environment loop. - variable_client.update_and_wait() - - # Create the agent. - actor = acting.IMPALAActor( - network=network, - variable_client=variable_client, - adder=adder) - - counter = counting.Counter(counter, 'actor') - logger = loggers.make_default_logger( - 'actor', save_data=False, steps_key='actor_steps') - - # Create the loop to connect environment and agent. - return acme.EnvironmentLoop(environment, actor, counter, logger) - - def evaluator(self, variable_source: acme.VariableSource, - counter: counting.Counter): - """The evaluation process.""" - environment = self._environment_factory(True) - network = self._network_factory(self._environment_spec.actions) - tf2_utils.create_variables(network, [self._environment_spec.observations]) - - variable_client = tf2_variable_utils.VariableClient( - client=variable_source, - variables={'policy': network.variables}, - update_period=self._variable_update_period) - - # Make sure not to use a random policy after checkpoint restoration by - # assigning variables before running the environment loop. - variable_client.update_and_wait() - - # Create the agent. - actor = acting.IMPALAActor( - network=network, variable_client=variable_client) - - # Create the run loop and return it. - logger = loggers.make_default_logger( - 'evaluator', steps_key='evaluator_steps') - counter = counting.Counter(counter, 'evaluator') - return acme.EnvironmentLoop(environment, actor, counter, logger) - - def build(self, name='impala'): - """Build the distributed agent topology.""" - program = lp.Program(name=name) - - with program.group('replay'): - queue = program.add_node(lp.ReverbNode(self.queue)) - - with program.group('counter'): - counter = program.add_node(lp.CourierNode(self.counter)) - - with program.group('learner'): - learner = program.add_node( - lp.CourierNode(self.learner, queue, counter)) - - with program.group('evaluator'): - program.add_node(lp.CourierNode(self.evaluator, learner, counter)) - - with program.group('cacher'): - cacher = program.add_node( - lp.CacherNode(learner, refresh_interval_ms=2000, stale_after_ms=4000)) - - with program.group('actor'): - for _ in range(self._num_actors): - program.add_node(lp.CourierNode(self.actor, queue, cacher, counter)) - - return program + """Program definition for IMPALA.""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.DiscreteArray], snt.RNNCore], + num_actors: int, + sequence_length: int, + sequence_period: int, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = 4, + max_queue_size: int = 10_000, + learning_rate: float = 1e-3, + discount: float = 0.99, + entropy_cost: float = 0.01, + baseline_cost: float = 0.5, + max_abs_reward: Optional[float] = None, + max_gradient_norm: Optional[float] = None, + variable_update_period: int = 1000, + save_logs: bool = False, + ): + + if environment_spec is None: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + self._environment_factory = environment_factory + self._network_factory = network_factory + self._environment_spec = environment_spec + self._num_actors = num_actors + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._sequence_length = sequence_length + self._max_queue_size = max_queue_size + self._sequence_period = sequence_period + self._discount = discount + self._learning_rate = learning_rate + self._entropy_cost = entropy_cost + self._baseline_cost = baseline_cost + self._max_abs_reward = max_abs_reward + self._max_gradient_norm = max_gradient_norm + self._variable_update_period = variable_update_period + self._save_logs = save_logs + + def queue(self): + """The queue.""" + num_actions = self._environment_spec.actions.num_values + network = self._network_factory(self._environment_spec.actions) + extra_spec = { + "core_state": network.initial_state(1), + "logits": tf.ones(shape=(1, num_actions), dtype=tf.float32), + } + # Remove batch dimensions. + extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) + signature = adders.SequenceAdder.signature( + self._environment_spec, extra_spec, sequence_length=self._sequence_length + ) + queue = reverb.Table.queue( + name=adders.DEFAULT_PRIORITY_TABLE, + max_size=self._max_queue_size, + signature=signature, + ) + return [queue] + + def counter(self): + """Creates the master counter process.""" + return tf2_savers.CheckpointingRunner( + counting.Counter(), time_delta_minutes=1, subdirectory="counter" + ) + + def learner(self, queue: reverb.Client, counter: counting.Counter): + """The Learning part of the agent.""" + # Use architect and create the environment. + # Create the networks. + network = self._network_factory(self._environment_spec.actions) + tf2_utils.create_variables(network, [self._environment_spec.observations]) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + server_address=queue.server_address, + batch_size=self._batch_size, + prefetch_size=self._prefetch_size, + ) + + logger = loggers.make_default_logger("learner", steps_key="learner_steps") + counter = counting.Counter(counter, "learner") + + # Return the learning agent. + learner = learning.IMPALALearner( + environment_spec=self._environment_spec, + network=network, + dataset=dataset, + discount=self._discount, + learning_rate=self._learning_rate, + entropy_cost=self._entropy_cost, + baseline_cost=self._baseline_cost, + max_abs_reward=self._max_abs_reward, + max_gradient_norm=self._max_gradient_norm, + counter=counter, + logger=logger, + ) + + return tf2_savers.CheckpointingRunner( + learner, time_delta_minutes=5, subdirectory="impala_learner" + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + ) -> acme.EnvironmentLoop: + """The actor process.""" + environment = self._environment_factory(False) + network = self._network_factory(self._environment_spec.actions) + tf2_utils.create_variables(network, [self._environment_spec.observations]) + + # Component to add things into the queue. + adder = adders.SequenceAdder( + client=replay, + period=self._sequence_period, + sequence_length=self._sequence_length, + ) + + variable_client = tf2_variable_utils.VariableClient( + client=variable_source, + variables={"policy": network.variables}, + update_period=self._variable_update_period, + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + actor = acting.IMPALAActor( + network=network, variable_client=variable_client, adder=adder + ) + + counter = counting.Counter(counter, "actor") + logger = loggers.make_default_logger( + "actor", save_data=False, steps_key="actor_steps" + ) + + # Create the loop to connect environment and agent. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, variable_source: acme.VariableSource, counter: counting.Counter + ): + """The evaluation process.""" + environment = self._environment_factory(True) + network = self._network_factory(self._environment_spec.actions) + tf2_utils.create_variables(network, [self._environment_spec.observations]) + + variable_client = tf2_variable_utils.VariableClient( + client=variable_source, + variables={"policy": network.variables}, + update_period=self._variable_update_period, + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + actor = acting.IMPALAActor(network=network, variable_client=variable_client) + + # Create the run loop and return it. + logger = loggers.make_default_logger("evaluator", steps_key="evaluator_steps") + counter = counting.Counter(counter, "evaluator") + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def build(self, name="impala"): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group("replay"): + queue = program.add_node(lp.ReverbNode(self.queue)) + + with program.group("counter"): + counter = program.add_node(lp.CourierNode(self.counter)) + + with program.group("learner"): + learner = program.add_node(lp.CourierNode(self.learner, queue, counter)) + + with program.group("evaluator"): + program.add_node(lp.CourierNode(self.evaluator, learner, counter)) + + with program.group("cacher"): + cacher = program.add_node( + lp.CacherNode(learner, refresh_interval_ms=2000, stale_after_ms=4000) + ) + + with program.group("actor"): + for _ in range(self._num_actors): + program.add_node(lp.CourierNode(self.actor, queue, cacher, counter)) + + return program diff --git a/acme/agents/tf/impala/agent_distributed_test.py b/acme/agents/tf/impala/agent_distributed_test.py index 04e59d2f37..0877a818a7 100644 --- a/acme/agents/tf/impala/agent_distributed_test.py +++ b/acme/agents/tf/impala/agent_distributed_test.py @@ -14,43 +14,43 @@ """Integration test for the distributed agent.""" +import launchpad as lp +from absl.testing import absltest + import acme from acme.agents.tf import impala from acme.testing import fakes from acme.tf import networks -import launchpad as lp - -from absl.testing import absltest class DistributedAgentTest(absltest.TestCase): - """Simple integration/smoke test for the distributed agent.""" + """Simple integration/smoke test for the distributed agent.""" - def test_atari(self): - """Tests that the agent can run for some steps without crashing.""" - env_factory = lambda x: fakes.fake_atari_wrapped(oar_wrapper=True) - net_factory = lambda spec: networks.IMPALAAtariNetwork(spec.num_values) + def test_atari(self): + """Tests that the agent can run for some steps without crashing.""" + env_factory = lambda x: fakes.fake_atari_wrapped(oar_wrapper=True) + net_factory = lambda spec: networks.IMPALAAtariNetwork(spec.num_values) - agent = impala.DistributedIMPALA( - environment_factory=env_factory, - network_factory=net_factory, - num_actors=2, - batch_size=32, - sequence_length=5, - sequence_period=1, - ) - program = agent.build() + agent = impala.DistributedIMPALA( + environment_factory=env_factory, + network_factory=net_factory, + num_actors=2, + batch_size=32, + sequence_length=5, + sequence_period=1, + ) + program = agent.build() - (learner_node,) = program.groups['learner'] - learner_node.disable_run() + (learner_node,) = program.groups["learner"] + learner_node.disable_run() - lp.launch(program, launch_type='test_mt') + lp.launch(program, launch_type="test_mt") - learner: acme.Learner = learner_node.create_handle().dereference() + learner: acme.Learner = learner_node.create_handle().dereference() - for _ in range(5): - learner.step() + for _ in range(5): + learner.step() -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/impala/agent_test.py b/acme/agents/tf/impala/agent_test.py index 71dd574019..f19457c290 100644 --- a/acme/agents/tf/impala/agent_test.py +++ b/acme/agents/tf/impala/agent_test.py @@ -14,53 +14,53 @@ """Tests for IMPALA agent.""" +import numpy as np +import sonnet as snt +from absl.testing import absltest + import acme from acme import specs from acme.agents.tf import impala from acme.testing import fakes from acme.tf import networks -import numpy as np -import sonnet as snt - -from absl.testing import absltest def _make_network(action_spec: specs.DiscreteArray) -> snt.RNNCore: - return snt.DeepRNN([ - snt.Flatten(), - snt.LSTM(20), - snt.nets.MLP([50, 50]), - networks.PolicyValueHead(action_spec.num_values), - ]) + return snt.DeepRNN( + [ + snt.Flatten(), + snt.LSTM(20), + snt.nets.MLP([50, 50]), + networks.PolicyValueHead(action_spec.num_values), + ] + ) class IMPALATest(absltest.TestCase): - # TODO(b/200509080): This test case is timing out. - @absltest.SkipTest - def test_impala(self): - # Create a fake environment to test with. - environment = fakes.DiscreteEnvironment( - num_actions=5, - num_observations=10, - obs_dtype=np.float32, - episode_length=10) - spec = specs.make_environment_spec(environment) + # TODO(b/200509080): This test case is timing out. + @absltest.SkipTest + def test_impala(self): + # Create a fake environment to test with. + environment = fakes.DiscreteEnvironment( + num_actions=5, num_observations=10, obs_dtype=np.float32, episode_length=10 + ) + spec = specs.make_environment_spec(environment) - # Construct the agent. - agent = impala.IMPALA( - environment_spec=spec, - network=_make_network(spec.actions), - sequence_length=3, - sequence_period=3, - batch_size=6, - ) + # Construct the agent. + agent = impala.IMPALA( + environment_spec=spec, + network=_make_network(spec.actions), + sequence_length=3, + sequence_period=3, + batch_size=6, + ) - # Try running the environment loop. We have no assertions here because all - # we care about is that the agent runs without raising any errors. - loop = acme.EnvironmentLoop(environment, agent) - loop.run(num_episodes=20) + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=20) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/impala/learning.py b/acme/agents/tf/impala/learning.py index b2ea261471..4c7cb7d366 100644 --- a/acme/agents/tf/impala/learning.py +++ b/acme/agents/tf/impala/learning.py @@ -17,12 +17,6 @@ import time from typing import Dict, List, Mapping, Optional -import acme -from acme import specs -from acme.tf import savers as tf2_savers -from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers import numpy as np import reverb import sonnet as snt @@ -31,160 +25,174 @@ import tree import trfl +import acme +from acme import specs +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting, loggers + tfd = tfp.distributions class IMPALALearner(acme.Learner, tf2_savers.TFSaveable): - """Learner for an importanced-weighted advantage actor-critic.""" - - def __init__( - self, - environment_spec: specs.EnvironmentSpec, - network: snt.RNNCore, - dataset: tf.data.Dataset, - learning_rate: float, - discount: float = 0.99, - entropy_cost: float = 0., - baseline_cost: float = 1., - max_abs_reward: Optional[float] = None, - max_gradient_norm: Optional[float] = None, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - ): - - # Internalise, optimizer, and dataset. - self._env_spec = environment_spec - self._optimizer = snt.optimizers.Adam(learning_rate=learning_rate) - self._network = network - self._variables = network.variables - # TODO(b/155086959): Fix type stubs and remove. - self._iterator = iter(dataset) # pytype: disable=wrong-arg-types - - # Hyperparameters. - self._discount = discount - self._entropy_cost = entropy_cost - self._baseline_cost = baseline_cost - - # Set up reward/gradient clipping. - if max_abs_reward is None: - max_abs_reward = np.inf - if max_gradient_norm is None: - max_gradient_norm = 1e10 # A very large number. Infinity results in NaNs. - self._max_abs_reward = tf.convert_to_tensor(max_abs_reward) - self._max_gradient_norm = tf.convert_to_tensor(max_gradient_norm) - - # Set up logging/counting. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) - - self._snapshotter = tf2_savers.Snapshotter( - objects_to_save={'network': network}, time_delta_minutes=60.) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - @property - def state(self) -> Mapping[str, tf2_savers.Checkpointable]: - """Returns the stateful objects for checkpointing.""" - return { - 'network': self._network, - 'optimizer': self._optimizer, - } - - @tf.function - def _step(self) -> Dict[str, tf.Tensor]: - """Does an SGD step on a batch of sequences.""" - - # Retrieve a batch of data from replay. - inputs: reverb.ReplaySample = next(self._iterator) - data = tf2_utils.batch_to_sequence(inputs.data) - observations, actions, rewards, discounts, extra = (data.observation, - data.action, - data.reward, - data.discount, - data.extras) - core_state = tree.map_structure(lambda s: s[0], extra['core_state']) - - # - actions = actions[:-1] # [T-1] - rewards = rewards[:-1] # [T-1] - discounts = discounts[:-1] # [T-1] - - with tf.GradientTape() as tape: - # Unroll current policy over observations. - (logits, values), _ = snt.static_unroll(self._network, observations, - core_state) - - # Compute importance sampling weights: current policy / behavior policy. - behaviour_logits = extra['logits'] - pi_behaviour = tfd.Categorical(logits=behaviour_logits[:-1]) - pi_target = tfd.Categorical(logits=logits[:-1]) - log_rhos = pi_target.log_prob(actions) - pi_behaviour.log_prob(actions) - - # Optionally clip rewards. - rewards = tf.clip_by_value(rewards, - tf.cast(-self._max_abs_reward, rewards.dtype), - tf.cast(self._max_abs_reward, rewards.dtype)) - - # Critic loss. - vtrace_returns = trfl.vtrace_from_importance_weights( - log_rhos=tf.cast(log_rhos, tf.float32), - discounts=tf.cast(self._discount * discounts, tf.float32), - rewards=tf.cast(rewards, tf.float32), - values=tf.cast(values[:-1], tf.float32), - bootstrap_value=values[-1], - ) - critic_loss = tf.square(vtrace_returns.vs - values[:-1]) - - # Policy-gradient loss. - policy_gradient_loss = trfl.policy_gradient( - policies=pi_target, - actions=actions, - action_values=vtrace_returns.pg_advantages, - ) - - # Entropy regulariser. - entropy_loss = trfl.policy_entropy_loss(pi_target).loss - - # Combine weighted sum of actor & critic losses. - loss = tf.reduce_mean(policy_gradient_loss + - self._baseline_cost * critic_loss + - self._entropy_cost * entropy_loss) - - # Compute gradients and optionally apply clipping. - gradients = tape.gradient(loss, self._network.trainable_variables) - gradients, _ = tf.clip_by_global_norm(gradients, self._max_gradient_norm) - self._optimizer.apply(gradients, self._network.trainable_variables) - - metrics = { - 'loss': loss, - 'critic_loss': tf.reduce_mean(critic_loss), - 'entropy_loss': tf.reduce_mean(entropy_loss), - 'policy_gradient_loss': tf.reduce_mean(policy_gradient_loss), - } - - return metrics - - def step(self): - """Does a step of SGD and logs the results.""" - - # Do a batch of SGD. - results = self._step() - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Update our counts and record it. - counts = self._counter.increment(steps=1, walltime=elapsed_time) - results.update(counts) - - # Snapshot and attempt to write logs. - self._snapshotter.save() - self._logger.write(results) - - def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: - return [tf2_utils.to_numpy(self._variables)] + """Learner for an importanced-weighted advantage actor-critic.""" + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + network: snt.RNNCore, + dataset: tf.data.Dataset, + learning_rate: float, + discount: float = 0.99, + entropy_cost: float = 0.0, + baseline_cost: float = 1.0, + max_abs_reward: Optional[float] = None, + max_gradient_norm: Optional[float] = None, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + ): + + # Internalise, optimizer, and dataset. + self._env_spec = environment_spec + self._optimizer = snt.optimizers.Adam(learning_rate=learning_rate) + self._network = network + self._variables = network.variables + # TODO(b/155086959): Fix type stubs and remove. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + + # Hyperparameters. + self._discount = discount + self._entropy_cost = entropy_cost + self._baseline_cost = baseline_cost + + # Set up reward/gradient clipping. + if max_abs_reward is None: + max_abs_reward = np.inf + if max_gradient_norm is None: + max_gradient_norm = 1e10 # A very large number. Infinity results in NaNs. + self._max_abs_reward = tf.convert_to_tensor(max_abs_reward) + self._max_gradient_norm = tf.convert_to_tensor(max_gradient_norm) + + # Set up logging/counting. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.TerminalLogger("learner", time_delta=1.0) + + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={"network": network}, time_delta_minutes=60.0 + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @property + def state(self) -> Mapping[str, tf2_savers.Checkpointable]: + """Returns the stateful objects for checkpointing.""" + return { + "network": self._network, + "optimizer": self._optimizer, + } + + @tf.function + def _step(self) -> Dict[str, tf.Tensor]: + """Does an SGD step on a batch of sequences.""" + + # Retrieve a batch of data from replay. + inputs: reverb.ReplaySample = next(self._iterator) + data = tf2_utils.batch_to_sequence(inputs.data) + observations, actions, rewards, discounts, extra = ( + data.observation, + data.action, + data.reward, + data.discount, + data.extras, + ) + core_state = tree.map_structure(lambda s: s[0], extra["core_state"]) + + # + actions = actions[:-1] # [T-1] + rewards = rewards[:-1] # [T-1] + discounts = discounts[:-1] # [T-1] + + with tf.GradientTape() as tape: + # Unroll current policy over observations. + (logits, values), _ = snt.static_unroll( + self._network, observations, core_state + ) + + # Compute importance sampling weights: current policy / behavior policy. + behaviour_logits = extra["logits"] + pi_behaviour = tfd.Categorical(logits=behaviour_logits[:-1]) + pi_target = tfd.Categorical(logits=logits[:-1]) + log_rhos = pi_target.log_prob(actions) - pi_behaviour.log_prob(actions) + + # Optionally clip rewards. + rewards = tf.clip_by_value( + rewards, + tf.cast(-self._max_abs_reward, rewards.dtype), + tf.cast(self._max_abs_reward, rewards.dtype), + ) + + # Critic loss. + vtrace_returns = trfl.vtrace_from_importance_weights( + log_rhos=tf.cast(log_rhos, tf.float32), + discounts=tf.cast(self._discount * discounts, tf.float32), + rewards=tf.cast(rewards, tf.float32), + values=tf.cast(values[:-1], tf.float32), + bootstrap_value=values[-1], + ) + critic_loss = tf.square(vtrace_returns.vs - values[:-1]) + + # Policy-gradient loss. + policy_gradient_loss = trfl.policy_gradient( + policies=pi_target, + actions=actions, + action_values=vtrace_returns.pg_advantages, + ) + + # Entropy regulariser. + entropy_loss = trfl.policy_entropy_loss(pi_target).loss + + # Combine weighted sum of actor & critic losses. + loss = tf.reduce_mean( + policy_gradient_loss + + self._baseline_cost * critic_loss + + self._entropy_cost * entropy_loss + ) + + # Compute gradients and optionally apply clipping. + gradients = tape.gradient(loss, self._network.trainable_variables) + gradients, _ = tf.clip_by_global_norm(gradients, self._max_gradient_norm) + self._optimizer.apply(gradients, self._network.trainable_variables) + + metrics = { + "loss": loss, + "critic_loss": tf.reduce_mean(critic_loss), + "entropy_loss": tf.reduce_mean(entropy_loss), + "policy_gradient_loss": tf.reduce_mean(policy_gradient_loss), + } + + return metrics + + def step(self): + """Does a step of SGD and logs the results.""" + + # Do a batch of SGD. + results = self._step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + results.update(counts) + + # Snapshot and attempt to write logs. + self._snapshotter.save() + self._logger.write(results) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables)] diff --git a/acme/agents/tf/iqn/learning.py b/acme/agents/tf/iqn/learning.py index aba47c61e5..de9c91ad21 100644 --- a/acme/agents/tf/iqn/learning.py +++ b/acme/agents/tf/iqn/learning.py @@ -16,40 +16,38 @@ from typing import Dict, List, Optional, Tuple -from acme import core -from acme import types -from acme.adders import reverb as adders -from acme.tf import losses -from acme.tf import networks -from acme.tf import savers as tf2_savers -from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers import numpy as np import reverb import sonnet as snt import tensorflow as tf +from acme import core, types +from acme.adders import reverb as adders +from acme.tf import losses, networks +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting, loggers + class IQNLearner(core.Learner, tf2_savers.TFSaveable): - """Distributional DQN learner.""" - - def __init__( - self, - network: networks.IQNNetwork, - target_network: snt.Module, - discount: float, - importance_sampling_exponent: float, - learning_rate: float, - target_update_period: int, - dataset: tf.data.Dataset, - huber_loss_parameter: float = 1., - replay_client: Optional[reverb.TFClient] = None, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - checkpoint: bool = True, + """Distributional DQN learner.""" + + def __init__( + self, + network: networks.IQNNetwork, + target_network: snt.Module, + discount: float, + importance_sampling_exponent: float, + learning_rate: float, + target_update_period: int, + dataset: tf.data.Dataset, + huber_loss_parameter: float = 1.0, + replay_client: Optional[reverb.TFClient] = None, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, ): - """Initializes the learner. + """Initializes the learner. Args: network: the online Q network (the one being optimized) that outputs @@ -70,171 +68,170 @@ def __init__( checkpoint: boolean indicating whether to checkpoint the learner or not. """ - # Internalise agent components (replay buffer, networks, optimizer). - self._iterator = iter(dataset) # pytype: disable=wrong-arg-types - self._network = network - self._target_network = target_network - self._optimizer = snt.optimizers.Adam(learning_rate) - self._replay_client = replay_client - - # Internalise the hyperparameters. - self._discount = discount - self._target_update_period = target_update_period - self._importance_sampling_exponent = importance_sampling_exponent - self._huber_loss_parameter = huber_loss_parameter - - # Learner state. - self._variables: List[List[tf.Tensor]] = [network.trainable_variables] - self._num_steps = tf.Variable(0, dtype=tf.int32) - - # Internalise logging/counting objects. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) - - # Create a snapshotter object. - if checkpoint: - self._checkpointer = tf2_savers.Checkpointer( - time_delta_minutes=5, - objects_to_save={ - 'network': self._network, - 'target_network': self._target_network, - 'optimizer': self._optimizer, - 'num_steps': self._num_steps - }) - self._snapshotter = tf2_savers.Snapshotter( - objects_to_save={'network': network}, time_delta_minutes=60.) - else: - self._checkpointer = None - self._snapshotter = None - - @tf.function - def _step(self) -> Dict[str, tf.Tensor]: - """Do a step of SGD and update the priorities.""" - - # Pull out the data needed for updates/priorities. - inputs = next(self._iterator) - transitions: types.Transition = inputs.data - keys, probs, *_ = inputs.info - - with tf.GradientTape() as tape: - loss, fetches = self._loss_and_fetches(transitions.observation, - transitions.action, - transitions.reward, - transitions.discount, - transitions.next_observation) - - # Get the importance weights. - importance_weights = 1. / probs # [B] - importance_weights **= self._importance_sampling_exponent - importance_weights /= tf.reduce_max(importance_weights) - - # Reweight. - loss *= tf.cast(importance_weights, loss.dtype) # [B] - loss = tf.reduce_mean(loss, axis=[0]) # [] - - # Do a step of SGD. - gradients = tape.gradient(loss, self._network.trainable_variables) - self._optimizer.apply(gradients, self._network.trainable_variables) - - # Update the priorities in the replay buffer. - if self._replay_client: - priorities = tf.clip_by_value(tf.abs(loss), -100, 100) - priorities = tf.cast(priorities, tf.float64) - self._replay_client.update_priorities( - table=adders.DEFAULT_PRIORITY_TABLE, keys=keys, priorities=priorities) - - # Periodically update the target network. - if tf.math.mod(self._num_steps, self._target_update_period) == 0: - for src, dest in zip(self._network.variables, - self._target_network.variables): - dest.assign(src) - self._num_steps.assign_add(1) - - # Report gradient norms. - fetches.update( - loss=loss, - gradient_norm=tf.linalg.global_norm(gradients)) - return fetches - - def step(self): - # Do a batch of SGD. - result = self._step() - - # Update our counts and record it. - counts = self._counter.increment(steps=1) - result.update(counts) - - # Checkpoint and attempt to write logs. - if self._checkpointer is not None: - self._checkpointer.save() - if self._snapshotter is not None: - self._snapshotter.save() - self._logger.write(result) - - def get_variables(self, names: List[str]) -> List[np.ndarray]: - return tf2_utils.to_numpy(self._variables) - - def _loss_and_fetches( - self, - o_tm1: tf.Tensor, - a_tm1: tf.Tensor, - r_t: tf.Tensor, - d_t: tf.Tensor, - o_t: tf.Tensor, - ) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]: - # Evaluate our networks. - _, dist_tm1, tau = self._network(o_tm1) - q_tm1 = _index_embs_with_actions(dist_tm1, a_tm1) - - q_selector, _, _ = self._target_network(o_t) - a_t = tf.argmax(q_selector, axis=1) - - _, dist_t, _ = self._target_network(o_t) - q_t = _index_embs_with_actions(dist_t, a_t) - - q_tm1 = losses.QuantileDistribution(values=q_tm1, - logits=tf.zeros_like(q_tm1)) - q_t = losses.QuantileDistribution(values=q_t, logits=tf.zeros_like(q_t)) - - # The rewards and discounts have to have the same type as network values. - r_t = tf.cast(r_t, tf.float32) - r_t = tf.clip_by_value(r_t, -1., 1.) - d_t = tf.cast(d_t, tf.float32) * tf.cast(self._discount, tf.float32) - - # Compute the loss. - loss_module = losses.NonUniformQuantileRegression( - self._huber_loss_parameter) - loss = loss_module(q_tm1, r_t, d_t, q_t, tau) - - # Compute statistics of the Q-values for logging. - max_q = tf.reduce_max(q_t.values) - min_q = tf.reduce_min(q_t.values) - mean_q, var_q = tf.nn.moments(q_t.values, [0, 1]) - fetches = { - 'max_q': max_q, - 'mean_q': mean_q, - 'min_q': min_q, - 'var_q': var_q, - } - - return loss, fetches - - @property - def state(self): - """Returns the stateful parts of the learner for checkpointing.""" - return { - 'network': self._network, - 'target_network': self._target_network, - 'optimizer': self._optimizer, - 'num_steps': self._num_steps - } - - -def _index_embs_with_actions( - embeddings: tf.Tensor, - actions: tf.Tensor, -) -> tf.Tensor: - """Slice an embedding Tensor with action indices. + # Internalise agent components (replay buffer, networks, optimizer). + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + self._network = network + self._target_network = target_network + self._optimizer = snt.optimizers.Adam(learning_rate) + self._replay_client = replay_client + + # Internalise the hyperparameters. + self._discount = discount + self._target_update_period = target_update_period + self._importance_sampling_exponent = importance_sampling_exponent + self._huber_loss_parameter = huber_loss_parameter + + # Learner state. + self._variables: List[List[tf.Tensor]] = [network.trainable_variables] + self._num_steps = tf.Variable(0, dtype=tf.int32) + + # Internalise logging/counting objects. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.TerminalLogger("learner", time_delta=1.0) + + # Create a snapshotter object. + if checkpoint: + self._checkpointer = tf2_savers.Checkpointer( + time_delta_minutes=5, + objects_to_save={ + "network": self._network, + "target_network": self._target_network, + "optimizer": self._optimizer, + "num_steps": self._num_steps, + }, + ) + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={"network": network}, time_delta_minutes=60.0 + ) + else: + self._checkpointer = None + self._snapshotter = None + + @tf.function + def _step(self) -> Dict[str, tf.Tensor]: + """Do a step of SGD and update the priorities.""" + + # Pull out the data needed for updates/priorities. + inputs = next(self._iterator) + transitions: types.Transition = inputs.data + keys, probs, *_ = inputs.info + + with tf.GradientTape() as tape: + loss, fetches = self._loss_and_fetches( + transitions.observation, + transitions.action, + transitions.reward, + transitions.discount, + transitions.next_observation, + ) + + # Get the importance weights. + importance_weights = 1.0 / probs # [B] + importance_weights **= self._importance_sampling_exponent + importance_weights /= tf.reduce_max(importance_weights) + + # Reweight. + loss *= tf.cast(importance_weights, loss.dtype) # [B] + loss = tf.reduce_mean(loss, axis=[0]) # [] + + # Do a step of SGD. + gradients = tape.gradient(loss, self._network.trainable_variables) + self._optimizer.apply(gradients, self._network.trainable_variables) + + # Update the priorities in the replay buffer. + if self._replay_client: + priorities = tf.clip_by_value(tf.abs(loss), -100, 100) + priorities = tf.cast(priorities, tf.float64) + self._replay_client.update_priorities( + table=adders.DEFAULT_PRIORITY_TABLE, keys=keys, priorities=priorities + ) + + # Periodically update the target network. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip( + self._network.variables, self._target_network.variables + ): + dest.assign(src) + self._num_steps.assign_add(1) + + # Report gradient norms. + fetches.update(loss=loss, gradient_norm=tf.linalg.global_norm(gradients)) + return fetches + + def step(self): + # Do a batch of SGD. + result = self._step() + + # Update our counts and record it. + counts = self._counter.increment(steps=1) + result.update(counts) + + # Checkpoint and attempt to write logs. + if self._checkpointer is not None: + self._checkpointer.save() + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(result) + + def get_variables(self, names: List[str]) -> List[np.ndarray]: + return tf2_utils.to_numpy(self._variables) + + def _loss_and_fetches( + self, + o_tm1: tf.Tensor, + a_tm1: tf.Tensor, + r_t: tf.Tensor, + d_t: tf.Tensor, + o_t: tf.Tensor, + ) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]: + # Evaluate our networks. + _, dist_tm1, tau = self._network(o_tm1) + q_tm1 = _index_embs_with_actions(dist_tm1, a_tm1) + + q_selector, _, _ = self._target_network(o_t) + a_t = tf.argmax(q_selector, axis=1) + + _, dist_t, _ = self._target_network(o_t) + q_t = _index_embs_with_actions(dist_t, a_t) + + q_tm1 = losses.QuantileDistribution(values=q_tm1, logits=tf.zeros_like(q_tm1)) + q_t = losses.QuantileDistribution(values=q_t, logits=tf.zeros_like(q_t)) + + # The rewards and discounts have to have the same type as network values. + r_t = tf.cast(r_t, tf.float32) + r_t = tf.clip_by_value(r_t, -1.0, 1.0) + d_t = tf.cast(d_t, tf.float32) * tf.cast(self._discount, tf.float32) + + # Compute the loss. + loss_module = losses.NonUniformQuantileRegression(self._huber_loss_parameter) + loss = loss_module(q_tm1, r_t, d_t, q_t, tau) + + # Compute statistics of the Q-values for logging. + max_q = tf.reduce_max(q_t.values) + min_q = tf.reduce_min(q_t.values) + mean_q, var_q = tf.nn.moments(q_t.values, [0, 1]) + fetches = { + "max_q": max_q, + "mean_q": mean_q, + "min_q": min_q, + "var_q": var_q, + } + + return loss, fetches + + @property + def state(self): + """Returns the stateful parts of the learner for checkpointing.""" + return { + "network": self._network, + "target_network": self._target_network, + "optimizer": self._optimizer, + "num_steps": self._num_steps, + } + + +def _index_embs_with_actions(embeddings: tf.Tensor, actions: tf.Tensor,) -> tf.Tensor: + """Slice an embedding Tensor with action indices. Take embeddings of the form [batch_size, num_actions, embed_dim] and actions of the form [batch_size], and return the sliced embeddings @@ -248,19 +245,20 @@ def _index_embs_with_actions( Returns: Tensor of embeddings indexed by actions """ - batch_size, num_actions, _ = embeddings.shape.as_list() - - # Values are the 'values' in a sparse tensor we will be setting - act_indx = tf.cast(actions, tf.int64)[:, None] - values = tf.ones([tf.size(actions)], dtype=tf.bool) - - # Create a range for each index into the batch - act_range = tf.range(0, batch_size, dtype=tf.int64)[:, None] - # Combine this into coordinates with the action indices - indices = tf.concat([act_range, act_indx], 1) - - actions_mask = tf.SparseTensor(indices, values, [batch_size, num_actions]) - actions_mask = tf.stop_gradient( - tf.sparse.to_dense(actions_mask, default_value=False)) - sliced_emb = tf.boolean_mask(embeddings, actions_mask) - return sliced_emb + batch_size, num_actions, _ = embeddings.shape.as_list() + + # Values are the 'values' in a sparse tensor we will be setting + act_indx = tf.cast(actions, tf.int64)[:, None] + values = tf.ones([tf.size(actions)], dtype=tf.bool) + + # Create a range for each index into the batch + act_range = tf.range(0, batch_size, dtype=tf.int64)[:, None] + # Combine this into coordinates with the action indices + indices = tf.concat([act_range, act_indx], 1) + + actions_mask = tf.SparseTensor(indices, values, [batch_size, num_actions]) + actions_mask = tf.stop_gradient( + tf.sparse.to_dense(actions_mask, default_value=False) + ) + sliced_emb = tf.boolean_mask(embeddings, actions_mask) + return sliced_emb diff --git a/acme/agents/tf/iqn/learning_test.py b/acme/agents/tf/iqn/learning_test.py index 9e2bdae6a3..a508a0708e 100644 --- a/acme/agents/tf/iqn/learning_test.py +++ b/acme/agents/tf/iqn/learning_test.py @@ -16,74 +16,69 @@ import copy +import numpy as np +import sonnet as snt +from absl.testing import absltest + from acme import specs from acme.agents.tf import iqn from acme.testing import fakes from acme.tf import networks from acme.tf import utils as tf2_utils from acme.utils import counting -import numpy as np -import sonnet as snt - -from absl.testing import absltest def _make_torso_network(num_outputs: int) -> snt.Module: - """Create torso network (outputs intermediate representation).""" - return snt.Sequential([ - snt.Flatten(), - snt.nets.MLP([num_outputs]) - ]) + """Create torso network (outputs intermediate representation).""" + return snt.Sequential([snt.Flatten(), snt.nets.MLP([num_outputs])]) def _make_head_network(num_outputs: int) -> snt.Module: - """Create head network (outputs Q-values).""" - return snt.nets.MLP([num_outputs]) + """Create head network (outputs Q-values).""" + return snt.nets.MLP([num_outputs]) class IQNLearnerTest(absltest.TestCase): - - def test_full_learner(self): - # Create dataset. - environment = fakes.DiscreteEnvironment( - num_actions=5, - num_observations=10, - obs_dtype=np.float32, - episode_length=10) - spec = specs.make_environment_spec(environment) - dataset = fakes.transition_dataset(environment).batch( - 2, drop_remainder=True) - - # Build network. - network = networks.IQNNetwork( - torso=_make_torso_network(num_outputs=2), - head=_make_head_network(num_outputs=spec.actions.num_values), - latent_dim=2, - num_quantile_samples=1) - tf2_utils.create_variables(network, [spec.observations]) - - # Build learner. - counter = counting.Counter() - learner = iqn.IQNLearner( - network=network, - target_network=copy.deepcopy(network), - dataset=dataset, - learning_rate=1e-4, - discount=0.99, - importance_sampling_exponent=0.2, - target_update_period=1, - counter=counter) - - # Run a learner step. - learner.step() - - # Check counts from IQN learner. - counts = counter.get_counts() - self.assertEqual(1, counts['steps']) - - # Check learner state. - self.assertEqual(1, learner.state['num_steps'].numpy()) - - -if __name__ == '__main__': - absltest.main() + def test_full_learner(self): + # Create dataset. + environment = fakes.DiscreteEnvironment( + num_actions=5, num_observations=10, obs_dtype=np.float32, episode_length=10 + ) + spec = specs.make_environment_spec(environment) + dataset = fakes.transition_dataset(environment).batch(2, drop_remainder=True) + + # Build network. + network = networks.IQNNetwork( + torso=_make_torso_network(num_outputs=2), + head=_make_head_network(num_outputs=spec.actions.num_values), + latent_dim=2, + num_quantile_samples=1, + ) + tf2_utils.create_variables(network, [spec.observations]) + + # Build learner. + counter = counting.Counter() + learner = iqn.IQNLearner( + network=network, + target_network=copy.deepcopy(network), + dataset=dataset, + learning_rate=1e-4, + discount=0.99, + importance_sampling_exponent=0.2, + target_update_period=1, + counter=counter, + ) + + # Run a learner step. + learner.step() + + # Check counts from IQN learner. + counts = counter.get_counts() + self.assertEqual(1, counts["steps"]) + + # Check learner state. + self.assertEqual(1, learner.state["num_steps"].numpy()) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/mcts/acting.py b/acme/agents/tf/mcts/acting.py index 887d7c25e5..41e98d706d 100644 --- a/acme/agents/tf/mcts/acting.py +++ b/acme/agents/tf/mcts/acting.py @@ -16,105 +16,104 @@ from typing import Optional, Tuple -import acme -from acme import adders -from acme import specs -from acme.agents.tf.mcts import models -from acme.agents.tf.mcts import search -from acme.agents.tf.mcts import types -from acme.tf import variable_utils as tf2_variable_utils - import dm_env import numpy as np -from scipy import special import sonnet as snt import tensorflow as tf +from scipy import special + +import acme +from acme import adders, specs +from acme.agents.tf.mcts import models, search, types +from acme.tf import variable_utils as tf2_variable_utils class MCTSActor(acme.Actor): - """Executes a policy- and value-network guided MCTS search.""" - - _prev_timestep: dm_env.TimeStep - - def __init__( - self, - environment_spec: specs.EnvironmentSpec, - model: models.Model, - network: snt.Module, - discount: float, - num_simulations: int, - adder: Optional[adders.Adder] = None, - variable_client: Optional[tf2_variable_utils.VariableClient] = None, - ): - - # Internalize components: model, network, data sink and variable source. - self._model = model - self._network = tf.function(network) - self._variable_client = variable_client - self._adder = adder - - # Internalize hyperparameters. - self._num_actions = environment_spec.actions.num_values - self._num_simulations = num_simulations - self._actions = list(range(self._num_actions)) - self._discount = discount - - # We need to save the policy so as to add it to replay on the next step. - self._probs = np.ones( - shape=(self._num_actions,), dtype=np.float32) / self._num_actions - - def _forward( - self, observation: types.Observation) -> Tuple[types.Probs, types.Value]: - """Performs a forward pass of the policy-value network.""" - logits, value = self._network(tf.expand_dims(observation, axis=0)) - - # Convert to numpy & take softmax. - logits = logits.numpy().squeeze(axis=0) - value = value.numpy().item() - probs = special.softmax(logits) - - return probs, value - - def select_action(self, observation: types.Observation) -> types.Action: - """Computes the agent's policy via MCTS.""" - if self._model.needs_reset: - self._model.reset(observation) - - # Compute a fresh MCTS plan. - root = search.mcts( - observation, - model=self._model, - search_policy=search.puct, - evaluation=self._forward, - num_simulations=self._num_simulations, - num_actions=self._num_actions, - discount=self._discount, - ) - - # The agent's policy is softmax w.r.t. the *visit counts* as in AlphaZero. - probs = search.visit_count_policy(root) - action = np.int32(np.random.choice(self._actions, p=probs)) - - # Save the policy probs so that we can add them to replay in `observe()`. - self._probs = probs.astype(np.float32) - - return action - - def update(self, wait: bool = False): - """Fetches the latest variables from the variable source, if needed.""" - if self._variable_client: - self._variable_client.update(wait) - - def observe_first(self, timestep: dm_env.TimeStep): - self._prev_timestep = timestep - if self._adder: - self._adder.add_first(timestep) - - def observe(self, action: types.Action, next_timestep: dm_env.TimeStep): - """Updates the agent's internal model and adds the transition to replay.""" - self._model.update(self._prev_timestep, action, next_timestep) - - self._prev_timestep = next_timestep - - if self._adder: - self._adder.add(action, next_timestep, extras={'pi': self._probs}) + """Executes a policy- and value-network guided MCTS search.""" + + _prev_timestep: dm_env.TimeStep + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + model: models.Model, + network: snt.Module, + discount: float, + num_simulations: int, + adder: Optional[adders.Adder] = None, + variable_client: Optional[tf2_variable_utils.VariableClient] = None, + ): + + # Internalize components: model, network, data sink and variable source. + self._model = model + self._network = tf.function(network) + self._variable_client = variable_client + self._adder = adder + + # Internalize hyperparameters. + self._num_actions = environment_spec.actions.num_values + self._num_simulations = num_simulations + self._actions = list(range(self._num_actions)) + self._discount = discount + + # We need to save the policy so as to add it to replay on the next step. + self._probs = ( + np.ones(shape=(self._num_actions,), dtype=np.float32) / self._num_actions + ) + + def _forward( + self, observation: types.Observation + ) -> Tuple[types.Probs, types.Value]: + """Performs a forward pass of the policy-value network.""" + logits, value = self._network(tf.expand_dims(observation, axis=0)) + + # Convert to numpy & take softmax. + logits = logits.numpy().squeeze(axis=0) + value = value.numpy().item() + probs = special.softmax(logits) + + return probs, value + + def select_action(self, observation: types.Observation) -> types.Action: + """Computes the agent's policy via MCTS.""" + if self._model.needs_reset: + self._model.reset(observation) + + # Compute a fresh MCTS plan. + root = search.mcts( + observation, + model=self._model, + search_policy=search.puct, + evaluation=self._forward, + num_simulations=self._num_simulations, + num_actions=self._num_actions, + discount=self._discount, + ) + + # The agent's policy is softmax w.r.t. the *visit counts* as in AlphaZero. + probs = search.visit_count_policy(root) + action = np.int32(np.random.choice(self._actions, p=probs)) + + # Save the policy probs so that we can add them to replay in `observe()`. + self._probs = probs.astype(np.float32) + + return action + + def update(self, wait: bool = False): + """Fetches the latest variables from the variable source, if needed.""" + if self._variable_client: + self._variable_client.update(wait) + + def observe_first(self, timestep: dm_env.TimeStep): + self._prev_timestep = timestep + if self._adder: + self._adder.add_first(timestep) + + def observe(self, action: types.Action, next_timestep: dm_env.TimeStep): + """Updates the agent's internal model and adds the transition to replay.""" + self._model.update(self._prev_timestep, action, next_timestep) + + self._prev_timestep = next_timestep + + if self._adder: + self._adder.add(action, next_timestep, extras={"pi": self._probs}) diff --git a/acme/agents/tf/mcts/agent.py b/acme/agents/tf/mcts/agent.py index c7b58bc0e0..0fd81c5e32 100644 --- a/acme/agents/tf/mcts/agent.py +++ b/acme/agents/tf/mcts/agent.py @@ -14,86 +14,78 @@ """A single-process MCTS agent.""" -from acme import datasets -from acme import specs -from acme.adders import reverb as adders -from acme.agents import agent -from acme.agents.tf.mcts import acting -from acme.agents.tf.mcts import learning -from acme.agents.tf.mcts import models -from acme.tf import utils as tf2_utils - import numpy as np import reverb import sonnet as snt +from acme import datasets, specs +from acme.adders import reverb as adders +from acme.agents import agent +from acme.agents.tf.mcts import acting, learning, models +from acme.tf import utils as tf2_utils + class MCTS(agent.Agent): - """A single-process MCTS agent.""" + """A single-process MCTS agent.""" - def __init__( - self, - network: snt.Module, - model: models.Model, - optimizer: snt.Optimizer, - n_step: int, - discount: float, - replay_capacity: int, - num_simulations: int, - environment_spec: specs.EnvironmentSpec, - batch_size: int, - ): + def __init__( + self, + network: snt.Module, + model: models.Model, + optimizer: snt.Optimizer, + n_step: int, + discount: float, + replay_capacity: int, + num_simulations: int, + environment_spec: specs.EnvironmentSpec, + batch_size: int, + ): - extra_spec = { - 'pi': - specs.Array( - shape=(environment_spec.actions.num_values,), dtype=np.float32) - } - # Create a replay server for storing transitions. - replay_table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=replay_capacity, - rate_limiter=reverb.rate_limiters.MinSize(1), - signature=adders.NStepTransitionAdder.signature( - environment_spec, extra_spec)) - self._server = reverb.Server([replay_table], port=None) + extra_spec = { + "pi": specs.Array( + shape=(environment_spec.actions.num_values,), dtype=np.float32 + ) + } + # Create a replay server for storing transitions. + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=replay_capacity, + rate_limiter=reverb.rate_limiters.MinSize(1), + signature=adders.NStepTransitionAdder.signature( + environment_spec, extra_spec + ), + ) + self._server = reverb.Server([replay_table], port=None) - # The adder is used to insert observations into replay. - address = f'localhost:{self._server.port}' - adder = adders.NStepTransitionAdder( - client=reverb.Client(address), - n_step=n_step, - discount=discount) + # The adder is used to insert observations into replay. + address = f"localhost:{self._server.port}" + adder = adders.NStepTransitionAdder( + client=reverb.Client(address), n_step=n_step, discount=discount + ) - # The dataset provides an interface to sample from replay. - dataset = datasets.make_reverb_dataset(server_address=address) - dataset = dataset.batch(batch_size, drop_remainder=True) + # The dataset provides an interface to sample from replay. + dataset = datasets.make_reverb_dataset(server_address=address) + dataset = dataset.batch(batch_size, drop_remainder=True) - tf2_utils.create_variables(network, [environment_spec.observations]) + tf2_utils.create_variables(network, [environment_spec.observations]) - # Now create the agent components: actor & learner. - actor = acting.MCTSActor( - environment_spec=environment_spec, - model=model, - network=network, - discount=discount, - adder=adder, - num_simulations=num_simulations, - ) + # Now create the agent components: actor & learner. + actor = acting.MCTSActor( + environment_spec=environment_spec, + model=model, + network=network, + discount=discount, + adder=adder, + num_simulations=num_simulations, + ) - learner = learning.AZLearner( - network=network, - optimizer=optimizer, - dataset=dataset, - discount=discount, - ) + learner = learning.AZLearner( + network=network, optimizer=optimizer, dataset=dataset, discount=discount, + ) - # The parent class combines these together into one 'agent'. - super().__init__( - actor=actor, - learner=learner, - min_observations=10, - observations_per_step=1, - ) + # The parent class combines these together into one 'agent'. + super().__init__( + actor=actor, learner=learner, min_observations=10, observations_per_step=1, + ) diff --git a/acme/agents/tf/mcts/agent_distributed.py b/acme/agents/tf/mcts/agent_distributed.py index b2eae72a03..41d15ac45f 100644 --- a/acme/agents/tf/mcts/agent_distributed.py +++ b/acme/agents/tf/mcts/agent_distributed.py @@ -16,218 +16,218 @@ from typing import Callable, Optional -import acme -from acme import datasets -from acme import specs -from acme.adders import reverb as adders -from acme.agents.tf.mcts import acting -from acme.agents.tf.mcts import learning -from acme.agents.tf.mcts import models -from acme.tf import utils as tf2_utils -from acme.tf import variable_utils as tf2_variable_utils -from acme.utils import counting -from acme.utils import loggers import dm_env import launchpad as lp import reverb import sonnet as snt +import acme +from acme import datasets, specs +from acme.adders import reverb as adders +from acme.agents.tf.mcts import acting, learning, models +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils +from acme.utils import counting, loggers + class DistributedMCTS: - """Distributed MCTS agent.""" - - def __init__( - self, - environment_factory: Callable[[], dm_env.Environment], - network_factory: Callable[[specs.DiscreteArray], snt.Module], - model_factory: Callable[[specs.EnvironmentSpec], models.Model], - num_actors: int, - num_simulations: int = 50, - batch_size: int = 256, - prefetch_size: int = 4, - target_update_period: int = 100, - samples_per_insert: float = 32.0, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - importance_sampling_exponent: float = 0.2, - priority_exponent: float = 0.6, - n_step: int = 5, - learning_rate: float = 1e-3, - discount: float = 0.99, - environment_spec: Optional[specs.EnvironmentSpec] = None, - save_logs: bool = False, - variable_update_period: int = 1000, - ): - - if environment_spec is None: - environment_spec = specs.make_environment_spec(environment_factory()) - - # These 'factories' create the relevant components on the workers. - self._environment_factory = environment_factory - self._network_factory = network_factory - self._model_factory = model_factory - - # Internalize hyperparameters. - self._num_actors = num_actors - self._num_simulations = num_simulations - self._env_spec = environment_spec - self._batch_size = batch_size - self._prefetch_size = prefetch_size - self._target_update_period = target_update_period - self._samples_per_insert = samples_per_insert - self._min_replay_size = min_replay_size - self._max_replay_size = max_replay_size - self._importance_sampling_exponent = importance_sampling_exponent - self._priority_exponent = priority_exponent - self._n_step = n_step - self._learning_rate = learning_rate - self._discount = discount - self._save_logs = save_logs - self._variable_update_period = variable_update_period - - def replay(self): - """The replay storage worker.""" - limiter = reverb.rate_limiters.SampleToInsertRatio( - min_size_to_sample=self._min_replay_size, - samples_per_insert=self._samples_per_insert, - error_buffer=self._batch_size) - extra_spec = { - 'pi': - specs.Array( - shape=(self._env_spec.actions.num_values,), dtype='float32') - } - signature = adders.NStepTransitionAdder.signature(self._env_spec, - extra_spec) - replay_table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=self._max_replay_size, - rate_limiter=limiter, - signature=signature) - return [replay_table] - - def learner(self, replay: reverb.Client, counter: counting.Counter): - """The learning part of the agent.""" - # Create the networks. - network = self._network_factory(self._env_spec.actions) - - tf2_utils.create_variables(network, [self._env_spec.observations]) - - # The dataset object to learn from. - dataset = datasets.make_reverb_dataset( - server_address=replay.server_address, - batch_size=self._batch_size, - prefetch_size=self._prefetch_size) - - # Create the optimizer. - optimizer = snt.optimizers.Adam(self._learning_rate) - - # Return the learning agent. - return learning.AZLearner( - network=network, - discount=self._discount, - dataset=dataset, - optimizer=optimizer, - counter=counter, - ) - - def actor( - self, - replay: reverb.Client, - variable_source: acme.VariableSource, - counter: counting.Counter, - ) -> acme.EnvironmentLoop: - """The actor process.""" - - # Build environment, model, network. - environment = self._environment_factory() - network = self._network_factory(self._env_spec.actions) - model = self._model_factory(self._env_spec) - - # Create variable client for communicating with the learner. - tf2_utils.create_variables(network, [self._env_spec.observations]) - variable_client = tf2_variable_utils.VariableClient( - client=variable_source, - variables={'network': network.trainable_variables}, - update_period=self._variable_update_period) - - # Component to add things into replay. - adder = adders.NStepTransitionAdder( - client=replay, - n_step=self._n_step, - discount=self._discount, - ) - - # Create the agent. - actor = acting.MCTSActor( - environment_spec=self._env_spec, - model=model, - network=network, - discount=self._discount, - adder=adder, - variable_client=variable_client, - num_simulations=self._num_simulations, - ) - - # Create the loop to connect environment and agent. - return acme.EnvironmentLoop(environment, actor, counter) - - def evaluator( - self, - variable_source: acme.VariableSource, - counter: counting.Counter, - ): - """The evaluation process.""" - - # Build environment, model, network. - environment = self._environment_factory() - network = self._network_factory(self._env_spec.actions) - model = self._model_factory(self._env_spec) - - # Create variable client for communicating with the learner. - tf2_utils.create_variables(network, [self._env_spec.observations]) - variable_client = tf2_variable_utils.VariableClient( - client=variable_source, - variables={'policy': network.trainable_variables}, - update_period=self._variable_update_period) - - # Create the agent. - actor = acting.MCTSActor( - environment_spec=self._env_spec, - model=model, - network=network, - discount=self._discount, - variable_client=variable_client, - num_simulations=self._num_simulations, - ) - - # Create the run loop and return it. - logger = loggers.make_default_logger('evaluator') - return acme.EnvironmentLoop( - environment, actor, counter=counter, logger=logger) - - def build(self, name='MCTS'): - """Builds the distributed agent topology.""" - program = lp.Program(name=name) - - with program.group('replay'): - replay = program.add_node(lp.ReverbNode(self.replay), label='replay') - - with program.group('counter'): - counter = program.add_node( - lp.CourierNode(counting.Counter), label='counter') - - with program.group('learner'): - learner = program.add_node( - lp.CourierNode(self.learner, replay, counter), label='learner') - - with program.group('evaluator'): - program.add_node( - lp.CourierNode(self.evaluator, learner, counter), label='evaluator') - - with program.group('actor'): - program.add_node( - lp.CourierNode(self.actor, replay, learner, counter), label='actor') - - return program + """Distributed MCTS agent.""" + + def __init__( + self, + environment_factory: Callable[[], dm_env.Environment], + network_factory: Callable[[specs.DiscreteArray], snt.Module], + model_factory: Callable[[specs.EnvironmentSpec], models.Model], + num_actors: int, + num_simulations: int = 50, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + samples_per_insert: float = 32.0, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + importance_sampling_exponent: float = 0.2, + priority_exponent: float = 0.6, + n_step: int = 5, + learning_rate: float = 1e-3, + discount: float = 0.99, + environment_spec: Optional[specs.EnvironmentSpec] = None, + save_logs: bool = False, + variable_update_period: int = 1000, + ): + + if environment_spec is None: + environment_spec = specs.make_environment_spec(environment_factory()) + + # These 'factories' create the relevant components on the workers. + self._environment_factory = environment_factory + self._network_factory = network_factory + self._model_factory = model_factory + + # Internalize hyperparameters. + self._num_actors = num_actors + self._num_simulations = num_simulations + self._env_spec = environment_spec + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._target_update_period = target_update_period + self._samples_per_insert = samples_per_insert + self._min_replay_size = min_replay_size + self._max_replay_size = max_replay_size + self._importance_sampling_exponent = importance_sampling_exponent + self._priority_exponent = priority_exponent + self._n_step = n_step + self._learning_rate = learning_rate + self._discount = discount + self._save_logs = save_logs + self._variable_update_period = variable_update_period + + def replay(self): + """The replay storage worker.""" + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._min_replay_size, + samples_per_insert=self._samples_per_insert, + error_buffer=self._batch_size, + ) + extra_spec = { + "pi": specs.Array( + shape=(self._env_spec.actions.num_values,), dtype="float32" + ) + } + signature = adders.NStepTransitionAdder.signature(self._env_spec, extra_spec) + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._max_replay_size, + rate_limiter=limiter, + signature=signature, + ) + return [replay_table] + + def learner(self, replay: reverb.Client, counter: counting.Counter): + """The learning part of the agent.""" + # Create the networks. + network = self._network_factory(self._env_spec.actions) + + tf2_utils.create_variables(network, [self._env_spec.observations]) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + server_address=replay.server_address, + batch_size=self._batch_size, + prefetch_size=self._prefetch_size, + ) + + # Create the optimizer. + optimizer = snt.optimizers.Adam(self._learning_rate) + + # Return the learning agent. + return learning.AZLearner( + network=network, + discount=self._discount, + dataset=dataset, + optimizer=optimizer, + counter=counter, + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + ) -> acme.EnvironmentLoop: + """The actor process.""" + + # Build environment, model, network. + environment = self._environment_factory() + network = self._network_factory(self._env_spec.actions) + model = self._model_factory(self._env_spec) + + # Create variable client for communicating with the learner. + tf2_utils.create_variables(network, [self._env_spec.observations]) + variable_client = tf2_variable_utils.VariableClient( + client=variable_source, + variables={"network": network.trainable_variables}, + update_period=self._variable_update_period, + ) + + # Component to add things into replay. + adder = adders.NStepTransitionAdder( + client=replay, n_step=self._n_step, discount=self._discount, + ) + + # Create the agent. + actor = acting.MCTSActor( + environment_spec=self._env_spec, + model=model, + network=network, + discount=self._discount, + adder=adder, + variable_client=variable_client, + num_simulations=self._num_simulations, + ) + + # Create the loop to connect environment and agent. + return acme.EnvironmentLoop(environment, actor, counter) + + def evaluator( + self, variable_source: acme.VariableSource, counter: counting.Counter, + ): + """The evaluation process.""" + + # Build environment, model, network. + environment = self._environment_factory() + network = self._network_factory(self._env_spec.actions) + model = self._model_factory(self._env_spec) + + # Create variable client for communicating with the learner. + tf2_utils.create_variables(network, [self._env_spec.observations]) + variable_client = tf2_variable_utils.VariableClient( + client=variable_source, + variables={"policy": network.trainable_variables}, + update_period=self._variable_update_period, + ) + + # Create the agent. + actor = acting.MCTSActor( + environment_spec=self._env_spec, + model=model, + network=network, + discount=self._discount, + variable_client=variable_client, + num_simulations=self._num_simulations, + ) + + # Create the run loop and return it. + logger = loggers.make_default_logger("evaluator") + return acme.EnvironmentLoop(environment, actor, counter=counter, logger=logger) + + def build(self, name="MCTS"): + """Builds the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group("replay"): + replay = program.add_node(lp.ReverbNode(self.replay), label="replay") + + with program.group("counter"): + counter = program.add_node( + lp.CourierNode(counting.Counter), label="counter" + ) + + with program.group("learner"): + learner = program.add_node( + lp.CourierNode(self.learner, replay, counter), label="learner" + ) + + with program.group("evaluator"): + program.add_node( + lp.CourierNode(self.evaluator, learner, counter), label="evaluator" + ) + + with program.group("actor"): + program.add_node( + lp.CourierNode(self.actor, replay, learner, counter), label="actor" + ) + + return program diff --git a/acme/agents/tf/mcts/agent_test.py b/acme/agents/tf/mcts/agent_test.py index de1a8029ab..e1481281d8 100644 --- a/acme/agents/tf/mcts/agent_test.py +++ b/acme/agents/tf/mcts/agent_test.py @@ -14,55 +14,58 @@ """Tests for the MCTS agent.""" +import numpy as np +import sonnet as snt +from absl.testing import absltest + import acme from acme import specs from acme.agents.tf import mcts from acme.agents.tf.mcts.models import simulator from acme.testing import fakes from acme.tf import networks -import numpy as np -import sonnet as snt - -from absl.testing import absltest class MCTSTest(absltest.TestCase): + def test_mcts(self): + # Create a fake environment to test with. + num_actions = 5 + environment = fakes.DiscreteEnvironment( + num_actions=num_actions, + num_observations=10, + obs_dtype=np.float32, + episode_length=10, + ) + spec = specs.make_environment_spec(environment) - def test_mcts(self): - # Create a fake environment to test with. - num_actions = 5 - environment = fakes.DiscreteEnvironment( - num_actions=num_actions, - num_observations=10, - obs_dtype=np.float32, - episode_length=10) - spec = specs.make_environment_spec(environment) - - network = snt.Sequential([ - snt.Flatten(), - snt.nets.MLP([50, 50]), - networks.PolicyValueHead(spec.actions.num_values), - ]) - model = simulator.Simulator(environment) - optimizer = snt.optimizers.Adam(1e-3) + network = snt.Sequential( + [ + snt.Flatten(), + snt.nets.MLP([50, 50]), + networks.PolicyValueHead(spec.actions.num_values), + ] + ) + model = simulator.Simulator(environment) + optimizer = snt.optimizers.Adam(1e-3) - # Construct the agent. - agent = mcts.MCTS( - environment_spec=spec, - network=network, - model=model, - optimizer=optimizer, - n_step=1, - discount=1., - replay_capacity=100, - num_simulations=10, - batch_size=10) + # Construct the agent. + agent = mcts.MCTS( + environment_spec=spec, + network=network, + model=model, + optimizer=optimizer, + n_step=1, + discount=1.0, + replay_capacity=100, + num_simulations=10, + batch_size=10, + ) - # Try running the environment loop. We have no assertions here because all - # we care about is that the agent runs without raising any errors. - loop = acme.EnvironmentLoop(environment, agent) - loop.run(num_episodes=2) + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/mcts/learning.py b/acme/agents/tf/mcts/learning.py index 6ec52c3d7f..d6ea06b223 100644 --- a/acme/agents/tf/mcts/learning.py +++ b/acme/agents/tf/mcts/learning.py @@ -16,74 +16,75 @@ from typing import List, Optional -import acme -from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers import numpy as np import sonnet as snt import tensorflow as tf +import acme +from acme.tf import utils as tf2_utils +from acme.utils import counting, loggers + class AZLearner(acme.Learner): - """AlphaZero-style learning.""" - - def __init__( - self, - network: snt.Module, - optimizer: snt.Optimizer, - dataset: tf.data.Dataset, - discount: float, - logger: Optional[loggers.Logger] = None, - counter: Optional[counting.Counter] = None, - ): - - # Logger and counter for tracking statistics / writing out to terminal. - self._counter = counting.Counter(counter, 'learner') - self._logger = logger or loggers.TerminalLogger('learner', time_delta=30.) - - # Internalize components. - # TODO(b/155086959): Fix type stubs and remove. - self._iterator = iter(dataset) # pytype: disable=wrong-arg-types - self._optimizer = optimizer - self._network = network - self._variables = network.trainable_variables - self._discount = np.float32(discount) - - @tf.function - def _step(self) -> tf.Tensor: - """Do a step of SGD on the loss.""" - - inputs = next(self._iterator) - o_t, _, r_t, d_t, o_tp1, extras = inputs.data - pi_t = extras['pi'] - - with tf.GradientTape() as tape: - # Forward the network on the two states in the transition. - logits, value = self._network(o_t) - _, target_value = self._network(o_tp1) - target_value = tf.stop_gradient(target_value) - - # Value loss is simply on-policy TD learning. - value_loss = tf.square(r_t + self._discount * d_t * target_value - value) - - # Policy loss distills MCTS policy into the policy network. - policy_loss = tf.nn.softmax_cross_entropy_with_logits( - logits=logits, labels=pi_t) - - # Compute gradients. - loss = tf.reduce_mean(value_loss + policy_loss) - gradients = tape.gradient(loss, self._network.trainable_variables) - - self._optimizer.apply(gradients, self._network.trainable_variables) - - return loss - - def step(self): - """Does a step of SGD and logs the results.""" - loss = self._step() - self._logger.write({'loss': loss}) - - def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: - """Exposes the variables for actors to update from.""" - return tf2_utils.to_numpy(self._variables) + """AlphaZero-style learning.""" + + def __init__( + self, + network: snt.Module, + optimizer: snt.Optimizer, + dataset: tf.data.Dataset, + discount: float, + logger: Optional[loggers.Logger] = None, + counter: Optional[counting.Counter] = None, + ): + + # Logger and counter for tracking statistics / writing out to terminal. + self._counter = counting.Counter(counter, "learner") + self._logger = logger or loggers.TerminalLogger("learner", time_delta=30.0) + + # Internalize components. + # TODO(b/155086959): Fix type stubs and remove. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + self._optimizer = optimizer + self._network = network + self._variables = network.trainable_variables + self._discount = np.float32(discount) + + @tf.function + def _step(self) -> tf.Tensor: + """Do a step of SGD on the loss.""" + + inputs = next(self._iterator) + o_t, _, r_t, d_t, o_tp1, extras = inputs.data + pi_t = extras["pi"] + + with tf.GradientTape() as tape: + # Forward the network on the two states in the transition. + logits, value = self._network(o_t) + _, target_value = self._network(o_tp1) + target_value = tf.stop_gradient(target_value) + + # Value loss is simply on-policy TD learning. + value_loss = tf.square(r_t + self._discount * d_t * target_value - value) + + # Policy loss distills MCTS policy into the policy network. + policy_loss = tf.nn.softmax_cross_entropy_with_logits( + logits=logits, labels=pi_t + ) + + # Compute gradients. + loss = tf.reduce_mean(value_loss + policy_loss) + gradients = tape.gradient(loss, self._network.trainable_variables) + + self._optimizer.apply(gradients, self._network.trainable_variables) + + return loss + + def step(self): + """Does a step of SGD and logs the results.""" + loss = self._step() + self._logger.write({"loss": loss}) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + """Exposes the variables for actors to update from.""" + return tf2_utils.to_numpy(self._variables) diff --git a/acme/agents/tf/mcts/models/base.py b/acme/agents/tf/mcts/models/base.py index 616b8f9560..389471d534 100644 --- a/acme/agents/tf/mcts/models/base.py +++ b/acme/agents/tf/mcts/models/base.py @@ -17,36 +17,36 @@ import abc from typing import Optional -from acme.agents.tf.mcts import types - import dm_env +from acme.agents.tf.mcts import types + class Model(dm_env.Environment, abc.ABC): - """Base (abstract) class for models used for planning via MCTS.""" - - @abc.abstractmethod - def load_checkpoint(self): - """Loads a saved model state, if it exists.""" - - @abc.abstractmethod - def save_checkpoint(self): - """Saves the model state so that we can reset it after a rollout.""" - - @abc.abstractmethod - def update( - self, - timestep: dm_env.TimeStep, - action: types.Action, - next_timestep: dm_env.TimeStep, - ) -> dm_env.TimeStep: - """Updates the model given an observation, action, reward, and discount.""" - - @abc.abstractmethod - def reset(self, initial_state: Optional[types.Observation] = None): - """Resets the model, optionally to an initial state.""" - - @property - @abc.abstractmethod - def needs_reset(self) -> bool: - """Returns whether or not the model needs to be reset.""" + """Base (abstract) class for models used for planning via MCTS.""" + + @abc.abstractmethod + def load_checkpoint(self): + """Loads a saved model state, if it exists.""" + + @abc.abstractmethod + def save_checkpoint(self): + """Saves the model state so that we can reset it after a rollout.""" + + @abc.abstractmethod + def update( + self, + timestep: dm_env.TimeStep, + action: types.Action, + next_timestep: dm_env.TimeStep, + ) -> dm_env.TimeStep: + """Updates the model given an observation, action, reward, and discount.""" + + @abc.abstractmethod + def reset(self, initial_state: Optional[types.Observation] = None): + """Resets the model, optionally to an initial state.""" + + @property + @abc.abstractmethod + def needs_reset(self) -> bool: + """Returns whether or not the model needs to be reset.""" diff --git a/acme/agents/tf/mcts/models/mlp.py b/acme/agents/tf/mcts/models/mlp.py index 6f4c8adbcb..06c567ed9a 100644 --- a/acme/agents/tf/mcts/models/mlp.py +++ b/acme/agents/tf/mcts/models/mlp.py @@ -16,205 +16,207 @@ from typing import Optional, Tuple -from acme import specs -from acme.agents.tf.mcts import types -from acme.agents.tf.mcts.models import base -from acme.tf import utils as tf2_utils - -from bsuite.baselines.utils import replay import dm_env import numpy as np -from scipy import special import sonnet as snt import tensorflow as tf +from bsuite.baselines.utils import replay +from scipy import special + +from acme import specs +from acme.agents.tf.mcts import types +from acme.agents.tf.mcts.models import base +from acme.tf import utils as tf2_utils class MLPTransitionModel(snt.Module): - """This uses MLPs to model (s, a) -> (r, d, s').""" - - def __init__( - self, - environment_spec: specs.EnvironmentSpec, - hidden_sizes: Tuple[int, ...], - ): - super(MLPTransitionModel, self).__init__(name='mlp_transition_model') - - # Get num actions/observation shape. - self._num_actions = environment_spec.actions.num_values - self._input_shape = environment_spec.observations.shape - self._flat_shape = int(np.prod(self._input_shape)) - - # Prediction networks. - self._state_network = snt.Sequential([ - snt.nets.MLP(hidden_sizes + (self._flat_shape,)), - snt.Reshape(self._input_shape) - ]) - self._reward_network = snt.Sequential([ - snt.nets.MLP(hidden_sizes + (1,)), - lambda r: tf.squeeze(r, axis=-1), - ]) - self._discount_network = snt.Sequential([ - snt.nets.MLP(hidden_sizes + (1,)), - lambda d: tf.squeeze(d, axis=-1), - ]) - - def __call__(self, state: tf.Tensor, - action: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: - - embedded_state = snt.Flatten()(state) - embedded_action = tf.one_hot(action, depth=self._num_actions) - - embedding = tf.concat([embedded_state, embedded_action], axis=-1) - - # Predict the next state, reward, and termination. - next_state = self._state_network(embedding) - reward = self._reward_network(embedding) - discount_logits = self._discount_network(embedding) - - return next_state, reward, discount_logits + """This uses MLPs to model (s, a) -> (r, d, s').""" + + def __init__( + self, environment_spec: specs.EnvironmentSpec, hidden_sizes: Tuple[int, ...], + ): + super(MLPTransitionModel, self).__init__(name="mlp_transition_model") + + # Get num actions/observation shape. + self._num_actions = environment_spec.actions.num_values + self._input_shape = environment_spec.observations.shape + self._flat_shape = int(np.prod(self._input_shape)) + + # Prediction networks. + self._state_network = snt.Sequential( + [ + snt.nets.MLP(hidden_sizes + (self._flat_shape,)), + snt.Reshape(self._input_shape), + ] + ) + self._reward_network = snt.Sequential( + [snt.nets.MLP(hidden_sizes + (1,)), lambda r: tf.squeeze(r, axis=-1),] + ) + self._discount_network = snt.Sequential( + [snt.nets.MLP(hidden_sizes + (1,)), lambda d: tf.squeeze(d, axis=-1),] + ) + + def __call__( + self, state: tf.Tensor, action: tf.Tensor + ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + + embedded_state = snt.Flatten()(state) + embedded_action = tf.one_hot(action, depth=self._num_actions) + + embedding = tf.concat([embedded_state, embedded_action], axis=-1) + + # Predict the next state, reward, and termination. + next_state = self._state_network(embedding) + reward = self._reward_network(embedding) + discount_logits = self._discount_network(embedding) + + return next_state, reward, discount_logits class MLPModel(base.Model): - """A simple environment model.""" - - _checkpoint: types.Observation - _state: types.Observation - - def __init__( - self, - environment_spec: specs.EnvironmentSpec, - replay_capacity: int, - batch_size: int, - hidden_sizes: Tuple[int, ...], - learning_rate: float = 1e-3, - terminal_tol: float = 1e-3, - ): - self._obs_spec = environment_spec.observations - self._action_spec = environment_spec.actions - # Hyperparameters. - self._batch_size = batch_size - self._terminal_tol = terminal_tol - - # Modelling - self._replay = replay.Replay(replay_capacity) - self._transition_model = MLPTransitionModel(environment_spec, hidden_sizes) - self._optimizer = snt.optimizers.Adam(learning_rate) - self._forward = tf.function(self._transition_model) - tf2_utils.create_variables( - self._transition_model, [self._obs_spec, self._action_spec]) - self._variables = self._transition_model.trainable_variables - - # Model state. - self._needs_reset = True - - @tf.function - def _step( - self, - o_t: tf.Tensor, - a_t: tf.Tensor, - r_t: tf.Tensor, - d_t: tf.Tensor, - o_tp1: tf.Tensor, - ) -> tf.Tensor: - - with tf.GradientTape() as tape: - next_state, reward, discount = self._transition_model(o_t, a_t) - - state_loss = tf.square(next_state - o_tp1) - reward_loss = tf.square(reward - r_t) - discount_loss = tf.nn.sigmoid_cross_entropy_with_logits(d_t, discount) - - loss = sum([ - tf.reduce_mean(state_loss), - tf.reduce_mean(reward_loss), - tf.reduce_mean(discount_loss), - ]) - - gradients = tape.gradient(loss, self._variables) - self._optimizer.apply(gradients, self._variables) - - return loss - - def step(self, action: types.Action): - # Reset if required. - if self._needs_reset: - raise ValueError('Model must be reset with an initial timestep.') - - # Step the model. - state, action = tf2_utils.add_batch_dim([self._state, action]) - new_state, reward, discount_logits = [ - x.numpy().squeeze(axis=0) for x in self._forward(state, action) - ] - discount = special.softmax(discount_logits) - - # Save the resulting state for the next step. - self._state = new_state - - # We threshold discount on a given tolerance. - if discount < self._terminal_tol: - self._needs_reset = True - return dm_env.termination(reward=reward, observation=self._state.copy()) - return dm_env.transition(reward=reward, observation=self._state.copy()) - - def reset(self, initial_state: Optional[types.Observation] = None): - if initial_state is None: - raise ValueError('Model must be reset with an initial state.') - # We reset to an initial state that we are explicitly given. - # This allows us to handle environments with stochastic resets (e.g. Catch). - self._state = initial_state.copy() - self._needs_reset = False - return dm_env.restart(self._state) - - def update( - self, - timestep: dm_env.TimeStep, - action: types.Action, - next_timestep: dm_env.TimeStep, - ) -> dm_env.TimeStep: - # Add the true transition to replay. - transition = [ - timestep.observation, - action, - next_timestep.reward, - next_timestep.discount, - next_timestep.observation, - ] - self._replay.add(transition) - - # Step the model to generate a synthetic transition. - ts = self.step(action) - - # Copy the *true* state on update. - self._state = next_timestep.observation.copy() - - if ts.last() or next_timestep.last(): - # Model believes that a termination has happened. - # This will result in a crash during planning if the true environment - # didn't terminate here as well. So, we indicate that we need a reset. - self._needs_reset = True - - # Sample from replay and do SGD. - if self._replay.size >= self._batch_size: - batch = self._replay.sample(self._batch_size) - self._step(*batch) - - return ts - - def save_checkpoint(self): - if self._needs_reset: - raise ValueError('Cannot save checkpoint: model must be reset first.') - self._checkpoint = self._state.copy() - - def load_checkpoint(self): - self._needs_reset = False - self._state = self._checkpoint.copy() - - def action_spec(self): - return self._action_spec - - def observation_spec(self): - return self._obs_spec - - @property - def needs_reset(self) -> bool: - return self._needs_reset + """A simple environment model.""" + + _checkpoint: types.Observation + _state: types.Observation + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + replay_capacity: int, + batch_size: int, + hidden_sizes: Tuple[int, ...], + learning_rate: float = 1e-3, + terminal_tol: float = 1e-3, + ): + self._obs_spec = environment_spec.observations + self._action_spec = environment_spec.actions + # Hyperparameters. + self._batch_size = batch_size + self._terminal_tol = terminal_tol + + # Modelling + self._replay = replay.Replay(replay_capacity) + self._transition_model = MLPTransitionModel(environment_spec, hidden_sizes) + self._optimizer = snt.optimizers.Adam(learning_rate) + self._forward = tf.function(self._transition_model) + tf2_utils.create_variables( + self._transition_model, [self._obs_spec, self._action_spec] + ) + self._variables = self._transition_model.trainable_variables + + # Model state. + self._needs_reset = True + + @tf.function + def _step( + self, + o_t: tf.Tensor, + a_t: tf.Tensor, + r_t: tf.Tensor, + d_t: tf.Tensor, + o_tp1: tf.Tensor, + ) -> tf.Tensor: + + with tf.GradientTape() as tape: + next_state, reward, discount = self._transition_model(o_t, a_t) + + state_loss = tf.square(next_state - o_tp1) + reward_loss = tf.square(reward - r_t) + discount_loss = tf.nn.sigmoid_cross_entropy_with_logits(d_t, discount) + + loss = sum( + [ + tf.reduce_mean(state_loss), + tf.reduce_mean(reward_loss), + tf.reduce_mean(discount_loss), + ] + ) + + gradients = tape.gradient(loss, self._variables) + self._optimizer.apply(gradients, self._variables) + + return loss + + def step(self, action: types.Action): + # Reset if required. + if self._needs_reset: + raise ValueError("Model must be reset with an initial timestep.") + + # Step the model. + state, action = tf2_utils.add_batch_dim([self._state, action]) + new_state, reward, discount_logits = [ + x.numpy().squeeze(axis=0) for x in self._forward(state, action) + ] + discount = special.softmax(discount_logits) + + # Save the resulting state for the next step. + self._state = new_state + + # We threshold discount on a given tolerance. + if discount < self._terminal_tol: + self._needs_reset = True + return dm_env.termination(reward=reward, observation=self._state.copy()) + return dm_env.transition(reward=reward, observation=self._state.copy()) + + def reset(self, initial_state: Optional[types.Observation] = None): + if initial_state is None: + raise ValueError("Model must be reset with an initial state.") + # We reset to an initial state that we are explicitly given. + # This allows us to handle environments with stochastic resets (e.g. Catch). + self._state = initial_state.copy() + self._needs_reset = False + return dm_env.restart(self._state) + + def update( + self, + timestep: dm_env.TimeStep, + action: types.Action, + next_timestep: dm_env.TimeStep, + ) -> dm_env.TimeStep: + # Add the true transition to replay. + transition = [ + timestep.observation, + action, + next_timestep.reward, + next_timestep.discount, + next_timestep.observation, + ] + self._replay.add(transition) + + # Step the model to generate a synthetic transition. + ts = self.step(action) + + # Copy the *true* state on update. + self._state = next_timestep.observation.copy() + + if ts.last() or next_timestep.last(): + # Model believes that a termination has happened. + # This will result in a crash during planning if the true environment + # didn't terminate here as well. So, we indicate that we need a reset. + self._needs_reset = True + + # Sample from replay and do SGD. + if self._replay.size >= self._batch_size: + batch = self._replay.sample(self._batch_size) + self._step(*batch) + + return ts + + def save_checkpoint(self): + if self._needs_reset: + raise ValueError("Cannot save checkpoint: model must be reset first.") + self._checkpoint = self._state.copy() + + def load_checkpoint(self): + self._needs_reset = False + self._state = self._checkpoint.copy() + + def action_spec(self): + return self._action_spec + + def observation_spec(self): + return self._obs_spec + + @property + def needs_reset(self) -> bool: + return self._needs_reset diff --git a/acme/agents/tf/mcts/models/simulator.py b/acme/agents/tf/mcts/models/simulator.py index e0acb5d7ba..4a6efd2cdc 100644 --- a/acme/agents/tf/mcts/models/simulator.py +++ b/acme/agents/tf/mcts/models/simulator.py @@ -17,71 +17,72 @@ import copy import dataclasses +import dm_env + from acme.agents.tf.mcts import types from acme.agents.tf.mcts.models import base -import dm_env @dataclasses.dataclass class Checkpoint: - """Holds the checkpoint state for the environment simulator.""" - needs_reset: bool - environment: dm_env.Environment + """Holds the checkpoint state for the environment simulator.""" + + needs_reset: bool + environment: dm_env.Environment class Simulator(base.Model): - """A simulator model, which wraps a copy of the true environment. + """A simulator model, which wraps a copy of the true environment. Assumptions: - The environment (including RNG) is fully copyable via `deepcopy`. - Environment dynamics (modulo episode resets) are deterministic. """ - _checkpoint: Checkpoint - _env: dm_env.Environment - - def __init__(self, env: dm_env.Environment): - # Make a 'checkpoint' copy env to save/load from when doing rollouts. - self._env = copy.deepcopy(env) - self._needs_reset = True - self.save_checkpoint() - - def update( - self, - timestep: dm_env.TimeStep, - action: types.Action, - next_timestep: dm_env.TimeStep, - ) -> dm_env.TimeStep: - # Call update() once per 'real' experience to keep this env in sync. - return self.step(action) - - def save_checkpoint(self): - self._checkpoint = Checkpoint( - needs_reset=self._needs_reset, - environment=copy.deepcopy(self._env), - ) - - def load_checkpoint(self): - self._env = copy.deepcopy(self._checkpoint.environment) - self._needs_reset = self._checkpoint.needs_reset - - def step(self, action: types.Action) -> dm_env.TimeStep: - if self._needs_reset: - raise ValueError('This model needs to be explicitly reset.') - timestep = self._env.step(action) - self._needs_reset = timestep.last() - return timestep - - def reset(self, *unused_args, **unused_kwargs): - self._needs_reset = False - return self._env.reset() - - def observation_spec(self): - return self._env.observation_spec() - - def action_spec(self): - return self._env.action_spec() - - @property - def needs_reset(self) -> bool: - return self._needs_reset + _checkpoint: Checkpoint + _env: dm_env.Environment + + def __init__(self, env: dm_env.Environment): + # Make a 'checkpoint' copy env to save/load from when doing rollouts. + self._env = copy.deepcopy(env) + self._needs_reset = True + self.save_checkpoint() + + def update( + self, + timestep: dm_env.TimeStep, + action: types.Action, + next_timestep: dm_env.TimeStep, + ) -> dm_env.TimeStep: + # Call update() once per 'real' experience to keep this env in sync. + return self.step(action) + + def save_checkpoint(self): + self._checkpoint = Checkpoint( + needs_reset=self._needs_reset, environment=copy.deepcopy(self._env), + ) + + def load_checkpoint(self): + self._env = copy.deepcopy(self._checkpoint.environment) + self._needs_reset = self._checkpoint.needs_reset + + def step(self, action: types.Action) -> dm_env.TimeStep: + if self._needs_reset: + raise ValueError("This model needs to be explicitly reset.") + timestep = self._env.step(action) + self._needs_reset = timestep.last() + return timestep + + def reset(self, *unused_args, **unused_kwargs): + self._needs_reset = False + return self._env.reset() + + def observation_spec(self): + return self._env.observation_spec() + + def action_spec(self): + return self._env.action_spec() + + @property + def needs_reset(self) -> bool: + return self._needs_reset diff --git a/acme/agents/tf/mcts/models/simulator_test.py b/acme/agents/tf/mcts/models/simulator_test.py index 90b68d5986..57ae347029 100644 --- a/acme/agents/tf/mcts/models/simulator_test.py +++ b/acme/agents/tf/mcts/models/simulator_test.py @@ -14,77 +14,76 @@ """Tests for simulator.py.""" -from acme.agents.tf.mcts.models import simulator -from bsuite.environments import catch import dm_env import numpy as np - from absl.testing import absltest +from bsuite.environments import catch +from acme.agents.tf.mcts.models import simulator -class SimulatorTest(absltest.TestCase): - def _check_equal(self, a: dm_env.TimeStep, b: dm_env.TimeStep): - self.assertEqual(a.reward, b.reward) - self.assertEqual(a.discount, b.discount) - self.assertEqual(a.step_type, b.step_type) - np.testing.assert_array_equal(a.observation, b.observation) +class SimulatorTest(absltest.TestCase): + def _check_equal(self, a: dm_env.TimeStep, b: dm_env.TimeStep): + self.assertEqual(a.reward, b.reward) + self.assertEqual(a.discount, b.discount) + self.assertEqual(a.step_type, b.step_type) + np.testing.assert_array_equal(a.observation, b.observation) - def test_simulator_fidelity(self): - """Tests whether the simulator match the ground truth.""" + def test_simulator_fidelity(self): + """Tests whether the simulator match the ground truth.""" - # Given an environment. - env = catch.Catch() + # Given an environment. + env = catch.Catch() - # If we instantiate a simulator 'model' of this environment. - model = simulator.Simulator(env) + # If we instantiate a simulator 'model' of this environment. + model = simulator.Simulator(env) - # Then the model and environment should always agree as we step them. - num_actions = env.action_spec().num_values - for _ in range(10): - true_timestep = env.reset() - self.assertTrue(model.needs_reset) - model_timestep = model.reset() - self.assertFalse(model.needs_reset) - self._check_equal(true_timestep, model_timestep) + # Then the model and environment should always agree as we step them. + num_actions = env.action_spec().num_values + for _ in range(10): + true_timestep = env.reset() + self.assertTrue(model.needs_reset) + model_timestep = model.reset() + self.assertFalse(model.needs_reset) + self._check_equal(true_timestep, model_timestep) - while not true_timestep.last(): - action = np.random.randint(num_actions) - true_timestep = env.step(action) - model_timestep = model.step(action) - self._check_equal(true_timestep, model_timestep) + while not true_timestep.last(): + action = np.random.randint(num_actions) + true_timestep = env.step(action) + model_timestep = model.step(action) + self._check_equal(true_timestep, model_timestep) - def test_checkpointing(self): - """Tests whether checkpointing restores the state correctly.""" - # Given an environment, and a model based on this environment. - model = simulator.Simulator(catch.Catch()) - num_actions = model.action_spec().num_values + def test_checkpointing(self): + """Tests whether checkpointing restores the state correctly.""" + # Given an environment, and a model based on this environment. + model = simulator.Simulator(catch.Catch()) + num_actions = model.action_spec().num_values - model.reset() + model.reset() - # Now, we save a checkpoint. - model.save_checkpoint() + # Now, we save a checkpoint. + model.save_checkpoint() - ts = model.step(1) + ts = model.step(1) - # Step the model once and load the checkpoint. - timestep = model.step(np.random.randint(num_actions)) - model.load_checkpoint() - self._check_equal(ts, model.step(1)) + # Step the model once and load the checkpoint. + timestep = model.step(np.random.randint(num_actions)) + model.load_checkpoint() + self._check_equal(ts, model.step(1)) - while not timestep.last(): - timestep = model.step(np.random.randint(num_actions)) + while not timestep.last(): + timestep = model.step(np.random.randint(num_actions)) - # The model should require a reset. - self.assertTrue(model.needs_reset) + # The model should require a reset. + self.assertTrue(model.needs_reset) - # Once we load checkpoint, the model should no longer require reset. - model.load_checkpoint() - self.assertFalse(model.needs_reset) + # Once we load checkpoint, the model should no longer require reset. + model.load_checkpoint() + self.assertFalse(model.needs_reset) - # Further steps should agree with the original environment state. - self._check_equal(ts, model.step(1)) + # Further steps should agree with the original environment state. + self._check_equal(ts, model.step(1)) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/mcts/search.py b/acme/agents/tf/mcts/search.py index 7bcc85f40a..e50f8aac0c 100644 --- a/acme/agents/tf/mcts/search.py +++ b/acme/agents/tf/mcts/search.py @@ -17,44 +17,44 @@ import dataclasses from typing import Callable, Dict -from acme.agents.tf.mcts import models -from acme.agents.tf.mcts import types import numpy as np +from acme.agents.tf.mcts import models, types + @dataclasses.dataclass class Node: - """A MCTS node.""" - - reward: float = 0. - visit_count: int = 0 - terminal: bool = False - prior: float = 1. - total_value: float = 0. - children: Dict[types.Action, 'Node'] = dataclasses.field(default_factory=dict) - - def expand(self, prior: np.ndarray): - """Expands this node, adding child nodes.""" - assert prior.ndim == 1 # Prior should be a flat vector. - for a, p in enumerate(prior): - self.children[a] = Node(prior=p) - - @property - def value(self) -> types.Value: # Q(s, a) - """Returns the value from this node.""" - if self.visit_count: - return self.total_value / self.visit_count - return 0. - - @property - def children_visits(self) -> np.ndarray: - """Return array of visit counts of visited children.""" - return np.array([c.visit_count for c in self.children.values()]) - - @property - def children_values(self) -> np.ndarray: - """Return array of values of visited children.""" - return np.array([c.value for c in self.children.values()]) + """A MCTS node.""" + + reward: float = 0.0 + visit_count: int = 0 + terminal: bool = False + prior: float = 1.0 + total_value: float = 0.0 + children: Dict[types.Action, "Node"] = dataclasses.field(default_factory=dict) + + def expand(self, prior: np.ndarray): + """Expands this node, adding child nodes.""" + assert prior.ndim == 1 # Prior should be a flat vector. + for a, p in enumerate(prior): + self.children[a] = Node(prior=p) + + @property + def value(self) -> types.Value: # Q(s, a) + """Returns the value from this node.""" + if self.visit_count: + return self.total_value / self.visit_count + return 0.0 + + @property + def children_visits(self) -> np.ndarray: + """Return array of visit counts of visited children.""" + return np.array([c.visit_count for c in self.children.values()]) + + @property + def children_values(self) -> np.ndarray: + """Return array of values of visited children.""" + return np.array([c.value for c in self.children.values()]) SearchPolicy = Callable[[Node], types.Action] @@ -67,128 +67,130 @@ def mcts( evaluation: types.EvaluationFn, num_simulations: int, num_actions: int, - discount: float = 1., + discount: float = 1.0, dirichlet_alpha: float = 1, - exploration_fraction: float = 0., + exploration_fraction: float = 0.0, ) -> Node: - """Does Monte Carlo tree search (MCTS), AlphaZero style.""" + """Does Monte Carlo tree search (MCTS), AlphaZero style.""" - # Evaluate the prior policy for this state. - prior, value = evaluation(observation) - assert prior.shape == (num_actions,) + # Evaluate the prior policy for this state. + prior, value = evaluation(observation) + assert prior.shape == (num_actions,) - # Add exploration noise to the prior. - noise = np.random.dirichlet(alpha=[dirichlet_alpha] * num_actions) - prior = prior * (1 - exploration_fraction) + noise * exploration_fraction + # Add exploration noise to the prior. + noise = np.random.dirichlet(alpha=[dirichlet_alpha] * num_actions) + prior = prior * (1 - exploration_fraction) + noise * exploration_fraction - # Create a fresh tree search. - root = Node() - root.expand(prior) + # Create a fresh tree search. + root = Node() + root.expand(prior) - # Save the model state so that we can reset it for each simulation. - model.save_checkpoint() - for _ in range(num_simulations): - # Start a new simulation from the top. - trajectory = [root] - node = root + # Save the model state so that we can reset it for each simulation. + model.save_checkpoint() + for _ in range(num_simulations): + # Start a new simulation from the top. + trajectory = [root] + node = root - # Generate a trajectory. - timestep = None - while node.children: - # Select an action according to the search policy. - action = search_policy(node) + # Generate a trajectory. + timestep = None + while node.children: + # Select an action according to the search policy. + action = search_policy(node) - # Point the node at the corresponding child. - node = node.children[action] + # Point the node at the corresponding child. + node = node.children[action] - # Step the simulator and add this timestep to the node. - timestep = model.step(action) - node.reward = timestep.reward or 0. - node.terminal = timestep.last() - trajectory.append(node) + # Step the simulator and add this timestep to the node. + timestep = model.step(action) + node.reward = timestep.reward or 0.0 + node.terminal = timestep.last() + trajectory.append(node) - if timestep is None: - raise ValueError('Generated an empty rollout; this should not happen.') + if timestep is None: + raise ValueError("Generated an empty rollout; this should not happen.") - # Calculate the bootstrap for leaf nodes. - if node.terminal: - # If terminal, there is no bootstrap value. - value = 0. - else: - # Otherwise, bootstrap from this node with our value function. - prior, value = evaluation(timestep.observation) + # Calculate the bootstrap for leaf nodes. + if node.terminal: + # If terminal, there is no bootstrap value. + value = 0.0 + else: + # Otherwise, bootstrap from this node with our value function. + prior, value = evaluation(timestep.observation) - # We also want to expand this node for next time. - node.expand(prior) + # We also want to expand this node for next time. + node.expand(prior) - # Load the saved model state. - model.load_checkpoint() + # Load the saved model state. + model.load_checkpoint() - # Monte Carlo back-up with bootstrap from value function. - ret = value - while trajectory: - # Pop off the latest node in the trajectory. - node = trajectory.pop() + # Monte Carlo back-up with bootstrap from value function. + ret = value + while trajectory: + # Pop off the latest node in the trajectory. + node = trajectory.pop() - # Accumulate the discounted return - ret *= discount - ret += node.reward + # Accumulate the discounted return + ret *= discount + ret += node.reward - # Update the node. - node.total_value += ret - node.visit_count += 1 + # Update the node. + node.total_value += ret + node.visit_count += 1 - return root + return root def bfs(node: Node) -> types.Action: - """Breadth-first search policy.""" - visit_counts = np.array([c.visit_count for c in node.children.values()]) - return argmax(-visit_counts) + """Breadth-first search policy.""" + visit_counts = np.array([c.visit_count for c in node.children.values()]) + return argmax(-visit_counts) -def puct(node: Node, ucb_scaling: float = 1.) -> types.Action: - """PUCT search policy, i.e. UCT with 'prior' policy.""" - # Action values Q(s,a). - value_scores = np.array([child.value for child in node.children.values()]) - check_numerics(value_scores) +def puct(node: Node, ucb_scaling: float = 1.0) -> types.Action: + """PUCT search policy, i.e. UCT with 'prior' policy.""" + # Action values Q(s,a). + value_scores = np.array([child.value for child in node.children.values()]) + check_numerics(value_scores) - # Policy prior P(s,a). - priors = np.array([child.prior for child in node.children.values()]) - check_numerics(priors) + # Policy prior P(s,a). + priors = np.array([child.prior for child in node.children.values()]) + check_numerics(priors) - # Visit ratios. - visit_ratios = np.array([ - np.sqrt(node.visit_count) / (child.visit_count + 1) - for child in node.children.values() - ]) - check_numerics(visit_ratios) + # Visit ratios. + visit_ratios = np.array( + [ + np.sqrt(node.visit_count) / (child.visit_count + 1) + for child in node.children.values() + ] + ) + check_numerics(visit_ratios) - # Combine. - puct_scores = value_scores + ucb_scaling * priors * visit_ratios - return argmax(puct_scores) + # Combine. + puct_scores = value_scores + ucb_scaling * priors * visit_ratios + return argmax(puct_scores) -def visit_count_policy(root: Node, temperature: float = 1.) -> types.Probs: - """Probability weighted by visit^{1/temp} of children nodes.""" - visits = root.children_visits - if np.sum(visits) == 0: # uniform policy for zero visits - visits += 1 - rescaled_visits = visits**(1 / temperature) - probs = rescaled_visits / np.sum(rescaled_visits) - check_numerics(probs) +def visit_count_policy(root: Node, temperature: float = 1.0) -> types.Probs: + """Probability weighted by visit^{1/temp} of children nodes.""" + visits = root.children_visits + if np.sum(visits) == 0: # uniform policy for zero visits + visits += 1 + rescaled_visits = visits ** (1 / temperature) + probs = rescaled_visits / np.sum(rescaled_visits) + check_numerics(probs) - return probs + return probs def argmax(values: np.ndarray) -> types.Action: - """Argmax with random tie-breaking.""" - check_numerics(values) - max_value = np.max(values) - return np.int32(np.random.choice(np.flatnonzero(values == max_value))) + """Argmax with random tie-breaking.""" + check_numerics(values) + max_value = np.max(values) + return np.int32(np.random.choice(np.flatnonzero(values == max_value))) def check_numerics(values: np.ndarray): - """Raises a ValueError if any of the inputs are NaN or Inf.""" - if not np.isfinite(values).all(): - raise ValueError('check_numerics failed. Inputs: {}. '.format(values)) + """Raises a ValueError if any of the inputs are NaN or Inf.""" + if not np.isfinite(values).all(): + raise ValueError("check_numerics failed. Inputs: {}. ".format(values)) diff --git a/acme/agents/tf/mcts/search_test.py b/acme/agents/tf/mcts/search_test.py index c5b4190c87..9f29d1d087 100644 --- a/acme/agents/tf/mcts/search_test.py +++ b/acme/agents/tf/mcts/search_test.py @@ -16,50 +16,48 @@ from typing import Text -from acme.agents.tf.mcts import search -from acme.agents.tf.mcts.models import simulator -from bsuite.environments import catch import numpy as np +from absl.testing import absltest, parameterized +from bsuite.environments import catch -from absl.testing import absltest -from absl.testing import parameterized +from acme.agents.tf.mcts import search +from acme.agents.tf.mcts.models import simulator class TestSearch(parameterized.TestCase): - - @parameterized.parameters([ - 'puct', - 'bfs', - ]) - def test_catch(self, policy_type: Text): - env = catch.Catch(rows=2, seed=1) - num_actions = env.action_spec().num_values - model = simulator.Simulator(env) - eval_fn = lambda _: (np.ones(num_actions) / num_actions, 0.) - - timestep = env.reset() - model.reset() - - search_policy = search.bfs if policy_type == 'bfs' else search.puct - - root = search.mcts( - observation=timestep.observation, - model=model, - search_policy=search_policy, - evaluation=eval_fn, - num_simulations=100, - num_actions=num_actions) - - values = np.array([c.value for c in root.children.values()]) - best_action = search.argmax(values) - - if env._paddle_x > env._ball_x: - self.assertEqual(best_action, 0) - if env._paddle_x == env._ball_x: - self.assertEqual(best_action, 1) - if env._paddle_x < env._ball_x: - self.assertEqual(best_action, 2) - - -if __name__ == '__main__': - absltest.main() + @parameterized.parameters( + ["puct", "bfs",] + ) + def test_catch(self, policy_type: Text): + env = catch.Catch(rows=2, seed=1) + num_actions = env.action_spec().num_values + model = simulator.Simulator(env) + eval_fn = lambda _: (np.ones(num_actions) / num_actions, 0.0) + + timestep = env.reset() + model.reset() + + search_policy = search.bfs if policy_type == "bfs" else search.puct + + root = search.mcts( + observation=timestep.observation, + model=model, + search_policy=search_policy, + evaluation=eval_fn, + num_simulations=100, + num_actions=num_actions, + ) + + values = np.array([c.value for c in root.children.values()]) + best_action = search.argmax(values) + + if env._paddle_x > env._ball_x: + self.assertEqual(best_action, 0) + if env._paddle_x == env._ball_x: + self.assertEqual(best_action, 1) + if env._paddle_x < env._ball_x: + self.assertEqual(best_action, 2) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/mcts/types.py b/acme/agents/tf/mcts/types.py index 93d8d8cd11..6e654e483c 100644 --- a/acme/agents/tf/mcts/types.py +++ b/acme/agents/tf/mcts/types.py @@ -15,6 +15,7 @@ """Type aliases and assumptions that are specific to the MCTS agent.""" from typing import Callable, Tuple, Union + import numpy as np # pylint: disable=invalid-name diff --git a/acme/agents/tf/mog_mpo/__init__.py b/acme/agents/tf/mog_mpo/__init__.py index bd2906f02a..de565043da 100644 --- a/acme/agents/tf/mog_mpo/__init__.py +++ b/acme/agents/tf/mog_mpo/__init__.py @@ -15,6 +15,5 @@ """Implementations of a (MoG) distributional MPO agent.""" from acme.agents.tf.mog_mpo.agent_distributed import DistributedMoGMPO -from acme.agents.tf.mog_mpo.learning import MoGMPOLearner -from acme.agents.tf.mog_mpo.learning import PolicyEvaluationConfig +from acme.agents.tf.mog_mpo.learning import MoGMPOLearner, PolicyEvaluationConfig from acme.agents.tf.mog_mpo.networks import make_default_networks diff --git a/acme/agents/tf/mog_mpo/agent_distributed.py b/acme/agents/tf/mog_mpo/agent_distributed.py index de9a4e606c..81c62f326a 100644 --- a/acme/agents/tf/mog_mpo/agent_distributed.py +++ b/acme/agents/tf/mog_mpo/agent_distributed.py @@ -16,277 +16,286 @@ from typing import Callable, Dict, Optional +import dm_env +import launchpad as lp +import reverb +import sonnet as snt + import acme -from acme import datasets -from acme import specs +from acme import datasets, specs from acme.adders import reverb as adders from acme.agents.tf import actors from acme.agents.tf.mog_mpo import learning from acme.tf import networks from acme.tf import savers as tf2_savers from acme.tf import variable_utils as tf2_variable_utils -from acme.utils import counting -from acme.utils import loggers -from acme.utils import lp_utils -import dm_env -import launchpad as lp -import reverb -import sonnet as snt +from acme.utils import counting, loggers, lp_utils class DistributedMoGMPO: - """Program definition for distributional (MoG) MPO.""" - - def __init__( - self, - environment_factory: Callable[[bool], dm_env.Environment], - network_factory: Callable[[specs.EnvironmentSpec], Dict[str, snt.Module]], - num_actors: int = 1, - num_caches: int = 0, - environment_spec: Optional[specs.EnvironmentSpec] = None, - batch_size: int = 256, - prefetch_size: int = 4, - min_replay_size: int = 1_000, - max_replay_size: int = 1_000_000, - samples_per_insert: Optional[float] = 32.0, - n_step: int = 5, - num_samples: int = 20, - policy_evaluation_config: Optional[ - learning.PolicyEvaluationConfig] = None, - additional_discount: float = 0.99, - target_policy_update_period: int = 100, - target_critic_update_period: int = 100, - policy_loss_factory: Optional[Callable[[], snt.Module]] = None, - max_actor_steps: Optional[int] = None, - log_every: float = 10.0, - ): - - if environment_spec is None: - environment_spec = specs.make_environment_spec(environment_factory(False)) - - self._environment_factory = environment_factory - self._network_factory = network_factory - self._policy_loss_factory = policy_loss_factory - self._environment_spec = environment_spec - self._num_actors = num_actors - self._num_caches = num_caches - self._batch_size = batch_size - self._prefetch_size = prefetch_size - self._min_replay_size = min_replay_size - self._max_replay_size = max_replay_size - self._samples_per_insert = samples_per_insert - self._n_step = n_step - self._additional_discount = additional_discount - self._num_samples = num_samples - self._policy_evaluation_config = policy_evaluation_config - self._target_policy_update_period = target_policy_update_period - self._target_critic_update_period = target_critic_update_period - self._max_actor_steps = max_actor_steps - self._log_every = log_every - - def replay(self): - """The replay storage.""" - if self._samples_per_insert is not None: - # Create enough of an error buffer to give a 10% tolerance in rate. - samples_per_insert_tolerance = 0.1 * self._samples_per_insert - error_buffer = self._min_replay_size * samples_per_insert_tolerance - - limiter = reverb.rate_limiters.SampleToInsertRatio( - min_size_to_sample=self._min_replay_size, - samples_per_insert=self._samples_per_insert, - error_buffer=error_buffer) - else: - limiter = reverb.rate_limiters.MinSize(self._min_replay_size) - replay_table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=self._max_replay_size, - rate_limiter=limiter, - signature=adders.NStepTransitionAdder.signature( - self._environment_spec)) - return [replay_table] - - def counter(self): - return tf2_savers.CheckpointingRunner( - counting.Counter(), time_delta_minutes=1, subdirectory='counter') - - def coordinator(self, counter: counting.Counter, max_actor_steps: int): - return lp_utils.StepsLimiter(counter, max_actor_steps) - - def learner( - self, - replay: reverb.Client, - counter: counting.Counter, - ): - """The Learning part of the agent.""" - - # Create online and target networks. - online_networks = self._network_factory(self._environment_spec) - target_networks = self._network_factory(self._environment_spec) - - # The dataset object to learn from. - dataset = datasets.make_reverb_dataset( - server_address=replay.server_address, - batch_size=self._batch_size, - prefetch_size=self._prefetch_size, - ) - - counter = counting.Counter(counter, 'learner') - logger = loggers.make_default_logger('learner', time_delta=self._log_every) - - # Create policy loss module if a factory is passed. - if self._policy_loss_factory: - policy_loss_module = self._policy_loss_factory() - else: - policy_loss_module = None - - # Return the learning agent. - return learning.MoGMPOLearner( - policy_network=online_networks['policy'], - critic_network=online_networks['critic'], - observation_network=online_networks['observation'], - target_policy_network=target_networks['policy'], - target_critic_network=target_networks['critic'], - target_observation_network=target_networks['observation'], - discount=self._additional_discount, - num_samples=self._num_samples, - policy_evaluation_config=self._policy_evaluation_config, - target_policy_update_period=self._target_policy_update_period, - target_critic_update_period=self._target_critic_update_period, - policy_loss_module=policy_loss_module, - dataset=dataset, - counter=counter, - logger=logger) - - def actor( - self, - replay: reverb.Client, - variable_source: acme.VariableSource, - counter: counting.Counter, - actor_id: int, - ) -> acme.EnvironmentLoop: - """The actor process.""" - - # Create environment and target networks to act with. - environment = self._environment_factory(False) - agent_networks = self._network_factory(self._environment_spec) - - # Create a stochastic behavior policy. - behavior_network = snt.Sequential([ - agent_networks['observation'], - agent_networks['policy'], - networks.StochasticSamplingHead(), - ]) - - # Ensure network variables are created. - policy_variables = {'policy': behavior_network.variables} - - # Create the variable client responsible for keeping the actor up-to-date. - variable_client = tf2_variable_utils.VariableClient( - variable_source, policy_variables, update_period=1000) - - # Make sure not to use a random policy after checkpoint restoration by - # assigning variables before running the environment loop. - variable_client.update_and_wait() - - # Component to add things into replay. - adder = adders.NStepTransitionAdder( - client=replay, n_step=self._n_step, discount=self._additional_discount) - - # Create the agent. - actor = actors.FeedForwardActor( - policy_network=behavior_network, - adder=adder, - variable_client=variable_client) - - # Create logger and counter; actors will not spam bigtable. - save_data = actor_id == 0 - counter = counting.Counter(counter, 'actor') - logger = loggers.make_default_logger( - 'actor', save_data=save_data, time_delta=self._log_every) - - # Create the run loop and return it. - return acme.EnvironmentLoop(environment, actor, counter, logger) - - def evaluator( - self, - variable_source: acme.VariableSource, - counter: counting.Counter, - ): - """The evaluation process.""" - - # Create environment and target networks to act with. - environment = self._environment_factory(True) - agent_networks = self._network_factory(self._environment_spec) - - # Create a stochastic behavior policy. - evaluator_network = snt.Sequential([ - agent_networks['observation'], - agent_networks['policy'], - networks.StochasticMeanHead(), - ]) - - # Create the variable client responsible for keeping the actor up-to-date. - variable_client = tf2_variable_utils.VariableClient( - variable_source, - variables={'policy': evaluator_network.variables}, - update_period=1000) - - # Make sure not to evaluate a random actor by assigning variables before - # running the environment loop. - variable_client.update_and_wait() - - # Create the agent. - evaluator = actors.FeedForwardActor( - policy_network=evaluator_network, variable_client=variable_client) - - # Create logger and counter. - counter = counting.Counter(counter, 'evaluator') - logger = loggers.make_default_logger( - 'evaluator', time_delta=self._log_every) - - # Create the run loop and return it. - return acme.EnvironmentLoop(environment, evaluator, counter, logger) - - def build(self, name='dmpo'): - """Build the distributed agent topology.""" - program = lp.Program(name=name) - - with program.group('replay'): - replay = program.add_node(lp.ReverbNode(self.replay)) - - with program.group('counter'): - counter = program.add_node(lp.CourierNode(self.counter)) - - if self._max_actor_steps: - _ = program.add_node( - lp.CourierNode(self.coordinator, counter, self._max_actor_steps)) - - with program.group('learner'): - learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) - - with program.group('evaluator'): - program.add_node(lp.CourierNode(self.evaluator, learner, counter)) - - if not self._num_caches: - # Use our learner as a single variable source. - sources = [learner] - else: - with program.group('cacher'): - # Create a set of learner caches. - sources = [] - for _ in range(self._num_caches): - cacher = program.add_node( - lp.CacherNode( - learner, refresh_interval_ms=2000, stale_after_ms=4000)) - sources.append(cacher) - - with program.group('actor'): - # Add actors which pull round-robin from our variable sources. - for actor_id in range(self._num_actors): - source = sources[actor_id % len(sources)] - program.add_node( - lp.CourierNode(self.actor, replay, source, counter, actor_id)) - - return program + """Program definition for distributional (MoG) MPO.""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.EnvironmentSpec], Dict[str, snt.Module]], + num_actors: int = 1, + num_caches: int = 0, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = 4, + min_replay_size: int = 1_000, + max_replay_size: int = 1_000_000, + samples_per_insert: Optional[float] = 32.0, + n_step: int = 5, + num_samples: int = 20, + policy_evaluation_config: Optional[learning.PolicyEvaluationConfig] = None, + additional_discount: float = 0.99, + target_policy_update_period: int = 100, + target_critic_update_period: int = 100, + policy_loss_factory: Optional[Callable[[], snt.Module]] = None, + max_actor_steps: Optional[int] = None, + log_every: float = 10.0, + ): + + if environment_spec is None: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + self._environment_factory = environment_factory + self._network_factory = network_factory + self._policy_loss_factory = policy_loss_factory + self._environment_spec = environment_spec + self._num_actors = num_actors + self._num_caches = num_caches + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._min_replay_size = min_replay_size + self._max_replay_size = max_replay_size + self._samples_per_insert = samples_per_insert + self._n_step = n_step + self._additional_discount = additional_discount + self._num_samples = num_samples + self._policy_evaluation_config = policy_evaluation_config + self._target_policy_update_period = target_policy_update_period + self._target_critic_update_period = target_critic_update_period + self._max_actor_steps = max_actor_steps + self._log_every = log_every + + def replay(self): + """The replay storage.""" + if self._samples_per_insert is not None: + # Create enough of an error buffer to give a 10% tolerance in rate. + samples_per_insert_tolerance = 0.1 * self._samples_per_insert + error_buffer = self._min_replay_size * samples_per_insert_tolerance + + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._min_replay_size, + samples_per_insert=self._samples_per_insert, + error_buffer=error_buffer, + ) + else: + limiter = reverb.rate_limiters.MinSize(self._min_replay_size) + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._max_replay_size, + rate_limiter=limiter, + signature=adders.NStepTransitionAdder.signature(self._environment_spec), + ) + return [replay_table] + + def counter(self): + return tf2_savers.CheckpointingRunner( + counting.Counter(), time_delta_minutes=1, subdirectory="counter" + ) + + def coordinator(self, counter: counting.Counter, max_actor_steps: int): + return lp_utils.StepsLimiter(counter, max_actor_steps) + + def learner( + self, replay: reverb.Client, counter: counting.Counter, + ): + """The Learning part of the agent.""" + + # Create online and target networks. + online_networks = self._network_factory(self._environment_spec) + target_networks = self._network_factory(self._environment_spec) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + server_address=replay.server_address, + batch_size=self._batch_size, + prefetch_size=self._prefetch_size, + ) + + counter = counting.Counter(counter, "learner") + logger = loggers.make_default_logger("learner", time_delta=self._log_every) + + # Create policy loss module if a factory is passed. + if self._policy_loss_factory: + policy_loss_module = self._policy_loss_factory() + else: + policy_loss_module = None + + # Return the learning agent. + return learning.MoGMPOLearner( + policy_network=online_networks["policy"], + critic_network=online_networks["critic"], + observation_network=online_networks["observation"], + target_policy_network=target_networks["policy"], + target_critic_network=target_networks["critic"], + target_observation_network=target_networks["observation"], + discount=self._additional_discount, + num_samples=self._num_samples, + policy_evaluation_config=self._policy_evaluation_config, + target_policy_update_period=self._target_policy_update_period, + target_critic_update_period=self._target_critic_update_period, + policy_loss_module=policy_loss_module, + dataset=dataset, + counter=counter, + logger=logger, + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + actor_id: int, + ) -> acme.EnvironmentLoop: + """The actor process.""" + + # Create environment and target networks to act with. + environment = self._environment_factory(False) + agent_networks = self._network_factory(self._environment_spec) + + # Create a stochastic behavior policy. + behavior_network = snt.Sequential( + [ + agent_networks["observation"], + agent_networks["policy"], + networks.StochasticSamplingHead(), + ] + ) + + # Ensure network variables are created. + policy_variables = {"policy": behavior_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, policy_variables, update_period=1000 + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Component to add things into replay. + adder = adders.NStepTransitionAdder( + client=replay, n_step=self._n_step, discount=self._additional_discount + ) + + # Create the agent. + actor = actors.FeedForwardActor( + policy_network=behavior_network, + adder=adder, + variable_client=variable_client, + ) + + # Create logger and counter; actors will not spam bigtable. + save_data = actor_id == 0 + counter = counting.Counter(counter, "actor") + logger = loggers.make_default_logger( + "actor", save_data=save_data, time_delta=self._log_every + ) + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, variable_source: acme.VariableSource, counter: counting.Counter, + ): + """The evaluation process.""" + + # Create environment and target networks to act with. + environment = self._environment_factory(True) + agent_networks = self._network_factory(self._environment_spec) + + # Create a stochastic behavior policy. + evaluator_network = snt.Sequential( + [ + agent_networks["observation"], + agent_networks["policy"], + networks.StochasticMeanHead(), + ] + ) + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, + variables={"policy": evaluator_network.variables}, + update_period=1000, + ) + + # Make sure not to evaluate a random actor by assigning variables before + # running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + evaluator = actors.FeedForwardActor( + policy_network=evaluator_network, variable_client=variable_client + ) + + # Create logger and counter. + counter = counting.Counter(counter, "evaluator") + logger = loggers.make_default_logger("evaluator", time_delta=self._log_every) + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, evaluator, counter, logger) + + def build(self, name="dmpo"): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group("replay"): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group("counter"): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + _ = program.add_node( + lp.CourierNode(self.coordinator, counter, self._max_actor_steps) + ) + + with program.group("learner"): + learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) + + with program.group("evaluator"): + program.add_node(lp.CourierNode(self.evaluator, learner, counter)) + + if not self._num_caches: + # Use our learner as a single variable source. + sources = [learner] + else: + with program.group("cacher"): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000 + ) + ) + sources.append(cacher) + + with program.group("actor"): + # Add actors which pull round-robin from our variable sources. + for actor_id in range(self._num_actors): + source = sources[actor_id % len(sources)] + program.add_node( + lp.CourierNode(self.actor, replay, source, counter, actor_id) + ) + + return program diff --git a/acme/agents/tf/mog_mpo/learning.py b/acme/agents/tf/mog_mpo/learning.py index 60b358afd2..36e8e397da 100644 --- a/acme/agents/tf/mog_mpo/learning.py +++ b/acme/agents/tf/mog_mpo/learning.py @@ -18,300 +18,315 @@ import time from typing import List, Optional -import acme -from acme import types -from acme.tf import losses -from acme.tf import savers as tf2_savers -from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers import numpy as np import reverb import sonnet as snt import tensorflow as tf import tensorflow_probability as tfp +import acme +from acme import types +from acme.tf import losses +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting, loggers + tfd = tfp.distributions @dataclasses.dataclass class PolicyEvaluationConfig: - evaluate_stochastic_policy: bool = True - num_value_samples: int = 128 + evaluate_stochastic_policy: bool = True + num_value_samples: int = 128 class MoGMPOLearner(acme.Learner): - """Distributional (MoG) MPO learner.""" - - def __init__( - self, - policy_network: snt.Module, - critic_network: snt.Module, - target_policy_network: snt.Module, - target_critic_network: snt.Module, - discount: float, - num_samples: int, - target_policy_update_period: int, - target_critic_update_period: int, - dataset: tf.data.Dataset, - observation_network: snt.Module, - target_observation_network: snt.Module, - policy_evaluation_config: Optional[PolicyEvaluationConfig] = None, - policy_loss_module: Optional[snt.Module] = None, - policy_optimizer: Optional[snt.Optimizer] = None, - critic_optimizer: Optional[snt.Optimizer] = None, - dual_optimizer: Optional[snt.Optimizer] = None, - clipping: bool = True, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - checkpoint: bool = True, - ): - - # Store online and target networks. - self._policy_network = policy_network - self._critic_network = critic_network - self._observation_network = observation_network - self._target_policy_network = target_policy_network - self._target_critic_network = target_critic_network - self._target_observation_network = target_observation_network - - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger('learner') - - # Other learner parameters. - self._discount = discount - self._num_samples = num_samples - if policy_evaluation_config is None: - policy_evaluation_config = PolicyEvaluationConfig() - self._policy_evaluation_config = policy_evaluation_config - self._clipping = clipping - - # Necessary to track when to update target networks. - self._num_steps = tf.Variable(0, dtype=tf.int32) - self._target_policy_update_period = target_policy_update_period - self._target_critic_update_period = target_critic_update_period - - # Batch dataset and create iterator. - self._iterator = iter(dataset) # pytype: disable=wrong-arg-types - - self._policy_loss_module = policy_loss_module or losses.MPO( - epsilon=1e-1, - epsilon_mean=3e-3, - epsilon_stddev=1e-6, - epsilon_penalty=1e-3, - init_log_temperature=10., - init_log_alpha_mean=10., - init_log_alpha_stddev=1000.) - - # Create the optimizers. - self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) - self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) - self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) - - # Expose the variables. - policy_network_to_expose = snt.Sequential( - [self._target_observation_network, self._target_policy_network]) - self._variables = { - 'critic': self._target_critic_network.variables, - 'policy': policy_network_to_expose.variables, - } - - # Create a checkpointer and snapshotter object. - self._checkpointer = None - self._snapshotter = None - - if checkpoint: - self._checkpointer = tf2_savers.Checkpointer( - subdirectory='mog_mpo_learner', - objects_to_save={ - 'counter': self._counter, - 'policy': self._policy_network, - 'critic': self._critic_network, - 'observation': self._observation_network, - 'target_policy': self._target_policy_network, - 'target_critic': self._target_critic_network, - 'target_observation': self._target_observation_network, - 'policy_optimizer': self._policy_optimizer, - 'critic_optimizer': self._critic_optimizer, - 'dual_optimizer': self._dual_optimizer, - 'policy_loss_module': self._policy_loss_module, - 'num_steps': self._num_steps, - }) - - self._snapshotter = tf2_savers.Snapshotter( - objects_to_save={ - 'policy': - snt.Sequential([ - self._target_observation_network, - self._target_policy_network - ]), - }) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - @tf.function - def _step(self, inputs: reverb.ReplaySample) -> types.NestedTensor: - - # Get data from replay (dropping extras if any). Note there is no - # extra data here because we do not insert any into Reverb. - o_tm1, a_tm1, r_t, d_t, o_t = (inputs.data.observation, inputs.data.action, - inputs.data.reward, inputs.data.discount, - inputs.data.next_observation) - - # Cast the additional discount to match the environment discount dtype. - discount = tf.cast(self._discount, dtype=d_t.dtype) - - with tf.GradientTape(persistent=True) as tape: - # Maybe transform the observation before feeding into policy and critic. - # Transforming the observations this way at the start of the learning - # step effectively means that the policy and critic share observation - # network weights. - o_tm1 = self._observation_network(o_tm1) - # This stop_gradient prevents gradients to propagate into the target - # observation network. In addition, since the online policy network is - # evaluated at o_t, this also means the policy loss does not influence - # the observation network training. - o_t = tf.stop_gradient(self._target_observation_network(o_t)) - - # Get online and target action distributions from policy networks. - online_action_distribution = self._policy_network(o_t) - target_action_distribution = self._target_policy_network(o_t) - - # Sample actions to evaluate policy; of size [N, B, ...]. - sampled_actions = target_action_distribution.sample(self._num_samples) - - # Tile embedded observations to feed into the target critic network. - # Note: this is more efficient than tiling before the embedding layer. - tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] - - # Compute target-estimated distributional value of sampled actions at o_t. - sampled_q_t_distributions = self._target_critic_network( - # Merge batch dimensions; to shape [N*B, ...]. - snt.merge_leading_dims(tiled_o_t, num_dims=2), - snt.merge_leading_dims(sampled_actions, num_dims=2)) - - # Compute online critic value distribution of a_tm1 in state o_tm1. - q_tm1_distribution = self._critic_network(o_tm1, a_tm1) # [B, ...] - - # Get the return distributions used in the policy evaluation bootstrap. - if self._policy_evaluation_config.evaluate_stochastic_policy: - z_distributions = sampled_q_t_distributions - num_joint_samples = self._num_samples - else: - z_distributions = self._target_critic_network( - o_t, target_action_distribution.mean()) - num_joint_samples = 1 - - num_value_samples = self._policy_evaluation_config.num_value_samples - num_joint_samples *= num_value_samples - z_samples = z_distributions.sample(num_value_samples) - z_samples = tf.reshape(z_samples, (num_joint_samples, -1, 1)) - - # Expand dims of reward and discount tensors. - reward = r_t[..., tf.newaxis] # [B, 1] - full_discount = discount * d_t[..., tf.newaxis] - target_q = reward + full_discount * z_samples # [N, B, 1] - target_q = tf.stop_gradient(target_q) - - # Compute sample-based cross-entropy. - log_probs_q = q_tm1_distribution.log_prob(target_q) # [N, B, 1] - critic_loss = -tf.reduce_mean(log_probs_q, axis=0) # [B, 1] - critic_loss = tf.reduce_mean(critic_loss) - - # Compute Q-values of sampled actions and reshape to [N, B]. - sampled_q_values = sampled_q_t_distributions.mean() - sampled_q_values = tf.reshape(sampled_q_values, (self._num_samples, -1)) - - # Compute MPO policy loss. - policy_loss, policy_stats = self._policy_loss_module( - online_action_distribution=online_action_distribution, - target_action_distribution=target_action_distribution, - actions=sampled_actions, - q_values=sampled_q_values) - policy_loss = tf.reduce_mean(policy_loss) - - # For clarity, explicitly define which variables are trained by which loss. - critic_trainable_variables = ( - # In this agent, the critic loss trains the observation network. - self._observation_network.trainable_variables + - self._critic_network.trainable_variables) - policy_trainable_variables = self._policy_network.trainable_variables - # The following are the MPO dual variables, stored in the loss module. - dual_trainable_variables = self._policy_loss_module.trainable_variables - - # Compute gradients. - critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) - policy_gradients, dual_gradients = tape.gradient( - policy_loss, (policy_trainable_variables, dual_trainable_variables)) - - # Delete the tape manually because of the persistent=True flag. - del tape - - # Maybe clip gradients. - if self._clipping: - policy_gradients = tuple(tf.clip_by_global_norm(policy_gradients, 40.)[0]) - critic_gradients = tuple(tf.clip_by_global_norm(critic_gradients, 40.)[0]) - - # Apply gradients. - self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) - self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) - self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) - - # Losses to track. - fetches = { - 'critic_loss': critic_loss, - 'policy_loss': policy_loss, - } - # Log MPO stats. - fetches.update(policy_stats) - - return fetches - - def step(self): - self._maybe_update_target_networks() - self._num_steps.assign_add(1) - - # Run the learning step. - fetches = self._step(next(self._iterator)) - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Update our counts and record it. - counts = self._counter.increment(steps=1, walltime=elapsed_time) - fetches.update(counts) - - # Checkpoint and attempt to write the logs. - if self._checkpointer is not None: - self._checkpointer.save() - if self._snapshotter is not None: - self._snapshotter.save() - self._logger.write(fetches) - - def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: - return [tf2_utils.to_numpy(self._variables[name]) for name in names] - - def _maybe_update_target_networks(self): - # Update target network. - online_policy_variables = self._policy_network.variables - target_policy_variables = self._target_policy_network.variables - online_critic_variables = (*self._observation_network.variables, - *self._critic_network.variables) - target_critic_variables = (*self._target_observation_network.variables, - *self._target_critic_network.variables) - - # Make online policy -> target policy network update ops. - if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: - for src, dest in zip(online_policy_variables, target_policy_variables): - dest.assign(src) - - # Make online critic -> target critic network update ops. - if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: - for src, dest in zip(online_critic_variables, target_critic_variables): - dest.assign(src) + """Distributional (MoG) MPO learner.""" + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + target_policy_network: snt.Module, + target_critic_network: snt.Module, + discount: float, + num_samples: int, + target_policy_update_period: int, + target_critic_update_period: int, + dataset: tf.data.Dataset, + observation_network: snt.Module, + target_observation_network: snt.Module, + policy_evaluation_config: Optional[PolicyEvaluationConfig] = None, + policy_loss_module: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + dual_optimizer: Optional[snt.Optimizer] = None, + clipping: bool = True, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._observation_network = observation_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + self._target_observation_network = target_observation_network + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger("learner") + + # Other learner parameters. + self._discount = discount + self._num_samples = num_samples + if policy_evaluation_config is None: + policy_evaluation_config = PolicyEvaluationConfig() + self._policy_evaluation_config = policy_evaluation_config + self._clipping = clipping + + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_policy_update_period = target_policy_update_period + self._target_critic_update_period = target_critic_update_period + + # Batch dataset and create iterator. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + + self._policy_loss_module = policy_loss_module or losses.MPO( + epsilon=1e-1, + epsilon_mean=3e-3, + epsilon_stddev=1e-6, + epsilon_penalty=1e-3, + init_log_temperature=10.0, + init_log_alpha_mean=10.0, + init_log_alpha_stddev=1000.0, + ) + + # Create the optimizers. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) + + # Expose the variables. + policy_network_to_expose = snt.Sequential( + [self._target_observation_network, self._target_policy_network] + ) + self._variables = { + "critic": self._target_critic_network.variables, + "policy": policy_network_to_expose.variables, + } + + # Create a checkpointer and snapshotter object. + self._checkpointer = None + self._snapshotter = None + + if checkpoint: + self._checkpointer = tf2_savers.Checkpointer( + subdirectory="mog_mpo_learner", + objects_to_save={ + "counter": self._counter, + "policy": self._policy_network, + "critic": self._critic_network, + "observation": self._observation_network, + "target_policy": self._target_policy_network, + "target_critic": self._target_critic_network, + "target_observation": self._target_observation_network, + "policy_optimizer": self._policy_optimizer, + "critic_optimizer": self._critic_optimizer, + "dual_optimizer": self._dual_optimizer, + "policy_loss_module": self._policy_loss_module, + "num_steps": self._num_steps, + }, + ) + + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={ + "policy": snt.Sequential( + [self._target_observation_network, self._target_policy_network] + ), + } + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self, inputs: reverb.ReplaySample) -> types.NestedTensor: + + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + o_tm1, a_tm1, r_t, d_t, o_t = ( + inputs.data.observation, + inputs.data.action, + inputs.data.reward, + inputs.data.discount, + inputs.data.next_observation, + ) + + # Cast the additional discount to match the environment discount dtype. + discount = tf.cast(self._discount, dtype=d_t.dtype) + + with tf.GradientTape(persistent=True) as tape: + # Maybe transform the observation before feeding into policy and critic. + # Transforming the observations this way at the start of the learning + # step effectively means that the policy and critic share observation + # network weights. + o_tm1 = self._observation_network(o_tm1) + # This stop_gradient prevents gradients to propagate into the target + # observation network. In addition, since the online policy network is + # evaluated at o_t, this also means the policy loss does not influence + # the observation network training. + o_t = tf.stop_gradient(self._target_observation_network(o_t)) + + # Get online and target action distributions from policy networks. + online_action_distribution = self._policy_network(o_t) + target_action_distribution = self._target_policy_network(o_t) + + # Sample actions to evaluate policy; of size [N, B, ...]. + sampled_actions = target_action_distribution.sample(self._num_samples) + + # Tile embedded observations to feed into the target critic network. + # Note: this is more efficient than tiling before the embedding layer. + tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] + + # Compute target-estimated distributional value of sampled actions at o_t. + sampled_q_t_distributions = self._target_critic_network( + # Merge batch dimensions; to shape [N*B, ...]. + snt.merge_leading_dims(tiled_o_t, num_dims=2), + snt.merge_leading_dims(sampled_actions, num_dims=2), + ) + + # Compute online critic value distribution of a_tm1 in state o_tm1. + q_tm1_distribution = self._critic_network(o_tm1, a_tm1) # [B, ...] + + # Get the return distributions used in the policy evaluation bootstrap. + if self._policy_evaluation_config.evaluate_stochastic_policy: + z_distributions = sampled_q_t_distributions + num_joint_samples = self._num_samples + else: + z_distributions = self._target_critic_network( + o_t, target_action_distribution.mean() + ) + num_joint_samples = 1 + + num_value_samples = self._policy_evaluation_config.num_value_samples + num_joint_samples *= num_value_samples + z_samples = z_distributions.sample(num_value_samples) + z_samples = tf.reshape(z_samples, (num_joint_samples, -1, 1)) + + # Expand dims of reward and discount tensors. + reward = r_t[..., tf.newaxis] # [B, 1] + full_discount = discount * d_t[..., tf.newaxis] + target_q = reward + full_discount * z_samples # [N, B, 1] + target_q = tf.stop_gradient(target_q) + + # Compute sample-based cross-entropy. + log_probs_q = q_tm1_distribution.log_prob(target_q) # [N, B, 1] + critic_loss = -tf.reduce_mean(log_probs_q, axis=0) # [B, 1] + critic_loss = tf.reduce_mean(critic_loss) + + # Compute Q-values of sampled actions and reshape to [N, B]. + sampled_q_values = sampled_q_t_distributions.mean() + sampled_q_values = tf.reshape(sampled_q_values, (self._num_samples, -1)) + + # Compute MPO policy loss. + policy_loss, policy_stats = self._policy_loss_module( + online_action_distribution=online_action_distribution, + target_action_distribution=target_action_distribution, + actions=sampled_actions, + q_values=sampled_q_values, + ) + policy_loss = tf.reduce_mean(policy_loss) + + # For clarity, explicitly define which variables are trained by which loss. + critic_trainable_variables = ( + # In this agent, the critic loss trains the observation network. + self._observation_network.trainable_variables + + self._critic_network.trainable_variables + ) + policy_trainable_variables = self._policy_network.trainable_variables + # The following are the MPO dual variables, stored in the loss module. + dual_trainable_variables = self._policy_loss_module.trainable_variables + + # Compute gradients. + critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) + policy_gradients, dual_gradients = tape.gradient( + policy_loss, (policy_trainable_variables, dual_trainable_variables) + ) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Maybe clip gradients. + if self._clipping: + policy_gradients = tuple(tf.clip_by_global_norm(policy_gradients, 40.0)[0]) + critic_gradients = tuple(tf.clip_by_global_norm(critic_gradients, 40.0)[0]) + + # Apply gradients. + self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) + self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) + self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) + + # Losses to track. + fetches = { + "critic_loss": critic_loss, + "policy_loss": policy_loss, + } + # Log MPO stats. + fetches.update(policy_stats) + + return fetches + + def step(self): + self._maybe_update_target_networks() + self._num_steps.assign_add(1) + + # Run the learning step. + fetches = self._step(next(self._iterator)) + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + if self._checkpointer is not None: + self._checkpointer.save() + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] + + def _maybe_update_target_networks(self): + # Update target network. + online_policy_variables = self._policy_network.variables + target_policy_variables = self._target_policy_network.variables + online_critic_variables = ( + *self._observation_network.variables, + *self._critic_network.variables, + ) + target_critic_variables = ( + *self._target_observation_network.variables, + *self._target_critic_network.variables, + ) + + # Make online policy -> target policy network update ops. + if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: + for src, dest in zip(online_policy_variables, target_policy_variables): + dest.assign(src) + + # Make online critic -> target critic network update ops. + if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: + for src, dest in zip(online_critic_variables, target_critic_variables): + dest.assign(src) diff --git a/acme/agents/tf/mog_mpo/networks.py b/acme/agents/tf/mog_mpo/networks.py index 1d9b18b78d..06011e11a8 100644 --- a/acme/agents/tf/mog_mpo/networks.py +++ b/acme/agents/tf/mog_mpo/networks.py @@ -16,13 +16,13 @@ from typing import Mapping, Sequence +import numpy as np +import sonnet as snt + from acme import specs from acme.tf import networks from acme.tf import utils as tf2_utils -import numpy as np -import sonnet as snt - def make_default_networks( environment_spec: specs.EnvironmentSpec, @@ -33,44 +33,48 @@ def make_default_networks( critic_init_scale: float = 1e-3, critic_num_components: int = 5, ) -> Mapping[str, snt.Module]: - """Creates networks used by the agent.""" + """Creates networks used by the agent.""" - # Unpack the environment spec to get appropriate shapes, dtypes, etc. - act_spec = environment_spec.actions - obs_spec = environment_spec.observations - num_dimensions = np.prod(act_spec.shape, dtype=int) + # Unpack the environment spec to get appropriate shapes, dtypes, etc. + act_spec = environment_spec.actions + obs_spec = environment_spec.observations + num_dimensions = np.prod(act_spec.shape, dtype=int) - # Create the observation network and make sure it's a Sonnet module. - observation_network = tf2_utils.batch_concat - observation_network = tf2_utils.to_sonnet_module(observation_network) + # Create the observation network and make sure it's a Sonnet module. + observation_network = tf2_utils.batch_concat + observation_network = tf2_utils.to_sonnet_module(observation_network) - # Create the policy network. - policy_network = snt.Sequential([ - networks.LayerNormMLP(policy_layer_sizes, activate_final=True), - networks.MultivariateNormalDiagHead( - num_dimensions, - init_scale=policy_init_scale, - use_tfd_independent=True) - ]) + # Create the policy network. + policy_network = snt.Sequential( + [ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, init_scale=policy_init_scale, use_tfd_independent=True + ), + ] + ) - # The multiplexer concatenates the (maybe transformed) observations/actions. - critic_network = snt.Sequential([ - networks.CriticMultiplexer(action_network=networks.ClipToSpec(act_spec)), - networks.LayerNormMLP(critic_layer_sizes, activate_final=True), - networks.GaussianMixtureHead( - num_dimensions=1, - num_components=critic_num_components, - init_scale=critic_init_scale) - ]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = snt.Sequential( + [ + networks.CriticMultiplexer(action_network=networks.ClipToSpec(act_spec)), + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.GaussianMixtureHead( + num_dimensions=1, + num_components=critic_num_components, + init_scale=critic_init_scale, + ), + ] + ) - # Create network variables. - # Get embedding spec by creating observation network variables. - emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) - tf2_utils.create_variables(policy_network, [emb_spec]) - tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) + # Create network variables. + # Get embedding spec by creating observation network variables. + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + tf2_utils.create_variables(policy_network, [emb_spec]) + tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) - return { - 'policy': policy_network, - 'critic': critic_network, - 'observation': observation_network, - } + return { + "policy": policy_network, + "critic": critic_network, + "observation": observation_network, + } diff --git a/acme/agents/tf/mompo/__init__.py b/acme/agents/tf/mompo/__init__.py index cee0d99a84..6d8dc28d6e 100644 --- a/acme/agents/tf/mompo/__init__.py +++ b/acme/agents/tf/mompo/__init__.py @@ -16,6 +16,8 @@ from acme.agents.tf.mompo.agent import MultiObjectiveMPO from acme.agents.tf.mompo.agent_distributed import DistributedMultiObjectiveMPO -from acme.agents.tf.mompo.learning import MultiObjectiveMPOLearner -from acme.agents.tf.mompo.learning import QValueObjective -from acme.agents.tf.mompo.learning import RewardObjective +from acme.agents.tf.mompo.learning import ( + MultiObjectiveMPOLearner, + QValueObjective, + RewardObjective, +) diff --git a/acme/agents/tf/mompo/agent.py b/acme/agents/tf/mompo/agent.py index e597a55926..7d5333cc77 100644 --- a/acme/agents/tf/mompo/agent.py +++ b/acme/agents/tf/mompo/agent.py @@ -17,25 +17,22 @@ import copy from typing import Optional, Sequence -from acme import datasets -from acme import specs -from acme import types +import reverb +import sonnet as snt +import tensorflow as tf + +from acme import datasets, specs, types from acme.adders import reverb as adders from acme.agents import agent from acme.agents.tf import actors from acme.agents.tf.mompo import learning -from acme.tf import losses -from acme.tf import networks +from acme.tf import losses, networks from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers -import reverb -import sonnet as snt -import tensorflow as tf +from acme.utils import counting, loggers class MultiObjectiveMPO(agent.Agent): - """Multi-objective MPO Agent. + """Multi-objective MPO Agent. This implements a single-process multi-objective MPO agent. This is an actor-critic algorithm that generates data via a behavior policy, inserts @@ -49,32 +46,34 @@ class MultiObjectiveMPO(agent.Agent): Q-values or a DiscreteValuedDistribution. """ - def __init__(self, - reward_objectives: Sequence[learning.RewardObjective], - qvalue_objectives: Sequence[learning.QValueObjective], - environment_spec: specs.EnvironmentSpec, - policy_network: snt.Module, - critic_network: snt.Module, - observation_network: types.TensorTransformation = tf.identity, - discount: float = 0.99, - batch_size: int = 512, - prefetch_size: int = 4, - target_policy_update_period: int = 200, - target_critic_update_period: int = 200, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: float = 16., - policy_loss_module: Optional[losses.MultiObjectiveMPO] = None, - policy_optimizer: Optional[snt.Optimizer] = None, - critic_optimizer: Optional[snt.Optimizer] = None, - n_step: int = 5, - num_samples: int = 20, - clipping: bool = True, - logger: Optional[loggers.Logger] = None, - counter: Optional[counting.Counter] = None, - checkpoint: bool = True, - replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): - """Initialize the agent. + def __init__( + self, + reward_objectives: Sequence[learning.RewardObjective], + qvalue_objectives: Sequence[learning.QValueObjective], + environment_spec: specs.EnvironmentSpec, + policy_network: snt.Module, + critic_network: snt.Module, + observation_network: types.TensorTransformation = tf.identity, + discount: float = 0.99, + batch_size: int = 512, + prefetch_size: int = 4, + target_policy_update_period: int = 200, + target_critic_update_period: int = 200, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 16.0, + policy_loss_module: Optional[losses.MultiObjectiveMPO] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + n_step: int = 5, + num_samples: int = 20, + clipping: bool = True, + logger: Optional[loggers.Logger] = None, + counter: Optional[counting.Counter] = None, + checkpoint: bool = True, + replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, + ): + """Initialize the agent. Args: reward_objectives: list of the objectives that the policy should optimize; @@ -111,94 +110,94 @@ def __init__(self, checkpoint: boolean indicating whether to checkpoint the learner. replay_table_name: string indicating what name to give the replay table. """ - # Check that at least one objective's reward function is specified. - if not reward_objectives: - raise ValueError('Must specify at least one reward objective.') - - # Create a replay server to add data to. - replay_table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=max_replay_size, - rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), - signature=adders.NStepTransitionAdder.signature(environment_spec)) - self._server = reverb.Server([replay_table], port=None) - - # The adder is used to insert observations into replay. - address = f'localhost:{self._server.port}' - adder = adders.NStepTransitionAdder( - client=reverb.Client(address), - n_step=n_step, - discount=discount) - - # The dataset object to learn from. - dataset = datasets.make_reverb_dataset( - table=replay_table_name, - server_address=address, - batch_size=batch_size, - prefetch_size=prefetch_size) - - # Make sure observation network is a Sonnet Module. - observation_network = tf2_utils.to_sonnet_module(observation_network) - - # Create target networks before creating online/target network variables. - target_policy_network = copy.deepcopy(policy_network) - target_critic_network = copy.deepcopy(critic_network) - target_observation_network = copy.deepcopy(observation_network) - - # Get observation and action specs. - act_spec = environment_spec.actions - obs_spec = environment_spec.observations - emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) - - # Create the behavior policy. - behavior_network = snt.Sequential([ - observation_network, - policy_network, - networks.StochasticSamplingHead(), - ]) - - # Create variables. - tf2_utils.create_variables(policy_network, [emb_spec]) - tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) - tf2_utils.create_variables(target_policy_network, [emb_spec]) - tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) - tf2_utils.create_variables(target_observation_network, [obs_spec]) - - # Create the actor which defines how we take actions. - actor = actors.FeedForwardActor( - policy_network=behavior_network, adder=adder) - - # Create optimizers. - policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) - critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) - - # The learner updates the parameters (and initializes them). - learner = learning.MultiObjectiveMPOLearner( - reward_objectives=reward_objectives, - qvalue_objectives=qvalue_objectives, - policy_network=policy_network, - critic_network=critic_network, - observation_network=observation_network, - target_policy_network=target_policy_network, - target_critic_network=target_critic_network, - target_observation_network=target_observation_network, - policy_loss_module=policy_loss_module, - policy_optimizer=policy_optimizer, - critic_optimizer=critic_optimizer, - clipping=clipping, - discount=discount, - num_samples=num_samples, - target_policy_update_period=target_policy_update_period, - target_critic_update_period=target_critic_update_period, - dataset=dataset, - logger=logger, - counter=counter, - checkpoint=checkpoint) - - super().__init__( - actor=actor, - learner=learner, - min_observations=max(batch_size, min_replay_size), - observations_per_step=float(batch_size) / samples_per_insert) + # Check that at least one objective's reward function is specified. + if not reward_objectives: + raise ValueError("Must specify at least one reward objective.") + + # Create a replay server to add data to. + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), + signature=adders.NStepTransitionAdder.signature(environment_spec), + ) + self._server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f"localhost:{self._server.port}" + adder = adders.NStepTransitionAdder( + client=reverb.Client(address), n_step=n_step, discount=discount + ) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + table=replay_table_name, + server_address=address, + batch_size=batch_size, + prefetch_size=prefetch_size, + ) + + # Make sure observation network is a Sonnet Module. + observation_network = tf2_utils.to_sonnet_module(observation_network) + + # Create target networks before creating online/target network variables. + target_policy_network = copy.deepcopy(policy_network) + target_critic_network = copy.deepcopy(critic_network) + target_observation_network = copy.deepcopy(observation_network) + + # Get observation and action specs. + act_spec = environment_spec.actions + obs_spec = environment_spec.observations + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + + # Create the behavior policy. + behavior_network = snt.Sequential( + [observation_network, policy_network, networks.StochasticSamplingHead(),] + ) + + # Create variables. + tf2_utils.create_variables(policy_network, [emb_spec]) + tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) + tf2_utils.create_variables(target_policy_network, [emb_spec]) + tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) + tf2_utils.create_variables(target_observation_network, [obs_spec]) + + # Create the actor which defines how we take actions. + actor = actors.FeedForwardActor(policy_network=behavior_network, adder=adder) + + # Create optimizers. + policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + + # The learner updates the parameters (and initializes them). + learner = learning.MultiObjectiveMPOLearner( + reward_objectives=reward_objectives, + qvalue_objectives=qvalue_objectives, + policy_network=policy_network, + critic_network=critic_network, + observation_network=observation_network, + target_policy_network=target_policy_network, + target_critic_network=target_critic_network, + target_observation_network=target_observation_network, + policy_loss_module=policy_loss_module, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + clipping=clipping, + discount=discount, + num_samples=num_samples, + target_policy_update_period=target_policy_update_period, + target_critic_update_period=target_critic_update_period, + dataset=dataset, + logger=logger, + counter=counter, + checkpoint=checkpoint, + ) + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert, + ) diff --git a/acme/agents/tf/mompo/agent_distributed.py b/acme/agents/tf/mompo/agent_distributed.py index 7fc2a74174..fc77614653 100644 --- a/acme/agents/tf/mompo/agent_distributed.py +++ b/acme/agents/tf/mompo/agent_distributed.py @@ -16,33 +16,31 @@ from typing import Callable, Dict, Optional, Sequence +import dm_env +import launchpad as lp +import reverb +import sonnet as snt +import tensorflow as tf + import acme -from acme import datasets -from acme import specs +from acme import datasets, specs from acme.adders import reverb as adders from acme.agents.tf import actors from acme.agents.tf.mompo import learning -from acme.tf import losses -from acme.tf import networks +from acme.tf import losses, networks from acme.tf import savers as tf2_savers from acme.tf import utils as tf2_utils from acme.tf import variable_utils as tf2_variable_utils -from acme.utils import counting -from acme.utils import loggers -from acme.utils import lp_utils -import dm_env -import launchpad as lp -import reverb -import sonnet as snt -import tensorflow as tf +from acme.utils import counting, loggers, lp_utils MultiObjectiveNetworkFactorySpec = Callable[ - [specs.BoundedArray, int], Dict[str, snt.Module]] + [specs.BoundedArray, int], Dict[str, snt.Module] +] MultiObjectivePolicyLossFactorySpec = Callable[[], losses.MultiObjectiveMPO] class DistributedMultiObjectiveMPO: - """Program definition for multi-objective MPO. + """Program definition for multi-objective MPO. This agent distinguishes itself from the distributed MPO agent in two ways: - Allowing for one or more objectives (see `acme/agents/tf/mompo/learning.py` @@ -52,310 +50,318 @@ class DistributedMultiObjectiveMPO: Q-values or a DiscreteValuedDistribution. """ - def __init__( - self, - reward_objectives: Sequence[learning.RewardObjective], - qvalue_objectives: Sequence[learning.QValueObjective], - environment_factory: Callable[[bool], dm_env.Environment], - network_factory: MultiObjectiveNetworkFactorySpec, - num_actors: int = 1, - num_caches: int = 0, - environment_spec: Optional[specs.EnvironmentSpec] = None, - batch_size: int = 512, - prefetch_size: int = 4, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: Optional[float] = None, - n_step: int = 5, - max_in_flight_items: int = 5, - num_samples: int = 20, - additional_discount: float = 0.99, - target_policy_update_period: int = 200, - target_critic_update_period: int = 200, - policy_loss_factory: Optional[MultiObjectivePolicyLossFactorySpec] = None, - max_actor_steps: Optional[int] = None, - log_every: float = 10.0, - ): - - if environment_spec is None: - environment_spec = specs.make_environment_spec(environment_factory(False)) - - self._environment_factory = environment_factory - self._network_factory = network_factory - self._policy_loss_factory = policy_loss_factory - self._environment_spec = environment_spec - self._num_actors = num_actors - self._num_caches = num_caches - self._batch_size = batch_size - self._prefetch_size = prefetch_size - self._min_replay_size = min_replay_size - self._max_replay_size = max_replay_size - self._samples_per_insert = samples_per_insert - self._n_step = n_step - self._max_in_flight_items = max_in_flight_items - self._additional_discount = additional_discount - self._num_samples = num_samples - self._target_policy_update_period = target_policy_update_period - self._target_critic_update_period = target_critic_update_period - self._max_actor_steps = max_actor_steps - self._log_every = log_every - - self._reward_objectives = reward_objectives - self._qvalue_objectives = qvalue_objectives - self._num_critic_heads = len(self._reward_objectives) - - if not self._reward_objectives: - raise ValueError('Must specify at least one reward objective.') - - def replay(self): - """The replay storage.""" - if self._samples_per_insert is not None: - # Create enough of an error buffer to give a 10% tolerance in rate. - samples_per_insert_tolerance = 0.1 * self._samples_per_insert - error_buffer = self._min_replay_size * samples_per_insert_tolerance - - limiter = reverb.rate_limiters.SampleToInsertRatio( - min_size_to_sample=self._min_replay_size, - samples_per_insert=self._samples_per_insert, - error_buffer=error_buffer) - else: - limiter = reverb.rate_limiters.MinSize(self._min_replay_size) - replay_table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=self._max_replay_size, - rate_limiter=limiter, - signature=adders.NStepTransitionAdder.signature( - self._environment_spec)) - return [replay_table] - - def counter(self): - return tf2_savers.CheckpointingRunner(counting.Counter(), - time_delta_minutes=1, - subdirectory='counter') - - def coordinator(self, counter: counting.Counter, max_actor_steps: int): - return lp_utils.StepsLimiter(counter, max_actor_steps) - - def learner( - self, - replay: reverb.Client, - counter: counting.Counter, - ): - """The Learning part of the agent.""" - - act_spec = self._environment_spec.actions - obs_spec = self._environment_spec.observations - - # Create online and target networks. - online_networks = self._network_factory(act_spec, self._num_critic_heads) - target_networks = self._network_factory(act_spec, self._num_critic_heads) - - # Make sure observation network is a Sonnet Module. - observation_network = online_networks.get('observation', tf.identity) - target_observation_network = target_networks.get('observation', tf.identity) - observation_network = tf2_utils.to_sonnet_module(observation_network) - target_observation_network = tf2_utils.to_sonnet_module( - target_observation_network) - - # Get embedding spec and create observation network variables. - emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) - - # Create variables. - tf2_utils.create_variables(online_networks['policy'], [emb_spec]) - tf2_utils.create_variables(online_networks['critic'], [emb_spec, act_spec]) - tf2_utils.create_variables(target_networks['policy'], [emb_spec]) - tf2_utils.create_variables(target_networks['critic'], [emb_spec, act_spec]) - tf2_utils.create_variables(target_observation_network, [obs_spec]) - - # The dataset object to learn from. - dataset = datasets.make_reverb_dataset(server_address=replay.server_address) - dataset = dataset.batch(self._batch_size, drop_remainder=True) - dataset = dataset.prefetch(self._prefetch_size) - - counter = counting.Counter(counter, 'learner') - logger = loggers.make_default_logger( - 'learner', time_delta=self._log_every, steps_key='learner_steps') - - # Create policy loss module if a factory is passed. - if self._policy_loss_factory: - policy_loss_module = self._policy_loss_factory() - else: - policy_loss_module = None - - # Return the learning agent. - return learning.MultiObjectiveMPOLearner( - reward_objectives=self._reward_objectives, - qvalue_objectives=self._qvalue_objectives, - policy_network=online_networks['policy'], - critic_network=online_networks['critic'], - observation_network=observation_network, - target_policy_network=target_networks['policy'], - target_critic_network=target_networks['critic'], - target_observation_network=target_observation_network, - discount=self._additional_discount, - num_samples=self._num_samples, - target_policy_update_period=self._target_policy_update_period, - target_critic_update_period=self._target_critic_update_period, - policy_loss_module=policy_loss_module, - dataset=dataset, - counter=counter, - logger=logger) - - def actor( - self, - replay: reverb.Client, - variable_source: acme.VariableSource, - counter: counting.Counter, - ) -> acme.EnvironmentLoop: - """The actor process.""" - - action_spec = self._environment_spec.actions - observation_spec = self._environment_spec.observations - - # Create environment and target networks to act with. - environment = self._environment_factory(False) - agent_networks = self._network_factory(action_spec, self._num_critic_heads) - - # Make sure observation network is defined. - observation_network = agent_networks.get('observation', tf.identity) - - # Create a stochastic behavior policy. - behavior_network = snt.Sequential([ - observation_network, - agent_networks['policy'], - networks.StochasticSamplingHead(), - ]) - - # Ensure network variables are created. - tf2_utils.create_variables(behavior_network, [observation_spec]) - policy_variables = {'policy': behavior_network.variables} - - # Create the variable client responsible for keeping the actor up-to-date. - variable_client = tf2_variable_utils.VariableClient( - variable_source, policy_variables, update_period=1000) - - # Make sure not to use a random policy after checkpoint restoration by - # assigning variables before running the environment loop. - variable_client.update_and_wait() - - # Component to add things into replay. - adder = adders.NStepTransitionAdder( - client=replay, - n_step=self._n_step, - max_in_flight_items=self._max_in_flight_items, - discount=self._additional_discount) - - # Create the agent. - actor = actors.FeedForwardActor( - policy_network=behavior_network, - adder=adder, - variable_client=variable_client) - - # Create logger and counter; actors will not spam bigtable. - counter = counting.Counter(counter, 'actor') - logger = loggers.make_default_logger( - 'actor', - save_data=False, - time_delta=self._log_every, - steps_key='actor_steps') - - # Create the run loop and return it. - return acme.EnvironmentLoop( - environment, actor, counter, logger) - - def evaluator( - self, - variable_source: acme.VariableSource, - counter: counting.Counter, - ): - """The evaluation process.""" - - action_spec = self._environment_spec.actions - observation_spec = self._environment_spec.observations - - # Create environment and target networks to act with. - environment = self._environment_factory(True) - agent_networks = self._network_factory(action_spec, self._num_critic_heads) - - # Make sure observation network is defined. - observation_network = agent_networks.get('observation', tf.identity) - - # Create a deterministic behavior policy. - evaluator_modules = [ - observation_network, - agent_networks['policy'], - networks.StochasticMeanHead(), - ] - if isinstance(action_spec, specs.BoundedArray): - evaluator_modules += [networks.ClipToSpec(action_spec)] - evaluator_network = snt.Sequential(evaluator_modules) - - # Ensure network variables are created. - tf2_utils.create_variables(evaluator_network, [observation_spec]) - policy_variables = {'policy': evaluator_network.variables} - - # Create the variable client responsible for keeping the actor up-to-date. - variable_client = tf2_variable_utils.VariableClient( - variable_source, policy_variables, update_period=1000) - - # Make sure not to evaluate a random actor by assigning variables before - # running the environment loop. - variable_client.update_and_wait() - - # Create the agent. - evaluator = actors.FeedForwardActor( - policy_network=evaluator_network, variable_client=variable_client) - - # Create logger and counter. - counter = counting.Counter(counter, 'evaluator') - logger = loggers.make_default_logger( - 'evaluator', time_delta=self._log_every, steps_key='evaluator_steps') - - # Create the run loop and return it. - return acme.EnvironmentLoop( - environment, evaluator, counter, logger) - - def build(self, name='mompo'): - """Build the distributed agent topology.""" - program = lp.Program(name=name) - - with program.group('replay'): - replay = program.add_node(lp.ReverbNode(self.replay)) - - with program.group('counter'): - counter = program.add_node(lp.CourierNode(self.counter)) - - if self._max_actor_steps: - _ = program.add_node( - lp.CourierNode(self.coordinator, counter, self._max_actor_steps)) - - with program.group('learner'): - learner = program.add_node( - lp.CourierNode(self.learner, replay, counter)) - - with program.group('evaluator'): - program.add_node( - lp.CourierNode(self.evaluator, learner, counter)) - - if not self._num_caches: - # Use our learner as a single variable source. - sources = [learner] - else: - with program.group('cacher'): - # Create a set of learner caches. - sources = [] - for _ in range(self._num_caches): - cacher = program.add_node( - lp.CacherNode( - learner, refresh_interval_ms=2000, stale_after_ms=4000)) - sources.append(cacher) - - with program.group('actor'): - # Add actors which pull round-robin from our variable sources. - for actor_id in range(self._num_actors): - source = sources[actor_id % len(sources)] - program.add_node(lp.CourierNode(self.actor, replay, source, counter)) - - return program + def __init__( + self, + reward_objectives: Sequence[learning.RewardObjective], + qvalue_objectives: Sequence[learning.QValueObjective], + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: MultiObjectiveNetworkFactorySpec, + num_actors: int = 1, + num_caches: int = 0, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 512, + prefetch_size: int = 4, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: Optional[float] = None, + n_step: int = 5, + max_in_flight_items: int = 5, + num_samples: int = 20, + additional_discount: float = 0.99, + target_policy_update_period: int = 200, + target_critic_update_period: int = 200, + policy_loss_factory: Optional[MultiObjectivePolicyLossFactorySpec] = None, + max_actor_steps: Optional[int] = None, + log_every: float = 10.0, + ): + + if environment_spec is None: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + self._environment_factory = environment_factory + self._network_factory = network_factory + self._policy_loss_factory = policy_loss_factory + self._environment_spec = environment_spec + self._num_actors = num_actors + self._num_caches = num_caches + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._min_replay_size = min_replay_size + self._max_replay_size = max_replay_size + self._samples_per_insert = samples_per_insert + self._n_step = n_step + self._max_in_flight_items = max_in_flight_items + self._additional_discount = additional_discount + self._num_samples = num_samples + self._target_policy_update_period = target_policy_update_period + self._target_critic_update_period = target_critic_update_period + self._max_actor_steps = max_actor_steps + self._log_every = log_every + + self._reward_objectives = reward_objectives + self._qvalue_objectives = qvalue_objectives + self._num_critic_heads = len(self._reward_objectives) + + if not self._reward_objectives: + raise ValueError("Must specify at least one reward objective.") + + def replay(self): + """The replay storage.""" + if self._samples_per_insert is not None: + # Create enough of an error buffer to give a 10% tolerance in rate. + samples_per_insert_tolerance = 0.1 * self._samples_per_insert + error_buffer = self._min_replay_size * samples_per_insert_tolerance + + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._min_replay_size, + samples_per_insert=self._samples_per_insert, + error_buffer=error_buffer, + ) + else: + limiter = reverb.rate_limiters.MinSize(self._min_replay_size) + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._max_replay_size, + rate_limiter=limiter, + signature=adders.NStepTransitionAdder.signature(self._environment_spec), + ) + return [replay_table] + + def counter(self): + return tf2_savers.CheckpointingRunner( + counting.Counter(), time_delta_minutes=1, subdirectory="counter" + ) + + def coordinator(self, counter: counting.Counter, max_actor_steps: int): + return lp_utils.StepsLimiter(counter, max_actor_steps) + + def learner( + self, replay: reverb.Client, counter: counting.Counter, + ): + """The Learning part of the agent.""" + + act_spec = self._environment_spec.actions + obs_spec = self._environment_spec.observations + + # Create online and target networks. + online_networks = self._network_factory(act_spec, self._num_critic_heads) + target_networks = self._network_factory(act_spec, self._num_critic_heads) + + # Make sure observation network is a Sonnet Module. + observation_network = online_networks.get("observation", tf.identity) + target_observation_network = target_networks.get("observation", tf.identity) + observation_network = tf2_utils.to_sonnet_module(observation_network) + target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network + ) + + # Get embedding spec and create observation network variables. + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + + # Create variables. + tf2_utils.create_variables(online_networks["policy"], [emb_spec]) + tf2_utils.create_variables(online_networks["critic"], [emb_spec, act_spec]) + tf2_utils.create_variables(target_networks["policy"], [emb_spec]) + tf2_utils.create_variables(target_networks["critic"], [emb_spec, act_spec]) + tf2_utils.create_variables(target_observation_network, [obs_spec]) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset(server_address=replay.server_address) + dataset = dataset.batch(self._batch_size, drop_remainder=True) + dataset = dataset.prefetch(self._prefetch_size) + + counter = counting.Counter(counter, "learner") + logger = loggers.make_default_logger( + "learner", time_delta=self._log_every, steps_key="learner_steps" + ) + + # Create policy loss module if a factory is passed. + if self._policy_loss_factory: + policy_loss_module = self._policy_loss_factory() + else: + policy_loss_module = None + + # Return the learning agent. + return learning.MultiObjectiveMPOLearner( + reward_objectives=self._reward_objectives, + qvalue_objectives=self._qvalue_objectives, + policy_network=online_networks["policy"], + critic_network=online_networks["critic"], + observation_network=observation_network, + target_policy_network=target_networks["policy"], + target_critic_network=target_networks["critic"], + target_observation_network=target_observation_network, + discount=self._additional_discount, + num_samples=self._num_samples, + target_policy_update_period=self._target_policy_update_period, + target_critic_update_period=self._target_critic_update_period, + policy_loss_module=policy_loss_module, + dataset=dataset, + counter=counter, + logger=logger, + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + ) -> acme.EnvironmentLoop: + """The actor process.""" + + action_spec = self._environment_spec.actions + observation_spec = self._environment_spec.observations + + # Create environment and target networks to act with. + environment = self._environment_factory(False) + agent_networks = self._network_factory(action_spec, self._num_critic_heads) + + # Make sure observation network is defined. + observation_network = agent_networks.get("observation", tf.identity) + + # Create a stochastic behavior policy. + behavior_network = snt.Sequential( + [ + observation_network, + agent_networks["policy"], + networks.StochasticSamplingHead(), + ] + ) + + # Ensure network variables are created. + tf2_utils.create_variables(behavior_network, [observation_spec]) + policy_variables = {"policy": behavior_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, policy_variables, update_period=1000 + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Component to add things into replay. + adder = adders.NStepTransitionAdder( + client=replay, + n_step=self._n_step, + max_in_flight_items=self._max_in_flight_items, + discount=self._additional_discount, + ) + + # Create the agent. + actor = actors.FeedForwardActor( + policy_network=behavior_network, + adder=adder, + variable_client=variable_client, + ) + + # Create logger and counter; actors will not spam bigtable. + counter = counting.Counter(counter, "actor") + logger = loggers.make_default_logger( + "actor", + save_data=False, + time_delta=self._log_every, + steps_key="actor_steps", + ) + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, variable_source: acme.VariableSource, counter: counting.Counter, + ): + """The evaluation process.""" + + action_spec = self._environment_spec.actions + observation_spec = self._environment_spec.observations + + # Create environment and target networks to act with. + environment = self._environment_factory(True) + agent_networks = self._network_factory(action_spec, self._num_critic_heads) + + # Make sure observation network is defined. + observation_network = agent_networks.get("observation", tf.identity) + + # Create a deterministic behavior policy. + evaluator_modules = [ + observation_network, + agent_networks["policy"], + networks.StochasticMeanHead(), + ] + if isinstance(action_spec, specs.BoundedArray): + evaluator_modules += [networks.ClipToSpec(action_spec)] + evaluator_network = snt.Sequential(evaluator_modules) + + # Ensure network variables are created. + tf2_utils.create_variables(evaluator_network, [observation_spec]) + policy_variables = {"policy": evaluator_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, policy_variables, update_period=1000 + ) + + # Make sure not to evaluate a random actor by assigning variables before + # running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + evaluator = actors.FeedForwardActor( + policy_network=evaluator_network, variable_client=variable_client + ) + + # Create logger and counter. + counter = counting.Counter(counter, "evaluator") + logger = loggers.make_default_logger( + "evaluator", time_delta=self._log_every, steps_key="evaluator_steps" + ) + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, evaluator, counter, logger) + + def build(self, name="mompo"): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group("replay"): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group("counter"): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + _ = program.add_node( + lp.CourierNode(self.coordinator, counter, self._max_actor_steps) + ) + + with program.group("learner"): + learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) + + with program.group("evaluator"): + program.add_node(lp.CourierNode(self.evaluator, learner, counter)) + + if not self._num_caches: + # Use our learner as a single variable source. + sources = [learner] + else: + with program.group("cacher"): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000 + ) + ) + sources.append(cacher) + + with program.group("actor"): + # Add actors which pull round-robin from our variable sources. + for actor_id in range(self._num_actors): + source = sources[actor_id % len(sources)] + program.add_node(lp.CourierNode(self.actor, replay, source, counter)) + + return program diff --git a/acme/agents/tf/mompo/agent_distributed_test.py b/acme/agents/tf/mompo/agent_distributed_test.py index a76b20d2f0..8ead1436bd 100644 --- a/acme/agents/tf/mompo/agent_distributed_test.py +++ b/acme/agents/tf/mompo/agent_distributed_test.py @@ -16,21 +16,19 @@ from typing import Sequence, Tuple -import acme -from acme import specs -from acme import wrappers -from acme.agents.tf import mompo -from acme.tf import networks -from acme.tf import utils as tf2_utils -from acme.utils import lp_utils -from dm_control import suite import launchpad as lp import numpy as np import sonnet as snt import tensorflow as tf +from absl.testing import absltest, parameterized +from dm_control import suite -from absl.testing import absltest -from absl.testing import parameterized +import acme +from acme import specs, wrappers +from acme.agents.tf import mompo +from acme.tf import networks +from acme.tf import utils as tf2_utils +from acme.utils import lp_utils def make_networks( @@ -40,124 +38,130 @@ def make_networks( critic_layer_sizes: Sequence[int] = (50,), num_layers_shared: int = 1, distributional_critic: bool = True, - vmin: float = -150., - vmax: float = 150., + vmin: float = -150.0, + vmax: float = 150.0, num_atoms: int = 51, ): - """Creates networks used by the agent.""" - - num_dimensions = np.prod(action_spec.shape, dtype=int) - - policy_network = snt.Sequential([ - networks.LayerNormMLP(policy_layer_sizes, activate_final=True), - networks.MultivariateNormalDiagHead( - num_dimensions, - tanh_mean=False, - init_scale=0.69) - ]) - - if not distributional_critic: - critic_layer_sizes = list(critic_layer_sizes) + [1] - - if not num_layers_shared: - # No layers are shared - critic_network_base = None - else: - critic_network_base = networks.LayerNormMLP( - critic_layer_sizes[:num_layers_shared], activate_final=True) - critic_network_heads = [ - snt.nets.MLP(critic_layer_sizes, activation=tf.nn.elu, - activate_final=False) - for _ in range(num_critic_heads)] - if distributional_critic: + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential( + [ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, tanh_mean=False, init_scale=0.69 + ), + ] + ) + + if not distributional_critic: + critic_layer_sizes = list(critic_layer_sizes) + [1] + + if not num_layers_shared: + # No layers are shared + critic_network_base = None + else: + critic_network_base = networks.LayerNormMLP( + critic_layer_sizes[:num_layers_shared], activate_final=True + ) critic_network_heads = [ - snt.Sequential([ - c, networks.DiscreteValuedHead(vmin, vmax, num_atoms) - ]) for c in critic_network_heads] - # The multiplexer concatenates the (maybe transformed) observations/actions. - critic_network = snt.Sequential([ - networks.CriticMultiplexer( - critic_network=critic_network_base, - action_network=networks.ClipToSpec(action_spec)), - networks.Multihead(network_heads=critic_network_heads), - ]) - - return { - 'policy': policy_network, - 'critic': critic_network, - 'observation': tf2_utils.batch_concat, - } + snt.nets.MLP(critic_layer_sizes, activation=tf.nn.elu, activate_final=False) + for _ in range(num_critic_heads) + ] + if distributional_critic: + critic_network_heads = [ + snt.Sequential([c, networks.DiscreteValuedHead(vmin, vmax, num_atoms)]) + for c in critic_network_heads + ] + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = snt.Sequential( + [ + networks.CriticMultiplexer( + critic_network=critic_network_base, + action_network=networks.ClipToSpec(action_spec), + ), + networks.Multihead(network_heads=critic_network_heads), + ] + ) + + return { + "policy": policy_network, + "critic": critic_network, + "observation": tf2_utils.batch_concat, + } def make_environment(evaluation: bool = False): - del evaluation # Unused. - environment = suite.load('cartpole', 'balance') - wrapped = wrappers.SinglePrecisionWrapper(environment) - return wrapped + del evaluation # Unused. + environment = suite.load("cartpole", "balance") + wrapped = wrappers.SinglePrecisionWrapper(environment) + return wrapped -def compute_action_norm(target_pi_samples: tf.Tensor, - target_q_target_pi_samples: tf.Tensor) -> tf.Tensor: - """Compute Q-values for the action norm objective from action samples.""" - del target_q_target_pi_samples - action_norm = tf.norm(target_pi_samples, ord=2, axis=-1) - return tf.stop_gradient(-1 * action_norm) +def compute_action_norm( + target_pi_samples: tf.Tensor, target_q_target_pi_samples: tf.Tensor +) -> tf.Tensor: + """Compute Q-values for the action norm objective from action samples.""" + del target_q_target_pi_samples + action_norm = tf.norm(target_pi_samples, ord=2, axis=-1) + return tf.stop_gradient(-1 * action_norm) -def task_reward_fn(observation: tf.Tensor, - action: tf.Tensor, - reward: tf.Tensor) -> tf.Tensor: - del observation, action - return tf.stop_gradient(reward) +def task_reward_fn( + observation: tf.Tensor, action: tf.Tensor, reward: tf.Tensor +) -> tf.Tensor: + del observation, action + return tf.stop_gradient(reward) def make_objectives() -> Tuple[ - Sequence[mompo.RewardObjective], Sequence[mompo.QValueObjective]]: - """Define the multiple objectives for the policy to learn.""" - task_reward = mompo.RewardObjective( - name='task', - reward_fn=task_reward_fn) - action_norm = mompo.QValueObjective( - name='action_norm_q', - qvalue_fn=compute_action_norm) - return [task_reward], [action_norm] + Sequence[mompo.RewardObjective], Sequence[mompo.QValueObjective] +]: + """Define the multiple objectives for the policy to learn.""" + task_reward = mompo.RewardObjective(name="task", reward_fn=task_reward_fn) + action_norm = mompo.QValueObjective( + name="action_norm_q", qvalue_fn=compute_action_norm + ) + return [task_reward], [action_norm] class DistributedAgentTest(parameterized.TestCase): - """Simple integration/smoke test for the distributed agent.""" - - @parameterized.named_parameters( - ('distributional_critic', True), - ('vanilla_critic', False)) - def test_agent(self, distributional_critic): - # Create objectives. - reward_objectives, qvalue_objectives = make_objectives() - - network_factory = lp_utils.partial_kwargs( - make_networks, distributional_critic=distributional_critic) - - agent = mompo.DistributedMultiObjectiveMPO( - reward_objectives, - qvalue_objectives, - environment_factory=make_environment, - network_factory=network_factory, - num_actors=2, - batch_size=32, - min_replay_size=32, - max_replay_size=1000, + """Simple integration/smoke test for the distributed agent.""" + + @parameterized.named_parameters( + ("distributional_critic", True), ("vanilla_critic", False) ) - program = agent.build() + def test_agent(self, distributional_critic): + # Create objectives. + reward_objectives, qvalue_objectives = make_objectives() + + network_factory = lp_utils.partial_kwargs( + make_networks, distributional_critic=distributional_critic + ) + + agent = mompo.DistributedMultiObjectiveMPO( + reward_objectives, + qvalue_objectives, + environment_factory=make_environment, + network_factory=network_factory, + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + ) + program = agent.build() - (learner_node,) = program.groups['learner'] - learner_node.disable_run() + (learner_node,) = program.groups["learner"] + learner_node.disable_run() - lp.launch(program, launch_type='test_mt') + lp.launch(program, launch_type="test_mt") - learner: acme.Learner = learner_node.create_handle().dereference() + learner: acme.Learner = learner_node.create_handle().dereference() - for _ in range(5): - learner.step() + for _ in range(5): + learner.step() -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/mompo/agent_test.py b/acme/agents/tf/mompo/agent_test.py index c08a8b06c9..63295ec4da 100644 --- a/acme/agents/tf/mompo/agent_test.py +++ b/acme/agents/tf/mompo/agent_test.py @@ -16,17 +16,16 @@ from typing import Dict, Sequence, Tuple +import numpy as np +import sonnet as snt +import tensorflow as tf +from absl.testing import absltest, parameterized + import acme from acme import specs from acme.agents.tf import mompo from acme.testing import fakes from acme.tf import networks -import numpy as np -import sonnet as snt -import tensorflow as tf - -from absl.testing import absltest -from absl.testing import parameterized def make_networks( @@ -36,114 +35,119 @@ def make_networks( critic_layer_sizes: Sequence[int] = (400, 300), num_layers_shared: int = 1, distributional_critic: bool = True, - vmin: float = -150., - vmax: float = 150., + vmin: float = -150.0, + vmax: float = 150.0, num_atoms: int = 51, ) -> Dict[str, snt.Module]: - """Creates networks used by the agent.""" - - num_dimensions = np.prod(action_spec.shape, dtype=int) - - policy_network = snt.Sequential([ - networks.LayerNormMLP(policy_layer_sizes, activate_final=True), - networks.MultivariateNormalDiagHead( - num_dimensions, - tanh_mean=False, - init_scale=0.69) - ]) - - if not distributional_critic: - critic_layer_sizes = list(critic_layer_sizes) + [1] - - if not num_layers_shared: - # No layers are shared - critic_network_base = None - else: - critic_network_base = networks.LayerNormMLP( - critic_layer_sizes[:num_layers_shared], activate_final=True) - critic_network_heads = [ - snt.nets.MLP(critic_layer_sizes, activation=tf.nn.elu, - activate_final=False) - for _ in range(num_critic_heads)] - if distributional_critic: + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential( + [ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, tanh_mean=False, init_scale=0.69 + ), + ] + ) + + if not distributional_critic: + critic_layer_sizes = list(critic_layer_sizes) + [1] + + if not num_layers_shared: + # No layers are shared + critic_network_base = None + else: + critic_network_base = networks.LayerNormMLP( + critic_layer_sizes[:num_layers_shared], activate_final=True + ) critic_network_heads = [ - snt.Sequential([ - c, networks.DiscreteValuedHead(vmin, vmax, num_atoms) - ]) for c in critic_network_heads] - # The multiplexer concatenates the (maybe transformed) observations/actions. - critic_network = snt.Sequential([ - networks.CriticMultiplexer( - critic_network=critic_network_base), - networks.Multihead(network_heads=critic_network_heads), - ]) - return { - 'policy': policy_network, - 'critic': critic_network, - } - - -def compute_action_norm(target_pi_samples: tf.Tensor, - target_q_target_pi_samples: tf.Tensor) -> tf.Tensor: - """Compute Q-values for the action norm objective from action samples.""" - del target_q_target_pi_samples - action_norm = tf.norm(target_pi_samples, ord=2, axis=-1) - return tf.stop_gradient(-1 * action_norm) - - -def task_reward_fn(observation: tf.Tensor, - action: tf.Tensor, - reward: tf.Tensor) -> tf.Tensor: - del observation, action - return tf.stop_gradient(reward) + snt.nets.MLP(critic_layer_sizes, activation=tf.nn.elu, activate_final=False) + for _ in range(num_critic_heads) + ] + if distributional_critic: + critic_network_heads = [ + snt.Sequential([c, networks.DiscreteValuedHead(vmin, vmax, num_atoms)]) + for c in critic_network_heads + ] + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = snt.Sequential( + [ + networks.CriticMultiplexer(critic_network=critic_network_base), + networks.Multihead(network_heads=critic_network_heads), + ] + ) + return { + "policy": policy_network, + "critic": critic_network, + } + + +def compute_action_norm( + target_pi_samples: tf.Tensor, target_q_target_pi_samples: tf.Tensor +) -> tf.Tensor: + """Compute Q-values for the action norm objective from action samples.""" + del target_q_target_pi_samples + action_norm = tf.norm(target_pi_samples, ord=2, axis=-1) + return tf.stop_gradient(-1 * action_norm) + + +def task_reward_fn( + observation: tf.Tensor, action: tf.Tensor, reward: tf.Tensor +) -> tf.Tensor: + del observation, action + return tf.stop_gradient(reward) def make_objectives() -> Tuple[ - Sequence[mompo.RewardObjective], Sequence[mompo.QValueObjective]]: - """Define the multiple objectives for the policy to learn.""" - task_reward = mompo.RewardObjective( - name='task', - reward_fn=task_reward_fn) - action_norm = mompo.QValueObjective( - name='action_norm_q', - qvalue_fn=compute_action_norm) - return [task_reward], [action_norm] + Sequence[mompo.RewardObjective], Sequence[mompo.QValueObjective] +]: + """Define the multiple objectives for the policy to learn.""" + task_reward = mompo.RewardObjective(name="task", reward_fn=task_reward_fn) + action_norm = mompo.QValueObjective( + name="action_norm_q", qvalue_fn=compute_action_norm + ) + return [task_reward], [action_norm] class MOMPOTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('distributional_critic', True), - ('vanilla_critic', False)) - def test_mompo(self, distributional_critic): - # Create a fake environment to test with. - environment = fakes.ContinuousEnvironment(episode_length=10) - spec = specs.make_environment_spec(environment) - - # Create objectives. - reward_objectives, qvalue_objectives = make_objectives() - num_critic_heads = len(reward_objectives) - - # Create networks. - agent_networks = make_networks( - spec.actions, num_critic_heads=num_critic_heads, - distributional_critic=distributional_critic) - - # Construct the agent. - agent = mompo.MultiObjectiveMPO( - reward_objectives, - qvalue_objectives, - spec, - policy_network=agent_networks['policy'], - critic_network=agent_networks['critic'], - batch_size=10, - samples_per_insert=2, - min_replay_size=10) - - # Try running the environment loop. We have no assertions here because all - # we care about is that the agent runs without raising any errors. - loop = acme.EnvironmentLoop(environment, agent) - loop.run(num_episodes=2) - - -if __name__ == '__main__': - absltest.main() + @parameterized.named_parameters( + ("distributional_critic", True), ("vanilla_critic", False) + ) + def test_mompo(self, distributional_critic): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment(episode_length=10) + spec = specs.make_environment_spec(environment) + + # Create objectives. + reward_objectives, qvalue_objectives = make_objectives() + num_critic_heads = len(reward_objectives) + + # Create networks. + agent_networks = make_networks( + spec.actions, + num_critic_heads=num_critic_heads, + distributional_critic=distributional_critic, + ) + + # Construct the agent. + agent = mompo.MultiObjectiveMPO( + reward_objectives, + qvalue_objectives, + spec, + policy_network=agent_networks["policy"], + critic_network=agent_networks["critic"], + batch_size=10, + samples_per_insert=2, + min_replay_size=10, + ) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/mompo/learning.py b/acme/agents/tf/mompo/learning.py index 095e253bed..030291355b 100644 --- a/acme/agents/tf/mompo/learning.py +++ b/acme/agents/tf/mompo/learning.py @@ -18,52 +18,51 @@ import time from typing import Callable, List, Optional, Sequence -import acme -from acme import types -from acme.tf import losses -from acme.tf import networks -from acme.tf import savers as tf2_savers -from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers import numpy as np import sonnet as snt import tensorflow as tf import trfl +import acme +from acme import types +from acme.tf import losses, networks +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting, loggers + QValueFunctionSpec = Callable[[tf.Tensor, tf.Tensor], tf.Tensor] RewardFunctionSpec = Callable[[tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor] _DEFAULT_EPSILON = 1e-1 _DEFAULT_EPSILON_MEAN = 1e-3 _DEFAULT_EPSILON_STDDEV = 1e-6 -_DEFAULT_INIT_LOG_TEMPERATURE = 1. -_DEFAULT_INIT_LOG_ALPHA_MEAN = 1. -_DEFAULT_INIT_LOG_ALPHA_STDDEV = 10. +_DEFAULT_INIT_LOG_TEMPERATURE = 1.0 +_DEFAULT_INIT_LOG_ALPHA_MEAN = 1.0 +_DEFAULT_INIT_LOG_ALPHA_STDDEV = 10.0 @dataclasses.dataclass class QValueObjective: - """Defines an objective by specifying its 'Q-values' directly.""" + """Defines an objective by specifying its 'Q-values' directly.""" - name: str - # This computes "Q-values" directly from the sampled actions and other Q's. - qvalue_fn: QValueFunctionSpec + name: str + # This computes "Q-values" directly from the sampled actions and other Q's. + qvalue_fn: QValueFunctionSpec @dataclasses.dataclass class RewardObjective: - """Defines an objective by specifying its reward function.""" + """Defines an objective by specifying its reward function.""" - name: str - # This computes the reward from observations, actions, and environment task - # reward. In the learner, a head will automatically be added to the critic - # network, to learn Q-values for this objective. - reward_fn: RewardFunctionSpec + name: str + # This computes the reward from observations, actions, and environment task + # reward. In the learner, a head will automatically be added to the critic + # network, to learn Q-values for this objective. + reward_fn: RewardFunctionSpec class MultiObjectiveMPOLearner(acme.Learner): - """Distributional MPO learner. + """Distributional MPO learner. This is the learning component of a multi-objective MPO (MO-MPO) agent. Two sequences of objectives must be specified. Otherwise, the inputs are identical @@ -92,288 +91,317 @@ class MultiObjectiveMPOLearner(acme.Learner): (Abdolmaleki, Huang et al., 2020): https://arxiv.org/pdf/2005.07513.pdf """ - def __init__( - self, - reward_objectives: Sequence[RewardObjective], - qvalue_objectives: Sequence[QValueObjective], - policy_network: snt.Module, - critic_network: snt.Module, - target_policy_network: snt.Module, - target_critic_network: snt.Module, - discount: float, - num_samples: int, - target_policy_update_period: int, - target_critic_update_period: int, - dataset: tf.data.Dataset, - observation_network: types.TensorTransformation = tf.identity, - target_observation_network: types.TensorTransformation = tf.identity, - policy_loss_module: Optional[losses.MultiObjectiveMPO] = None, - policy_optimizer: Optional[snt.Optimizer] = None, - critic_optimizer: Optional[snt.Optimizer] = None, - dual_optimizer: Optional[snt.Optimizer] = None, - clipping: bool = True, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - checkpoint: bool = True, - ): - - # Store online and target networks. - self._policy_network = policy_network - self._critic_network = critic_network - self._target_policy_network = target_policy_network - self._target_critic_network = target_critic_network - - # Make sure observation networks are snt.Module's so they have variables. - self._observation_network = tf2_utils.to_sonnet_module(observation_network) - self._target_observation_network = tf2_utils.to_sonnet_module( - target_observation_network) - - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger('learner') - - # Other learner parameters. - self._discount = discount - self._num_samples = num_samples - self._clipping = clipping - - # Necessary to track when to update target networks. - self._num_steps = tf.Variable(0, dtype=tf.int32) - self._target_policy_update_period = target_policy_update_period - self._target_critic_update_period = target_critic_update_period - - # Batch dataset and create iterator. - # TODO(b/155086959): Fix type stubs and remove. - self._iterator = iter(dataset) # pytype: disable=wrong-arg-types - - # Store objectives - self._reward_objectives = reward_objectives - self._qvalue_objectives = qvalue_objectives - if self._qvalue_objectives is None: - self._qvalue_objectives = [] - self._num_critic_heads = len(self._reward_objectives) # C - self._objective_names = ( - [x.name for x in self._reward_objectives] + - [x.name for x in self._qvalue_objectives]) - - self._policy_loss_module = policy_loss_module or losses.MultiObjectiveMPO( - epsilons=[losses.KLConstraint(name, _DEFAULT_EPSILON) - for name in self._objective_names], - epsilon_mean=_DEFAULT_EPSILON_MEAN, - epsilon_stddev=_DEFAULT_EPSILON_STDDEV, - init_log_temperature=_DEFAULT_INIT_LOG_TEMPERATURE, - init_log_alpha_mean=_DEFAULT_INIT_LOG_ALPHA_MEAN, - init_log_alpha_stddev=_DEFAULT_INIT_LOG_ALPHA_STDDEV) - - # Check that ordering of objectives matches the policy_loss_module's - if self._objective_names != list(self._policy_loss_module.objective_names): - raise ValueError("Agent's ordering of objectives doesn't match " - "the policy loss module's ordering of epsilons.") - - # Create the optimizers. - self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) - self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) - self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) - - # Expose the variables. - policy_network_to_expose = snt.Sequential( - [self._target_observation_network, self._target_policy_network]) - self._variables = { - 'critic': self._target_critic_network.variables, - 'policy': policy_network_to_expose.variables, - } - - # Create a checkpointer and snapshotter object. - self._checkpointer = None - self._snapshotter = None - - if checkpoint: - self._checkpointer = tf2_savers.Checkpointer( - subdirectory='mompo_learner', - objects_to_save={ - 'counter': self._counter, - 'policy': self._policy_network, - 'critic': self._critic_network, - 'observation': self._observation_network, - 'target_policy': self._target_policy_network, - 'target_critic': self._target_critic_network, - 'target_observation': self._target_observation_network, - 'policy_optimizer': self._policy_optimizer, - 'critic_optimizer': self._critic_optimizer, - 'dual_optimizer': self._dual_optimizer, - 'policy_loss_module': self._policy_loss_module, - 'num_steps': self._num_steps, - }) - - self._snapshotter = tf2_savers.Snapshotter( - objects_to_save={ - 'policy': - snt.Sequential([ - self._target_observation_network, - self._target_policy_network - ]), - }) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp: float = None - - @tf.function - def _step(self) -> types.NestedTensor: - # Update target network. - online_policy_variables = self._policy_network.variables - target_policy_variables = self._target_policy_network.variables - online_critic_variables = ( - *self._observation_network.variables, - *self._critic_network.variables, - ) - target_critic_variables = ( - *self._target_observation_network.variables, - *self._target_critic_network.variables, - ) - - # Make online policy -> target policy network update ops. - if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: - for src, dest in zip(online_policy_variables, target_policy_variables): - dest.assign(src) - # Make online critic -> target critic network update ops. - if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: - for src, dest in zip(online_critic_variables, target_critic_variables): - dest.assign(src) - - self._num_steps.assign_add(1) - - # Get data from replay (dropping extras if any). Note there is no - # extra data here because we do not insert any into Reverb. - inputs = next(self._iterator) - transitions: types.Transition = inputs.data - - with tf.GradientTape(persistent=True) as tape: - # Maybe transform the observation before feeding into policy and critic. - # Transforming the observations this way at the start of the learning - # step effectively means that the policy and critic share observation - # network weights. - o_tm1 = self._observation_network(transitions.observation) - # This stop_gradient prevents gradients to propagate into the target - # observation network. In addition, since the online policy network is - # evaluated at o_t, this also means the policy loss does not influence - # the observation network training. - o_t = tf.stop_gradient( - self._target_observation_network(transitions.next_observation)) - - # Get online and target action distributions from policy networks. - online_action_distribution = self._policy_network(o_t) - target_action_distribution = self._target_policy_network(o_t) - - # Sample actions to evaluate policy; of size [N, B, ...]. - sampled_actions = target_action_distribution.sample(self._num_samples) - - # Tile embedded observations to feed into the target critic network. - # Note: this is more efficient than tiling before the embedding layer. - tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] - - # Compute target-estimated distributional value of sampled actions at o_t. - sampled_q_t_all = self._target_critic_network( - # Merge batch dimensions; to shape [N*B, ...]. - snt.merge_leading_dims(tiled_o_t, num_dims=2), - snt.merge_leading_dims(sampled_actions, num_dims=2)) - - # Compute online critic value distribution of a_tm1 in state o_tm1. - q_tm1_all = self._critic_network(o_tm1, transitions.action) - - # Compute rewards for objectives with defined reward_fn - reward_stats = {} - r_t_all = [] - for objective in self._reward_objectives: - r = objective.reward_fn(o_tm1, transitions.action, transitions.reward) - reward_stats['{}_reward'.format(objective.name)] = tf.reduce_mean(r) - r_t_all.append(r) - r_t_all = tf.stack(r_t_all, axis=-1) - r_t_all.get_shape().assert_has_rank(2) # [B, C] - - if isinstance(sampled_q_t_all, list): # Distributional critics - critic_loss, sampled_q_t = _compute_distributional_critic_loss( - sampled_q_t_all, q_tm1_all, r_t_all, transitions.discount, - self._discount, self._num_samples) - else: - critic_loss, sampled_q_t = _compute_critic_loss( - sampled_q_t_all, q_tm1_all, r_t_all, transitions.discount, - self._discount, self._num_samples, self._num_critic_heads) - - # Add sampled Q-values for objectives with defined qvalue_fn - sampled_q_t_k = [sampled_q_t] - for objective in self._qvalue_objectives: - sampled_q_t_k.append(tf.expand_dims(tf.stop_gradient( - objective.qvalue_fn(sampled_actions, sampled_q_t)), axis=-1)) - sampled_q_t_k = tf.concat(sampled_q_t_k, axis=-1) # [N, B, K] - - # Compute MPO policy loss. - policy_loss, policy_stats = self._policy_loss_module( - online_action_distribution=online_action_distribution, - target_action_distribution=target_action_distribution, - actions=sampled_actions, - q_values=sampled_q_t_k) - - # For clarity, explicitly define which variables are trained by which loss. - critic_trainable_variables = ( - # In this agent, the critic loss trains the observation network. - self._observation_network.trainable_variables + - self._critic_network.trainable_variables) - policy_trainable_variables = self._policy_network.trainable_variables - # The following are the MPO dual variables, stored in the loss module. - dual_trainable_variables = self._policy_loss_module.trainable_variables - - # Compute gradients. - critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) - policy_gradients, dual_gradients = tape.gradient( - policy_loss, (policy_trainable_variables, dual_trainable_variables)) - - # Delete the tape manually because of the persistent=True flag. - del tape - - # Maybe clip gradients. - if self._clipping: - policy_gradients = tuple(tf.clip_by_global_norm(policy_gradients, 40.)[0]) - critic_gradients = tuple(tf.clip_by_global_norm(critic_gradients, 40.)[0]) - - # Apply gradients. - self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) - self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) - self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) - - # Losses to track. - fetches = { - 'critic_loss': critic_loss, - 'policy_loss': policy_loss, - } - fetches.update(policy_stats) # Log MPO stats. - fetches.update(reward_stats) # Log reward stats. - - return fetches - - def step(self): - # Run the learning step. - fetches = self._step() - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Update our counts and record it. - counts = self._counter.increment(steps=1, walltime=elapsed_time) - fetches.update(counts) - - # Checkpoint and attempt to write the logs. - if self._checkpointer is not None: - self._checkpointer.save() - if self._snapshotter is not None: - self._snapshotter.save() - self._logger.write(fetches) - - def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: - return [tf2_utils.to_numpy(self._variables[name]) for name in names] + def __init__( + self, + reward_objectives: Sequence[RewardObjective], + qvalue_objectives: Sequence[QValueObjective], + policy_network: snt.Module, + critic_network: snt.Module, + target_policy_network: snt.Module, + target_critic_network: snt.Module, + discount: float, + num_samples: int, + target_policy_update_period: int, + target_critic_update_period: int, + dataset: tf.data.Dataset, + observation_network: types.TensorTransformation = tf.identity, + target_observation_network: types.TensorTransformation = tf.identity, + policy_loss_module: Optional[losses.MultiObjectiveMPO] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + dual_optimizer: Optional[snt.Optimizer] = None, + clipping: bool = True, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + + # Make sure observation networks are snt.Module's so they have variables. + self._observation_network = tf2_utils.to_sonnet_module(observation_network) + self._target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network + ) + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger("learner") + + # Other learner parameters. + self._discount = discount + self._num_samples = num_samples + self._clipping = clipping + + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_policy_update_period = target_policy_update_period + self._target_critic_update_period = target_critic_update_period + + # Batch dataset and create iterator. + # TODO(b/155086959): Fix type stubs and remove. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + + # Store objectives + self._reward_objectives = reward_objectives + self._qvalue_objectives = qvalue_objectives + if self._qvalue_objectives is None: + self._qvalue_objectives = [] + self._num_critic_heads = len(self._reward_objectives) # C + self._objective_names = [x.name for x in self._reward_objectives] + [ + x.name for x in self._qvalue_objectives + ] + + self._policy_loss_module = policy_loss_module or losses.MultiObjectiveMPO( + epsilons=[ + losses.KLConstraint(name, _DEFAULT_EPSILON) + for name in self._objective_names + ], + epsilon_mean=_DEFAULT_EPSILON_MEAN, + epsilon_stddev=_DEFAULT_EPSILON_STDDEV, + init_log_temperature=_DEFAULT_INIT_LOG_TEMPERATURE, + init_log_alpha_mean=_DEFAULT_INIT_LOG_ALPHA_MEAN, + init_log_alpha_stddev=_DEFAULT_INIT_LOG_ALPHA_STDDEV, + ) + + # Check that ordering of objectives matches the policy_loss_module's + if self._objective_names != list(self._policy_loss_module.objective_names): + raise ValueError( + "Agent's ordering of objectives doesn't match " + "the policy loss module's ordering of epsilons." + ) + + # Create the optimizers. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) + + # Expose the variables. + policy_network_to_expose = snt.Sequential( + [self._target_observation_network, self._target_policy_network] + ) + self._variables = { + "critic": self._target_critic_network.variables, + "policy": policy_network_to_expose.variables, + } + + # Create a checkpointer and snapshotter object. + self._checkpointer = None + self._snapshotter = None + + if checkpoint: + self._checkpointer = tf2_savers.Checkpointer( + subdirectory="mompo_learner", + objects_to_save={ + "counter": self._counter, + "policy": self._policy_network, + "critic": self._critic_network, + "observation": self._observation_network, + "target_policy": self._target_policy_network, + "target_critic": self._target_critic_network, + "target_observation": self._target_observation_network, + "policy_optimizer": self._policy_optimizer, + "critic_optimizer": self._critic_optimizer, + "dual_optimizer": self._dual_optimizer, + "policy_loss_module": self._policy_loss_module, + "num_steps": self._num_steps, + }, + ) + + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={ + "policy": snt.Sequential( + [self._target_observation_network, self._target_policy_network] + ), + } + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp: float = None + + @tf.function + def _step(self) -> types.NestedTensor: + # Update target network. + online_policy_variables = self._policy_network.variables + target_policy_variables = self._target_policy_network.variables + online_critic_variables = ( + *self._observation_network.variables, + *self._critic_network.variables, + ) + target_critic_variables = ( + *self._target_observation_network.variables, + *self._target_critic_network.variables, + ) + + # Make online policy -> target policy network update ops. + if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: + for src, dest in zip(online_policy_variables, target_policy_variables): + dest.assign(src) + # Make online critic -> target critic network update ops. + if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: + for src, dest in zip(online_critic_variables, target_critic_variables): + dest.assign(src) + + self._num_steps.assign_add(1) + + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + inputs = next(self._iterator) + transitions: types.Transition = inputs.data + + with tf.GradientTape(persistent=True) as tape: + # Maybe transform the observation before feeding into policy and critic. + # Transforming the observations this way at the start of the learning + # step effectively means that the policy and critic share observation + # network weights. + o_tm1 = self._observation_network(transitions.observation) + # This stop_gradient prevents gradients to propagate into the target + # observation network. In addition, since the online policy network is + # evaluated at o_t, this also means the policy loss does not influence + # the observation network training. + o_t = tf.stop_gradient( + self._target_observation_network(transitions.next_observation) + ) + + # Get online and target action distributions from policy networks. + online_action_distribution = self._policy_network(o_t) + target_action_distribution = self._target_policy_network(o_t) + + # Sample actions to evaluate policy; of size [N, B, ...]. + sampled_actions = target_action_distribution.sample(self._num_samples) + + # Tile embedded observations to feed into the target critic network. + # Note: this is more efficient than tiling before the embedding layer. + tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] + + # Compute target-estimated distributional value of sampled actions at o_t. + sampled_q_t_all = self._target_critic_network( + # Merge batch dimensions; to shape [N*B, ...]. + snt.merge_leading_dims(tiled_o_t, num_dims=2), + snt.merge_leading_dims(sampled_actions, num_dims=2), + ) + + # Compute online critic value distribution of a_tm1 in state o_tm1. + q_tm1_all = self._critic_network(o_tm1, transitions.action) + + # Compute rewards for objectives with defined reward_fn + reward_stats = {} + r_t_all = [] + for objective in self._reward_objectives: + r = objective.reward_fn(o_tm1, transitions.action, transitions.reward) + reward_stats["{}_reward".format(objective.name)] = tf.reduce_mean(r) + r_t_all.append(r) + r_t_all = tf.stack(r_t_all, axis=-1) + r_t_all.get_shape().assert_has_rank(2) # [B, C] + + if isinstance(sampled_q_t_all, list): # Distributional critics + critic_loss, sampled_q_t = _compute_distributional_critic_loss( + sampled_q_t_all, + q_tm1_all, + r_t_all, + transitions.discount, + self._discount, + self._num_samples, + ) + else: + critic_loss, sampled_q_t = _compute_critic_loss( + sampled_q_t_all, + q_tm1_all, + r_t_all, + transitions.discount, + self._discount, + self._num_samples, + self._num_critic_heads, + ) + + # Add sampled Q-values for objectives with defined qvalue_fn + sampled_q_t_k = [sampled_q_t] + for objective in self._qvalue_objectives: + sampled_q_t_k.append( + tf.expand_dims( + tf.stop_gradient( + objective.qvalue_fn(sampled_actions, sampled_q_t) + ), + axis=-1, + ) + ) + sampled_q_t_k = tf.concat(sampled_q_t_k, axis=-1) # [N, B, K] + + # Compute MPO policy loss. + policy_loss, policy_stats = self._policy_loss_module( + online_action_distribution=online_action_distribution, + target_action_distribution=target_action_distribution, + actions=sampled_actions, + q_values=sampled_q_t_k, + ) + + # For clarity, explicitly define which variables are trained by which loss. + critic_trainable_variables = ( + # In this agent, the critic loss trains the observation network. + self._observation_network.trainable_variables + + self._critic_network.trainable_variables + ) + policy_trainable_variables = self._policy_network.trainable_variables + # The following are the MPO dual variables, stored in the loss module. + dual_trainable_variables = self._policy_loss_module.trainable_variables + + # Compute gradients. + critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) + policy_gradients, dual_gradients = tape.gradient( + policy_loss, (policy_trainable_variables, dual_trainable_variables) + ) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Maybe clip gradients. + if self._clipping: + policy_gradients = tuple(tf.clip_by_global_norm(policy_gradients, 40.0)[0]) + critic_gradients = tuple(tf.clip_by_global_norm(critic_gradients, 40.0)[0]) + + # Apply gradients. + self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) + self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) + self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) + + # Losses to track. + fetches = { + "critic_loss": critic_loss, + "policy_loss": policy_loss, + } + fetches.update(policy_stats) # Log MPO stats. + fetches.update(reward_stats) # Log reward stats. + + return fetches + + def step(self): + # Run the learning step. + fetches = self._step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + if self._checkpointer is not None: + self._checkpointer.save() + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] def _compute_distributional_critic_loss( @@ -382,41 +410,45 @@ def _compute_distributional_critic_loss( r_t_all: tf.Tensor, d_t: tf.Tensor, discount: float, - num_samples: int): - """Compute loss and sampled Q-values for distributional critics.""" - # Compute average logits by first reshaping them and normalizing them - # across atoms. - batch_size = r_t_all.get_shape()[0] - # Cast the additional discount to match the environment discount dtype. - discount = tf.cast(discount, dtype=d_t.dtype) - critic_losses = [] - sampled_q_ts = [] - for idx, (sampled_q_t_distributions, q_tm1_distribution) in enumerate( - zip(sampled_q_t_all, q_tm1_all)): - # Compute loss for distributional critic for objective c - sampled_logits = tf.reshape( - sampled_q_t_distributions.logits, - [num_samples, batch_size, -1]) # [N, B, A] - sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) - averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) - - # Construct the expected distributional value for bootstrapping. - q_t_distribution = networks.DiscreteValuedDistribution( - values=sampled_q_t_distributions.values, logits=averaged_logits) - - # Compute critic distributional loss. - critic_loss = losses.categorical( - q_tm1_distribution, r_t_all[:, idx], discount * d_t, - q_t_distribution) - critic_losses.append(tf.reduce_mean(critic_loss)) - - # Compute Q-values of sampled actions and reshape to [N, B]. - sampled_q_ts.append(tf.reshape( - sampled_q_t_distributions.mean(), (num_samples, -1))) - - critic_loss = tf.reduce_mean(critic_losses) - sampled_q_t = tf.stack(sampled_q_ts, axis=-1) # [N, B, C] - return critic_loss, sampled_q_t + num_samples: int, +): + """Compute loss and sampled Q-values for distributional critics.""" + # Compute average logits by first reshaping them and normalizing them + # across atoms. + batch_size = r_t_all.get_shape()[0] + # Cast the additional discount to match the environment discount dtype. + discount = tf.cast(discount, dtype=d_t.dtype) + critic_losses = [] + sampled_q_ts = [] + for idx, (sampled_q_t_distributions, q_tm1_distribution) in enumerate( + zip(sampled_q_t_all, q_tm1_all) + ): + # Compute loss for distributional critic for objective c + sampled_logits = tf.reshape( + sampled_q_t_distributions.logits, [num_samples, batch_size, -1] + ) # [N, B, A] + sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) + averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) + + # Construct the expected distributional value for bootstrapping. + q_t_distribution = networks.DiscreteValuedDistribution( + values=sampled_q_t_distributions.values, logits=averaged_logits + ) + + # Compute critic distributional loss. + critic_loss = losses.categorical( + q_tm1_distribution, r_t_all[:, idx], discount * d_t, q_t_distribution + ) + critic_losses.append(tf.reduce_mean(critic_loss)) + + # Compute Q-values of sampled actions and reshape to [N, B]. + sampled_q_ts.append( + tf.reshape(sampled_q_t_distributions.mean(), (num_samples, -1)) + ) + + critic_loss = tf.reduce_mean(critic_losses) + sampled_q_t = tf.stack(sampled_q_ts, axis=-1) # [N, B, C] + return critic_loss, sampled_q_t def _compute_critic_loss( @@ -426,29 +458,30 @@ def _compute_critic_loss( d_t: tf.Tensor, discount: float, num_samples: int, - num_critic_heads: int): - """Compute loss and sampled Q-values for (non-distributional) critics.""" - # Reshape Q-value samples back to original batch dimensions and average - # them to compute the TD-learning bootstrap target. - batch_size = r_t_all.get_shape()[0] - sampled_q_t = tf.reshape( - sampled_q_t_all, - (num_samples, batch_size, num_critic_heads)) # [N,B,C] - q_t = tf.reduce_mean(sampled_q_t, axis=0) # [B, C] - - # Flatten q_t and q_tm1; necessary for trfl.td_learning - q_t = tf.reshape(q_t, [-1]) # [B*C] - q_tm1 = tf.reshape(q_tm1_all, [-1]) # [B*C] - - # Flatten r_t_all; necessary for trfl.td_learning - r_t_all = tf.reshape(r_t_all, [-1]) # [B*C] - - # Broadcast and then flatten d_t, to match shape of q_t and q_tm1 - d_t = tf.tile(d_t, [num_critic_heads]) # [B*C] - # Cast the additional discount to match the environment discount dtype. - discount = tf.cast(discount, dtype=d_t.dtype) - - # Critic loss. - critic_loss = trfl.td_learning(q_tm1, r_t_all, discount * d_t, q_t).loss - critic_loss = tf.reduce_mean(critic_loss) - return critic_loss, sampled_q_t + num_critic_heads: int, +): + """Compute loss and sampled Q-values for (non-distributional) critics.""" + # Reshape Q-value samples back to original batch dimensions and average + # them to compute the TD-learning bootstrap target. + batch_size = r_t_all.get_shape()[0] + sampled_q_t = tf.reshape( + sampled_q_t_all, (num_samples, batch_size, num_critic_heads) + ) # [N,B,C] + q_t = tf.reduce_mean(sampled_q_t, axis=0) # [B, C] + + # Flatten q_t and q_tm1; necessary for trfl.td_learning + q_t = tf.reshape(q_t, [-1]) # [B*C] + q_tm1 = tf.reshape(q_tm1_all, [-1]) # [B*C] + + # Flatten r_t_all; necessary for trfl.td_learning + r_t_all = tf.reshape(r_t_all, [-1]) # [B*C] + + # Broadcast and then flatten d_t, to match shape of q_t and q_tm1 + d_t = tf.tile(d_t, [num_critic_heads]) # [B*C] + # Cast the additional discount to match the environment discount dtype. + discount = tf.cast(discount, dtype=d_t.dtype) + + # Critic loss. + critic_loss = trfl.td_learning(q_tm1, r_t_all, discount * d_t, q_t).loss + critic_loss = tf.reduce_mean(critic_loss) + return critic_loss, sampled_q_t diff --git a/acme/agents/tf/mpo/agent.py b/acme/agents/tf/mpo/agent.py index 82f60392f9..7b17d6b253 100644 --- a/acme/agents/tf/mpo/agent.py +++ b/acme/agents/tf/mpo/agent.py @@ -17,24 +17,22 @@ import copy from typing import Optional -from acme import datasets -from acme import specs -from acme import types +import reverb +import sonnet as snt +import tensorflow as tf + +from acme import datasets, specs, types from acme.adders import reverb as adders from acme.agents import agent from acme.agents.tf import actors from acme.agents.tf.mpo import learning from acme.tf import networks from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers -import reverb -import sonnet as snt -import tensorflow as tf +from acme.utils import counting, loggers class MPO(agent.Agent): - """MPO Agent. + """MPO Agent. This implements a single-process MPO agent. This is an actor-critic algorithm that generates data via a behavior policy, inserts N-step transitions into @@ -43,33 +41,33 @@ class MPO(agent.Agent): itself from the DPG agent by using MPO to learn a stochastic policy. """ - def __init__( - self, - environment_spec: specs.EnvironmentSpec, - policy_network: snt.Module, - critic_network: snt.Module, - observation_network: types.TensorTransformation = tf.identity, - discount: float = 0.99, - batch_size: int = 256, - prefetch_size: int = 4, - target_policy_update_period: int = 100, - target_critic_update_period: int = 100, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: float = 32.0, - policy_loss_module: Optional[snt.Module] = None, - policy_optimizer: Optional[snt.Optimizer] = None, - critic_optimizer: Optional[snt.Optimizer] = None, - n_step: int = 5, - num_samples: int = 20, - clipping: bool = True, - logger: Optional[loggers.Logger] = None, - counter: Optional[counting.Counter] = None, - checkpoint: bool = True, - save_directory: str = '~/acme', - replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, - ): - """Initialize the agent. + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + policy_network: snt.Module, + critic_network: snt.Module, + observation_network: types.TensorTransformation = tf.identity, + discount: float = 0.99, + batch_size: int = 256, + prefetch_size: int = 4, + target_policy_update_period: int = 100, + target_critic_update_period: int = 100, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 32.0, + policy_loss_module: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + n_step: int = 5, + num_samples: int = 20, + clipping: bool = True, + logger: Optional[loggers.Logger] = None, + counter: Optional[counting.Counter] = None, + checkpoint: bool = True, + save_directory: str = "~/acme", + replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, + ): + """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. @@ -105,87 +103,89 @@ def __init__( replay_table_name: string indicating what name to give the replay table. """ - # Create a replay server to add data to. - replay_table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=max_replay_size, - rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), - signature=adders.NStepTransitionAdder.signature(environment_spec)) - self._server = reverb.Server([replay_table], port=None) - - # The adder is used to insert observations into replay. - address = f'localhost:{self._server.port}' - adder = adders.NStepTransitionAdder( - client=reverb.Client(address), n_step=n_step, discount=discount) - - # The dataset object to learn from. - dataset = datasets.make_reverb_dataset( - table=replay_table_name, - server_address=address, - batch_size=batch_size, - prefetch_size=prefetch_size) - - # Make sure observation network is a Sonnet Module. - observation_network = tf2_utils.to_sonnet_module(observation_network) - - # Create target networks before creating online/target network variables. - target_policy_network = copy.deepcopy(policy_network) - target_critic_network = copy.deepcopy(critic_network) - target_observation_network = copy.deepcopy(observation_network) - - # Get observation and action specs. - act_spec = environment_spec.actions - obs_spec = environment_spec.observations - emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) - - # Create the behavior policy. - behavior_network = snt.Sequential([ - observation_network, - policy_network, - networks.StochasticSamplingHead(), - ]) - - # Create variables. - tf2_utils.create_variables(policy_network, [emb_spec]) - tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) - tf2_utils.create_variables(target_policy_network, [emb_spec]) - tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) - tf2_utils.create_variables(target_observation_network, [obs_spec]) - - # Create the actor which defines how we take actions. - actor = actors.FeedForwardActor( - policy_network=behavior_network, adder=adder) - - # Create optimizers. - policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) - critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) - - # The learner updates the parameters (and initializes them). - learner = learning.MPOLearner( - policy_network=policy_network, - critic_network=critic_network, - observation_network=observation_network, - target_policy_network=target_policy_network, - target_critic_network=target_critic_network, - target_observation_network=target_observation_network, - policy_loss_module=policy_loss_module, - policy_optimizer=policy_optimizer, - critic_optimizer=critic_optimizer, - clipping=clipping, - discount=discount, - num_samples=num_samples, - target_policy_update_period=target_policy_update_period, - target_critic_update_period=target_critic_update_period, - dataset=dataset, - logger=logger, - counter=counter, - checkpoint=checkpoint, - save_directory=save_directory) - - super().__init__( - actor=actor, - learner=learner, - min_observations=max(batch_size, min_replay_size), - observations_per_step=float(batch_size) / samples_per_insert) + # Create a replay server to add data to. + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), + signature=adders.NStepTransitionAdder.signature(environment_spec), + ) + self._server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f"localhost:{self._server.port}" + adder = adders.NStepTransitionAdder( + client=reverb.Client(address), n_step=n_step, discount=discount + ) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + table=replay_table_name, + server_address=address, + batch_size=batch_size, + prefetch_size=prefetch_size, + ) + + # Make sure observation network is a Sonnet Module. + observation_network = tf2_utils.to_sonnet_module(observation_network) + + # Create target networks before creating online/target network variables. + target_policy_network = copy.deepcopy(policy_network) + target_critic_network = copy.deepcopy(critic_network) + target_observation_network = copy.deepcopy(observation_network) + + # Get observation and action specs. + act_spec = environment_spec.actions + obs_spec = environment_spec.observations + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + + # Create the behavior policy. + behavior_network = snt.Sequential( + [observation_network, policy_network, networks.StochasticSamplingHead(),] + ) + + # Create variables. + tf2_utils.create_variables(policy_network, [emb_spec]) + tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) + tf2_utils.create_variables(target_policy_network, [emb_spec]) + tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) + tf2_utils.create_variables(target_observation_network, [obs_spec]) + + # Create the actor which defines how we take actions. + actor = actors.FeedForwardActor(policy_network=behavior_network, adder=adder) + + # Create optimizers. + policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + + # The learner updates the parameters (and initializes them). + learner = learning.MPOLearner( + policy_network=policy_network, + critic_network=critic_network, + observation_network=observation_network, + target_policy_network=target_policy_network, + target_critic_network=target_critic_network, + target_observation_network=target_observation_network, + policy_loss_module=policy_loss_module, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + clipping=clipping, + discount=discount, + num_samples=num_samples, + target_policy_update_period=target_policy_update_period, + target_critic_update_period=target_critic_update_period, + dataset=dataset, + logger=logger, + counter=counter, + checkpoint=checkpoint, + save_directory=save_directory, + ) + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert, + ) diff --git a/acme/agents/tf/mpo/agent_distributed.py b/acme/agents/tf/mpo/agent_distributed.py index 6e7799c756..82ecdd6b76 100644 --- a/acme/agents/tf/mpo/agent_distributed.py +++ b/acme/agents/tf/mpo/agent_distributed.py @@ -16,9 +16,14 @@ from typing import Callable, Dict, Optional +import dm_env +import launchpad as lp +import reverb +import sonnet as snt +import tensorflow as tf + import acme -from acme import datasets -from acme import specs +from acme import datasets, specs from acme.adders import reverb as adders from acme.agents.tf import actors from acme.agents.tf.mpo import learning @@ -26,313 +31,312 @@ from acme.tf import savers as tf2_savers from acme.tf import utils as tf2_utils from acme.tf import variable_utils as tf2_variable_utils -from acme.utils import counting -from acme.utils import loggers -from acme.utils import lp_utils -import dm_env -import launchpad as lp -import reverb -import sonnet as snt -import tensorflow as tf +from acme.utils import counting, loggers, lp_utils class DistributedMPO: - """Program definition for MPO.""" - - def __init__( - self, - environment_factory: Callable[[bool], dm_env.Environment], - network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], - num_actors: int = 1, - num_caches: int = 0, - environment_spec: Optional[specs.EnvironmentSpec] = None, - batch_size: int = 256, - prefetch_size: int = 4, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: Optional[float] = 32.0, - n_step: int = 5, - num_samples: int = 20, - additional_discount: float = 0.99, - target_policy_update_period: int = 100, - target_critic_update_period: int = 100, - variable_update_period: int = 1000, - policy_loss_factory: Optional[Callable[[], snt.Module]] = None, - max_actor_steps: Optional[int] = None, - log_every: float = 10.0, - ): - - if environment_spec is None: - environment_spec = specs.make_environment_spec(environment_factory(False)) - - self._environment_factory = environment_factory - self._network_factory = network_factory - self._policy_loss_factory = policy_loss_factory - self._environment_spec = environment_spec - self._num_actors = num_actors - self._num_caches = num_caches - self._batch_size = batch_size - self._prefetch_size = prefetch_size - self._min_replay_size = min_replay_size - self._max_replay_size = max_replay_size - self._samples_per_insert = samples_per_insert - self._n_step = n_step - self._additional_discount = additional_discount - self._num_samples = num_samples - self._target_policy_update_period = target_policy_update_period - self._target_critic_update_period = target_critic_update_period - self._variable_update_period = variable_update_period - self._max_actor_steps = max_actor_steps - self._log_every = log_every - - def replay(self): - """The replay storage.""" - if self._samples_per_insert is not None: - # Create enough of an error buffer to give a 10% tolerance in rate. - samples_per_insert_tolerance = 0.1 * self._samples_per_insert - error_buffer = self._min_replay_size * samples_per_insert_tolerance - - limiter = reverb.rate_limiters.SampleToInsertRatio( - min_size_to_sample=self._min_replay_size, - samples_per_insert=self._samples_per_insert, - error_buffer=error_buffer) - else: - limiter = reverb.rate_limiters.MinSize( - min_size_to_sample=self._min_replay_size) - replay_table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=self._max_replay_size, - rate_limiter=limiter, - signature=adders.NStepTransitionAdder.signature( - self._environment_spec)) - return [replay_table] - - def counter(self): - return tf2_savers.CheckpointingRunner(counting.Counter(), - time_delta_minutes=1, - subdirectory='counter') - - def coordinator(self, counter: counting.Counter, max_actor_steps: int): - return lp_utils.StepsLimiter(counter, max_actor_steps) - - def learner( - self, - replay: reverb.Client, - counter: counting.Counter, - ): - """The Learning part of the agent.""" - - act_spec = self._environment_spec.actions - obs_spec = self._environment_spec.observations - - # Create online and target networks. - online_networks = self._network_factory(act_spec) - target_networks = self._network_factory(act_spec) - - # Make sure observation networks are Sonnet Modules. - observation_network = online_networks.get('observation', tf.identity) - observation_network = tf2_utils.to_sonnet_module(observation_network) - online_networks['observation'] = observation_network - target_observation_network = target_networks.get('observation', tf.identity) - target_observation_network = tf2_utils.to_sonnet_module( - target_observation_network) - target_networks['observation'] = target_observation_network - - # Get embedding spec and create observation network variables. - emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) - - tf2_utils.create_variables(online_networks['policy'], [emb_spec]) - tf2_utils.create_variables(online_networks['critic'], [emb_spec, act_spec]) - tf2_utils.create_variables(target_networks['observation'], [obs_spec]) - tf2_utils.create_variables(target_networks['policy'], [emb_spec]) - tf2_utils.create_variables(target_networks['critic'], [emb_spec, act_spec]) - - # The dataset object to learn from. - dataset = datasets.make_reverb_dataset( - server_address=replay.server_address) - dataset = dataset.batch(self._batch_size, drop_remainder=True) - dataset = dataset.prefetch(self._prefetch_size) - - # Create a counter and logger for bookkeeping steps and performance. - counter = counting.Counter(counter, 'learner') - logger = loggers.make_default_logger( - 'learner', time_delta=self._log_every, steps_key='learner_steps') - - # Create policy loss module if a factory is passed. - if self._policy_loss_factory: - policy_loss_module = self._policy_loss_factory() - else: - policy_loss_module = None - - # Return the learning agent. - return learning.MPOLearner( - policy_network=online_networks['policy'], - critic_network=online_networks['critic'], - observation_network=observation_network, - target_policy_network=target_networks['policy'], - target_critic_network=target_networks['critic'], - target_observation_network=target_observation_network, - discount=self._additional_discount, - num_samples=self._num_samples, - target_policy_update_period=self._target_policy_update_period, - target_critic_update_period=self._target_critic_update_period, - policy_loss_module=policy_loss_module, - dataset=dataset, - counter=counter, - logger=logger) - - def actor( - self, - replay: reverb.Client, - variable_source: acme.VariableSource, - counter: counting.Counter, - ) -> acme.EnvironmentLoop: - """The actor process.""" - - action_spec = self._environment_spec.actions - observation_spec = self._environment_spec.observations - - # Create environment and target networks to act with. - environment = self._environment_factory(False) - agent_networks = self._network_factory(action_spec) - - # Create a stochastic behavior policy. - behavior_modules = [ - agent_networks.get('observation', tf.identity), - agent_networks.get('policy'), - networks.StochasticSamplingHead() - ] - behavior_network = snt.Sequential(behavior_modules) - - # Ensure network variables are created. - tf2_utils.create_variables(behavior_network, [observation_spec]) - policy_variables = {'policy': behavior_network.variables} - - # Create the variable client responsible for keeping the actor up-to-date. - variable_client = tf2_variable_utils.VariableClient( - variable_source, - policy_variables, - update_period=self._variable_update_period) - - # Make sure not to use a random policy after checkpoint restoration by - # assigning variables before running the environment loop. - variable_client.update_and_wait() - - # Component to add things into replay. - adder = adders.NStepTransitionAdder( - client=replay, - n_step=self._n_step, - discount=self._additional_discount) - - # Create the agent. - actor = actors.FeedForwardActor( - policy_network=behavior_network, - adder=adder, - variable_client=variable_client) - - # Create logger and counter; actors will not spam bigtable. - counter = counting.Counter(counter, 'actor') - logger = loggers.make_default_logger( - 'actor', - save_data=False, - time_delta=self._log_every, - steps_key='actor_steps') - - # Create the run loop and return it. - return acme.EnvironmentLoop(environment, actor, counter, logger) - - def evaluator( - self, - variable_source: acme.VariableSource, - counter: counting.Counter, - ): - """The evaluation process.""" - - action_spec = self._environment_spec.actions - observation_spec = self._environment_spec.observations - - # Create environment and target networks to act with. - environment = self._environment_factory(True) - agent_networks = self._network_factory(action_spec) - - # Create a stochastic behavior policy. - evaluator_modules = [ - agent_networks.get('observation', tf.identity), - agent_networks.get('policy'), - networks.StochasticMeanHead(), - ] - - if isinstance(action_spec, specs.BoundedArray): - evaluator_modules += [networks.ClipToSpec(action_spec)] - evaluator_network = snt.Sequential(evaluator_modules) - - # Ensure network variables are created. - tf2_utils.create_variables(evaluator_network, [observation_spec]) - policy_variables = {'policy': evaluator_network.variables} - - # Create the variable client responsible for keeping the actor up-to-date. - variable_client = tf2_variable_utils.VariableClient( - variable_source, - policy_variables, - update_period=self._variable_update_period) - - # Make sure not to evaluate a random actor by assigning variables before - # running the environment loop. - variable_client.update_and_wait() - - # Create the agent. - evaluator = actors.FeedForwardActor( - policy_network=evaluator_network, variable_client=variable_client) - - # Create logger and counter. - counter = counting.Counter(counter, 'evaluator') - logger = loggers.make_default_logger( - 'evaluator', time_delta=self._log_every, steps_key='evaluator_steps') - - # Create the run loop and return it. - return acme.EnvironmentLoop(environment, evaluator, counter, logger) - - def build(self, name='mpo'): - """Build the distributed agent topology.""" - program = lp.Program(name=name) - - with program.group('replay'): - replay = program.add_node(lp.ReverbNode(self.replay)) - - with program.group('counter'): - counter = program.add_node(lp.CourierNode(self.counter)) - - if self._max_actor_steps: - _ = program.add_node( - lp.CourierNode(self.coordinator, counter, self._max_actor_steps)) - - with program.group('learner'): - learner = program.add_node( - lp.CourierNode(self.learner, replay, counter)) - - with program.group('evaluator'): - program.add_node( - lp.CourierNode(self.evaluator, learner, counter)) - - if not self._num_caches: - # Use our learner as a single variable source. - sources = [learner] - else: - with program.group('cacher'): - # Create a set of learner caches. - sources = [] - for _ in range(self._num_caches): - cacher = program.add_node( - lp.CacherNode( - learner, refresh_interval_ms=2000, stale_after_ms=4000)) - sources.append(cacher) - - with program.group('actor'): - # Add actors which pull round-robin from our variable sources. - for actor_id in range(self._num_actors): - source = sources[actor_id % len(sources)] - program.add_node(lp.CourierNode(self.actor, replay, source, counter)) - - return program + """Program definition for MPO.""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], + num_actors: int = 1, + num_caches: int = 0, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = 4, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: Optional[float] = 32.0, + n_step: int = 5, + num_samples: int = 20, + additional_discount: float = 0.99, + target_policy_update_period: int = 100, + target_critic_update_period: int = 100, + variable_update_period: int = 1000, + policy_loss_factory: Optional[Callable[[], snt.Module]] = None, + max_actor_steps: Optional[int] = None, + log_every: float = 10.0, + ): + + if environment_spec is None: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + self._environment_factory = environment_factory + self._network_factory = network_factory + self._policy_loss_factory = policy_loss_factory + self._environment_spec = environment_spec + self._num_actors = num_actors + self._num_caches = num_caches + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._min_replay_size = min_replay_size + self._max_replay_size = max_replay_size + self._samples_per_insert = samples_per_insert + self._n_step = n_step + self._additional_discount = additional_discount + self._num_samples = num_samples + self._target_policy_update_period = target_policy_update_period + self._target_critic_update_period = target_critic_update_period + self._variable_update_period = variable_update_period + self._max_actor_steps = max_actor_steps + self._log_every = log_every + + def replay(self): + """The replay storage.""" + if self._samples_per_insert is not None: + # Create enough of an error buffer to give a 10% tolerance in rate. + samples_per_insert_tolerance = 0.1 * self._samples_per_insert + error_buffer = self._min_replay_size * samples_per_insert_tolerance + + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._min_replay_size, + samples_per_insert=self._samples_per_insert, + error_buffer=error_buffer, + ) + else: + limiter = reverb.rate_limiters.MinSize( + min_size_to_sample=self._min_replay_size + ) + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._max_replay_size, + rate_limiter=limiter, + signature=adders.NStepTransitionAdder.signature(self._environment_spec), + ) + return [replay_table] + + def counter(self): + return tf2_savers.CheckpointingRunner( + counting.Counter(), time_delta_minutes=1, subdirectory="counter" + ) + + def coordinator(self, counter: counting.Counter, max_actor_steps: int): + return lp_utils.StepsLimiter(counter, max_actor_steps) + + def learner( + self, replay: reverb.Client, counter: counting.Counter, + ): + """The Learning part of the agent.""" + + act_spec = self._environment_spec.actions + obs_spec = self._environment_spec.observations + + # Create online and target networks. + online_networks = self._network_factory(act_spec) + target_networks = self._network_factory(act_spec) + + # Make sure observation networks are Sonnet Modules. + observation_network = online_networks.get("observation", tf.identity) + observation_network = tf2_utils.to_sonnet_module(observation_network) + online_networks["observation"] = observation_network + target_observation_network = target_networks.get("observation", tf.identity) + target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network + ) + target_networks["observation"] = target_observation_network + + # Get embedding spec and create observation network variables. + emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + + tf2_utils.create_variables(online_networks["policy"], [emb_spec]) + tf2_utils.create_variables(online_networks["critic"], [emb_spec, act_spec]) + tf2_utils.create_variables(target_networks["observation"], [obs_spec]) + tf2_utils.create_variables(target_networks["policy"], [emb_spec]) + tf2_utils.create_variables(target_networks["critic"], [emb_spec, act_spec]) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset(server_address=replay.server_address) + dataset = dataset.batch(self._batch_size, drop_remainder=True) + dataset = dataset.prefetch(self._prefetch_size) + + # Create a counter and logger for bookkeeping steps and performance. + counter = counting.Counter(counter, "learner") + logger = loggers.make_default_logger( + "learner", time_delta=self._log_every, steps_key="learner_steps" + ) + + # Create policy loss module if a factory is passed. + if self._policy_loss_factory: + policy_loss_module = self._policy_loss_factory() + else: + policy_loss_module = None + + # Return the learning agent. + return learning.MPOLearner( + policy_network=online_networks["policy"], + critic_network=online_networks["critic"], + observation_network=observation_network, + target_policy_network=target_networks["policy"], + target_critic_network=target_networks["critic"], + target_observation_network=target_observation_network, + discount=self._additional_discount, + num_samples=self._num_samples, + target_policy_update_period=self._target_policy_update_period, + target_critic_update_period=self._target_critic_update_period, + policy_loss_module=policy_loss_module, + dataset=dataset, + counter=counter, + logger=logger, + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + ) -> acme.EnvironmentLoop: + """The actor process.""" + + action_spec = self._environment_spec.actions + observation_spec = self._environment_spec.observations + + # Create environment and target networks to act with. + environment = self._environment_factory(False) + agent_networks = self._network_factory(action_spec) + + # Create a stochastic behavior policy. + behavior_modules = [ + agent_networks.get("observation", tf.identity), + agent_networks.get("policy"), + networks.StochasticSamplingHead(), + ] + behavior_network = snt.Sequential(behavior_modules) + + # Ensure network variables are created. + tf2_utils.create_variables(behavior_network, [observation_spec]) + policy_variables = {"policy": behavior_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, + policy_variables, + update_period=self._variable_update_period, + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Component to add things into replay. + adder = adders.NStepTransitionAdder( + client=replay, n_step=self._n_step, discount=self._additional_discount + ) + + # Create the agent. + actor = actors.FeedForwardActor( + policy_network=behavior_network, + adder=adder, + variable_client=variable_client, + ) + + # Create logger and counter; actors will not spam bigtable. + counter = counting.Counter(counter, "actor") + logger = loggers.make_default_logger( + "actor", + save_data=False, + time_delta=self._log_every, + steps_key="actor_steps", + ) + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, variable_source: acme.VariableSource, counter: counting.Counter, + ): + """The evaluation process.""" + + action_spec = self._environment_spec.actions + observation_spec = self._environment_spec.observations + + # Create environment and target networks to act with. + environment = self._environment_factory(True) + agent_networks = self._network_factory(action_spec) + + # Create a stochastic behavior policy. + evaluator_modules = [ + agent_networks.get("observation", tf.identity), + agent_networks.get("policy"), + networks.StochasticMeanHead(), + ] + + if isinstance(action_spec, specs.BoundedArray): + evaluator_modules += [networks.ClipToSpec(action_spec)] + evaluator_network = snt.Sequential(evaluator_modules) + + # Ensure network variables are created. + tf2_utils.create_variables(evaluator_network, [observation_spec]) + policy_variables = {"policy": evaluator_network.variables} + + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = tf2_variable_utils.VariableClient( + variable_source, + policy_variables, + update_period=self._variable_update_period, + ) + + # Make sure not to evaluate a random actor by assigning variables before + # running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + evaluator = actors.FeedForwardActor( + policy_network=evaluator_network, variable_client=variable_client + ) + + # Create logger and counter. + counter = counting.Counter(counter, "evaluator") + logger = loggers.make_default_logger( + "evaluator", time_delta=self._log_every, steps_key="evaluator_steps" + ) + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, evaluator, counter, logger) + + def build(self, name="mpo"): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group("replay"): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group("counter"): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + _ = program.add_node( + lp.CourierNode(self.coordinator, counter, self._max_actor_steps) + ) + + with program.group("learner"): + learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) + + with program.group("evaluator"): + program.add_node(lp.CourierNode(self.evaluator, learner, counter)) + + if not self._num_caches: + # Use our learner as a single variable source. + sources = [learner] + else: + with program.group("cacher"): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000 + ) + ) + sources.append(cacher) + + with program.group("actor"): + # Add actors which pull round-robin from our variable sources. + for actor_id in range(self._num_actors): + source = sources[actor_id % len(sources)] + program.add_node(lp.CourierNode(self.actor, replay, source, counter)) + + return program diff --git a/acme/agents/tf/mpo/agent_distributed_test.py b/acme/agents/tf/mpo/agent_distributed_test.py index 1bf1bda1f2..d2df498f67 100644 --- a/acme/agents/tf/mpo/agent_distributed_test.py +++ b/acme/agents/tf/mpo/agent_distributed_test.py @@ -16,17 +16,17 @@ from typing import Sequence +import launchpad as lp +import numpy as np +import sonnet as snt +from absl.testing import absltest + import acme from acme import specs from acme.agents.tf import mpo from acme.testing import fakes from acme.tf import networks from acme.tf import utils as tf2_utils -import launchpad as lp -import numpy as np -import sonnet as snt - -from absl.testing import absltest def make_networks( @@ -34,67 +34,71 @@ def make_networks( policy_layer_sizes: Sequence[int] = (50, 50), critic_layer_sizes: Sequence[int] = (50, 50), ): - """Creates networks used by the agent.""" - - num_dimensions = np.prod(action_spec.shape, dtype=int) - - observation_network = tf2_utils.batch_concat - policy_network = snt.Sequential([ - networks.LayerNormMLP(policy_layer_sizes, activate_final=True), - networks.MultivariateNormalDiagHead( - num_dimensions, - tanh_mean=True, - init_scale=0.3, - fixed_scale=True, - use_tfd_independent=False) - ]) - evaluator_network = snt.Sequential([ - observation_network, - policy_network, - networks.StochasticMeanHead(), - ]) - # The multiplexer concatenates the (maybe transformed) observations/actions. - multiplexer = networks.CriticMultiplexer( - action_network=networks.ClipToSpec(action_spec)) - critic_network = snt.Sequential([ - multiplexer, - networks.LayerNormMLP(critic_layer_sizes, activate_final=True), - networks.NearZeroInitializedLinear(1), - ]) - - return { - 'policy': policy_network, - 'critic': critic_network, - 'observation': observation_network, - 'evaluator': evaluator_network, - } + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + + observation_network = tf2_utils.batch_concat + policy_network = snt.Sequential( + [ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + init_scale=0.3, + fixed_scale=True, + use_tfd_independent=False, + ), + ] + ) + evaluator_network = snt.Sequential( + [observation_network, policy_network, networks.StochasticMeanHead(),] + ) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer( + action_network=networks.ClipToSpec(action_spec) + ) + critic_network = snt.Sequential( + [ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ] + ) + + return { + "policy": policy_network, + "critic": critic_network, + "observation": observation_network, + "evaluator": evaluator_network, + } class DistributedAgentTest(absltest.TestCase): - """Simple integration/smoke test for the distributed agent.""" + """Simple integration/smoke test for the distributed agent.""" - def test_agent(self): + def test_agent(self): - agent = mpo.DistributedMPO( - environment_factory=lambda x: fakes.ContinuousEnvironment(bounded=True), - network_factory=make_networks, - num_actors=2, - batch_size=32, - min_replay_size=32, - max_replay_size=1000, - ) - program = agent.build() + agent = mpo.DistributedMPO( + environment_factory=lambda x: fakes.ContinuousEnvironment(bounded=True), + network_factory=make_networks, + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + ) + program = agent.build() - (learner_node,) = program.groups['learner'] - learner_node.disable_run() + (learner_node,) = program.groups["learner"] + learner_node.disable_run() - lp.launch(program, launch_type='test_mt') + lp.launch(program, launch_type="test_mt") - learner: acme.Learner = learner_node.create_handle().dereference() + learner: acme.Learner = learner_node.create_handle().dereference() - for _ in range(5): - learner.step() + for _ in range(5): + learner.step() -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/mpo/agent_test.py b/acme/agents/tf/mpo/agent_test.py index 9a7634191f..92e779bd87 100644 --- a/acme/agents/tf/mpo/agent_test.py +++ b/acme/agents/tf/mpo/agent_test.py @@ -14,64 +14,65 @@ """Tests for the MPO agent.""" +import numpy as np +import sonnet as snt +from absl.testing import absltest + import acme from acme import specs from acme.agents.tf import mpo from acme.testing import fakes from acme.tf import networks -import numpy as np -import sonnet as snt - -from absl.testing import absltest def make_networks( - action_spec, - policy_layer_sizes=(10, 10), - critic_layer_sizes=(10, 10), + action_spec, policy_layer_sizes=(10, 10), critic_layer_sizes=(10, 10), ): - """Creates networks used by the agent.""" + """Creates networks used by the agent.""" - num_dimensions = np.prod(action_spec.shape, dtype=int) - critic_layer_sizes = list(critic_layer_sizes) + [1] + num_dimensions = np.prod(action_spec.shape, dtype=int) + critic_layer_sizes = list(critic_layer_sizes) + [1] - policy_network = snt.Sequential([ - networks.LayerNormMLP(policy_layer_sizes), - networks.MultivariateNormalDiagHead(num_dimensions) - ]) - critic_network = networks.CriticMultiplexer( - critic_network=networks.LayerNormMLP(critic_layer_sizes)) + policy_network = snt.Sequential( + [ + networks.LayerNormMLP(policy_layer_sizes), + networks.MultivariateNormalDiagHead(num_dimensions), + ] + ) + critic_network = networks.CriticMultiplexer( + critic_network=networks.LayerNormMLP(critic_layer_sizes) + ) - return { - 'policy': policy_network, - 'critic': critic_network, - } + return { + "policy": policy_network, + "critic": critic_network, + } class MPOTest(absltest.TestCase): - - def test_mpo(self): - # Create a fake environment to test with. - environment = fakes.ContinuousEnvironment(episode_length=10, bounded=False) - spec = specs.make_environment_spec(environment) - - # Create networks. - agent_networks = make_networks(spec.actions) - - # Construct the agent. - agent = mpo.MPO( - spec, - policy_network=agent_networks['policy'], - critic_network=agent_networks['critic'], - batch_size=10, - samples_per_insert=2, - min_replay_size=10) - - # Try running the environment loop. We have no assertions here because all - # we care about is that the agent runs without raising any errors. - loop = acme.EnvironmentLoop(environment, agent) - loop.run(num_episodes=2) - - -if __name__ == '__main__': - absltest.main() + def test_mpo(self): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment(episode_length=10, bounded=False) + spec = specs.make_environment_spec(environment) + + # Create networks. + agent_networks = make_networks(spec.actions) + + # Construct the agent. + agent = mpo.MPO( + spec, + policy_network=agent_networks["policy"], + critic_network=agent_networks["critic"], + batch_size=10, + samples_per_insert=2, + min_replay_size=10, + ) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/mpo/learning.py b/acme/agents/tf/mpo/learning.py index b18ff459bf..be582f3d5c 100644 --- a/acme/agents/tf/mpo/learning.py +++ b/acme/agents/tf/mpo/learning.py @@ -17,271 +17,280 @@ import time from typing import List, Optional +import numpy as np +import sonnet as snt +import tensorflow as tf +import trfl + import acme from acme import types from acme.tf import losses from acme.tf import savers as tf2_savers from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers -import numpy as np -import sonnet as snt -import tensorflow as tf -import trfl +from acme.utils import counting, loggers class MPOLearner(acme.Learner): - """MPO learner.""" - - def __init__( - self, - policy_network: snt.Module, - critic_network: snt.Module, - target_policy_network: snt.Module, - target_critic_network: snt.Module, - discount: float, - num_samples: int, - target_policy_update_period: int, - target_critic_update_period: int, - dataset: tf.data.Dataset, - observation_network: types.TensorTransformation = tf.identity, - target_observation_network: types.TensorTransformation = tf.identity, - policy_loss_module: Optional[snt.Module] = None, - policy_optimizer: Optional[snt.Optimizer] = None, - critic_optimizer: Optional[snt.Optimizer] = None, - dual_optimizer: Optional[snt.Optimizer] = None, - clipping: bool = True, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - checkpoint: bool = True, - save_directory: str = '~/acme', - ): - - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger('learner') - self._discount = discount - self._num_samples = num_samples - self._clipping = clipping - - # Necessary to track when to update target networks. - self._num_steps = tf.Variable(0, dtype=tf.int32) - self._target_policy_update_period = target_policy_update_period - self._target_critic_update_period = target_critic_update_period - - # Batch dataset and create iterator. - # TODO(b/155086959): Fix type stubs and remove. - self._iterator = iter(dataset) # pytype: disable=wrong-arg-types - - # Store online and target networks. - self._policy_network = policy_network - self._critic_network = critic_network - self._target_policy_network = target_policy_network - self._target_critic_network = target_critic_network - - # Make sure observation networks are snt.Module's so they have variables. - self._observation_network = tf2_utils.to_sonnet_module(observation_network) - self._target_observation_network = tf2_utils.to_sonnet_module( - target_observation_network) - - self._policy_loss_module = policy_loss_module or losses.MPO( - epsilon=1e-1, - epsilon_penalty=1e-3, - epsilon_mean=2.5e-3, - epsilon_stddev=1e-6, - init_log_temperature=10., - init_log_alpha_mean=10., - init_log_alpha_stddev=1000.) - - # Create the optimizers. - self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) - self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) - self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) - - # Expose the variables. - policy_network_to_expose = snt.Sequential( - [self._target_observation_network, self._target_policy_network]) - self._variables = { - 'critic': self._target_critic_network.variables, - 'policy': policy_network_to_expose.variables, - } - - # Create a checkpointer and snapshotter object. - self._checkpointer = None - self._snapshotter = None - - if checkpoint: - self._checkpointer = tf2_savers.Checkpointer( - directory=save_directory, - subdirectory='mpo_learner', - objects_to_save={ - 'counter': self._counter, - 'policy': self._policy_network, - 'critic': self._critic_network, - 'observation_network': self._observation_network, - 'target_policy': self._target_policy_network, - 'target_critic': self._target_critic_network, - 'target_observation_network': self._target_observation_network, - 'policy_optimizer': self._policy_optimizer, - 'critic_optimizer': self._critic_optimizer, - 'dual_optimizer': self._dual_optimizer, - 'policy_loss_module': self._policy_loss_module, - 'num_steps': self._num_steps, - }) - - self._snapshotter = tf2_savers.Snapshotter( - directory=save_directory, - objects_to_save={ - 'policy': - snt.Sequential([ - self._target_observation_network, - self._target_policy_network - ]), - }) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - @tf.function - def _step(self) -> types.Nest: - # Update target network. - online_policy_variables = self._policy_network.variables - target_policy_variables = self._target_policy_network.variables - online_critic_variables = ( - *self._observation_network.variables, - *self._critic_network.variables, - ) - target_critic_variables = ( - *self._target_observation_network.variables, - *self._target_critic_network.variables, - ) - - # Make online policy -> target policy network update ops. - if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: - for src, dest in zip(online_policy_variables, target_policy_variables): - dest.assign(src) - # Make online critic -> target critic network update ops. - if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: - for src, dest in zip(online_critic_variables, target_critic_variables): - dest.assign(src) - - # Increment number of learner steps for periodic update bookkeeping. - self._num_steps.assign_add(1) - - # Get next batch of data. - inputs = next(self._iterator) - - # Get data from replay (dropping extras if any). Note there is no - # extra data here because we do not insert any into Reverb. - transitions: types.Transition = inputs.data - - # Cast the additional discount to match the environment discount dtype. - discount = tf.cast(self._discount, dtype=transitions.discount.dtype) - - with tf.GradientTape(persistent=True) as tape: - # Maybe transform the observation before feeding into policy and critic. - # Transforming the observations this way at the start of the learning - # step effectively means that the policy and critic share observation - # network weights. - o_tm1 = self._observation_network(transitions.observation) - # This stop_gradient prevents gradients to propagate into the target - # observation network. In addition, since the online policy network is - # evaluated at o_t, this also means the policy loss does not influence - # the observation network training. - o_t = tf.stop_gradient( - self._target_observation_network(transitions.next_observation)) - - # Get action distributions from policy networks. - online_action_distribution = self._policy_network(o_t) - target_action_distribution = self._target_policy_network(o_t) - - # Get sampled actions to evaluate policy; of size [N, B, ...]. - sampled_actions = target_action_distribution.sample(self._num_samples) - tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] - - # Compute the target critic's Q-value of the sampled actions in state o_t. - sampled_q_t = self._target_critic_network( - # Merge batch dimensions; to shape [N*B, ...]. - snt.merge_leading_dims(tiled_o_t, num_dims=2), - snt.merge_leading_dims(sampled_actions, num_dims=2)) - - # Reshape Q-value samples back to original batch dimensions and average - # them to compute the TD-learning bootstrap target. - sampled_q_t = tf.reshape(sampled_q_t, (self._num_samples, -1)) # [N, B] - q_t = tf.reduce_mean(sampled_q_t, axis=0) # [B] - - # Compute online critic value of a_tm1 in state o_tm1. - q_tm1 = self._critic_network(o_tm1, transitions.action) # [B, 1] - q_tm1 = tf.squeeze(q_tm1, axis=-1) # [B]; necessary for trfl.td_learning. - - # Critic loss. - critic_loss = trfl.td_learning(q_tm1, transitions.reward, - discount * transitions.discount, q_t).loss - critic_loss = tf.reduce_mean(critic_loss) - - # Actor learning. - policy_loss, policy_stats = self._policy_loss_module( - online_action_distribution=online_action_distribution, - target_action_distribution=target_action_distribution, - actions=sampled_actions, - q_values=sampled_q_t) - - # For clarity, explicitly define which variables are trained by which loss. - critic_trainable_variables = ( - # In this agent, the critic loss trains the observation network. - self._observation_network.trainable_variables + - self._critic_network.trainable_variables) - policy_trainable_variables = self._policy_network.trainable_variables - # The following are the MPO dual variables, stored in the loss module. - dual_trainable_variables = self._policy_loss_module.trainable_variables - - # Compute gradients. - critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) - policy_gradients, dual_gradients = tape.gradient( - policy_loss, (policy_trainable_variables, dual_trainable_variables)) - - # Delete the tape manually because of the persistent=True flag. - del tape - - # Maybe clip gradients. - if self._clipping: - policy_gradients = tuple(tf.clip_by_global_norm(policy_gradients, 40.)[0]) - critic_gradients = tuple(tf.clip_by_global_norm(critic_gradients, 40.)[0]) - - # Apply gradients. - self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) - self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) - self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) - - # Losses to track. - fetches = { - 'critic_loss': critic_loss, - 'policy_loss': policy_loss, - } - fetches.update(policy_stats) # Log MPO stats. - - return fetches - - def step(self): - # Run the learning step. - fetches = self._step() - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Update our counts and record it. - counts = self._counter.increment(steps=1, walltime=elapsed_time) - fetches.update(counts) - - # Checkpoint and attempt to write the logs. - if self._checkpointer is not None: - self._checkpointer.save() - if self._snapshotter is not None: - self._snapshotter.save() - self._logger.write(fetches) - - def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: - return [tf2_utils.to_numpy(self._variables[name]) for name in names] + """MPO learner.""" + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + target_policy_network: snt.Module, + target_critic_network: snt.Module, + discount: float, + num_samples: int, + target_policy_update_period: int, + target_critic_update_period: int, + dataset: tf.data.Dataset, + observation_network: types.TensorTransformation = tf.identity, + target_observation_network: types.TensorTransformation = tf.identity, + policy_loss_module: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + dual_optimizer: Optional[snt.Optimizer] = None, + clipping: bool = True, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + save_directory: str = "~/acme", + ): + + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger("learner") + self._discount = discount + self._num_samples = num_samples + self._clipping = clipping + + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_policy_update_period = target_policy_update_period + self._target_critic_update_period = target_critic_update_period + + # Batch dataset and create iterator. + # TODO(b/155086959): Fix type stubs and remove. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + + # Make sure observation networks are snt.Module's so they have variables. + self._observation_network = tf2_utils.to_sonnet_module(observation_network) + self._target_observation_network = tf2_utils.to_sonnet_module( + target_observation_network + ) + + self._policy_loss_module = policy_loss_module or losses.MPO( + epsilon=1e-1, + epsilon_penalty=1e-3, + epsilon_mean=2.5e-3, + epsilon_stddev=1e-6, + init_log_temperature=10.0, + init_log_alpha_mean=10.0, + init_log_alpha_stddev=1000.0, + ) + + # Create the optimizers. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) + + # Expose the variables. + policy_network_to_expose = snt.Sequential( + [self._target_observation_network, self._target_policy_network] + ) + self._variables = { + "critic": self._target_critic_network.variables, + "policy": policy_network_to_expose.variables, + } + + # Create a checkpointer and snapshotter object. + self._checkpointer = None + self._snapshotter = None + + if checkpoint: + self._checkpointer = tf2_savers.Checkpointer( + directory=save_directory, + subdirectory="mpo_learner", + objects_to_save={ + "counter": self._counter, + "policy": self._policy_network, + "critic": self._critic_network, + "observation_network": self._observation_network, + "target_policy": self._target_policy_network, + "target_critic": self._target_critic_network, + "target_observation_network": self._target_observation_network, + "policy_optimizer": self._policy_optimizer, + "critic_optimizer": self._critic_optimizer, + "dual_optimizer": self._dual_optimizer, + "policy_loss_module": self._policy_loss_module, + "num_steps": self._num_steps, + }, + ) + + self._snapshotter = tf2_savers.Snapshotter( + directory=save_directory, + objects_to_save={ + "policy": snt.Sequential( + [self._target_observation_network, self._target_policy_network] + ), + }, + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self) -> types.Nest: + # Update target network. + online_policy_variables = self._policy_network.variables + target_policy_variables = self._target_policy_network.variables + online_critic_variables = ( + *self._observation_network.variables, + *self._critic_network.variables, + ) + target_critic_variables = ( + *self._target_observation_network.variables, + *self._target_critic_network.variables, + ) + + # Make online policy -> target policy network update ops. + if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: + for src, dest in zip(online_policy_variables, target_policy_variables): + dest.assign(src) + # Make online critic -> target critic network update ops. + if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: + for src, dest in zip(online_critic_variables, target_critic_variables): + dest.assign(src) + + # Increment number of learner steps for periodic update bookkeeping. + self._num_steps.assign_add(1) + + # Get next batch of data. + inputs = next(self._iterator) + + # Get data from replay (dropping extras if any). Note there is no + # extra data here because we do not insert any into Reverb. + transitions: types.Transition = inputs.data + + # Cast the additional discount to match the environment discount dtype. + discount = tf.cast(self._discount, dtype=transitions.discount.dtype) + + with tf.GradientTape(persistent=True) as tape: + # Maybe transform the observation before feeding into policy and critic. + # Transforming the observations this way at the start of the learning + # step effectively means that the policy and critic share observation + # network weights. + o_tm1 = self._observation_network(transitions.observation) + # This stop_gradient prevents gradients to propagate into the target + # observation network. In addition, since the online policy network is + # evaluated at o_t, this also means the policy loss does not influence + # the observation network training. + o_t = tf.stop_gradient( + self._target_observation_network(transitions.next_observation) + ) + + # Get action distributions from policy networks. + online_action_distribution = self._policy_network(o_t) + target_action_distribution = self._target_policy_network(o_t) + + # Get sampled actions to evaluate policy; of size [N, B, ...]. + sampled_actions = target_action_distribution.sample(self._num_samples) + tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] + + # Compute the target critic's Q-value of the sampled actions in state o_t. + sampled_q_t = self._target_critic_network( + # Merge batch dimensions; to shape [N*B, ...]. + snt.merge_leading_dims(tiled_o_t, num_dims=2), + snt.merge_leading_dims(sampled_actions, num_dims=2), + ) + + # Reshape Q-value samples back to original batch dimensions and average + # them to compute the TD-learning bootstrap target. + sampled_q_t = tf.reshape(sampled_q_t, (self._num_samples, -1)) # [N, B] + q_t = tf.reduce_mean(sampled_q_t, axis=0) # [B] + + # Compute online critic value of a_tm1 in state o_tm1. + q_tm1 = self._critic_network(o_tm1, transitions.action) # [B, 1] + q_tm1 = tf.squeeze(q_tm1, axis=-1) # [B]; necessary for trfl.td_learning. + + # Critic loss. + critic_loss = trfl.td_learning( + q_tm1, transitions.reward, discount * transitions.discount, q_t + ).loss + critic_loss = tf.reduce_mean(critic_loss) + + # Actor learning. + policy_loss, policy_stats = self._policy_loss_module( + online_action_distribution=online_action_distribution, + target_action_distribution=target_action_distribution, + actions=sampled_actions, + q_values=sampled_q_t, + ) + + # For clarity, explicitly define which variables are trained by which loss. + critic_trainable_variables = ( + # In this agent, the critic loss trains the observation network. + self._observation_network.trainable_variables + + self._critic_network.trainable_variables + ) + policy_trainable_variables = self._policy_network.trainable_variables + # The following are the MPO dual variables, stored in the loss module. + dual_trainable_variables = self._policy_loss_module.trainable_variables + + # Compute gradients. + critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) + policy_gradients, dual_gradients = tape.gradient( + policy_loss, (policy_trainable_variables, dual_trainable_variables) + ) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Maybe clip gradients. + if self._clipping: + policy_gradients = tuple(tf.clip_by_global_norm(policy_gradients, 40.0)[0]) + critic_gradients = tuple(tf.clip_by_global_norm(critic_gradients, 40.0)[0]) + + # Apply gradients. + self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) + self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) + self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) + + # Losses to track. + fetches = { + "critic_loss": critic_loss, + "policy_loss": policy_loss, + } + fetches.update(policy_stats) # Log MPO stats. + + return fetches + + def step(self): + # Run the learning step. + fetches = self._step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + if self._checkpointer is not None: + self._checkpointer.save() + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] diff --git a/acme/agents/tf/r2d2/agent.py b/acme/agents/tf/r2d2/agent.py index 9acd5668da..463ce3f0a2 100644 --- a/acme/agents/tf/r2d2/agent.py +++ b/acme/agents/tf/r2d2/agent.py @@ -17,24 +17,23 @@ import copy from typing import Optional -from acme import datasets -from acme import specs +import reverb +import sonnet as snt +import tensorflow as tf +import trfl + +from acme import datasets, specs from acme.adders import reverb as adders from acme.agents import agent from acme.agents.tf import actors from acme.agents.tf.r2d2 import learning from acme.tf import savers as tf2_savers from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers -import reverb -import sonnet as snt -import tensorflow as tf -import trfl +from acme.utils import counting, loggers class R2D2(agent.Agent): - """R2D2 Agent. + """R2D2 Agent. This implements a single-process R2D2 agent. This is a Q-learning algorithm that generates data via a (epislon-greedy) behavior policy, inserts @@ -42,111 +41,113 @@ class R2D2(agent.Agent): as a result the behavior) by sampling from this buffer. """ - def __init__( - self, - environment_spec: specs.EnvironmentSpec, - network: snt.RNNCore, - burn_in_length: int, - trace_length: int, - replay_period: int, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - discount: float = 0.99, - batch_size: int = 32, - prefetch_size: int = tf.data.experimental.AUTOTUNE, - target_update_period: int = 100, - importance_sampling_exponent: float = 0.2, - priority_exponent: float = 0.6, - epsilon: float = 0.01, - learning_rate: float = 1e-3, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: float = 32.0, - store_lstm_state: bool = True, - max_priority_weight: float = 0.9, - checkpoint: bool = True, - ): - - if store_lstm_state: - extra_spec = { - 'core_state': tf2_utils.squeeze_batch_dim(network.initial_state(1)), - } - else: - extra_spec = () - - sequence_length = burn_in_length + trace_length + 1 - replay_table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Prioritized(priority_exponent), - remover=reverb.selectors.Fifo(), - max_size=max_replay_size, - rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), - signature=adders.SequenceAdder.signature( - environment_spec, extra_spec, sequence_length=sequence_length)) - self._server = reverb.Server([replay_table], port=None) - address = f'localhost:{self._server.port}' - - # Component to add things into replay. - adder = adders.SequenceAdder( - client=reverb.Client(address), - period=replay_period, - sequence_length=sequence_length, - ) - - # The dataset object to learn from. - dataset = datasets.make_reverb_dataset( - server_address=address, - batch_size=batch_size, - prefetch_size=prefetch_size) - - target_network = copy.deepcopy(network) - tf2_utils.create_variables(network, [environment_spec.observations]) - tf2_utils.create_variables(target_network, [environment_spec.observations]) - - learner = learning.R2D2Learner( - environment_spec=environment_spec, - network=network, - target_network=target_network, - burn_in_length=burn_in_length, - sequence_length=sequence_length, - dataset=dataset, - reverb_client=reverb.TFClient(address), - counter=counter, - logger=logger, - discount=discount, - target_update_period=target_update_period, - importance_sampling_exponent=importance_sampling_exponent, - max_replay_size=max_replay_size, - learning_rate=learning_rate, - store_lstm_state=store_lstm_state, - max_priority_weight=max_priority_weight, - ) - - self._checkpointer = tf2_savers.Checkpointer( - subdirectory='r2d2_learner', - time_delta_minutes=60, - objects_to_save=learner.state, - enable_checkpointing=checkpoint, - ) - self._snapshotter = tf2_savers.Snapshotter( - objects_to_save={'network': network}, time_delta_minutes=60.) - - policy_network = snt.DeepRNN([ - network, - lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), - ]) - - actor = actors.RecurrentActor( - policy_network, adder, store_recurrent_state=store_lstm_state) - observations_per_step = ( - float(replay_period * batch_size) / samples_per_insert) - super().__init__( - actor=actor, - learner=learner, - min_observations=replay_period * max(batch_size, min_replay_size), - observations_per_step=observations_per_step) - - def update(self): - super().update() - self._snapshotter.save() - self._checkpointer.save() + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + network: snt.RNNCore, + burn_in_length: int, + trace_length: int, + replay_period: int, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + discount: float = 0.99, + batch_size: int = 32, + prefetch_size: int = tf.data.experimental.AUTOTUNE, + target_update_period: int = 100, + importance_sampling_exponent: float = 0.2, + priority_exponent: float = 0.6, + epsilon: float = 0.01, + learning_rate: float = 1e-3, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 32.0, + store_lstm_state: bool = True, + max_priority_weight: float = 0.9, + checkpoint: bool = True, + ): + + if store_lstm_state: + extra_spec = { + "core_state": tf2_utils.squeeze_batch_dim(network.initial_state(1)), + } + else: + extra_spec = () + + sequence_length = burn_in_length + trace_length + 1 + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Prioritized(priority_exponent), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), + signature=adders.SequenceAdder.signature( + environment_spec, extra_spec, sequence_length=sequence_length + ), + ) + self._server = reverb.Server([replay_table], port=None) + address = f"localhost:{self._server.port}" + + # Component to add things into replay. + adder = adders.SequenceAdder( + client=reverb.Client(address), + period=replay_period, + sequence_length=sequence_length, + ) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset( + server_address=address, batch_size=batch_size, prefetch_size=prefetch_size + ) + + target_network = copy.deepcopy(network) + tf2_utils.create_variables(network, [environment_spec.observations]) + tf2_utils.create_variables(target_network, [environment_spec.observations]) + + learner = learning.R2D2Learner( + environment_spec=environment_spec, + network=network, + target_network=target_network, + burn_in_length=burn_in_length, + sequence_length=sequence_length, + dataset=dataset, + reverb_client=reverb.TFClient(address), + counter=counter, + logger=logger, + discount=discount, + target_update_period=target_update_period, + importance_sampling_exponent=importance_sampling_exponent, + max_replay_size=max_replay_size, + learning_rate=learning_rate, + store_lstm_state=store_lstm_state, + max_priority_weight=max_priority_weight, + ) + + self._checkpointer = tf2_savers.Checkpointer( + subdirectory="r2d2_learner", + time_delta_minutes=60, + objects_to_save=learner.state, + enable_checkpointing=checkpoint, + ) + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={"network": network}, time_delta_minutes=60.0 + ) + + policy_network = snt.DeepRNN( + [network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),] + ) + + actor = actors.RecurrentActor( + policy_network, adder, store_recurrent_state=store_lstm_state + ) + observations_per_step = float(replay_period * batch_size) / samples_per_insert + super().__init__( + actor=actor, + learner=learner, + min_observations=replay_period * max(batch_size, min_replay_size), + observations_per_step=observations_per_step, + ) + + def update(self): + super().update() + self._snapshotter.save() + self._checkpointer.save() diff --git a/acme/agents/tf/r2d2/agent_distributed.py b/acme/agents/tf/r2d2/agent_distributed.py index 6be0025195..ba2d090ba7 100644 --- a/acme/agents/tf/r2d2/agent_distributed.py +++ b/acme/agents/tf/r2d2/agent_distributed.py @@ -17,17 +17,6 @@ import copy from typing import Callable, List, Optional -import acme -from acme import datasets -from acme import specs -from acme.adders import reverb as adders -from acme.agents.tf import actors -from acme.agents.tf.r2d2 import learning -from acme.tf import savers as tf2_savers -from acme.tf import utils as tf2_utils -from acme.tf import variable_utils as tf2_variable_utils -from acme.utils import counting -from acme.utils import loggers import dm_env import launchpad as lp import numpy as np @@ -36,241 +25,266 @@ import tensorflow as tf import trfl +import acme +from acme import datasets, specs +from acme.adders import reverb as adders +from acme.agents.tf import actors +from acme.agents.tf.r2d2 import learning +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils +from acme.utils import counting, loggers + class DistributedR2D2: - """Program definition for Recurrent Replay Distributed DQN (R2D2).""" - - def __init__(self, - environment_factory: Callable[[bool], dm_env.Environment], - network_factory: Callable[[specs.DiscreteArray], snt.RNNCore], - num_actors: int, - burn_in_length: int, - trace_length: int, - replay_period: int, - environment_spec: Optional[specs.EnvironmentSpec] = None, - batch_size: int = 256, - prefetch_size: int = tf.data.experimental.AUTOTUNE, - min_replay_size: int = 1000, - max_replay_size: int = 100_000, - samples_per_insert: float = 32.0, - discount: float = 0.99, - priority_exponent: float = 0.6, - importance_sampling_exponent: float = 0.2, - variable_update_period: int = 1000, - learning_rate: float = 1e-3, - evaluator_epsilon: float = 0., - target_update_period: int = 100, - save_logs: bool = False): - - if environment_spec is None: - environment_spec = specs.make_environment_spec(environment_factory(False)) - - self._environment_factory = environment_factory - self._network_factory = network_factory - self._environment_spec = environment_spec - self._num_actors = num_actors - self._batch_size = batch_size - self._prefetch_size = prefetch_size - self._min_replay_size = min_replay_size - self._max_replay_size = max_replay_size - self._samples_per_insert = samples_per_insert - self._burn_in_length = burn_in_length - self._trace_length = trace_length - self._replay_period = replay_period - self._discount = discount - self._target_update_period = target_update_period - self._variable_update_period = variable_update_period - self._save_logs = save_logs - self._priority_exponent = priority_exponent - self._learning_rate = learning_rate - self._evaluator_epsilon = evaluator_epsilon - self._importance_sampling_exponent = importance_sampling_exponent - - self._obs_spec = environment_spec.observations - - def replay(self) -> List[reverb.Table]: - """The replay storage.""" - network = self._network_factory(self._environment_spec.actions) - extra_spec = { - 'core_state': network.initial_state(1), - } - # Remove batch dimensions. - extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) - if self._samples_per_insert: - limiter = reverb.rate_limiters.SampleToInsertRatio( - min_size_to_sample=self._min_replay_size, - samples_per_insert=self._samples_per_insert, - error_buffer=self._batch_size) - else: - limiter = reverb.rate_limiters.MinSize(self._min_replay_size) - table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Prioritized(self._priority_exponent), - remover=reverb.selectors.Fifo(), - max_size=self._max_replay_size, - rate_limiter=limiter, - signature=adders.SequenceAdder.signature( - self._environment_spec, - extra_spec, - sequence_length=self._burn_in_length + self._trace_length + 1)) - - return [table] - - def counter(self): - """Creates the master counter process.""" - return tf2_savers.CheckpointingRunner( - counting.Counter(), time_delta_minutes=1, subdirectory='counter') - - def learner(self, replay: reverb.Client, counter: counting.Counter): - """The Learning part of the agent.""" - # Use architect and create the environment. - # Create the networks. - network = self._network_factory(self._environment_spec.actions) - target_network = copy.deepcopy(network) - - tf2_utils.create_variables(network, [self._obs_spec]) - tf2_utils.create_variables(target_network, [self._obs_spec]) - - # The dataset object to learn from. - reverb_client = reverb.TFClient(replay.server_address) - sequence_length = self._burn_in_length + self._trace_length + 1 - dataset = datasets.make_reverb_dataset( - server_address=replay.server_address, - batch_size=self._batch_size, - prefetch_size=self._prefetch_size) - - counter = counting.Counter(counter, 'learner') - logger = loggers.make_default_logger( - 'learner', save_data=True, steps_key='learner_steps') - # Return the learning agent. - learner = learning.R2D2Learner( - environment_spec=self._environment_spec, - network=network, - target_network=target_network, - burn_in_length=self._burn_in_length, - sequence_length=sequence_length, - dataset=dataset, - reverb_client=reverb_client, - counter=counter, - logger=logger, - discount=self._discount, - target_update_period=self._target_update_period, - importance_sampling_exponent=self._importance_sampling_exponent, - learning_rate=self._learning_rate, - max_replay_size=self._max_replay_size) - return tf2_savers.CheckpointingRunner( - wrapped=learner, time_delta_minutes=60, subdirectory='r2d2_learner') - - def actor( - self, - replay: reverb.Client, - variable_source: acme.VariableSource, - counter: counting.Counter, - epsilon: float, - ) -> acme.EnvironmentLoop: - """The actor process.""" - environment = self._environment_factory(False) - network = self._network_factory(self._environment_spec.actions) - - tf2_utils.create_variables(network, [self._obs_spec]) - - policy_network = snt.DeepRNN([ - network, - lambda qs: tf.cast(trfl.epsilon_greedy(qs, epsilon).sample(), tf.int32), - ]) - - # Component to add things into replay. - sequence_length = self._burn_in_length + self._trace_length + 1 - adder = adders.SequenceAdder( - client=replay, - period=self._replay_period, - sequence_length=sequence_length, - delta_encoded=True, - ) - - variable_client = tf2_variable_utils.VariableClient( - client=variable_source, - variables={'policy': policy_network.variables}, - update_period=self._variable_update_period) - - # Make sure not to use a random policy after checkpoint restoration by - # assigning variables before running the environment loop. - variable_client.update_and_wait() - - # Create the agent. - actor = actors.RecurrentActor( - policy_network=policy_network, - variable_client=variable_client, - adder=adder) - - counter = counting.Counter(counter, 'actor') - logger = loggers.make_default_logger( - 'actor', save_data=False, steps_key='actor_steps') - - # Create the loop to connect environment and agent. - return acme.EnvironmentLoop(environment, actor, counter, logger) - - def evaluator( - self, - variable_source: acme.VariableSource, - counter: counting.Counter, - ): - """The evaluation process.""" - environment = self._environment_factory(True) - network = self._network_factory(self._environment_spec.actions) - - tf2_utils.create_variables(network, [self._obs_spec]) - policy_network = snt.DeepRNN([ - network, - lambda qs: tf.cast(tf.argmax(qs, axis=-1), tf.int32), - ]) - - variable_client = tf2_variable_utils.VariableClient( - client=variable_source, - variables={'policy': policy_network.variables}, - update_period=self._variable_update_period) - - # Make sure not to use a random policy after checkpoint restoration by - # assigning variables before running the environment loop. - variable_client.update_and_wait() - - # Create the agent. - actor = actors.RecurrentActor( - policy_network=policy_network, variable_client=variable_client) - - # Create the run loop and return it. - logger = loggers.make_default_logger( - 'evaluator', save_data=True, steps_key='evaluator_steps') - counter = counting.Counter(counter, 'evaluator') - - return acme.EnvironmentLoop(environment, actor, counter, logger) - - def build(self, name='r2d2'): - """Build the distributed agent topology.""" - program = lp.Program(name=name) - - with program.group('replay'): - replay = program.add_node(lp.ReverbNode(self.replay)) - - with program.group('counter'): - counter = program.add_node(lp.CourierNode(self.counter)) - - with program.group('learner'): - learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) - - with program.group('cacher'): - cacher = program.add_node( - lp.CacherNode(learner, refresh_interval_ms=2000, stale_after_ms=4000)) - - with program.group('evaluator'): - program.add_node(lp.CourierNode(self.evaluator, cacher, counter)) - - # Generate an epsilon for each actor. - epsilons = np.flip(np.logspace(1, 8, self._num_actors, base=0.4), axis=0) - - with program.group('actor'): - for epsilon in epsilons: - program.add_node( - lp.CourierNode(self.actor, replay, cacher, counter, epsilon)) - - return program + """Program definition for Recurrent Replay Distributed DQN (R2D2).""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.DiscreteArray], snt.RNNCore], + num_actors: int, + burn_in_length: int, + trace_length: int, + replay_period: int, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = tf.data.experimental.AUTOTUNE, + min_replay_size: int = 1000, + max_replay_size: int = 100_000, + samples_per_insert: float = 32.0, + discount: float = 0.99, + priority_exponent: float = 0.6, + importance_sampling_exponent: float = 0.2, + variable_update_period: int = 1000, + learning_rate: float = 1e-3, + evaluator_epsilon: float = 0.0, + target_update_period: int = 100, + save_logs: bool = False, + ): + + if environment_spec is None: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + self._environment_factory = environment_factory + self._network_factory = network_factory + self._environment_spec = environment_spec + self._num_actors = num_actors + self._batch_size = batch_size + self._prefetch_size = prefetch_size + self._min_replay_size = min_replay_size + self._max_replay_size = max_replay_size + self._samples_per_insert = samples_per_insert + self._burn_in_length = burn_in_length + self._trace_length = trace_length + self._replay_period = replay_period + self._discount = discount + self._target_update_period = target_update_period + self._variable_update_period = variable_update_period + self._save_logs = save_logs + self._priority_exponent = priority_exponent + self._learning_rate = learning_rate + self._evaluator_epsilon = evaluator_epsilon + self._importance_sampling_exponent = importance_sampling_exponent + + self._obs_spec = environment_spec.observations + + def replay(self) -> List[reverb.Table]: + """The replay storage.""" + network = self._network_factory(self._environment_spec.actions) + extra_spec = { + "core_state": network.initial_state(1), + } + # Remove batch dimensions. + extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) + if self._samples_per_insert: + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._min_replay_size, + samples_per_insert=self._samples_per_insert, + error_buffer=self._batch_size, + ) + else: + limiter = reverb.rate_limiters.MinSize(self._min_replay_size) + table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Prioritized(self._priority_exponent), + remover=reverb.selectors.Fifo(), + max_size=self._max_replay_size, + rate_limiter=limiter, + signature=adders.SequenceAdder.signature( + self._environment_spec, + extra_spec, + sequence_length=self._burn_in_length + self._trace_length + 1, + ), + ) + + return [table] + + def counter(self): + """Creates the master counter process.""" + return tf2_savers.CheckpointingRunner( + counting.Counter(), time_delta_minutes=1, subdirectory="counter" + ) + + def learner(self, replay: reverb.Client, counter: counting.Counter): + """The Learning part of the agent.""" + # Use architect and create the environment. + # Create the networks. + network = self._network_factory(self._environment_spec.actions) + target_network = copy.deepcopy(network) + + tf2_utils.create_variables(network, [self._obs_spec]) + tf2_utils.create_variables(target_network, [self._obs_spec]) + + # The dataset object to learn from. + reverb_client = reverb.TFClient(replay.server_address) + sequence_length = self._burn_in_length + self._trace_length + 1 + dataset = datasets.make_reverb_dataset( + server_address=replay.server_address, + batch_size=self._batch_size, + prefetch_size=self._prefetch_size, + ) + + counter = counting.Counter(counter, "learner") + logger = loggers.make_default_logger( + "learner", save_data=True, steps_key="learner_steps" + ) + # Return the learning agent. + learner = learning.R2D2Learner( + environment_spec=self._environment_spec, + network=network, + target_network=target_network, + burn_in_length=self._burn_in_length, + sequence_length=sequence_length, + dataset=dataset, + reverb_client=reverb_client, + counter=counter, + logger=logger, + discount=self._discount, + target_update_period=self._target_update_period, + importance_sampling_exponent=self._importance_sampling_exponent, + learning_rate=self._learning_rate, + max_replay_size=self._max_replay_size, + ) + return tf2_savers.CheckpointingRunner( + wrapped=learner, time_delta_minutes=60, subdirectory="r2d2_learner" + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + epsilon: float, + ) -> acme.EnvironmentLoop: + """The actor process.""" + environment = self._environment_factory(False) + network = self._network_factory(self._environment_spec.actions) + + tf2_utils.create_variables(network, [self._obs_spec]) + + policy_network = snt.DeepRNN( + [ + network, + lambda qs: tf.cast(trfl.epsilon_greedy(qs, epsilon).sample(), tf.int32), + ] + ) + + # Component to add things into replay. + sequence_length = self._burn_in_length + self._trace_length + 1 + adder = adders.SequenceAdder( + client=replay, + period=self._replay_period, + sequence_length=sequence_length, + delta_encoded=True, + ) + + variable_client = tf2_variable_utils.VariableClient( + client=variable_source, + variables={"policy": policy_network.variables}, + update_period=self._variable_update_period, + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + actor = actors.RecurrentActor( + policy_network=policy_network, variable_client=variable_client, adder=adder + ) + + counter = counting.Counter(counter, "actor") + logger = loggers.make_default_logger( + "actor", save_data=False, steps_key="actor_steps" + ) + + # Create the loop to connect environment and agent. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, variable_source: acme.VariableSource, counter: counting.Counter, + ): + """The evaluation process.""" + environment = self._environment_factory(True) + network = self._network_factory(self._environment_spec.actions) + + tf2_utils.create_variables(network, [self._obs_spec]) + policy_network = snt.DeepRNN( + [network, lambda qs: tf.cast(tf.argmax(qs, axis=-1), tf.int32),] + ) + + variable_client = tf2_variable_utils.VariableClient( + client=variable_source, + variables={"policy": policy_network.variables}, + update_period=self._variable_update_period, + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + # Create the agent. + actor = actors.RecurrentActor( + policy_network=policy_network, variable_client=variable_client + ) + + # Create the run loop and return it. + logger = loggers.make_default_logger( + "evaluator", save_data=True, steps_key="evaluator_steps" + ) + counter = counting.Counter(counter, "evaluator") + + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def build(self, name="r2d2"): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group("replay"): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group("counter"): + counter = program.add_node(lp.CourierNode(self.counter)) + + with program.group("learner"): + learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) + + with program.group("cacher"): + cacher = program.add_node( + lp.CacherNode(learner, refresh_interval_ms=2000, stale_after_ms=4000) + ) + + with program.group("evaluator"): + program.add_node(lp.CourierNode(self.evaluator, cacher, counter)) + + # Generate an epsilon for each actor. + epsilons = np.flip(np.logspace(1, 8, self._num_actors, base=0.4), axis=0) + + with program.group("actor"): + for epsilon in epsilons: + program.add_node( + lp.CourierNode(self.actor, replay, cacher, counter, epsilon) + ) + + return program diff --git a/acme/agents/tf/r2d2/agent_distributed_test.py b/acme/agents/tf/r2d2/agent_distributed_test.py index 81000f63f2..5cafdd8778 100644 --- a/acme/agents/tf/r2d2/agent_distributed_test.py +++ b/acme/agents/tf/r2d2/agent_distributed_test.py @@ -14,45 +14,45 @@ """Integration test for the distributed agent.""" +import launchpad as lp +from absl.testing import absltest + import acme from acme.agents.tf import r2d2 from acme.testing import fakes from acme.tf import networks -import launchpad as lp - -from absl.testing import absltest class DistributedAgentTest(absltest.TestCase): - """Simple integration/smoke test for the distributed agent.""" + """Simple integration/smoke test for the distributed agent.""" - def test_agent(self): - env_factory = lambda x: fakes.fake_atari_wrapped(oar_wrapper=True) - net_factory = lambda spec: networks.R2D2AtariNetwork(spec.num_values) + def test_agent(self): + env_factory = lambda x: fakes.fake_atari_wrapped(oar_wrapper=True) + net_factory = lambda spec: networks.R2D2AtariNetwork(spec.num_values) - agent = r2d2.DistributedR2D2( - environment_factory=env_factory, - network_factory=net_factory, - num_actors=2, - batch_size=32, - min_replay_size=32, - max_replay_size=1000, - replay_period=1, - burn_in_length=1, - trace_length=10, - ) - program = agent.build() + agent = r2d2.DistributedR2D2( + environment_factory=env_factory, + network_factory=net_factory, + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + replay_period=1, + burn_in_length=1, + trace_length=10, + ) + program = agent.build() - (learner_node,) = program.groups['learner'] - learner_node.disable_run() + (learner_node,) = program.groups["learner"] + learner_node.disable_run() - lp.launch(program, launch_type='test_mt') + lp.launch(program, launch_type="test_mt") - learner: acme.Learner = learner_node.create_handle().dereference() + learner: acme.Learner = learner_node.create_handle().dereference() - for _ in range(5): - learner.step() + for _ in range(5): + learner.step() -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/r2d2/agent_test.py b/acme/agents/tf/r2d2/agent_test.py index 033667824f..0bf7e69b3e 100644 --- a/acme/agents/tf/r2d2/agent_test.py +++ b/acme/agents/tf/r2d2/agent_test.py @@ -14,71 +14,71 @@ """Tests for RDQN agent.""" +import numpy as np +import sonnet as snt +from absl.testing import absltest, parameterized + import acme from acme import specs from acme.agents.tf import r2d2 from acme.testing import fakes from acme.tf import networks -import numpy as np -import sonnet as snt - -from absl.testing import absltest -from absl.testing import parameterized class SimpleNetwork(networks.RNNCore): + def __init__(self, action_spec: specs.DiscreteArray): + super().__init__(name="r2d2_test_network") + self._net = snt.DeepRNN( + [ + snt.Flatten(), + snt.LSTM(20), + snt.nets.MLP([50, 50, action_spec.num_values]), + ] + ) - def __init__(self, action_spec: specs.DiscreteArray): - super().__init__(name='r2d2_test_network') - self._net = snt.DeepRNN([ - snt.Flatten(), - snt.LSTM(20), - snt.nets.MLP([50, 50, action_spec.num_values]) - ]) + def __call__(self, inputs, state): + return self._net(inputs, state) - def __call__(self, inputs, state): - return self._net(inputs, state) + def initial_state(self, batch_size: int, **kwargs): + return self._net.initial_state(batch_size) - def initial_state(self, batch_size: int, **kwargs): - return self._net.initial_state(batch_size) - - def unroll(self, inputs, state, sequence_length): - return snt.static_unroll(self._net, inputs, state, sequence_length) + def unroll(self, inputs, state, sequence_length): + return snt.static_unroll(self._net, inputs, state, sequence_length) class R2D2Test(parameterized.TestCase): - - @parameterized.parameters(True, False) - def test_r2d2(self, store_lstm_state: bool): - # Create a fake environment to test with. - # TODO(b/152596848): Allow R2D2 to deal with integer observations. - environment = fakes.DiscreteEnvironment( - num_actions=5, - num_observations=10, - obs_shape=(10, 4), - obs_dtype=np.float32, - episode_length=10) - spec = specs.make_environment_spec(environment) - - # Construct the agent. - agent = r2d2.R2D2( - environment_spec=spec, - network=SimpleNetwork(spec.actions), - batch_size=10, - samples_per_insert=2, - min_replay_size=10, - store_lstm_state=store_lstm_state, - burn_in_length=2, - trace_length=6, - replay_period=4, - checkpoint=False, - ) - - # Try running the environment loop. We have no assertions here because all - # we care about is that the agent runs without raising any errors. - loop = acme.EnvironmentLoop(environment, agent) - loop.run(num_episodes=5) - - -if __name__ == '__main__': - absltest.main() + @parameterized.parameters(True, False) + def test_r2d2(self, store_lstm_state: bool): + # Create a fake environment to test with. + # TODO(b/152596848): Allow R2D2 to deal with integer observations. + environment = fakes.DiscreteEnvironment( + num_actions=5, + num_observations=10, + obs_shape=(10, 4), + obs_dtype=np.float32, + episode_length=10, + ) + spec = specs.make_environment_spec(environment) + + # Construct the agent. + agent = r2d2.R2D2( + environment_spec=spec, + network=SimpleNetwork(spec.actions), + batch_size=10, + samples_per_insert=2, + min_replay_size=10, + store_lstm_state=store_lstm_state, + burn_in_length=2, + trace_length=6, + replay_period=4, + checkpoint=False, + ) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=5) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/r2d2/learning.py b/acme/agents/tf/r2d2/learning.py index 5d8f9d520c..4e0129f2b4 100644 --- a/acme/agents/tf/r2d2/learning.py +++ b/acme/agents/tf/r2d2/learning.py @@ -16,226 +16,235 @@ import functools import time -from typing import Dict, Iterator, List, Mapping, Union, Optional +from typing import Dict, Iterator, List, Mapping, Optional, Union -import acme -from acme import specs -from acme.adders import reverb as adders -from acme.tf import losses -from acme.tf import networks -from acme.tf import savers as tf2_savers -from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers import numpy as np import reverb import sonnet as snt import tensorflow as tf import tree +import acme +from acme import specs +from acme.adders import reverb as adders +from acme.tf import losses, networks +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting, loggers + Variables = List[np.ndarray] class R2D2Learner(acme.Learner, tf2_savers.TFSaveable): - """R2D2 learner. + """R2D2 learner. This is the learning component of the R2D2 agent. It takes a dataset as input and implements update functionality to learn from this dataset. """ - def __init__( - self, - environment_spec: specs.EnvironmentSpec, - network: Union[networks.RNNCore, snt.RNNCore], - target_network: Union[networks.RNNCore, snt.RNNCore], - burn_in_length: int, - sequence_length: int, - dataset: tf.data.Dataset, - reverb_client: Optional[reverb.TFClient] = None, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - discount: float = 0.99, - target_update_period: int = 100, - importance_sampling_exponent: float = 0.2, - max_replay_size: int = 1_000_000, - learning_rate: float = 1e-3, - # TODO(sergomez): rename to use_core_state for consistency with JAX agent. - store_lstm_state: bool = True, - max_priority_weight: float = 0.9, - n_step: int = 5, - clip_grad_norm: Optional[float] = None, - ): - - if not isinstance(network, networks.RNNCore): - network.unroll = functools.partial(snt.static_unroll, network) - target_network.unroll = functools.partial(snt.static_unroll, - target_network) - - # Internalise agent components (replay buffer, networks, optimizer). - # TODO(b/155086959): Fix type stubs and remove. - self._iterator: Iterator[reverb.ReplaySample] = iter(dataset) # pytype: disable=wrong-arg-types - self._network = network - self._target_network = target_network - self._optimizer = snt.optimizers.Adam(learning_rate, epsilon=1e-3) - self._reverb_client = reverb_client - - # Internalise the hyperparameters. - self._store_lstm_state = store_lstm_state - self._burn_in_length = burn_in_length - self._discount = discount - self._max_replay_size = max_replay_size - self._importance_sampling_exponent = importance_sampling_exponent - self._max_priority_weight = max_priority_weight - self._target_update_period = target_update_period - self._num_actions = environment_spec.actions.num_values - self._sequence_length = sequence_length - self._n_step = n_step - self._clip_grad_norm = clip_grad_norm - - if burn_in_length: - self._burn_in = lambda o, s: self._network.unroll(o, s, burn_in_length) - else: - self._burn_in = lambda o, s: (o, s) # pylint: disable=unnecessary-lambda - - # Learner state. - self._variables = network.variables - self._num_steps = tf.Variable( - 0., dtype=tf.float32, trainable=False, name='step') - - # Internalise logging/counting objects. - self._counter = counting.Counter(counter, 'learner') - self._logger = logger or loggers.TerminalLogger('learner', time_delta=100.) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - @tf.function - def _step(self) -> Dict[str, tf.Tensor]: - - # Draw a batch of data from replay. - sample: reverb.ReplaySample = next(self._iterator) - - data = tf2_utils.batch_to_sequence(sample.data) - observations, actions, rewards, discounts, extra = (data.observation, - data.action, - data.reward, - data.discount, - data.extras) - unused_sequence_length, batch_size = actions.shape - - # Get initial state for the LSTM, either from replay or simply use zeros. - if self._store_lstm_state: - core_state = tree.map_structure(lambda x: x[0], extra['core_state']) - else: - core_state = self._network.initial_state(batch_size) - target_core_state = tree.map_structure(tf.identity, core_state) - - # Before training, optionally unroll the LSTM for a fixed warmup period. - burn_in_obs = tree.map_structure(lambda x: x[:self._burn_in_length], - observations) - _, core_state = self._burn_in(burn_in_obs, core_state) - _, target_core_state = self._burn_in(burn_in_obs, target_core_state) - - # Don't train on the warmup period. - observations, actions, rewards, discounts, extra = tree.map_structure( - lambda x: x[self._burn_in_length:], - (observations, actions, rewards, discounts, extra)) - - with tf.GradientTape() as tape: - # Unroll the online and target Q-networks on the sequences. - q_values, _ = self._network.unroll(observations, core_state, - self._sequence_length) - target_q_values, _ = self._target_network.unroll(observations, - target_core_state, - self._sequence_length) - - # Compute the target policy distribution (greedy). - greedy_actions = tf.argmax(q_values, output_type=tf.int32, axis=-1) - target_policy_probs = tf.one_hot( - greedy_actions, depth=self._num_actions, dtype=q_values.dtype) - - # Compute the transformed n-step loss. - rewards = tree.map_structure(lambda x: x[:-1], rewards) - discounts = tree.map_structure(lambda x: x[:-1], discounts) - loss, extra = losses.transformed_n_step_loss( - qs=q_values, - targnet_qs=target_q_values, - actions=actions, - rewards=rewards, - pcontinues=discounts * self._discount, - target_policy_probs=target_policy_probs, - bootstrap_n=self._n_step, - ) - - # Calculate importance weights and use them to scale the loss. - sample_info = sample.info - keys, probs = sample_info.key, sample_info.probability - importance_weights = 1. / (self._max_replay_size * probs) # [T, B] - importance_weights **= self._importance_sampling_exponent - importance_weights /= tf.reduce_max(importance_weights) - loss *= tf.cast(importance_weights, tf.float32) # [T, B] - loss = tf.reduce_mean(loss) # [] - - # Apply gradients via optimizer. - gradients = tape.gradient(loss, self._network.trainable_variables) - # Clip and apply gradients. - if self._clip_grad_norm is not None: - gradients, _ = tf.clip_by_global_norm(gradients, self._clip_grad_norm) - - self._optimizer.apply(gradients, self._network.trainable_variables) - - # Periodically update the target network. - if tf.math.mod(self._num_steps, self._target_update_period) == 0: - for src, dest in zip(self._network.variables, - self._target_network.variables): - dest.assign(src) - self._num_steps.assign_add(1) - - if self._reverb_client: - # Compute updated priorities. - priorities = compute_priority(extra.errors, self._max_priority_weight) - # Compute priorities and add an op to update them on the reverb side. - self._reverb_client.update_priorities( - table=adders.DEFAULT_PRIORITY_TABLE, - keys=keys, - priorities=tf.cast(priorities, tf.float64)) - - return {'loss': loss} - - def step(self): - # Run the learning step. - results = self._step() - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Update our counts and record it. - counts = self._counter.increment(steps=1, walltime=elapsed_time) - results.update(counts) - self._logger.write(results) - - def get_variables(self, names: List[str]) -> List[Variables]: - return [tf2_utils.to_numpy(self._variables)] - - @property - def state(self) -> Mapping[str, tf2_savers.Checkpointable]: - """Returns the stateful parts of the learner for checkpointing.""" - return { - 'network': self._network, - 'target_network': self._target_network, - 'optimizer': self._optimizer, - 'num_steps': self._num_steps, - } + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + network: Union[networks.RNNCore, snt.RNNCore], + target_network: Union[networks.RNNCore, snt.RNNCore], + burn_in_length: int, + sequence_length: int, + dataset: tf.data.Dataset, + reverb_client: Optional[reverb.TFClient] = None, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + discount: float = 0.99, + target_update_period: int = 100, + importance_sampling_exponent: float = 0.2, + max_replay_size: int = 1_000_000, + learning_rate: float = 1e-3, + # TODO(sergomez): rename to use_core_state for consistency with JAX agent. + store_lstm_state: bool = True, + max_priority_weight: float = 0.9, + n_step: int = 5, + clip_grad_norm: Optional[float] = None, + ): + + if not isinstance(network, networks.RNNCore): + network.unroll = functools.partial(snt.static_unroll, network) + target_network.unroll = functools.partial(snt.static_unroll, target_network) + + # Internalise agent components (replay buffer, networks, optimizer). + # TODO(b/155086959): Fix type stubs and remove. + self._iterator: Iterator[reverb.ReplaySample] = iter( + dataset + ) # pytype: disable=wrong-arg-types + self._network = network + self._target_network = target_network + self._optimizer = snt.optimizers.Adam(learning_rate, epsilon=1e-3) + self._reverb_client = reverb_client + + # Internalise the hyperparameters. + self._store_lstm_state = store_lstm_state + self._burn_in_length = burn_in_length + self._discount = discount + self._max_replay_size = max_replay_size + self._importance_sampling_exponent = importance_sampling_exponent + self._max_priority_weight = max_priority_weight + self._target_update_period = target_update_period + self._num_actions = environment_spec.actions.num_values + self._sequence_length = sequence_length + self._n_step = n_step + self._clip_grad_norm = clip_grad_norm + + if burn_in_length: + self._burn_in = lambda o, s: self._network.unroll(o, s, burn_in_length) + else: + self._burn_in = lambda o, s: (o, s) # pylint: disable=unnecessary-lambda + + # Learner state. + self._variables = network.variables + self._num_steps = tf.Variable( + 0.0, dtype=tf.float32, trainable=False, name="step" + ) + + # Internalise logging/counting objects. + self._counter = counting.Counter(counter, "learner") + self._logger = logger or loggers.TerminalLogger("learner", time_delta=100.0) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self) -> Dict[str, tf.Tensor]: + + # Draw a batch of data from replay. + sample: reverb.ReplaySample = next(self._iterator) + + data = tf2_utils.batch_to_sequence(sample.data) + observations, actions, rewards, discounts, extra = ( + data.observation, + data.action, + data.reward, + data.discount, + data.extras, + ) + unused_sequence_length, batch_size = actions.shape + + # Get initial state for the LSTM, either from replay or simply use zeros. + if self._store_lstm_state: + core_state = tree.map_structure(lambda x: x[0], extra["core_state"]) + else: + core_state = self._network.initial_state(batch_size) + target_core_state = tree.map_structure(tf.identity, core_state) + + # Before training, optionally unroll the LSTM for a fixed warmup period. + burn_in_obs = tree.map_structure( + lambda x: x[: self._burn_in_length], observations + ) + _, core_state = self._burn_in(burn_in_obs, core_state) + _, target_core_state = self._burn_in(burn_in_obs, target_core_state) + + # Don't train on the warmup period. + observations, actions, rewards, discounts, extra = tree.map_structure( + lambda x: x[self._burn_in_length :], + (observations, actions, rewards, discounts, extra), + ) + + with tf.GradientTape() as tape: + # Unroll the online and target Q-networks on the sequences. + q_values, _ = self._network.unroll( + observations, core_state, self._sequence_length + ) + target_q_values, _ = self._target_network.unroll( + observations, target_core_state, self._sequence_length + ) + + # Compute the target policy distribution (greedy). + greedy_actions = tf.argmax(q_values, output_type=tf.int32, axis=-1) + target_policy_probs = tf.one_hot( + greedy_actions, depth=self._num_actions, dtype=q_values.dtype + ) + + # Compute the transformed n-step loss. + rewards = tree.map_structure(lambda x: x[:-1], rewards) + discounts = tree.map_structure(lambda x: x[:-1], discounts) + loss, extra = losses.transformed_n_step_loss( + qs=q_values, + targnet_qs=target_q_values, + actions=actions, + rewards=rewards, + pcontinues=discounts * self._discount, + target_policy_probs=target_policy_probs, + bootstrap_n=self._n_step, + ) + + # Calculate importance weights and use them to scale the loss. + sample_info = sample.info + keys, probs = sample_info.key, sample_info.probability + importance_weights = 1.0 / (self._max_replay_size * probs) # [T, B] + importance_weights **= self._importance_sampling_exponent + importance_weights /= tf.reduce_max(importance_weights) + loss *= tf.cast(importance_weights, tf.float32) # [T, B] + loss = tf.reduce_mean(loss) # [] + + # Apply gradients via optimizer. + gradients = tape.gradient(loss, self._network.trainable_variables) + #  Clip and apply gradients. + if self._clip_grad_norm is not None: + gradients, _ = tf.clip_by_global_norm(gradients, self._clip_grad_norm) + + self._optimizer.apply(gradients, self._network.trainable_variables) + + # Periodically update the target network. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip( + self._network.variables, self._target_network.variables + ): + dest.assign(src) + self._num_steps.assign_add(1) + + if self._reverb_client: + # Compute updated priorities. + priorities = compute_priority(extra.errors, self._max_priority_weight) + # Compute priorities and add an op to update them on the reverb side. + self._reverb_client.update_priorities( + table=adders.DEFAULT_PRIORITY_TABLE, + keys=keys, + priorities=tf.cast(priorities, tf.float64), + ) + + return {"loss": loss} + + def step(self): + # Run the learning step. + results = self._step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + results.update(counts) + self._logger.write(results) + + def get_variables(self, names: List[str]) -> List[Variables]: + return [tf2_utils.to_numpy(self._variables)] + + @property + def state(self) -> Mapping[str, tf2_savers.Checkpointable]: + """Returns the stateful parts of the learner for checkpointing.""" + return { + "network": self._network, + "target_network": self._target_network, + "optimizer": self._optimizer, + "num_steps": self._num_steps, + } def compute_priority(errors: tf.Tensor, alpha: float): - """Compute priority as mixture of max and mean sequence errors.""" - abs_errors = tf.abs(errors) - mean_priority = tf.reduce_mean(abs_errors, axis=0) - max_priority = tf.reduce_max(abs_errors, axis=0) + """Compute priority as mixture of max and mean sequence errors.""" + abs_errors = tf.abs(errors) + mean_priority = tf.reduce_mean(abs_errors, axis=0) + max_priority = tf.reduce_max(abs_errors, axis=0) - return alpha * max_priority + (1 - alpha) * mean_priority + return alpha * max_priority + (1 - alpha) * mean_priority diff --git a/acme/agents/tf/r2d3/agent.py b/acme/agents/tf/r2d3/agent.py index 8bbe89d24b..b672edcd81 100644 --- a/acme/agents/tf/r2d3/agent.py +++ b/acme/agents/tf/r2d3/agent.py @@ -17,8 +17,13 @@ import functools from typing import Optional -from acme import datasets -from acme import specs +import reverb +import sonnet as snt +import tensorflow as tf +import tree +import trfl + +from acme import datasets, specs from acme import types as acme_types from acme.adders import reverb as adders from acme.agents import agent @@ -26,147 +31,142 @@ from acme.agents.tf.r2d2 import learning from acme.tf import savers as tf2_savers from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers -import reverb -import sonnet as snt -import tensorflow as tf -import tree -import trfl +from acme.utils import counting, loggers class R2D3(agent.Agent): - """R2D3 Agent. + """R2D3 Agent. This implements a single-process R2D2 agent that mixes demonstrations with actor experience. """ - def __init__(self, - environment_spec: specs.EnvironmentSpec, - network: snt.RNNCore, - target_network: snt.RNNCore, - burn_in_length: int, - trace_length: int, - replay_period: int, - demonstration_dataset: tf.data.Dataset, - demonstration_ratio: float, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - discount: float = 0.99, - batch_size: int = 32, - target_update_period: int = 100, - importance_sampling_exponent: float = 0.2, - epsilon: float = 0.01, - learning_rate: float = 1e-3, - save_logs: bool = False, - log_name: str = 'agent', - checkpoint: bool = True, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: float = 32.0): - - sequence_length = burn_in_length + trace_length + 1 - extra_spec = { - 'core_state': network.initial_state(1), - } - # Remove batch dimensions. - extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) - replay_table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=max_replay_size, - rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), - signature=adders.SequenceAdder.signature( - environment_spec, extra_spec, sequence_length=sequence_length)) - self._server = reverb.Server([replay_table], port=None) - address = f'localhost:{self._server.port}' - - # Component to add things into replay. - sequence_kwargs = dict( - period=replay_period, - sequence_length=sequence_length, - ) - adder = adders.SequenceAdder(client=reverb.Client(address), - **sequence_kwargs) - - # The dataset object to learn from. - dataset = datasets.make_reverb_dataset( - server_address=address) - - # Combine with demonstration dataset. - transition = functools.partial(_sequence_from_episode, - extra_spec=extra_spec, - **sequence_kwargs) - dataset_demos = demonstration_dataset.map(transition) - dataset = tf.data.experimental.sample_from_datasets( - [dataset, dataset_demos], - [1 - demonstration_ratio, demonstration_ratio]) - - # Batch and prefetch. - dataset = dataset.batch(batch_size, drop_remainder=True) - dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) - - tf2_utils.create_variables(network, [environment_spec.observations]) - tf2_utils.create_variables(target_network, [environment_spec.observations]) - - learner = learning.R2D2Learner( - environment_spec=environment_spec, - network=network, - target_network=target_network, - burn_in_length=burn_in_length, - dataset=dataset, - reverb_client=reverb.TFClient(address), - counter=counter, - logger=logger, - sequence_length=sequence_length, - discount=discount, - target_update_period=target_update_period, - importance_sampling_exponent=importance_sampling_exponent, - max_replay_size=max_replay_size, - learning_rate=learning_rate, - store_lstm_state=False, - ) - - self._checkpointer = tf2_savers.Checkpointer( - subdirectory='r2d2_learner', - time_delta_minutes=60, - objects_to_save=learner.state, - enable_checkpointing=checkpoint, - ) - - self._snapshotter = tf2_savers.Snapshotter( - objects_to_save={'network': network}, time_delta_minutes=60.) - - policy_network = snt.DeepRNN([ - network, - lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), - ]) - - actor = actors.RecurrentActor(policy_network, adder) - observations_per_step = (float(replay_period * batch_size) / - samples_per_insert) - super().__init__( - actor=actor, - learner=learner, - min_observations=replay_period * max(batch_size, min_replay_size), - observations_per_step=observations_per_step) - - def update(self): - super().update() - self._snapshotter.save() - self._checkpointer.save() - - -def _sequence_from_episode(observations: acme_types.NestedTensor, - actions: tf.Tensor, - rewards: tf.Tensor, - discounts: tf.Tensor, - extra_spec: acme_types.NestedSpec, - period: int, - sequence_length: int): - """Produce Reverb-like sequence from a full episode. + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + network: snt.RNNCore, + target_network: snt.RNNCore, + burn_in_length: int, + trace_length: int, + replay_period: int, + demonstration_dataset: tf.data.Dataset, + demonstration_ratio: float, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + discount: float = 0.99, + batch_size: int = 32, + target_update_period: int = 100, + importance_sampling_exponent: float = 0.2, + epsilon: float = 0.01, + learning_rate: float = 1e-3, + save_logs: bool = False, + log_name: str = "agent", + checkpoint: bool = True, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 32.0, + ): + + sequence_length = burn_in_length + trace_length + 1 + extra_spec = { + "core_state": network.initial_state(1), + } + # Remove batch dimensions. + extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), + signature=adders.SequenceAdder.signature( + environment_spec, extra_spec, sequence_length=sequence_length + ), + ) + self._server = reverb.Server([replay_table], port=None) + address = f"localhost:{self._server.port}" + + # Component to add things into replay. + sequence_kwargs = dict(period=replay_period, sequence_length=sequence_length,) + adder = adders.SequenceAdder(client=reverb.Client(address), **sequence_kwargs) + + # The dataset object to learn from. + dataset = datasets.make_reverb_dataset(server_address=address) + + # Combine with demonstration dataset. + transition = functools.partial( + _sequence_from_episode, extra_spec=extra_spec, **sequence_kwargs + ) + dataset_demos = demonstration_dataset.map(transition) + dataset = tf.data.experimental.sample_from_datasets( + [dataset, dataset_demos], [1 - demonstration_ratio, demonstration_ratio] + ) + + # Batch and prefetch. + dataset = dataset.batch(batch_size, drop_remainder=True) + dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) + + tf2_utils.create_variables(network, [environment_spec.observations]) + tf2_utils.create_variables(target_network, [environment_spec.observations]) + + learner = learning.R2D2Learner( + environment_spec=environment_spec, + network=network, + target_network=target_network, + burn_in_length=burn_in_length, + dataset=dataset, + reverb_client=reverb.TFClient(address), + counter=counter, + logger=logger, + sequence_length=sequence_length, + discount=discount, + target_update_period=target_update_period, + importance_sampling_exponent=importance_sampling_exponent, + max_replay_size=max_replay_size, + learning_rate=learning_rate, + store_lstm_state=False, + ) + + self._checkpointer = tf2_savers.Checkpointer( + subdirectory="r2d2_learner", + time_delta_minutes=60, + objects_to_save=learner.state, + enable_checkpointing=checkpoint, + ) + + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={"network": network}, time_delta_minutes=60.0 + ) + + policy_network = snt.DeepRNN( + [network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),] + ) + + actor = actors.RecurrentActor(policy_network, adder) + observations_per_step = float(replay_period * batch_size) / samples_per_insert + super().__init__( + actor=actor, + learner=learner, + min_observations=replay_period * max(batch_size, min_replay_size), + observations_per_step=observations_per_step, + ) + + def update(self): + super().update() + self._snapshotter.save() + self._checkpointer.save() + + +def _sequence_from_episode( + observations: acme_types.NestedTensor, + actions: tf.Tensor, + rewards: tf.Tensor, + discounts: tf.Tensor, + extra_spec: acme_types.NestedSpec, + period: int, + sequence_length: int, +): + """Produce Reverb-like sequence from a full episode. Observations, actions, rewards and discounts have the same length. This function will ignore the first reward and discount and the last action. @@ -189,38 +189,41 @@ def _sequence_from_episode(observations: acme_types.NestedTensor, (o_t, a_t, r_t, d_t, e_t) Tuple. """ - length = tf.shape(rewards)[0] - first = tf.random.uniform(shape=(), minval=0, maxval=length, dtype=tf.int32) - first = first // period * period # Get a multiple of `period`. - to = tf.minimum(first + sequence_length, length) - - def _slice_and_pad(x): - pad_length = sequence_length + first - to - padding_shape = tf.concat([[pad_length], tf.shape(x)[1:]], axis=0) - result = tf.concat([x[first:to], tf.zeros(padding_shape, x.dtype)], axis=0) - result.set_shape([sequence_length] + x.shape.as_list()[1:]) - return result - - o_t = tree.map_structure(_slice_and_pad, observations) - a_t = tree.map_structure(_slice_and_pad, actions) - r_t = _slice_and_pad(rewards) - d_t = _slice_and_pad(discounts) - start_of_episode = tf.equal(first, 0) - start_of_episode = tf.expand_dims(start_of_episode, axis=0) - start_of_episode = tf.tile(start_of_episode, [sequence_length]) - - def _sequence_zeros(spec): - return tf.zeros([sequence_length] + spec.shape, spec.dtype) - - e_t = tree.map_structure(_sequence_zeros, extra_spec) - info = tree.map_structure(lambda dtype: tf.ones([], dtype), - reverb.SampleInfo.tf_dtypes()) - return reverb.ReplaySample( - info=info, - data=adders.Step( - observation=o_t, - action=a_t, - reward=r_t, - discount=d_t, - start_of_episode=start_of_episode, - extras=e_t)) + length = tf.shape(rewards)[0] + first = tf.random.uniform(shape=(), minval=0, maxval=length, dtype=tf.int32) + first = first // period * period # Get a multiple of `period`. + to = tf.minimum(first + sequence_length, length) + + def _slice_and_pad(x): + pad_length = sequence_length + first - to + padding_shape = tf.concat([[pad_length], tf.shape(x)[1:]], axis=0) + result = tf.concat([x[first:to], tf.zeros(padding_shape, x.dtype)], axis=0) + result.set_shape([sequence_length] + x.shape.as_list()[1:]) + return result + + o_t = tree.map_structure(_slice_and_pad, observations) + a_t = tree.map_structure(_slice_and_pad, actions) + r_t = _slice_and_pad(rewards) + d_t = _slice_and_pad(discounts) + start_of_episode = tf.equal(first, 0) + start_of_episode = tf.expand_dims(start_of_episode, axis=0) + start_of_episode = tf.tile(start_of_episode, [sequence_length]) + + def _sequence_zeros(spec): + return tf.zeros([sequence_length] + spec.shape, spec.dtype) + + e_t = tree.map_structure(_sequence_zeros, extra_spec) + info = tree.map_structure( + lambda dtype: tf.ones([], dtype), reverb.SampleInfo.tf_dtypes() + ) + return reverb.ReplaySample( + info=info, + data=adders.Step( + observation=o_t, + action=a_t, + reward=r_t, + discount=d_t, + start_of_episode=start_of_episode, + extras=e_t, + ), + ) diff --git a/acme/agents/tf/r2d3/agent_test.py b/acme/agents/tf/r2d3/agent_test.py index e5822166d2..4cebc9fbd9 100644 --- a/acme/agents/tf/r2d3/agent_test.py +++ b/acme/agents/tf/r2d3/agent_test.py @@ -14,81 +14,79 @@ """Tests for R2D3 agent.""" +import dm_env +import numpy as np +import sonnet as snt +from absl.testing import absltest + import acme from acme import specs from acme.agents.tf import r2d3 from acme.agents.tf.dqfd import bsuite_demonstrations from acme.testing import fakes from acme.tf import networks -import dm_env -import numpy as np -import sonnet as snt - -from absl.testing import absltest class SimpleNetwork(networks.RNNCore): + def __init__(self, action_spec: specs.DiscreteArray): + super().__init__(name="r2d2_test_network") + self._net = snt.DeepRNN( + [ + snt.Flatten(), + snt.LSTM(20), + snt.nets.MLP([50, 50, action_spec.num_values]), + ] + ) - def __init__(self, action_spec: specs.DiscreteArray): - super().__init__(name='r2d2_test_network') - self._net = snt.DeepRNN([ - snt.Flatten(), - snt.LSTM(20), - snt.nets.MLP([50, 50, action_spec.num_values]) - ]) - - def __call__(self, inputs, state): - return self._net(inputs, state) + def __call__(self, inputs, state): + return self._net(inputs, state) - def initial_state(self, batch_size: int, **kwargs): - return self._net.initial_state(batch_size) + def initial_state(self, batch_size: int, **kwargs): + return self._net.initial_state(batch_size) - def unroll(self, inputs, state, sequence_length): - return snt.static_unroll(self._net, inputs, state, sequence_length) + def unroll(self, inputs, state, sequence_length): + return snt.static_unroll(self._net, inputs, state, sequence_length) class R2D3Test(absltest.TestCase): - - def test_r2d3(self): - # Create a fake environment to test with. - environment = fakes.DiscreteEnvironment( - num_actions=5, - num_observations=10, - obs_dtype=np.float32, - episode_length=10) - spec = specs.make_environment_spec(environment) - - # Build demonstrations. - dummy_action = np.zeros((), dtype=np.int32) - recorder = bsuite_demonstrations.DemonstrationRecorder() - timestep = environment.reset() - while timestep.step_type is not dm_env.StepType.LAST: - recorder.step(timestep, dummy_action) - timestep = environment.step(dummy_action) - recorder.step(timestep, dummy_action) - recorder.record_episode() - - # Construct the agent. - agent = r2d3.R2D3( - environment_spec=spec, - network=SimpleNetwork(spec.actions), - target_network=SimpleNetwork(spec.actions), - demonstration_dataset=recorder.make_tf_dataset(), - demonstration_ratio=0.5, - batch_size=10, - samples_per_insert=2, - min_replay_size=10, - burn_in_length=2, - trace_length=6, - replay_period=4, - checkpoint=False, - ) - - # Try running the environment loop. We have no assertions here because all - # we care about is that the agent runs without raising any errors. - loop = acme.EnvironmentLoop(environment, agent) - loop.run(num_episodes=5) - - -if __name__ == '__main__': - absltest.main() + def test_r2d3(self): + # Create a fake environment to test with. + environment = fakes.DiscreteEnvironment( + num_actions=5, num_observations=10, obs_dtype=np.float32, episode_length=10 + ) + spec = specs.make_environment_spec(environment) + + # Build demonstrations. + dummy_action = np.zeros((), dtype=np.int32) + recorder = bsuite_demonstrations.DemonstrationRecorder() + timestep = environment.reset() + while timestep.step_type is not dm_env.StepType.LAST: + recorder.step(timestep, dummy_action) + timestep = environment.step(dummy_action) + recorder.step(timestep, dummy_action) + recorder.record_episode() + + # Construct the agent. + agent = r2d3.R2D3( + environment_spec=spec, + network=SimpleNetwork(spec.actions), + target_network=SimpleNetwork(spec.actions), + demonstration_dataset=recorder.make_tf_dataset(), + demonstration_ratio=0.5, + batch_size=10, + samples_per_insert=2, + min_replay_size=10, + burn_in_length=2, + trace_length=6, + replay_period=4, + checkpoint=False, + ) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=5) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/svg0_prior/__init__.py b/acme/agents/tf/svg0_prior/__init__.py index b4218db22c..854c7017e8 100644 --- a/acme/agents/tf/svg0_prior/__init__.py +++ b/acme/agents/tf/svg0_prior/__init__.py @@ -17,5 +17,7 @@ from acme.agents.tf.svg0_prior.agent import SVG0 from acme.agents.tf.svg0_prior.agent_distributed import DistributedSVG0 from acme.agents.tf.svg0_prior.learning import SVG0Learner -from acme.agents.tf.svg0_prior.networks import make_default_networks -from acme.agents.tf.svg0_prior.networks import make_network_with_prior +from acme.agents.tf.svg0_prior.networks import ( + make_default_networks, + make_network_with_prior, +) diff --git a/acme/agents/tf/svg0_prior/acting.py b/acme/agents/tf/svg0_prior/acting.py index f044a14e4b..ac0d9c5f23 100644 --- a/acme/agents/tf/svg0_prior/acting.py +++ b/acme/agents/tf/svg0_prior/acting.py @@ -16,52 +16,48 @@ from typing import Optional -from acme import adders -from acme import types +import dm_env +import sonnet as snt +from acme import adders, types from acme.agents.tf import actors from acme.tf import utils as tf2_utils from acme.tf import variable_utils as tf2_variable_utils -import dm_env -import sonnet as snt - class SVG0Actor(actors.FeedForwardActor): - """An actor that also returns `log_prob`.""" - - def __init__( - self, - policy_network: snt.Module, - adder: Optional[adders.Adder] = None, - variable_client: Optional[tf2_variable_utils.VariableClient] = None, - deterministic_policy: Optional[bool] = False, - ): - super().__init__(policy_network, adder, variable_client) - self._log_prob = None - self._deterministic_policy = deterministic_policy - - def select_action(self, observation: types.NestedArray) -> types.NestedArray: - # Add a dummy batch dimension and as a side effect convert numpy to TF. - batched_observation = tf2_utils.add_batch_dim(observation) - - # Compute the policy, conditioned on the observation. - policy = self._policy_network(batched_observation) - if self._deterministic_policy: - action = policy.mean() - else: - action = policy.sample() - self._log_prob = policy.log_prob(action) - return tf2_utils.to_numpy_squeeze(action) - - def observe( - self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - ): - if not self._adder: - return - - extras = {'log_prob': self._log_prob} - extras = tf2_utils.to_numpy_squeeze(extras) - self._adder.add(action, next_timestep, extras) + """An actor that also returns `log_prob`.""" + + def __init__( + self, + policy_network: snt.Module, + adder: Optional[adders.Adder] = None, + variable_client: Optional[tf2_variable_utils.VariableClient] = None, + deterministic_policy: Optional[bool] = False, + ): + super().__init__(policy_network, adder, variable_client) + self._log_prob = None + self._deterministic_policy = deterministic_policy + + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + # Add a dummy batch dimension and as a side effect convert numpy to TF. + batched_observation = tf2_utils.add_batch_dim(observation) + + # Compute the policy, conditioned on the observation. + policy = self._policy_network(batched_observation) + if self._deterministic_policy: + action = policy.mean() + else: + action = policy.sample() + self._log_prob = policy.log_prob(action) + return tf2_utils.to_numpy_squeeze(action) + + def observe( + self, action: types.NestedArray, next_timestep: dm_env.TimeStep, + ): + if not self._adder: + return + + extras = {"log_prob": self._log_prob} + extras = tf2_utils.to_numpy_squeeze(extras) + self._adder.add(action, next_timestep, extras) diff --git a/acme/agents/tf/svg0_prior/agent.py b/acme/agents/tf/svg0_prior/agent.py index e9303c7261..e9a1a14a1c 100644 --- a/acme/agents/tf/svg0_prior/agent.py +++ b/acme/agents/tf/svg0_prior/agent.py @@ -18,227 +18,213 @@ import dataclasses from typing import Iterator, List, Optional, Tuple -from acme import adders -from acme import core -from acme import datasets -from acme import specs -from acme.adders import reverb as reverb_adders -from acme.agents import agent -from acme.agents.tf.svg0_prior import acting -from acme.agents.tf.svg0_prior import learning -from acme.tf import utils -from acme.tf import variable_utils -from acme.utils import counting -from acme.utils import loggers import reverb import sonnet as snt import tensorflow as tf +from acme import adders, core, datasets, specs +from acme.adders import reverb as reverb_adders +from acme.agents import agent +from acme.agents.tf.svg0_prior import acting, learning +from acme.tf import utils, variable_utils +from acme.utils import counting, loggers + @dataclasses.dataclass class SVG0Config: - """Configuration options for the agent.""" - - discount: float = 0.99 - batch_size: int = 256 - prefetch_size: int = 4 - target_update_period: int = 100 - policy_optimizer: Optional[snt.Optimizer] = None - critic_optimizer: Optional[snt.Optimizer] = None - prior_optimizer: Optional[snt.Optimizer] = None - min_replay_size: int = 1000 - max_replay_size: int = 1000000 - samples_per_insert: Optional[float] = 32.0 - sequence_length: int = 10 - sigma: float = 0.3 - replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE - distillation_cost: Optional[float] = 1e-3 - entropy_regularizer_cost: Optional[float] = 1e-3 + """Configuration options for the agent.""" + + discount: float = 0.99 + batch_size: int = 256 + prefetch_size: int = 4 + target_update_period: int = 100 + policy_optimizer: Optional[snt.Optimizer] = None + critic_optimizer: Optional[snt.Optimizer] = None + prior_optimizer: Optional[snt.Optimizer] = None + min_replay_size: int = 1000 + max_replay_size: int = 1000000 + samples_per_insert: Optional[float] = 32.0 + sequence_length: int = 10 + sigma: float = 0.3 + replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE + distillation_cost: Optional[float] = 1e-3 + entropy_regularizer_cost: Optional[float] = 1e-3 @dataclasses.dataclass class SVG0Networks: - """Structure containing the networks for SVG0.""" - - policy_network: snt.Module - critic_network: snt.Module - prior_network: Optional[snt.Module] - - def __init__( - self, - policy_network: snt.Module, - critic_network: snt.Module, - prior_network: Optional[snt.Module] = None - ): - # This method is implemented (rather than added by the dataclass decorator) - # in order to allow observation network to be passed as an arbitrary tensor - # transformation rather than as a snt Module. - # TODO(mwhoffman): use Protocol rather than Module/TensorTransformation. - self.policy_network = policy_network - self.critic_network = critic_network - self.prior_network = prior_network - - def init(self, environment_spec: specs.EnvironmentSpec): - """Initialize the networks given an environment spec.""" - # Get observation and action specs. - act_spec = environment_spec.actions - obs_spec = environment_spec.observations - - # Create variables for the policy and critic nets. - _ = utils.create_variables(self.policy_network, [obs_spec]) - _ = utils.create_variables(self.critic_network, [obs_spec, act_spec]) - if self.prior_network is not None: - _ = utils.create_variables(self.prior_network, [obs_spec]) - - def make_policy( - self, - ) -> snt.Module: - """Create a single network which evaluates the policy.""" - return self.policy_network - - def make_prior( - self, - ) -> snt.Module: - """Create a single network which evaluates the prior.""" - behavior_prior = self.prior_network - return behavior_prior + """Structure containing the networks for SVG0.""" + + policy_network: snt.Module + critic_network: snt.Module + prior_network: Optional[snt.Module] + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + prior_network: Optional[snt.Module] = None, + ): + # This method is implemented (rather than added by the dataclass decorator) + # in order to allow observation network to be passed as an arbitrary tensor + # transformation rather than as a snt Module. + # TODO(mwhoffman): use Protocol rather than Module/TensorTransformation. + self.policy_network = policy_network + self.critic_network = critic_network + self.prior_network = prior_network + + def init(self, environment_spec: specs.EnvironmentSpec): + """Initialize the networks given an environment spec.""" + # Get observation and action specs. + act_spec = environment_spec.actions + obs_spec = environment_spec.observations + + # Create variables for the policy and critic nets. + _ = utils.create_variables(self.policy_network, [obs_spec]) + _ = utils.create_variables(self.critic_network, [obs_spec, act_spec]) + if self.prior_network is not None: + _ = utils.create_variables(self.prior_network, [obs_spec]) + + def make_policy(self,) -> snt.Module: + """Create a single network which evaluates the policy.""" + return self.policy_network + + def make_prior(self,) -> snt.Module: + """Create a single network which evaluates the prior.""" + behavior_prior = self.prior_network + return behavior_prior class SVG0Builder: - """Builder for SVG0 which constructs individual components of the agent.""" - - def __init__(self, config: SVG0Config): - self._config = config - - def make_replay_tables( - self, - environment_spec: specs.EnvironmentSpec, - sequence_length: int, - ) -> List[reverb.Table]: - """Create tables to insert data into.""" - if self._config.samples_per_insert is None: - # We will take a samples_per_insert ratio of None to mean that there is - # no limit, i.e. this only implies a min size limit. - limiter = reverb.rate_limiters.MinSize(self._config.min_replay_size) - - else: - error_buffer = max(1, self._config.samples_per_insert) - limiter = reverb.rate_limiters.SampleToInsertRatio( - min_size_to_sample=self._config.min_replay_size, - samples_per_insert=self._config.samples_per_insert, - error_buffer=error_buffer) - - extras_spec = { - 'log_prob': tf.ones( - shape=(), dtype=tf.float32) - } - replay_table = reverb.Table( - name=self._config.replay_table_name, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=self._config.max_replay_size, - rate_limiter=limiter, - signature=reverb_adders.SequenceAdder.signature( - environment_spec, - extras_spec=extras_spec, - sequence_length=sequence_length + 1)) - - return [replay_table] - - def make_dataset_iterator( - self, - reverb_client: reverb.Client, - ) -> Iterator[reverb.ReplaySample]: - """Create a dataset iterator to use for learning/updating the agent.""" - # The dataset provides an interface to sample from replay. - dataset = datasets.make_reverb_dataset( - table=self._config.replay_table_name, - server_address=reverb_client.server_address, - batch_size=self._config.batch_size, - prefetch_size=self._config.prefetch_size) - - # TODO(b/155086959): Fix type stubs and remove. - return iter(dataset) # pytype: disable=wrong-arg-types - - def make_adder( - self, - replay_client: reverb.Client, - ) -> adders.Adder: - """Create an adder which records data generated by the actor/environment.""" - return reverb_adders.SequenceAdder( - client=replay_client, - sequence_length=self._config.sequence_length+1, - priority_fns={self._config.replay_table_name: lambda x: 1.}, - period=self._config.sequence_length, - end_of_episode_behavior=reverb_adders.EndBehavior.CONTINUE, + """Builder for SVG0 which constructs individual components of the agent.""" + + def __init__(self, config: SVG0Config): + self._config = config + + def make_replay_tables( + self, environment_spec: specs.EnvironmentSpec, sequence_length: int, + ) -> List[reverb.Table]: + """Create tables to insert data into.""" + if self._config.samples_per_insert is None: + # We will take a samples_per_insert ratio of None to mean that there is + # no limit, i.e. this only implies a min size limit. + limiter = reverb.rate_limiters.MinSize(self._config.min_replay_size) + + else: + error_buffer = max(1, self._config.samples_per_insert) + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer, + ) + + extras_spec = {"log_prob": tf.ones(shape=(), dtype=tf.float32)} + replay_table = reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=reverb_adders.SequenceAdder.signature( + environment_spec, + extras_spec=extras_spec, + sequence_length=sequence_length + 1, + ), + ) + + return [replay_table] + + def make_dataset_iterator( + self, reverb_client: reverb.Client, + ) -> Iterator[reverb.ReplaySample]: + """Create a dataset iterator to use for learning/updating the agent.""" + # The dataset provides an interface to sample from replay. + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=reverb_client.server_address, + batch_size=self._config.batch_size, + prefetch_size=self._config.prefetch_size, + ) + + # TODO(b/155086959): Fix type stubs and remove. + return iter(dataset) # pytype: disable=wrong-arg-types + + def make_adder(self, replay_client: reverb.Client,) -> adders.Adder: + """Create an adder which records data generated by the actor/environment.""" + return reverb_adders.SequenceAdder( + client=replay_client, + sequence_length=self._config.sequence_length + 1, + priority_fns={self._config.replay_table_name: lambda x: 1.0}, + period=self._config.sequence_length, + end_of_episode_behavior=reverb_adders.EndBehavior.CONTINUE, ) - def make_actor( - self, - policy_network: snt.Module, - adder: Optional[adders.Adder] = None, - variable_source: Optional[core.VariableSource] = None, - deterministic_policy: Optional[bool] = False, - ): - """Create an actor instance.""" - if variable_source: - # Create the variable client responsible for keeping the actor up-to-date. - variable_client = variable_utils.VariableClient( - client=variable_source, - variables={'policy': policy_network.variables}, - update_period=1000, - ) - - # Make sure not to use a random policy after checkpoint restoration by - # assigning variables before running the environment loop. - variable_client.update_and_wait() - - else: - variable_client = None - - # Create the actor which defines how we take actions. - return acting.SVG0Actor( - policy_network=policy_network, - adder=adder, - variable_client=variable_client, - deterministic_policy=deterministic_policy - ) - - def make_learner( - self, - networks: Tuple[SVG0Networks, SVG0Networks], - dataset: Iterator[reverb.ReplaySample], - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - checkpoint: bool = False, - ): - """Creates an instance of the learner.""" - online_networks, target_networks = networks - - # The learner updates the parameters (and initializes them). - return learning.SVG0Learner( - policy_network=online_networks.policy_network, - critic_network=online_networks.critic_network, - target_policy_network=target_networks.policy_network, - target_critic_network=target_networks.critic_network, - prior_network=online_networks.prior_network, - target_prior_network=target_networks.prior_network, - policy_optimizer=self._config.policy_optimizer, - critic_optimizer=self._config.critic_optimizer, - prior_optimizer=self._config.prior_optimizer, - distillation_cost=self._config.distillation_cost, - entropy_regularizer_cost=self._config.entropy_regularizer_cost, - discount=self._config.discount, - target_update_period=self._config.target_update_period, - dataset_iterator=dataset, - counter=counter, - logger=logger, - checkpoint=checkpoint, - ) + def make_actor( + self, + policy_network: snt.Module, + adder: Optional[adders.Adder] = None, + variable_source: Optional[core.VariableSource] = None, + deterministic_policy: Optional[bool] = False, + ): + """Create an actor instance.""" + if variable_source: + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = variable_utils.VariableClient( + client=variable_source, + variables={"policy": policy_network.variables}, + update_period=1000, + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + else: + variable_client = None + + # Create the actor which defines how we take actions. + return acting.SVG0Actor( + policy_network=policy_network, + adder=adder, + variable_client=variable_client, + deterministic_policy=deterministic_policy, + ) + + def make_learner( + self, + networks: Tuple[SVG0Networks, SVG0Networks], + dataset: Iterator[reverb.ReplaySample], + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = False, + ): + """Creates an instance of the learner.""" + online_networks, target_networks = networks + + # The learner updates the parameters (and initializes them). + return learning.SVG0Learner( + policy_network=online_networks.policy_network, + critic_network=online_networks.critic_network, + target_policy_network=target_networks.policy_network, + target_critic_network=target_networks.critic_network, + prior_network=online_networks.prior_network, + target_prior_network=target_networks.prior_network, + policy_optimizer=self._config.policy_optimizer, + critic_optimizer=self._config.critic_optimizer, + prior_optimizer=self._config.prior_optimizer, + distillation_cost=self._config.distillation_cost, + entropy_regularizer_cost=self._config.entropy_regularizer_cost, + discount=self._config.discount, + target_update_period=self._config.target_update_period, + dataset_iterator=dataset, + counter=counter, + logger=logger, + checkpoint=checkpoint, + ) class SVG0(agent.Agent): - """SVG0 Agent with prior. + """SVG0 Agent with prior. This implements a single-process SVG0 agent. This is an actor-critic algorithm that generates data via a behavior policy, inserts N-step transitions into @@ -246,32 +232,32 @@ class SVG0(agent.Agent): behavior) by sampling uniformly from this buffer. """ - def __init__( - self, - environment_spec: specs.EnvironmentSpec, - policy_network: snt.Module, - critic_network: snt.Module, - discount: float = 0.99, - batch_size: int = 256, - prefetch_size: int = 4, - target_update_period: int = 100, - prior_network: Optional[snt.Module] = None, - policy_optimizer: Optional[snt.Optimizer] = None, - critic_optimizer: Optional[snt.Optimizer] = None, - prior_optimizer: Optional[snt.Optimizer] = None, - distillation_cost: Optional[float] = 1e-3, - entropy_regularizer_cost: Optional[float] = 1e-3, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: float = 32.0, - sequence_length: int = 10, - sigma: float = 0.3, - replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - checkpoint: bool = True, - ): - """Initialize the agent. + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + policy_network: snt.Module, + critic_network: snt.Module, + discount: float = 0.99, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + prior_network: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + prior_optimizer: Optional[snt.Optimizer] = None, + distillation_cost: Optional[float] = 1e-3, + entropy_regularizer_cost: Optional[float] = 1e-3, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 32.0, + sequence_length: int = 10, + sigma: float = 0.3, + replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. @@ -301,70 +287,72 @@ def __init__( logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ - # Create the Builder object which will internally create agent components. - builder = SVG0Builder( - # TODO(mwhoffman): pass the config dataclass in directly. - # TODO(mwhoffman): use the limiter rather than the workaround below. - # Right now this modifies min_replay_size and samples_per_insert so that - # they are not controlled by a limiter and are instead handled by the - # Agent base class (the above TODO directly references this behavior). - SVG0Config( - discount=discount, - batch_size=batch_size, - prefetch_size=prefetch_size, - target_update_period=target_update_period, - policy_optimizer=policy_optimizer, - critic_optimizer=critic_optimizer, - prior_optimizer=prior_optimizer, - distillation_cost=distillation_cost, - entropy_regularizer_cost=entropy_regularizer_cost, - min_replay_size=1, # Let the Agent class handle this. - max_replay_size=max_replay_size, - samples_per_insert=None, # Let the Agent class handle this. - sequence_length=sequence_length, - sigma=sigma, - replay_table_name=replay_table_name, - )) - - # TODO(mwhoffman): pass the network dataclass in directly. - online_networks = SVG0Networks(policy_network=policy_network, - critic_network=critic_network, - prior_network=prior_network,) - - # Target networks are just a copy of the online networks. - target_networks = copy.deepcopy(online_networks) - - # Initialize the networks. - online_networks.init(environment_spec) - target_networks.init(environment_spec) - - # TODO(mwhoffman): either make this Dataclass or pass only one struct. - # The network struct passed to make_learner is just a tuple for the - # time-being (for backwards compatibility). - networks = (online_networks, target_networks) - - # Create the behavior policy. - policy_network = online_networks.make_policy() - - # Create the replay server and grab its address. - replay_tables = builder.make_replay_tables(environment_spec, - sequence_length) - replay_server = reverb.Server(replay_tables, port=None) - replay_client = reverb.Client(f'localhost:{replay_server.port}') - - # Create actor, dataset, and learner for generating, storing, and consuming - # data respectively. - adder = builder.make_adder(replay_client) - actor = builder.make_actor(policy_network, adder) - dataset = builder.make_dataset_iterator(replay_client) - learner = builder.make_learner(networks, dataset, counter, logger, - checkpoint) - - super().__init__( - actor=actor, - learner=learner, - min_observations=max(batch_size, min_replay_size), - observations_per_step=float(batch_size) / samples_per_insert) - - # Save the replay so we don't garbage collect it. - self._replay_server = replay_server + # Create the Builder object which will internally create agent components. + builder = SVG0Builder( + # TODO(mwhoffman): pass the config dataclass in directly. + # TODO(mwhoffman): use the limiter rather than the workaround below. + # Right now this modifies min_replay_size and samples_per_insert so that + # they are not controlled by a limiter and are instead handled by the + # Agent base class (the above TODO directly references this behavior). + SVG0Config( + discount=discount, + batch_size=batch_size, + prefetch_size=prefetch_size, + target_update_period=target_update_period, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + prior_optimizer=prior_optimizer, + distillation_cost=distillation_cost, + entropy_regularizer_cost=entropy_regularizer_cost, + min_replay_size=1, # Let the Agent class handle this. + max_replay_size=max_replay_size, + samples_per_insert=None, # Let the Agent class handle this. + sequence_length=sequence_length, + sigma=sigma, + replay_table_name=replay_table_name, + ) + ) + + # TODO(mwhoffman): pass the network dataclass in directly. + online_networks = SVG0Networks( + policy_network=policy_network, + critic_network=critic_network, + prior_network=prior_network, + ) + + # Target networks are just a copy of the online networks. + target_networks = copy.deepcopy(online_networks) + + # Initialize the networks. + online_networks.init(environment_spec) + target_networks.init(environment_spec) + + # TODO(mwhoffman): either make this Dataclass or pass only one struct. + # The network struct passed to make_learner is just a tuple for the + # time-being (for backwards compatibility). + networks = (online_networks, target_networks) + + # Create the behavior policy. + policy_network = online_networks.make_policy() + + # Create the replay server and grab its address. + replay_tables = builder.make_replay_tables(environment_spec, sequence_length) + replay_server = reverb.Server(replay_tables, port=None) + replay_client = reverb.Client(f"localhost:{replay_server.port}") + + # Create actor, dataset, and learner for generating, storing, and consuming + # data respectively. + adder = builder.make_adder(replay_client) + actor = builder.make_actor(policy_network, adder) + dataset = builder.make_dataset_iterator(replay_client) + learner = builder.make_learner(networks, dataset, counter, logger, checkpoint) + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert, + ) + + # Save the replay so we don't garbage collect it. + self._replay_server = replay_server diff --git a/acme/agents/tf/svg0_prior/agent_distributed.py b/acme/agents/tf/svg0_prior/agent_distributed.py index 8bf0bebcf2..641bef8d0a 100644 --- a/acme/agents/tf/svg0_prior/agent_distributed.py +++ b/acme/agents/tf/svg0_prior/agent_distributed.py @@ -17,235 +17,237 @@ import copy from typing import Callable, Dict, Optional -import acme -from acme import specs -from acme.agents.tf.svg0_prior import agent -from acme.tf import savers as tf2_savers -from acme.utils import counting -from acme.utils import loggers -from acme.utils import lp_utils import dm_env import launchpad as lp import reverb import sonnet as snt +import acme +from acme import specs +from acme.agents.tf.svg0_prior import agent +from acme.tf import savers as tf2_savers +from acme.utils import counting, loggers, lp_utils + class DistributedSVG0: - """Program definition for SVG0.""" - - def __init__( - self, - environment_factory: Callable[[bool], dm_env.Environment], - network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], - num_actors: int = 1, - num_caches: int = 0, - environment_spec: Optional[specs.EnvironmentSpec] = None, - batch_size: int = 256, - prefetch_size: int = 4, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: Optional[float] = 32.0, - sequence_length: int = 10, - sigma: float = 0.3, - discount: float = 0.99, - policy_optimizer: Optional[snt.Optimizer] = None, - critic_optimizer: Optional[snt.Optimizer] = None, - prior_optimizer: Optional[snt.Optimizer] = None, - distillation_cost: Optional[float] = 1e-3, - entropy_regularizer_cost: Optional[float] = 1e-3, - target_update_period: int = 100, - max_actor_steps: Optional[int] = None, - log_every: float = 10.0, - ): - - if not environment_spec: - environment_spec = specs.make_environment_spec(environment_factory(False)) - - # TODO(mwhoffman): Make network_factory directly return the struct. - # TODO(mwhoffman): Make the factory take the entire spec. - def wrapped_network_factory(action_spec): - networks_dict = network_factory(action_spec) - networks = agent.SVG0Networks( - policy_network=networks_dict.get('policy'), - critic_network=networks_dict.get('critic'), - prior_network=networks_dict.get('prior', None),) - return networks - - self._environment_factory = environment_factory - self._network_factory = wrapped_network_factory - self._environment_spec = environment_spec - self._sigma = sigma - self._num_actors = num_actors - self._num_caches = num_caches - self._max_actor_steps = max_actor_steps - self._log_every = log_every - self._sequence_length = sequence_length - - self._builder = agent.SVG0Builder( - # TODO(mwhoffman): pass the config dataclass in directly. - # TODO(mwhoffman): use the limiter rather than the workaround below. - agent.SVG0Config( - discount=discount, - batch_size=batch_size, - prefetch_size=prefetch_size, - target_update_period=target_update_period, - policy_optimizer=policy_optimizer, - critic_optimizer=critic_optimizer, - prior_optimizer=prior_optimizer, - min_replay_size=min_replay_size, - max_replay_size=max_replay_size, - samples_per_insert=samples_per_insert, - sequence_length=sequence_length, - sigma=sigma, - distillation_cost=distillation_cost, - entropy_regularizer_cost=entropy_regularizer_cost, - )) - - def replay(self): - """The replay storage.""" - return self._builder.make_replay_tables(self._environment_spec, - self._sequence_length) - - def counter(self): - return tf2_savers.CheckpointingRunner(counting.Counter(), - time_delta_minutes=1, - subdirectory='counter') - - def coordinator(self, counter: counting.Counter): - return lp_utils.StepsLimiter(counter, self._max_actor_steps) - - def learner( - self, - replay: reverb.Client, - counter: counting.Counter, - ): - """The Learning part of the agent.""" - - # Create the networks to optimize (online) and target networks. - online_networks = self._network_factory(self._environment_spec.actions) - target_networks = copy.deepcopy(online_networks) - - # Initialize the networks. - online_networks.init(self._environment_spec) - target_networks.init(self._environment_spec) - - dataset = self._builder.make_dataset_iterator(replay) - counter = counting.Counter(counter, 'learner') - logger = loggers.make_default_logger( - 'learner', time_delta=self._log_every, steps_key='learner_steps') - - return self._builder.make_learner( - networks=(online_networks, target_networks), - dataset=dataset, - counter=counter, - logger=logger, - ) - - def actor( - self, - replay: reverb.Client, - variable_source: acme.VariableSource, - counter: counting.Counter, - ) -> acme.EnvironmentLoop: - """The actor process.""" - - # Create the behavior policy. - networks = self._network_factory(self._environment_spec.actions) - networks.init(self._environment_spec) - policy_network = networks.make_policy() - - # Create the agent. - actor = self._builder.make_actor( - policy_network=policy_network, - adder=self._builder.make_adder(replay), - variable_source=variable_source, - ) - - # Create the environment. - environment = self._environment_factory(False) - - # Create logger and counter; actors will not spam bigtable. - counter = counting.Counter(counter, 'actor') - logger = loggers.make_default_logger( - 'actor', - save_data=False, - time_delta=self._log_every, - steps_key='actor_steps') - - # Create the loop to connect environment and agent. - return acme.EnvironmentLoop(environment, actor, counter, logger) - - def evaluator( - self, - variable_source: acme.VariableSource, - counter: counting.Counter, - logger: Optional[loggers.Logger] = None, - ): - """The evaluation process.""" - - # Create the behavior policy. - networks = self._network_factory(self._environment_spec.actions) - networks.init(self._environment_spec) - policy_network = networks.make_policy() - - # Create the agent. - actor = self._builder.make_actor( - policy_network=policy_network, - variable_source=variable_source, - deterministic_policy=True, - ) - - # Make the environment. - environment = self._environment_factory(True) - - # Create logger and counter. - counter = counting.Counter(counter, 'evaluator') - logger = logger or loggers.make_default_logger( - 'evaluator', - time_delta=self._log_every, - steps_key='evaluator_steps', - ) - - # Create the run loop and return it. - return acme.EnvironmentLoop(environment, actor, counter, logger) - - def build(self, name='svg0'): - """Build the distributed agent topology.""" - program = lp.Program(name=name) - - with program.group('replay'): - replay = program.add_node(lp.ReverbNode(self.replay)) - - with program.group('counter'): - counter = program.add_node(lp.CourierNode(self.counter)) - - if self._max_actor_steps: - with program.group('coordinator'): - _ = program.add_node(lp.CourierNode(self.coordinator, counter)) - - with program.group('learner'): - learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) - - with program.group('evaluator'): - program.add_node(lp.CourierNode(self.evaluator, learner, counter)) - - if not self._num_caches: - # Use our learner as a single variable source. - sources = [learner] - else: - with program.group('cacher'): - # Create a set of learner caches. - sources = [] - for _ in range(self._num_caches): - cacher = program.add_node( - lp.CacherNode( - learner, refresh_interval_ms=2000, stale_after_ms=4000)) - sources.append(cacher) - - with program.group('actor'): - # Add actors which pull round-robin from our variable sources. - for actor_id in range(self._num_actors): - source = sources[actor_id % len(sources)] - program.add_node(lp.CourierNode(self.actor, replay, source, counter)) - - return program + """Program definition for SVG0.""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], + num_actors: int = 1, + num_caches: int = 0, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = 4, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: Optional[float] = 32.0, + sequence_length: int = 10, + sigma: float = 0.3, + discount: float = 0.99, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + prior_optimizer: Optional[snt.Optimizer] = None, + distillation_cost: Optional[float] = 1e-3, + entropy_regularizer_cost: Optional[float] = 1e-3, + target_update_period: int = 100, + max_actor_steps: Optional[int] = None, + log_every: float = 10.0, + ): + + if not environment_spec: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + # TODO(mwhoffman): Make network_factory directly return the struct. + # TODO(mwhoffman): Make the factory take the entire spec. + def wrapped_network_factory(action_spec): + networks_dict = network_factory(action_spec) + networks = agent.SVG0Networks( + policy_network=networks_dict.get("policy"), + critic_network=networks_dict.get("critic"), + prior_network=networks_dict.get("prior", None), + ) + return networks + + self._environment_factory = environment_factory + self._network_factory = wrapped_network_factory + self._environment_spec = environment_spec + self._sigma = sigma + self._num_actors = num_actors + self._num_caches = num_caches + self._max_actor_steps = max_actor_steps + self._log_every = log_every + self._sequence_length = sequence_length + + self._builder = agent.SVG0Builder( + # TODO(mwhoffman): pass the config dataclass in directly. + # TODO(mwhoffman): use the limiter rather than the workaround below. + agent.SVG0Config( + discount=discount, + batch_size=batch_size, + prefetch_size=prefetch_size, + target_update_period=target_update_period, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + prior_optimizer=prior_optimizer, + min_replay_size=min_replay_size, + max_replay_size=max_replay_size, + samples_per_insert=samples_per_insert, + sequence_length=sequence_length, + sigma=sigma, + distillation_cost=distillation_cost, + entropy_regularizer_cost=entropy_regularizer_cost, + ) + ) + + def replay(self): + """The replay storage.""" + return self._builder.make_replay_tables( + self._environment_spec, self._sequence_length + ) + + def counter(self): + return tf2_savers.CheckpointingRunner( + counting.Counter(), time_delta_minutes=1, subdirectory="counter" + ) + + def coordinator(self, counter: counting.Counter): + return lp_utils.StepsLimiter(counter, self._max_actor_steps) + + def learner( + self, replay: reverb.Client, counter: counting.Counter, + ): + """The Learning part of the agent.""" + + # Create the networks to optimize (online) and target networks. + online_networks = self._network_factory(self._environment_spec.actions) + target_networks = copy.deepcopy(online_networks) + + # Initialize the networks. + online_networks.init(self._environment_spec) + target_networks.init(self._environment_spec) + + dataset = self._builder.make_dataset_iterator(replay) + counter = counting.Counter(counter, "learner") + logger = loggers.make_default_logger( + "learner", time_delta=self._log_every, steps_key="learner_steps" + ) + + return self._builder.make_learner( + networks=(online_networks, target_networks), + dataset=dataset, + counter=counter, + logger=logger, + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + ) -> acme.EnvironmentLoop: + """The actor process.""" + + # Create the behavior policy. + networks = self._network_factory(self._environment_spec.actions) + networks.init(self._environment_spec) + policy_network = networks.make_policy() + + # Create the agent. + actor = self._builder.make_actor( + policy_network=policy_network, + adder=self._builder.make_adder(replay), + variable_source=variable_source, + ) + + # Create the environment. + environment = self._environment_factory(False) + + # Create logger and counter; actors will not spam bigtable. + counter = counting.Counter(counter, "actor") + logger = loggers.make_default_logger( + "actor", + save_data=False, + time_delta=self._log_every, + steps_key="actor_steps", + ) + + # Create the loop to connect environment and agent. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, + variable_source: acme.VariableSource, + counter: counting.Counter, + logger: Optional[loggers.Logger] = None, + ): + """The evaluation process.""" + + # Create the behavior policy. + networks = self._network_factory(self._environment_spec.actions) + networks.init(self._environment_spec) + policy_network = networks.make_policy() + + # Create the agent. + actor = self._builder.make_actor( + policy_network=policy_network, + variable_source=variable_source, + deterministic_policy=True, + ) + + # Make the environment. + environment = self._environment_factory(True) + + # Create logger and counter. + counter = counting.Counter(counter, "evaluator") + logger = logger or loggers.make_default_logger( + "evaluator", time_delta=self._log_every, steps_key="evaluator_steps", + ) + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def build(self, name="svg0"): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group("replay"): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group("counter"): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + with program.group("coordinator"): + _ = program.add_node(lp.CourierNode(self.coordinator, counter)) + + with program.group("learner"): + learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) + + with program.group("evaluator"): + program.add_node(lp.CourierNode(self.evaluator, learner, counter)) + + if not self._num_caches: + # Use our learner as a single variable source. + sources = [learner] + else: + with program.group("cacher"): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000 + ) + ) + sources.append(cacher) + + with program.group("actor"): + # Add actors which pull round-robin from our variable sources. + for actor_id in range(self._num_actors): + source = sources[actor_id % len(sources)] + program.add_node(lp.CourierNode(self.actor, replay, source, counter)) + + return program diff --git a/acme/agents/tf/svg0_prior/agent_distributed_test.py b/acme/agents/tf/svg0_prior/agent_distributed_test.py index 070231ab8c..af5533cc95 100644 --- a/acme/agents/tf/svg0_prior/agent_distributed_test.py +++ b/acme/agents/tf/svg0_prior/agent_distributed_test.py @@ -16,17 +16,17 @@ from typing import Sequence +import launchpad as lp +import numpy as np +import sonnet as snt +from absl.testing import absltest + import acme from acme import specs from acme.agents.tf import svg0_prior from acme.testing import fakes from acme.tf import networks from acme.tf import utils as tf2_utils -import launchpad as lp -import numpy as np -import sonnet as snt - -from absl.testing import absltest def make_networks( @@ -34,62 +34,67 @@ def make_networks( policy_layer_sizes: Sequence[int] = (10, 10), critic_layer_sizes: Sequence[int] = (10, 10), ): - """Simple networks for testing..""" - - # Get total number of action dimensions from action spec. - num_dimensions = np.prod(action_spec.shape, dtype=int) - - policy_network = snt.Sequential([ - tf2_utils.batch_concat, - networks.LayerNormMLP(policy_layer_sizes, activate_final=True), - networks.MultivariateNormalDiagHead( - num_dimensions, - tanh_mean=True, - min_scale=0.3, - init_scale=0.7, - fixed_scale=False, - use_tfd_independent=False) - ]) - # The multiplexer concatenates the (maybe transformed) observations/actions. - multiplexer = networks.CriticMultiplexer() - critic_network = snt.Sequential([ - multiplexer, - networks.LayerNormMLP(critic_layer_sizes, activate_final=True), - networks.NearZeroInitializedLinear(1), - ]) - - return { - 'policy': policy_network, - 'critic': critic_network, - } + """Simple networks for testing..""" + + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential( + [ + tf2_utils.batch_concat, + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.3, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False, + ), + ] + ) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer() + critic_network = snt.Sequential( + [ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ] + ) + + return { + "policy": policy_network, + "critic": critic_network, + } class DistributedAgentTest(absltest.TestCase): - """Simple integration/smoke test for the distributed agent.""" - - def test_control_suite(self): - """Tests that the agent can run on the control suite without crashing.""" - - agent = svg0_prior.DistributedSVG0( - environment_factory=lambda x: fakes.ContinuousEnvironment(), - network_factory=make_networks, - num_actors=2, - batch_size=32, - min_replay_size=32, - max_replay_size=1000, - ) - program = agent.build() + """Simple integration/smoke test for the distributed agent.""" + + def test_control_suite(self): + """Tests that the agent can run on the control suite without crashing.""" + + agent = svg0_prior.DistributedSVG0( + environment_factory=lambda x: fakes.ContinuousEnvironment(), + network_factory=make_networks, + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + ) + program = agent.build() - (learner_node,) = program.groups['learner'] - learner_node.disable_run() + (learner_node,) = program.groups["learner"] + learner_node.disable_run() - lp.launch(program, launch_type='test_mt') + lp.launch(program, launch_type="test_mt") - learner: acme.Learner = learner_node.create_handle().dereference() + learner: acme.Learner = learner_node.create_handle().dereference() - for _ in range(5): - learner.step() + for _ in range(5): + learner.step() -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/svg0_prior/agent_test.py b/acme/agents/tf/svg0_prior/agent_test.py index c8f0b03c08..da343c5e5a 100644 --- a/acme/agents/tf/svg0_prior/agent_test.py +++ b/acme/agents/tf/svg0_prior/agent_test.py @@ -17,17 +17,16 @@ import sys from typing import Dict, Sequence +import numpy as np +import sonnet as snt +from absl.testing import absltest + import acme -from acme import specs -from acme import types +from acme import specs, types from acme.agents.tf import svg0_prior from acme.testing import fakes from acme.tf import networks from acme.tf import utils as tf2_utils -import numpy as np -import sonnet as snt - -from absl.testing import absltest def make_networks( @@ -35,62 +34,66 @@ def make_networks( policy_layer_sizes: Sequence[int] = (10, 10), critic_layer_sizes: Sequence[int] = (10, 10), ) -> Dict[str, snt.Module]: - """Creates networks used by the agent.""" - # Get total number of action dimensions from action spec. - num_dimensions = np.prod(action_spec.shape, dtype=int) - - policy_network = snt.Sequential([ - tf2_utils.batch_concat, - networks.LayerNormMLP(policy_layer_sizes, activate_final=True), - networks.MultivariateNormalDiagHead( - num_dimensions, - tanh_mean=True, - min_scale=0.3, - init_scale=0.7, - fixed_scale=False, - use_tfd_independent=False) - ]) - # The multiplexer concatenates the (maybe transformed) observations/actions. - multiplexer = networks.CriticMultiplexer() - critic_network = snt.Sequential([ - multiplexer, - networks.LayerNormMLP(critic_layer_sizes, activate_final=True), - networks.NearZeroInitializedLinear(1), - ]) - - return { - 'policy': policy_network, - 'critic': critic_network, - } - - -class SVG0Test(absltest.TestCase): - - def test_svg0(self): - # Create a fake environment to test with. - environment = fakes.ContinuousEnvironment(episode_length=10) - spec = specs.make_environment_spec(environment) - - # Create the networks. - agent_networks = make_networks(spec.actions) - - # Construct the agent. - agent = svg0_prior.SVG0( - environment_spec=spec, - policy_network=agent_networks['policy'], - critic_network=agent_networks['critic'], - batch_size=10, - samples_per_insert=2, - min_replay_size=10, + """Creates networks used by the agent.""" + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential( + [ + tf2_utils.batch_concat, + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.3, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False, + ), + ] + ) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer() + critic_network = snt.Sequential( + [ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ] ) - # Try running the environment loop. We have no assertions here because all - # we care about is that the agent runs without raising any errors. - loop = acme.EnvironmentLoop(environment, agent) - loop.run(num_episodes=2) - - # Imports check + return { + "policy": policy_network, + "critic": critic_network, + } -if __name__ == '__main__': - absltest.main() +class SVG0Test(absltest.TestCase): + def test_svg0(self): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment(episode_length=10) + spec = specs.make_environment_spec(environment) + + # Create the networks. + agent_networks = make_networks(spec.actions) + + # Construct the agent. + agent = svg0_prior.SVG0( + environment_spec=spec, + policy_network=agent_networks["policy"], + critic_network=agent_networks["critic"], + batch_size=10, + samples_per_insert=2, + min_replay_size=10, + ) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) + + # Imports check + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/agents/tf/svg0_prior/learning.py b/acme/agents/tf/svg0_prior/learning.py index 297228e1ee..12b2bfda15 100644 --- a/acme/agents/tf/svg0_prior/learning.py +++ b/acme/agents/tf/svg0_prior/learning.py @@ -17,51 +17,51 @@ import time from typing import Dict, Iterator, List, Optional -import acme -from acme.agents.tf.svg0_prior import utils as svg0_utils -from acme.tf import savers as tf2_savers -from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers import numpy as np import reverb import sonnet as snt import tensorflow as tf from trfl import continuous_retrace_ops +import acme +from acme.agents.tf.svg0_prior import utils as svg0_utils +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting, loggers + _MIN_LOG_VAL = 1e-20 class SVG0Learner(acme.Learner): - """SVG0 learner with optional prior. + """SVG0 learner with optional prior. This is the learning component of an SVG0 agent. IE it takes a dataset as input and implements update functionality to learn from this dataset. """ - def __init__( - self, - policy_network: snt.Module, - critic_network: snt.Module, - target_policy_network: snt.Module, - target_critic_network: snt.Module, - discount: float, - target_update_period: int, - dataset_iterator: Iterator[reverb.ReplaySample], - prior_network: Optional[snt.Module] = None, - target_prior_network: Optional[snt.Module] = None, - policy_optimizer: Optional[snt.Optimizer] = None, - critic_optimizer: Optional[snt.Optimizer] = None, - prior_optimizer: Optional[snt.Optimizer] = None, - distillation_cost: Optional[float] = 1e-3, - entropy_regularizer_cost: Optional[float] = 1e-3, - num_action_samples: int = 10, - lambda_: float = 1.0, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - checkpoint: bool = True, - ): - """Initializes the learner. + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + target_policy_network: snt.Module, + target_critic_network: snt.Module, + discount: float, + target_update_period: int, + dataset_iterator: Iterator[reverb.ReplaySample], + prior_network: Optional[snt.Module] = None, + target_prior_network: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + prior_optimizer: Optional[snt.Optimizer] = None, + distillation_cost: Optional[float] = 1e-3, + entropy_regularizer_cost: Optional[float] = 1e-3, + num_action_samples: int = 10, + lambda_: float = 1.0, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + """Initializes the learner. Args: policy_network: the online (optimized) policy. @@ -94,293 +94,307 @@ def __init__( checkpoint: boolean indicating whether to checkpoint the learner. """ - # Store online and target networks. - self._policy_network = policy_network - self._critic_network = critic_network - self._target_policy_network = target_policy_network - self._target_critic_network = target_critic_network - - self._prior_network = prior_network - self._target_prior_network = target_prior_network - - self._lambda = lambda_ - self._num_action_samples = num_action_samples - self._distillation_cost = distillation_cost - self._entropy_regularizer_cost = entropy_regularizer_cost - - # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger('learner') - - # Other learner parameters. - self._discount = discount - - # Necessary to track when to update target networks. - self._num_steps = tf.Variable(0, dtype=tf.int32) - self._target_update_period = target_update_period - - # Batch dataset and create iterator. - self._iterator = dataset_iterator - - # Create optimizers if they aren't given. - self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) - self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) - self._prior_optimizer = prior_optimizer or snt.optimizers.Adam(1e-4) - - # Expose the variables. - self._variables = { - 'critic': self._critic_network.variables, - 'policy': self._policy_network.variables, - } - if self._prior_network is not None: - self._variables['prior'] = self._prior_network.variables - - # Create a checkpointer and snapshotter objects. - self._checkpointer = None - self._snapshotter = None - - if checkpoint: - objects_to_save = { - 'counter': self._counter, - 'policy': self._policy_network, - 'critic': self._critic_network, - 'target_policy': self._target_policy_network, - 'target_critic': self._target_critic_network, - 'policy_optimizer': self._policy_optimizer, - 'critic_optimizer': self._critic_optimizer, - 'num_steps': self._num_steps, - } - if self._prior_network is not None: - objects_to_save['prior'] = self._prior_network - objects_to_save['target_prior'] = self._target_prior_network - objects_to_save['prior_optimizer'] = self._prior_optimizer - - self._checkpointer = tf2_savers.Checkpointer( - subdirectory='svg0_learner', - objects_to_save=objects_to_save) - objects_to_snapshot = { - 'policy': self._policy_network, - 'critic': self._critic_network, - } - if self._prior_network is not None: - objects_to_snapshot['prior'] = self._prior_network - - self._snapshotter = tf2_savers.Snapshotter( - objects_to_save=objects_to_snapshot) - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - @tf.function - def _step(self) -> Dict[str, tf.Tensor]: - # Update target network - online_variables = [ - *self._critic_network.variables, - *self._policy_network.variables, - ] - if self._prior_network is not None: - online_variables += [*self._prior_network.variables] - online_variables = tuple(online_variables) - - target_variables = [ - *self._target_critic_network.variables, - *self._target_policy_network.variables, - ] - if self._prior_network is not None: - target_variables += [*self._target_prior_network.variables] - target_variables = tuple(target_variables) - - # Make online -> target network update ops. - if tf.math.mod(self._num_steps, self._target_update_period) == 0: - for src, dest in zip(online_variables, target_variables): - dest.assign(src) - self._num_steps.assign_add(1) - - # Get data from replay (dropping extras if any) and flip to `[T, B, ...]`. - sample: reverb.ReplaySample = next(self._iterator) - data = tf2_utils.batch_to_sequence(sample.data) - observations, actions, rewards, discounts, extra = (data.observation, - data.action, - data.reward, - data.discount, - data.extras) - online_target_pi_q = svg0_utils.OnlineTargetPiQ( - online_pi=self._policy_network, - online_q=self._critic_network, - target_pi=self._target_policy_network, - target_q=self._target_critic_network, - num_samples=self._num_action_samples, - online_prior=self._prior_network, - target_prior=self._target_prior_network, - ) - with tf.GradientTape(persistent=True) as tape: - step_outputs = svg0_utils.static_rnn( - core=online_target_pi_q, - inputs=(observations, actions), - unroll_length=rewards.shape[0]) - - # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the - # number of action samples taken. - target_pi_samples = tf2_utils.batch_to_sequence( - step_outputs.target_samples) - # Tile observations to have shape [S, T+1, B,..]. - tiled_observations = tf2_utils.tile_nested(observations, - self._num_action_samples) - - # Finally compute target Q values on the new action samples. - # Shape: [S, T+1, B, 1] - target_q_target_pi_samples = snt.BatchApply(self._target_critic_network, - 3)(tiled_observations, - target_pi_samples) - # Compute the value estimate by averaging over the action dimension. - # Shape: [T+1, B, 1]. - target_v_target_pi = tf.reduce_mean(target_q_target_pi_samples, axis=0) - - # Split the target V's into the target for learning - # `value_function_target` and the bootstrap value. Shape: [T, B]. - value_function_target = tf.squeeze(target_v_target_pi[:-1], axis=-1) - # Shape: [B]. - bootstrap_value = tf.squeeze(target_v_target_pi[-1], axis=-1) - - # When learning with a prior, add entropy terms to value targets. - if self._prior_network is not None: - value_function_target -= self._distillation_cost * tf.stop_gradient( - step_outputs.analytic_kl_to_target[:-1] + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + + self._prior_network = prior_network + self._target_prior_network = target_prior_network + + self._lambda = lambda_ + self._num_action_samples = num_action_samples + self._distillation_cost = distillation_cost + self._entropy_regularizer_cost = entropy_regularizer_cost + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger("learner") + + # Other learner parameters. + self._discount = discount + + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_update_period = target_update_period + + # Batch dataset and create iterator. + self._iterator = dataset_iterator + + # Create optimizers if they aren't given. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + self._prior_optimizer = prior_optimizer or snt.optimizers.Adam(1e-4) + + # Expose the variables. + self._variables = { + "critic": self._critic_network.variables, + "policy": self._policy_network.variables, + } + if self._prior_network is not None: + self._variables["prior"] = self._prior_network.variables + + # Create a checkpointer and snapshotter objects. + self._checkpointer = None + self._snapshotter = None + + if checkpoint: + objects_to_save = { + "counter": self._counter, + "policy": self._policy_network, + "critic": self._critic_network, + "target_policy": self._target_policy_network, + "target_critic": self._target_critic_network, + "policy_optimizer": self._policy_optimizer, + "critic_optimizer": self._critic_optimizer, + "num_steps": self._num_steps, + } + if self._prior_network is not None: + objects_to_save["prior"] = self._prior_network + objects_to_save["target_prior"] = self._target_prior_network + objects_to_save["prior_optimizer"] = self._prior_optimizer + + self._checkpointer = tf2_savers.Checkpointer( + subdirectory="svg0_learner", objects_to_save=objects_to_save + ) + objects_to_snapshot = { + "policy": self._policy_network, + "critic": self._critic_network, + } + if self._prior_network is not None: + objects_to_snapshot["prior"] = self._prior_network + + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save=objects_to_snapshot + ) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self) -> Dict[str, tf.Tensor]: + # Update target network + online_variables = [ + *self._critic_network.variables, + *self._policy_network.variables, + ] + if self._prior_network is not None: + online_variables += [*self._prior_network.variables] + online_variables = tuple(online_variables) + + target_variables = [ + *self._target_critic_network.variables, + *self._target_policy_network.variables, + ] + if self._prior_network is not None: + target_variables += [*self._target_prior_network.variables] + target_variables = tuple(target_variables) + + # Make online -> target network update ops. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(online_variables, target_variables): + dest.assign(src) + self._num_steps.assign_add(1) + + # Get data from replay (dropping extras if any) and flip to `[T, B, ...]`. + sample: reverb.ReplaySample = next(self._iterator) + data = tf2_utils.batch_to_sequence(sample.data) + observations, actions, rewards, discounts, extra = ( + data.observation, + data.action, + data.reward, + data.discount, + data.extras, + ) + online_target_pi_q = svg0_utils.OnlineTargetPiQ( + online_pi=self._policy_network, + online_q=self._critic_network, + target_pi=self._target_policy_network, + target_q=self._target_critic_network, + num_samples=self._num_action_samples, + online_prior=self._prior_network, + target_prior=self._target_prior_network, + ) + with tf.GradientTape(persistent=True) as tape: + step_outputs = svg0_utils.static_rnn( + core=online_target_pi_q, + inputs=(observations, actions), + unroll_length=rewards.shape[0], + ) + + # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the + # number of action samples taken. + target_pi_samples = tf2_utils.batch_to_sequence(step_outputs.target_samples) + # Tile observations to have shape [S, T+1, B,..]. + tiled_observations = tf2_utils.tile_nested( + observations, self._num_action_samples + ) + + # Finally compute target Q values on the new action samples. + # Shape: [S, T+1, B, 1] + target_q_target_pi_samples = snt.BatchApply(self._target_critic_network, 3)( + tiled_observations, target_pi_samples + ) + # Compute the value estimate by averaging over the action dimension. + # Shape: [T+1, B, 1]. + target_v_target_pi = tf.reduce_mean(target_q_target_pi_samples, axis=0) + + # Split the target V's into the target for learning + # `value_function_target` and the bootstrap value. Shape: [T, B]. + value_function_target = tf.squeeze(target_v_target_pi[:-1], axis=-1) + # Shape: [B]. + bootstrap_value = tf.squeeze(target_v_target_pi[-1], axis=-1) + + # When learning with a prior, add entropy terms to value targets. + if self._prior_network is not None: + value_function_target -= self._distillation_cost * tf.stop_gradient( + step_outputs.analytic_kl_to_target[:-1] + ) + bootstrap_value -= self._distillation_cost * tf.stop_gradient( + step_outputs.analytic_kl_to_target[-1] + ) + + # Get target log probs and behavior log probs from rollout. + # Shape: [T+1, B]. + target_log_probs_behavior_actions = ( + step_outputs.target_log_probs_behavior_actions + ) + behavior_log_probs = extra["log_prob"] + # Calculate importance weights. Shape: [T+1, B]. + rhos = tf.exp(target_log_probs_behavior_actions - behavior_log_probs) + + # Filter the importance weights to mask out episode restarts. Ignore the + # last action and consider the step type of the next step for masking. + # Shape: [T, B]. + episode_start_mask = tf2_utils.batch_to_sequence( + sample.data.start_of_episode + )[1:] + + rhos = svg0_utils.mask_out_restarting(rhos[:-1], episode_start_mask) + + # rhos = rhos[:-1] + # Compute the log importance weights with a small value added for + # stability. + # Shape: [T, B] + log_rhos = tf.math.log(rhos + _MIN_LOG_VAL) + + # Retrieve the target and online Q values and throw away the last action. + # Shape: [T, B]. + target_q_values = tf.squeeze(step_outputs.target_q[:-1], -1) + online_q_values = tf.squeeze(step_outputs.online_q[:-1], -1) + + # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the + # number of action samples taken. + online_pi_samples = tf2_utils.batch_to_sequence(step_outputs.online_samples) + target_q_online_pi_samples = snt.BatchApply(self._target_critic_network, 3)( + tiled_observations, online_pi_samples + ) + expected_q = tf.reduce_mean( + tf.squeeze(target_q_online_pi_samples, -1), axis=0 + ) + + # Flip online_log_probs to be of shape [S, T+1, B] and then compute + # entropy by averaging over num samples. Final shape: [T+1, B]. + online_log_probs = tf2_utils.batch_to_sequence( + step_outputs.online_log_probs + ) + sample_based_entropy = tf.reduce_mean(-online_log_probs, axis=0) + retrace_outputs = continuous_retrace_ops.retrace_from_importance_weights( + log_rhos=log_rhos, + discounts=self._discount * discounts[:-1], + rewards=rewards[:-1], + q_values=target_q_values, + values=value_function_target, + bootstrap_value=bootstrap_value, + lambda_=self._lambda, + ) + + # Critic loss. Shape: [T, B]. + critic_loss = 0.5 * tf.math.squared_difference( + tf.stop_gradient(retrace_outputs.qs), online_q_values + ) + + # Policy loss- SVG0 with sample based entropy. Shape: [T, B] + policy_loss = -( + expected_q + self._entropy_regularizer_cost * sample_based_entropy + ) + policy_loss = policy_loss[:-1] + + if self._prior_network is not None: + # When training the prior, also add the per-timestep KL cost. + policy_loss += ( + self._distillation_cost * step_outputs.analytic_kl_to_target[:-1] + ) + + # Ensure episode restarts are masked out when computing the losses. + critic_loss = svg0_utils.mask_out_restarting( + critic_loss, episode_start_mask + ) + critic_loss = tf.reduce_mean(critic_loss) + + policy_loss = svg0_utils.mask_out_restarting( + policy_loss, episode_start_mask ) - bootstrap_value -= self._distillation_cost * tf.stop_gradient( - step_outputs.analytic_kl_to_target[-1]) - - # Get target log probs and behavior log probs from rollout. - # Shape: [T+1, B]. - target_log_probs_behavior_actions = ( - step_outputs.target_log_probs_behavior_actions) - behavior_log_probs = extra['log_prob'] - # Calculate importance weights. Shape: [T+1, B]. - rhos = tf.exp(target_log_probs_behavior_actions - behavior_log_probs) - - # Filter the importance weights to mask out episode restarts. Ignore the - # last action and consider the step type of the next step for masking. - # Shape: [T, B]. - episode_start_mask = tf2_utils.batch_to_sequence( - sample.data.start_of_episode)[1:] - - rhos = svg0_utils.mask_out_restarting(rhos[:-1], episode_start_mask) - - # rhos = rhos[:-1] - # Compute the log importance weights with a small value added for - # stability. - # Shape: [T, B] - log_rhos = tf.math.log(rhos + _MIN_LOG_VAL) - - # Retrieve the target and online Q values and throw away the last action. - # Shape: [T, B]. - target_q_values = tf.squeeze(step_outputs.target_q[:-1], -1) - online_q_values = tf.squeeze(step_outputs.online_q[:-1], -1) - - # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the - # number of action samples taken. - online_pi_samples = tf2_utils.batch_to_sequence( - step_outputs.online_samples) - target_q_online_pi_samples = snt.BatchApply(self._target_critic_network, - 3)(tiled_observations, - online_pi_samples) - expected_q = tf.reduce_mean( - tf.squeeze(target_q_online_pi_samples, -1), axis=0) - - # Flip online_log_probs to be of shape [S, T+1, B] and then compute - # entropy by averaging over num samples. Final shape: [T+1, B]. - online_log_probs = tf2_utils.batch_to_sequence( - step_outputs.online_log_probs) - sample_based_entropy = tf.reduce_mean(-online_log_probs, axis=0) - retrace_outputs = continuous_retrace_ops.retrace_from_importance_weights( - log_rhos=log_rhos, - discounts=self._discount * discounts[:-1], - rewards=rewards[:-1], - q_values=target_q_values, - values=value_function_target, - bootstrap_value=bootstrap_value, - lambda_=self._lambda, - ) - - # Critic loss. Shape: [T, B]. - critic_loss = 0.5 * tf.math.squared_difference( - tf.stop_gradient(retrace_outputs.qs), online_q_values) - - # Policy loss- SVG0 with sample based entropy. Shape: [T, B] - policy_loss = -( - expected_q + self._entropy_regularizer_cost * sample_based_entropy) - policy_loss = policy_loss[:-1] - - if self._prior_network is not None: - # When training the prior, also add the per-timestep KL cost. - policy_loss += ( - self._distillation_cost * step_outputs.analytic_kl_to_target[:-1]) - - # Ensure episode restarts are masked out when computing the losses. - critic_loss = svg0_utils.mask_out_restarting(critic_loss, - episode_start_mask) - critic_loss = tf.reduce_mean(critic_loss) - - policy_loss = svg0_utils.mask_out_restarting(policy_loss, - episode_start_mask) - policy_loss = tf.reduce_mean(policy_loss) - - if self._prior_network is not None: - prior_loss = step_outputs.analytic_kl_divergence[:-1] - prior_loss = svg0_utils.mask_out_restarting(prior_loss, - episode_start_mask) - prior_loss = tf.reduce_mean(prior_loss) - - # Get trainable variables. - policy_variables = self._policy_network.trainable_variables - critic_variables = self._critic_network.trainable_variables - - # Compute gradients. - policy_gradients = tape.gradient(policy_loss, policy_variables) - critic_gradients = tape.gradient(critic_loss, critic_variables) - if self._prior_network is not None: - prior_variables = self._prior_network.trainable_variables - prior_gradients = tape.gradient(prior_loss, prior_variables) - - # Delete the tape manually because of the persistent=True flag. - del tape - - # Apply gradients. - self._policy_optimizer.apply(policy_gradients, policy_variables) - self._critic_optimizer.apply(critic_gradients, critic_variables) - losses = { - 'critic_loss': critic_loss, - 'policy_loss': policy_loss, - } - - if self._prior_network is not None: - self._prior_optimizer.apply(prior_gradients, prior_variables) - losses['prior_loss'] = prior_loss - - # Losses to track. - return losses - - def step(self): - # Run the learning step. - fetches = self._step() - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Update our counts and record it. - counts = self._counter.increment(steps=1, walltime=elapsed_time) - fetches.update(counts) - - # Checkpoint and attempt to write the logs. - if self._checkpointer is not None: - self._checkpointer.save() - if self._snapshotter is not None: - self._snapshotter.save() - self._logger.write(fetches) - - def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: - return [tf2_utils.to_numpy(self._variables[name]) for name in names] + policy_loss = tf.reduce_mean(policy_loss) + + if self._prior_network is not None: + prior_loss = step_outputs.analytic_kl_divergence[:-1] + prior_loss = svg0_utils.mask_out_restarting( + prior_loss, episode_start_mask + ) + prior_loss = tf.reduce_mean(prior_loss) + + # Get trainable variables. + policy_variables = self._policy_network.trainable_variables + critic_variables = self._critic_network.trainable_variables + + # Compute gradients. + policy_gradients = tape.gradient(policy_loss, policy_variables) + critic_gradients = tape.gradient(critic_loss, critic_variables) + if self._prior_network is not None: + prior_variables = self._prior_network.trainable_variables + prior_gradients = tape.gradient(prior_loss, prior_variables) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Apply gradients. + self._policy_optimizer.apply(policy_gradients, policy_variables) + self._critic_optimizer.apply(critic_gradients, critic_variables) + losses = { + "critic_loss": critic_loss, + "policy_loss": policy_loss, + } + + if self._prior_network is not None: + self._prior_optimizer.apply(prior_gradients, prior_variables) + losses["prior_loss"] = prior_loss + + # Losses to track. + return losses + + def step(self): + # Run the learning step. + fetches = self._step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + if self._checkpointer is not None: + self._checkpointer.save() + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] diff --git a/acme/agents/tf/svg0_prior/networks.py b/acme/agents/tf/svg0_prior/networks.py index 945a200336..23d7ac86fb 100644 --- a/acme/agents/tf/svg0_prior/networks.py +++ b/acme/agents/tf/svg0_prior/networks.py @@ -15,52 +15,57 @@ """Shared helpers for different experiment flavours.""" import functools -from typing import Mapping, Sequence, Optional +from typing import Mapping, Optional, Sequence -from acme import specs -from acme import types +import numpy as np +import sonnet as snt + +from acme import specs, types from acme.agents.tf.svg0_prior import utils as svg0_utils from acme.tf import networks from acme.tf import utils as tf2_utils -import numpy as np -import sonnet as snt - def make_default_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), ) -> Mapping[str, types.TensorTransformation]: - """Creates networks used by the agent.""" + """Creates networks used by the agent.""" - # Get total number of action dimensions from action spec. - num_dimensions = np.prod(action_spec.shape, dtype=int) + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) - policy_network = snt.Sequential([ - tf2_utils.batch_concat, - networks.LayerNormMLP(policy_layer_sizes, activate_final=True), - networks.MultivariateNormalDiagHead( - num_dimensions, - tanh_mean=True, - min_scale=0.3, - init_scale=0.7, - fixed_scale=False, - use_tfd_independent=False) - ]) - # The multiplexer concatenates the (maybe transformed) observations/actions. - multiplexer = networks.CriticMultiplexer( - action_network=networks.ClipToSpec(action_spec)) - critic_network = snt.Sequential([ - multiplexer, - networks.LayerNormMLP(critic_layer_sizes, activate_final=True), - networks.NearZeroInitializedLinear(1), - ]) + policy_network = snt.Sequential( + [ + tf2_utils.batch_concat, + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.3, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False, + ), + ] + ) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer( + action_network=networks.ClipToSpec(action_spec) + ) + critic_network = snt.Sequential( + [ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ] + ) - return { - "policy": policy_network, - "critic": critic_network, - } + return { + "policy": policy_network, + "critic": critic_network, + } def make_network_with_prior( @@ -71,48 +76,59 @@ def make_network_with_prior( policy_keys: Optional[Sequence[str]] = None, prior_keys: Optional[Sequence[str]] = None, ) -> Mapping[str, types.TensorTransformation]: - """Creates networks used by the agent.""" + """Creates networks used by the agent.""" - # Get total number of action dimensions from action spec. - num_dimensions = np.prod(action_spec.shape, dtype=int) - flatten_concat_policy = functools.partial( - svg0_utils.batch_concat_selection, concat_keys=policy_keys) - flatten_concat_prior = functools.partial( - svg0_utils.batch_concat_selection, concat_keys=prior_keys) + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + flatten_concat_policy = functools.partial( + svg0_utils.batch_concat_selection, concat_keys=policy_keys + ) + flatten_concat_prior = functools.partial( + svg0_utils.batch_concat_selection, concat_keys=prior_keys + ) - policy_network = snt.Sequential([ - flatten_concat_policy, - networks.LayerNormMLP(policy_layer_sizes, activate_final=True), - networks.MultivariateNormalDiagHead( - num_dimensions, - tanh_mean=True, - min_scale=0.1, - init_scale=0.7, - fixed_scale=False, - use_tfd_independent=False) - ]) - # The multiplexer concatenates the (maybe transformed) observations/actions. - multiplexer = networks.CriticMultiplexer( - observation_network=flatten_concat_policy, - action_network=networks.ClipToSpec(action_spec)) - critic_network = snt.Sequential([ - multiplexer, - networks.LayerNormMLP(critic_layer_sizes, activate_final=True), - networks.NearZeroInitializedLinear(1), - ]) - prior_network = snt.Sequential([ - flatten_concat_prior, - networks.LayerNormMLP(prior_layer_sizes, activate_final=True), - networks.MultivariateNormalDiagHead( - num_dimensions, - tanh_mean=True, - min_scale=0.1, - init_scale=0.7, - fixed_scale=False, - use_tfd_independent=False) - ]) - return { - "policy": policy_network, - "critic": critic_network, - "prior": prior_network, - } + policy_network = snt.Sequential( + [ + flatten_concat_policy, + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.1, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False, + ), + ] + ) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer( + observation_network=flatten_concat_policy, + action_network=networks.ClipToSpec(action_spec), + ) + critic_network = snt.Sequential( + [ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ] + ) + prior_network = snt.Sequential( + [ + flatten_concat_prior, + networks.LayerNormMLP(prior_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.1, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False, + ), + ] + ) + return { + "policy": policy_network, + "critic": critic_network, + "prior": prior_network, + } diff --git a/acme/agents/tf/svg0_prior/utils.py b/acme/agents/tf/svg0_prior/utils.py index 8474fea6cc..fb0f41072f 100644 --- a/acme/agents/tf/svg0_prior/utils.py +++ b/acme/agents/tf/svg0_prior/utils.py @@ -15,84 +15,92 @@ """Utility functions for SVG0 algorithm with priors.""" import collections -from typing import Tuple, Optional, Dict, Iterable - -from acme import types -from acme.tf import utils as tf2_utils +from typing import Dict, Iterable, Optional, Tuple import sonnet as snt import tensorflow as tf import tree +from acme import types +from acme.tf import utils as tf2_utils + class OnlineTargetPiQ(snt.Module): - """Core to unroll online and target policies and Q functions at once. + """Core to unroll online and target policies and Q functions at once. A core that runs online and target policies and Q functions. This can be more efficient if the core needs to be unrolled across time and called many times. """ - def __init__(self, - online_pi: snt.Module, - online_q: snt.Module, - target_pi: snt.Module, - target_q: snt.Module, - num_samples: int, - online_prior: Optional[snt.Module] = None, - target_prior: Optional[snt.Module] = None, - name='OnlineTargetPiQ'): - super().__init__(name) - - self._online_pi = online_pi - self._target_pi = target_pi - self._online_q = online_q - self._target_q = target_q - self._online_prior = online_prior - self._target_prior = target_prior - - self._num_samples = num_samples - output_list = [ - 'online_samples', 'target_samples', 'target_log_probs_behavior_actions', - 'online_log_probs', 'online_q', 'target_q' - ] - if online_prior is not None: - output_list += ['analytic_kl_divergence', 'analytic_kl_to_target'] - self._output_tuple = collections.namedtuple( - 'OnlineTargetPiQ', output_list) - - def __call__(self, input_obs_and_action: Tuple[tf.Tensor, tf.Tensor]): - (obs, action) = input_obs_and_action - online_pi_dist = self._online_pi(obs) - target_pi_dist = self._target_pi(obs) - - online_samples = online_pi_dist.sample(self._num_samples) - target_samples = target_pi_dist.sample(self._num_samples) - target_log_probs_behavior_actions = target_pi_dist.log_prob(action) - - online_log_probs = online_pi_dist.log_prob(tf.stop_gradient(online_samples)) - - online_q_out = self._online_q(obs, action) - target_q_out = self._target_q(obs, action) - - output_list = [ - online_samples, target_samples, target_log_probs_behavior_actions, - online_log_probs, online_q_out, target_q_out - ] - - if self._online_prior is not None: - prior_dist = self._online_prior(obs) - target_prior_dist = self._target_prior(obs) - analytic_kl_divergence = online_pi_dist.kl_divergence(prior_dist) - analytic_kl_to_target = online_pi_dist.kl_divergence(target_prior_dist) - - output_list += [analytic_kl_divergence, analytic_kl_to_target] - output = self._output_tuple(*output_list) - return output - - -def static_rnn(core: snt.Module, inputs: types.NestedTensor, - unroll_length: int): - """Unroll core along inputs for unroll_length steps. + def __init__( + self, + online_pi: snt.Module, + online_q: snt.Module, + target_pi: snt.Module, + target_q: snt.Module, + num_samples: int, + online_prior: Optional[snt.Module] = None, + target_prior: Optional[snt.Module] = None, + name="OnlineTargetPiQ", + ): + super().__init__(name) + + self._online_pi = online_pi + self._target_pi = target_pi + self._online_q = online_q + self._target_q = target_q + self._online_prior = online_prior + self._target_prior = target_prior + + self._num_samples = num_samples + output_list = [ + "online_samples", + "target_samples", + "target_log_probs_behavior_actions", + "online_log_probs", + "online_q", + "target_q", + ] + if online_prior is not None: + output_list += ["analytic_kl_divergence", "analytic_kl_to_target"] + self._output_tuple = collections.namedtuple("OnlineTargetPiQ", output_list) + + def __call__(self, input_obs_and_action: Tuple[tf.Tensor, tf.Tensor]): + (obs, action) = input_obs_and_action + online_pi_dist = self._online_pi(obs) + target_pi_dist = self._target_pi(obs) + + online_samples = online_pi_dist.sample(self._num_samples) + target_samples = target_pi_dist.sample(self._num_samples) + target_log_probs_behavior_actions = target_pi_dist.log_prob(action) + + online_log_probs = online_pi_dist.log_prob(tf.stop_gradient(online_samples)) + + online_q_out = self._online_q(obs, action) + target_q_out = self._target_q(obs, action) + + output_list = [ + online_samples, + target_samples, + target_log_probs_behavior_actions, + online_log_probs, + online_q_out, + target_q_out, + ] + + if self._online_prior is not None: + prior_dist = self._online_prior(obs) + target_prior_dist = self._target_prior(obs) + analytic_kl_divergence = online_pi_dist.kl_divergence(prior_dist) + analytic_kl_to_target = online_pi_dist.kl_divergence(target_prior_dist) + + output_list += [analytic_kl_divergence, analytic_kl_to_target] + output = self._output_tuple(*output_list) + return output + + +def static_rnn(core: snt.Module, inputs: types.NestedTensor, unroll_length: int): + """Unroll core along inputs for unroll_length steps. Note: for time-major input tensors whose leading dimension is less than unroll_length, `None` would be provided instead. @@ -106,19 +114,20 @@ def static_rnn(core: snt.Module, inputs: types.NestedTensor, step_outputs: a `nest` of time-major stacked output tensors of length `unroll_length`. """ - step_outputs = [] - for time_dim in range(unroll_length): - inputs_t = tree.map_structure( - lambda t, i_=time_dim: t[i_] if i_ < t.shape[0] else None, inputs) - step_output = core(inputs_t) - step_outputs.append(step_output) + step_outputs = [] + for time_dim in range(unroll_length): + inputs_t = tree.map_structure( + lambda t, i_=time_dim: t[i_] if i_ < t.shape[0] else None, inputs + ) + step_output = core(inputs_t) + step_outputs.append(step_output) - step_outputs = _nest_stack(step_outputs) - return step_outputs + step_outputs = _nest_stack(step_outputs) + return step_outputs def mask_out_restarting(tensor: tf.Tensor, start_of_episode: tf.Tensor): - """Mask out `tensor` taken on the step that resets the environment. + """Mask out `tensor` taken on the step that resets the environment. Args: tensor: a time-major 2-D `Tensor` of shape [T, B]. @@ -129,29 +138,33 @@ def mask_out_restarting(tensor: tf.Tensor, start_of_episode: tf.Tensor): tensor of shape [T, B] with elements are masked out according to step_types, restarting weights of shape [T, B] """ - tensor.get_shape().assert_has_rank(2) - start_of_episode.get_shape().assert_has_rank(2) - weights = tf.cast(~start_of_episode, dtype=tf.float32) - masked_tensor = tensor * weights - return masked_tensor - - -def batch_concat_selection(observation_dict: Dict[str, types.NestedTensor], - concat_keys: Optional[Iterable[str]] = None, - output_dtype=tf.float32) -> tf.Tensor: - """Concatenate a dict of observations into 2-D tensors.""" - concat_keys = concat_keys or sorted(observation_dict.keys()) - to_concat = [] - for obs in concat_keys: - if obs not in observation_dict: - raise KeyError( - 'Missing observation. Requested: {} (available: {})'.format( - obs, list(observation_dict.keys()))) - to_concat.append(tf.cast(observation_dict[obs], output_dtype)) - - return tf2_utils.batch_concat(to_concat) + tensor.get_shape().assert_has_rank(2) + start_of_episode.get_shape().assert_has_rank(2) + weights = tf.cast(~start_of_episode, dtype=tf.float32) + masked_tensor = tensor * weights + return masked_tensor + + +def batch_concat_selection( + observation_dict: Dict[str, types.NestedTensor], + concat_keys: Optional[Iterable[str]] = None, + output_dtype=tf.float32, +) -> tf.Tensor: + """Concatenate a dict of observations into 2-D tensors.""" + concat_keys = concat_keys or sorted(observation_dict.keys()) + to_concat = [] + for obs in concat_keys: + if obs not in observation_dict: + raise KeyError( + "Missing observation. Requested: {} (available: {})".format( + obs, list(observation_dict.keys()) + ) + ) + to_concat.append(tf.cast(observation_dict[obs], output_dtype)) + + return tf2_utils.batch_concat(to_concat) def _nest_stack(list_of_nests, axis=0): - """Convert a list of nests to a nest of stacked lists.""" - return tree.map_structure(lambda *ts: tf.stack(ts, axis=axis), *list_of_nests) + """Convert a list of nests to a nest of stacked lists.""" + return tree.map_structure(lambda *ts: tf.stack(ts, axis=axis), *list_of_nests) diff --git a/acme/core.py b/acme/core.py index edd7ad6f2a..3fcfda82ef 100644 --- a/acme/core.py +++ b/acme/core.py @@ -21,16 +21,17 @@ import itertools from typing import Generic, Iterator, List, Optional, Sequence, TypeVar +import dm_env + from acme import types from acme.utils import metrics -import dm_env -T = TypeVar('T') +T = TypeVar("T") @metrics.record_class_usage class Actor(abc.ABC): - """Interface for an agent that can act. + """Interface for an agent that can act. This interface defines an API for an Actor to interact with an EnvironmentLoop (see acme.environment_loop), e.g. a simple RL loop where each step is of the @@ -49,13 +50,13 @@ class Actor(abc.ABC): actor.update() """ - @abc.abstractmethod - def select_action(self, observation: types.NestedArray) -> types.NestedArray: - """Samples from the policy and returns an action.""" + @abc.abstractmethod + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + """Samples from the policy and returns an action.""" - @abc.abstractmethod - def observe_first(self, timestep: dm_env.TimeStep): - """Make a first observation from the environment. + @abc.abstractmethod + def observe_first(self, timestep: dm_env.TimeStep): + """Make a first observation from the environment. Note that this need not be an initial state, it is merely beginning the recording of a trajectory. @@ -64,22 +65,20 @@ def observe_first(self, timestep: dm_env.TimeStep): timestep: first timestep. """ - @abc.abstractmethod - def observe( - self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - ): - """Make an observation of timestep data from the environment. + @abc.abstractmethod + def observe( + self, action: types.NestedArray, next_timestep: dm_env.TimeStep, + ): + """Make an observation of timestep data from the environment. Args: action: action taken in the environment. next_timestep: timestep produced by the environment given the action. """ - @abc.abstractmethod - def update(self, wait: bool = False): - """Perform an update of the actor parameters from past observations. + @abc.abstractmethod + def update(self, wait: bool = False): + """Perform an update of the actor parameters from past observations. Args: wait: if True, the update will be blocking. @@ -87,16 +86,16 @@ def update(self, wait: bool = False): class VariableSource(abc.ABC): - """Abstract source of variables. + """Abstract source of variables. Objects which implement this interface provide a source of variables, returned as a collection of (nested) numpy arrays. Generally this will be used to provide variables to some learned policy/etc. """ - @abc.abstractmethod - def get_variables(self, names: Sequence[str]) -> List[types.NestedArray]: - """Return the named variables as a collection of (nested) numpy arrays. + @abc.abstractmethod + def get_variables(self, names: Sequence[str]) -> List[types.NestedArray]: + """Return the named variables as a collection of (nested) numpy arrays. Args: names: args where each name is a string identifying a predefined subset of @@ -110,27 +109,27 @@ def get_variables(self, names: Sequence[str]) -> List[types.NestedArray]: @metrics.record_class_usage class Worker(abc.ABC): - """An interface for (potentially) distributed workers.""" + """An interface for (potentially) distributed workers.""" - @abc.abstractmethod - def run(self): - """Runs the worker.""" + @abc.abstractmethod + def run(self): + """Runs the worker.""" class Saveable(abc.ABC, Generic[T]): - """An interface for saveable objects.""" + """An interface for saveable objects.""" - @abc.abstractmethod - def save(self) -> T: - """Returns the state from the object to be saved.""" + @abc.abstractmethod + def save(self) -> T: + """Returns the state from the object to be saved.""" - @abc.abstractmethod - def restore(self, state: T): - """Given the state, restores the object.""" + @abc.abstractmethod + def restore(self, state: T): + """Given the state, restores the object.""" class Learner(VariableSource, Worker, Saveable): - """Abstract learner object. + """Abstract learner object. This corresponds to an object which implements a learning loop. A single step of learning should be implemented via the `step` method and this step @@ -145,32 +144,32 @@ class Learner(VariableSource, Worker, Saveable): useful when the dataset is filled by an external process. """ - @abc.abstractmethod - def step(self): - """Perform an update step of the learner's parameters.""" + @abc.abstractmethod + def step(self): + """Perform an update step of the learner's parameters.""" - def run(self, num_steps: Optional[int] = None) -> None: - """Run the update loop; typically an infinite loop which calls step.""" + def run(self, num_steps: Optional[int] = None) -> None: + """Run the update loop; typically an infinite loop which calls step.""" - iterator = range(num_steps) if num_steps is not None else itertools.count() + iterator = range(num_steps) if num_steps is not None else itertools.count() - for _ in iterator: - self.step() + for _ in iterator: + self.step() - def save(self): - raise NotImplementedError('Method "save" is not implemented.') + def save(self): + raise NotImplementedError('Method "save" is not implemented.') - def restore(self, state): - raise NotImplementedError('Method "restore" is not implemented.') + def restore(self, state): + raise NotImplementedError('Method "restore" is not implemented.') class PrefetchingIterator(Iterator[T], abc.ABC): - """Abstract iterator object which supports `ready` method.""" + """Abstract iterator object which supports `ready` method.""" - @abc.abstractmethod - def ready(self) -> bool: - """Is there any data waiting for processing.""" + @abc.abstractmethod + def ready(self) -> bool: + """Is there any data waiting for processing.""" - @abc.abstractmethod - def retrieved_elements(self) -> int: - """How many elements were retrieved from the iterator.""" + @abc.abstractmethod + def retrieved_elements(self) -> int: + """How many elements were retrieved from the iterator.""" diff --git a/acme/core_test.py b/acme/core_test.py index a7f2db554b..2ec2f05ed1 100644 --- a/acme/core_test.py +++ b/acme/core_test.py @@ -16,42 +16,40 @@ from typing import List -from acme import core -from acme import types - from absl.testing import absltest +from acme import core, types + class StepCountingLearner(core.Learner): - """A learner which counts `num_steps` and then raises `StopIteration`.""" + """A learner which counts `num_steps` and then raises `StopIteration`.""" - def __init__(self, num_steps: int): - self.step_count = 0 - self.num_steps = num_steps + def __init__(self, num_steps: int): + self.step_count = 0 + self.num_steps = num_steps - def step(self): - self.step_count += 1 - if self.step_count >= self.num_steps: - raise StopIteration() + def step(self): + self.step_count += 1 + if self.step_count >= self.num_steps: + raise StopIteration() - def get_variables(self, unused: List[str]) -> List[types.NestedArray]: - del unused - return [] + def get_variables(self, unused: List[str]) -> List[types.NestedArray]: + del unused + return [] class CoreTest(absltest.TestCase): + def test_learner_run_with_limit(self): + learner = StepCountingLearner(100) + learner.run(7) + self.assertEqual(learner.step_count, 7) - def test_learner_run_with_limit(self): - learner = StepCountingLearner(100) - learner.run(7) - self.assertEqual(learner.step_count, 7) - - def test_learner_run_no_limit(self): - learner = StepCountingLearner(100) - with self.assertRaises(StopIteration): - learner.run() - self.assertEqual(learner.step_count, 100) + def test_learner_run_no_limit(self): + learner = StepCountingLearner(100) + with self.assertRaises(StopIteration): + learner.run() + self.assertEqual(learner.step_count, 100) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/datasets/__init__.py b/acme/datasets/__init__.py index 6dcfae02df..42627e2d39 100644 --- a/acme/datasets/__init__.py +++ b/acme/datasets/__init__.py @@ -16,4 +16,5 @@ from acme.datasets.numpy_iterator import NumpyIterator from acme.datasets.reverb import make_reverb_dataset + # from acme.datasets.reverb import make_reverb_dataset_trajectory diff --git a/acme/datasets/image_augmentation.py b/acme/datasets/image_augmentation.py index bcf4f07080..3e89ffba4f 100644 --- a/acme/datasets/image_augmentation.py +++ b/acme/datasets/image_augmentation.py @@ -16,26 +16,28 @@ import enum -from acme import types -from acme.datasets import reverb as reverb_dataset import reverb import tensorflow as tf +from acme import types +from acme.datasets import reverb as reverb_dataset + class CropType(enum.Enum): - """Types of cropping supported by the image aumentation transforms. + """Types of cropping supported by the image aumentation transforms. BILINEAR: Continuously randomly located then bilinearly interpolated. ALIGNED: Aligned with input image's pixel grid. """ - BILINEAR = 'bilinear' - ALIGNED = 'aligned' + + BILINEAR = "bilinear" + ALIGNED = "aligned" -def pad_and_crop(img: tf.Tensor, - pad_size: int = 4, - method: CropType = CropType.ALIGNED) -> tf.Tensor: - """Pad and crop image to mimic a random translation with mirroring at edges. +def pad_and_crop( + img: tf.Tensor, pad_size: int = 4, method: CropType = CropType.ALIGNED +) -> tf.Tensor: + """Pad and crop image to mimic a random translation with mirroring at edges. This implements the image augmentation from section 3.1 in (Kostrikov et al.) https://arxiv.org/abs/2004.13649. @@ -50,71 +52,82 @@ def pad_and_crop(img: tf.Tensor, Returns: The image after having been padded and cropped. """ - num_batch_dims = img.shape[:-3].rank - - if img.shape.is_fully_defined(): - img_shape = img.shape.as_list() - else: - img_shape = tf.shape(img) - - # Set paddings for height and width only, batches and channels set to [0, 0]. - paddings = [[0, 0]] * num_batch_dims # Do not pad batch dims. - paddings.extend([[pad_size, pad_size], [pad_size, pad_size], [0, 0]]) - - # Pad using symmetric padding. - padded_img = tf.pad(img, paddings=paddings, mode='SYMMETRIC') - - # Crop padded image using requested method. - if method == CropType.ALIGNED: - cropped_img = tf.image.random_crop(padded_img, img_shape) - elif method == CropType.BILINEAR: - height, width = img_shape[-3:-1] - padded_height, padded_width = height + 2 * pad_size, width + 2 * pad_size - - # Pick a top-left point uniformly at random. - top_left = tf.random.uniform( - shape=(2,), maxval=2 * pad_size + 1, dtype=tf.int32) - - # This single box is applied to the entire batch if a batch is passed. - batch_size = tf.shape(padded_img)[0] - box = tf.cast( - tf.tile( - tf.expand_dims([ - top_left[0] / padded_height, - top_left[1] / padded_width, - (top_left[0] + height) / padded_height, - (top_left[1] + width) / padded_width, - ], axis=0), [batch_size, 1]), - tf.float32) # Shape [batch_size, 2]. - - # Crop and resize according to `box` then reshape back to input shape. - cropped_img = tf.image.crop_and_resize( - padded_img, - box, - tf.range(batch_size), - (height, width), - method='bilinear') - cropped_img = tf.reshape(cropped_img, img_shape) - - return cropped_img + num_batch_dims = img.shape[:-3].rank + + if img.shape.is_fully_defined(): + img_shape = img.shape.as_list() + else: + img_shape = tf.shape(img) + + # Set paddings for height and width only, batches and channels set to [0, 0]. + paddings = [[0, 0]] * num_batch_dims # Do not pad batch dims. + paddings.extend([[pad_size, pad_size], [pad_size, pad_size], [0, 0]]) + + # Pad using symmetric padding. + padded_img = tf.pad(img, paddings=paddings, mode="SYMMETRIC") + + # Crop padded image using requested method. + if method == CropType.ALIGNED: + cropped_img = tf.image.random_crop(padded_img, img_shape) + elif method == CropType.BILINEAR: + height, width = img_shape[-3:-1] + padded_height, padded_width = height + 2 * pad_size, width + 2 * pad_size + + # Pick a top-left point uniformly at random. + top_left = tf.random.uniform( + shape=(2,), maxval=2 * pad_size + 1, dtype=tf.int32 + ) + + # This single box is applied to the entire batch if a batch is passed. + batch_size = tf.shape(padded_img)[0] + box = tf.cast( + tf.tile( + tf.expand_dims( + [ + top_left[0] / padded_height, + top_left[1] / padded_width, + (top_left[0] + height) / padded_height, + (top_left[1] + width) / padded_width, + ], + axis=0, + ), + [batch_size, 1], + ), + tf.float32, + ) # Shape [batch_size, 2]. + + # Crop and resize according to `box` then reshape back to input shape. + cropped_img = tf.image.crop_and_resize( + padded_img, box, tf.range(batch_size), (height, width), method="bilinear" + ) + cropped_img = tf.reshape(cropped_img, img_shape) + + return cropped_img def make_transform( observation_transform: types.TensorTransformation, transform_next_observation: bool = True, ) -> reverb_dataset.Transform: - """Creates the appropriate dataset transform for the given signature.""" - - if transform_next_observation: - def transform(x: reverb.ReplaySample) -> reverb.ReplaySample: - return x._replace( - data=x.data._replace( - observation=observation_transform(x.data.observation), - next_observation=observation_transform(x.data.next_observation))) - else: - def transform(x: reverb.ReplaySample) -> reverb.ReplaySample: - return x._replace( - data=x.data._replace( - observation=observation_transform(x.data.observation))) - - return transform + """Creates the appropriate dataset transform for the given signature.""" + + if transform_next_observation: + + def transform(x: reverb.ReplaySample) -> reverb.ReplaySample: + return x._replace( + data=x.data._replace( + observation=observation_transform(x.data.observation), + next_observation=observation_transform(x.data.next_observation), + ) + ) + + else: + + def transform(x: reverb.ReplaySample) -> reverb.ReplaySample: + return x._replace( + data=x.data._replace( + observation=observation_transform(x.data.observation) + ) + ) + + return transform diff --git a/acme/datasets/numpy_iterator.py b/acme/datasets/numpy_iterator.py index 8d6733690f..a2744e820e 100644 --- a/acme/datasets/numpy_iterator.py +++ b/acme/datasets/numpy_iterator.py @@ -16,13 +16,14 @@ from typing import Iterator -from acme import types import numpy as np import tree +from acme import types + class NumpyIterator(Iterator[types.NestedArray]): - """Iterator over a dataset with elements converted to numpy. + """Iterator over a dataset with elements converted to numpy. Note: This iterator returns read-only numpy arrays. @@ -32,17 +33,18 @@ class NumpyIterator(Iterator[types.NestedArray]): TODO(b/178684359): Remove this when it is upstreamed into `tf.data`. """ - __slots__ = ['_iterator'] + __slots__ = ["_iterator"] - def __init__(self, dataset): - self._iterator: Iterator[types.NestedTensor] = iter(dataset) + def __init__(self, dataset): + self._iterator: Iterator[types.NestedTensor] = iter(dataset) - def __iter__(self) -> 'NumpyIterator': - return self + def __iter__(self) -> "NumpyIterator": + return self - def __next__(self) -> types.NestedArray: - return tree.map_structure(lambda t: np.asarray(memoryview(t)), - next(self._iterator)) + def __next__(self) -> types.NestedArray: + return tree.map_structure( + lambda t: np.asarray(memoryview(t)), next(self._iterator) + ) - def next(self): - return self.__next__() + def next(self): + return self.__next__() diff --git a/acme/datasets/numpy_iterator_test.py b/acme/datasets/numpy_iterator_test.py index 500a4c3fbc..685aac1b42 100644 --- a/acme/datasets/numpy_iterator_test.py +++ b/acme/datasets/numpy_iterator_test.py @@ -16,34 +16,30 @@ import collections -from acme.datasets import numpy_iterator import tensorflow as tf - from absl.testing import absltest +from acme.datasets import numpy_iterator -class NumpyIteratorTest(absltest.TestCase): - def testBasic(self): - ds = tf.data.Dataset.range(3) - self.assertEqual([0, 1, 2], list(numpy_iterator.NumpyIterator(ds))) - - def testNestedStructure(self): - point = collections.namedtuple('Point', ['x', 'y']) - ds = tf.data.Dataset.from_tensor_slices({ - 'a': ([1, 2], [3, 4]), - 'b': [5, 6], - 'c': point([7, 8], [9, 10]) - }) - self.assertEqual([{ - 'a': (1, 3), - 'b': 5, - 'c': point(7, 9) - }, { - 'a': (2, 4), - 'b': 6, - 'c': point(8, 10) - }], list(numpy_iterator.NumpyIterator(ds))) - -if __name__ == '__main__': - absltest.main() +class NumpyIteratorTest(absltest.TestCase): + def testBasic(self): + ds = tf.data.Dataset.range(3) + self.assertEqual([0, 1, 2], list(numpy_iterator.NumpyIterator(ds))) + + def testNestedStructure(self): + point = collections.namedtuple("Point", ["x", "y"]) + ds = tf.data.Dataset.from_tensor_slices( + {"a": ([1, 2], [3, 4]), "b": [5, 6], "c": point([7, 8], [9, 10])} + ) + self.assertEqual( + [ + {"a": (1, 3), "b": 5, "c": point(7, 9)}, + {"a": (2, 4), "b": 6, "c": point(8, 10)}, + ], + list(numpy_iterator.NumpyIterator(ds)), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/datasets/reverb.py b/acme/datasets/reverb.py index 7751566003..c258283920 100644 --- a/acme/datasets/reverb.py +++ b/acme/datasets/reverb.py @@ -18,12 +18,12 @@ import os from typing import Callable, Mapping, Optional, Union -from acme import specs -from acme import types -from acme.adders import reverb as adders import reverb import tensorflow as tf +from acme import specs, types +from acme.adders import reverb as adders + Transform = Callable[[reverb.ReplaySample], reverb.ReplaySample] @@ -43,7 +43,7 @@ def make_reverb_dataset( using_deprecated_adder: bool = False, sequence_length: Optional[int] = None, ) -> tf.data.Dataset: - """Make a TensorFlow dataset backed by a Reverb trajectory replay service. + """Make a TensorFlow dataset backed by a Reverb trajectory replay service. Arguments: server_address: Address of the Reverb server. @@ -73,85 +73,92 @@ def make_reverb_dataset( mapping with no positive weight values. """ - if environment_spec or extra_spec: - raise ValueError( - 'The make_reverb_dataset factory function no longer requires specs as' - ' as they should be passed as a signature to the reverb.Table when it' - ' is created. Consider either updating your code or falling back to the' - ' deprecated dataset factory in acme/datasets/deprecated.') - - # These are no longer used and are only kept in the call signature for - # backward compatibility. - del environment_spec - del extra_spec - del transition_adder - del convert_zero_size_to_none - del using_deprecated_adder - del sequence_length - - # This is the default that used to be set by reverb.TFClient.dataset(). - if max_in_flight_samples_per_worker is None and batch_size is None: - max_in_flight_samples_per_worker = 100 - elif max_in_flight_samples_per_worker is None: - max_in_flight_samples_per_worker = 2 * batch_size - - # Create mapping from tables to non-zero weights. - if isinstance(table, str): - tables = collections.OrderedDict([(table, 1.)]) - else: - tables = collections.OrderedDict([ - (name, weight) for name, weight in table.items() if weight > 0. - ]) - if len(tables) <= 0: - raise ValueError(f'No positive weights in input tables {tables}') - - # Normalize weights. - total_weight = sum(tables.values()) - tables = collections.OrderedDict([ - (name, weight / total_weight) for name, weight in tables.items() - ]) - - def _make_dataset(unused_idx: tf.Tensor) -> tf.data.Dataset: - datasets = () - for table_name, weight in tables.items(): - max_in_flight_samples = max( - 1, int(max_in_flight_samples_per_worker * weight)) - dataset = reverb.TrajectoryDataset.from_table_signature( - server_address=server_address, - table=table_name, - max_in_flight_samples_per_worker=max_in_flight_samples) - datasets += (dataset,) - if len(datasets) > 1: - dataset = tf.data.Dataset.sample_from_datasets( - datasets, weights=tables.values()) + if environment_spec or extra_spec: + raise ValueError( + "The make_reverb_dataset factory function no longer requires specs as" + " as they should be passed as a signature to the reverb.Table when it" + " is created. Consider either updating your code or falling back to the" + " deprecated dataset factory in acme/datasets/deprecated." + ) + + # These are no longer used and are only kept in the call signature for + # backward compatibility. + del environment_spec + del extra_spec + del transition_adder + del convert_zero_size_to_none + del using_deprecated_adder + del sequence_length + + # This is the default that used to be set by reverb.TFClient.dataset(). + if max_in_flight_samples_per_worker is None and batch_size is None: + max_in_flight_samples_per_worker = 100 + elif max_in_flight_samples_per_worker is None: + max_in_flight_samples_per_worker = 2 * batch_size + + # Create mapping from tables to non-zero weights. + if isinstance(table, str): + tables = collections.OrderedDict([(table, 1.0)]) else: - dataset = datasets[0] - - # Post-process each element if a post-processing function is passed, e.g. - # observation-stacking or data augmenting transformations. - if postprocess: - dataset = dataset.map(postprocess) + tables = collections.OrderedDict( + [(name, weight) for name, weight in table.items() if weight > 0.0] + ) + if len(tables) <= 0: + raise ValueError(f"No positive weights in input tables {tables}") + + # Normalize weights. + total_weight = sum(tables.values()) + tables = collections.OrderedDict( + [(name, weight / total_weight) for name, weight in tables.items()] + ) + + def _make_dataset(unused_idx: tf.Tensor) -> tf.data.Dataset: + datasets = () + for table_name, weight in tables.items(): + max_in_flight_samples = max( + 1, int(max_in_flight_samples_per_worker * weight) + ) + dataset = reverb.TrajectoryDataset.from_table_signature( + server_address=server_address, + table=table_name, + max_in_flight_samples_per_worker=max_in_flight_samples, + ) + datasets += (dataset,) + if len(datasets) > 1: + dataset = tf.data.Dataset.sample_from_datasets( + datasets, weights=tables.values() + ) + else: + dataset = datasets[0] + + # Post-process each element if a post-processing function is passed, e.g. + # observation-stacking or data augmenting transformations. + if postprocess: + dataset = dataset.map(postprocess) + + if batch_size: + dataset = dataset.batch(batch_size, drop_remainder=True) + + return dataset + + if num_parallel_calls is not None: + # Create a datasets and interleaves it to create `num_parallel_calls` + # `TrajectoryDataset`s. + num_datasets_to_interleave = ( + os.cpu_count() + if num_parallel_calls == tf.data.AUTOTUNE + else num_parallel_calls + ) + dataset = tf.data.Dataset.range(num_datasets_to_interleave).interleave( + map_func=_make_dataset, + cycle_length=num_parallel_calls, + num_parallel_calls=num_parallel_calls, + deterministic=False, + ) + else: + dataset = _make_dataset(tf.constant(0)) - if batch_size: - dataset = dataset.batch(batch_size, drop_remainder=True) + if prefetch_size: + dataset = dataset.prefetch(prefetch_size) return dataset - - if num_parallel_calls is not None: - # Create a datasets and interleaves it to create `num_parallel_calls` - # `TrajectoryDataset`s. - num_datasets_to_interleave = ( - os.cpu_count() - if num_parallel_calls == tf.data.AUTOTUNE else num_parallel_calls) - dataset = tf.data.Dataset.range(num_datasets_to_interleave).interleave( - map_func=_make_dataset, - cycle_length=num_parallel_calls, - num_parallel_calls=num_parallel_calls, - deterministic=False) - else: - dataset = _make_dataset(tf.constant(0)) - - if prefetch_size: - dataset = dataset.prefetch(prefetch_size) - - return dataset diff --git a/acme/datasets/reverb_benchmark.py b/acme/datasets/reverb_benchmark.py index 65af786f9a..9205d17a1d 100644 --- a/acme/datasets/reverb_benchmark.py +++ b/acme/datasets/reverb_benchmark.py @@ -20,78 +20,76 @@ import time from typing import Sequence -from absl import app -from absl import logging -from acme import adders -from acme import specs -from acme.adders import reverb as adders_reverb -from acme.datasets import reverb as datasets -from acme.testing import fakes import numpy as np import reverb +from absl import app, logging from reverb import rate_limiters +from acme import adders, specs +from acme.adders import reverb as adders_reverb +from acme.datasets import reverb as datasets +from acme.testing import fakes + -def make_replay_tables(environment_spec: specs.EnvironmentSpec - ) -> Sequence[reverb.Table]: - """Create tables to insert data into.""" - return [ - reverb.Table( - name='default', - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=1000000, - rate_limiter=rate_limiters.MinSize(1), - signature=adders_reverb.NStepTransitionAdder.signature( - environment_spec)) - ] +def make_replay_tables( + environment_spec: specs.EnvironmentSpec, +) -> Sequence[reverb.Table]: + """Create tables to insert data into.""" + return [ + reverb.Table( + name="default", + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=1000000, + rate_limiter=rate_limiters.MinSize(1), + signature=adders_reverb.NStepTransitionAdder.signature(environment_spec), + ) + ] def make_adder(replay_client: reverb.Client) -> adders.Adder: - return adders_reverb.NStepTransitionAdder( - priority_fns={'default': None}, - client=replay_client, - n_step=1, - discount=1) + return adders_reverb.NStepTransitionAdder( + priority_fns={"default": None}, client=replay_client, n_step=1, discount=1 + ) def main(_): - environment = fakes.ContinuousEnvironment(action_dim=8, - observation_dim=87, - episode_length=10000000) - spec = specs.make_environment_spec(environment) - replay_tables = make_replay_tables(spec) - replay_server = reverb.Server(replay_tables, port=None) - replay_client = reverb.Client(f'localhost:{replay_server.port}') - adder = make_adder(replay_client) - - timestep = environment.reset() - adder.add_first(timestep) - # TODO(raveman): Consider also filling the table to say 1M (too slow). - for steps in range(10000): - if steps % 1000 == 0: - logging.info('Processed %s steps', steps) - action = np.asarray(np.random.uniform(-1, 1, (8,)), dtype=np.float32) - next_timestep = environment.step(action) - adder.add(action, next_timestep, extras=()) - - for batch_size in [256, 256 * 8, 256 * 64]: - for prefetch_size in [0, 1, 4]: - print(f'Processing batch_size={batch_size} prefetch_size={prefetch_size}') - ds = datasets.make_reverb_dataset( - table='default', - server_address=replay_client.server_address, - batch_size=batch_size, - prefetch_size=prefetch_size, - ) - it = ds.as_numpy_iterator() - - for iteration in range(3): - t = time.time() - for _ in range(1000): - _ = next(it) - print(f'Iteration {iteration} finished in {time.time() - t}s') - - -if __name__ == '__main__': - app.run(main) + environment = fakes.ContinuousEnvironment( + action_dim=8, observation_dim=87, episode_length=10000000 + ) + spec = specs.make_environment_spec(environment) + replay_tables = make_replay_tables(spec) + replay_server = reverb.Server(replay_tables, port=None) + replay_client = reverb.Client(f"localhost:{replay_server.port}") + adder = make_adder(replay_client) + + timestep = environment.reset() + adder.add_first(timestep) + # TODO(raveman): Consider also filling the table to say 1M (too slow). + for steps in range(10000): + if steps % 1000 == 0: + logging.info("Processed %s steps", steps) + action = np.asarray(np.random.uniform(-1, 1, (8,)), dtype=np.float32) + next_timestep = environment.step(action) + adder.add(action, next_timestep, extras=()) + + for batch_size in [256, 256 * 8, 256 * 64]: + for prefetch_size in [0, 1, 4]: + print(f"Processing batch_size={batch_size} prefetch_size={prefetch_size}") + ds = datasets.make_reverb_dataset( + table="default", + server_address=replay_client.server_address, + batch_size=batch_size, + prefetch_size=prefetch_size, + ) + it = ds.as_numpy_iterator() + + for iteration in range(3): + t = time.time() + for _ in range(1000): + _ = next(it) + print(f"Iteration {iteration} finished in {time.time() - t}s") + + +if __name__ == "__main__": + app.run(main) diff --git a/acme/datasets/tfds.py b/acme/datasets/tfds.py index 80c79e2497..39cdc12c2e 100644 --- a/acme/datasets/tfds.py +++ b/acme/datasets/tfds.py @@ -15,67 +15,67 @@ """Utilities related to loading TFDS datasets.""" import logging -from typing import Any, Iterator, Optional, Tuple, Sequence +from typing import Any, Iterator, Optional, Sequence, Tuple -from acme import specs -from acme import types -from flax import jax_utils import jax import jax.numpy as jnp import numpy as np import rlds import tensorflow as tf import tensorflow_datasets as tfds +from flax import jax_utils + +from acme import specs, types def _batched_step_to_transition(step: rlds.BatchedStep) -> types.Transition: - return types.Transition( - observation=tf.nest.map_structure(lambda x: x[0], step[rlds.OBSERVATION]), - action=tf.nest.map_structure(lambda x: x[0], step[rlds.ACTION]), - reward=tf.nest.map_structure(lambda x: x[0], step[rlds.REWARD]), - discount=1.0 - tf.cast(step[rlds.IS_TERMINAL][1], dtype=tf.float32), - # If next step is terminal, then the observation may be arbitrary. - next_observation=tf.nest.map_structure( - lambda x: x[1], step[rlds.OBSERVATION]) - ) + return types.Transition( + observation=tf.nest.map_structure(lambda x: x[0], step[rlds.OBSERVATION]), + action=tf.nest.map_structure(lambda x: x[0], step[rlds.ACTION]), + reward=tf.nest.map_structure(lambda x: x[0], step[rlds.REWARD]), + discount=1.0 - tf.cast(step[rlds.IS_TERMINAL][1], dtype=tf.float32), + # If next step is terminal, then the observation may be arbitrary. + next_observation=tf.nest.map_structure(lambda x: x[1], step[rlds.OBSERVATION]), + ) def _batch_steps(episode: rlds.Episode) -> tf.data.Dataset: - return rlds.transformations.batch( - episode[rlds.STEPS], size=2, shift=1, drop_remainder=True) + return rlds.transformations.batch( + episode[rlds.STEPS], size=2, shift=1, drop_remainder=True + ) def _dataset_size_upperbound(dataset: tf.data.Dataset) -> int: - if dataset.cardinality() != tf.data.experimental.UNKNOWN_CARDINALITY: - return dataset.cardinality() - return tf.cast( - dataset.batch(1000).reduce(0, lambda x, step: x + 1000), tf.int64) + if dataset.cardinality() != tf.data.experimental.UNKNOWN_CARDINALITY: + return dataset.cardinality() + return tf.cast(dataset.batch(1000).reduce(0, lambda x, step: x + 1000), tf.int64) def load_tfds_dataset( dataset_name: str, num_episodes: Optional[int] = None, - env_spec: Optional[specs.EnvironmentSpec] = None) -> tf.data.Dataset: - """Returns a TFDS dataset with the given name.""" - # Used only in tests. - del env_spec + env_spec: Optional[specs.EnvironmentSpec] = None, +) -> tf.data.Dataset: + """Returns a TFDS dataset with the given name.""" + # Used only in tests. + del env_spec - dataset = tfds.load(dataset_name)['train'] - if num_episodes: - dataset = dataset.take(num_episodes) - return dataset + dataset = tfds.load(dataset_name)["train"] + if num_episodes: + dataset = dataset.take(num_episodes) + return dataset # TODO(sinopalnikov): replace get_ftds_dataset with a pair of load/transform. def get_tfds_dataset( dataset_name: str, num_episodes: Optional[int] = None, - env_spec: Optional[specs.EnvironmentSpec] = None) -> tf.data.Dataset: - """Returns a TFDS dataset transformed to a dataset of transitions.""" - dataset = load_tfds_dataset(dataset_name, num_episodes, env_spec) - batched_steps = dataset.flat_map(_batch_steps) - return rlds.transformations.map_steps(batched_steps, - _batched_step_to_transition) + env_spec: Optional[specs.EnvironmentSpec] = None, +) -> tf.data.Dataset: + """Returns a TFDS dataset transformed to a dataset of transitions.""" + dataset = load_tfds_dataset(dataset_name, num_episodes, env_spec) + batched_steps = dataset.flat_map(_batch_steps) + return rlds.transformations.map_steps(batched_steps, _batched_step_to_transition) # In order to avoid excessive copying on TPU one needs to make the last @@ -84,28 +84,28 @@ def get_tfds_dataset( def _pad(x: jnp.ndarray) -> jnp.ndarray: - if len(x.shape) != 2: + if len(x.shape) != 2: + return x + # Find a more scientific way to find this threshold (30). Depending on various + # conditions for low enough sizes the excessive copying is not triggered. + if x.shape[-1] % _BEST_DIVISOR != 0 and x.shape[-1] > 30: + n = _BEST_DIVISOR - (x.shape[-1] % _BEST_DIVISOR) + x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(0, n)], "constant") return x - # Find a more scientific way to find this threshold (30). Depending on various - # conditions for low enough sizes the excessive copying is not triggered. - if x.shape[-1] % _BEST_DIVISOR != 0 and x.shape[-1] > 30: - n = _BEST_DIVISOR - (x.shape[-1] % _BEST_DIVISOR) - x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(0, n)], 'constant') - return x # Undo the padding. def _unpad(x: jnp.ndarray, shape: Sequence[int]) -> jnp.ndarray: - if len(shape) == 2 and x.shape[-1] != shape[-1]: - return x[..., :shape[-1]] - return x + if len(shape) == 2 and x.shape[-1] != shape[-1]: + return x[..., : shape[-1]] + return x -_PMAP_AXIS_NAME = 'data' +_PMAP_AXIS_NAME = "data" class JaxInMemoryRandomSampleIterator(Iterator[Any]): - """In memory random sample iterator implemented in JAX. + """In memory random sample iterator implemented in JAX. Loads the whole dataset in memory and performs random sampling with replacement of batches of `batch_size`. @@ -113,12 +113,14 @@ class JaxInMemoryRandomSampleIterator(Iterator[Any]): an iterator on tf.data.Dataset. """ - def __init__(self, - dataset: tf.data.Dataset, - key: jnp.ndarray, - batch_size: int, - shard_dataset_across_devices: bool = False): - """Creates an iterator. + def __init__( + self, + dataset: tf.data.Dataset, + key: jnp.ndarray, + batch_size: int, + shard_dataset_across_devices: bool = False, + ): + """Creates an iterator. Args: dataset: underlying tf Dataset @@ -133,77 +135,88 @@ def __init__(self, only within its data chunk The number of available devices must divide the batch_size evenly. """ - # Read the whole dataset. We use artificially large batch_size to make sure - # we capture the whole dataset. - size = _dataset_size_upperbound(dataset) - data = next(dataset.batch(size).as_numpy_iterator()) - self._dataset_size = jax.tree_flatten( - jax.tree_map(lambda x: x.shape[0], data))[0][0] - device = jax_utils._pmap_device_order() - if not shard_dataset_across_devices: - device = device[:1] - should_pmap = len(device) > 1 - assert batch_size % len(device) == 0 - self._dataset_size = self._dataset_size - self._dataset_size % len(device) - # len(device) needs to divide self._dataset_size evenly. - assert self._dataset_size % len(device) == 0 - logging.info('Trying to load %s elements to %s', self._dataset_size, device) - logging.info('Dataset %s %s', - ('before padding' if should_pmap else ''), - jax.tree_map(lambda x: x.shape, data)) - if should_pmap: - shapes = jax.tree_map(lambda x: x.shape, data) - # Padding to a multiple of 128 is needed to avoid excessive copying on TPU - data = jax.tree_map(_pad, data) - logging.info('Dataset after padding %s', - jax.tree_map(lambda x: x.shape, data)) - def split_and_put(x: jnp.ndarray) -> jnp.ndarray: - return jax.device_put_sharded( - np.split(x[:self._dataset_size], len(device)), devices=device) - self._jax_dataset = jax.tree_map(split_and_put, data) - else: - self._jax_dataset = jax.tree_map(jax.device_put, data) - - self._key = (jnp.stack(jax.random.split(key, len(device))) - if should_pmap else key) - - def sample_per_shard(data: Any, - key: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - key1, key2 = jax.random.split(key) - indices = jax.random.randint( - key1, (batch_size // len(device),), - minval=0, - maxval=self._dataset_size // len(device)) - data_sample = jax.tree_map(lambda d: jnp.take(d, indices, axis=0), data) - return data_sample, key2 - - if should_pmap: - def sample(data, key): - data_sample, key = sample_per_shard(data, key) - # Gathering data on TPUs is much more efficient that doing so on a host - # since it avoids Host - Device communications. - data_sample = jax.lax.all_gather( - data_sample, axis_name=_PMAP_AXIS_NAME, axis=0, tiled=True) - data_sample = jax.tree_map(_unpad, data_sample, shapes) - return data_sample, key - - pmapped_sample = jax.pmap(sample, axis_name=_PMAP_AXIS_NAME) - - def sample_and_postprocess(key: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - data, key = pmapped_sample(self._jax_dataset, key) - # All pmapped devices return the same data, so we just take the one from - # the first device. - return jax.tree_map(lambda x: x[0], data), key - self._sample = sample_and_postprocess - else: - self._sample = jax.jit( - lambda key: sample_per_shard(self._jax_dataset, key)) - - def __next__(self) -> Any: - data, self._key = self._sample(self._key) - return data - - @property - def dataset_size(self) -> int: - """An integer of the dataset cardinality.""" - return self._dataset_size + # Read the whole dataset. We use artificially large batch_size to make sure + # we capture the whole dataset. + size = _dataset_size_upperbound(dataset) + data = next(dataset.batch(size).as_numpy_iterator()) + self._dataset_size = jax.tree_flatten(jax.tree_map(lambda x: x.shape[0], data))[ + 0 + ][0] + device = jax_utils._pmap_device_order() + if not shard_dataset_across_devices: + device = device[:1] + should_pmap = len(device) > 1 + assert batch_size % len(device) == 0 + self._dataset_size = self._dataset_size - self._dataset_size % len(device) + # len(device) needs to divide self._dataset_size evenly. + assert self._dataset_size % len(device) == 0 + logging.info("Trying to load %s elements to %s", self._dataset_size, device) + logging.info( + "Dataset %s %s", + ("before padding" if should_pmap else ""), + jax.tree_map(lambda x: x.shape, data), + ) + if should_pmap: + shapes = jax.tree_map(lambda x: x.shape, data) + # Padding to a multiple of 128 is needed to avoid excessive copying on TPU + data = jax.tree_map(_pad, data) + logging.info( + "Dataset after padding %s", jax.tree_map(lambda x: x.shape, data) + ) + + def split_and_put(x: jnp.ndarray) -> jnp.ndarray: + return jax.device_put_sharded( + np.split(x[: self._dataset_size], len(device)), devices=device + ) + + self._jax_dataset = jax.tree_map(split_and_put, data) + else: + self._jax_dataset = jax.tree_map(jax.device_put, data) + + self._key = ( + jnp.stack(jax.random.split(key, len(device))) if should_pmap else key + ) + + def sample_per_shard(data: Any, key: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: + key1, key2 = jax.random.split(key) + indices = jax.random.randint( + key1, + (batch_size // len(device),), + minval=0, + maxval=self._dataset_size // len(device), + ) + data_sample = jax.tree_map(lambda d: jnp.take(d, indices, axis=0), data) + return data_sample, key2 + + if should_pmap: + + def sample(data, key): + data_sample, key = sample_per_shard(data, key) + # Gathering data on TPUs is much more efficient that doing so on a host + # since it avoids Host - Device communications. + data_sample = jax.lax.all_gather( + data_sample, axis_name=_PMAP_AXIS_NAME, axis=0, tiled=True + ) + data_sample = jax.tree_map(_unpad, data_sample, shapes) + return data_sample, key + + pmapped_sample = jax.pmap(sample, axis_name=_PMAP_AXIS_NAME) + + def sample_and_postprocess(key: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: + data, key = pmapped_sample(self._jax_dataset, key) + # All pmapped devices return the same data, so we just take the one from + # the first device. + return jax.tree_map(lambda x: x[0], data), key + + self._sample = sample_and_postprocess + else: + self._sample = jax.jit(lambda key: sample_per_shard(self._jax_dataset, key)) + + def __next__(self) -> Any: + data, self._key = self._sample(self._key) + return data + + @property + def dataset_size(self) -> int: + """An integer of the dataset cardinality.""" + return self._dataset_size diff --git a/acme/environment_loop.py b/acme/environment_loop.py index 23cc972326..8bbc17937d 100644 --- a/acme/environment_loop.py +++ b/acme/environment_loop.py @@ -18,20 +18,19 @@ import time from typing import List, Optional, Sequence -from acme import core -from acme.utils import counting -from acme.utils import loggers -from acme.utils import observers as observers_lib -from acme.utils import signals - import dm_env -from dm_env import specs import numpy as np import tree +from dm_env import specs + +from acme import core +from acme.utils import counting, loggers +from acme.utils import observers as observers_lib +from acme.utils import signals class EnvironmentLoop(core.Worker): - """A simple RL environment loop. + """A simple RL environment loop. This takes `Environment` and `Actor` instances and coordinates their interaction. Agent is updated if `should_update=True`. This can be used as: @@ -54,27 +53,28 @@ class EnvironmentLoop(core.Worker): the current timestep datastruct and the current action. """ - def __init__( - self, - environment: dm_env.Environment, - actor: core.Actor, - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - should_update: bool = True, - label: str = 'environment_loop', - observers: Sequence[observers_lib.EnvLoopObserver] = (), - ): - # Internalize agent and environment. - self._environment = environment - self._actor = actor - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger( - label, steps_key=self._counter.get_steps_key()) - self._should_update = should_update - self._observers = observers - - def run_episode(self) -> loggers.LoggingData: - """Run one episode. + def __init__( + self, + environment: dm_env.Environment, + actor: core.Actor, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + should_update: bool = True, + label: str = "environment_loop", + observers: Sequence[observers_lib.EnvLoopObserver] = (), + ): + # Internalize agent and environment. + self._environment = environment + self._actor = actor + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger( + label, steps_key=self._counter.get_steps_key() + ) + self._should_update = should_update + self._observers = observers + + def run_episode(self) -> loggers.LoggingData: + """Run one episode. Each episode is a loop which interacts first with the environment to get an observation and then give that observation to the agent in order to retrieve @@ -83,85 +83,84 @@ def run_episode(self) -> loggers.LoggingData: Returns: An instance of `loggers.LoggingData`. """ - # Reset any counts and start the environment. - episode_start_time = time.time() - select_action_durations: List[float] = [] - env_step_durations: List[float] = [] - episode_steps: int = 0 - - # For evaluation, this keeps track of the total undiscounted reward - # accumulated during the episode. - episode_return = tree.map_structure(_generate_zeros_from_spec, - self._environment.reward_spec()) - env_reset_start = time.time() - timestep = self._environment.reset() - env_reset_duration = time.time() - env_reset_start - # Make the first observation. - self._actor.observe_first(timestep) - for observer in self._observers: - # Initialize the observer with the current state of the env after reset - # and the initial timestep. - observer.observe_first(self._environment, timestep) - - # Run an episode. - while not timestep.last(): - # Book-keeping. - episode_steps += 1 - - # Generate an action from the agent's policy. - select_action_start = time.time() - action = self._actor.select_action(timestep.observation) - select_action_durations.append(time.time() - select_action_start) - - # Step the environment with the agent's selected action. - env_step_start = time.time() - timestep = self._environment.step(action) - env_step_durations.append(time.time() - env_step_start) - - # Have the agent and observers observe the timestep. - self._actor.observe(action, next_timestep=timestep) - for observer in self._observers: - # One environment step was completed. Observe the current state of the - # environment, the current timestep and the action. - observer.observe(self._environment, timestep, action) - - # Give the actor the opportunity to update itself. - if self._should_update: - self._actor.update() - - # Equivalent to: episode_return += timestep.reward - # We capture the return value because if timestep.reward is a JAX - # DeviceArray, episode_return will not be mutated in-place. (In all other - # cases, the returned episode_return will be the same object as the - # argument episode_return.) - episode_return = tree.map_structure(operator.iadd, - episode_return, - timestep.reward) - - # Record counts. - counts = self._counter.increment(episodes=1, steps=episode_steps) - - # Collect the results and combine with counts. - steps_per_second = episode_steps / (time.time() - episode_start_time) - result = { - 'episode_length': episode_steps, - 'episode_return': episode_return, - 'steps_per_second': steps_per_second, - 'env_reset_duration_sec': env_reset_duration, - 'select_action_duration_sec': np.mean(select_action_durations), - 'env_step_duration_sec': np.mean(env_step_durations), - } - result.update(counts) - for observer in self._observers: - result.update(observer.get_metrics()) - return result - - def run( - self, - num_episodes: Optional[int] = None, - num_steps: Optional[int] = None, - ) -> int: - """Perform the run loop. + # Reset any counts and start the environment. + episode_start_time = time.time() + select_action_durations: List[float] = [] + env_step_durations: List[float] = [] + episode_steps: int = 0 + + # For evaluation, this keeps track of the total undiscounted reward + # accumulated during the episode. + episode_return = tree.map_structure( + _generate_zeros_from_spec, self._environment.reward_spec() + ) + env_reset_start = time.time() + timestep = self._environment.reset() + env_reset_duration = time.time() - env_reset_start + # Make the first observation. + self._actor.observe_first(timestep) + for observer in self._observers: + # Initialize the observer with the current state of the env after reset + # and the initial timestep. + observer.observe_first(self._environment, timestep) + + # Run an episode. + while not timestep.last(): + # Book-keeping. + episode_steps += 1 + + # Generate an action from the agent's policy. + select_action_start = time.time() + action = self._actor.select_action(timestep.observation) + select_action_durations.append(time.time() - select_action_start) + + # Step the environment with the agent's selected action. + env_step_start = time.time() + timestep = self._environment.step(action) + env_step_durations.append(time.time() - env_step_start) + + # Have the agent and observers observe the timestep. + self._actor.observe(action, next_timestep=timestep) + for observer in self._observers: + # One environment step was completed. Observe the current state of the + # environment, the current timestep and the action. + observer.observe(self._environment, timestep, action) + + # Give the actor the opportunity to update itself. + if self._should_update: + self._actor.update() + + # Equivalent to: episode_return += timestep.reward + # We capture the return value because if timestep.reward is a JAX + # DeviceArray, episode_return will not be mutated in-place. (In all other + # cases, the returned episode_return will be the same object as the + # argument episode_return.) + episode_return = tree.map_structure( + operator.iadd, episode_return, timestep.reward + ) + + # Record counts. + counts = self._counter.increment(episodes=1, steps=episode_steps) + + # Collect the results and combine with counts. + steps_per_second = episode_steps / (time.time() - episode_start_time) + result = { + "episode_length": episode_steps, + "episode_return": episode_return, + "steps_per_second": steps_per_second, + "env_reset_duration_sec": env_reset_duration, + "select_action_duration_sec": np.mean(select_action_durations), + "env_step_duration_sec": np.mean(env_step_durations), + } + result.update(counts) + for observer in self._observers: + result.update(observer.get_metrics()) + return result + + def run( + self, num_episodes: Optional[int] = None, num_steps: Optional[int] = None, + ) -> int: + """Perform the run loop. Run the environment loop either for `num_episodes` episodes or for at least `num_steps` steps (the last episode is always run until completion, @@ -183,27 +182,28 @@ def run( ValueError: If both 'num_episodes' and 'num_steps' are not None. """ - if not (num_episodes is None or num_steps is None): - raise ValueError('Either "num_episodes" or "num_steps" should be None.') + if not (num_episodes is None or num_steps is None): + raise ValueError('Either "num_episodes" or "num_steps" should be None.') - def should_terminate(episode_count: int, step_count: int) -> bool: - return ((num_episodes is not None and episode_count >= num_episodes) or - (num_steps is not None and step_count >= num_steps)) + def should_terminate(episode_count: int, step_count: int) -> bool: + return (num_episodes is not None and episode_count >= num_episodes) or ( + num_steps is not None and step_count >= num_steps + ) - episode_count: int = 0 - step_count: int = 0 - with signals.runtime_terminator(): - while not should_terminate(episode_count, step_count): - episode_start = time.time() - result = self.run_episode() - result = {**result, **{'episode_duration': time.time() - episode_start}} - episode_count += 1 - step_count += int(result['episode_length']) - # Log the given episode results. - self._logger.write(result) + episode_count: int = 0 + step_count: int = 0 + with signals.runtime_terminator(): + while not should_terminate(episode_count, step_count): + episode_start = time.time() + result = self.run_episode() + result = {**result, **{"episode_duration": time.time() - episode_start}} + episode_count += 1 + step_count += int(result["episode_length"]) + # Log the given episode results. + self._logger.write(result) - return step_count + return step_count def _generate_zeros_from_spec(spec: specs.Array) -> np.ndarray: - return np.zeros(spec.shape, spec.dtype) + return np.zeros(spec.shape, spec.dtype) diff --git a/acme/environment_loop_test.py b/acme/environment_loop_test.py index 677ddb4590..f9cdc8181d 100644 --- a/acme/environment_loop_test.py +++ b/acme/environment_loop_test.py @@ -16,69 +16,69 @@ from typing import Optional -from acme import environment_loop -from acme import specs -from acme import types -from acme.testing import fakes import numpy as np +from absl.testing import absltest, parameterized -from absl.testing import absltest -from absl.testing import parameterized +from acme import environment_loop, specs, types +from acme.testing import fakes EPISODE_LENGTH = 10 # Discount specs F32_2_MIN_0_MAX_1 = specs.BoundedArray( - dtype=np.float32, shape=(2,), minimum=0.0, maximum=1.0) + dtype=np.float32, shape=(2,), minimum=0.0, maximum=1.0 +) F32_2x1_MIN_0_MAX_1 = specs.BoundedArray( - dtype=np.float32, shape=(2, 1), minimum=0.0, maximum=1.0) -TREE_MIN_0_MAX_1 = {'a': F32_2_MIN_0_MAX_1, 'b': F32_2x1_MIN_0_MAX_1} + dtype=np.float32, shape=(2, 1), minimum=0.0, maximum=1.0 +) +TREE_MIN_0_MAX_1 = {"a": F32_2_MIN_0_MAX_1, "b": F32_2x1_MIN_0_MAX_1} # Reward specs F32 = specs.Array(dtype=np.float32, shape=()) F32_1x3 = specs.Array(dtype=np.float32, shape=(1, 3)) -TREE = {'a': F32, 'b': F32_1x3} +TREE = {"a": F32, "b": F32_1x3} TEST_CASES = ( - ('scalar_discount_scalar_reward', None, None), - ('vector_discount_scalar_reward', F32_2_MIN_0_MAX_1, F32), - ('matrix_discount_matrix_reward', F32_2x1_MIN_0_MAX_1, F32_1x3), - ('tree_discount_tree_reward', TREE_MIN_0_MAX_1, TREE), - ) + ("scalar_discount_scalar_reward", None, None), + ("vector_discount_scalar_reward", F32_2_MIN_0_MAX_1, F32), + ("matrix_discount_matrix_reward", F32_2x1_MIN_0_MAX_1, F32_1x3), + ("tree_discount_tree_reward", TREE_MIN_0_MAX_1, TREE), +) class EnvironmentLoopTest(parameterized.TestCase): - - @parameterized.named_parameters(*TEST_CASES) - def test_one_episode(self, discount_spec, reward_spec): - _, loop = _parameterized_setup(discount_spec, reward_spec) - result = loop.run_episode() - self.assertIn('episode_length', result) - self.assertEqual(EPISODE_LENGTH, result['episode_length']) - self.assertIn('episode_return', result) - self.assertIn('steps_per_second', result) - - @parameterized.named_parameters(*TEST_CASES) - def test_run_episodes(self, discount_spec, reward_spec): - actor, loop = _parameterized_setup(discount_spec, reward_spec) - - # Run the loop. There should be EPISODE_LENGTH update calls per episode. - loop.run(num_episodes=10) - self.assertEqual(actor.num_updates, 10 * EPISODE_LENGTH) - - @parameterized.named_parameters(*TEST_CASES) - def test_run_steps(self, discount_spec, reward_spec): - actor, loop = _parameterized_setup(discount_spec, reward_spec) - - # Run the loop. This will run 2 episodes so that total number of steps is - # at least 15. - loop.run(num_steps=EPISODE_LENGTH + 5) - self.assertEqual(actor.num_updates, 2 * EPISODE_LENGTH) - - -def _parameterized_setup(discount_spec: Optional[types.NestedSpec] = None, - reward_spec: Optional[types.NestedSpec] = None): - """Common setup code that, unlike self.setUp, takes arguments. + @parameterized.named_parameters(*TEST_CASES) + def test_one_episode(self, discount_spec, reward_spec): + _, loop = _parameterized_setup(discount_spec, reward_spec) + result = loop.run_episode() + self.assertIn("episode_length", result) + self.assertEqual(EPISODE_LENGTH, result["episode_length"]) + self.assertIn("episode_return", result) + self.assertIn("steps_per_second", result) + + @parameterized.named_parameters(*TEST_CASES) + def test_run_episodes(self, discount_spec, reward_spec): + actor, loop = _parameterized_setup(discount_spec, reward_spec) + + # Run the loop. There should be EPISODE_LENGTH update calls per episode. + loop.run(num_episodes=10) + self.assertEqual(actor.num_updates, 10 * EPISODE_LENGTH) + + @parameterized.named_parameters(*TEST_CASES) + def test_run_steps(self, discount_spec, reward_spec): + actor, loop = _parameterized_setup(discount_spec, reward_spec) + + # Run the loop. This will run 2 episodes so that total number of steps is + # at least 15. + loop.run(num_steps=EPISODE_LENGTH + 5) + self.assertEqual(actor.num_updates, 2 * EPISODE_LENGTH) + + +def _parameterized_setup( + discount_spec: Optional[types.NestedSpec] = None, + reward_spec: Optional[types.NestedSpec] = None, +): + """Common setup code that, unlike self.setUp, takes arguments. Args: discount_spec: None, or a (nested) specs.BoundedArray. @@ -86,17 +86,17 @@ def _parameterized_setup(discount_spec: Optional[types.NestedSpec] = None, Returns: environment, actor, loop """ - env_kwargs = {'episode_length': EPISODE_LENGTH} - if discount_spec: - env_kwargs['discount_spec'] = discount_spec - if reward_spec: - env_kwargs['reward_spec'] = reward_spec + env_kwargs = {"episode_length": EPISODE_LENGTH} + if discount_spec: + env_kwargs["discount_spec"] = discount_spec + if reward_spec: + env_kwargs["reward_spec"] = reward_spec - environment = fakes.DiscreteEnvironment(**env_kwargs) - actor = fakes.Actor(specs.make_environment_spec(environment)) - loop = environment_loop.EnvironmentLoop(environment, actor) - return actor, loop + environment = fakes.DiscreteEnvironment(**env_kwargs) + actor = fakes.Actor(specs.make_environment_spec(environment)) + loop = environment_loop.EnvironmentLoop(environment, actor) + return actor, loop -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/environment_loops/__init__.py b/acme/environment_loops/__init__.py index 32a4e75237..c6b3c9b061 100644 --- a/acme/environment_loops/__init__.py +++ b/acme/environment_loops/__init__.py @@ -15,7 +15,9 @@ """Specialized environment loops.""" try: - # pylint: disable=g-import-not-at-top - from acme.environment_loops.open_spiel_environment_loop import OpenSpielEnvironmentLoop + # pylint: disable=g-import-not-at-top + from acme.environment_loops.open_spiel_environment_loop import ( + OpenSpielEnvironmentLoop, + ) except ImportError: - pass + pass diff --git a/acme/environment_loops/open_spiel_environment_loop.py b/acme/environment_loops/open_spiel_environment_loop.py index 4e9c81a36c..7548c7380c 100644 --- a/acme/environment_loops/open_spiel_environment_loop.py +++ b/acme/environment_loops/open_spiel_environment_loop.py @@ -18,22 +18,23 @@ import time from typing import Optional, Sequence -from acme import core -from acme.utils import counting -from acme.utils import loggers -from acme.wrappers import open_spiel_wrapper import dm_env -from dm_env import specs import numpy as np -import tree # pytype: disable=import-error import pyspiel +import tree +from dm_env import specs + +from acme import core +from acme.utils import counting, loggers +from acme.wrappers import open_spiel_wrapper + # pytype: enable=import-error class OpenSpielEnvironmentLoop(core.Worker): - """An OpenSpiel RL environment loop. + """An OpenSpiel RL environment loop. This takes `Environment` and list of `Actor` instances and coordinates their interaction. Agents are updated if `should_update=True`. This can be used as: @@ -52,71 +53,78 @@ class OpenSpielEnvironmentLoop(core.Worker): `Logger` instance is given. """ - def __init__( - self, - environment: open_spiel_wrapper.OpenSpielWrapper, - actors: Sequence[core.Actor], - counter: Optional[counting.Counter] = None, - logger: Optional[loggers.Logger] = None, - should_update: bool = True, - label: str = 'open_spiel_environment_loop', - ): - # Internalize agent and environment. - self._environment = environment - self._actors = actors - self._counter = counter or counting.Counter() - self._logger = logger or loggers.make_default_logger(label) - self._should_update = should_update - - # Track information necessary to coordinate updates among multiple actors. - self._observed_first = [False] * len(self._actors) - self._prev_actions = [pyspiel.INVALID_ACTION] * len(self._actors) - - def _send_observation(self, timestep: dm_env.TimeStep, player: int): - # If terminal all actors must update - if player == pyspiel.PlayerId.TERMINAL: - for player_id in range(len(self._actors)): - # Note: we must account for situations where the first observation - # is a terminal state, e.g. if an opponent folds in poker before we get - # to act. - if self._observed_first[player_id]: - player_timestep = self._get_player_timestep(timestep, player_id) - self._actors[player_id].observe(self._prev_actions[player_id], - player_timestep) - if self._should_update: - self._actors[player_id].update() - self._observed_first = [False] * len(self._actors) - self._prev_actions = [pyspiel.INVALID_ACTION] * len(self._actors) - else: - if not self._observed_first[player]: - player_timestep = dm_env.TimeStep( + def __init__( + self, + environment: open_spiel_wrapper.OpenSpielWrapper, + actors: Sequence[core.Actor], + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + should_update: bool = True, + label: str = "open_spiel_environment_loop", + ): + # Internalize agent and environment. + self._environment = environment + self._actors = actors + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger(label) + self._should_update = should_update + + # Track information necessary to coordinate updates among multiple actors. + self._observed_first = [False] * len(self._actors) + self._prev_actions = [pyspiel.INVALID_ACTION] * len(self._actors) + + def _send_observation(self, timestep: dm_env.TimeStep, player: int): + # If terminal all actors must update + if player == pyspiel.PlayerId.TERMINAL: + for player_id in range(len(self._actors)): + # Note: we must account for situations where the first observation + # is a terminal state, e.g. if an opponent folds in poker before we get + # to act. + if self._observed_first[player_id]: + player_timestep = self._get_player_timestep(timestep, player_id) + self._actors[player_id].observe( + self._prev_actions[player_id], player_timestep + ) + if self._should_update: + self._actors[player_id].update() + self._observed_first = [False] * len(self._actors) + self._prev_actions = [pyspiel.INVALID_ACTION] * len(self._actors) + else: + if not self._observed_first[player]: + player_timestep = dm_env.TimeStep( + observation=timestep.observation[player], + reward=None, + discount=None, + step_type=dm_env.StepType.FIRST, + ) + self._actors[player].observe_first(player_timestep) + self._observed_first[player] = True + else: + player_timestep = self._get_player_timestep(timestep, player) + self._actors[player].observe( + self._prev_actions[player], player_timestep + ) + if self._should_update: + self._actors[player].update() + + def _get_action(self, timestep: dm_env.TimeStep, player: int) -> int: + self._prev_actions[player] = self._actors[player].select_action( + timestep.observation[player] + ) + return self._prev_actions[player] + + def _get_player_timestep( + self, timestep: dm_env.TimeStep, player: int + ) -> dm_env.TimeStep: + return dm_env.TimeStep( observation=timestep.observation[player], - reward=None, - discount=None, - step_type=dm_env.StepType.FIRST) - self._actors[player].observe_first(player_timestep) - self._observed_first[player] = True - else: - player_timestep = self._get_player_timestep(timestep, player) - self._actors[player].observe(self._prev_actions[player], - player_timestep) - if self._should_update: - self._actors[player].update() - - def _get_action(self, timestep: dm_env.TimeStep, player: int) -> int: - self._prev_actions[player] = self._actors[player].select_action( - timestep.observation[player]) - return self._prev_actions[player] - - def _get_player_timestep(self, timestep: dm_env.TimeStep, - player: int) -> dm_env.TimeStep: - return dm_env.TimeStep(observation=timestep.observation[player], - reward=timestep.reward[player], - discount=timestep.discount[player], - step_type=timestep.step_type) - - def run_episode(self) -> loggers.LoggingData: - """Run one episode. + reward=timestep.reward[player], + discount=timestep.discount[player], + step_type=timestep.step_type, + ) + + def run_episode(self) -> loggers.LoggingData: + """Run one episode. Each episode is a loop which interacts first with the environment to get an observation and then give that observation to the agent in order to retrieve @@ -125,70 +133,70 @@ def run_episode(self) -> loggers.LoggingData: Returns: An instance of `loggers.LoggingData`. """ - # Reset any counts and start the environment. - start_time = time.time() - episode_steps = 0 - - # For evaluation, this keeps track of the total undiscounted reward - # for each player accumulated during the episode. - multiplayer_reward_spec = specs.BoundedArray( - (self._environment.game.num_players(),), - np.float32, - minimum=self._environment.game.min_utility(), - maximum=self._environment.game.max_utility()) - episode_return = tree.map_structure(_generate_zeros_from_spec, - multiplayer_reward_spec) - - timestep = self._environment.reset() - - # Make the first observation. - self._send_observation(timestep, self._environment.current_player) - - # Run an episode. - while not timestep.last(): - # Generate an action from the agent's policy and step the environment. - if self._environment.is_turn_based: - action_list = [ - self._get_action(timestep, self._environment.current_player) - ] - else: - # FIXME: Support simultaneous move games. - raise ValueError('Currently only supports sequential games.') - - timestep = self._environment.step(action_list) - - # Have the agent observe the timestep and let the actor update itself. - self._send_observation(timestep, self._environment.current_player) - - # Book-keeping. - episode_steps += 1 - - # Equivalent to: episode_return += timestep.reward - # We capture the return value because if timestep.reward is a JAX - # DeviceArray, episode_return will not be mutated in-place. (In all other - # cases, the returned episode_return will be the same object as the - # argument episode_return.) - episode_return = tree.map_structure(operator.iadd, - episode_return, - timestep.reward) - - # Record counts. - counts = self._counter.increment(episodes=1, steps=episode_steps) - - # Collect the results and combine with counts. - steps_per_second = episode_steps / (time.time() - start_time) - result = { - 'episode_length': episode_steps, - 'episode_return': episode_return, - 'steps_per_second': steps_per_second, - } - result.update(counts) - return result - - def run(self, - num_episodes: Optional[int] = None, - num_steps: Optional[int] = None): - """Perform the run loop. + # Reset any counts and start the environment. + start_time = time.time() + episode_steps = 0 + + # For evaluation, this keeps track of the total undiscounted reward + # for each player accumulated during the episode. + multiplayer_reward_spec = specs.BoundedArray( + (self._environment.game.num_players(),), + np.float32, + minimum=self._environment.game.min_utility(), + maximum=self._environment.game.max_utility(), + ) + episode_return = tree.map_structure( + _generate_zeros_from_spec, multiplayer_reward_spec + ) + + timestep = self._environment.reset() + + # Make the first observation. + self._send_observation(timestep, self._environment.current_player) + + # Run an episode. + while not timestep.last(): + # Generate an action from the agent's policy and step the environment. + if self._environment.is_turn_based: + action_list = [ + self._get_action(timestep, self._environment.current_player) + ] + else: + # FIXME: Support simultaneous move games. + raise ValueError("Currently only supports sequential games.") + + timestep = self._environment.step(action_list) + + # Have the agent observe the timestep and let the actor update itself. + self._send_observation(timestep, self._environment.current_player) + + # Book-keeping. + episode_steps += 1 + + # Equivalent to: episode_return += timestep.reward + # We capture the return value because if timestep.reward is a JAX + # DeviceArray, episode_return will not be mutated in-place. (In all other + # cases, the returned episode_return will be the same object as the + # argument episode_return.) + episode_return = tree.map_structure( + operator.iadd, episode_return, timestep.reward + ) + + # Record counts. + counts = self._counter.increment(episodes=1, steps=episode_steps) + + # Collect the results and combine with counts. + steps_per_second = episode_steps / (time.time() - start_time) + result = { + "episode_length": episode_steps, + "episode_return": episode_return, + "steps_per_second": steps_per_second, + } + result.update(counts) + return result + + def run(self, num_episodes: Optional[int] = None, num_steps: Optional[int] = None): + """Perform the run loop. Run the environment loop either for `num_episodes` episodes or for at least `num_steps` steps (the last episode is always run until completion, @@ -207,21 +215,22 @@ def run(self, ValueError: If both 'num_episodes' and 'num_steps' are not None. """ - if not (num_episodes is None or num_steps is None): - raise ValueError('Either "num_episodes" or "num_steps" should be None.') + if not (num_episodes is None or num_steps is None): + raise ValueError('Either "num_episodes" or "num_steps" should be None.') - def should_terminate(episode_count: int, step_count: int) -> bool: - return ((num_episodes is not None and episode_count >= num_episodes) or - (num_steps is not None and step_count >= num_steps)) + def should_terminate(episode_count: int, step_count: int) -> bool: + return (num_episodes is not None and episode_count >= num_episodes) or ( + num_steps is not None and step_count >= num_steps + ) - episode_count, step_count = 0, 0 - while not should_terminate(episode_count, step_count): - result = self.run_episode() - episode_count += 1 - step_count += result['episode_length'] - # Log the given results. - self._logger.write(result) + episode_count, step_count = 0, 0 + while not should_terminate(episode_count, step_count): + result = self.run_episode() + episode_count += 1 + step_count += result["episode_length"] + # Log the given results. + self._logger.write(result) def _generate_zeros_from_spec(spec: specs.Array) -> np.ndarray: - return np.zeros(spec.shape, spec.dtype) + return np.zeros(spec.shape, spec.dtype) diff --git a/acme/environment_loops/open_spiel_environment_loop_test.py b/acme/environment_loops/open_spiel_environment_loop_test.py index e09d3d4957..b7dcf56245 100644 --- a/acme/environment_loops/open_spiel_environment_loop_test.py +++ b/acme/environment_loops/open_spiel_environment_loop_test.py @@ -16,86 +16,83 @@ import unittest -import acme -from acme import core -from acme import specs -from acme import types -from acme import wrappers import dm_env import numpy as np import tree +from absl.testing import absltest, parameterized -from absl.testing import absltest -from absl.testing import parameterized +import acme +from acme import core, specs, types, wrappers SKIP_OPEN_SPIEL_TESTS = False -SKIP_OPEN_SPIEL_MESSAGE = 'open_spiel not installed.' +SKIP_OPEN_SPIEL_MESSAGE = "open_spiel not installed." try: - # pylint: disable=g-import-not-at-top - # pytype: disable=import-error - from acme.environment_loops import open_spiel_environment_loop - from acme.wrappers import open_spiel_wrapper - from open_spiel.python import rl_environment - # pytype: disable=import-error - - class RandomActor(core.Actor): - """Fake actor which generates random actions and validates specs.""" - - def __init__(self, spec: specs.EnvironmentSpec): - self._spec = spec - self.num_updates = 0 - - def select_action(self, observation: open_spiel_wrapper.OLT) -> int: - _validate_spec(self._spec.observations, observation) - legals = np.array(np.nonzero(observation.legal_actions), dtype=np.int32) - return np.random.choice(legals[0]) - - def observe_first(self, timestep: dm_env.TimeStep): - _validate_spec(self._spec.observations, timestep.observation) - - def observe(self, action: types.NestedArray, - next_timestep: dm_env.TimeStep): - _validate_spec(self._spec.actions, action) - _validate_spec(self._spec.rewards, next_timestep.reward) - _validate_spec(self._spec.discounts, next_timestep.discount) - _validate_spec(self._spec.observations, next_timestep.observation) - - def update(self, wait: bool = False): - self.num_updates += 1 + # pylint: disable=g-import-not-at-top + # pytype: disable=import-error + from open_spiel.python import rl_environment + + from acme.environment_loops import open_spiel_environment_loop + from acme.wrappers import open_spiel_wrapper + + # pytype: disable=import-error + + class RandomActor(core.Actor): + """Fake actor which generates random actions and validates specs.""" + + def __init__(self, spec: specs.EnvironmentSpec): + self._spec = spec + self.num_updates = 0 + + def select_action(self, observation: open_spiel_wrapper.OLT) -> int: + _validate_spec(self._spec.observations, observation) + legals = np.array(np.nonzero(observation.legal_actions), dtype=np.int32) + return np.random.choice(legals[0]) + + def observe_first(self, timestep: dm_env.TimeStep): + _validate_spec(self._spec.observations, timestep.observation) + + def observe(self, action: types.NestedArray, next_timestep: dm_env.TimeStep): + _validate_spec(self._spec.actions, action) + _validate_spec(self._spec.rewards, next_timestep.reward) + _validate_spec(self._spec.discounts, next_timestep.discount) + _validate_spec(self._spec.observations, next_timestep.observation) + + def update(self, wait: bool = False): + self.num_updates += 1 + except ModuleNotFoundError: - SKIP_OPEN_SPIEL_TESTS = True + SKIP_OPEN_SPIEL_TESTS = True def _validate_spec(spec: types.NestedSpec, value: types.NestedArray): - """Validate a value from a potentially nested spec.""" - tree.assert_same_structure(value, spec) - tree.map_structure(lambda s, v: s.validate(v), spec, value) + """Validate a value from a potentially nested spec.""" + tree.assert_same_structure(value, spec) + tree.map_structure(lambda s, v: s.validate(v), spec, value) @unittest.skipIf(SKIP_OPEN_SPIEL_TESTS, SKIP_OPEN_SPIEL_MESSAGE) class OpenSpielEnvironmentLoopTest(parameterized.TestCase): + def test_loop_run(self): + raw_env = rl_environment.Environment("tic_tac_toe") + env = open_spiel_wrapper.OpenSpielWrapper(raw_env) + env = wrappers.SinglePrecisionWrapper(env) + environment_spec = acme.make_environment_spec(env) - def test_loop_run(self): - raw_env = rl_environment.Environment('tic_tac_toe') - env = open_spiel_wrapper.OpenSpielWrapper(raw_env) - env = wrappers.SinglePrecisionWrapper(env) - environment_spec = acme.make_environment_spec(env) - - actors = [] - for _ in range(env.num_players): - actors.append(RandomActor(environment_spec)) + actors = [] + for _ in range(env.num_players): + actors.append(RandomActor(environment_spec)) - loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop(env, actors) - result = loop.run_episode() - self.assertIn('episode_length', result) - self.assertIn('episode_return', result) - self.assertIn('steps_per_second', result) + loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop(env, actors) + result = loop.run_episode() + self.assertIn("episode_length", result) + self.assertIn("episode_return", result) + self.assertIn("steps_per_second", result) - loop.run(num_episodes=10) - loop.run(num_steps=100) + loop.run(num_episodes=10) + loop.run(num_steps=100) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/jax/__init__.py b/acme/jax/__init__.py index 240cb71526..de867df849 100644 --- a/acme/jax/__init__.py +++ b/acme/jax/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/acme/jax/experiments/__init__.py b/acme/jax/experiments/__init__.py index 4fe42aec72..09ccf9a673 100644 --- a/acme/jax/experiments/__init__.py +++ b/acme/jax/experiments/__init__.py @@ -14,18 +14,22 @@ """JAX experiment utils.""" -from acme.jax.experiments.config import CheckpointingConfig -from acme.jax.experiments.config import default_evaluator_factory -from acme.jax.experiments.config import DeprecatedPolicyFactory -from acme.jax.experiments.config import EvaluatorFactory -from acme.jax.experiments.config import ExperimentConfig -from acme.jax.experiments.config import make_policy -from acme.jax.experiments.config import MakeActorFn -from acme.jax.experiments.config import NetworkFactory -from acme.jax.experiments.config import OfflineExperimentConfig -from acme.jax.experiments.config import PolicyFactory -from acme.jax.experiments.config import SnapshotModelFactory +from acme.jax.experiments.config import ( + CheckpointingConfig, + DeprecatedPolicyFactory, + EvaluatorFactory, + ExperimentConfig, + MakeActorFn, + NetworkFactory, + OfflineExperimentConfig, + PolicyFactory, + SnapshotModelFactory, + default_evaluator_factory, + make_policy, +) from acme.jax.experiments.make_distributed_experiment import make_distributed_experiment -from acme.jax.experiments.make_distributed_offline_experiment import make_distributed_offline_experiment +from acme.jax.experiments.make_distributed_offline_experiment import ( + make_distributed_offline_experiment, +) from acme.jax.experiments.run_experiment import run_experiment from acme.jax.experiments.run_offline_experiment import run_offline_experiment diff --git a/acme/jax/experiments/config.py b/acme/jax/experiments/config.py index 5836919ba7..47c1938913 100644 --- a/acme/jax/experiments/config.py +++ b/acme/jax/experiments/config.py @@ -18,71 +18,68 @@ import datetime from typing import Any, Callable, Dict, Generic, Iterator, Optional, Sequence -from acme import core -from acme import environment_loop -from acme import specs -from acme.agents.jax import builders -from acme.jax import types -from acme.jax import utils -from acme.utils import counting -from acme.utils import loggers -from acme.utils import observers as observers_lib -from acme.utils import experiment_utils import jax from typing_extensions import Protocol +from acme import core, environment_loop, specs +from acme.agents.jax import builders +from acme.jax import types, utils +from acme.utils import counting, experiment_utils, loggers +from acme.utils import observers as observers_lib + class MakeActorFn(Protocol, Generic[builders.Policy]): - - def __call__(self, random_key: types.PRNGKey, policy: builders.Policy, - environment_spec: specs.EnvironmentSpec, - variable_source: core.VariableSource) -> core.Actor: - ... + def __call__( + self, + random_key: types.PRNGKey, + policy: builders.Policy, + environment_spec: specs.EnvironmentSpec, + variable_source: core.VariableSource, + ) -> core.Actor: + ... class NetworkFactory(Protocol, Generic[builders.Networks]): + def __call__(self, environment_spec: specs.EnvironmentSpec) -> builders.Networks: + ... - def __call__(self, - environment_spec: specs.EnvironmentSpec) -> builders.Networks: - ... - - -class DeprecatedPolicyFactory(Protocol, Generic[builders.Networks, - builders.Policy]): - def __call__(self, networks: builders.Networks) -> builders.Policy: - ... +class DeprecatedPolicyFactory(Protocol, Generic[builders.Networks, builders.Policy]): + def __call__(self, networks: builders.Networks) -> builders.Policy: + ... class PolicyFactory(Protocol, Generic[builders.Networks, builders.Policy]): - - def __call__(self, networks: builders.Networks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool) -> builders.Policy: - ... + def __call__( + self, + networks: builders.Networks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool, + ) -> builders.Policy: + ... class EvaluatorFactory(Protocol, Generic[builders.Policy]): - - def __call__(self, random_key: types.PRNGKey, - variable_source: core.VariableSource, counter: counting.Counter, - make_actor_fn: MakeActorFn[builders.Policy]) -> core.Worker: - ... + def __call__( + self, + random_key: types.PRNGKey, + variable_source: core.VariableSource, + counter: counting.Counter, + make_actor_fn: MakeActorFn[builders.Policy], + ) -> core.Worker: + ... class SnapshotModelFactory(Protocol, Generic[builders.Networks]): - - def __call__( - self, networks: builders.Networks, environment_spec: specs.EnvironmentSpec - ) -> Dict[str, Callable[[core.VariableSource], types.ModelToSnapshot]]: - ... - - + def __call__( + self, networks: builders.Networks, environment_spec: specs.EnvironmentSpec + ) -> Dict[str, Callable[[core.VariableSource], types.ModelToSnapshot]]: + ... @dataclasses.dataclass(frozen=True) class CheckpointingConfig: - """Configuration options for checkpointing. + """Configuration options for checkpointing. Attributes: max_to_keep: Maximum number of checkpoints to keep. Unless preserved by @@ -108,21 +105,21 @@ class CheckpointingConfig: checkpoint_ttl_seconds: TTL (time to leave) in seconds for checkpoints. Indefinite if set to None. """ - max_to_keep: int = 1 - directory: str = '~/acme' - add_uid: bool = True - time_delta_minutes: int = 5 - keep_checkpoint_every_n_hours: Optional[int] = None - replay_checkpointing_time_delta_minutes: Optional[int] = None - checkpoint_ttl_seconds: Optional[int] = int( - datetime.timedelta(days=5).total_seconds() - ) + + max_to_keep: int = 1 + directory: str = "~/acme" + add_uid: bool = True + time_delta_minutes: int = 5 + keep_checkpoint_every_n_hours: Optional[int] = None + replay_checkpointing_time_delta_minutes: Optional[int] = None + checkpoint_ttl_seconds: Optional[int] = int( + datetime.timedelta(days=5).total_seconds() + ) @dataclasses.dataclass(frozen=True) -class ExperimentConfig(Generic[builders.Networks, builders.Policy, - builders.Sample]): - """Config which defines aspects of constructing an experiment. +class ExperimentConfig(Generic[builders.Networks, builders.Policy, builders.Sample]): + """Config which defines aspects of constructing an experiment. Attributes: builder: Builds components of an RL agent (Learner, Actor...). @@ -147,66 +144,75 @@ class ExperimentConfig(Generic[builders.Networks, builders.Policy, checkpointing: Configuration options for checkpointing. If None, checkpointing and snapshotting is disabled. """ - # Below fields must be explicitly specified for any Agent. - builder: builders.ActorLearnerBuilder[builders.Networks, builders.Policy, - builders.Sample] - network_factory: NetworkFactory[builders.Networks] - environment_factory: types.EnvironmentFactory - max_num_actor_steps: int - seed: int - # policy_network_factory is deprecated. Use builder.make_policy to - # create the policy. - policy_network_factory: Optional[DeprecatedPolicyFactory[ - builders.Networks, builders.Policy]] = None - # Fields below are optional. If you just started with Acme do not worry about - # them. You might need them later when you want to customize your RL agent. - # TODO(stanczyk): Introduce a marker for the default value (instead of None). - evaluator_factories: Optional[Sequence[EvaluatorFactory[ - builders.Policy]]] = None - # eval_policy_network_factory is deprecated. Use builder.make_policy to - # create the policy. - eval_policy_network_factory: Optional[DeprecatedPolicyFactory[ - builders.Networks, builders.Policy]] = None - environment_spec: Optional[specs.EnvironmentSpec] = None - observers: Sequence[observers_lib.EnvLoopObserver] = () - logger_factory: loggers.LoggerFactory = dataclasses.field( - default_factory=experiment_utils.create_experiment_logger_factory) - checkpointing: Optional[CheckpointingConfig] = CheckpointingConfig() - - # TODO(stanczyk): Make get_evaluator_factories a standalone function. - def get_evaluator_factories(self): - """Constructs the evaluator factories.""" - if self.evaluator_factories is not None: - return self.evaluator_factories - - def eval_policy_factory(networks: builders.Networks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool) -> builders.Policy: - del evaluation - # The config factory has precedence until all agents are migrated to use - # builder.make_policy - if self.eval_policy_network_factory is not None: - return self.eval_policy_network_factory(networks) - else: - return self.builder.make_policy( - networks=networks, - environment_spec=environment_spec, - evaluation=True) - - return [ - default_evaluator_factory( - environment_factory=self.environment_factory, - network_factory=self.network_factory, - policy_factory=eval_policy_factory, - logger_factory=self.logger_factory, - observers=self.observers) + + # Below fields must be explicitly specified for any Agent. + builder: builders.ActorLearnerBuilder[ + builders.Networks, builders.Policy, builders.Sample ] + network_factory: NetworkFactory[builders.Networks] + environment_factory: types.EnvironmentFactory + max_num_actor_steps: int + seed: int + # policy_network_factory is deprecated. Use builder.make_policy to + # create the policy. + policy_network_factory: Optional[ + DeprecatedPolicyFactory[builders.Networks, builders.Policy] + ] = None + # Fields below are optional. If you just started with Acme do not worry about + # them. You might need them later when you want to customize your RL agent. + # TODO(stanczyk): Introduce a marker for the default value (instead of None). + evaluator_factories: Optional[Sequence[EvaluatorFactory[builders.Policy]]] = None + # eval_policy_network_factory is deprecated. Use builder.make_policy to + # create the policy. + eval_policy_network_factory: Optional[ + DeprecatedPolicyFactory[builders.Networks, builders.Policy] + ] = None + environment_spec: Optional[specs.EnvironmentSpec] = None + observers: Sequence[observers_lib.EnvLoopObserver] = () + logger_factory: loggers.LoggerFactory = dataclasses.field( + default_factory=experiment_utils.create_experiment_logger_factory + ) + checkpointing: Optional[CheckpointingConfig] = CheckpointingConfig() + + # TODO(stanczyk): Make get_evaluator_factories a standalone function. + def get_evaluator_factories(self): + """Constructs the evaluator factories.""" + if self.evaluator_factories is not None: + return self.evaluator_factories + + def eval_policy_factory( + networks: builders.Networks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool, + ) -> builders.Policy: + del evaluation + # The config factory has precedence until all agents are migrated to use + # builder.make_policy + if self.eval_policy_network_factory is not None: + return self.eval_policy_network_factory(networks) + else: + return self.builder.make_policy( + networks=networks, + environment_spec=environment_spec, + evaluation=True, + ) + + return [ + default_evaluator_factory( + environment_factory=self.environment_factory, + network_factory=self.network_factory, + policy_factory=eval_policy_factory, + logger_factory=self.logger_factory, + observers=self.observers, + ) + ] @dataclasses.dataclass -class OfflineExperimentConfig(Generic[builders.Networks, builders.Policy, - builders.Sample]): - """Config which defines aspects of constructing an offline RL experiment. +class OfflineExperimentConfig( + Generic[builders.Networks, builders.Policy, builders.Sample] +): + """Config which defines aspects of constructing an offline RL experiment. This class is similar to the ExperimentConfig, but is tailored to offline RL setting, so it excludes attributes related to training via interaction with @@ -235,44 +241,48 @@ class OfflineExperimentConfig(Generic[builders.Networks, builders.Policy, checkpointing: Configuration options for checkpointing. If None, checkpointing and snapshotting is disabled. """ - # Below fields must be explicitly specified for any Agent. - builder: builders.OfflineBuilder[builders.Networks, builders.Policy, - builders.Sample] - network_factory: Callable[[specs.EnvironmentSpec], builders.Networks] - demonstration_dataset_factory: Callable[[types.PRNGKey], - Iterator[builders.Sample]] - environment_factory: types.EnvironmentFactory - max_num_learner_steps: int - seed: int - # Fields below are optional. If you just started with Acme do not worry about - # them. You might need them later when you want to customize your RL agent. - # TODO(stanczyk): Introduce a marker for the default value (instead of None). - evaluator_factories: Optional[Sequence[EvaluatorFactory]] = None - environment_spec: Optional[specs.EnvironmentSpec] = None - observers: Sequence[observers_lib.EnvLoopObserver] = () - logger_factory: loggers.LoggerFactory = dataclasses.field( - default_factory=experiment_utils.create_experiment_logger_factory) - checkpointing: Optional[CheckpointingConfig] = CheckpointingConfig() - - # TODO(stanczyk): Make get_evaluator_factories a standalone function. - def get_evaluator_factories(self): - """Constructs the evaluator factories.""" - if self.evaluator_factories is not None: - return self.evaluator_factories - if self.environment_factory is None: - raise ValueError( - 'You need to set `environment_factory` in `OfflineExperimentConfig` ' - 'when `evaluator_factories` are not specified. To disable evaluation ' - 'altogether just set `evaluator_factories = []`') - - return [ - default_evaluator_factory( - environment_factory=self.environment_factory, - network_factory=self.network_factory, - policy_factory=self.builder.make_policy, - logger_factory=self.logger_factory, - observers=self.observers) + + # Below fields must be explicitly specified for any Agent. + builder: builders.OfflineBuilder[ + builders.Networks, builders.Policy, builders.Sample ] + network_factory: Callable[[specs.EnvironmentSpec], builders.Networks] + demonstration_dataset_factory: Callable[[types.PRNGKey], Iterator[builders.Sample]] + environment_factory: types.EnvironmentFactory + max_num_learner_steps: int + seed: int + # Fields below are optional. If you just started with Acme do not worry about + # them. You might need them later when you want to customize your RL agent. + # TODO(stanczyk): Introduce a marker for the default value (instead of None). + evaluator_factories: Optional[Sequence[EvaluatorFactory]] = None + environment_spec: Optional[specs.EnvironmentSpec] = None + observers: Sequence[observers_lib.EnvLoopObserver] = () + logger_factory: loggers.LoggerFactory = dataclasses.field( + default_factory=experiment_utils.create_experiment_logger_factory + ) + checkpointing: Optional[CheckpointingConfig] = CheckpointingConfig() + + # TODO(stanczyk): Make get_evaluator_factories a standalone function. + def get_evaluator_factories(self): + """Constructs the evaluator factories.""" + if self.evaluator_factories is not None: + return self.evaluator_factories + if self.environment_factory is None: + raise ValueError( + "You need to set `environment_factory` in `OfflineExperimentConfig` " + "when `evaluator_factories` are not specified. To disable evaluation " + "altogether just set `evaluator_factories = []`" + ) + + return [ + default_evaluator_factory( + environment_factory=self.environment_factory, + network_factory=self.network_factory, + policy_factory=self.builder.make_policy, + logger_factory=self.logger_factory, + observers=self.observers, + ) + ] def default_evaluator_factory( @@ -282,48 +292,50 @@ def default_evaluator_factory( logger_factory: loggers.LoggerFactory, observers: Sequence[observers_lib.EnvLoopObserver] = (), ) -> EvaluatorFactory[builders.Policy]: - """Returns a default evaluator process.""" - - def evaluator( - random_key: types.PRNGKey, - variable_source: core.VariableSource, - counter: counting.Counter, - make_actor: MakeActorFn[builders.Policy], - ): - """The evaluation process.""" - - # Create environment and evaluator networks - environment_key, actor_key = jax.random.split(random_key) - # Environments normally require uint32 as a seed. - environment = environment_factory(utils.sample_uint32(environment_key)) - environment_spec = specs.make_environment_spec(environment) - networks = network_factory(environment_spec) - policy = policy_factory(networks, environment_spec, True) - actor = make_actor(actor_key, policy, environment_spec, variable_source) - - # Create logger and counter. - counter = counting.Counter(counter, 'evaluator') - logger = logger_factory('evaluator', 'actor_steps', 0) - - # Create the run loop and return it. - return environment_loop.EnvironmentLoop( - environment, actor, counter, logger, observers=observers) - - return evaluator - - -def make_policy(experiment: ExperimentConfig[builders.Networks, builders.Policy, - Any], networks: builders.Networks, - environment_spec: specs.EnvironmentSpec, - evaluation: bool) -> builders.Policy: - """Constructs a policy. It is only meant to be used internally.""" - # TODO(sabela): remove and update callers once all agents use - # builder.make_policy - if not evaluation and experiment.policy_network_factory: - return experiment.policy_network_factory(networks) - if evaluation and experiment.eval_policy_network_factory: - return experiment.eval_policy_network_factory(networks) - return experiment.builder.make_policy( - networks=networks, - environment_spec=environment_spec, - evaluation=evaluation) + """Returns a default evaluator process.""" + + def evaluator( + random_key: types.PRNGKey, + variable_source: core.VariableSource, + counter: counting.Counter, + make_actor: MakeActorFn[builders.Policy], + ): + """The evaluation process.""" + + # Create environment and evaluator networks + environment_key, actor_key = jax.random.split(random_key) + # Environments normally require uint32 as a seed. + environment = environment_factory(utils.sample_uint32(environment_key)) + environment_spec = specs.make_environment_spec(environment) + networks = network_factory(environment_spec) + policy = policy_factory(networks, environment_spec, True) + actor = make_actor(actor_key, policy, environment_spec, variable_source) + + # Create logger and counter. + counter = counting.Counter(counter, "evaluator") + logger = logger_factory("evaluator", "actor_steps", 0) + + # Create the run loop and return it. + return environment_loop.EnvironmentLoop( + environment, actor, counter, logger, observers=observers + ) + + return evaluator + + +def make_policy( + experiment: ExperimentConfig[builders.Networks, builders.Policy, Any], + networks: builders.Networks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool, +) -> builders.Policy: + """Constructs a policy. It is only meant to be used internally.""" + # TODO(sabela): remove and update callers once all agents use + # builder.make_policy + if not evaluation and experiment.policy_network_factory: + return experiment.policy_network_factory(networks) + if evaluation and experiment.eval_policy_network_factory: + return experiment.eval_policy_network_factory(networks) + return experiment.builder.make_policy( + networks=networks, environment_spec=environment_spec, evaluation=evaluation + ) diff --git a/acme/jax/experiments/make_distributed_experiment.py b/acme/jax/experiments/make_distributed_experiment.py index d90022fb05..4e278534cd 100644 --- a/acme/jax/experiments/make_distributed_experiment.py +++ b/acme/jax/experiments/make_distributed_experiment.py @@ -18,29 +18,20 @@ import math from typing import Any, List, Optional -from acme import core -from acme import environment_loop -from acme import specs -from acme.agents.jax import actor_core -from acme.agents.jax import builders -from acme.jax import inference_server as inference_server_lib -from acme.jax import networks as networks_lib -from acme.jax import savers -from acme.jax import utils -from acme.jax import variable_utils -from acme.jax.experiments import config -from acme.jax import snapshotter -from acme.utils import counting -from acme.utils import lp_utils import jax import launchpad as lp import reverb -ActorId = int -InferenceServer = inference_server_lib.InferenceServer[ - actor_core.SelectActionFn] - +from acme import core, environment_loop, specs +from acme.agents.jax import actor_core, builders +from acme.jax import inference_server as inference_server_lib +from acme.jax import networks as networks_lib +from acme.jax import savers, snapshotter, utils, variable_utils +from acme.jax.experiments import config +from acme.utils import counting, lp_utils +ActorId = int +InferenceServer = inference_server_lib.InferenceServer[actor_core.SelectActionFn] def make_distributed_experiment( @@ -58,10 +49,10 @@ def make_distributed_experiment( make_snapshot_models: Optional[ config.SnapshotModelFactory[builders.Networks] ] = None, - name: str = 'agent', + name: str = "agent", program: Optional[lp.Program] = None, ) -> lp.Program: - """Builds a Launchpad program for running the experiment. + """Builds a Launchpad program for running the experiment. Args: experiment: configuration of the experiment. @@ -98,307 +89,332 @@ def make_distributed_experiment( The Launchpad program with all the nodes needed for running the experiment. """ - if multithreading_colocate_learner_and_reverb and num_learner_nodes > 1: - raise ValueError( - 'Replay and learner colocation is not yet supported when the learner is' - ' spread across multiple nodes (num_learner_nodes > 1). Please contact' - ' Acme devs if this is a feature you want. Got:' - '\tmultithreading_colocate_learner_and_reverb=' - f'{multithreading_colocate_learner_and_reverb}' - f'\tnum_learner_nodes={num_learner_nodes}.') - - - def build_replay(): - """The replay storage.""" - dummy_seed = 1 - spec = ( - experiment.environment_spec or - specs.make_environment_spec(experiment.environment_factory(dummy_seed))) - network = experiment.network_factory(spec) - policy = config.make_policy( - experiment=experiment, - networks=network, - environment_spec=spec, - evaluation=False) - return experiment.builder.make_replay_tables(spec, policy) - - def build_model_saver(variable_source: core.VariableSource): - assert experiment.checkpointing - environment = experiment.environment_factory(0) - spec = specs.make_environment_spec(environment) - networks = experiment.network_factory(spec) - models = make_snapshot_models(networks, spec) - # TODO(raveman): Decouple checkpointing and snapshotting configs. - return snapshotter.JAXSnapshotter( - variable_source=variable_source, - models=models, - path=experiment.checkpointing.directory, - subdirectory='snapshots', - add_uid=experiment.checkpointing.add_uid) - - def build_counter(): - counter = counting.Counter() - if experiment.checkpointing: - checkpointing = experiment.checkpointing - counter = savers.CheckpointingRunner( - counter, - key='counter', - subdirectory='counter', - time_delta_minutes=checkpointing.time_delta_minutes, - directory=checkpointing.directory, - add_uid=checkpointing.add_uid, - max_to_keep=checkpointing.max_to_keep, - keep_checkpoint_every_n_hours=checkpointing.keep_checkpoint_every_n_hours, - checkpoint_ttl_seconds=checkpointing.checkpoint_ttl_seconds, - ) - return counter - - def build_learner( - random_key: networks_lib.PRNGKey, - replay: reverb.Client, - counter: Optional[counting.Counter] = None, - primary_learner: Optional[core.Learner] = None, - ): - """The Learning part of the agent.""" - - dummy_seed = 1 - spec = ( - experiment.environment_spec or - specs.make_environment_spec(experiment.environment_factory(dummy_seed))) - - # Creates the networks to optimize (online) and target networks. - networks = experiment.network_factory(spec) - - iterator = experiment.builder.make_dataset_iterator(replay) - # make_dataset_iterator is responsible for putting data onto appropriate - # training devices, so here we apply prefetch, so that data is copied over - # in the background. - iterator = utils.prefetch(iterable=iterator, buffer_size=1) - counter = counting.Counter(counter, 'learner') - learner = experiment.builder.make_learner(random_key, networks, iterator, - experiment.logger_factory, spec, - replay, counter) - - if experiment.checkpointing: - if primary_learner is None: - checkpointing = experiment.checkpointing - learner = savers.CheckpointingRunner( - learner, - key='learner', - subdirectory='learner', - time_delta_minutes=5, - directory=checkpointing.directory, - add_uid=checkpointing.add_uid, - max_to_keep=checkpointing.max_to_keep, - keep_checkpoint_every_n_hours=checkpointing.keep_checkpoint_every_n_hours, - checkpoint_ttl_seconds=checkpointing.checkpoint_ttl_seconds, + if multithreading_colocate_learner_and_reverb and num_learner_nodes > 1: + raise ValueError( + "Replay and learner colocation is not yet supported when the learner is" + " spread across multiple nodes (num_learner_nodes > 1). Please contact" + " Acme devs if this is a feature you want. Got:" + "\tmultithreading_colocate_learner_and_reverb=" + f"{multithreading_colocate_learner_and_reverb}" + f"\tnum_learner_nodes={num_learner_nodes}." + ) + + def build_replay(): + """The replay storage.""" + dummy_seed = 1 + spec = experiment.environment_spec or specs.make_environment_spec( + experiment.environment_factory(dummy_seed) ) - else: - learner.restore(primary_learner.save()) - # NOTE: This initially synchronizes secondary learner states with the - # primary one. Further synchronization should be handled by the learner - # properly doing a pmap/pmean on the loss/gradients, respectively. - - return learner - - def build_inference_server( - inference_server_config: inference_server_lib.InferenceServerConfig, - variable_source: core.VariableSource, - ) -> InferenceServer: - """Builds an inference server for `ActorCore` policies.""" - dummy_seed = 1 - spec = ( - experiment.environment_spec or - specs.make_environment_spec(experiment.environment_factory(dummy_seed))) - networks = experiment.network_factory(spec) - policy = config.make_policy( - experiment=experiment, - networks=networks, - environment_spec=spec, - evaluation=False, + network = experiment.network_factory(spec) + policy = config.make_policy( + experiment=experiment, + networks=network, + environment_spec=spec, + evaluation=False, + ) + return experiment.builder.make_replay_tables(spec, policy) + + def build_model_saver(variable_source: core.VariableSource): + assert experiment.checkpointing + environment = experiment.environment_factory(0) + spec = specs.make_environment_spec(environment) + networks = experiment.network_factory(spec) + models = make_snapshot_models(networks, spec) + # TODO(raveman): Decouple checkpointing and snapshotting configs. + return snapshotter.JAXSnapshotter( + variable_source=variable_source, + models=models, + path=experiment.checkpointing.directory, + subdirectory="snapshots", + add_uid=experiment.checkpointing.add_uid, + ) + + def build_counter(): + counter = counting.Counter() + if experiment.checkpointing: + checkpointing = experiment.checkpointing + counter = savers.CheckpointingRunner( + counter, + key="counter", + subdirectory="counter", + time_delta_minutes=checkpointing.time_delta_minutes, + directory=checkpointing.directory, + add_uid=checkpointing.add_uid, + max_to_keep=checkpointing.max_to_keep, + keep_checkpoint_every_n_hours=checkpointing.keep_checkpoint_every_n_hours, + checkpoint_ttl_seconds=checkpointing.checkpoint_ttl_seconds, + ) + return counter + + def build_learner( + random_key: networks_lib.PRNGKey, + replay: reverb.Client, + counter: Optional[counting.Counter] = None, + primary_learner: Optional[core.Learner] = None, + ): + """The Learning part of the agent.""" + + dummy_seed = 1 + spec = experiment.environment_spec or specs.make_environment_spec( + experiment.environment_factory(dummy_seed) + ) + + # Creates the networks to optimize (online) and target networks. + networks = experiment.network_factory(spec) + + iterator = experiment.builder.make_dataset_iterator(replay) + # make_dataset_iterator is responsible for putting data onto appropriate + # training devices, so here we apply prefetch, so that data is copied over + # in the background. + iterator = utils.prefetch(iterable=iterator, buffer_size=1) + counter = counting.Counter(counter, "learner") + learner = experiment.builder.make_learner( + random_key, + networks, + iterator, + experiment.logger_factory, + spec, + replay, + counter, + ) + + if experiment.checkpointing: + if primary_learner is None: + checkpointing = experiment.checkpointing + learner = savers.CheckpointingRunner( + learner, + key="learner", + subdirectory="learner", + time_delta_minutes=5, + directory=checkpointing.directory, + add_uid=checkpointing.add_uid, + max_to_keep=checkpointing.max_to_keep, + keep_checkpoint_every_n_hours=checkpointing.keep_checkpoint_every_n_hours, + checkpoint_ttl_seconds=checkpointing.checkpoint_ttl_seconds, + ) + else: + learner.restore(primary_learner.save()) + # NOTE: This initially synchronizes secondary learner states with the + # primary one. Further synchronization should be handled by the learner + # properly doing a pmap/pmean on the loss/gradients, respectively. + + return learner + + def build_inference_server( + inference_server_config: inference_server_lib.InferenceServerConfig, + variable_source: core.VariableSource, + ) -> InferenceServer: + """Builds an inference server for `ActorCore` policies.""" + dummy_seed = 1 + spec = experiment.environment_spec or specs.make_environment_spec( + experiment.environment_factory(dummy_seed) + ) + networks = experiment.network_factory(spec) + policy = config.make_policy( + experiment=experiment, + networks=networks, + environment_spec=spec, + evaluation=False, + ) + if not isinstance(policy, actor_core.ActorCore): + raise TypeError( + f"Using InferenceServer with policy of unsupported type:" + f"{type(policy)}. InferenceServer only supports `ActorCore` policies." + ) + + return InferenceServer( + handler=jax.jit( + jax.vmap( + policy.select_action, + in_axes=(None, 0, 0), + # Note on in_axes: Params will not be batched. Only the + # observations and actor state will be stacked along a new + # leading axis by the inference server. + ), + ), + variable_source=variable_source, + devices=jax.local_devices(), + config=inference_server_config, + ) + + def build_actor( + random_key: networks_lib.PRNGKey, + replay: reverb.Client, + variable_source: core.VariableSource, + counter: counting.Counter, + actor_id: ActorId, + inference_server: Optional[InferenceServer], + ) -> environment_loop.EnvironmentLoop: + """The actor process.""" + environment_key, actor_key = jax.random.split(random_key) + # Create environment and policy core. + + # Environments normally require uint32 as a seed. + environment = experiment.environment_factory( + utils.sample_uint32(environment_key) + ) + environment_spec = specs.make_environment_spec(environment) + + networks = experiment.network_factory(environment_spec) + policy_network = config.make_policy( + experiment=experiment, + networks=networks, + environment_spec=environment_spec, + evaluation=False, + ) + if inference_server is not None: + policy_network = actor_core.ActorCore( + init=policy_network.init, + select_action=inference_server.handler, + get_extras=policy_network.get_extras, + ) + variable_source = variable_utils.ReferenceVariableSource() + + adder = experiment.builder.make_adder(replay, environment_spec, policy_network) + actor = experiment.builder.make_actor( + actor_key, policy_network, environment_spec, variable_source, adder + ) + + # Create logger and counter. + counter = counting.Counter(counter, "actor") + logger = experiment.logger_factory("actor", counter.get_steps_key(), actor_id) + # Create the loop to connect environment and agent. + return environment_loop.EnvironmentLoop( + environment, actor, counter, logger, observers=experiment.observers + ) + + if not program: + program = lp.Program(name=name) + + key = jax.random.PRNGKey(experiment.seed) + + checkpoint_time_delta_minutes: Optional[int] = ( + experiment.checkpointing.replay_checkpointing_time_delta_minutes + if experiment.checkpointing + else None ) - if not isinstance(policy, actor_core.ActorCore): - raise TypeError( - f'Using InferenceServer with policy of unsupported type:' - f'{type(policy)}. InferenceServer only supports `ActorCore` policies.' - ) - - return InferenceServer( - handler=jax.jit( - jax.vmap( - policy.select_action, - in_axes=(None, 0, 0), - # Note on in_axes: Params will not be batched. Only the - # observations and actor state will be stacked along a new - # leading axis by the inference server. - ),), - variable_source=variable_source, - devices=jax.local_devices(), - config=inference_server_config, + replay_node = lp.ReverbNode( + build_replay, checkpoint_time_delta_minutes=checkpoint_time_delta_minutes ) + replay = replay_node.create_handle() - def build_actor( - random_key: networks_lib.PRNGKey, - replay: reverb.Client, - variable_source: core.VariableSource, - counter: counting.Counter, - actor_id: ActorId, - inference_server: Optional[InferenceServer], - ) -> environment_loop.EnvironmentLoop: - """The actor process.""" - environment_key, actor_key = jax.random.split(random_key) - # Create environment and policy core. - - # Environments normally require uint32 as a seed. - environment = experiment.environment_factory( - utils.sample_uint32(environment_key)) - environment_spec = specs.make_environment_spec(environment) - - networks = experiment.network_factory(environment_spec) - policy_network = config.make_policy( - experiment=experiment, - networks=networks, - environment_spec=environment_spec, - evaluation=False) - if inference_server is not None: - policy_network = actor_core.ActorCore( - init=policy_network.init, - select_action=inference_server.handler, - get_extras=policy_network.get_extras, - ) - variable_source = variable_utils.ReferenceVariableSource() - - adder = experiment.builder.make_adder(replay, environment_spec, - policy_network) - actor = experiment.builder.make_actor(actor_key, policy_network, - environment_spec, variable_source, - adder) - - # Create logger and counter. - counter = counting.Counter(counter, 'actor') - logger = experiment.logger_factory('actor', counter.get_steps_key(), - actor_id) - # Create the loop to connect environment and agent. - return environment_loop.EnvironmentLoop( - environment, actor, counter, logger, observers=experiment.observers) - - if not program: - program = lp.Program(name=name) - - key = jax.random.PRNGKey(experiment.seed) - - checkpoint_time_delta_minutes: Optional[int] = ( - experiment.checkpointing.replay_checkpointing_time_delta_minutes - if experiment.checkpointing else None) - replay_node = lp.ReverbNode( - build_replay, checkpoint_time_delta_minutes=checkpoint_time_delta_minutes) - replay = replay_node.create_handle() - - counter = program.add_node(lp.CourierNode(build_counter), label='counter') - - if experiment.max_num_actor_steps is not None: - program.add_node( - lp.CourierNode(lp_utils.StepsLimiter, counter, - experiment.max_num_actor_steps), - label='counter') - - learner_key, key = jax.random.split(key) - learner_node = lp.CourierNode(build_learner, learner_key, replay, counter) - learner = learner_node.create_handle() - variable_sources = [learner] - - if multithreading_colocate_learner_and_reverb: - program.add_node( - lp.MultiThreadingColocation([learner_node, replay_node]), - label='learner') - else: - program.add_node(replay_node, label='replay') - - with program.group('learner'): - program.add_node(learner_node) - - # Maybe create secondary learners, necessary when using multi-host - # accelerators. - # Warning! If you set num_learner_nodes > 1, make sure the learner class - # does the appropriate pmap/pmean operations on the loss/gradients, - # respectively. - for _ in range(1, num_learner_nodes): - learner_key, key = jax.random.split(key) - variable_sources.append( - program.add_node( - lp.CourierNode( - build_learner, learner_key, replay, - primary_learner=learner))) - # NOTE: Secondary learners are used to load-balance get_variables calls, - # which is why they get added to the list of available variable sources. - # NOTE: Only the primary learner checkpoints. - # NOTE: Do not pass the counter to the secondary learners to avoid - # double counting of learner steps. - - if inference_server_config is not None: - num_actors_per_server = math.ceil(num_actors / num_inference_servers) - with program.group('inference_server'): - inference_nodes = [] - for _ in range(num_inference_servers): - inference_nodes.append( - program.add_node( - lp.CourierNode( - build_inference_server, - inference_server_config, - learner, - courier_kwargs={'thread_pool_size': num_actors_per_server - }))) - else: - num_inference_servers = 1 - inference_nodes = [None] + counter = program.add_node(lp.CourierNode(build_counter), label="counter") - num_actor_nodes, remainder = divmod(num_actors, num_actors_per_node) - num_actor_nodes += int(remainder > 0) + if experiment.max_num_actor_steps is not None: + program.add_node( + lp.CourierNode( + lp_utils.StepsLimiter, counter, experiment.max_num_actor_steps + ), + label="counter", + ) + learner_key, key = jax.random.split(key) + learner_node = lp.CourierNode(build_learner, learner_key, replay, counter) + learner = learner_node.create_handle() + variable_sources = [learner] - with program.group('actor'): - # Create all actor threads. - *actor_keys, key = jax.random.split(key, num_actors + 1) + if multithreading_colocate_learner_and_reverb: + program.add_node( + lp.MultiThreadingColocation([learner_node, replay_node]), label="learner" + ) + else: + program.add_node(replay_node, label="replay") + + with program.group("learner"): + program.add_node(learner_node) + + # Maybe create secondary learners, necessary when using multi-host + # accelerators. + # Warning! If you set num_learner_nodes > 1, make sure the learner class + # does the appropriate pmap/pmean operations on the loss/gradients, + # respectively. + for _ in range(1, num_learner_nodes): + learner_key, key = jax.random.split(key) + variable_sources.append( + program.add_node( + lp.CourierNode( + build_learner, learner_key, replay, primary_learner=learner + ) + ) + ) + # NOTE: Secondary learners are used to load-balance get_variables calls, + # which is why they get added to the list of available variable sources. + # NOTE: Only the primary learner checkpoints. + # NOTE: Do not pass the counter to the secondary learners to avoid + # double counting of learner steps. + + if inference_server_config is not None: + num_actors_per_server = math.ceil(num_actors / num_inference_servers) + with program.group("inference_server"): + inference_nodes = [] + for _ in range(num_inference_servers): + inference_nodes.append( + program.add_node( + lp.CourierNode( + build_inference_server, + inference_server_config, + learner, + courier_kwargs={"thread_pool_size": num_actors_per_server}, + ) + ) + ) + else: + num_inference_servers = 1 + inference_nodes = [None] + + num_actor_nodes, remainder = divmod(num_actors, num_actors_per_node) + num_actor_nodes += int(remainder > 0) + + with program.group("actor"): + # Create all actor threads. + *actor_keys, key = jax.random.split(key, num_actors + 1) + + # Create (maybe colocated) actor nodes. + for node_id, variable_source, inference_node in zip( + range(num_actor_nodes), + itertools.cycle(variable_sources), + itertools.cycle(inference_nodes), + ): + colocation_nodes = [] + + first_actor_id = node_id * num_actors_per_node + for actor_id in range( + first_actor_id, min(first_actor_id + num_actors_per_node, num_actors) + ): + actor = lp.CourierNode( + build_actor, + actor_keys[actor_id], + replay, + variable_source, + counter, + actor_id, + inference_node, + ) + colocation_nodes.append(actor) + + if len(colocation_nodes) == 1: + program.add_node(colocation_nodes[0]) + elif multiprocessing_colocate_actors: + program.add_node(lp.MultiProcessingColocation(colocation_nodes)) + else: + program.add_node(lp.MultiThreadingColocation(colocation_nodes)) + + for evaluator in experiment.get_evaluator_factories(): + evaluator_key, key = jax.random.split(key) + program.add_node( + lp.CourierNode( + evaluator, + evaluator_key, + learner, + counter, + experiment.builder.make_actor, + ), + label="evaluator", + ) - # Create (maybe colocated) actor nodes. - for node_id, variable_source, inference_node in zip( - range(num_actor_nodes), - itertools.cycle(variable_sources), - itertools.cycle(inference_nodes), - ): - colocation_nodes = [] - - first_actor_id = node_id * num_actors_per_node - for actor_id in range( - first_actor_id, min(first_actor_id + num_actors_per_node, num_actors) - ): - actor = lp.CourierNode( - build_actor, - actor_keys[actor_id], - replay, - variable_source, - counter, - actor_id, - inference_node, + if make_snapshot_models and experiment.checkpointing: + program.add_node( + lp.CourierNode(build_model_saver, learner), label="model_saver" ) - colocation_nodes.append(actor) - - if len(colocation_nodes) == 1: - program.add_node(colocation_nodes[0]) - elif multiprocessing_colocate_actors: - program.add_node(lp.MultiProcessingColocation(colocation_nodes)) - else: - program.add_node(lp.MultiThreadingColocation(colocation_nodes)) - - for evaluator in experiment.get_evaluator_factories(): - evaluator_key, key = jax.random.split(key) - program.add_node( - lp.CourierNode(evaluator, evaluator_key, learner, counter, - experiment.builder.make_actor), - label='evaluator') - - if make_snapshot_models and experiment.checkpointing: - program.add_node( - lp.CourierNode(build_model_saver, learner), label='model_saver') - - return program + + return program diff --git a/acme/jax/experiments/make_distributed_offline_experiment.py b/acme/jax/experiments/make_distributed_offline_experiment.py index 40599eb2d9..88b1208cd2 100644 --- a/acme/jax/experiments/make_distributed_offline_experiment.py +++ b/acme/jax/experiments/make_distributed_offline_experiment.py @@ -16,28 +16,27 @@ from typing import Any, Optional -from acme import core -from acme import specs +import jax +import launchpad as lp + +from acme import core, specs from acme.agents.jax import builders from acme.jax import networks as networks_lib -from acme.jax import savers -from acme.jax import utils +from acme.jax import savers, snapshotter, utils from acme.jax.experiments import config -from acme.jax import snapshotter -from acme.utils import counting -from acme.utils import lp_utils -import jax -import launchpad as lp +from acme.utils import counting, lp_utils def make_distributed_offline_experiment( experiment: config.OfflineExperimentConfig[builders.Networks, Any, Any], *, - make_snapshot_models: Optional[config.SnapshotModelFactory[ - builders.Networks]] = None, - name: str = 'agent', - program: Optional[lp.Program] = None) -> lp.Program: - """Builds a Launchpad program for running the experiment. + make_snapshot_models: Optional[ + config.SnapshotModelFactory[builders.Networks] + ] = None, + name: str = "agent", + program: Optional[lp.Program] = None +) -> lp.Program: + """Builds a Launchpad program for running the experiment. Args: experiment: configuration for the experiment. @@ -51,107 +50,117 @@ def make_distributed_offline_experiment( The Launchpad program with all the nodes needed for running the experiment. """ - def build_model_saver(variable_source: core.VariableSource): - assert experiment.checkpointing - environment = experiment.environment_factory(0) - spec = specs.make_environment_spec(environment) - networks = experiment.network_factory(spec) - models = make_snapshot_models(networks, spec) - # TODO(raveman): Decouple checkpointing and snahpshotting configs. - return snapshotter.JAXSnapshotter( - variable_source=variable_source, - models=models, - path=experiment.checkpointing.directory, - add_uid=experiment.checkpointing.add_uid) - - def build_counter(): - counter = counting.Counter() - if experiment.checkpointing: - counter = savers.CheckpointingRunner( - counter, - key='counter', - subdirectory='counter', - time_delta_minutes=experiment.checkpointing.time_delta_minutes, - directory=experiment.checkpointing.directory, - add_uid=experiment.checkpointing.add_uid, - max_to_keep=experiment.checkpointing.max_to_keep, - checkpoint_ttl_seconds=experiment.checkpointing.checkpoint_ttl_seconds, - ) - return counter - - def build_learner( - random_key: networks_lib.PRNGKey, - counter: Optional[counting.Counter] = None, - ): - """The Learning part of the agent.""" - - dummy_seed = 1 - spec = ( - experiment.environment_spec or - specs.make_environment_spec(experiment.environment_factory(dummy_seed))) - - # Creates the networks to optimize (online) and target networks. - networks = experiment.network_factory(spec) - - dataset_key, random_key = jax.random.split(random_key) - iterator = experiment.demonstration_dataset_factory(dataset_key) - # make_demonstrations is responsible for putting data onto appropriate - # training devices, so here we apply prefetch, so that data is copied over - # in the background. - iterator = utils.prefetch(iterable=iterator, buffer_size=1) - counter = counting.Counter(counter, 'learner') - learner = experiment.builder.make_learner( - random_key=random_key, - networks=networks, - dataset=iterator, - logger_fn=experiment.logger_factory, - environment_spec=spec, - counter=counter) - - if experiment.checkpointing: - learner = savers.CheckpointingRunner( - learner, - key='learner', - subdirectory='learner', - time_delta_minutes=5, - directory=experiment.checkpointing.directory, - add_uid=experiment.checkpointing.add_uid, - max_to_keep=experiment.checkpointing.max_to_keep, - checkpoint_ttl_seconds=experiment.checkpointing.checkpoint_ttl_seconds, - ) - - return learner - - if not program: - program = lp.Program(name=name) - - key = jax.random.PRNGKey(experiment.seed) - - counter = program.add_node(lp.CourierNode(build_counter), label='counter') - - if experiment.max_num_learner_steps is not None: - program.add_node( - lp.CourierNode( - lp_utils.StepsLimiter, - counter, - experiment.max_num_learner_steps, - steps_key='learner_steps'), - label='counter') - - learner_key, key = jax.random.split(key) - learner_node = lp.CourierNode(build_learner, learner_key, counter) - learner = learner_node.create_handle() - program.add_node(learner_node, label='learner') - - for evaluator in experiment.get_evaluator_factories(): - evaluator_key, key = jax.random.split(key) - program.add_node( - lp.CourierNode(evaluator, evaluator_key, learner, counter, - experiment.builder.make_actor), - label='evaluator') - - if make_snapshot_models and experiment.checkpointing: - program.add_node(lp.CourierNode(build_model_saver, learner), - label='model_saver') - - return program + def build_model_saver(variable_source: core.VariableSource): + assert experiment.checkpointing + environment = experiment.environment_factory(0) + spec = specs.make_environment_spec(environment) + networks = experiment.network_factory(spec) + models = make_snapshot_models(networks, spec) + # TODO(raveman): Decouple checkpointing and snahpshotting configs. + return snapshotter.JAXSnapshotter( + variable_source=variable_source, + models=models, + path=experiment.checkpointing.directory, + add_uid=experiment.checkpointing.add_uid, + ) + + def build_counter(): + counter = counting.Counter() + if experiment.checkpointing: + counter = savers.CheckpointingRunner( + counter, + key="counter", + subdirectory="counter", + time_delta_minutes=experiment.checkpointing.time_delta_minutes, + directory=experiment.checkpointing.directory, + add_uid=experiment.checkpointing.add_uid, + max_to_keep=experiment.checkpointing.max_to_keep, + checkpoint_ttl_seconds=experiment.checkpointing.checkpoint_ttl_seconds, + ) + return counter + + def build_learner( + random_key: networks_lib.PRNGKey, counter: Optional[counting.Counter] = None, + ): + """The Learning part of the agent.""" + + dummy_seed = 1 + spec = experiment.environment_spec or specs.make_environment_spec( + experiment.environment_factory(dummy_seed) + ) + + # Creates the networks to optimize (online) and target networks. + networks = experiment.network_factory(spec) + + dataset_key, random_key = jax.random.split(random_key) + iterator = experiment.demonstration_dataset_factory(dataset_key) + # make_demonstrations is responsible for putting data onto appropriate + # training devices, so here we apply prefetch, so that data is copied over + # in the background. + iterator = utils.prefetch(iterable=iterator, buffer_size=1) + counter = counting.Counter(counter, "learner") + learner = experiment.builder.make_learner( + random_key=random_key, + networks=networks, + dataset=iterator, + logger_fn=experiment.logger_factory, + environment_spec=spec, + counter=counter, + ) + + if experiment.checkpointing: + learner = savers.CheckpointingRunner( + learner, + key="learner", + subdirectory="learner", + time_delta_minutes=5, + directory=experiment.checkpointing.directory, + add_uid=experiment.checkpointing.add_uid, + max_to_keep=experiment.checkpointing.max_to_keep, + checkpoint_ttl_seconds=experiment.checkpointing.checkpoint_ttl_seconds, + ) + + return learner + + if not program: + program = lp.Program(name=name) + + key = jax.random.PRNGKey(experiment.seed) + + counter = program.add_node(lp.CourierNode(build_counter), label="counter") + + if experiment.max_num_learner_steps is not None: + program.add_node( + lp.CourierNode( + lp_utils.StepsLimiter, + counter, + experiment.max_num_learner_steps, + steps_key="learner_steps", + ), + label="counter", + ) + + learner_key, key = jax.random.split(key) + learner_node = lp.CourierNode(build_learner, learner_key, counter) + learner = learner_node.create_handle() + program.add_node(learner_node, label="learner") + + for evaluator in experiment.get_evaluator_factories(): + evaluator_key, key = jax.random.split(key) + program.add_node( + lp.CourierNode( + evaluator, + evaluator_key, + learner, + counter, + experiment.builder.make_actor, + ), + label="evaluator", + ) + + if make_snapshot_models and experiment.checkpointing: + program.add_node( + lp.CourierNode(build_model_saver, learner), label="model_saver" + ) + + return program diff --git a/acme/jax/experiments/run_experiment.py b/acme/jax/experiments/run_experiment.py index 3e5a5968a7..3af50460a6 100644 --- a/acme/jax/experiments/run_experiment.py +++ b/acme/jax/experiments/run_experiment.py @@ -18,23 +18,24 @@ import time from typing import Optional, Sequence, Tuple +import dm_env +import jax +import reverb + import acme -from acme import core -from acme import specs -from acme import types +from acme import core, specs, types from acme.jax import utils from acme.jax.experiments import config from acme.tf import savers from acme.utils import counting -import dm_env -import jax -import reverb -def run_experiment(experiment: config.ExperimentConfig, - eval_every: int = 100, - num_eval_episodes: int = 1): - """Runs a simple, single-threaded training loop using the default evaluators. +def run_experiment( + experiment: config.ExperimentConfig, + eval_every: int = 100, + num_eval_episodes: int = 1, +): + """Runs a simple, single-threaded training loop using the default evaluators. It targets simplicity of the code and so only the basic features of the ExperimentConfig are supported. @@ -46,142 +47,147 @@ def run_experiment(experiment: config.ExperimentConfig, evaluation step. """ - key = jax.random.PRNGKey(experiment.seed) - - # Create the environment and get its spec. - environment = experiment.environment_factory(experiment.seed) - environment_spec = experiment.environment_spec or specs.make_environment_spec( - environment) - - # Create the networks and policy. - networks = experiment.network_factory(environment_spec) - policy = config.make_policy( - experiment=experiment, - networks=networks, - environment_spec=environment_spec, - evaluation=False) - - # Create the replay server and grab its address. - replay_tables = experiment.builder.make_replay_tables(environment_spec, - policy) - - # Disable blocking of inserts by tables' rate limiters, as this function - # executes learning (sampling from the table) and data generation - # (inserting into the table) sequentially from the same thread - # which could result in blocked insert making the algorithm hang. - replay_tables, rate_limiters_max_diff = _disable_insert_blocking( - replay_tables) - - replay_server = reverb.Server(replay_tables, port=None) - replay_client = reverb.Client(f'localhost:{replay_server.port}') - - # Parent counter allows to share step counts between train and eval loops and - # the learner, so that it is possible to plot for example evaluator's return - # value as a function of the number of training episodes. - parent_counter = counting.Counter(time_delta=0.) - - dataset = experiment.builder.make_dataset_iterator(replay_client) - # We always use prefetch as it provides an iterator with an additional - # 'ready' method. - dataset = utils.prefetch(dataset, buffer_size=1) - - # Create actor, adder, and learner for generating, storing, and consuming - # data respectively. - # NOTE: These are created in reverse order as the actor needs to be given the - # adder and the learner (as a source of variables). - learner_key, key = jax.random.split(key) - learner = experiment.builder.make_learner( - random_key=learner_key, - networks=networks, - dataset=dataset, - logger_fn=experiment.logger_factory, - environment_spec=environment_spec, - replay_client=replay_client, - counter=counting.Counter(parent_counter, prefix='learner', time_delta=0.)) - - adder = experiment.builder.make_adder(replay_client, environment_spec, policy) - - actor_key, key = jax.random.split(key) - actor = experiment.builder.make_actor( - actor_key, policy, environment_spec, variable_source=learner, adder=adder) - - # Create the environment loop used for training. - train_counter = counting.Counter( - parent_counter, prefix='actor', time_delta=0.) - train_logger = experiment.logger_factory('actor', - train_counter.get_steps_key(), 0) - - checkpointer = None - if experiment.checkpointing is not None: - checkpointing = experiment.checkpointing - checkpointer = savers.Checkpointer( - objects_to_save={'learner': learner, 'counter': parent_counter}, - time_delta_minutes=checkpointing.time_delta_minutes, - directory=checkpointing.directory, - add_uid=checkpointing.add_uid, - max_to_keep=checkpointing.max_to_keep, - keep_checkpoint_every_n_hours=checkpointing.keep_checkpoint_every_n_hours, - checkpoint_ttl_seconds=checkpointing.checkpoint_ttl_seconds, + key = jax.random.PRNGKey(experiment.seed) + + # Create the environment and get its spec. + environment = experiment.environment_factory(experiment.seed) + environment_spec = experiment.environment_spec or specs.make_environment_spec( + environment + ) + + # Create the networks and policy. + networks = experiment.network_factory(environment_spec) + policy = config.make_policy( + experiment=experiment, + networks=networks, + environment_spec=environment_spec, + evaluation=False, + ) + + # Create the replay server and grab its address. + replay_tables = experiment.builder.make_replay_tables(environment_spec, policy) + + # Disable blocking of inserts by tables' rate limiters, as this function + # executes learning (sampling from the table) and data generation + # (inserting into the table) sequentially from the same thread + # which could result in blocked insert making the algorithm hang. + replay_tables, rate_limiters_max_diff = _disable_insert_blocking(replay_tables) + + replay_server = reverb.Server(replay_tables, port=None) + replay_client = reverb.Client(f"localhost:{replay_server.port}") + + # Parent counter allows to share step counts between train and eval loops and + # the learner, so that it is possible to plot for example evaluator's return + # value as a function of the number of training episodes. + parent_counter = counting.Counter(time_delta=0.0) + + dataset = experiment.builder.make_dataset_iterator(replay_client) + # We always use prefetch as it provides an iterator with an additional + # 'ready' method. + dataset = utils.prefetch(dataset, buffer_size=1) + + # Create actor, adder, and learner for generating, storing, and consuming + # data respectively. + # NOTE: These are created in reverse order as the actor needs to be given the + # adder and the learner (as a source of variables). + learner_key, key = jax.random.split(key) + learner = experiment.builder.make_learner( + random_key=learner_key, + networks=networks, + dataset=dataset, + logger_fn=experiment.logger_factory, + environment_spec=environment_spec, + replay_client=replay_client, + counter=counting.Counter(parent_counter, prefix="learner", time_delta=0.0), + ) + + adder = experiment.builder.make_adder(replay_client, environment_spec, policy) + + actor_key, key = jax.random.split(key) + actor = experiment.builder.make_actor( + actor_key, policy, environment_spec, variable_source=learner, adder=adder + ) + + # Create the environment loop used for training. + train_counter = counting.Counter(parent_counter, prefix="actor", time_delta=0.0) + train_logger = experiment.logger_factory("actor", train_counter.get_steps_key(), 0) + + checkpointer = None + if experiment.checkpointing is not None: + checkpointing = experiment.checkpointing + checkpointer = savers.Checkpointer( + objects_to_save={"learner": learner, "counter": parent_counter}, + time_delta_minutes=checkpointing.time_delta_minutes, + directory=checkpointing.directory, + add_uid=checkpointing.add_uid, + max_to_keep=checkpointing.max_to_keep, + keep_checkpoint_every_n_hours=checkpointing.keep_checkpoint_every_n_hours, + checkpoint_ttl_seconds=checkpointing.checkpoint_ttl_seconds, + ) + + # Replace the actor with a LearningActor. This makes sure that every time + # that `update` is called on the actor it checks to see whether there is + # any new data to learn from and if so it runs a learner step. The rate + # at which new data is released is controlled by the replay table's + # rate_limiter which is created by the builder.make_replay_tables call above. + actor = _LearningActor( + actor, learner, dataset, replay_tables, rate_limiters_max_diff, checkpointer + ) + + train_loop = acme.EnvironmentLoop( + environment, + actor, + counter=train_counter, + logger=train_logger, + observers=experiment.observers, + ) + + max_num_actor_steps = experiment.max_num_actor_steps - parent_counter.get_counts().get( + train_counter.get_steps_key(), 0 + ) + + if num_eval_episodes == 0: + # No evaluation. Just run the training loop. + train_loop.run(num_steps=max_num_actor_steps) + return + + # Create the evaluation actor and loop. + eval_counter = counting.Counter(parent_counter, prefix="evaluator", time_delta=0.0) + eval_logger = experiment.logger_factory( + "evaluator", eval_counter.get_steps_key(), 0 + ) + eval_policy = config.make_policy( + experiment=experiment, + networks=networks, + environment_spec=environment_spec, + evaluation=True, + ) + eval_actor = experiment.builder.make_actor( + random_key=jax.random.PRNGKey(experiment.seed), + policy=eval_policy, + environment_spec=environment_spec, + variable_source=learner, + ) + eval_loop = acme.EnvironmentLoop( + environment, + eval_actor, + counter=eval_counter, + logger=eval_logger, + observers=experiment.observers, ) - # Replace the actor with a LearningActor. This makes sure that every time - # that `update` is called on the actor it checks to see whether there is - # any new data to learn from and if so it runs a learner step. The rate - # at which new data is released is controlled by the replay table's - # rate_limiter which is created by the builder.make_replay_tables call above. - actor = _LearningActor(actor, learner, dataset, replay_tables, - rate_limiters_max_diff, checkpointer) - - train_loop = acme.EnvironmentLoop( - environment, - actor, - counter=train_counter, - logger=train_logger, - observers=experiment.observers) - - max_num_actor_steps = ( - experiment.max_num_actor_steps - - parent_counter.get_counts().get(train_counter.get_steps_key(), 0)) - - if num_eval_episodes == 0: - # No evaluation. Just run the training loop. - train_loop.run(num_steps=max_num_actor_steps) - return - - # Create the evaluation actor and loop. - eval_counter = counting.Counter( - parent_counter, prefix='evaluator', time_delta=0.) - eval_logger = experiment.logger_factory('evaluator', - eval_counter.get_steps_key(), 0) - eval_policy = config.make_policy( - experiment=experiment, - networks=networks, - environment_spec=environment_spec, - evaluation=True) - eval_actor = experiment.builder.make_actor( - random_key=jax.random.PRNGKey(experiment.seed), - policy=eval_policy, - environment_spec=environment_spec, - variable_source=learner) - eval_loop = acme.EnvironmentLoop( - environment, - eval_actor, - counter=eval_counter, - logger=eval_logger, - observers=experiment.observers) - - steps = 0 - while steps < max_num_actor_steps: + steps = 0 + while steps < max_num_actor_steps: + eval_loop.run(num_episodes=num_eval_episodes) + num_steps = min(eval_every, max_num_actor_steps - steps) + steps += train_loop.run(num_steps=num_steps) eval_loop.run(num_episodes=num_eval_episodes) - num_steps = min(eval_every, max_num_actor_steps - steps) - steps += train_loop.run(num_steps=num_steps) - eval_loop.run(num_episodes=num_eval_episodes) - environment.close() + environment.close() class _LearningActor(core.Actor): - """Actor which learns (updates its parameters) when `update` is called. + """Actor which learns (updates its parameters) when `update` is called. This combines a base actor and a learner. Whenever `update` is called on the wrapping actor the learner will take a step (e.g. one step of gradient @@ -191,12 +197,16 @@ class _LearningActor(core.Actor): Intended to be used by the `run_experiment` only. """ - def __init__(self, actor: core.Actor, learner: core.Learner, - iterator: core.PrefetchingIterator, - replay_tables: Sequence[reverb.Table], - sample_sizes: Sequence[int], - checkpointer: Optional[savers.Checkpointer]): - """Initializes _LearningActor. + def __init__( + self, + actor: core.Actor, + learner: core.Learner, + iterator: core.PrefetchingIterator, + replay_tables: Sequence[reverb.Table], + sample_sizes: Sequence[int], + checkpointer: Optional[savers.Checkpointer], + ): + """Initializes _LearningActor. Args: actor: Actor to be wrapped. @@ -210,71 +220,73 @@ def __init__(self, actor: core.Actor, learner: core.Learner, collected by the actor. checkpointer: Checkpointer to save the state on update. """ - self._actor = actor - self._learner = learner - self._iterator = iterator - self._replay_tables = replay_tables - self._sample_sizes = sample_sizes - self._learner_steps = 0 - self._checkpointer = checkpointer - - def select_action(self, observation: types.NestedArray) -> types.NestedArray: - return self._actor.select_action(observation) - - def observe_first(self, timestep: dm_env.TimeStep): - self._actor.observe_first(timestep) - - def observe(self, action: types.NestedArray, next_timestep: dm_env.TimeStep): - self._actor.observe(action, next_timestep) - - def _maybe_train(self): - trained = False - while True: - if self._iterator.ready(): - self._learner.step() - batches = self._iterator.retrieved_elements() - self._learner_steps - self._learner_steps += 1 - assert batches == 1, ( - 'Learner step must retrieve exactly one element from the iterator' - f' (retrieved {batches}). Otherwise agent can deadlock. Example ' - 'cause is that your chosen agent' - 's Builder has a `make_learner` ' - 'factory that prefetches the data but it shouldn' - 't.') - trained = True - else: - # Wait for the iterator to fetch more data from the table(s) only - # if there plenty of data to sample from each table. - for table, sample_size in zip(self._replay_tables, self._sample_sizes): - if not table.can_sample(sample_size): - return trained - # Let iterator's prefetching thread get data from the table(s). - time.sleep(0.001) - - def update(self): - if self._maybe_train(): - # Update the actor weights only when learner was updated. - self._actor.update() - if self._checkpointer: - self._checkpointer.save() + self._actor = actor + self._learner = learner + self._iterator = iterator + self._replay_tables = replay_tables + self._sample_sizes = sample_sizes + self._learner_steps = 0 + self._checkpointer = checkpointer + + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + return self._actor.select_action(observation) + + def observe_first(self, timestep: dm_env.TimeStep): + self._actor.observe_first(timestep) + + def observe(self, action: types.NestedArray, next_timestep: dm_env.TimeStep): + self._actor.observe(action, next_timestep) + + def _maybe_train(self): + trained = False + while True: + if self._iterator.ready(): + self._learner.step() + batches = self._iterator.retrieved_elements() - self._learner_steps + self._learner_steps += 1 + assert batches == 1, ( + "Learner step must retrieve exactly one element from the iterator" + f" (retrieved {batches}). Otherwise agent can deadlock. Example " + "cause is that your chosen agent" + "s Builder has a `make_learner` " + "factory that prefetches the data but it shouldn" + "t." + ) + trained = True + else: + # Wait for the iterator to fetch more data from the table(s) only + # if there plenty of data to sample from each table. + for table, sample_size in zip(self._replay_tables, self._sample_sizes): + if not table.can_sample(sample_size): + return trained + # Let iterator's prefetching thread get data from the table(s). + time.sleep(0.001) + + def update(self): + if self._maybe_train(): + # Update the actor weights only when learner was updated. + self._actor.update() + if self._checkpointer: + self._checkpointer.save() def _disable_insert_blocking( - tables: Sequence[reverb.Table] + tables: Sequence[reverb.Table], ) -> Tuple[Sequence[reverb.Table], Sequence[int]]: - """Disables blocking of insert operations for a given collection of tables.""" - modified_tables = [] - sample_sizes = [] - for table in tables: - rate_limiter_info = table.info.rate_limiter_info - rate_limiter = reverb.rate_limiters.RateLimiter( - samples_per_insert=rate_limiter_info.samples_per_insert, - min_size_to_sample=rate_limiter_info.min_size_to_sample, - min_diff=rate_limiter_info.min_diff, - max_diff=sys.float_info.max) - modified_tables.append(table.replace(rate_limiter=rate_limiter)) - # Target the middle of the rate limiter's insert-sample balance window. - sample_sizes.append( - max(1, int( - (rate_limiter_info.max_diff - rate_limiter_info.min_diff) / 2))) - return modified_tables, sample_sizes + """Disables blocking of insert operations for a given collection of tables.""" + modified_tables = [] + sample_sizes = [] + for table in tables: + rate_limiter_info = table.info.rate_limiter_info + rate_limiter = reverb.rate_limiters.RateLimiter( + samples_per_insert=rate_limiter_info.samples_per_insert, + min_size_to_sample=rate_limiter_info.min_size_to_sample, + min_diff=rate_limiter_info.min_diff, + max_diff=sys.float_info.max, + ) + modified_tables.append(table.replace(rate_limiter=rate_limiter)) + # Target the middle of the rate limiter's insert-sample balance window. + sample_sizes.append( + max(1, int((rate_limiter_info.max_diff - rate_limiter_info.min_diff) / 2)) + ) + return modified_tables, sample_sizes diff --git a/acme/jax/experiments/run_experiment_test.py b/acme/jax/experiments/run_experiment_test.py index 99b9187041..94b216e835 100644 --- a/acme/jax/experiments/run_experiment_test.py +++ b/acme/jax/experiments/run_experiment_test.py @@ -14,83 +14,92 @@ """Tests for the run_experiment function.""" +import dm_env +from absl.testing import absltest, parameterized + from acme.agents.jax import sac from acme.jax import experiments from acme.jax.experiments import test_utils as experiment_test_utils -from acme.testing import fakes -from acme.testing import test_utils -import dm_env -from absl.testing import absltest -from absl.testing import parameterized +from acme.testing import fakes, test_utils class RunExperimentTest(test_utils.TestCase): - - @parameterized.named_parameters( - dict(testcase_name='noeval', num_eval_episodes=0), - dict(testcase_name='eval', num_eval_episodes=1)) - def test_checkpointing(self, num_eval_episodes: int): - num_train_steps = 100 - experiment_config = self._get_experiment_config( - num_train_steps=num_train_steps) - - experiments.run_experiment( - experiment_config, eval_every=10, num_eval_episodes=num_eval_episodes) - - checkpoint_counter = experiment_test_utils.restore_counter( - experiment_config.checkpointing) - self.assertIn('actor_steps', checkpoint_counter.get_counts()) - self.assertGreater(checkpoint_counter.get_counts()['actor_steps'], 0) - - # Run the second experiment with the same checkpointing config to verify - # that it restores from the latest saved checkpoint. - experiments.run_experiment( - experiment_config, eval_every=50, num_eval_episodes=num_eval_episodes) - - checkpoint_counter = experiment_test_utils.restore_counter( - experiment_config.checkpointing) - self.assertIn('actor_steps', checkpoint_counter.get_counts()) - # Verify that the steps done in the first run are taken into account. - self.assertLessEqual(checkpoint_counter.get_counts()['actor_steps'], - num_train_steps) - - def test_eval_every(self): - num_train_steps = 100 - experiment_config = self._get_experiment_config( - num_train_steps=num_train_steps) - - experiments.run_experiment( - experiment_config, eval_every=70, num_eval_episodes=1) - - checkpoint_counter = experiment_test_utils.restore_counter( - experiment_config.checkpointing) - self.assertIn('actor_steps', checkpoint_counter.get_counts()) - self.assertGreater(checkpoint_counter.get_counts()['actor_steps'], 0) - self.assertLessEqual(checkpoint_counter.get_counts()['actor_steps'], - num_train_steps) - - def _get_experiment_config( - self, *, num_train_steps: int) -> experiments.ExperimentConfig: - """Returns a config for a test experiment with the given number of steps.""" - - def environment_factory(seed: int) -> dm_env.Environment: - del seed - return fakes.ContinuousEnvironment( - episode_length=10, action_dim=3, observation_dim=5) - - num_train_steps = 100 - - sac_config = sac.SACConfig() - checkpointing_config = experiments.CheckpointingConfig( - directory=self.get_tempdir(), time_delta_minutes=0) - return experiments.ExperimentConfig( - builder=sac.SACBuilder(sac_config), - environment_factory=environment_factory, - network_factory=sac.make_networks, - seed=0, - max_num_actor_steps=num_train_steps, - checkpointing=checkpointing_config) - - -if __name__ == '__main__': - absltest.main() + @parameterized.named_parameters( + dict(testcase_name="noeval", num_eval_episodes=0), + dict(testcase_name="eval", num_eval_episodes=1), + ) + def test_checkpointing(self, num_eval_episodes: int): + num_train_steps = 100 + experiment_config = self._get_experiment_config(num_train_steps=num_train_steps) + + experiments.run_experiment( + experiment_config, eval_every=10, num_eval_episodes=num_eval_episodes + ) + + checkpoint_counter = experiment_test_utils.restore_counter( + experiment_config.checkpointing + ) + self.assertIn("actor_steps", checkpoint_counter.get_counts()) + self.assertGreater(checkpoint_counter.get_counts()["actor_steps"], 0) + + # Run the second experiment with the same checkpointing config to verify + # that it restores from the latest saved checkpoint. + experiments.run_experiment( + experiment_config, eval_every=50, num_eval_episodes=num_eval_episodes + ) + + checkpoint_counter = experiment_test_utils.restore_counter( + experiment_config.checkpointing + ) + self.assertIn("actor_steps", checkpoint_counter.get_counts()) + # Verify that the steps done in the first run are taken into account. + self.assertLessEqual( + checkpoint_counter.get_counts()["actor_steps"], num_train_steps + ) + + def test_eval_every(self): + num_train_steps = 100 + experiment_config = self._get_experiment_config(num_train_steps=num_train_steps) + + experiments.run_experiment( + experiment_config, eval_every=70, num_eval_episodes=1 + ) + + checkpoint_counter = experiment_test_utils.restore_counter( + experiment_config.checkpointing + ) + self.assertIn("actor_steps", checkpoint_counter.get_counts()) + self.assertGreater(checkpoint_counter.get_counts()["actor_steps"], 0) + self.assertLessEqual( + checkpoint_counter.get_counts()["actor_steps"], num_train_steps + ) + + def _get_experiment_config( + self, *, num_train_steps: int + ) -> experiments.ExperimentConfig: + """Returns a config for a test experiment with the given number of steps.""" + + def environment_factory(seed: int) -> dm_env.Environment: + del seed + return fakes.ContinuousEnvironment( + episode_length=10, action_dim=3, observation_dim=5 + ) + + num_train_steps = 100 + + sac_config = sac.SACConfig() + checkpointing_config = experiments.CheckpointingConfig( + directory=self.get_tempdir(), time_delta_minutes=0 + ) + return experiments.ExperimentConfig( + builder=sac.SACBuilder(sac_config), + environment_factory=environment_factory, + network_factory=sac.make_networks, + seed=0, + max_num_actor_steps=num_train_steps, + checkpointing=checkpointing_config, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/jax/experiments/run_offline_experiment.py b/acme/jax/experiments/run_offline_experiment.py index 772bb1fb6e..2c6902fcaf 100644 --- a/acme/jax/experiments/run_offline_experiment.py +++ b/acme/jax/experiments/run_offline_experiment.py @@ -14,18 +14,21 @@ """Runner used for executing local offline RL agents.""" +import jax + import acme from acme import specs from acme.jax.experiments import config from acme.tf import savers from acme.utils import counting -import jax -def run_offline_experiment(experiment: config.OfflineExperimentConfig, - eval_every: int = 100, - num_eval_episodes: int = 1): - """Runs a simple, single-threaded training loop using the default evaluators. +def run_offline_experiment( + experiment: config.OfflineExperimentConfig, + eval_every: int = 100, + num_eval_episodes: int = 1, +): + """Runs a simple, single-threaded training loop using the default evaluators. It targets simplicity of the code and so only the basic features of the OfflineExperimentConfig are supported. @@ -37,83 +40,89 @@ def run_offline_experiment(experiment: config.OfflineExperimentConfig, evaluation step. """ - key = jax.random.PRNGKey(experiment.seed) - - # Create the environment and get its spec. - environment = experiment.environment_factory(experiment.seed) - environment_spec = experiment.environment_spec or specs.make_environment_spec( - environment) - - # Create the networks and policy. - networks = experiment.network_factory(environment_spec) - - # Parent counter allows to share step counts between train and eval loops and - # the learner, so that it is possible to plot for example evaluator's return - # value as a function of the number of training episodes. - parent_counter = counting.Counter(time_delta=0.) - - # Create the demonstrations dataset. - dataset_key, key = jax.random.split(key) - dataset = experiment.demonstration_dataset_factory(dataset_key) - - # Create the learner. - learner_key, key = jax.random.split(key) - learner = experiment.builder.make_learner( - random_key=learner_key, - networks=networks, - dataset=dataset, - logger_fn=experiment.logger_factory, - environment_spec=environment_spec, - counter=counting.Counter(parent_counter, prefix='learner', time_delta=0.)) - - # Define the evaluation loop. - eval_loop = None - if num_eval_episodes > 0: - # Create the evaluation actor and loop. - eval_counter = counting.Counter( - parent_counter, prefix='evaluator', time_delta=0.) - eval_logger = experiment.logger_factory('evaluator', - eval_counter.get_steps_key(), 0) - eval_key, key = jax.random.split(key) - eval_actor = experiment.builder.make_actor( - random_key=eval_key, - policy=experiment.builder.make_policy(networks, environment_spec, True), + key = jax.random.PRNGKey(experiment.seed) + + # Create the environment and get its spec. + environment = experiment.environment_factory(experiment.seed) + environment_spec = experiment.environment_spec or specs.make_environment_spec( + environment + ) + + # Create the networks and policy. + networks = experiment.network_factory(environment_spec) + + # Parent counter allows to share step counts between train and eval loops and + # the learner, so that it is possible to plot for example evaluator's return + # value as a function of the number of training episodes. + parent_counter = counting.Counter(time_delta=0.0) + + # Create the demonstrations dataset. + dataset_key, key = jax.random.split(key) + dataset = experiment.demonstration_dataset_factory(dataset_key) + + # Create the learner. + learner_key, key = jax.random.split(key) + learner = experiment.builder.make_learner( + random_key=learner_key, + networks=networks, + dataset=dataset, + logger_fn=experiment.logger_factory, environment_spec=environment_spec, - variable_source=learner) - eval_loop = acme.EnvironmentLoop( - environment, - eval_actor, - counter=eval_counter, - logger=eval_logger, - observers=experiment.observers) - - checkpointer = None - if experiment.checkpointing is not None: - checkpointing = experiment.checkpointing - checkpointer = savers.Checkpointer( - objects_to_save={'learner': learner, 'counter': parent_counter}, - time_delta_minutes=checkpointing.time_delta_minutes, - directory=checkpointing.directory, - add_uid=checkpointing.add_uid, - max_to_keep=checkpointing.max_to_keep, - keep_checkpoint_every_n_hours=checkpointing.keep_checkpoint_every_n_hours, - checkpoint_ttl_seconds=checkpointing.checkpoint_ttl_seconds, + counter=counting.Counter(parent_counter, prefix="learner", time_delta=0.0), + ) + + # Define the evaluation loop. + eval_loop = None + if num_eval_episodes > 0: + # Create the evaluation actor and loop. + eval_counter = counting.Counter( + parent_counter, prefix="evaluator", time_delta=0.0 + ) + eval_logger = experiment.logger_factory( + "evaluator", eval_counter.get_steps_key(), 0 + ) + eval_key, key = jax.random.split(key) + eval_actor = experiment.builder.make_actor( + random_key=eval_key, + policy=experiment.builder.make_policy(networks, environment_spec, True), + environment_spec=environment_spec, + variable_source=learner, + ) + eval_loop = acme.EnvironmentLoop( + environment, + eval_actor, + counter=eval_counter, + logger=eval_logger, + observers=experiment.observers, + ) + + checkpointer = None + if experiment.checkpointing is not None: + checkpointing = experiment.checkpointing + checkpointer = savers.Checkpointer( + objects_to_save={"learner": learner, "counter": parent_counter}, + time_delta_minutes=checkpointing.time_delta_minutes, + directory=checkpointing.directory, + add_uid=checkpointing.add_uid, + max_to_keep=checkpointing.max_to_keep, + keep_checkpoint_every_n_hours=checkpointing.keep_checkpoint_every_n_hours, + checkpoint_ttl_seconds=checkpointing.checkpoint_ttl_seconds, + ) + + max_num_learner_steps = experiment.max_num_learner_steps - parent_counter.get_counts().get( + "learner_steps", 0 ) - max_num_learner_steps = ( - experiment.max_num_learner_steps - - parent_counter.get_counts().get('learner_steps', 0)) - - # Run the training loop. - if eval_loop: - eval_loop.run(num_eval_episodes) - steps = 0 - while steps < max_num_learner_steps: - learner_steps = min(eval_every, max_num_learner_steps - steps) - for _ in range(learner_steps): - learner.step() - if checkpointer is not None: - checkpointer.save() + # Run the training loop. if eval_loop: - eval_loop.run(num_eval_episodes) - steps += learner_steps + eval_loop.run(num_eval_episodes) + steps = 0 + while steps < max_num_learner_steps: + learner_steps = min(eval_every, max_num_learner_steps - steps) + for _ in range(learner_steps): + learner.step() + if checkpointer is not None: + checkpointer.save() + if eval_loop: + eval_loop.run(num_eval_episodes) + steps += learner_steps diff --git a/acme/jax/experiments/run_offline_experiment_test.py b/acme/jax/experiments/run_offline_experiment_test.py index 5e24d3df21..f1611ed39c 100644 --- a/acme/jax/experiments/run_offline_experiment_test.py +++ b/acme/jax/experiments/run_offline_experiment_test.py @@ -16,98 +16,111 @@ from typing import Iterator -from acme import specs -from acme import types +import dm_env +from absl.testing import absltest, parameterized + +from acme import specs, types from acme.agents.jax import crr from acme.jax import experiments from acme.jax import types as jax_types from acme.jax.experiments import test_utils as experiment_test_utils -from acme.testing import fakes -from acme.testing import test_utils -import dm_env -from absl.testing import absltest -from absl.testing import parameterized +from acme.testing import fakes, test_utils class RunOfflineExperimentTest(test_utils.TestCase): - - @parameterized.named_parameters( - dict(testcase_name='noeval', num_eval_episodes=0), - dict(testcase_name='eval', num_eval_episodes=1)) - def test_checkpointing(self, num_eval_episodes: int): - num_learner_steps = 100 - - experiment_config = self._get_experiment_config( - num_learner_steps=num_learner_steps) - - experiments.run_offline_experiment( - experiment_config, num_eval_episodes=num_eval_episodes) - - checkpoint_counter = experiment_test_utils.restore_counter( - experiment_config.checkpointing) - self.assertIn('learner_steps', checkpoint_counter.get_counts()) - self.assertGreater(checkpoint_counter.get_counts()['learner_steps'], 0) - - # Run the second experiment with the same checkpointing config to verify - # that it restores from the latest saved checkpoint. - experiments.run_offline_experiment( - experiment_config, num_eval_episodes=num_eval_episodes) - - checkpoint_counter = experiment_test_utils.restore_counter( - experiment_config.checkpointing) - self.assertIn('learner_steps', checkpoint_counter.get_counts()) - # Verify that the steps done in the first run are taken into account. - self.assertLessEqual(checkpoint_counter.get_counts()['learner_steps'], - num_learner_steps) - - def test_eval_every(self): - num_learner_steps = 100 - - experiment_config = self._get_experiment_config( - num_learner_steps=num_learner_steps) - - experiments.run_offline_experiment( - experiment_config, eval_every=70, num_eval_episodes=1) - - checkpoint_counter = experiment_test_utils.restore_counter( - experiment_config.checkpointing) - self.assertIn('learner_steps', checkpoint_counter.get_counts()) - self.assertGreater(checkpoint_counter.get_counts()['learner_steps'], 0) - self.assertLessEqual(checkpoint_counter.get_counts()['learner_steps'], - num_learner_steps) - - def _get_experiment_config( - self, *, num_learner_steps: int) -> experiments.OfflineExperimentConfig: - def environment_factory(seed: int) -> dm_env.Environment: - del seed - return fakes.ContinuousEnvironment( - episode_length=10, action_dim=3, observation_dim=5) - - environment = environment_factory(seed=1) - environment_spec = specs.make_environment_spec(environment) - - def demonstration_dataset_factory( - random_key: jax_types.PRNGKey) -> Iterator[types.Transition]: - del random_key - batch_size = 64 - return fakes.transition_iterator_from_spec(environment_spec)(batch_size) - - crr_config = crr.CRRConfig() - crr_builder = crr.CRRBuilder( - crr_config, policy_loss_coeff_fn=crr.policy_loss_coeff_advantage_exp) - checkpointing_config = experiments.CheckpointingConfig( - directory=self.get_tempdir(), time_delta_minutes=0) - return experiments.OfflineExperimentConfig( - builder=crr_builder, - network_factory=crr.make_networks, - demonstration_dataset_factory=demonstration_dataset_factory, - environment_factory=environment_factory, - max_num_learner_steps=num_learner_steps, - seed=0, - environment_spec=environment_spec, - checkpointing=checkpointing_config, + @parameterized.named_parameters( + dict(testcase_name="noeval", num_eval_episodes=0), + dict(testcase_name="eval", num_eval_episodes=1), ) - - -if __name__ == '__main__': - absltest.main() + def test_checkpointing(self, num_eval_episodes: int): + num_learner_steps = 100 + + experiment_config = self._get_experiment_config( + num_learner_steps=num_learner_steps + ) + + experiments.run_offline_experiment( + experiment_config, num_eval_episodes=num_eval_episodes + ) + + checkpoint_counter = experiment_test_utils.restore_counter( + experiment_config.checkpointing + ) + self.assertIn("learner_steps", checkpoint_counter.get_counts()) + self.assertGreater(checkpoint_counter.get_counts()["learner_steps"], 0) + + # Run the second experiment with the same checkpointing config to verify + # that it restores from the latest saved checkpoint. + experiments.run_offline_experiment( + experiment_config, num_eval_episodes=num_eval_episodes + ) + + checkpoint_counter = experiment_test_utils.restore_counter( + experiment_config.checkpointing + ) + self.assertIn("learner_steps", checkpoint_counter.get_counts()) + # Verify that the steps done in the first run are taken into account. + self.assertLessEqual( + checkpoint_counter.get_counts()["learner_steps"], num_learner_steps + ) + + def test_eval_every(self): + num_learner_steps = 100 + + experiment_config = self._get_experiment_config( + num_learner_steps=num_learner_steps + ) + + experiments.run_offline_experiment( + experiment_config, eval_every=70, num_eval_episodes=1 + ) + + checkpoint_counter = experiment_test_utils.restore_counter( + experiment_config.checkpointing + ) + self.assertIn("learner_steps", checkpoint_counter.get_counts()) + self.assertGreater(checkpoint_counter.get_counts()["learner_steps"], 0) + self.assertLessEqual( + checkpoint_counter.get_counts()["learner_steps"], num_learner_steps + ) + + def _get_experiment_config( + self, *, num_learner_steps: int + ) -> experiments.OfflineExperimentConfig: + def environment_factory(seed: int) -> dm_env.Environment: + del seed + return fakes.ContinuousEnvironment( + episode_length=10, action_dim=3, observation_dim=5 + ) + + environment = environment_factory(seed=1) + environment_spec = specs.make_environment_spec(environment) + + def demonstration_dataset_factory( + random_key: jax_types.PRNGKey, + ) -> Iterator[types.Transition]: + del random_key + batch_size = 64 + return fakes.transition_iterator_from_spec(environment_spec)(batch_size) + + crr_config = crr.CRRConfig() + crr_builder = crr.CRRBuilder( + crr_config, policy_loss_coeff_fn=crr.policy_loss_coeff_advantage_exp + ) + checkpointing_config = experiments.CheckpointingConfig( + directory=self.get_tempdir(), time_delta_minutes=0 + ) + return experiments.OfflineExperimentConfig( + builder=crr_builder, + network_factory=crr.make_networks, + demonstration_dataset_factory=demonstration_dataset_factory, + environment_factory=environment_factory, + max_num_learner_steps=num_learner_steps, + seed=0, + environment_spec=environment_spec, + checkpointing=checkpointing_config, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/jax/experiments/test_utils.py b/acme/jax/experiments/test_utils.py index 5fe695a5a5..dece474194 100644 --- a/acme/jax/experiments/test_utils.py +++ b/acme/jax/experiments/test_utils.py @@ -20,12 +20,14 @@ def restore_counter( - checkpointing_config: experiments.CheckpointingConfig) -> counting.Counter: - """Restores a counter from the latest checkpoint saved with this config.""" - counter = counting.Counter() - savers.Checkpointer( - objects_to_save={'counter': counter}, - directory=checkpointing_config.directory, - add_uid=checkpointing_config.add_uid, - max_to_keep=checkpointing_config.max_to_keep) - return counter + checkpointing_config: experiments.CheckpointingConfig, +) -> counting.Counter: + """Restores a counter from the latest checkpoint saved with this config.""" + counter = counting.Counter() + savers.Checkpointer( + objects_to_save={"counter": counter}, + directory=checkpointing_config.directory, + add_uid=checkpointing_config.add_uid, + max_to_keep=checkpointing_config.max_to_keep, + ) + return counter diff --git a/acme/jax/imitation_learning_types.py b/acme/jax/imitation_learning_types.py index 996e6e26d5..7a429dc359 100644 --- a/acme/jax/imitation_learning_types.py +++ b/acme/jax/imitation_learning_types.py @@ -17,6 +17,6 @@ from typing import TypeVar # Common TypeVars that correspond to various aspects of the direct RL algorithm. -DirectPolicyNetwork = TypeVar('DirectPolicyNetwork') -DirectRLNetworks = TypeVar('DirectRLNetworks') -DirectRLTrainingState = TypeVar('DirectRLTrainingState') +DirectPolicyNetwork = TypeVar("DirectPolicyNetwork") +DirectRLNetworks = TypeVar("DirectRLNetworks") +DirectRLTrainingState = TypeVar("DirectRLTrainingState") diff --git a/acme/jax/inference_server.py b/acme/jax/inference_server.py index 9e4f27a684..70193bd2c7 100644 --- a/acme/jax/inference_server.py +++ b/acme/jax/inference_server.py @@ -18,15 +18,17 @@ import datetime import threading from typing import Any, Callable, Generic, Optional, Sequence, TypeVar -import acme -from acme.jax import variable_utils + import jax import launchpad as lp +import acme +from acme.jax import variable_utils + @dataclasses.dataclass class InferenceServerConfig: - """Configuration options for centralised inference. + """Configuration options for centralised inference. Attributes: batch_size: How many elements to batch together per single inference call. @@ -37,25 +39,26 @@ class InferenceServerConfig: so there batch handler is always called with batch_size elements). By default timeout is effectively disabled (set to 30 days). """ - batch_size: Optional[int] = None - update_period: Optional[int] = None - timeout: datetime.timedelta = datetime.timedelta(days=30) + + batch_size: Optional[int] = None + update_period: Optional[int] = None + timeout: datetime.timedelta = datetime.timedelta(days=30) -InferenceServerHandler = TypeVar('InferenceServerHandler') +InferenceServerHandler = TypeVar("InferenceServerHandler") class InferenceServer(Generic[InferenceServerHandler]): - """Centralised, batched inference server.""" + """Centralised, batched inference server.""" - def __init__( - self, - handler: InferenceServerHandler, - variable_source: acme.VariableSource, - devices: Sequence[jax.Device], - config: InferenceServerConfig, - ): - """Constructs an inference server object. + def __init__( + self, + handler: InferenceServerHandler, + variable_source: acme.VariableSource, + devices: Sequence[jax.Device], + config: InferenceServerConfig, + ): + """Constructs an inference server object. Args: handler: A callable or a mapping of callables to be exposed @@ -65,78 +68,81 @@ def __init__( parallel. config: Inference Server configuration. """ - self._variable_source = variable_source - self._variable_client = None - self._keys = [] - self._devices = devices - self._config = config - self._call_cnt = 0 - self._device_params = [None] * len(self._devices) - self._device_params_ids = [None] * len(self._devices) - self._mutex = threading.Lock() - self._handler = jax.tree_map(self._build_handler, handler, is_leaf=callable) - - @property - def handler(self) -> InferenceServerHandler: - return self._handler - - def _dereference_params(self, arg): - """Replaces VariableReferences with their corresponding param values.""" - - if not isinstance(arg, variable_utils.VariableReference): - # All arguments but VariableReference are returned without modifications. - return arg - - # Due to batching dimension we take the first element. - variable_name = arg.variable_name[0] - - if variable_name not in self._keys: - # Create a new VariableClient which also serves new variables. - self._keys.append(variable_name) - self._variable_client = variable_utils.VariableClient( - client=self._variable_source, - key=self._keys, - update_period=self._config.update_period) - - params = self._variable_client.params - device_idx = self._call_cnt % len(self._devices) - # Select device via round robin, and update its params if they changed. - if self._device_params_ids[device_idx] != id(params): - self._device_params_ids[device_idx] = id(params) - self._device_params[device_idx] = jax.device_put( - params, self._devices[device_idx]) - - # Return the params that are located on the chosen device. - device_params = self._device_params[device_idx] - if len(self._keys) == 1: - return device_params - return device_params[self._keys.index(variable_name)] - - def _build_handler(self, handler: Callable[..., Any]) -> Callable[..., Any]: - """Builds a batched handler for a given callable handler and its name.""" - - def dereference_params_and_call_handler(*args, **kwargs): - with self._mutex: - # Dereference args corresponding to params, leaving others unchanged. - args_with_dereferenced_params = [ - self._dereference_params(arg) for arg in args - ] - kwargs_with_dereferenced_params = { - key: self._dereference_params(value) - for key, value in kwargs.items() - } - self._call_cnt += 1 - - # Maybe update params, depending on client configuration. - if self._variable_client is not None: - self._variable_client.update() - - return handler(*args_with_dereferenced_params, - **kwargs_with_dereferenced_params) - - return lp.batched_handler( - batch_size=self._config.batch_size, - timeout=self._config.timeout, - pad_batch=True, - max_parallelism=2 * len(self._devices))( - dereference_params_and_call_handler) + self._variable_source = variable_source + self._variable_client = None + self._keys = [] + self._devices = devices + self._config = config + self._call_cnt = 0 + self._device_params = [None] * len(self._devices) + self._device_params_ids = [None] * len(self._devices) + self._mutex = threading.Lock() + self._handler = jax.tree_map(self._build_handler, handler, is_leaf=callable) + + @property + def handler(self) -> InferenceServerHandler: + return self._handler + + def _dereference_params(self, arg): + """Replaces VariableReferences with their corresponding param values.""" + + if not isinstance(arg, variable_utils.VariableReference): + # All arguments but VariableReference are returned without modifications. + return arg + + # Due to batching dimension we take the first element. + variable_name = arg.variable_name[0] + + if variable_name not in self._keys: + # Create a new VariableClient which also serves new variables. + self._keys.append(variable_name) + self._variable_client = variable_utils.VariableClient( + client=self._variable_source, + key=self._keys, + update_period=self._config.update_period, + ) + + params = self._variable_client.params + device_idx = self._call_cnt % len(self._devices) + # Select device via round robin, and update its params if they changed. + if self._device_params_ids[device_idx] != id(params): + self._device_params_ids[device_idx] = id(params) + self._device_params[device_idx] = jax.device_put( + params, self._devices[device_idx] + ) + + # Return the params that are located on the chosen device. + device_params = self._device_params[device_idx] + if len(self._keys) == 1: + return device_params + return device_params[self._keys.index(variable_name)] + + def _build_handler(self, handler: Callable[..., Any]) -> Callable[..., Any]: + """Builds a batched handler for a given callable handler and its name.""" + + def dereference_params_and_call_handler(*args, **kwargs): + with self._mutex: + # Dereference args corresponding to params, leaving others unchanged. + args_with_dereferenced_params = [ + self._dereference_params(arg) for arg in args + ] + kwargs_with_dereferenced_params = { + key: self._dereference_params(value) + for key, value in kwargs.items() + } + self._call_cnt += 1 + + # Maybe update params, depending on client configuration. + if self._variable_client is not None: + self._variable_client.update() + + return handler( + *args_with_dereferenced_params, **kwargs_with_dereferenced_params + ) + + return lp.batched_handler( + batch_size=self._config.batch_size, + timeout=self._config.timeout, + pad_batch=True, + max_parallelism=2 * len(self._devices), + )(dereference_params_and_call_handler) diff --git a/acme/jax/losses/__init__.py b/acme/jax/losses/__init__.py index 9a1f9d9532..ed5cdb76d5 100644 --- a/acme/jax/losses/__init__.py +++ b/acme/jax/losses/__init__.py @@ -15,6 +15,4 @@ """Common loss functions.""" from acme.jax.losses.impala import impala_loss -from acme.jax.losses.mpo import MPO -from acme.jax.losses.mpo import MPOParams -from acme.jax.losses.mpo import MPOStats +from acme.jax.losses.mpo import MPO, MPOParams, MPOStats diff --git a/acme/jax/losses/impala.py b/acme/jax/losses/impala.py index dba61033d7..129b04f0a6 100644 --- a/acme/jax/losses/impala.py +++ b/acme/jax/losses/impala.py @@ -19,8 +19,6 @@ from typing import Callable, Mapping, Tuple -from acme.agents.jax.impala import types -from acme.jax import utils import haiku as hk import jax import jax.numpy as jnp @@ -29,6 +27,9 @@ import rlax import tree +from acme.agents.jax.impala import types +from acme.jax import utils + def impala_loss( unroll_fn: types.PolicyValueFn, @@ -38,7 +39,7 @@ def impala_loss( baseline_cost: float = 1.0, entropy_cost: float = 0.0, ) -> Callable[[hk.Params, reverb.ReplaySample], jax.Array]: - """Builds the standard entropy-regularised IMPALA loss function. + """Builds the standard entropy-regularised IMPALA loss function. Args: unroll_fn: A `hk.Transformed` object containing a callable which maps @@ -52,63 +53,71 @@ def impala_loss( A loss function with signature (params, data) -> (loss_scalar, metrics). """ - def loss_fn( - params: hk.Params, - sample: reverb.ReplaySample, - ) -> Tuple[jax.Array, Mapping[str, jax.Array]]: - """Batched, entropy-regularised actor-critic loss with V-trace.""" - - # Extract the data. - data = sample.data - observations, actions, rewards, discounts, extra = (data.observation, - data.action, - data.reward, - data.discount, - data.extras) - initial_state = tree.map_structure(lambda s: s[0], extra['core_state']) - behaviour_logits = extra['logits'] - - # Apply reward clipping. - rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) - - # Unroll current policy over observations. - (logits, values), _ = unroll_fn(params, observations, initial_state) - - # Compute importance sampling weights: current policy / behavior policy. - rhos = rlax.categorical_importance_sampling_ratios(logits[:-1], - behaviour_logits[:-1], - actions[:-1]) - - # Critic loss. - vtrace_returns = rlax.vtrace_td_error_and_advantage( - v_tm1=values[:-1], - v_t=values[1:], - r_t=rewards[:-1], - discount_t=discounts[:-1] * discount, - rho_tm1=rhos) - critic_loss = jnp.square(vtrace_returns.errors) - - # Policy gradient loss. - policy_gradient_loss = rlax.policy_gradient_loss( - logits_t=logits[:-1], - a_t=actions[:-1], - adv_t=vtrace_returns.pg_advantage, - w_t=jnp.ones_like(rewards[:-1])) - - # Entropy regulariser. - entropy_loss = rlax.entropy_loss(logits[:-1], jnp.ones_like(rewards[:-1])) - - # Combine weighted sum of actor & critic losses, averaged over the sequence. - mean_loss = jnp.mean(policy_gradient_loss + baseline_cost * critic_loss + - entropy_cost * entropy_loss) # [] - - metrics = { - 'policy_loss': jnp.mean(policy_gradient_loss), - 'critic_loss': jnp.mean(baseline_cost * critic_loss), - 'entropy_loss': jnp.mean(entropy_cost * entropy_loss), - 'entropy': jnp.mean(entropy_loss), - } - - return mean_loss, metrics - - return utils.mapreduce(loss_fn, in_axes=(None, 0)) # pytype: disable=bad-return-type # jax-devicearray + def loss_fn( + params: hk.Params, sample: reverb.ReplaySample, + ) -> Tuple[jax.Array, Mapping[str, jax.Array]]: + """Batched, entropy-regularised actor-critic loss with V-trace.""" + + # Extract the data. + data = sample.data + observations, actions, rewards, discounts, extra = ( + data.observation, + data.action, + data.reward, + data.discount, + data.extras, + ) + initial_state = tree.map_structure(lambda s: s[0], extra["core_state"]) + behaviour_logits = extra["logits"] + + # Apply reward clipping. + rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) + + # Unroll current policy over observations. + (logits, values), _ = unroll_fn(params, observations, initial_state) + + # Compute importance sampling weights: current policy / behavior policy. + rhos = rlax.categorical_importance_sampling_ratios( + logits[:-1], behaviour_logits[:-1], actions[:-1] + ) + + # Critic loss. + vtrace_returns = rlax.vtrace_td_error_and_advantage( + v_tm1=values[:-1], + v_t=values[1:], + r_t=rewards[:-1], + discount_t=discounts[:-1] * discount, + rho_tm1=rhos, + ) + critic_loss = jnp.square(vtrace_returns.errors) + + # Policy gradient loss. + policy_gradient_loss = rlax.policy_gradient_loss( + logits_t=logits[:-1], + a_t=actions[:-1], + adv_t=vtrace_returns.pg_advantage, + w_t=jnp.ones_like(rewards[:-1]), + ) + + # Entropy regulariser. + entropy_loss = rlax.entropy_loss(logits[:-1], jnp.ones_like(rewards[:-1])) + + # Combine weighted sum of actor & critic losses, averaged over the sequence. + mean_loss = jnp.mean( + policy_gradient_loss + + baseline_cost * critic_loss + + entropy_cost * entropy_loss + ) # [] + + metrics = { + "policy_loss": jnp.mean(policy_gradient_loss), + "critic_loss": jnp.mean(baseline_cost * critic_loss), + "entropy_loss": jnp.mean(entropy_cost * entropy_loss), + "entropy": jnp.mean(entropy_loss), + } + + return mean_loss, metrics + + return utils.mapreduce( + loss_fn, in_axes=(None, 0) + ) # pytype: disable=bad-return-type # jax-devicearray diff --git a/acme/jax/losses/impala_test.py b/acme/jax/losses/impala_test.py index b8f2739680..59861013c4 100644 --- a/acme/jax/losses/impala_test.py +++ b/acme/jax/losses/impala_test.py @@ -14,88 +14,87 @@ """Tests for the IMPALA loss function.""" -from acme.adders import reverb as adders -from acme.jax.losses import impala -from acme.utils.tree_utils import tree_map import haiku as hk import jax import jax.numpy as jnp import numpy as np import reverb - from absl.testing import absltest +from acme.adders import reverb as adders +from acme.jax.losses import impala +from acme.utils.tree_utils import tree_map -class ImpalaTest(absltest.TestCase): - def test_shapes(self): - - # - batch_size = 2 - sequence_len = 3 - num_actions = 5 - hidden_size = 7 - - # Define a trivial recurrent actor-critic network. - @hk.without_apply_rng - @hk.transform - def unroll_fn_transformed(observations, state): - lstm = hk.LSTM(hidden_size) - embedding, state = hk.dynamic_unroll(lstm, observations, state) - logits = hk.Linear(num_actions)(embedding) - values = jnp.squeeze(hk.Linear(1)(embedding), axis=-1) - - return (logits, values), state - - @hk.without_apply_rng - @hk.transform - def initial_state_fn(): - return hk.LSTM(hidden_size).initial_state(None) - - # Initial recurrent network state. - initial_state = initial_state_fn.apply(None) - - # Make some fake data. - observations = np.ones(shape=(sequence_len, 50)) - actions = np.random.randint(num_actions, size=sequence_len) - rewards = np.random.rand(sequence_len) - discounts = np.ones(shape=(sequence_len,)) - - batch_tile = tree_map(lambda x: np.tile(x, [batch_size, *([1] * x.ndim)])) - seq_tile = tree_map(lambda x: np.tile(x, [sequence_len, *([1] * x.ndim)])) - - extras = { - 'logits': np.random.rand(sequence_len, num_actions), - 'core_state': seq_tile(initial_state), - } - - # Package up the data into a ReverbSample. - data = adders.Step( - observations, - actions, - rewards, - discounts, - extras=extras, - start_of_episode=()) - data = batch_tile(data) - sample = reverb.ReplaySample(info=None, data=data) - - # Initialise parameters. - rng = hk.PRNGSequence(1) - params = unroll_fn_transformed.init(next(rng), observations, initial_state) - - # Make loss function. - loss_fn = impala.impala_loss( - unroll_fn_transformed.apply, discount=0.99) - - # Return value should be scalar. - loss, metrics = loss_fn(params, sample) - loss = jax.device_get(loss) - self.assertEqual(loss.shape, ()) - for value in metrics.values(): - value = jax.device_get(value) - self.assertEqual(value.shape, ()) - - -if __name__ == '__main__': - absltest.main() +class ImpalaTest(absltest.TestCase): + def test_shapes(self): + + # + batch_size = 2 + sequence_len = 3 + num_actions = 5 + hidden_size = 7 + + # Define a trivial recurrent actor-critic network. + @hk.without_apply_rng + @hk.transform + def unroll_fn_transformed(observations, state): + lstm = hk.LSTM(hidden_size) + embedding, state = hk.dynamic_unroll(lstm, observations, state) + logits = hk.Linear(num_actions)(embedding) + values = jnp.squeeze(hk.Linear(1)(embedding), axis=-1) + + return (logits, values), state + + @hk.without_apply_rng + @hk.transform + def initial_state_fn(): + return hk.LSTM(hidden_size).initial_state(None) + + # Initial recurrent network state. + initial_state = initial_state_fn.apply(None) + + # Make some fake data. + observations = np.ones(shape=(sequence_len, 50)) + actions = np.random.randint(num_actions, size=sequence_len) + rewards = np.random.rand(sequence_len) + discounts = np.ones(shape=(sequence_len,)) + + batch_tile = tree_map(lambda x: np.tile(x, [batch_size, *([1] * x.ndim)])) + seq_tile = tree_map(lambda x: np.tile(x, [sequence_len, *([1] * x.ndim)])) + + extras = { + "logits": np.random.rand(sequence_len, num_actions), + "core_state": seq_tile(initial_state), + } + + # Package up the data into a ReverbSample. + data = adders.Step( + observations, + actions, + rewards, + discounts, + extras=extras, + start_of_episode=(), + ) + data = batch_tile(data) + sample = reverb.ReplaySample(info=None, data=data) + + # Initialise parameters. + rng = hk.PRNGSequence(1) + params = unroll_fn_transformed.init(next(rng), observations, initial_state) + + # Make loss function. + loss_fn = impala.impala_loss(unroll_fn_transformed.apply, discount=0.99) + + # Return value should be scalar. + loss, metrics = loss_fn(params, sample) + loss = jax.device_get(loss) + self.assertEqual(loss.shape, ()) + for value in metrics.values(): + value = jax.device_get(value) + self.assertEqual(value.shape, ()) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/jax/losses/mpo.py b/acme/jax/losses/mpo.py index 6902c00cf2..2d39abcb3f 100644 --- a/acme/jax/losses/mpo.py +++ b/acme/jax/losses/mpo.py @@ -41,39 +41,41 @@ class MPOParams(NamedTuple): - """NamedTuple to store trainable loss parameters.""" - log_temperature: jnp.ndarray - log_alpha_mean: jnp.ndarray - log_alpha_stddev: jnp.ndarray - log_penalty_temperature: Optional[jnp.ndarray] = None + """NamedTuple to store trainable loss parameters.""" + + log_temperature: jnp.ndarray + log_alpha_mean: jnp.ndarray + log_alpha_stddev: jnp.ndarray + log_penalty_temperature: Optional[jnp.ndarray] = None class MPOStats(NamedTuple): - """NamedTuple to store loss statistics.""" - dual_alpha_mean: float - dual_alpha_stddev: float - dual_temperature: float + """NamedTuple to store loss statistics.""" + + dual_alpha_mean: float + dual_alpha_stddev: float + dual_temperature: float - loss_policy: float - loss_alpha: float - loss_temperature: float - kl_q_rel: float + loss_policy: float + loss_alpha: float + loss_temperature: float + kl_q_rel: float - kl_mean_rel: float - kl_stddev_rel: float + kl_mean_rel: float + kl_stddev_rel: float - q_min: float - q_max: float + q_min: float + q_max: float - pi_stddev_min: float - pi_stddev_max: float - pi_stddev_cond: float + pi_stddev_min: float + pi_stddev_max: float + pi_stddev_cond: float - penalty_kl_q_rel: Optional[float] = None + penalty_kl_q_rel: Optional[float] = None class MPO: - """MPO loss with decoupled KL constraints as in (Abdolmaleki et al., 2018). + """MPO loss with decoupled KL constraints as in (Abdolmaleki et al., 2018). This implementation of the MPO loss includes the following features, as options: @@ -86,17 +88,19 @@ class MPO: (Abdolmaleki et al., 2020): https://arxiv.org/pdf/2005.07513.pdf """ - def __init__(self, - epsilon: float, - epsilon_mean: float, - epsilon_stddev: float, - init_log_temperature: float, - init_log_alpha_mean: float, - init_log_alpha_stddev: float, - per_dim_constraining: bool = True, - action_penalization: bool = True, - epsilon_penalty: float = 0.001): - """Initialize and configure the MPO loss. + def __init__( + self, + epsilon: float, + epsilon_mean: float, + epsilon_stddev: float, + init_log_temperature: float, + init_log_alpha_mean: float, + init_log_alpha_stddev: float, + per_dim_constraining: bool = True, + action_penalization: bool = True, + epsilon_penalty: float = 0.001, + ): + """Initialize and configure the MPO loss. Args: epsilon: KL constraint on the non-parametric auxiliary policy, the one @@ -121,68 +125,69 @@ def __init__(self, constraint. """ - # MPO constrain thresholds. - self._epsilon = epsilon - self._epsilon_mean = epsilon_mean - self._epsilon_stddev = epsilon_stddev - - # Initial values for the constraints' dual variables. - self._init_log_temperature = init_log_temperature - self._init_log_alpha_mean = init_log_alpha_mean - self._init_log_alpha_stddev = init_log_alpha_stddev - - # Whether to penalize out-of-bound actions via MO-MPO and its corresponding - # constraint threshold. - self._action_penalization = action_penalization - self._epsilon_penalty = epsilon_penalty - - # Whether to ensure per-dimension KL constraint satisfication. - self._per_dim_constraining = per_dim_constraining - - @property - def per_dim_constraining(self): - return self._per_dim_constraining - - def init_params(self, action_dim: int, dtype: DType = jnp.float32): - """Creates an initial set of parameters.""" - - if self._per_dim_constraining: - dual_variable_shape = [action_dim] - else: - dual_variable_shape = [1] - - log_temperature = jnp.full([1], self._init_log_temperature, dtype=dtype) - - log_alpha_mean = jnp.full( - dual_variable_shape, self._init_log_alpha_mean, dtype=dtype) - - log_alpha_stddev = jnp.full( - dual_variable_shape, self._init_log_alpha_stddev, dtype=dtype) - - if self._action_penalization: - log_penalty_temperature = jnp.full([1], - self._init_log_temperature, - dtype=dtype) - else: - log_penalty_temperature = None - - return MPOParams( - log_temperature=log_temperature, - log_alpha_mean=log_alpha_mean, - log_alpha_stddev=log_alpha_stddev, - log_penalty_temperature=log_penalty_temperature) - - def __call__( - self, - params: MPOParams, - online_action_distribution: Union[tfd.MultivariateNormalDiag, - tfd.Independent], - target_action_distribution: Union[tfd.MultivariateNormalDiag, - tfd.Independent], - actions: jnp.ndarray, # Shape [N, B, D]. - q_values: jnp.ndarray, # Shape [N, B]. - ) -> Tuple[jnp.ndarray, MPOStats]: - """Computes the decoupled MPO loss. + # MPO constrain thresholds. + self._epsilon = epsilon + self._epsilon_mean = epsilon_mean + self._epsilon_stddev = epsilon_stddev + + # Initial values for the constraints' dual variables. + self._init_log_temperature = init_log_temperature + self._init_log_alpha_mean = init_log_alpha_mean + self._init_log_alpha_stddev = init_log_alpha_stddev + + # Whether to penalize out-of-bound actions via MO-MPO and its corresponding + # constraint threshold. + self._action_penalization = action_penalization + self._epsilon_penalty = epsilon_penalty + + # Whether to ensure per-dimension KL constraint satisfication. + self._per_dim_constraining = per_dim_constraining + + @property + def per_dim_constraining(self): + return self._per_dim_constraining + + def init_params(self, action_dim: int, dtype: DType = jnp.float32): + """Creates an initial set of parameters.""" + + if self._per_dim_constraining: + dual_variable_shape = [action_dim] + else: + dual_variable_shape = [1] + + log_temperature = jnp.full([1], self._init_log_temperature, dtype=dtype) + + log_alpha_mean = jnp.full( + dual_variable_shape, self._init_log_alpha_mean, dtype=dtype + ) + + log_alpha_stddev = jnp.full( + dual_variable_shape, self._init_log_alpha_stddev, dtype=dtype + ) + + if self._action_penalization: + log_penalty_temperature = jnp.full( + [1], self._init_log_temperature, dtype=dtype + ) + else: + log_penalty_temperature = None + + return MPOParams( + log_temperature=log_temperature, + log_alpha_mean=log_alpha_mean, + log_alpha_stddev=log_alpha_stddev, + log_penalty_temperature=log_penalty_temperature, + ) + + def __call__( + self, + params: MPOParams, + online_action_distribution: Union[tfd.MultivariateNormalDiag, tfd.Independent], + target_action_distribution: Union[tfd.MultivariateNormalDiag, tfd.Independent], + actions: jnp.ndarray, # Shape [N, B, D]. + q_values: jnp.ndarray, # Shape [N, B]. + ) -> Tuple[jnp.ndarray, MPOStats]: + """Computes the decoupled MPO loss. Args: params: parameters tracking the temperature and the dual variables. @@ -199,137 +204,162 @@ def __call__( Stats, for diagnostics and tracking performance. """ - # Cast `MultivariateNormalDiag`s to Independent Normals. - # The latter allows us to satisfy KL constraints per-dimension. - if isinstance(target_action_distribution, tfd.MultivariateNormalDiag): - target_action_distribution = tfd.Independent( - tfd.Normal(target_action_distribution.mean(), - target_action_distribution.stddev())) - online_action_distribution = tfd.Independent( - tfd.Normal(online_action_distribution.mean(), - online_action_distribution.stddev())) - - # Transform dual variables from log-space. - # Note: using softplus instead of exponential for numerical stability. - temperature = jax.nn.softplus(params.log_temperature) + _MPO_FLOAT_EPSILON - alpha_mean = jax.nn.softplus(params.log_alpha_mean) + _MPO_FLOAT_EPSILON - alpha_stddev = jax.nn.softplus(params.log_alpha_stddev) + _MPO_FLOAT_EPSILON - - # Get online and target means and stddevs in preparation for decomposition. - online_mean = online_action_distribution.distribution.mean() - online_scale = online_action_distribution.distribution.stddev() - target_mean = target_action_distribution.distribution.mean() - target_scale = target_action_distribution.distribution.stddev() - - # Compute normalized importance weights, used to compute expectations with - # respect to the non-parametric policy; and the temperature loss, used to - # adapt the tempering of Q-values. - normalized_weights, loss_temperature = compute_weights_and_temperature_loss( - q_values, self._epsilon, temperature) - - # Only needed for diagnostics: Compute estimated actualized KL between the - # non-parametric and current target policies. - kl_nonparametric = compute_nonparametric_kl_from_normalized_weights( - normalized_weights) - - if self._action_penalization: - # Transform action penalization temperature. - penalty_temperature = jax.nn.softplus( - params.log_penalty_temperature) + _MPO_FLOAT_EPSILON - - # Compute action penalization cost. - # Note: the cost is zero in [-1, 1] and quadratic beyond. - diff_out_of_bound = actions - jnp.clip(actions, -1.0, 1.0) - cost_out_of_bound = -jnp.linalg.norm(diff_out_of_bound, axis=-1) - - penalty_normalized_weights, loss_penalty_temperature = compute_weights_and_temperature_loss( - cost_out_of_bound, self._epsilon_penalty, penalty_temperature) - - # Only needed for diagnostics: Compute estimated actualized KL between the - # non-parametric and current target policies. - penalty_kl_nonparametric = compute_nonparametric_kl_from_normalized_weights( - penalty_normalized_weights) - - # Combine normalized weights. - normalized_weights += penalty_normalized_weights - loss_temperature += loss_penalty_temperature - - # Decompose the online policy into fixed-mean & fixed-stddev distributions. - # This has been documented as having better performance in bandit settings, - # see e.g. https://arxiv.org/pdf/1812.02256.pdf. - fixed_stddev_distribution = tfd.Independent( - tfd.Normal(loc=online_mean, scale=target_scale)) - fixed_mean_distribution = tfd.Independent( - tfd.Normal(loc=target_mean, scale=online_scale)) - - # Compute the decomposed policy losses. - loss_policy_mean = compute_cross_entropy_loss(actions, normalized_weights, - fixed_stddev_distribution) - loss_policy_stddev = compute_cross_entropy_loss(actions, normalized_weights, - fixed_mean_distribution) - - # Compute the decomposed KL between the target and online policies. - if self._per_dim_constraining: - kl_mean = target_action_distribution.distribution.kl_divergence( - fixed_stddev_distribution.distribution) # Shape [B, D]. - kl_stddev = target_action_distribution.distribution.kl_divergence( - fixed_mean_distribution.distribution) # Shape [B, D]. - else: - kl_mean = target_action_distribution.kl_divergence( - fixed_stddev_distribution) # Shape [B]. - kl_stddev = target_action_distribution.kl_divergence( - fixed_mean_distribution) # Shape [B]. - - # Compute the alpha-weighted KL-penalty and dual losses to adapt the alphas. - loss_kl_mean, loss_alpha_mean = compute_parametric_kl_penalty_and_dual_loss( - kl_mean, alpha_mean, self._epsilon_mean) - loss_kl_stddev, loss_alpha_stddev = compute_parametric_kl_penalty_and_dual_loss( - kl_stddev, alpha_stddev, self._epsilon_stddev) - - # Combine losses. - loss_policy = loss_policy_mean + loss_policy_stddev - loss_kl_penalty = loss_kl_mean + loss_kl_stddev - loss_dual = loss_alpha_mean + loss_alpha_stddev + loss_temperature - loss = loss_policy + loss_kl_penalty + loss_dual - - # Create statistics. - pi_stddev = online_action_distribution.distribution.stddev() - stats = MPOStats( - # Dual Variables. - dual_alpha_mean=jnp.mean(alpha_mean), - dual_alpha_stddev=jnp.mean(alpha_stddev), - dual_temperature=jnp.mean(temperature), - # Losses. - loss_policy=jnp.mean(loss), - loss_alpha=jnp.mean(loss_alpha_mean + loss_alpha_stddev), - loss_temperature=jnp.mean(loss_temperature), - # KL measurements. - kl_q_rel=jnp.mean(kl_nonparametric) / self._epsilon, - penalty_kl_q_rel=((jnp.mean(penalty_kl_nonparametric) / - self._epsilon_penalty) - if self._action_penalization else None), - kl_mean_rel=jnp.mean(kl_mean, axis=0) / self._epsilon_mean, - kl_stddev_rel=jnp.mean(kl_stddev, axis=0) / self._epsilon_stddev, - # Q measurements. - q_min=jnp.mean(jnp.min(q_values, axis=0)), - q_max=jnp.mean(jnp.max(q_values, axis=0)), - # If the policy has stddev, log summary stats for this as well. - pi_stddev_min=jnp.mean(jnp.min(pi_stddev, axis=-1)), - pi_stddev_max=jnp.mean(jnp.max(pi_stddev, axis=-1)), - # Condition number of the diagonal covariance (actually, stddev) matrix. - pi_stddev_cond=jnp.mean( - jnp.max(pi_stddev, axis=-1) / jnp.min(pi_stddev, axis=-1)), - ) - - return loss, stats + # Cast `MultivariateNormalDiag`s to Independent Normals. + # The latter allows us to satisfy KL constraints per-dimension. + if isinstance(target_action_distribution, tfd.MultivariateNormalDiag): + target_action_distribution = tfd.Independent( + tfd.Normal( + target_action_distribution.mean(), + target_action_distribution.stddev(), + ) + ) + online_action_distribution = tfd.Independent( + tfd.Normal( + online_action_distribution.mean(), + online_action_distribution.stddev(), + ) + ) + + # Transform dual variables from log-space. + # Note: using softplus instead of exponential for numerical stability. + temperature = jax.nn.softplus(params.log_temperature) + _MPO_FLOAT_EPSILON + alpha_mean = jax.nn.softplus(params.log_alpha_mean) + _MPO_FLOAT_EPSILON + alpha_stddev = jax.nn.softplus(params.log_alpha_stddev) + _MPO_FLOAT_EPSILON + + # Get online and target means and stddevs in preparation for decomposition. + online_mean = online_action_distribution.distribution.mean() + online_scale = online_action_distribution.distribution.stddev() + target_mean = target_action_distribution.distribution.mean() + target_scale = target_action_distribution.distribution.stddev() + + # Compute normalized importance weights, used to compute expectations with + # respect to the non-parametric policy; and the temperature loss, used to + # adapt the tempering of Q-values. + normalized_weights, loss_temperature = compute_weights_and_temperature_loss( + q_values, self._epsilon, temperature + ) + + # Only needed for diagnostics: Compute estimated actualized KL between the + # non-parametric and current target policies. + kl_nonparametric = compute_nonparametric_kl_from_normalized_weights( + normalized_weights + ) + + if self._action_penalization: + # Transform action penalization temperature. + penalty_temperature = ( + jax.nn.softplus(params.log_penalty_temperature) + _MPO_FLOAT_EPSILON + ) + + # Compute action penalization cost. + # Note: the cost is zero in [-1, 1] and quadratic beyond. + diff_out_of_bound = actions - jnp.clip(actions, -1.0, 1.0) + cost_out_of_bound = -jnp.linalg.norm(diff_out_of_bound, axis=-1) + + ( + penalty_normalized_weights, + loss_penalty_temperature, + ) = compute_weights_and_temperature_loss( + cost_out_of_bound, self._epsilon_penalty, penalty_temperature + ) + + # Only needed for diagnostics: Compute estimated actualized KL between the + # non-parametric and current target policies. + penalty_kl_nonparametric = compute_nonparametric_kl_from_normalized_weights( + penalty_normalized_weights + ) + + # Combine normalized weights. + normalized_weights += penalty_normalized_weights + loss_temperature += loss_penalty_temperature + + # Decompose the online policy into fixed-mean & fixed-stddev distributions. + # This has been documented as having better performance in bandit settings, + # see e.g. https://arxiv.org/pdf/1812.02256.pdf. + fixed_stddev_distribution = tfd.Independent( + tfd.Normal(loc=online_mean, scale=target_scale) + ) + fixed_mean_distribution = tfd.Independent( + tfd.Normal(loc=target_mean, scale=online_scale) + ) + + # Compute the decomposed policy losses. + loss_policy_mean = compute_cross_entropy_loss( + actions, normalized_weights, fixed_stddev_distribution + ) + loss_policy_stddev = compute_cross_entropy_loss( + actions, normalized_weights, fixed_mean_distribution + ) + + # Compute the decomposed KL between the target and online policies. + if self._per_dim_constraining: + kl_mean = target_action_distribution.distribution.kl_divergence( + fixed_stddev_distribution.distribution + ) # Shape [B, D]. + kl_stddev = target_action_distribution.distribution.kl_divergence( + fixed_mean_distribution.distribution + ) # Shape [B, D]. + else: + kl_mean = target_action_distribution.kl_divergence( + fixed_stddev_distribution + ) # Shape [B]. + kl_stddev = target_action_distribution.kl_divergence( + fixed_mean_distribution + ) # Shape [B]. + + # Compute the alpha-weighted KL-penalty and dual losses to adapt the alphas. + loss_kl_mean, loss_alpha_mean = compute_parametric_kl_penalty_and_dual_loss( + kl_mean, alpha_mean, self._epsilon_mean + ) + loss_kl_stddev, loss_alpha_stddev = compute_parametric_kl_penalty_and_dual_loss( + kl_stddev, alpha_stddev, self._epsilon_stddev + ) + + # Combine losses. + loss_policy = loss_policy_mean + loss_policy_stddev + loss_kl_penalty = loss_kl_mean + loss_kl_stddev + loss_dual = loss_alpha_mean + loss_alpha_stddev + loss_temperature + loss = loss_policy + loss_kl_penalty + loss_dual + + # Create statistics. + pi_stddev = online_action_distribution.distribution.stddev() + stats = MPOStats( + # Dual Variables. + dual_alpha_mean=jnp.mean(alpha_mean), + dual_alpha_stddev=jnp.mean(alpha_stddev), + dual_temperature=jnp.mean(temperature), + # Losses. + loss_policy=jnp.mean(loss), + loss_alpha=jnp.mean(loss_alpha_mean + loss_alpha_stddev), + loss_temperature=jnp.mean(loss_temperature), + # KL measurements. + kl_q_rel=jnp.mean(kl_nonparametric) / self._epsilon, + penalty_kl_q_rel=( + (jnp.mean(penalty_kl_nonparametric) / self._epsilon_penalty) + if self._action_penalization + else None + ), + kl_mean_rel=jnp.mean(kl_mean, axis=0) / self._epsilon_mean, + kl_stddev_rel=jnp.mean(kl_stddev, axis=0) / self._epsilon_stddev, + # Q measurements. + q_min=jnp.mean(jnp.min(q_values, axis=0)), + q_max=jnp.mean(jnp.max(q_values, axis=0)), + # If the policy has stddev, log summary stats for this as well. + pi_stddev_min=jnp.mean(jnp.min(pi_stddev, axis=-1)), + pi_stddev_max=jnp.mean(jnp.max(pi_stddev, axis=-1)), + # Condition number of the diagonal covariance (actually, stddev) matrix. + pi_stddev_cond=jnp.mean( + jnp.max(pi_stddev, axis=-1) / jnp.min(pi_stddev, axis=-1) + ), + ) + + return loss, stats def compute_weights_and_temperature_loss( - q_values: jnp.ndarray, - epsilon: float, - temperature: jnp.ndarray, + q_values: jnp.ndarray, epsilon: float, temperature: jnp.ndarray, ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Computes normalized importance weights for the policy optimization. + """Computes normalized importance weights for the policy optimization. Args: q_values: Q-values associated with the actions sampled from the target @@ -346,33 +376,34 @@ def compute_weights_and_temperature_loss( Temperature loss, used to adapt the temperature. """ - # Temper the given Q-values using the current temperature. - tempered_q_values = jax.lax.stop_gradient(q_values) / temperature + # Temper the given Q-values using the current temperature. + tempered_q_values = jax.lax.stop_gradient(q_values) / temperature - # Compute the normalized importance weights used to compute expectations with - # respect to the non-parametric policy. - normalized_weights = jax.nn.softmax(tempered_q_values, axis=0) - normalized_weights = jax.lax.stop_gradient(normalized_weights) + # Compute the normalized importance weights used to compute expectations with + # respect to the non-parametric policy. + normalized_weights = jax.nn.softmax(tempered_q_values, axis=0) + normalized_weights = jax.lax.stop_gradient(normalized_weights) - # Compute the temperature loss (dual of the E-step optimization problem). - q_logsumexp = jax.scipy.special.logsumexp(tempered_q_values, axis=0) - log_num_actions = jnp.log(q_values.shape[0] / 1.) - loss_temperature = epsilon + jnp.mean(q_logsumexp) - log_num_actions - loss_temperature = temperature * loss_temperature + # Compute the temperature loss (dual of the E-step optimization problem). + q_logsumexp = jax.scipy.special.logsumexp(tempered_q_values, axis=0) + log_num_actions = jnp.log(q_values.shape[0] / 1.0) + loss_temperature = epsilon + jnp.mean(q_logsumexp) - log_num_actions + loss_temperature = temperature * loss_temperature - return normalized_weights, loss_temperature + return normalized_weights, loss_temperature def compute_nonparametric_kl_from_normalized_weights( - normalized_weights: jnp.ndarray) -> jnp.ndarray: - """Estimate the actualized KL between the non-parametric and target policies.""" + normalized_weights: jnp.ndarray, +) -> jnp.ndarray: + """Estimate the actualized KL between the non-parametric and target policies.""" - # Compute integrand. - num_action_samples = normalized_weights.shape[0] / 1. - integrand = jnp.log(num_action_samples * normalized_weights + 1e-8) + # Compute integrand. + num_action_samples = normalized_weights.shape[0] / 1.0 + integrand = jnp.log(num_action_samples * normalized_weights + 1e-8) - # Return the expectation with respect to the non-parametric policy. - return jnp.sum(normalized_weights * integrand, axis=0) + # Return the expectation with respect to the non-parametric policy. + return jnp.sum(normalized_weights * integrand, axis=0) def compute_cross_entropy_loss( @@ -380,7 +411,7 @@ def compute_cross_entropy_loss( normalized_weights: jnp.ndarray, online_action_distribution: tfd.Distribution, ) -> jnp.ndarray: - """Compute cross-entropy online and the reweighted target policy. + """Compute cross-entropy online and the reweighted target policy. Args: sampled_actions: samples used in the Monte Carlo integration in the policy @@ -395,22 +426,20 @@ def compute_cross_entropy_loss( produces the policy gradient. """ - # Compute the M-step loss. - log_prob = online_action_distribution.log_prob(sampled_actions) + # Compute the M-step loss. + log_prob = online_action_distribution.log_prob(sampled_actions) - # Compute the weighted average log-prob using the normalized weights. - loss_policy_gradient = -jnp.sum(log_prob * normalized_weights, axis=0) + # Compute the weighted average log-prob using the normalized weights. + loss_policy_gradient = -jnp.sum(log_prob * normalized_weights, axis=0) - # Return the mean loss over the batch of states. - return jnp.mean(loss_policy_gradient, axis=0) + # Return the mean loss over the batch of states. + return jnp.mean(loss_policy_gradient, axis=0) def compute_parametric_kl_penalty_and_dual_loss( - kl: jnp.ndarray, - alpha: jnp.ndarray, - epsilon: float, + kl: jnp.ndarray, alpha: jnp.ndarray, epsilon: float, ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Computes the KL cost to be added to the Lagragian and its dual loss. + """Computes the KL cost to be added to the Lagragian and its dual loss. The KL cost is simply the alpha-weighted KL divergence and it is added as a regularizer to the policy loss. The dual variable alpha itself has a loss that @@ -427,26 +456,29 @@ def compute_parametric_kl_penalty_and_dual_loss( loss_alpha: The Lagrange dual loss minimized to adapt alpha. """ - # Compute the mean KL over the batch. - mean_kl = jnp.mean(kl, axis=0) + # Compute the mean KL over the batch. + mean_kl = jnp.mean(kl, axis=0) - # Compute the regularization. - loss_kl = jnp.sum(jax.lax.stop_gradient(alpha) * mean_kl) + # Compute the regularization. + loss_kl = jnp.sum(jax.lax.stop_gradient(alpha) * mean_kl) - # Compute the dual loss. - loss_alpha = jnp.sum(alpha * (epsilon - jax.lax.stop_gradient(mean_kl))) + # Compute the dual loss. + loss_alpha = jnp.sum(alpha * (epsilon - jax.lax.stop_gradient(mean_kl))) - return loss_kl, loss_alpha + return loss_kl, loss_alpha def clip_mpo_params(params: MPOParams, per_dim_constraining: bool) -> MPOParams: - clipped_params = MPOParams( - log_temperature=jnp.maximum(_MIN_LOG_TEMPERATURE, params.log_temperature), - log_alpha_mean=jnp.maximum(_MIN_LOG_ALPHA, params.log_alpha_mean), - log_alpha_stddev=jnp.maximum(_MIN_LOG_ALPHA, params.log_alpha_stddev)) - if not per_dim_constraining: - return clipped_params - else: - return clipped_params._replace( - log_penalty_temperature=jnp.maximum(_MIN_LOG_TEMPERATURE, - params.log_penalty_temperature)) + clipped_params = MPOParams( + log_temperature=jnp.maximum(_MIN_LOG_TEMPERATURE, params.log_temperature), + log_alpha_mean=jnp.maximum(_MIN_LOG_ALPHA, params.log_alpha_mean), + log_alpha_stddev=jnp.maximum(_MIN_LOG_ALPHA, params.log_alpha_stddev), + ) + if not per_dim_constraining: + return clipped_params + else: + return clipped_params._replace( + log_penalty_temperature=jnp.maximum( + _MIN_LOG_TEMPERATURE, params.log_penalty_temperature + ) + ) diff --git a/acme/jax/networks/__init__.py b/acme/jax/networks/__init__.py index ffc3d41296..51a8c4fee2 100644 --- a/acme/jax/networks/__init__.py +++ b/acme/jax/networks/__init__.py @@ -14,47 +14,49 @@ """JAX networks implemented with Haiku.""" -from acme.jax.networks.atari import AtariTorso -from acme.jax.networks.atari import DeepIMPALAAtariNetwork -from acme.jax.networks.atari import dqn_atari_network -from acme.jax.networks.atari import R2D2AtariNetwork -from acme.jax.networks.base import Action -from acme.jax.networks.base import Entropy -from acme.jax.networks.base import FeedForwardNetwork -from acme.jax.networks.base import Logits -from acme.jax.networks.base import LogProb -from acme.jax.networks.base import LogProbFn -from acme.jax.networks.base import LSTMOutputs -from acme.jax.networks.base import make_unrollable_network -from acme.jax.networks.base import NetworkOutput -from acme.jax.networks.base import non_stochastic_network_to_typed -from acme.jax.networks.base import Observation -from acme.jax.networks.base import Params -from acme.jax.networks.base import PolicyValueRNN -from acme.jax.networks.base import PRNGKey -from acme.jax.networks.base import QNetwork -from acme.jax.networks.base import QValues -from acme.jax.networks.base import RecurrentQNetwork -from acme.jax.networks.base import RecurrentState -from acme.jax.networks.base import SampleFn -from acme.jax.networks.base import TypedFeedForwardNetwork -from acme.jax.networks.base import UnrollableNetwork -from acme.jax.networks.base import Value -from acme.jax.networks.continuous import LayerNormMLP -from acme.jax.networks.continuous import NearZeroInitializedLinear -from acme.jax.networks.distributional import CategoricalCriticHead -from acme.jax.networks.distributional import CategoricalHead -from acme.jax.networks.distributional import CategoricalValueHead -from acme.jax.networks.distributional import DiscreteValued -from acme.jax.networks.distributional import GaussianMixture -from acme.jax.networks.distributional import MultivariateNormalDiagHead -from acme.jax.networks.distributional import NormalTanhDistribution -from acme.jax.networks.distributional import TanhTransformedDistribution +from acme.jax.networks.atari import ( + AtariTorso, + DeepIMPALAAtariNetwork, + R2D2AtariNetwork, + dqn_atari_network, +) +from acme.jax.networks.base import ( + Action, + Entropy, + FeedForwardNetwork, + Logits, + LogProb, + LogProbFn, + LSTMOutputs, + NetworkOutput, + Observation, + Params, + PolicyValueRNN, + PRNGKey, + QNetwork, + QValues, + RecurrentQNetwork, + RecurrentState, + SampleFn, + TypedFeedForwardNetwork, + UnrollableNetwork, + Value, + make_unrollable_network, + non_stochastic_network_to_typed, +) +from acme.jax.networks.continuous import LayerNormMLP, NearZeroInitializedLinear +from acme.jax.networks.distributional import ( + CategoricalCriticHead, + CategoricalHead, + CategoricalValueHead, + DiscreteValued, + GaussianMixture, + MultivariateNormalDiagHead, + NormalTanhDistribution, + TanhTransformedDistribution, +) from acme.jax.networks.duelling import DuellingMLP from acme.jax.networks.multiplexers import CriticMultiplexer from acme.jax.networks.policy_value import PolicyValueHead -from acme.jax.networks.rescaling import ClipToSpec -from acme.jax.networks.rescaling import TanhToSpec -from acme.jax.networks.resnet import DownsamplingStrategy -from acme.jax.networks.resnet import ResidualBlock -from acme.jax.networks.resnet import ResNetTorso +from acme.jax.networks.rescaling import ClipToSpec, TanhToSpec +from acme.jax.networks.resnet import DownsamplingStrategy, ResidualBlock, ResNetTorso diff --git a/acme/jax/networks/atari.py b/acme/jax/networks/atari.py index 37ca991879..0d2fcab93c 100644 --- a/acme/jax/networks/atari.py +++ b/acme/jax/networks/atari.py @@ -22,162 +22,169 @@ - X?: X is optional (e.g. optional batch/sequence dimension). """ -from typing import Optional, Tuple, Sequence +from typing import Optional, Sequence, Tuple -from acme.jax.networks import base -from acme.jax.networks import duelling -from acme.jax.networks import embedding -from acme.jax.networks import policy_value -from acme.jax.networks import resnet -from acme.wrappers import observation_action_reward import haiku as hk import jax import jax.numpy as jnp +from acme.jax.networks import base, duelling, embedding, policy_value, resnet +from acme.wrappers import observation_action_reward + # Useful type aliases. Images = jnp.ndarray class AtariTorso(hk.Module): - """Simple convolutional stack commonly used for Atari.""" - - def __init__(self): - super().__init__(name='atari_torso') - self._network = hk.Sequential([ - hk.Conv2D(32, [8, 8], 4), jax.nn.relu, - hk.Conv2D(64, [4, 4], 2), jax.nn.relu, - hk.Conv2D(64, [3, 3], 1), jax.nn.relu - ]) - - def __call__(self, inputs: Images) -> jnp.ndarray: - inputs_rank = jnp.ndim(inputs) - batched_inputs = inputs_rank == 4 - if inputs_rank < 3 or inputs_rank > 4: - raise ValueError('Expected input BHWC or HWC. Got rank %d' % inputs_rank) - - outputs = self._network(inputs) - - if batched_inputs: - return jnp.reshape(outputs, [outputs.shape[0], -1]) # [B, D] - return jnp.reshape(outputs, [-1]) # [D] + """Simple convolutional stack commonly used for Atari.""" + + def __init__(self): + super().__init__(name="atari_torso") + self._network = hk.Sequential( + [ + hk.Conv2D(32, [8, 8], 4), + jax.nn.relu, + hk.Conv2D(64, [4, 4], 2), + jax.nn.relu, + hk.Conv2D(64, [3, 3], 1), + jax.nn.relu, + ] + ) + + def __call__(self, inputs: Images) -> jnp.ndarray: + inputs_rank = jnp.ndim(inputs) + batched_inputs = inputs_rank == 4 + if inputs_rank < 3 or inputs_rank > 4: + raise ValueError("Expected input BHWC or HWC. Got rank %d" % inputs_rank) + + outputs = self._network(inputs) + + if batched_inputs: + return jnp.reshape(outputs, [outputs.shape[0], -1]) # [B, D] + return jnp.reshape(outputs, [-1]) # [D] def dqn_atari_network(num_actions: int) -> base.QNetwork: - """A feed-forward network for use with Ape-X DQN.""" + """A feed-forward network for use with Ape-X DQN.""" - def network(inputs: Images) -> base.QValues: - model = hk.Sequential([ - AtariTorso(), - duelling.DuellingMLP(num_actions, hidden_sizes=[512]), - ]) - return model(inputs) + def network(inputs: Images) -> base.QValues: + model = hk.Sequential( + [AtariTorso(), duelling.DuellingMLP(num_actions, hidden_sizes=[512]),] + ) + return model(inputs) - return network + return network class DeepAtariTorso(hk.Module): - """Deep torso for Atari, from the IMPALA paper.""" - - def __init__( - self, - channels_per_group: Sequence[int] = (16, 32, 32), - blocks_per_group: Sequence[int] = (2, 2, 2), - downsampling_strategies: Sequence[resnet.DownsamplingStrategy] = ( - resnet.DownsamplingStrategy.CONV_MAX,) * 3, - hidden_sizes: Sequence[int] = (256,), - use_layer_norm: bool = False, - name: str = 'deep_atari_torso'): - super().__init__(name=name) - self._use_layer_norm = use_layer_norm - self.resnet = resnet.ResNetTorso( - channels_per_group=channels_per_group, - blocks_per_group=blocks_per_group, - downsampling_strategies=downsampling_strategies, - use_layer_norm=use_layer_norm) - # Make sure to activate the last layer as this torso is expected to feed - # into the rest of a bigger network. - self.mlp_head = hk.nets.MLP(output_sizes=hidden_sizes, activate_final=True) - - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - output = self.resnet(x) - output = jax.nn.relu(output) - output = hk.Flatten(preserve_dims=-3)(output) - output = self.mlp_head(output) - return output + """Deep torso for Atari, from the IMPALA paper.""" + + def __init__( + self, + channels_per_group: Sequence[int] = (16, 32, 32), + blocks_per_group: Sequence[int] = (2, 2, 2), + downsampling_strategies: Sequence[resnet.DownsamplingStrategy] = ( + resnet.DownsamplingStrategy.CONV_MAX, + ) + * 3, + hidden_sizes: Sequence[int] = (256,), + use_layer_norm: bool = False, + name: str = "deep_atari_torso", + ): + super().__init__(name=name) + self._use_layer_norm = use_layer_norm + self.resnet = resnet.ResNetTorso( + channels_per_group=channels_per_group, + blocks_per_group=blocks_per_group, + downsampling_strategies=downsampling_strategies, + use_layer_norm=use_layer_norm, + ) + # Make sure to activate the last layer as this torso is expected to feed + # into the rest of a bigger network. + self.mlp_head = hk.nets.MLP(output_sizes=hidden_sizes, activate_final=True) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + output = self.resnet(x) + output = jax.nn.relu(output) + output = hk.Flatten(preserve_dims=-3)(output) + output = self.mlp_head(output) + return output class DeepIMPALAAtariNetwork(hk.RNNCore): - """A recurrent network for use with IMPALA. + """A recurrent network for use with IMPALA. See https://arxiv.org/pdf/1802.01561.pdf for more information. """ - def __init__(self, num_actions: int): - super().__init__(name='impala_atari_network') - self._embed = embedding.OAREmbedding( - DeepAtariTorso(use_layer_norm=True), num_actions) - self._core = hk.GRU(256) - self._head = policy_value.PolicyValueHead(num_actions) - self._num_actions = num_actions + def __init__(self, num_actions: int): + super().__init__(name="impala_atari_network") + self._embed = embedding.OAREmbedding( + DeepAtariTorso(use_layer_norm=True), num_actions + ) + self._core = hk.GRU(256) + self._head = policy_value.PolicyValueHead(num_actions) + self._num_actions = num_actions - def __call__(self, inputs: observation_action_reward.OAR, - state: hk.LSTMState) -> base.LSTMOutputs: + def __call__( + self, inputs: observation_action_reward.OAR, state: hk.LSTMState + ) -> base.LSTMOutputs: - embeddings = self._embed(inputs) # [B?, D+A+1] - embeddings, new_state = self._core(embeddings, state) - logits, value = self._head(embeddings) # logits: [B?, A], value: [B?, 1] + embeddings = self._embed(inputs) # [B?, D+A+1] + embeddings, new_state = self._core(embeddings, state) + logits, value = self._head(embeddings) # logits: [B?, A], value: [B?, 1] - return (logits, value), new_state + return (logits, value), new_state - def initial_state(self, batch_size: Optional[int], - **unused_kwargs) -> hk.LSTMState: - return self._core.initial_state(batch_size) + def initial_state(self, batch_size: Optional[int], **unused_kwargs) -> hk.LSTMState: + return self._core.initial_state(batch_size) - def unroll(self, inputs: observation_action_reward.OAR, - state: hk.LSTMState) -> base.LSTMOutputs: - """Efficient unroll that applies embeddings, MLP, & convnet in one pass.""" - embeddings = self._embed(inputs) - embeddings, new_states = hk.static_unroll(self._core, embeddings, state) - logits, values = self._head(embeddings) + def unroll( + self, inputs: observation_action_reward.OAR, state: hk.LSTMState + ) -> base.LSTMOutputs: + """Efficient unroll that applies embeddings, MLP, & convnet in one pass.""" + embeddings = self._embed(inputs) + embeddings, new_states = hk.static_unroll(self._core, embeddings, state) + logits, values = self._head(embeddings) - return (logits, values), new_states + return (logits, values), new_states class R2D2AtariNetwork(hk.RNNCore): - """A duelling recurrent network for use with Atari observations as seen in R2D2. + """A duelling recurrent network for use with Atari observations as seen in R2D2. See https://openreview.net/forum?id=r1lyTjAqYX for more information. """ - def __init__(self, num_actions: int): - super().__init__(name='r2d2_atari_network') - self._embed = embedding.OAREmbedding( - DeepAtariTorso(hidden_sizes=[512], use_layer_norm=True), num_actions) - self._core = hk.LSTM(512) - self._duelling_head = duelling.DuellingMLP(num_actions, hidden_sizes=[512]) - self._num_actions = num_actions - - def __call__( - self, - inputs: observation_action_reward.OAR, # [B, ...] - state: hk.LSTMState # [B, ...] - ) -> Tuple[base.QValues, hk.LSTMState]: - embeddings = self._embed(inputs) # [B, D+A+1] - core_outputs, new_state = self._core(embeddings, state) - q_values = self._duelling_head(core_outputs) - return q_values, new_state - - def initial_state(self, batch_size: Optional[int], - **unused_kwargs) -> hk.LSTMState: - return self._core.initial_state(batch_size) - - def unroll( - self, - inputs: observation_action_reward.OAR, # [T, B, ...] - state: hk.LSTMState # [T, ...] - ) -> Tuple[base.QValues, hk.LSTMState]: - """Efficient unroll that applies torso, core, and duelling mlp in one pass.""" - embeddings = hk.BatchApply(self._embed)(inputs) # [T, B, D+A+1] - core_outputs, new_states = hk.static_unroll(self._core, embeddings, state) - q_values = hk.BatchApply(self._duelling_head)(core_outputs) # [T, B, A] - return q_values, new_states + def __init__(self, num_actions: int): + super().__init__(name="r2d2_atari_network") + self._embed = embedding.OAREmbedding( + DeepAtariTorso(hidden_sizes=[512], use_layer_norm=True), num_actions + ) + self._core = hk.LSTM(512) + self._duelling_head = duelling.DuellingMLP(num_actions, hidden_sizes=[512]) + self._num_actions = num_actions + + def __call__( + self, + inputs: observation_action_reward.OAR, # [B, ...] + state: hk.LSTMState, # [B, ...] + ) -> Tuple[base.QValues, hk.LSTMState]: + embeddings = self._embed(inputs) # [B, D+A+1] + core_outputs, new_state = self._core(embeddings, state) + q_values = self._duelling_head(core_outputs) + return q_values, new_state + + def initial_state(self, batch_size: Optional[int], **unused_kwargs) -> hk.LSTMState: + return self._core.initial_state(batch_size) + + def unroll( + self, + inputs: observation_action_reward.OAR, # [T, B, ...] + state: hk.LSTMState, # [T, ...] + ) -> Tuple[base.QValues, hk.LSTMState]: + """Efficient unroll that applies torso, core, and duelling mlp in one pass.""" + embeddings = hk.BatchApply(self._embed)(inputs) # [T, B, D+A+1] + core_outputs, new_states = hk.static_unroll(self._core, embeddings, state) + q_values = hk.BatchApply(self._duelling_head)(core_outputs) # [T, B, A] + return q_values, new_states diff --git a/acme/jax/networks/base.py b/acme/jax/networks/base.py index c6825f5184..b822e65ee3 100644 --- a/acme/jax/networks/base.py +++ b/acme/jax/networks/base.py @@ -17,14 +17,14 @@ import dataclasses from typing import Callable, Optional, Tuple -from acme import specs -from acme import types -from acme.jax import types as jax_types -from acme.jax import utils as jax_utils import haiku as hk import jax.numpy as jnp from typing_extensions import Protocol +from acme import specs, types +from acme.jax import types as jax_types +from acme.jax import utils as jax_utils + # This definition is deprecated. Use jax_types.PRNGKey directly instead. # TODO(sinopalnikov): migrate all users and remove this definition. PRNGKey = jax_types.PRNGKey @@ -46,53 +46,56 @@ QNetwork = Callable[[Observation], QValues] LSTMOutputs = Tuple[Tuple[Logits, Value], hk.LSTMState] PolicyValueRNN = Callable[[Observation, hk.LSTMState], LSTMOutputs] -RecurrentQNetwork = Callable[[Observation, hk.LSTMState], - Tuple[QValues, hk.LSTMState]] +RecurrentQNetwork = Callable[[Observation, hk.LSTMState], Tuple[QValues, hk.LSTMState]] SampleFn = Callable[[NetworkOutput, PRNGKey], Action] LogProbFn = Callable[[NetworkOutput, Action], LogProb] @dataclasses.dataclass class FeedForwardNetwork: - """Holds a pair of pure functions defining a feed-forward network. + """Holds a pair of pure functions defining a feed-forward network. Attributes: init: A pure function: ``params = init(rng, *a, **k)`` apply: A pure function: ``out = apply(params, rng, *a, **k)`` """ - # Initializes and returns the networks parameters. - init: Callable[..., Params] - # Computes and returns the outputs of a forward pass. - apply: Callable[..., NetworkOutput] + # Initializes and returns the networks parameters. + init: Callable[..., Params] + # Computes and returns the outputs of a forward pass. + apply: Callable[..., NetworkOutput] -class ApplyFn(Protocol): - def __call__(self, - params: Params, - observation: Observation, - *args, - is_training: bool, - key: Optional[PRNGKey] = None, - **kwargs) -> NetworkOutput: - ... +class ApplyFn(Protocol): + def __call__( + self, + params: Params, + observation: Observation, + *args, + is_training: bool, + key: Optional[PRNGKey] = None, + **kwargs + ) -> NetworkOutput: + ... @dataclasses.dataclass class TypedFeedForwardNetwork: - """FeedForwardNetwork with more specific types of the member functions. + """FeedForwardNetwork with more specific types of the member functions. Attributes: init: A pure function. Initializes and returns the networks parameters. apply: A pure function. Computes and returns the outputs of a forward pass. """ - init: Callable[[PRNGKey], Params] - apply: ApplyFn + + init: Callable[[PRNGKey], Params] + apply: ApplyFn def non_stochastic_network_to_typed( - network: FeedForwardNetwork) -> TypedFeedForwardNetwork: - """Converts non-stochastic FeedForwardNetwork to TypedFeedForwardNetwork. + network: FeedForwardNetwork, +) -> TypedFeedForwardNetwork: + """Converts non-stochastic FeedForwardNetwork to TypedFeedForwardNetwork. Non-stochastic network is the one that doesn't take a random key as an input for its `apply` method. @@ -104,56 +107,67 @@ def non_stochastic_network_to_typed( corresponding TypedFeedForwardNetwork """ - def apply(params: Params, - observation: Observation, - *args, - is_training: bool, - key: Optional[PRNGKey] = None, - **kwargs) -> NetworkOutput: - del is_training, key - return network.apply(params, observation, *args, **kwargs) + def apply( + params: Params, + observation: Observation, + *args, + is_training: bool, + key: Optional[PRNGKey] = None, + **kwargs + ) -> NetworkOutput: + del is_training, key + return network.apply(params, observation, *args, **kwargs) - return TypedFeedForwardNetwork(init=network.init, apply=apply) + return TypedFeedForwardNetwork(init=network.init, apply=apply) @dataclasses.dataclass class UnrollableNetwork: - """Network that can unroll over an input sequence.""" - init: Callable[[PRNGKey], Params] - apply: Callable[[Params, PRNGKey, Observation, RecurrentState], - Tuple[NetworkOutput, RecurrentState]] - unroll: Callable[[Params, PRNGKey, Observation, RecurrentState], - Tuple[NetworkOutput, RecurrentState]] - init_recurrent_state: Callable[[PRNGKey, Optional[BatchSize]], RecurrentState] - # TODO(b/244311990): Consider supporting parameterized and learnable initial - # state functions. + """Network that can unroll over an input sequence.""" + + init: Callable[[PRNGKey], Params] + apply: Callable[ + [Params, PRNGKey, Observation, RecurrentState], + Tuple[NetworkOutput, RecurrentState], + ] + unroll: Callable[ + [Params, PRNGKey, Observation, RecurrentState], + Tuple[NetworkOutput, RecurrentState], + ] + init_recurrent_state: Callable[[PRNGKey, Optional[BatchSize]], RecurrentState] + # TODO(b/244311990): Consider supporting parameterized and learnable initial + # state functions. def make_unrollable_network( - environment_spec: specs.EnvironmentSpec, - make_core_module: Callable[[], hk.RNNCore]) -> UnrollableNetwork: - """Builds an UnrollableNetwork from a hk.Module factory.""" + environment_spec: specs.EnvironmentSpec, make_core_module: Callable[[], hk.RNNCore] +) -> UnrollableNetwork: + """Builds an UnrollableNetwork from a hk.Module factory.""" - dummy_observation = jax_utils.zeros_like(environment_spec.observations) + dummy_observation = jax_utils.zeros_like(environment_spec.observations) - def make_unrollable_network_functions(): - model = make_core_module() - apply = model.__call__ + def make_unrollable_network_functions(): + model = make_core_module() + apply = model.__call__ - def init() -> Tuple[NetworkOutput, RecurrentState]: - return model(dummy_observation, model.initial_state(None)) + def init() -> Tuple[NetworkOutput, RecurrentState]: + return model(dummy_observation, model.initial_state(None)) - return init, (apply, model.unroll, model.initial_state) # pytype: disable=attribute-error + return ( + init, + (apply, model.unroll, model.initial_state), + ) # pytype: disable=attribute-error - # Transform and unpack pure functions - f = hk.multi_transform(make_unrollable_network_functions) - apply, unroll, initial_state_fn = f.apply + # Transform and unpack pure functions + f = hk.multi_transform(make_unrollable_network_functions) + apply, unroll, initial_state_fn = f.apply - def init_recurrent_state(key: jax_types.PRNGKey, - batch_size: Optional[int]) -> RecurrentState: - # TODO(b/244311990): Consider supporting parameterized and learnable initial - # state functions. - no_params = None - return initial_state_fn(no_params, key, batch_size) + def init_recurrent_state( + key: jax_types.PRNGKey, batch_size: Optional[int] + ) -> RecurrentState: + # TODO(b/244311990): Consider supporting parameterized and learnable initial + # state functions. + no_params = None + return initial_state_fn(no_params, key, batch_size) - return UnrollableNetwork(f.init, apply, unroll, init_recurrent_state) + return UnrollableNetwork(f.init, apply, unroll, init_recurrent_state) diff --git a/acme/jax/networks/continuous.py b/acme/jax/networks/continuous.py index 71f88fa5b2..65d391b9be 100644 --- a/acme/jax/networks/continuous.py +++ b/acme/jax/networks/continuous.py @@ -24,14 +24,14 @@ class NearZeroInitializedLinear(hk.Linear): - """Simple linear layer, initialized at near zero weights and zero biases.""" + """Simple linear layer, initialized at near zero weights and zero biases.""" - def __init__(self, output_size: int, scale: float = 1e-4): - super().__init__(output_size, w_init=hk.initializers.VarianceScaling(scale)) + def __init__(self, output_size: int, scale: float = 1e-4): + super().__init__(output_size, w_init=hk.initializers.VarianceScaling(scale)) class LayerNormMLP(hk.Module): - """Simple feedforward MLP torso with initial layer-norm. + """Simple feedforward MLP torso with initial layer-norm. This MLP's first linear layer is followed by a LayerNorm layer and a tanh non-linearity; subsequent layers use `activation`, which defaults to elu. @@ -40,13 +40,15 @@ class LayerNormMLP(hk.Module): legacy reasons. """ - def __init__(self, - layer_sizes: Sequence[int], - w_init: hk.initializers.Initializer = uniform_initializer, - activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.elu, - activate_final: bool = False, - name: str = 'feedforward_mlp_torso'): - """Construct the MLP. + def __init__( + self, + layer_sizes: Sequence[int], + w_init: hk.initializers.Initializer = uniform_initializer, + activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.elu, + activate_final: bool = False, + name: str = "feedforward_mlp_torso", + ): + """Construct the MLP. Args: layer_sizes: a sequence of ints specifying the size of each layer. @@ -58,19 +60,22 @@ def __init__(self, layer of the neural network. name: a name for the module. """ - super().__init__(name=name) + super().__init__(name=name) - self._network = hk.Sequential([ - hk.Linear(layer_sizes[0], w_init=w_init), - hk.LayerNorm(axis=-1, create_scale=True, create_offset=True), - jax.lax.tanh, - hk.nets.MLP( - layer_sizes[1:], - w_init=w_init, - activation=activation, - activate_final=activate_final), - ]) + self._network = hk.Sequential( + [ + hk.Linear(layer_sizes[0], w_init=w_init), + hk.LayerNorm(axis=-1, create_scale=True, create_offset=True), + jax.lax.tanh, + hk.nets.MLP( + layer_sizes[1:], + w_init=w_init, + activation=activation, + activate_final=activate_final, + ), + ] + ) - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: - """Forwards the policy network.""" - return self._network(inputs) + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + """Forwards the policy network.""" + return self._network(inputs) diff --git a/acme/jax/networks/distributional.py b/acme/jax/networks/distributional.py index c16945927f..7439c085b8 100644 --- a/acme/jax/networks/distributional.py +++ b/acme/jax/networks/distributional.py @@ -14,7 +14,7 @@ """Haiku modules that output tfd.Distributions.""" -from typing import Any, List, Optional, Sequence, Union, Callable +from typing import Any, Callable, List, Optional, Sequence, Union import chex import haiku as hk @@ -31,41 +31,44 @@ class CategoricalHead(hk.Module): - """Module that produces a categorical distribution with the given number of values.""" - - def __init__( - self, - num_values: Union[int, List[int]], - dtype: Optional[Any] = jnp.int32, - w_init: Optional[Initializer] = None, - name: Optional[str] = None, - ): - super().__init__(name=name) - self._dtype = dtype - self._logit_shape = num_values - self._linear = hk.Linear(np.prod(num_values), w_init=w_init) - - def __call__(self, inputs: jnp.ndarray) -> tfd.Distribution: - logits = self._linear(inputs) - if not isinstance(self._logit_shape, int): - logits = hk.Reshape(self._logit_shape)(logits) - return tfd.Categorical(logits=logits, dtype=self._dtype) + """Module that produces a categorical distribution with the given number of values.""" + + def __init__( + self, + num_values: Union[int, List[int]], + dtype: Optional[Any] = jnp.int32, + w_init: Optional[Initializer] = None, + name: Optional[str] = None, + ): + super().__init__(name=name) + self._dtype = dtype + self._logit_shape = num_values + self._linear = hk.Linear(np.prod(num_values), w_init=w_init) + + def __call__(self, inputs: jnp.ndarray) -> tfd.Distribution: + logits = self._linear(inputs) + if not isinstance(self._logit_shape, int): + logits = hk.Reshape(self._logit_shape)(logits) + return tfd.Categorical(logits=logits, dtype=self._dtype) class GaussianMixture(hk.Module): - """Module that outputs a Gaussian Mixture Distribution.""" - - def __init__(self, - num_dimensions: int, - num_components: int, - multivariate: bool, - init_scale: Optional[float] = None, - append_singleton_event_dim: bool = False, - reinterpreted_batch_ndims: Optional[int] = None, - transformation_fn: Optional[Callable[[tfd.Distribution], - tfd.Distribution]] = None, - name: str = 'GaussianMixture'): - """Initialization. + """Module that outputs a Gaussian Mixture Distribution.""" + + def __init__( + self, + num_dimensions: int, + num_components: int, + multivariate: bool, + init_scale: Optional[float] = None, + append_singleton_event_dim: bool = False, + reinterpreted_batch_ndims: Optional[int] = None, + transformation_fn: Optional[ + Callable[[tfd.Distribution], tfd.Distribution] + ] = None, + name: str = "GaussianMixture", + ): + """Initialization. Args: num_dimensions: dimensionality of the output distribution @@ -80,25 +83,25 @@ def __init__(self, applied to individual components. name: name of the module passed to snt.Module parent class. """ - super().__init__(name=name) + super().__init__(name=name) - self._num_dimensions = num_dimensions - self._num_components = num_components - self._multivariate = multivariate - self._append_singleton_event_dim = append_singleton_event_dim - self._reinterpreted_batch_ndims = reinterpreted_batch_ndims + self._num_dimensions = num_dimensions + self._num_components = num_components + self._multivariate = multivariate + self._append_singleton_event_dim = append_singleton_event_dim + self._reinterpreted_batch_ndims = reinterpreted_batch_ndims - if init_scale is not None: - self._scale_factor = init_scale / jax.nn.softplus(0.) - else: - self._scale_factor = 1.0 # Corresponds to init_scale = softplus(0). + if init_scale is not None: + self._scale_factor = init_scale / jax.nn.softplus(0.0) + else: + self._scale_factor = 1.0 # Corresponds to init_scale = softplus(0). - self._transformation_fn = transformation_fn + self._transformation_fn = transformation_fn - def __call__(self, - inputs: jnp.ndarray, - low_noise_policy: bool = False) -> tfd.Distribution: - """Run the networks through inputs. + def __call__( + self, inputs: jnp.ndarray, low_noise_policy: bool = False + ) -> tfd.Distribution: + """Run the networks through inputs. Args: inputs: hidden activations of the policy network body. @@ -110,140 +113,150 @@ def __call__(self, Mixture Gaussian distribution. """ - # Define the weight initializer. - w_init = hk.initializers.VarianceScaling(scale=1e-5) - - # Create a layer that outputs the unnormalized log-weights. - if self._multivariate: - logits_size = self._num_components - else: - logits_size = self._num_dimensions * self._num_components - logit_layer = hk.Linear(logits_size, w_init=w_init) - - # Create two layers that outputs a location and a scale, respectively, for - # each dimension and each component. - loc_layer = hk.Linear( - self._num_dimensions * self._num_components, w_init=w_init) - scale_layer = hk.Linear( - self._num_dimensions * self._num_components, w_init=w_init) - - # Compute logits, locs, and scales if necessary. - logits = logit_layer(inputs) - locs = loc_layer(inputs) - - # When a low_noise_policy is requested, set the scales to its minimum value. - if low_noise_policy: - scales = jnp.full(locs.shape, _MIN_SCALE) - else: - scales = scale_layer(inputs) - scales = self._scale_factor * jax.nn.softplus(scales) + _MIN_SCALE - - if self._multivariate: - components_class = tfd.MultivariateNormalDiag - shape = [-1, self._num_components, self._num_dimensions] # [B, C, D] - # In this case, no need to reshape logits as they are in the correct shape - # already, namely [batch_size, num_components]. - else: - components_class = tfd.Normal - shape = [-1, self._num_dimensions, self._num_components] # [B, D, C] - if self._append_singleton_event_dim: - shape.insert(2, 1) # [B, D, 1, C] - logits = logits.reshape(shape) - - # Reshape the mixture's location and scale parameters appropriately. - locs = locs.reshape(shape) - scales = scales.reshape(shape) - - if self._multivariate: - components_distribution = components_class(loc=locs, scale_diag=scales) - else: - components_distribution = components_class(loc=locs, scale=scales) - - # Transformed the component distributions in the mixture. - if self._transformation_fn: - components_distribution = self._transformation_fn(components_distribution) - - # Create the mixture distribution. - distribution = tfd.MixtureSameFamily( - mixture_distribution=tfd.Categorical(logits=logits), - components_distribution=components_distribution) - - if not self._multivariate: - distribution = tfd.Independent( - distribution, - reinterpreted_batch_ndims=self._reinterpreted_batch_ndims) - - return distribution + # Define the weight initializer. + w_init = hk.initializers.VarianceScaling(scale=1e-5) + + # Create a layer that outputs the unnormalized log-weights. + if self._multivariate: + logits_size = self._num_components + else: + logits_size = self._num_dimensions * self._num_components + logit_layer = hk.Linear(logits_size, w_init=w_init) + + # Create two layers that outputs a location and a scale, respectively, for + # each dimension and each component. + loc_layer = hk.Linear( + self._num_dimensions * self._num_components, w_init=w_init + ) + scale_layer = hk.Linear( + self._num_dimensions * self._num_components, w_init=w_init + ) + + # Compute logits, locs, and scales if necessary. + logits = logit_layer(inputs) + locs = loc_layer(inputs) + + # When a low_noise_policy is requested, set the scales to its minimum value. + if low_noise_policy: + scales = jnp.full(locs.shape, _MIN_SCALE) + else: + scales = scale_layer(inputs) + scales = self._scale_factor * jax.nn.softplus(scales) + _MIN_SCALE + + if self._multivariate: + components_class = tfd.MultivariateNormalDiag + shape = [-1, self._num_components, self._num_dimensions] # [B, C, D] + # In this case, no need to reshape logits as they are in the correct shape + # already, namely [batch_size, num_components]. + else: + components_class = tfd.Normal + shape = [-1, self._num_dimensions, self._num_components] # [B, D, C] + if self._append_singleton_event_dim: + shape.insert(2, 1) # [B, D, 1, C] + logits = logits.reshape(shape) + + # Reshape the mixture's location and scale parameters appropriately. + locs = locs.reshape(shape) + scales = scales.reshape(shape) + + if self._multivariate: + components_distribution = components_class(loc=locs, scale_diag=scales) + else: + components_distribution = components_class(loc=locs, scale=scales) + + # Transformed the component distributions in the mixture. + if self._transformation_fn: + components_distribution = self._transformation_fn(components_distribution) + + # Create the mixture distribution. + distribution = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical(logits=logits), + components_distribution=components_distribution, + ) + + if not self._multivariate: + distribution = tfd.Independent( + distribution, reinterpreted_batch_ndims=self._reinterpreted_batch_ndims + ) + + return distribution class TanhTransformedDistribution(tfd.TransformedDistribution): - """Distribution followed by tanh.""" + """Distribution followed by tanh.""" - def __init__(self, distribution, threshold=.999, validate_args=False): - """Initialize the distribution. + def __init__(self, distribution, threshold=0.999, validate_args=False): + """Initialize the distribution. Args: distribution: The distribution to transform. threshold: Clipping value of the action when computing the logprob. validate_args: Passed to super class. """ - super().__init__( - distribution=distribution, - bijector=tfp.bijectors.Tanh(), - validate_args=validate_args) - # Computes the log of the average probability distribution outside the - # clipping range, i.e. on the interval [-inf, -atanh(threshold)] for - # log_prob_left and [atanh(threshold), inf] for log_prob_right. - self._threshold = threshold - inverse_threshold = self.bijector.inverse(threshold) - # average(pdf) = p/epsilon - # So log(average(pdf)) = log(p) - log(epsilon) - log_epsilon = jnp.log(1. - threshold) - # Those 2 values are differentiable w.r.t. model parameters, such that the - # gradient is defined everywhere. - self._log_prob_left = self.distribution.log_cdf( - -inverse_threshold) - log_epsilon - self._log_prob_right = self.distribution.log_survival_function( - inverse_threshold) - log_epsilon - - def log_prob(self, event): - # Without this clip there would be NaNs in the inner tf.where and that - # causes issues for some reasons. - event = jnp.clip(event, -self._threshold, self._threshold) - # The inverse image of {threshold} is the interval [atanh(threshold), inf] - # which has a probability of "log_prob_right" under the given distribution. - return jnp.where( - event <= -self._threshold, self._log_prob_left, - jnp.where(event >= self._threshold, self._log_prob_right, - super().log_prob(event))) - - def mode(self): - return self.bijector.forward(self.distribution.mode()) - - def entropy(self, seed=None): - # We return an estimation using a single sample of the log_det_jacobian. - # We can still do some backpropagation with this estimate. - return self.distribution.entropy() + self.bijector.forward_log_det_jacobian( - self.distribution.sample(seed=seed), event_ndims=0) - - @classmethod - def _parameter_properties(cls, dtype: Optional[Any], num_classes=None): - td_properties = super()._parameter_properties(dtype, - num_classes=num_classes) - del td_properties['bijector'] - return td_properties + super().__init__( + distribution=distribution, + bijector=tfp.bijectors.Tanh(), + validate_args=validate_args, + ) + # Computes the log of the average probability distribution outside the + # clipping range, i.e. on the interval [-inf, -atanh(threshold)] for + # log_prob_left and [atanh(threshold), inf] for log_prob_right. + self._threshold = threshold + inverse_threshold = self.bijector.inverse(threshold) + # average(pdf) = p/epsilon + # So log(average(pdf)) = log(p) - log(epsilon) + log_epsilon = jnp.log(1.0 - threshold) + # Those 2 values are differentiable w.r.t. model parameters, such that the + # gradient is defined everywhere. + self._log_prob_left = ( + self.distribution.log_cdf(-inverse_threshold) - log_epsilon + ) + self._log_prob_right = ( + self.distribution.log_survival_function(inverse_threshold) - log_epsilon + ) + + def log_prob(self, event): + # Without this clip there would be NaNs in the inner tf.where and that + # causes issues for some reasons. + event = jnp.clip(event, -self._threshold, self._threshold) + # The inverse image of {threshold} is the interval [atanh(threshold), inf] + # which has a probability of "log_prob_right" under the given distribution. + return jnp.where( + event <= -self._threshold, + self._log_prob_left, + jnp.where( + event >= self._threshold, self._log_prob_right, super().log_prob(event) + ), + ) + + def mode(self): + return self.bijector.forward(self.distribution.mode()) + + def entropy(self, seed=None): + # We return an estimation using a single sample of the log_det_jacobian. + # We can still do some backpropagation with this estimate. + return self.distribution.entropy() + self.bijector.forward_log_det_jacobian( + self.distribution.sample(seed=seed), event_ndims=0 + ) + + @classmethod + def _parameter_properties(cls, dtype: Optional[Any], num_classes=None): + td_properties = super()._parameter_properties(dtype, num_classes=num_classes) + del td_properties["bijector"] + return td_properties class NormalTanhDistribution(hk.Module): - """Module that produces a TanhTransformedDistribution distribution.""" + """Module that produces a TanhTransformedDistribution distribution.""" - def __init__(self, - num_dimensions: int, - min_scale: float = 1e-3, - w_init: hk_init.Initializer = hk_init.VarianceScaling( - 1.0, 'fan_in', 'uniform'), - b_init: hk_init.Initializer = hk_init.Constant(0.)): - """Initialization. + def __init__( + self, + num_dimensions: int, + min_scale: float = 1e-3, + w_init: hk_init.Initializer = hk_init.VarianceScaling(1.0, "fan_in", "uniform"), + b_init: hk_init.Initializer = hk_init.Constant(0.0), + ): + """Initialization. Args: num_dimensions: Number of dimensions of a distribution. @@ -251,30 +264,33 @@ def __init__(self, w_init: Initialization for linear layer weights. b_init: Initialization for linear layer biases. """ - super().__init__(name='Normal') - self._min_scale = min_scale - self._loc_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) - self._scale_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) + super().__init__(name="Normal") + self._min_scale = min_scale + self._loc_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) + self._scale_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) - def __call__(self, inputs: jnp.ndarray) -> tfd.Distribution: - loc = self._loc_layer(inputs) - scale = self._scale_layer(inputs) - scale = jax.nn.softplus(scale) + self._min_scale - distribution = tfd.Normal(loc=loc, scale=scale) - return tfd.Independent( - TanhTransformedDistribution(distribution), reinterpreted_batch_ndims=1) + def __call__(self, inputs: jnp.ndarray) -> tfd.Distribution: + loc = self._loc_layer(inputs) + scale = self._scale_layer(inputs) + scale = jax.nn.softplus(scale) + self._min_scale + distribution = tfd.Normal(loc=loc, scale=scale) + return tfd.Independent( + TanhTransformedDistribution(distribution), reinterpreted_batch_ndims=1 + ) class MultivariateNormalDiagHead(hk.Module): - """Module that produces a tfd.MultivariateNormalDiag distribution.""" - - def __init__(self, - num_dimensions: int, - init_scale: float = 0.3, - min_scale: float = 1e-6, - w_init: hk_init.Initializer = hk_init.VarianceScaling(1e-4), - b_init: hk_init.Initializer = hk_init.Constant(0.)): - """Initialization. + """Module that produces a tfd.MultivariateNormalDiag distribution.""" + + def __init__( + self, + num_dimensions: int, + init_scale: float = 0.3, + min_scale: float = 1e-6, + w_init: hk_init.Initializer = hk_init.VarianceScaling(1e-4), + b_init: hk_init.Initializer = hk_init.Constant(0.0), + ): + """Initialization. Args: num_dimensions: Number of dimensions of MVN distribution. @@ -283,108 +299,107 @@ def __init__(self, w_init: Initialization for linear layer weights. b_init: Initialization for linear layer biases. """ - super().__init__(name='MultivariateNormalDiagHead') - self._min_scale = min_scale - self._init_scale = init_scale - self._loc_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) - self._scale_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) + super().__init__(name="MultivariateNormalDiagHead") + self._min_scale = min_scale + self._init_scale = init_scale + self._loc_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) + self._scale_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) - def __call__(self, inputs: jnp.ndarray) -> tfd.Distribution: - loc = self._loc_layer(inputs) - scale = jax.nn.softplus(self._scale_layer(inputs)) - scale *= self._init_scale / jax.nn.softplus(0.) - scale += self._min_scale - return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) + def __call__(self, inputs: jnp.ndarray) -> tfd.Distribution: + loc = self._loc_layer(inputs) + scale = jax.nn.softplus(self._scale_layer(inputs)) + scale *= self._init_scale / jax.nn.softplus(0.0) + scale += self._min_scale + return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) class CategoricalValueHead(hk.Module): - """Network head that produces a categorical distribution and value.""" + """Network head that produces a categorical distribution and value.""" - def __init__( - self, - num_values: int, - name: Optional[str] = None, - ): - super().__init__(name=name) - self._logit_layer = hk.Linear(num_values) - self._value_layer = hk.Linear(1) + def __init__( + self, num_values: int, name: Optional[str] = None, + ): + super().__init__(name=name) + self._logit_layer = hk.Linear(num_values) + self._value_layer = hk.Linear(1) - def __call__(self, inputs: jnp.ndarray): - logits = self._logit_layer(inputs) - value = jnp.squeeze(self._value_layer(inputs), axis=-1) - return (tfd.Categorical(logits=logits), value) + def __call__(self, inputs: jnp.ndarray): + logits = self._logit_layer(inputs) + value = jnp.squeeze(self._value_layer(inputs), axis=-1) + return (tfd.Categorical(logits=logits), value) class DiscreteValued(hk.Module): - """C51-style head. + """C51-style head. For each action, it produces the logits for a discrete distribution over atoms. Therefore, the returned logits represents several distributions, one for each action. """ - def __init__( - self, - num_actions: int, - head_units: int = 512, - num_atoms: int = 51, - v_min: float = -1.0, - v_max: float = 1.0, - ): - super().__init__('DiscreteValued') - self._num_actions = num_actions - self._num_atoms = num_atoms - self._atoms = jnp.linspace(v_min, v_max, self._num_atoms) - self._network = hk.nets.MLP([head_units, num_actions * num_atoms]) - - def __call__(self, inputs: jnp.ndarray): - q_logits = self._network(inputs) - q_logits = jnp.reshape(q_logits, (-1, self._num_actions, self._num_atoms)) - q_dist = jax.nn.softmax(q_logits) - q_values = jnp.sum(q_dist * self._atoms, axis=2) - q_values = jax.lax.stop_gradient(q_values) - return q_values, q_logits, self._atoms + def __init__( + self, + num_actions: int, + head_units: int = 512, + num_atoms: int = 51, + v_min: float = -1.0, + v_max: float = 1.0, + ): + super().__init__("DiscreteValued") + self._num_actions = num_actions + self._num_atoms = num_atoms + self._atoms = jnp.linspace(v_min, v_max, self._num_atoms) + self._network = hk.nets.MLP([head_units, num_actions * num_atoms]) + + def __call__(self, inputs: jnp.ndarray): + q_logits = self._network(inputs) + q_logits = jnp.reshape(q_logits, (-1, self._num_actions, self._num_atoms)) + q_dist = jax.nn.softmax(q_logits) + q_values = jnp.sum(q_dist * self._atoms, axis=2) + q_values = jax.lax.stop_gradient(q_values) + return q_values, q_logits, self._atoms class CategoricalCriticHead(hk.Module): - """Critic head that uses a categorical to represent action values.""" + """Critic head that uses a categorical to represent action values.""" - def __init__(self, - num_bins: int = 601, - vmax: Optional[float] = None, - vmin: Optional[float] = None, - w_init: hk_init.Initializer = hk_init.VarianceScaling(1e-5)): - super().__init__(name='categorical_critic_head') - vmax = vmax if vmax is not None else 0.5 * (num_bins - 1) - vmin = vmin if vmin is not None else -1.0 * vmax + def __init__( + self, + num_bins: int = 601, + vmax: Optional[float] = None, + vmin: Optional[float] = None, + w_init: hk_init.Initializer = hk_init.VarianceScaling(1e-5), + ): + super().__init__(name="categorical_critic_head") + vmax = vmax if vmax is not None else 0.5 * (num_bins - 1) + vmin = vmin if vmin is not None else -1.0 * vmax - self._head = DiscreteValuedTfpHead( - vmin=vmin, - vmax=vmax, - logits_shape=(1,), - num_atoms=num_bins, - w_init=w_init) + self._head = DiscreteValuedTfpHead( + vmin=vmin, vmax=vmax, logits_shape=(1,), num_atoms=num_bins, w_init=w_init + ) - def __call__(self, embedding: chex.Array) -> tfd.Distribution: - output = self._head(embedding) - return output + def __call__(self, embedding: chex.Array) -> tfd.Distribution: + output = self._head(embedding) + return output class DiscreteValuedTfpHead(hk.Module): - """Represents a parameterized discrete valued distribution. + """Represents a parameterized discrete valued distribution. The returned distribution is essentially a `tfd.Categorical` that knows its support and thus can compute the mean value. """ - def __init__(self, - vmin: float, - vmax: float, - num_atoms: int, - logits_shape: Optional[Sequence[int]] = None, - w_init: Optional[Initializer] = None, - b_init: Optional[Initializer] = None): - """Initialization. + def __init__( + self, + vmin: float, + vmax: float, + num_atoms: int, + logits_shape: Optional[Sequence[int]] = None, + w_init: Optional[Initializer] = None, + b_init: Optional[Initializer] = None, + ): + """Initialization. If vmin and vmax have shape S, this will store the category values as a Tensor of shape (S*, num_atoms). @@ -398,25 +413,26 @@ def __init__(self, w_init: Initialization for linear layer weights. b_init: Initialization for linear layer biases. """ - super().__init__(name='DiscreteValuedHead') - self._values = np.linspace(vmin, vmax, num=num_atoms, axis=-1) - if not logits_shape: - logits_shape = () - self._logits_shape = logits_shape + (num_atoms,) - self._w_init = w_init - self._b_init = b_init - - def __call__(self, inputs: chex.Array) -> tfd.Distribution: - net = hk.Linear( - np.prod(self._logits_shape), w_init=self._w_init, b_init=self._b_init) - logits = net(inputs) - logits = hk.Reshape(self._logits_shape, preserve_dims=1)(logits) - return DiscreteValuedTfpDistribution(values=self._values, logits=logits) + super().__init__(name="DiscreteValuedHead") + self._values = np.linspace(vmin, vmax, num=num_atoms, axis=-1) + if not logits_shape: + logits_shape = () + self._logits_shape = logits_shape + (num_atoms,) + self._w_init = w_init + self._b_init = b_init + + def __call__(self, inputs: chex.Array) -> tfd.Distribution: + net = hk.Linear( + np.prod(self._logits_shape), w_init=self._w_init, b_init=self._b_init + ) + logits = net(inputs) + logits = hk.Reshape(self._logits_shape, preserve_dims=1)(logits) + return DiscreteValuedTfpDistribution(values=self._values, logits=logits) @tf_tfp.experimental.auto_composite_tensor class DiscreteValuedTfpDistribution(tfd.Categorical): - """This is a generalization of a categorical distribution. + """This is a generalization of a categorical distribution. The support for the DiscreteValued distribution can be any real valued range, whereas the categorical distribution has support [0, n_categories - 1] or @@ -424,12 +440,14 @@ class DiscreteValuedTfpDistribution(tfd.Categorical): distribution over its support. """ - def __init__(self, - values: chex.Array, - logits: Optional[chex.Array] = None, - probs: Optional[chex.Array] = None, - name: str = 'DiscreteValuedDistribution'): - """Initialization. + def __init__( + self, + values: chex.Array, + logits: Optional[chex.Array] = None, + probs: Optional[chex.Array] = None, + name: str = "DiscreteValuedDistribution", + ): + """Initialization. Args: values: Values making up support of the distribution. Should have a shape @@ -445,50 +463,52 @@ def __init__(self, passed in. name: Name of the distribution object. """ - parameters = dict(locals()) - self._values = np.asarray(values) - - if logits is not None: - logits = jnp.asarray(logits) - chex.assert_shape(logits, (..., *self._values.shape)) - - if probs is not None: - probs = jnp.asarray(probs) - chex.assert_shape(probs, (..., *self._values.shape)) - - super().__init__(logits=logits, probs=probs, name=name) - - self._parameters = parameters - - @property - def values(self): - return self._values - - @classmethod - def _parameter_properties(cls, dtype, num_classes=None): - return dict( - values=tfp.util.ParameterProperties( - event_ndims=None, - shape_fn=lambda shape: (num_classes,), - specifies_shape=True), - logits=tfp.util.ParameterProperties(event_ndims=1), - probs=tfp.util.ParameterProperties(event_ndims=1, is_preferred=False)) - - def _sample_n(self, key: chex.PRNGKey, n: int) -> chex.Array: - indices = super()._sample_n(key=key, n=n) - return jnp.take_along_axis(self._values, indices, axis=-1) - - def mean(self) -> chex.Array: - """Overrides the Categorical mean by incorporating category values.""" - return jnp.sum(self.probs_parameter() * self._values, axis=-1) - - def variance(self) -> chex.Array: - """Overrides the Categorical variance by incorporating category values.""" - dist_squared = jnp.square(jnp.expand_dims(self.mean(), -1) - self._values) - return jnp.sum(self.probs_parameter() * dist_squared, axis=-1) - - def _event_shape(self): - return jnp.zeros((), dtype=jnp.int32) - - def _event_shape_tensor(self): - return [] + parameters = dict(locals()) + self._values = np.asarray(values) + + if logits is not None: + logits = jnp.asarray(logits) + chex.assert_shape(logits, (..., *self._values.shape)) + + if probs is not None: + probs = jnp.asarray(probs) + chex.assert_shape(probs, (..., *self._values.shape)) + + super().__init__(logits=logits, probs=probs, name=name) + + self._parameters = parameters + + @property + def values(self): + return self._values + + @classmethod + def _parameter_properties(cls, dtype, num_classes=None): + return dict( + values=tfp.util.ParameterProperties( + event_ndims=None, + shape_fn=lambda shape: (num_classes,), + specifies_shape=True, + ), + logits=tfp.util.ParameterProperties(event_ndims=1), + probs=tfp.util.ParameterProperties(event_ndims=1, is_preferred=False), + ) + + def _sample_n(self, key: chex.PRNGKey, n: int) -> chex.Array: + indices = super()._sample_n(key=key, n=n) + return jnp.take_along_axis(self._values, indices, axis=-1) + + def mean(self) -> chex.Array: + """Overrides the Categorical mean by incorporating category values.""" + return jnp.sum(self.probs_parameter() * self._values, axis=-1) + + def variance(self) -> chex.Array: + """Overrides the Categorical variance by incorporating category values.""" + dist_squared = jnp.square(jnp.expand_dims(self.mean(), -1) - self._values) + return jnp.sum(self.probs_parameter() * dist_squared, axis=-1) + + def _event_shape(self): + return jnp.zeros((), dtype=jnp.int32) + + def _event_shape_tensor(self): + return [] diff --git a/acme/jax/networks/duelling.py b/acme/jax/networks/duelling.py index 8db6d173d2..156c08c9fd 100644 --- a/acme/jax/networks/duelling.py +++ b/acme/jax/networks/duelling.py @@ -17,29 +17,28 @@ [0] https://arxiv.org/abs/1511.06581 """ -from typing import Sequence, Optional +from typing import Optional, Sequence import haiku as hk import jax.numpy as jnp class DuellingMLP(hk.Module): - """A Duelling MLP Q-network.""" + """A Duelling MLP Q-network.""" - def __init__( - self, - num_actions: int, - hidden_sizes: Sequence[int], - w_init: Optional[hk.initializers.Initializer] = None, - ): - super().__init__(name='duelling_q_network') + def __init__( + self, + num_actions: int, + hidden_sizes: Sequence[int], + w_init: Optional[hk.initializers.Initializer] = None, + ): + super().__init__(name="duelling_q_network") - self._value_mlp = hk.nets.MLP([*hidden_sizes, 1], w_init=w_init) - self._advantage_mlp = hk.nets.MLP([*hidden_sizes, num_actions], - w_init=w_init) + self._value_mlp = hk.nets.MLP([*hidden_sizes, 1], w_init=w_init) + self._advantage_mlp = hk.nets.MLP([*hidden_sizes, num_actions], w_init=w_init) - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: - """Forward pass of the duelling network. + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + """Forward pass of the duelling network. Args: inputs: 2-D tensor of shape [batch_size, embedding_size]. @@ -48,13 +47,13 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: q_values: 2-D tensor of action values of shape [batch_size, num_actions] """ - # Compute value & advantage for duelling. - value = self._value_mlp(inputs) # [B, 1] - advantages = self._advantage_mlp(inputs) # [B, A] + # Compute value & advantage for duelling. + value = self._value_mlp(inputs) # [B, 1] + advantages = self._advantage_mlp(inputs) # [B, A] - # Advantages have zero mean. - advantages -= jnp.mean(advantages, axis=-1, keepdims=True) # [B, A] + # Advantages have zero mean. + advantages -= jnp.mean(advantages, axis=-1, keepdims=True) # [B, A] - q_values = value + advantages # [B, A] + q_values = value + advantages # [B, A] - return q_values + return q_values diff --git a/acme/jax/networks/embedding.py b/acme/jax/networks/embedding.py index 2510c6a62c..30b030418f 100644 --- a/acme/jax/networks/embedding.py +++ b/acme/jax/networks/embedding.py @@ -16,46 +16,50 @@ import dataclasses -from acme.wrappers import observation_action_reward import haiku as hk import jax import jax.numpy as jnp +from acme.wrappers import observation_action_reward + @dataclasses.dataclass class OAREmbedding(hk.Module): - """Module for embedding (observation, action, reward) inputs together.""" + """Module for embedding (observation, action, reward) inputs together.""" - torso: hk.SupportsCall - num_actions: int + torso: hk.SupportsCall + num_actions: int - def __call__(self, inputs: observation_action_reward.OAR) -> jnp.ndarray: - """Embed each of the (observation, action, reward) inputs & concatenate.""" + def __call__(self, inputs: observation_action_reward.OAR) -> jnp.ndarray: + """Embed each of the (observation, action, reward) inputs & concatenate.""" - # Add dummy batch dimension to observation if necessary. - # This is needed because Conv2D assumes a leading batch dimension, i.e. - # that inputs are in [B, H, W, C] format. - expand_obs = len(inputs.observation.shape) == 3 - if expand_obs: - inputs = inputs._replace( - observation=jnp.expand_dims(inputs.observation, axis=0)) - features = self.torso(inputs.observation) # [T?, B, D] - if expand_obs: - features = jnp.squeeze(features, axis=0) + # Add dummy batch dimension to observation if necessary. + # This is needed because Conv2D assumes a leading batch dimension, i.e. + # that inputs are in [B, H, W, C] format. + expand_obs = len(inputs.observation.shape) == 3 + if expand_obs: + inputs = inputs._replace( + observation=jnp.expand_dims(inputs.observation, axis=0) + ) + features = self.torso(inputs.observation) # [T?, B, D] + if expand_obs: + features = jnp.squeeze(features, axis=0) - # Do a one-hot embedding of the actions. - action = jax.nn.one_hot( - inputs.action, num_classes=self.num_actions) # [T?, B, A] + # Do a one-hot embedding of the actions. + action = jax.nn.one_hot( + inputs.action, num_classes=self.num_actions + ) # [T?, B, A] - # Map rewards -> [-1, 1]. - reward = jnp.tanh(inputs.reward) + # Map rewards -> [-1, 1]. + reward = jnp.tanh(inputs.reward) - # Add dummy trailing dimensions to rewards if necessary. - while reward.ndim < action.ndim: - reward = jnp.expand_dims(reward, axis=-1) + # Add dummy trailing dimensions to rewards if necessary. + while reward.ndim < action.ndim: + reward = jnp.expand_dims(reward, axis=-1) - # Concatenate on final dimension. - embedding = jnp.concatenate( - [features, action, reward], axis=-1) # [T?, B, D+A+1] + # Concatenate on final dimension. + embedding = jnp.concatenate( + [features, action, reward], axis=-1 + ) # [T?, B, D+A+1] - return embedding + return embedding diff --git a/acme/jax/networks/multiplexers.py b/acme/jax/networks/multiplexers.py index 632f0b735f..7e386e331c 100644 --- a/acme/jax/networks/multiplexers.py +++ b/acme/jax/networks/multiplexers.py @@ -16,17 +16,18 @@ from typing import Callable, Optional, Union -from acme.jax import utils import haiku as hk import jax.numpy as jnp import tensorflow_probability +from acme.jax import utils + tfd = tensorflow_probability.substrates.jax.distributions ModuleOrArrayTransform = Union[hk.Module, Callable[[jnp.ndarray], jnp.ndarray]] class CriticMultiplexer(hk.Module): - """Module connecting a critic torso to (transformed) observations/actions. + """Module connecting a critic torso to (transformed) observations/actions. This takes as input a `critic_network`, an `observation_network`, and an `action_network` and returns another network whose outputs are given by @@ -42,30 +43,30 @@ class CriticMultiplexer(hk.Module): module reduces to a simple `tf2_utils.batch_concat()`. """ - def __init__(self, - critic_network: Optional[ModuleOrArrayTransform] = None, - observation_network: Optional[ModuleOrArrayTransform] = None, - action_network: Optional[ModuleOrArrayTransform] = None): - self._critic_network = critic_network - self._observation_network = observation_network - self._action_network = action_network - super().__init__(name='critic_multiplexer') - - def __call__(self, - observation: jnp.ndarray, - action: jnp.ndarray) -> jnp.ndarray: - - # Maybe transform observations and actions before feeding them on. - if self._observation_network: - observation = self._observation_network(observation) - if self._action_network: - action = self._action_network(action) - - # Concat observations and actions, with one batch dimension. - outputs = utils.batch_concat([observation, action]) - - # Maybe transform output before returning. - if self._critic_network: - outputs = self._critic_network(outputs) - - return outputs + def __init__( + self, + critic_network: Optional[ModuleOrArrayTransform] = None, + observation_network: Optional[ModuleOrArrayTransform] = None, + action_network: Optional[ModuleOrArrayTransform] = None, + ): + self._critic_network = critic_network + self._observation_network = observation_network + self._action_network = action_network + super().__init__(name="critic_multiplexer") + + def __call__(self, observation: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: + + # Maybe transform observations and actions before feeding them on. + if self._observation_network: + observation = self._observation_network(observation) + if self._action_network: + action = self._action_network(action) + + # Concat observations and actions, with one batch dimension. + outputs = utils.batch_concat([observation, action]) + + # Maybe transform output before returning. + if self._critic_network: + outputs = self._critic_network(outputs) + + return outputs diff --git a/acme/jax/networks/policy_value.py b/acme/jax/networks/policy_value.py index 509d3fb18e..5365554c73 100644 --- a/acme/jax/networks/policy_value.py +++ b/acme/jax/networks/policy_value.py @@ -21,16 +21,16 @@ class PolicyValueHead(hk.Module): - """A network with two linear layers, for policy and value respectively.""" + """A network with two linear layers, for policy and value respectively.""" - def __init__(self, num_actions: int): - super().__init__(name='policy_value_network') - self._policy_layer = hk.Linear(num_actions) - self._value_layer = hk.Linear(1) + def __init__(self, num_actions: int): + super().__init__(name="policy_value_network") + self._policy_layer = hk.Linear(num_actions) + self._value_layer = hk.Linear(1) - def __call__(self, inputs: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Returns a (Logits, Value) tuple.""" - logits = self._policy_layer(inputs) # [B, A] - value = jnp.squeeze(self._value_layer(inputs), axis=-1) # [B] + def __call__(self, inputs: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Returns a (Logits, Value) tuple.""" + logits = self._policy_layer(inputs) # [B, A] + value = self._value_layer(inputs)[..., 0] # [B] - return logits, value + return logits, value diff --git a/acme/jax/networks/rescaling.py b/acme/jax/networks/rescaling.py index 4d63c84593..ccb0f2392f 100644 --- a/acme/jax/networks/rescaling.py +++ b/acme/jax/networks/rescaling.py @@ -16,42 +16,46 @@ import dataclasses -from acme import specs -from jax import lax import jax.numpy as jnp +from jax import lax + +from acme import specs @dataclasses.dataclass class ClipToSpec: - """Clips inputs to within a BoundedArraySpec.""" - spec: specs.BoundedArray + """Clips inputs to within a BoundedArraySpec.""" - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: - return jnp.clip(inputs, self.spec.minimum, self.spec.maximum) + spec: specs.BoundedArray + + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + return jnp.clip(inputs, self.spec.minimum, self.spec.maximum) @dataclasses.dataclass class RescaleToSpec: - """Rescales inputs in [-1, 1] to match a BoundedArraySpec.""" - spec: specs.BoundedArray + """Rescales inputs in [-1, 1] to match a BoundedArraySpec.""" + + spec: specs.BoundedArray - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: - scale = self.spec.maximum - self.spec.minimum - offset = self.spec.minimum - inputs = 0.5 * (inputs + 1.0) # [0, 1] - output = inputs * scale + offset # [minimum, maximum] - return output + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + scale = self.spec.maximum - self.spec.minimum + offset = self.spec.minimum + inputs = 0.5 * (inputs + 1.0) # [0, 1] + output = inputs * scale + offset # [minimum, maximum] + return output @dataclasses.dataclass class TanhToSpec: - """Squashes real-valued inputs to match a BoundedArraySpec.""" - spec: specs.BoundedArray - - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: - scale = self.spec.maximum - self.spec.minimum - offset = self.spec.minimum - inputs = lax.tanh(inputs) # [-1, 1] - inputs = 0.5 * (inputs + 1.0) # [0, 1] - output = inputs * scale + offset # [minimum, maximum] - return output + """Squashes real-valued inputs to match a BoundedArraySpec.""" + + spec: specs.BoundedArray + + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + scale = self.spec.maximum - self.spec.minimum + offset = self.spec.minimum + inputs = lax.tanh(inputs) # [-1, 1] + inputs = 0.5 * (inputs + 1.0) # [0, 1] + output = inputs * scale + offset # [minimum, maximum] + return output diff --git a/acme/jax/networks/resnet.py b/acme/jax/networks/resnet.py index 71a7c9429c..446bc708bf 100644 --- a/acme/jax/networks/resnet.py +++ b/acme/jax/networks/resnet.py @@ -17,6 +17,7 @@ import enum import functools from typing import Callable, Sequence, Union + import haiku as hk import jax import jax.numpy as jnp @@ -27,133 +28,158 @@ class ResidualBlock(hk.Module): - """Residual block of operations, e.g. convolutional or MLP.""" - - def __init__(self, - make_inner_op: MakeInnerOp, - non_linearity: NonLinearity = jax.nn.relu, - use_layer_norm: bool = False, - name: str = 'residual_block'): - super().__init__(name=name) - self.inner_op1 = make_inner_op() - self.inner_op2 = make_inner_op() - self.non_linearity = non_linearity - self.use_layer_norm = use_layer_norm - - if use_layer_norm: - self.layernorm1 = hk.LayerNorm( - axis=(-3, -2, -1), create_scale=True, create_offset=True, eps=1e-6) - self.layernorm2 = hk.LayerNorm( - axis=(-3, -2, -1), create_scale=True, create_offset=True, eps=1e-6) - - def __call__(self, x: jnp.ndarray): - output = x - - # First layer in residual block. - if self.use_layer_norm: - output = self.layernorm1(output) - output = self.non_linearity(output) - output = self.inner_op1(output) - - # Second layer in residual block. - if self.use_layer_norm: - output = self.layernorm2(output) - output = self.non_linearity(output) - output = self.inner_op2(output) - return x + output + """Residual block of operations, e.g. convolutional or MLP.""" + + def __init__( + self, + make_inner_op: MakeInnerOp, + non_linearity: NonLinearity = jax.nn.relu, + use_layer_norm: bool = False, + name: str = "residual_block", + ): + super().__init__(name=name) + self.inner_op1 = make_inner_op() + self.inner_op2 = make_inner_op() + self.non_linearity = non_linearity + self.use_layer_norm = use_layer_norm + + if use_layer_norm: + self.layernorm1 = hk.LayerNorm( + axis=(-3, -2, -1), create_scale=True, create_offset=True, eps=1e-6 + ) + self.layernorm2 = hk.LayerNorm( + axis=(-3, -2, -1), create_scale=True, create_offset=True, eps=1e-6 + ) + + def __call__(self, x: jnp.ndarray): + output = x + + # First layer in residual block. + if self.use_layer_norm: + output = self.layernorm1(output) + output = self.non_linearity(output) + output = self.inner_op1(output) + + # Second layer in residual block. + if self.use_layer_norm: + output = self.layernorm2(output) + output = self.non_linearity(output) + output = self.inner_op2(output) + return x + output # TODO(nikola): Remove this enum and configure downsampling with a layer factory # instead. class DownsamplingStrategy(enum.Enum): - AVG_POOL = 'avg_pool' - CONV_MAX = 'conv+max' # Used in IMPALA - LAYERNORM_RELU_CONV = 'layernorm+relu+conv' # Used in MuZero - CONV = 'conv' + AVG_POOL = "avg_pool" + CONV_MAX = "conv+max" # Used in IMPALA + LAYERNORM_RELU_CONV = "layernorm+relu+conv" # Used in MuZero + CONV = "conv" def make_downsampling_layer( - strategy: Union[str, DownsamplingStrategy], - output_channels: int, + strategy: Union[str, DownsamplingStrategy], output_channels: int, ) -> hk.SupportsCall: - """Returns a sequence of modules corresponding to the desired downsampling.""" - strategy = DownsamplingStrategy(strategy) - - if strategy is DownsamplingStrategy.AVG_POOL: - return hk.AvgPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding='SAME') - - elif strategy is DownsamplingStrategy.CONV: - return hk.Sequential([ - hk.Conv2D( - output_channels, - kernel_shape=3, - stride=2, - w_init=hk.initializers.TruncatedNormal(1e-2)), - ]) - - elif strategy is DownsamplingStrategy.LAYERNORM_RELU_CONV: - return hk.Sequential([ - hk.LayerNorm( - axis=(-3, -2, -1), create_scale=True, create_offset=True, eps=1e-6), - jax.nn.relu, - hk.Conv2D( - output_channels, - kernel_shape=3, - stride=2, - w_init=hk.initializers.TruncatedNormal(1e-2)), - ]) - - elif strategy is DownsamplingStrategy.CONV_MAX: - return hk.Sequential([ - hk.Conv2D(output_channels, kernel_shape=3, stride=1), - hk.MaxPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding='SAME') - ]) - else: - raise ValueError('Unrecognized downsampling strategy. Expected one of' - f' {[strategy.value for strategy in DownsamplingStrategy]}' - f' but received {strategy}.') + """Returns a sequence of modules corresponding to the desired downsampling.""" + strategy = DownsamplingStrategy(strategy) + + if strategy is DownsamplingStrategy.AVG_POOL: + return hk.AvgPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding="SAME") + + elif strategy is DownsamplingStrategy.CONV: + return hk.Sequential( + [ + hk.Conv2D( + output_channels, + kernel_shape=3, + stride=2, + w_init=hk.initializers.TruncatedNormal(1e-2), + ), + ] + ) + + elif strategy is DownsamplingStrategy.LAYERNORM_RELU_CONV: + return hk.Sequential( + [ + hk.LayerNorm( + axis=(-3, -2, -1), create_scale=True, create_offset=True, eps=1e-6 + ), + jax.nn.relu, + hk.Conv2D( + output_channels, + kernel_shape=3, + stride=2, + w_init=hk.initializers.TruncatedNormal(1e-2), + ), + ] + ) + + elif strategy is DownsamplingStrategy.CONV_MAX: + return hk.Sequential( + [ + hk.Conv2D(output_channels, kernel_shape=3, stride=1), + hk.MaxPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding="SAME"), + ] + ) + else: + raise ValueError( + "Unrecognized downsampling strategy. Expected one of" + f" {[strategy.value for strategy in DownsamplingStrategy]}" + f" but received {strategy}." + ) class ResNetTorso(hk.Module): - """ResNetTorso for visual inputs, inspired by the IMPALA paper.""" - - def __init__(self, - channels_per_group: Sequence[int] = (16, 32, 32), - blocks_per_group: Sequence[int] = (2, 2, 2), - downsampling_strategies: Sequence[DownsamplingStrategy] = ( - DownsamplingStrategy.CONV_MAX,) * 3, - use_layer_norm: bool = False, - name: str = 'resnet_torso'): - super().__init__(name=name) - self._channels_per_group = channels_per_group - self._blocks_per_group = blocks_per_group - self._downsampling_strategies = downsampling_strategies - self._use_layer_norm = use_layer_norm - - if (len(channels_per_group) != len(blocks_per_group) or - len(channels_per_group) != len(downsampling_strategies)): - raise ValueError('Length of channels_per_group, blocks_per_group, and ' - 'downsampling_strategies must be equal. ' - f'Got channels_per_group={channels_per_group}, ' - f'blocks_per_group={blocks_per_group}, and' - f'downsampling_strategies={downsampling_strategies}.') - - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: - output = inputs - channels_blocks_strategies = zip(self._channels_per_group, - self._blocks_per_group, - self._downsampling_strategies) - - for i, (num_channels, num_blocks, - strategy) in enumerate(channels_blocks_strategies): - output = make_downsampling_layer(strategy, num_channels)(output) - - for j in range(num_blocks): - output = ResidualBlock( - make_inner_op=functools.partial( - hk.Conv2D, output_channels=num_channels, kernel_shape=3), - use_layer_norm=self._use_layer_norm, - name=f'residual_{i}_{j}')( - output) - - return output + """ResNetTorso for visual inputs, inspired by the IMPALA paper.""" + + def __init__( + self, + channels_per_group: Sequence[int] = (16, 32, 32), + blocks_per_group: Sequence[int] = (2, 2, 2), + downsampling_strategies: Sequence[DownsamplingStrategy] = ( + DownsamplingStrategy.CONV_MAX, + ) + * 3, + use_layer_norm: bool = False, + name: str = "resnet_torso", + ): + super().__init__(name=name) + self._channels_per_group = channels_per_group + self._blocks_per_group = blocks_per_group + self._downsampling_strategies = downsampling_strategies + self._use_layer_norm = use_layer_norm + + if len(channels_per_group) != len(blocks_per_group) or len( + channels_per_group + ) != len(downsampling_strategies): + raise ValueError( + "Length of channels_per_group, blocks_per_group, and " + "downsampling_strategies must be equal. " + f"Got channels_per_group={channels_per_group}, " + f"blocks_per_group={blocks_per_group}, and" + f"downsampling_strategies={downsampling_strategies}." + ) + + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + output = inputs + channels_blocks_strategies = zip( + self._channels_per_group, + self._blocks_per_group, + self._downsampling_strategies, + ) + + for i, (num_channels, num_blocks, strategy) in enumerate( + channels_blocks_strategies + ): + output = make_downsampling_layer(strategy, num_channels)(output) + + for j in range(num_blocks): + output = ResidualBlock( + make_inner_op=functools.partial( + hk.Conv2D, output_channels=num_channels, kernel_shape=3 + ), + use_layer_norm=self._use_layer_norm, + name=f"residual_{i}_{j}", + )(output) + + return output diff --git a/acme/jax/observation_stacking.py b/acme/jax/observation_stacking.py index d495715ede..0eeb91437f 100644 --- a/acme/jax/observation_stacking.py +++ b/acme/jax/observation_stacking.py @@ -16,6 +16,12 @@ from typing import Any, Mapping, NamedTuple, Tuple +import jax +import jax.numpy as jnp +import reverb +import tensorflow as tf +import tree + from acme import specs from acme import types as acme_types from acme.agents.jax import actor_core as actor_core_lib @@ -23,11 +29,6 @@ from acme.jax import types as jax_types from acme.jax import utils as jax_utils from acme.tf import utils as tf_utils -import jax -import jax.numpy as jnp -import reverb -import tensorflow as tf -import tree ActorState = Any Observation = networks_lib.Observation @@ -36,28 +37,27 @@ class StackerState(NamedTuple): - stack: jax.Array # Observations stacked along the final dimension. - needs_reset: jax.Array # A scalar boolean. + stack: jax.Array # Observations stacked along the final dimension. + needs_reset: jax.Array # A scalar boolean. class StackingActorState(NamedTuple): - actor_state: ActorState - stacker_state: StackerState + actor_state: ActorState + stacker_state: StackerState # TODO(bshahr): Consider moving to jax_utils, extending current tiling function. def tile_nested_array(nest: acme_types.NestedArray, num: int, axis: int): + def _tile_array(array: jnp.ndarray) -> jnp.ndarray: + reps = [1] * array.ndim + reps[axis] = num + return jnp.tile(array, reps) - def _tile_array(array: jnp.ndarray) -> jnp.ndarray: - reps = [1] * array.ndim - reps[axis] = num - return jnp.tile(array, reps) - - return jax.tree_map(_tile_array, nest) + return jax.tree_map(_tile_array, nest) class ObservationStacker: - """Class used to handle agent-side observation stacking. + """Class used to handle agent-side observation stacking. Once an ObservationStacker is initialized and an initial_state is obtained from it, one can stack nested observations by simply calling the @@ -67,113 +67,124 @@ class ObservationStacker: See also observation_stacking.wrap_actor_core for hints on how to use it. """ - def __init__(self, - observation_spec: acme_types.NestedSpec, - stack_size: int = 4): - - def _repeat_observation(state: StackerState, - first_observation: Observation) -> StackerState: - return state._replace( - needs_reset=jnp.array(False), - stack=tile_nested_array(first_observation, stack_size - 1, axis=-1)) - - self._zero_stack = tile_nested_array( - jax_utils.zeros_like(observation_spec), stack_size - 1, axis=-1) - self._repeat_observation = _repeat_observation - - def __call__(self, inputs: Observation, - state: StackerState) -> Tuple[Observation, StackerState]: - - # If this is a first observation, initialize the stack by repeating it, - # otherwise leave it intact. - state = jax.lax.cond( - state.needs_reset, - self._repeat_observation, - lambda state, *args: state, # No-op on state. - state, - inputs) - - # Concatenate frames along the final axis (assumed to be for channels). - output = jax.tree_map(lambda *x: jnp.concatenate(x, axis=-1), - state.stack, inputs) - - # Update the frame stack by adding the input and dropping the first - # observation in the stack. Note that we use the final dimension as each - # leaf in the nested observation may have a different last dim. - new_state = state._replace( - stack=jax.tree_map(lambda x, y: y[..., x.shape[-1]:], inputs, output)) - - return output, new_state - - def initial_state(self) -> StackerState: - return StackerState(stack=self._zero_stack, needs_reset=jnp.array(True)) - - -def get_adjusted_environment_spec(environment_spec: specs.EnvironmentSpec, - stack_size: int) -> specs.EnvironmentSpec: - """Returns a spec where the observation spec accounts for stacking.""" - - def stack_observation_spec(obs_spec: specs.Array) -> specs.Array: - """Adjusts last axis shape to account for observation stacking.""" - new_shape = obs_spec.shape[:-1] + (obs_spec.shape[-1] * stack_size,) - return obs_spec.replace(shape=new_shape) - - adjusted_observation_spec = jax.tree_map(stack_observation_spec, - environment_spec.observations) - - return environment_spec._replace(observations=adjusted_observation_spec) + def __init__(self, observation_spec: acme_types.NestedSpec, stack_size: int = 4): + def _repeat_observation( + state: StackerState, first_observation: Observation + ) -> StackerState: + return state._replace( + needs_reset=jnp.array(False), + stack=tile_nested_array(first_observation, stack_size - 1, axis=-1), + ) + + self._zero_stack = tile_nested_array( + jax_utils.zeros_like(observation_spec), stack_size - 1, axis=-1 + ) + self._repeat_observation = _repeat_observation + + def __call__( + self, inputs: Observation, state: StackerState + ) -> Tuple[Observation, StackerState]: + + # If this is a first observation, initialize the stack by repeating it, + # otherwise leave it intact. + state = jax.lax.cond( + state.needs_reset, + self._repeat_observation, + lambda state, *args: state, # No-op on state. + state, + inputs, + ) + + # Concatenate frames along the final axis (assumed to be for channels). + output = jax.tree_map( + lambda *x: jnp.concatenate(x, axis=-1), state.stack, inputs + ) + + # Update the frame stack by adding the input and dropping the first + # observation in the stack. Note that we use the final dimension as each + # leaf in the nested observation may have a different last dim. + new_state = state._replace( + stack=jax.tree_map(lambda x, y: y[..., x.shape[-1] :], inputs, output) + ) + + return output, new_state + + def initial_state(self) -> StackerState: + return StackerState(stack=self._zero_stack, needs_reset=jnp.array(True)) + + +def get_adjusted_environment_spec( + environment_spec: specs.EnvironmentSpec, stack_size: int +) -> specs.EnvironmentSpec: + """Returns a spec where the observation spec accounts for stacking.""" + + def stack_observation_spec(obs_spec: specs.Array) -> specs.Array: + """Adjusts last axis shape to account for observation stacking.""" + new_shape = obs_spec.shape[:-1] + (obs_spec.shape[-1] * stack_size,) + return obs_spec.replace(shape=new_shape) + + adjusted_observation_spec = jax.tree_map( + stack_observation_spec, environment_spec.observations + ) + + return environment_spec._replace(observations=adjusted_observation_spec) def wrap_actor_core( actor_core: actor_core_lib.ActorCore, observation_spec: specs.Array, - num_stacked_observations: int = 1) -> actor_core_lib.ActorCore: - """Wraps an actor core so that it performs observation stacking.""" + num_stacked_observations: int = 1, +) -> actor_core_lib.ActorCore: + """Wraps an actor core so that it performs observation stacking.""" - if num_stacked_observations <= 0: - raise ValueError( - 'Number of stacked observations must be strictly positive.' - f' Received num_stacked_observations={num_stacked_observations}.') + if num_stacked_observations <= 0: + raise ValueError( + "Number of stacked observations must be strictly positive." + f" Received num_stacked_observations={num_stacked_observations}." + ) - if num_stacked_observations == 1: - # Return unwrapped core when a trivial stack size is requested. - return actor_core + if num_stacked_observations == 1: + # Return unwrapped core when a trivial stack size is requested. + return actor_core - obs_stacker = ObservationStacker( - observation_spec=observation_spec, stack_size=num_stacked_observations) + obs_stacker = ObservationStacker( + observation_spec=observation_spec, stack_size=num_stacked_observations + ) - def init(key: jax_types.PRNGKey) -> StackingActorState: - return StackingActorState( - actor_state=actor_core.init(key), - stacker_state=obs_stacker.initial_state()) + def init(key: jax_types.PRNGKey) -> StackingActorState: + return StackingActorState( + actor_state=actor_core.init(key), stacker_state=obs_stacker.initial_state() + ) - def select_action( - params: Params, - observations: Observation, - state: StackingActorState, - ) -> Tuple[Action, StackingActorState]: + def select_action( + params: Params, observations: Observation, state: StackingActorState, + ) -> Tuple[Action, StackingActorState]: - stacked_observations, stacker_state = obs_stacker(observations, - state.stacker_state) + stacked_observations, stacker_state = obs_stacker( + observations, state.stacker_state + ) - actions, actor_state = actor_core.select_action(params, - stacked_observations, - state.actor_state) - new_state = StackingActorState( - actor_state=actor_state, stacker_state=stacker_state) + actions, actor_state = actor_core.select_action( + params, stacked_observations, state.actor_state + ) + new_state = StackingActorState( + actor_state=actor_state, stacker_state=stacker_state + ) - return actions, new_state + return actions, new_state - def get_extras(state: StackingActorState) -> Mapping[str, jnp.ndarray]: - return actor_core.get_extras(state.actor_state) + def get_extras(state: StackingActorState) -> Mapping[str, jnp.ndarray]: + return actor_core.get_extras(state.actor_state) - return actor_core_lib.ActorCore( - init=init, select_action=select_action, get_extras=get_extras) + return actor_core_lib.ActorCore( + init=init, select_action=select_action, get_extras=get_extras + ) -def stack_reverb_observation(sample: reverb.ReplaySample, - stack_size: int) -> reverb.ReplaySample: - """Stacks observations in a Reverb sample. +def stack_reverb_observation( + sample: reverb.ReplaySample, stack_size: int +) -> reverb.ReplaySample: + """Stacks observations in a Reverb sample. This function is meant to be used on the dataset creation side as a post-processing function before batching. @@ -199,26 +210,29 @@ def stack_reverb_observation(sample: reverb.ReplaySample, multiplied by `stack_size`. """ - def _repeat_first(sequence: tf.Tensor) -> tf.Tensor: - repeated_first_step = tf_utils.tile_tensor(sequence[0], stack_size - 1) - return tf.concat([repeated_first_step, sequence], 0)[:-(stack_size - 1)] + def _repeat_first(sequence: tf.Tensor) -> tf.Tensor: + repeated_first_step = tf_utils.tile_tensor(sequence[0], stack_size - 1) + return tf.concat([repeated_first_step, sequence], 0)[: -(stack_size - 1)] - def _stack_observation(observation: tf.Tensor) -> tf.Tensor: - stack = [tf.roll(observation, i, axis=0) for i in range(stack_size)] - stack.reverse() # Reverse stack order to be chronological. - return tf.concat(stack, axis=-1) + def _stack_observation(observation: tf.Tensor) -> tf.Tensor: + stack = [tf.roll(observation, i, axis=0) for i in range(stack_size)] + stack.reverse() # Reverse stack order to be chronological. + return tf.concat(stack, axis=-1) - # Maybe repeat the first observation, if at the start of an episode. - data = tf.cond(sample.data.start_of_episode[0], - lambda: tree.map_structure(_repeat_first, sample.data), - lambda: sample.data) + # Maybe repeat the first observation, if at the start of an episode. + data = tf.cond( + sample.data.start_of_episode[0], + lambda: tree.map_structure(_repeat_first, sample.data), + lambda: sample.data, + ) - # Stack observation in the sample's data. - data_with_stacked_obs = data._replace( - observation=tree.map_structure(_stack_observation, data.observation)) + # Stack observation in the sample's data. + data_with_stacked_obs = data._replace( + observation=tree.map_structure(_stack_observation, data.observation) + ) - # Truncate the start of the sequence due to the first stacks containing the - # final observations that were rolled over to the start. - data = tree.map_structure(lambda x: x[stack_size - 1:], data_with_stacked_obs) + # Truncate the start of the sequence due to the first stacks containing the + # final observations that were rolled over to the start. + data = tree.map_structure(lambda x: x[stack_size - 1 :], data_with_stacked_obs) - return reverb.ReplaySample(info=sample.info, data=data) + return reverb.ReplaySample(info=sample.info, data=data) diff --git a/acme/jax/running_statistics.py b/acme/jax/running_statistics.py index 6f6688e41d..14e7aa06f1 100644 --- a/acme/jax/running_statistics.py +++ b/acme/jax/running_statistics.py @@ -17,14 +17,14 @@ import dataclasses from typing import Any, Optional, Tuple, Union -from acme import types -from acme.utils import tree_utils import chex import jax import jax.numpy as jnp import numpy as np import tree +from acme import types +from acme.utils import tree_utils Path = Tuple[Any, ...] """Path in a nested structure. @@ -38,69 +38,75 @@ def _is_prefix(a: Path, b: Path) -> bool: - """Returns whether `a` is a prefix of `b`.""" - return b[:len(a)] == a + """Returns whether `a` is a prefix of `b`.""" + return b[: len(a)] == a def _zeros_like(nest: types.Nest, dtype=None) -> types.NestedArray: - return jax.tree_map(lambda x: jnp.zeros(x.shape, dtype or x.dtype), nest) + return jax.tree_map(lambda x: jnp.zeros(x.shape, dtype or x.dtype), nest) def _ones_like(nest: types.Nest, dtype=None) -> types.NestedArray: - return jax.tree_map(lambda x: jnp.ones(x.shape, dtype or x.dtype), nest) + return jax.tree_map(lambda x: jnp.ones(x.shape, dtype or x.dtype), nest) @chex.dataclass(frozen=True) class NestedMeanStd: - """A container for running statistics (mean, std) of possibly nested data.""" - mean: types.NestedArray - std: types.NestedArray + """A container for running statistics (mean, std) of possibly nested data.""" + + mean: types.NestedArray + std: types.NestedArray @chex.dataclass(frozen=True) class RunningStatisticsState(NestedMeanStd): - """Full state of running statistics computation.""" - count: Union[int, jnp.ndarray] - summed_variance: types.NestedArray + """Full state of running statistics computation.""" + + count: Union[int, jnp.ndarray] + summed_variance: types.NestedArray @dataclasses.dataclass(frozen=True) class NestStatisticsConfig: - """Specifies how to compute statistics for Nests with the same structure. + """Specifies how to compute statistics for Nests with the same structure. Attributes: paths: A sequence of Nest paths to compute statistics for. If there is a collision between paths (one is a prefix of the other), the shorter path takes precedence. """ - paths: Tuple[Path, ...] = ((),) + + paths: Tuple[Path, ...] = ((),) def _is_path_included(config: NestStatisticsConfig, path: Path) -> bool: - """Returns whether the path is included in the config.""" - # A path is included in the config if it corresponds to a tree node that - # belongs to a subtree rooted at the node corresponding to some path in - # the config. - return any(_is_prefix(config_path, path) for config_path in config.paths) + """Returns whether the path is included in the config.""" + # A path is included in the config if it corresponds to a tree node that + # belongs to a subtree rooted at the node corresponding to some path in + # the config. + return any(_is_prefix(config_path, path) for config_path in config.paths) def init_state(nest: types.Nest) -> RunningStatisticsState: - """Initializes the running statistics for the given nested structure.""" - dtype = jnp.float64 if jax.config.jax_enable_x64 else jnp.float32 - - return RunningStatisticsState( # pytype: disable=wrong-arg-types # jax-ndarray - count=0., - mean=_zeros_like(nest, dtype=dtype), - summed_variance=_zeros_like(nest, dtype=dtype), - # Initialize with ones to make sure normalization works correctly - # in the initial state. - std=_ones_like(nest, dtype=dtype)) - - -def _validate_batch_shapes(batch: types.NestedArray, - reference_sample: types.NestedArray, - batch_dims: Tuple[int, ...]) -> None: - """Verifies shapes of the batch leaves against the reference sample. + """Initializes the running statistics for the given nested structure.""" + dtype = jnp.float64 if jax.config.jax_enable_x64 else jnp.float32 + + return RunningStatisticsState( # pytype: disable=wrong-arg-types # jax-ndarray + count=0.0, + mean=_zeros_like(nest, dtype=dtype), + summed_variance=_zeros_like(nest, dtype=dtype), + # Initialize with ones to make sure normalization works correctly + # in the initial state. + std=_ones_like(nest, dtype=dtype), + ) + + +def _validate_batch_shapes( + batch: types.NestedArray, + reference_sample: types.NestedArray, + batch_dims: Tuple[int, ...], +) -> None: + """Verifies shapes of the batch leaves against the reference sample. Checks that batch dimensions are the same in all leaves in the batch. Checks that non-batch dimensions for all leaves in the batch are the same @@ -114,24 +120,26 @@ def _validate_batch_shapes(batch: types.NestedArray, Returns: None. """ - def validate_node_shape(reference_sample: jnp.ndarray, - batch: jnp.ndarray) -> None: - expected_shape = batch_dims + reference_sample.shape - assert batch.shape == expected_shape, f'{batch.shape} != {expected_shape}' - tree_utils.fast_map_structure(validate_node_shape, reference_sample, batch) + def validate_node_shape(reference_sample: jnp.ndarray, batch: jnp.ndarray) -> None: + expected_shape = batch_dims + reference_sample.shape + assert batch.shape == expected_shape, f"{batch.shape} != {expected_shape}" + + tree_utils.fast_map_structure(validate_node_shape, reference_sample, batch) -def update(state: RunningStatisticsState, - batch: types.NestedArray, - *, - config: NestStatisticsConfig = NestStatisticsConfig(), - weights: Optional[jnp.ndarray] = None, - std_min_value: float = 1e-6, - std_max_value: float = 1e6, - pmap_axis_name: Optional[str] = None, - validate_shapes: bool = True) -> RunningStatisticsState: - """Updates the running statistics with the given batch of data. +def update( + state: RunningStatisticsState, + batch: types.NestedArray, + *, + config: NestStatisticsConfig = NestStatisticsConfig(), + weights: Optional[jnp.ndarray] = None, + std_min_value: float = 1e-6, + std_max_value: float = 1e6, + pmap_axis_name: Optional[str] = None, + validate_shapes: bool = True, +) -> RunningStatisticsState: + """Updates the running statistics with the given batch of data. Note: data batch and state elements (mean, etc.) must have the same structure. @@ -159,116 +167,123 @@ def update(state: RunningStatisticsState, Returns: Updated running statistics. """ - # We require exactly the same structure to avoid issues when flattened - # batch and state have different order of elements. - tree.assert_same_structure(batch, state.mean) - batch_shape = tree.flatten(batch)[0].shape - # We assume the batch dimensions always go first. - batch_dims = batch_shape[:len(batch_shape) - tree.flatten(state.mean)[0].ndim] - batch_axis = range(len(batch_dims)) - if weights is None: - step_increment = np.prod(batch_dims) - else: - step_increment = jnp.sum(weights) - if pmap_axis_name is not None: - step_increment = jax.lax.psum(step_increment, axis_name=pmap_axis_name) - count = state.count + step_increment - - # Validation is important. If the shapes don't match exactly, but are - # compatible, arrays will be silently broadcasted resulting in incorrect - # statistics. - if validate_shapes: - if weights is not None: - if weights.shape != batch_dims: - raise ValueError(f'{weights.shape} != {batch_dims}') - _validate_batch_shapes(batch, state.mean, batch_dims) - - def _compute_node_statistics( - path: Path, mean: jnp.ndarray, summed_variance: jnp.ndarray, - batch: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: - assert isinstance(mean, jnp.ndarray), type(mean) - assert isinstance(summed_variance, jnp.ndarray), type(summed_variance) - if not _is_path_included(config, path): - # Return unchanged. - return mean, summed_variance - # The mean and the sum of past variances are updated with Welford's - # algorithm using batches (see https://stackoverflow.com/q/56402955). - diff_to_old_mean = batch - mean - if weights is not None: - expanded_weights = jnp.reshape( - weights, - list(weights.shape) + [1] * (batch.ndim - weights.ndim)) - diff_to_old_mean = diff_to_old_mean * expanded_weights - mean_update = jnp.sum(diff_to_old_mean, axis=batch_axis) / count + # We require exactly the same structure to avoid issues when flattened + # batch and state have different order of elements. + tree.assert_same_structure(batch, state.mean) + batch_shape = tree.flatten(batch)[0].shape + # We assume the batch dimensions always go first. + batch_dims = batch_shape[: len(batch_shape) - tree.flatten(state.mean)[0].ndim] + batch_axis = range(len(batch_dims)) + if weights is None: + step_increment = np.prod(batch_dims) + else: + step_increment = jnp.sum(weights) if pmap_axis_name is not None: - mean_update = jax.lax.psum( - mean_update, axis_name=pmap_axis_name) - mean = mean + mean_update - - diff_to_new_mean = batch - mean - variance_update = diff_to_old_mean * diff_to_new_mean - variance_update = jnp.sum(variance_update, axis=batch_axis) - if pmap_axis_name is not None: - variance_update = jax.lax.psum(variance_update, axis_name=pmap_axis_name) - summed_variance = summed_variance + variance_update - return mean, summed_variance - - updated_stats = tree_utils.fast_map_structure_with_path( - _compute_node_statistics, state.mean, state.summed_variance, batch) - # map_structure_up_to is slow, so shortcut if we know the input is not - # structured. - if isinstance(state.mean, jnp.ndarray): - mean, summed_variance = updated_stats - else: - # Reshape the updated stats from `nest(mean, summed_variance)` to - # `nest(mean), nest(summed_variance)`. - mean, summed_variance = [ - tree.map_structure_up_to( - state.mean, lambda s, i=idx: s[i], updated_stats) - for idx in range(2) - ] - - def compute_std(path: Path, summed_variance: jnp.ndarray, - std: jnp.ndarray) -> jnp.ndarray: - assert isinstance(summed_variance, jnp.ndarray) - if not _is_path_included(config, path): - return std - # Summed variance can get negative due to rounding errors. - summed_variance = jnp.maximum(summed_variance, 0) - std = jnp.sqrt(summed_variance / count) - std = jnp.clip(std, std_min_value, std_max_value) - return std - - std = tree_utils.fast_map_structure_with_path(compute_std, summed_variance, - state.std) - - return RunningStatisticsState( - count=count, mean=mean, summed_variance=summed_variance, std=std) - - -def normalize(batch: types.NestedArray, - mean_std: NestedMeanStd, - max_abs_value: Optional[float] = None) -> types.NestedArray: - """Normalizes data using running statistics.""" - - def normalize_leaf(data: jnp.ndarray, mean: jnp.ndarray, - std: jnp.ndarray) -> jnp.ndarray: - # Only normalize inexact types. - if not jnp.issubdtype(data.dtype, jnp.inexact): - return data - data = (data - mean) / std - if max_abs_value is not None: - # TODO(b/124318564): remove pylint directive - data = jnp.clip(data, -max_abs_value, +max_abs_value) # pylint: disable=invalid-unary-operand-type - return data - - return tree_utils.fast_map_structure(normalize_leaf, batch, mean_std.mean, - mean_std.std) - - -def denormalize(batch: types.NestedArray, - mean_std: NestedMeanStd) -> types.NestedArray: - """Denormalizes values in a nested structure using the given mean/std. + step_increment = jax.lax.psum(step_increment, axis_name=pmap_axis_name) + count = state.count + step_increment + + # Validation is important. If the shapes don't match exactly, but are + # compatible, arrays will be silently broadcasted resulting in incorrect + # statistics. + if validate_shapes: + if weights is not None: + if weights.shape != batch_dims: + raise ValueError(f"{weights.shape} != {batch_dims}") + _validate_batch_shapes(batch, state.mean, batch_dims) + + def _compute_node_statistics( + path: Path, mean: jnp.ndarray, summed_variance: jnp.ndarray, batch: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + assert isinstance(mean, jnp.ndarray), type(mean) + assert isinstance(summed_variance, jnp.ndarray), type(summed_variance) + if not _is_path_included(config, path): + # Return unchanged. + return mean, summed_variance + # The mean and the sum of past variances are updated with Welford's + # algorithm using batches (see https://stackoverflow.com/q/56402955). + diff_to_old_mean = batch - mean + if weights is not None: + expanded_weights = jnp.reshape( + weights, list(weights.shape) + [1] * (batch.ndim - weights.ndim) + ) + diff_to_old_mean = diff_to_old_mean * expanded_weights + mean_update = jnp.sum(diff_to_old_mean, axis=batch_axis) / count + if pmap_axis_name is not None: + mean_update = jax.lax.psum(mean_update, axis_name=pmap_axis_name) + mean = mean + mean_update + + diff_to_new_mean = batch - mean + variance_update = diff_to_old_mean * diff_to_new_mean + variance_update = jnp.sum(variance_update, axis=batch_axis) + if pmap_axis_name is not None: + variance_update = jax.lax.psum(variance_update, axis_name=pmap_axis_name) + summed_variance = summed_variance + variance_update + return mean, summed_variance + + updated_stats = tree_utils.fast_map_structure_with_path( + _compute_node_statistics, state.mean, state.summed_variance, batch + ) + # map_structure_up_to is slow, so shortcut if we know the input is not + # structured. + if isinstance(state.mean, jnp.ndarray): + mean, summed_variance = updated_stats + else: + # Reshape the updated stats from `nest(mean, summed_variance)` to + # `nest(mean), nest(summed_variance)`. + mean, summed_variance = [ + tree.map_structure_up_to(state.mean, lambda s, i=idx: s[i], updated_stats) + for idx in range(2) + ] + + def compute_std( + path: Path, summed_variance: jnp.ndarray, std: jnp.ndarray + ) -> jnp.ndarray: + assert isinstance(summed_variance, jnp.ndarray) + if not _is_path_included(config, path): + return std + # Summed variance can get negative due to rounding errors. + summed_variance = jnp.maximum(summed_variance, 0) + std = jnp.sqrt(summed_variance / count) + std = jnp.clip(std, std_min_value, std_max_value) + return std + + std = tree_utils.fast_map_structure_with_path( + compute_std, summed_variance, state.std + ) + + return RunningStatisticsState( + count=count, mean=mean, summed_variance=summed_variance, std=std + ) + + +def normalize( + batch: types.NestedArray, + mean_std: NestedMeanStd, + max_abs_value: Optional[float] = None, +) -> types.NestedArray: + """Normalizes data using running statistics.""" + + def normalize_leaf( + data: jnp.ndarray, mean: jnp.ndarray, std: jnp.ndarray + ) -> jnp.ndarray: + # Only normalize inexact types. + if not jnp.issubdtype(data.dtype, jnp.inexact): + return data + data = (data - mean) / std + if max_abs_value is not None: + # TODO(b/124318564): remove pylint directive + data = jnp.clip( + data, -max_abs_value, +max_abs_value + ) # pylint: disable=invalid-unary-operand-type + return data + + return tree_utils.fast_map_structure( + normalize_leaf, batch, mean_std.mean, mean_std.std + ) + + +def denormalize(batch: types.NestedArray, mean_std: NestedMeanStd) -> types.NestedArray: + """Denormalizes values in a nested structure using the given mean/std. Only values of inexact types are denormalized. See https://numpy.org/doc/stable/_images/dtype-hierarchy.png for Numpy type @@ -282,20 +297,22 @@ def denormalize(batch: types.NestedArray, Nested structure with denormalized values. """ - def denormalize_leaf(data: jnp.ndarray, mean: jnp.ndarray, - std: jnp.ndarray) -> jnp.ndarray: - # Only denormalize inexact types. - if not np.issubdtype(data.dtype, np.inexact): - return data - return data * std + mean + def denormalize_leaf( + data: jnp.ndarray, mean: jnp.ndarray, std: jnp.ndarray + ) -> jnp.ndarray: + # Only denormalize inexact types. + if not np.issubdtype(data.dtype, np.inexact): + return data + return data * std + mean - return tree_utils.fast_map_structure(denormalize_leaf, batch, mean_std.mean, - mean_std.std) + return tree_utils.fast_map_structure( + denormalize_leaf, batch, mean_std.mean, mean_std.std + ) @dataclasses.dataclass(frozen=True) class NestClippingConfig: - """Specifies how to clip Nests with the same structure. + """Specifies how to clip Nests with the same structure. Attributes: path_map: A map that specifies how to clip values in Nests with the same @@ -303,53 +320,64 @@ class NestClippingConfig: absolute values to use for clipping. If there is a collision between paths (one path is a prefix of the other), the behavior is undefined. """ - path_map: Tuple[Tuple[Path, float], ...] = () - - -def get_clip_config_for_path(config: NestClippingConfig, - path: Path) -> NestClippingConfig: - """Returns the config for a subtree from the leaf defined by the path.""" - # Start with an empty config. - path_map = [] - for map_path, max_abs_value in config.path_map: - if _is_prefix(map_path, path): - return NestClippingConfig(path_map=(((), max_abs_value),)) - if _is_prefix(path, map_path): - path_map.append((map_path[len(path):], max_abs_value)) - return NestClippingConfig(path_map=tuple(path_map)) - -def clip(batch: types.NestedArray, - clipping_config: NestClippingConfig) -> types.NestedArray: - """Clips the batch.""" - - def max_abs_value_for_path(path: Path, x: jnp.ndarray) -> Optional[float]: - del x # Unused, needed by interface. - return next((max_abs_value - for clipping_path, max_abs_value in clipping_config.path_map - if _is_prefix(clipping_path, path)), None) - - max_abs_values = tree_utils.fast_map_structure_with_path( - max_abs_value_for_path, batch) - - def clip_leaf(data: jnp.ndarray, - max_abs_value: Optional[float]) -> jnp.ndarray: - if max_abs_value is not None: - # TODO(b/124318564): remove pylint directive - data = jnp.clip(data, -max_abs_value, +max_abs_value) # pylint: disable=invalid-unary-operand-type - return data - - return tree_utils.fast_map_structure(clip_leaf, batch, max_abs_values) + path_map: Tuple[Tuple[Path, float], ...] = () + + +def get_clip_config_for_path( + config: NestClippingConfig, path: Path +) -> NestClippingConfig: + """Returns the config for a subtree from the leaf defined by the path.""" + # Start with an empty config. + path_map = [] + for map_path, max_abs_value in config.path_map: + if _is_prefix(map_path, path): + return NestClippingConfig(path_map=(((), max_abs_value),)) + if _is_prefix(path, map_path): + path_map.append((map_path[len(path) :], max_abs_value)) + return NestClippingConfig(path_map=tuple(path_map)) + + +def clip( + batch: types.NestedArray, clipping_config: NestClippingConfig +) -> types.NestedArray: + """Clips the batch.""" + + def max_abs_value_for_path(path: Path, x: jnp.ndarray) -> Optional[float]: + del x # Unused, needed by interface. + return next( + ( + max_abs_value + for clipping_path, max_abs_value in clipping_config.path_map + if _is_prefix(clipping_path, path) + ), + None, + ) + + max_abs_values = tree_utils.fast_map_structure_with_path( + max_abs_value_for_path, batch + ) + + def clip_leaf(data: jnp.ndarray, max_abs_value: Optional[float]) -> jnp.ndarray: + if max_abs_value is not None: + # TODO(b/124318564): remove pylint directive + data = jnp.clip( + data, -max_abs_value, +max_abs_value + ) # pylint: disable=invalid-unary-operand-type + return data + + return tree_utils.fast_map_structure(clip_leaf, batch, max_abs_values) @dataclasses.dataclass(frozen=True) class NestNormalizationConfig: - """Specifies how to normalize Nests with the same structure. + """Specifies how to normalize Nests with the same structure. Attributes: stats_config: A config that defines how to compute running statistics to be used for normalization. clip_config: A config that defines how to clip normalized values. """ - stats_config: NestStatisticsConfig = NestStatisticsConfig() - clip_config: NestClippingConfig = NestClippingConfig() + + stats_config: NestStatisticsConfig = NestStatisticsConfig() + clip_config: NestClippingConfig = NestClippingConfig() diff --git a/acme/jax/running_statistics_test.py b/acme/jax/running_statistics_test.py index 2119515737..f299af8898 100644 --- a/acme/jax/running_statistics_test.py +++ b/acme/jax/running_statistics_test.py @@ -18,288 +18,285 @@ import math from typing import NamedTuple -from acme import specs -from acme.jax import running_statistics import jax -from jax.config import config as jax_config import jax.numpy as jnp import numpy as np import tree - from absl.testing import absltest +from jax.config import config as jax_config -update_and_validate = functools.partial( - running_statistics.update, validate_shapes=True) +from acme import specs +from acme.jax import running_statistics + +update_and_validate = functools.partial(running_statistics.update, validate_shapes=True) class TestNestedSpec(NamedTuple): - # Note: the fields are intentionally in reverse order to test ordering. - a: specs.Array - b: specs.Array + # Note: the fields are intentionally in reverse order to test ordering. + a: specs.Array + b: specs.Array class RunningStatisticsTest(absltest.TestCase): + def setUp(self): + super().setUp() + jax_config.update("jax_enable_x64", False) + + def assert_allclose( + self, actual: jnp.ndarray, desired: jnp.ndarray, err_msg: str = "" + ) -> None: + np.testing.assert_allclose( + actual, desired, atol=1e-5, rtol=1e-5, err_msg=err_msg + ) - def setUp(self): - super().setUp() - jax_config.update('jax_enable_x64', False) - - def assert_allclose(self, - actual: jnp.ndarray, - desired: jnp.ndarray, - err_msg: str = '') -> None: - np.testing.assert_allclose( - actual, desired, atol=1e-5, rtol=1e-5, err_msg=err_msg) - - def test_normalize(self): - state = running_statistics.init_state(specs.Array((5,), jnp.float32)) - - x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5) - x1, x2, x3, x4 = jnp.split(x, 4, axis=0) - - state = update_and_validate(state, x1) - state = update_and_validate(state, x2) - state = update_and_validate(state, x3) - state = update_and_validate(state, x4) - normalized = running_statistics.normalize(x, state) - - mean = jnp.mean(normalized) - std = jnp.std(normalized) - self.assert_allclose(mean, jnp.zeros_like(mean)) - self.assert_allclose(std, jnp.ones_like(std)) - - def test_init_normalize(self): - state = running_statistics.init_state(specs.Array((5,), jnp.float32)) - - x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5) - normalized = running_statistics.normalize(x, state) - - self.assert_allclose(normalized, x) - - def test_one_batch_dim(self): - state = running_statistics.init_state(specs.Array((5,), jnp.float32)) - - x = jnp.arange(10, dtype=jnp.float32).reshape(2, 5) - - state = update_and_validate(state, x) - normalized = running_statistics.normalize(x, state) - - mean = jnp.mean(normalized, axis=0) - std = jnp.std(normalized, axis=0) - self.assert_allclose(mean, jnp.zeros_like(mean)) - self.assert_allclose(std, jnp.ones_like(std)) - - def test_clip(self): - state = running_statistics.init_state(specs.Array((), jnp.float32)) - - x = jnp.arange(5, dtype=jnp.float32) - - state = update_and_validate(state, x) - normalized = running_statistics.normalize(x, state, max_abs_value=1.0) - - mean = jnp.mean(normalized) - std = jnp.std(normalized) - self.assert_allclose(mean, jnp.zeros_like(mean)) - self.assert_allclose(std, jnp.ones_like(std) * math.sqrt(0.6)) - - def test_nested_normalize(self): - state = running_statistics.init_state({ - 'a': specs.Array((5,), jnp.float32), - 'b': specs.Array((2,), jnp.float32) - }) - - x1 = { - 'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5), - 'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) - } - x2 = { - 'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5) + 20, - 'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) + 8 - } - x3 = { - 'a': jnp.arange(40, dtype=jnp.float32).reshape(4, 2, 5), - 'b': jnp.arange(16, dtype=jnp.float32).reshape(4, 2, 2) - } - - state = update_and_validate(state, x1) - state = update_and_validate(state, x2) - state = update_and_validate(state, x3) - normalized = running_statistics.normalize(x3, state) - - mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized) - std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized) - tree.map_structure( - lambda x: self.assert_allclose(x, jnp.zeros_like(x)), - mean) - tree.map_structure( - lambda x: self.assert_allclose(x, jnp.ones_like(x)), - std) - - def test_validation(self): - state = running_statistics.init_state(specs.Array((1, 2, 3), jnp.float32)) - - x = jnp.arange(12, dtype=jnp.float32).reshape(2, 2, 3) - with self.assertRaises(AssertionError): - update_and_validate(state, x) - - x = jnp.arange(3, dtype=jnp.float32).reshape(1, 1, 3) - with self.assertRaises(AssertionError): - update_and_validate(state, x) - - def test_int_not_normalized(self): - state = running_statistics.init_state(specs.Array((), jnp.int32)) - - x = jnp.arange(5, dtype=jnp.int32) - - state = update_and_validate(state, x) - normalized = running_statistics.normalize(x, state) - - np.testing.assert_array_equal(normalized, x) - - def test_pmap_update_nested(self): - local_device_count = jax.local_device_count() - state = running_statistics.init_state({ - 'a': specs.Array((5,), jnp.float32), - 'b': specs.Array((2,), jnp.float32) - }) - - x = { - 'a': (jnp.arange(15 * local_device_count, - dtype=jnp.float32)).reshape(local_device_count, 3, 5), - 'b': (jnp.arange(6 * local_device_count, - dtype=jnp.float32)).reshape(local_device_count, 3, 2), - } - - devices = jax.local_devices() - state = jax.device_put_replicated(state, devices) - pmap_axis_name = 'i' - state = jax.pmap( - functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name), - pmap_axis_name)(state, x) - state = jax.pmap( - functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name), - pmap_axis_name)(state, x) - normalized = jax.pmap(running_statistics.normalize)(x, state) - - mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized) - std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized) - tree.map_structure( - lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean) - tree.map_structure( - lambda x: self.assert_allclose(x, jnp.ones_like(x)), std) - - def test_different_structure_normalize(self): - spec = TestNestedSpec( - a=specs.Array((5,), jnp.float32), b=specs.Array((2,), jnp.float32)) - state = running_statistics.init_state(spec) - - x = { - 'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5), - 'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) - } - - with self.assertRaises(TypeError): - state = update_and_validate(state, x) - - def test_weights(self): - state = running_statistics.init_state(specs.Array((), jnp.float32)) - - x = jnp.arange(5, dtype=jnp.float32) - x_weights = jnp.ones_like(x) - y = 2 * x + 5 - y_weights = 2 * x_weights - z = jnp.concatenate([x, y]) - weights = jnp.concatenate([x_weights, y_weights]) - - state = update_and_validate(state, z, weights=weights) - - self.assertEqual(state.mean, (jnp.mean(x) + 2 * jnp.mean(y)) / 3) - big_z = jnp.concatenate([x, y, y]) - normalized = running_statistics.normalize(big_z, state) - self.assertAlmostEqual(jnp.mean(normalized), 0., places=6) - self.assertAlmostEqual(jnp.std(normalized), 1., places=6) - - def test_normalize_config(self): - x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5) - x_split = jnp.split(x, 5, axis=0) - - y = jnp.arange(160, dtype=jnp.float32).reshape(20, 2, 4) - y_split = jnp.split(y, 5, axis=0) - - z = {'a': x, 'b': y} - - z_split = [{'a': xx, 'b': yy} for xx, yy in zip(x_split, y_split)] - - update = jax.jit(running_statistics.update, static_argnames=('config',)) - - config = running_statistics.NestStatisticsConfig((('a',),)) - state = running_statistics.init_state({ - 'a': specs.Array((5,), jnp.float32), - 'b': specs.Array((4,), jnp.float32) - }) - # Test initialization from the first element. - state = update(state, z_split[0], config=config) - state = update(state, z_split[1], config=config) - state = update(state, z_split[2], config=config) - state = update(state, z_split[3], config=config) - state = update(state, z_split[4], config=config) - - normalize = jax.jit(running_statistics.normalize) - normalized = normalize(z, state) - - for key in normalized: - mean = jnp.mean(normalized[key], axis=(0, 1)) - std = jnp.std(normalized[key], axis=(0, 1)) - if key == 'a': - self.assert_allclose( - mean, - jnp.zeros_like(mean), - err_msg=f'key:{key} mean:{mean} normalized:{normalized[key]}') - self.assert_allclose( - std, - jnp.ones_like(std), - err_msg=f'key:{key} std:{std} normalized:{normalized[key]}') - else: - assert key == 'b' - np.testing.assert_array_equal( - normalized[key], - z[key], - err_msg=f'z:{z[key]} normalized:{normalized[key]}') - - def test_clip_config(self): - x = jnp.arange(10, dtype=jnp.float32) - 5 - y = jnp.arange(8, dtype=jnp.float32) - 4 - - z = {'x': x, 'y': y} - - max_abs_x = 2 - config = running_statistics.NestClippingConfig(((('x',), max_abs_x),)) - - clipped_z = running_statistics.clip(z, config) - - clipped_x = jnp.clip(a=x, a_min=-max_abs_x, a_max=max_abs_x) - np.testing.assert_array_equal(clipped_z['x'], clipped_x) - - np.testing.assert_array_equal(clipped_z['y'], z['y']) - - def test_denormalize(self): - state = running_statistics.init_state(specs.Array((5,), jnp.float32)) - - x = jnp.arange(100, dtype=jnp.float32).reshape(10, 2, 5) - x1, x2 = jnp.split(x, 2, axis=0) - - state = update_and_validate(state, x1) - state = update_and_validate(state, x2) - normalized = running_statistics.normalize(x, state) - - mean = jnp.mean(normalized) - std = jnp.std(normalized) - self.assert_allclose(mean, jnp.zeros_like(mean)) - self.assert_allclose(std, jnp.ones_like(std)) - - denormalized = running_statistics.denormalize(normalized, state) - self.assert_allclose(denormalized, x) + def test_normalize(self): + state = running_statistics.init_state(specs.Array((5,), jnp.float32)) + x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5) + x1, x2, x3, x4 = jnp.split(x, 4, axis=0) -if __name__ == '__main__': - absltest.main() + state = update_and_validate(state, x1) + state = update_and_validate(state, x2) + state = update_and_validate(state, x3) + state = update_and_validate(state, x4) + normalized = running_statistics.normalize(x, state) + + mean = jnp.mean(normalized) + std = jnp.std(normalized) + self.assert_allclose(mean, jnp.zeros_like(mean)) + self.assert_allclose(std, jnp.ones_like(std)) + + def test_init_normalize(self): + state = running_statistics.init_state(specs.Array((5,), jnp.float32)) + + x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5) + normalized = running_statistics.normalize(x, state) + + self.assert_allclose(normalized, x) + + def test_one_batch_dim(self): + state = running_statistics.init_state(specs.Array((5,), jnp.float32)) + + x = jnp.arange(10, dtype=jnp.float32).reshape(2, 5) + + state = update_and_validate(state, x) + normalized = running_statistics.normalize(x, state) + + mean = jnp.mean(normalized, axis=0) + std = jnp.std(normalized, axis=0) + self.assert_allclose(mean, jnp.zeros_like(mean)) + self.assert_allclose(std, jnp.ones_like(std)) + + def test_clip(self): + state = running_statistics.init_state(specs.Array((), jnp.float32)) + + x = jnp.arange(5, dtype=jnp.float32) + + state = update_and_validate(state, x) + normalized = running_statistics.normalize(x, state, max_abs_value=1.0) + + mean = jnp.mean(normalized) + std = jnp.std(normalized) + self.assert_allclose(mean, jnp.zeros_like(mean)) + self.assert_allclose(std, jnp.ones_like(std) * math.sqrt(0.6)) + + def test_nested_normalize(self): + state = running_statistics.init_state( + {"a": specs.Array((5,), jnp.float32), "b": specs.Array((2,), jnp.float32)} + ) + + x1 = { + "a": jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5), + "b": jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2), + } + x2 = { + "a": jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5) + 20, + "b": jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) + 8, + } + x3 = { + "a": jnp.arange(40, dtype=jnp.float32).reshape(4, 2, 5), + "b": jnp.arange(16, dtype=jnp.float32).reshape(4, 2, 2), + } + + state = update_and_validate(state, x1) + state = update_and_validate(state, x2) + state = update_and_validate(state, x3) + normalized = running_statistics.normalize(x3, state) + + mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized) + std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized) + tree.map_structure(lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean) + tree.map_structure(lambda x: self.assert_allclose(x, jnp.ones_like(x)), std) + + def test_validation(self): + state = running_statistics.init_state(specs.Array((1, 2, 3), jnp.float32)) + + x = jnp.arange(12, dtype=jnp.float32).reshape(2, 2, 3) + with self.assertRaises(AssertionError): + update_and_validate(state, x) + + x = jnp.arange(3, dtype=jnp.float32).reshape(1, 1, 3) + with self.assertRaises(AssertionError): + update_and_validate(state, x) + + def test_int_not_normalized(self): + state = running_statistics.init_state(specs.Array((), jnp.int32)) + + x = jnp.arange(5, dtype=jnp.int32) + + state = update_and_validate(state, x) + normalized = running_statistics.normalize(x, state) + + np.testing.assert_array_equal(normalized, x) + + def test_pmap_update_nested(self): + local_device_count = jax.local_device_count() + state = running_statistics.init_state( + {"a": specs.Array((5,), jnp.float32), "b": specs.Array((2,), jnp.float32)} + ) + + x = { + "a": (jnp.arange(15 * local_device_count, dtype=jnp.float32)).reshape( + local_device_count, 3, 5 + ), + "b": (jnp.arange(6 * local_device_count, dtype=jnp.float32)).reshape( + local_device_count, 3, 2 + ), + } + + devices = jax.local_devices() + state = jax.device_put_replicated(state, devices) + pmap_axis_name = "i" + state = jax.pmap( + functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name), + pmap_axis_name, + )(state, x) + state = jax.pmap( + functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name), + pmap_axis_name, + )(state, x) + normalized = jax.pmap(running_statistics.normalize)(x, state) + + mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized) + std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized) + tree.map_structure(lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean) + tree.map_structure(lambda x: self.assert_allclose(x, jnp.ones_like(x)), std) + + def test_different_structure_normalize(self): + spec = TestNestedSpec( + a=specs.Array((5,), jnp.float32), b=specs.Array((2,), jnp.float32) + ) + state = running_statistics.init_state(spec) + + x = { + "a": jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5), + "b": jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2), + } + + with self.assertRaises(TypeError): + state = update_and_validate(state, x) + + def test_weights(self): + state = running_statistics.init_state(specs.Array((), jnp.float32)) + + x = jnp.arange(5, dtype=jnp.float32) + x_weights = jnp.ones_like(x) + y = 2 * x + 5 + y_weights = 2 * x_weights + z = jnp.concatenate([x, y]) + weights = jnp.concatenate([x_weights, y_weights]) + + state = update_and_validate(state, z, weights=weights) + + self.assertEqual(state.mean, (jnp.mean(x) + 2 * jnp.mean(y)) / 3) + big_z = jnp.concatenate([x, y, y]) + normalized = running_statistics.normalize(big_z, state) + self.assertAlmostEqual(jnp.mean(normalized), 0.0, places=6) + self.assertAlmostEqual(jnp.std(normalized), 1.0, places=6) + + def test_normalize_config(self): + x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5) + x_split = jnp.split(x, 5, axis=0) + + y = jnp.arange(160, dtype=jnp.float32).reshape(20, 2, 4) + y_split = jnp.split(y, 5, axis=0) + + z = {"a": x, "b": y} + + z_split = [{"a": xx, "b": yy} for xx, yy in zip(x_split, y_split)] + + update = jax.jit(running_statistics.update, static_argnames=("config",)) + + config = running_statistics.NestStatisticsConfig((("a",),)) + state = running_statistics.init_state( + {"a": specs.Array((5,), jnp.float32), "b": specs.Array((4,), jnp.float32)} + ) + # Test initialization from the first element. + state = update(state, z_split[0], config=config) + state = update(state, z_split[1], config=config) + state = update(state, z_split[2], config=config) + state = update(state, z_split[3], config=config) + state = update(state, z_split[4], config=config) + + normalize = jax.jit(running_statistics.normalize) + normalized = normalize(z, state) + + for key in normalized: + mean = jnp.mean(normalized[key], axis=(0, 1)) + std = jnp.std(normalized[key], axis=(0, 1)) + if key == "a": + self.assert_allclose( + mean, + jnp.zeros_like(mean), + err_msg=f"key:{key} mean:{mean} normalized:{normalized[key]}", + ) + self.assert_allclose( + std, + jnp.ones_like(std), + err_msg=f"key:{key} std:{std} normalized:{normalized[key]}", + ) + else: + assert key == "b" + np.testing.assert_array_equal( + normalized[key], + z[key], + err_msg=f"z:{z[key]} normalized:{normalized[key]}", + ) + + def test_clip_config(self): + x = jnp.arange(10, dtype=jnp.float32) - 5 + y = jnp.arange(8, dtype=jnp.float32) - 4 + + z = {"x": x, "y": y} + + max_abs_x = 2 + config = running_statistics.NestClippingConfig(((("x",), max_abs_x),)) + + clipped_z = running_statistics.clip(z, config) + + clipped_x = jnp.clip(a=x, a_min=-max_abs_x, a_max=max_abs_x) + np.testing.assert_array_equal(clipped_z["x"], clipped_x) + + np.testing.assert_array_equal(clipped_z["y"], z["y"]) + + def test_denormalize(self): + state = running_statistics.init_state(specs.Array((5,), jnp.float32)) + + x = jnp.arange(100, dtype=jnp.float32).reshape(10, 2, 5) + x1, x2 = jnp.split(x, 2, axis=0) + + state = update_and_validate(state, x1) + state = update_and_validate(state, x2) + normalized = running_statistics.normalize(x, state) + + mean = jnp.mean(normalized) + std = jnp.std(normalized) + self.assert_allclose(mean, jnp.zeros_like(mean)) + self.assert_allclose(std, jnp.ones_like(std)) + + denormalized = running_statistics.denormalize(normalized, state) + self.assert_allclose(denormalized, x) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/jax/savers.py b/acme/jax/savers.py index 2b5d5a013a..0f180cde51 100644 --- a/acme/jax/savers.py +++ b/acme/jax/savers.py @@ -19,78 +19,83 @@ import pickle from typing import Any -from absl import logging -from acme import core -from acme.tf import savers as tf_savers import jax import numpy as np import tree +from absl import logging + +from acme import core +from acme.tf import savers as tf_savers # Internal imports. CheckpointState = Any _DEFAULT_CHECKPOINT_TTL = int(datetime.timedelta(days=5).total_seconds()) -_ARRAY_NAME = 'array_nest' -_EXEMPLAR_NAME = 'nest_exemplar' +_ARRAY_NAME = "array_nest" +_EXEMPLAR_NAME = "nest_exemplar" def restore_from_path(ckpt_dir: str) -> CheckpointState: - """Restore the state stored in ckpt_dir.""" - array_path = os.path.join(ckpt_dir, _ARRAY_NAME) - exemplar_path = os.path.join(ckpt_dir, _EXEMPLAR_NAME) + """Restore the state stored in ckpt_dir.""" + array_path = os.path.join(ckpt_dir, _ARRAY_NAME) + exemplar_path = os.path.join(ckpt_dir, _EXEMPLAR_NAME) - with open(exemplar_path, 'rb') as f: - exemplar = pickle.load(f) + with open(exemplar_path, "rb") as f: + exemplar = pickle.load(f) - with open(array_path, 'rb') as f: - files = np.load(f, allow_pickle=True) - flat_state = [files[key] for key in files.files] - unflattened_tree = tree.unflatten_as(exemplar, flat_state) + with open(array_path, "rb") as f: + files = np.load(f, allow_pickle=True) + flat_state = [files[key] for key in files.files] + unflattened_tree = tree.unflatten_as(exemplar, flat_state) - def maybe_convert_to_python(value, numpy): - return value if numpy else value.item() + def maybe_convert_to_python(value, numpy): + return value if numpy else value.item() - return tree.map_structure(maybe_convert_to_python, unflattened_tree, exemplar) + return tree.map_structure(maybe_convert_to_python, unflattened_tree, exemplar) def save_to_path(ckpt_dir: str, state: CheckpointState): - """Save the state in ckpt_dir.""" + """Save the state in ckpt_dir.""" + + if not os.path.exists(ckpt_dir): + os.makedirs(ckpt_dir) - if not os.path.exists(ckpt_dir): - os.makedirs(ckpt_dir) + is_numpy = lambda x: isinstance(x, (np.ndarray, jax.Array)) + flat_state = tree.flatten(state) + nest_exemplar = tree.map_structure(is_numpy, state) - is_numpy = lambda x: isinstance(x, (np.ndarray, jax.Array)) - flat_state = tree.flatten(state) - nest_exemplar = tree.map_structure(is_numpy, state) + array_path = os.path.join(ckpt_dir, _ARRAY_NAME) + logging.info("Saving flattened array nest to %s", array_path) - array_path = os.path.join(ckpt_dir, _ARRAY_NAME) - logging.info('Saving flattened array nest to %s', array_path) - def _disabled_seek(*_): - raise AttributeError('seek() is disabled on this object.') - with open(array_path, 'wb') as f: - setattr(f, 'seek', _disabled_seek) - np.savez(f, *flat_state) + def _disabled_seek(*_): + raise AttributeError("seek() is disabled on this object.") - exemplar_path = os.path.join(ckpt_dir, _EXEMPLAR_NAME) - logging.info('Saving nest exemplar to %s', exemplar_path) - with open(exemplar_path, 'wb') as f: - pickle.dump(nest_exemplar, f) + with open(array_path, "wb") as f: + setattr(f, "seek", _disabled_seek) + np.savez(f, *flat_state) + + exemplar_path = os.path.join(ckpt_dir, _EXEMPLAR_NAME) + logging.info("Saving nest exemplar to %s", exemplar_path) + with open(exemplar_path, "wb") as f: + pickle.dump(nest_exemplar, f) # Use TF checkpointer. class Checkpointer(tf_savers.Checkpointer): - - def __init__( - self, - object_to_save: core.Saveable, - directory: str = '~/acme', - subdirectory: str = 'default', - **tf_checkpointer_kwargs): - super().__init__(dict(saveable=object_to_save), - directory=directory, - subdirectory=subdirectory, - **tf_checkpointer_kwargs) + def __init__( + self, + object_to_save: core.Saveable, + directory: str = "~/acme", + subdirectory: str = "default", + **tf_checkpointer_kwargs + ): + super().__init__( + dict(saveable=object_to_save), + directory=directory, + subdirectory=subdirectory, + **tf_checkpointer_kwargs + ) CheckpointingRunner = tf_savers.CheckpointingRunner diff --git a/acme/jax/savers_test.py b/acme/jax/savers_test.py index d67403dd6b..99d9905bf3 100644 --- a/acme/jax/savers_test.py +++ b/acme/jax/savers_test.py @@ -16,74 +16,72 @@ from unittest import mock -from acme import core -from acme.jax import savers -from acme.testing import test_utils -from acme.utils import paths import jax.numpy as jnp import numpy as np import tree - from absl.testing import absltest +from acme import core +from acme.jax import savers +from acme.testing import test_utils +from acme.utils import paths -class DummySaveable(core.Saveable): - def __init__(self, state): - self.state = state +class DummySaveable(core.Saveable): + def __init__(self, state): + self.state = state - def save(self): - return self.state + def save(self): + return self.state - def restore(self, state): - self.state = state + def restore(self, state): + self.state = state def nest_assert_equal(a, b): - tree.map_structure(np.testing.assert_array_equal, a, b) + tree.map_structure(np.testing.assert_array_equal, a, b) class SaverTest(test_utils.TestCase): - - def setUp(self): - super().setUp() - self._test_state = { - 'foo': jnp.ones(shape=(8, 4), dtype=jnp.float32), - 'bar': [jnp.zeros(shape=(3, 2), dtype=jnp.int32)], - 'baz': 3, - } - - def test_save_restore(self): - """Checks that we can save and restore state.""" - directory = self.get_tempdir() - savers.save_to_path(directory, self._test_state) - result = savers.restore_from_path(directory) - nest_assert_equal(result, self._test_state) - - def test_checkpointer(self): - """Checks that the Checkpointer class saves and restores as expected.""" - - with mock.patch.object(paths, 'get_unique_id') as mock_unique_id: - mock_unique_id.return_value = ('test',) - - # Given a path and some stateful object... - directory = self.get_tempdir() - x = DummySaveable(self._test_state) - - # If we checkpoint it... - checkpointer = savers.Checkpointer(x, directory, time_delta_minutes=0) - checkpointer.save() - - # The checkpointer should restore the object's state. - x.state = None - checkpointer.restore() - nest_assert_equal(x.state, self._test_state) - - # Checkpointers should also attempt a restore at construction time. - x.state = None - savers.Checkpointer(x, directory, time_delta_minutes=0) - nest_assert_equal(x.state, self._test_state) - - -if __name__ == '__main__': - absltest.main() + def setUp(self): + super().setUp() + self._test_state = { + "foo": jnp.ones(shape=(8, 4), dtype=jnp.float32), + "bar": [jnp.zeros(shape=(3, 2), dtype=jnp.int32)], + "baz": 3, + } + + def test_save_restore(self): + """Checks that we can save and restore state.""" + directory = self.get_tempdir() + savers.save_to_path(directory, self._test_state) + result = savers.restore_from_path(directory) + nest_assert_equal(result, self._test_state) + + def test_checkpointer(self): + """Checks that the Checkpointer class saves and restores as expected.""" + + with mock.patch.object(paths, "get_unique_id") as mock_unique_id: + mock_unique_id.return_value = ("test",) + + # Given a path and some stateful object... + directory = self.get_tempdir() + x = DummySaveable(self._test_state) + + # If we checkpoint it... + checkpointer = savers.Checkpointer(x, directory, time_delta_minutes=0) + checkpointer.save() + + # The checkpointer should restore the object's state. + x.state = None + checkpointer.restore() + nest_assert_equal(x.state, self._test_state) + + # Checkpointers should also attempt a restore at construction time. + x.state = None + savers.Checkpointer(x, directory, time_delta_minutes=0) + nest_assert_equal(x.state, self._test_state) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/jax/snapshotter.py b/acme/jax/snapshotter.py index 81ca4784aa..7644aa0087 100644 --- a/acme/jax/snapshotter.py +++ b/acme/jax/snapshotter.py @@ -18,99 +18,99 @@ import time from typing import Callable, Dict, List, Optional, Sequence, Tuple +import tensorflow as tf from absl import logging +from jax.experimental import jax2tf + from acme import core from acme.jax import types -from acme.utils import signals -from acme.utils import paths -from jax.experimental import jax2tf -import tensorflow as tf +from acme.utils import paths, signals # Internal imports. class JAXSnapshotter(core.Worker): - """Periodically fetches new version of params and stores tf.saved_models.""" - - # NOTE: External contributor please refrain from modifying the high level of - # the API defined here. - - def __init__(self, - variable_source: core.VariableSource, - models: Dict[str, Callable[[core.VariableSource], - types.ModelToSnapshot]], - path: str, - subdirectory: Optional[str] = None, - max_to_keep: Optional[int] = None, - add_uid: bool = False): - self._variable_source = variable_source - self._models = models - if subdirectory is not None: - self._path = paths.process_path(path, subdirectory, add_uid=add_uid) - else: - self._path = paths.process_path(path, add_uid=add_uid) - self._max_to_keep = max_to_keep - self._snapshot_paths: Optional[List[str]] = None - - # Handle preemption signal. Note that this must happen in the main thread. - def _signal_handler(self): - logging.info('Caught SIGTERM: forcing models save.') - self._save() - - def _save(self): - if not self._snapshot_paths: - # Lazy discovery of already existing snapshots. - self._snapshot_paths = os.listdir(self._path) - self._snapshot_paths.sort(reverse=True) - - snapshot_location = os.path.join(self._path, time.strftime('%Y%m%d-%H%M%S')) - if self._snapshot_paths and self._snapshot_paths[0] == snapshot_location: - logging.info('Snapshot for the current time already exists.') - return - - # To make sure models are captured as close as possible from the same time - # we gather all the `ModelToSnapshot` in a 1st loop. We then convert/saved - # them in another loop as this operation can be slow. - models_and_paths = self._get_models_and_paths(path=snapshot_location) - self._snapshot_paths.insert(0, snapshot_location) - - for model, saving_path in models_and_paths: - self._snapshot_model(model=model, saving_path=saving_path) - - # Delete any excess snapshots. - while self._max_to_keep and len(self._snapshot_paths) > self._max_to_keep: - paths.rmdir(os.path.join(self._path, self._snapshot_paths.pop())) - - def _get_models_and_paths( - self, path: str) -> Sequence[Tuple[types.ModelToSnapshot, str]]: - """Gets the models to save asssociated with their saving path.""" - models_and_paths = [] - for name, model_fn in self._models.items(): - model = model_fn(self._variable_source) - model_path = os.path.join(path, name) - models_and_paths.append((model, model_path)) - return models_and_paths - - def _snapshot_model(self, model: types.ModelToSnapshot, - saving_path: str) -> None: - module = model_to_tf_module(model) - tf.saved_model.save(module, saving_path) - - def run(self): - """Runs the saver.""" - with signals.runtime_terminator(self._signal_handler): - while True: + """Periodically fetches new version of params and stores tf.saved_models.""" + + # NOTE: External contributor please refrain from modifying the high level of + # the API defined here. + + def __init__( + self, + variable_source: core.VariableSource, + models: Dict[str, Callable[[core.VariableSource], types.ModelToSnapshot]], + path: str, + subdirectory: Optional[str] = None, + max_to_keep: Optional[int] = None, + add_uid: bool = False, + ): + self._variable_source = variable_source + self._models = models + if subdirectory is not None: + self._path = paths.process_path(path, subdirectory, add_uid=add_uid) + else: + self._path = paths.process_path(path, add_uid=add_uid) + self._max_to_keep = max_to_keep + self._snapshot_paths: Optional[List[str]] = None + + # Handle preemption signal. Note that this must happen in the main thread. + def _signal_handler(self): + logging.info("Caught SIGTERM: forcing models save.") self._save() - time.sleep(5 * 60) - -def model_to_tf_module(model: types.ModelToSnapshot) -> tf.Module: + def _save(self): + if not self._snapshot_paths: + # Lazy discovery of already existing snapshots. + self._snapshot_paths = os.listdir(self._path) + self._snapshot_paths.sort(reverse=True) + + snapshot_location = os.path.join(self._path, time.strftime("%Y%m%d-%H%M%S")) + if self._snapshot_paths and self._snapshot_paths[0] == snapshot_location: + logging.info("Snapshot for the current time already exists.") + return + + # To make sure models are captured as close as possible from the same time + # we gather all the `ModelToSnapshot` in a 1st loop. We then convert/saved + # them in another loop as this operation can be slow. + models_and_paths = self._get_models_and_paths(path=snapshot_location) + self._snapshot_paths.insert(0, snapshot_location) + + for model, saving_path in models_and_paths: + self._snapshot_model(model=model, saving_path=saving_path) + + # Delete any excess snapshots. + while self._max_to_keep and len(self._snapshot_paths) > self._max_to_keep: + paths.rmdir(os.path.join(self._path, self._snapshot_paths.pop())) + + def _get_models_and_paths( + self, path: str + ) -> Sequence[Tuple[types.ModelToSnapshot, str]]: + """Gets the models to save asssociated with their saving path.""" + models_and_paths = [] + for name, model_fn in self._models.items(): + model = model_fn(self._variable_source) + model_path = os.path.join(path, name) + models_and_paths.append((model, model_path)) + return models_and_paths + + def _snapshot_model(self, model: types.ModelToSnapshot, saving_path: str) -> None: + module = model_to_tf_module(model) + tf.saved_model.save(module, saving_path) + + def run(self): + """Runs the saver.""" + with signals.runtime_terminator(self._signal_handler): + while True: + self._save() + time.sleep(5 * 60) - def jax_fn_to_save(**kwargs): - return model.model(model.params, **kwargs) - module = tf.Module() - module.f = tf.function(jax2tf.convert(jax_fn_to_save), autograph=False) - # Traces input to ensure the model has the correct shapes. - module.f(**model.dummy_kwargs) - return module +def model_to_tf_module(model: types.ModelToSnapshot) -> tf.Module: + def jax_fn_to_save(**kwargs): + return model.model(model.params, **kwargs) + + module = tf.Module() + module.f = tf.function(jax2tf.convert(jax_fn_to_save), autograph=False) + # Traces input to ensure the model has the correct shapes. + module.f(**model.dummy_kwargs) + return module diff --git a/acme/jax/snapshotter_test.py b/acme/jax/snapshotter_test.py index 40a1a7509f..8db1ddd68e 100644 --- a/acme/jax/snapshotter_test.py +++ b/acme/jax/snapshotter_test.py @@ -18,122 +18,121 @@ import time from typing import Any, Sequence -from acme import core -from acme.jax import snapshotter -from acme.jax import types -from acme.testing import test_utils import jax.numpy as jnp - from absl.testing import absltest +from acme import core +from acme.jax import snapshotter, types +from acme.testing import test_utils + def _model0(params, x1, x2): - return params['w0'] * jnp.sin(x1) + params['w1'] * jnp.cos(x2) + return params["w0"] * jnp.sin(x1) + params["w1"] * jnp.cos(x2) def _model1(params, x): - return params['p0'] * jnp.log(x) + return params["p0"] * jnp.log(x) class _DummyVariableSource(core.VariableSource): - - def __init__(self): - self._params_model0 = { - 'w0': jnp.ones([2, 3], dtype=jnp.float32), - 'w1': 2 * jnp.ones([2, 3], dtype=jnp.float32), - } - - self._params_model1 = { - 'p0': jnp.ones([3, 1], dtype=jnp.float32), - } - - def get_variables(self, names: Sequence[str]) -> Sequence[Any]: # pytype: disable=signature-mismatch # overriding-return-type-checks - variables = [] - for n in names: - if n == 'params_model0': - variables.append(self._params_model0) - elif n == 'params_model1': - variables.append(self._params_model1) - else: - raise ValueError('Unknow variable name: {n}') - return variables + def __init__(self): + self._params_model0 = { + "w0": jnp.ones([2, 3], dtype=jnp.float32), + "w1": 2 * jnp.ones([2, 3], dtype=jnp.float32), + } + + self._params_model1 = { + "p0": jnp.ones([3, 1], dtype=jnp.float32), + } + + def get_variables( + self, names: Sequence[str] + ) -> Sequence[ + Any + ]: # pytype: disable=signature-mismatch # overriding-return-type-checks + variables = [] + for n in names: + if n == "params_model0": + variables.append(self._params_model0) + elif n == "params_model1": + variables.append(self._params_model1) + else: + raise ValueError("Unknow variable name: {n}") + return variables def _get_model0(variable_source: core.VariableSource) -> types.ModelToSnapshot: - return types.ModelToSnapshot( - model=_model0, - params=variable_source.get_variables(['params_model0'])[0], - dummy_kwargs={ - 'x1': jnp.ones([2, 3], dtype=jnp.float32), - 'x2': jnp.ones([2, 3], dtype=jnp.float32), - }, - ) + return types.ModelToSnapshot( + model=_model0, + params=variable_source.get_variables(["params_model0"])[0], + dummy_kwargs={ + "x1": jnp.ones([2, 3], dtype=jnp.float32), + "x2": jnp.ones([2, 3], dtype=jnp.float32), + }, + ) def _get_model1(variable_source: core.VariableSource) -> types.ModelToSnapshot: - return types.ModelToSnapshot( - model=_model1, - params=variable_source.get_variables(['params_model1'])[0], - dummy_kwargs={ - 'x': jnp.ones([3, 1], dtype=jnp.float32), - }, - ) + return types.ModelToSnapshot( + model=_model1, + params=variable_source.get_variables(["params_model1"])[0], + dummy_kwargs={"x": jnp.ones([3, 1], dtype=jnp.float32),}, + ) class SnapshotterTest(test_utils.TestCase): - - def setUp(self): - super().setUp() - self._test_models = {'model0': _get_model0, 'model1': _get_model1} - - def _check_snapshot(self, directory: str, name: str): - self.assertTrue(os.path.exists(os.path.join(directory, name, 'model0'))) - self.assertTrue(os.path.exists(os.path.join(directory, name, 'model1'))) - - def test_snapshotter(self): - """Checks that the Snapshotter class saves as expected.""" - directory = self.get_tempdir() - - models_snapshotter = snapshotter.JAXSnapshotter( - variable_source=_DummyVariableSource(), - models=self._test_models, - path=directory, - max_to_keep=2, - add_uid=False, - ) - models_snapshotter._save() - - # The snapshots are written in a folder of the form: - # PATH/{time.strftime}/MODEL_NAME - first_snapshots = os.listdir(directory) - self.assertEqual(len(first_snapshots), 1) - self._check_snapshot(directory, first_snapshots[0]) - # Make sure that the second snapshot is constructed. - time.sleep(1.1) - models_snapshotter._save() - snapshots = os.listdir(directory) - self.assertEqual(len(snapshots), 2) - self._check_snapshot(directory, snapshots[0]) - self._check_snapshot(directory, snapshots[1]) - - # Make sure that new snapshotter deletes the oldest snapshot upon _save(). - time.sleep(1.1) - models_snapshotter2 = snapshotter.JAXSnapshotter( - variable_source=_DummyVariableSource(), - models=self._test_models, - path=directory, - max_to_keep=2, - add_uid=False, - ) - self.assertEqual(snapshots, os.listdir(directory)) - time.sleep(1.1) - models_snapshotter2._save() - snapshots = os.listdir(directory) - self.assertNotIn(first_snapshots[0], snapshots) - self.assertEqual(len(snapshots), 2) - self._check_snapshot(directory, snapshots[0]) - self._check_snapshot(directory, snapshots[1]) - - -if __name__ == '__main__': - absltest.main() + def setUp(self): + super().setUp() + self._test_models = {"model0": _get_model0, "model1": _get_model1} + + def _check_snapshot(self, directory: str, name: str): + self.assertTrue(os.path.exists(os.path.join(directory, name, "model0"))) + self.assertTrue(os.path.exists(os.path.join(directory, name, "model1"))) + + def test_snapshotter(self): + """Checks that the Snapshotter class saves as expected.""" + directory = self.get_tempdir() + + models_snapshotter = snapshotter.JAXSnapshotter( + variable_source=_DummyVariableSource(), + models=self._test_models, + path=directory, + max_to_keep=2, + add_uid=False, + ) + models_snapshotter._save() + + # The snapshots are written in a folder of the form: + # PATH/{time.strftime}/MODEL_NAME + first_snapshots = os.listdir(directory) + self.assertEqual(len(first_snapshots), 1) + self._check_snapshot(directory, first_snapshots[0]) + # Make sure that the second snapshot is constructed. + time.sleep(1.1) + models_snapshotter._save() + snapshots = os.listdir(directory) + self.assertEqual(len(snapshots), 2) + self._check_snapshot(directory, snapshots[0]) + self._check_snapshot(directory, snapshots[1]) + + # Make sure that new snapshotter deletes the oldest snapshot upon _save(). + time.sleep(1.1) + models_snapshotter2 = snapshotter.JAXSnapshotter( + variable_source=_DummyVariableSource(), + models=self._test_models, + path=directory, + max_to_keep=2, + add_uid=False, + ) + self.assertEqual(snapshots, os.listdir(directory)) + time.sleep(1.1) + models_snapshotter2._save() + snapshots = os.listdir(directory) + self.assertNotIn(first_snapshots[0], snapshots) + self.assertEqual(len(snapshots), 2) + self._check_snapshot(directory, snapshots[0]) + self._check_snapshot(directory, snapshots[1]) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/jax/types.py b/acme/jax/types.py index d1c4f69d10..432de6fbe5 100644 --- a/acme/jax/types.py +++ b/acme/jax/types.py @@ -17,20 +17,21 @@ import dataclasses from typing import Any, Callable, Dict, Generic, Mapping, TypeVar -from acme import types import chex import dm_env import jax import jax.numpy as jnp +from acme import types + PRNGKey = jax.random.KeyArray -Networks = TypeVar('Networks') +Networks = TypeVar("Networks") """Container for all agent network components.""" -Policy = TypeVar('Policy') +Policy = TypeVar("Policy") """Function or container for agent policy functions.""" -Sample = TypeVar('Sample') +Sample = TypeVar("Sample") """Sample from the demonstrations or replay buffer.""" -TrainingState = TypeVar('TrainingState') +TrainingState = TypeVar("TrainingState") TrainingMetrics = Mapping[str, jnp.ndarray] """Metrics returned by the training step. @@ -49,8 +50,8 @@ @chex.dataclass(frozen=True, mappable_dataclass=False) class TrainingStepOutput(Generic[TrainingState]): - state: TrainingState - metrics: TrainingMetrics + state: TrainingState + metrics: TrainingMetrics Seed = int @@ -59,13 +60,14 @@ class TrainingStepOutput(Generic[TrainingState]): @dataclasses.dataclass class ModelToSnapshot: - """Stores all necessary info to be able to save a model. + """Stores all necessary info to be able to save a model. Attributes: model: a jax function to be saved. params: fixed params to be passed to the function. dummy_kwargs: arguments to be passed to the function. """ - model: Any # Callable[params, **dummy_kwargs] - params: Any - dummy_kwargs: Dict[str, Any] + + model: Any # Callable[params, **dummy_kwargs] + params: Any + dummy_kwargs: Dict[str, Any] diff --git a/acme/jax/utils.py b/acme/jax/utils.py index 60ed8aaf39..e9f4c11692 100644 --- a/acme/jax/utils.py +++ b/acme/jax/utils.py @@ -18,34 +18,42 @@ import itertools import queue import threading -from typing import Callable, Iterable, Iterator, NamedTuple, Optional, Sequence, Tuple, TypeVar +from typing import ( + Callable, + Iterable, + Iterator, + NamedTuple, + Optional, + Sequence, + Tuple, + TypeVar, +) -from absl import logging -from acme import core -from acme import types -from acme.jax import types as jax_types import haiku as hk import jax import jax.numpy as jnp import numpy as np import reverb import tree +from absl import logging +from acme import core, types +from acme.jax import types as jax_types -F = TypeVar('F', bound=Callable) -N = TypeVar('N', bound=types.NestedArray) -T = TypeVar('T') +F = TypeVar("F", bound=Callable) +N = TypeVar("N", bound=types.NestedArray) +T = TypeVar("T") NUM_PREFETCH_THREADS = 1 def add_batch_dim(values: types.Nest) -> types.NestedArray: - return jax.tree_map(lambda x: jnp.expand_dims(x, axis=0), values) + return jax.tree_map(lambda x: jnp.expand_dims(x, axis=0), values) def _flatten(x: jnp.ndarray, num_batch_dims: int) -> jnp.ndarray: - """Flattens the input, preserving the first ``num_batch_dims`` dimensions. + """Flattens the input, preserving the first ``num_batch_dims`` dimensions. If the input has fewer than ``num_batch_dims`` dimensions, it is returned unchanged. @@ -59,72 +67,70 @@ def _flatten(x: jnp.ndarray, num_batch_dims: int) -> jnp.ndarray: Returns: flattened input. """ - # TODO(b/173492429): consider throwing an error instead. - if x.ndim < num_batch_dims: - return x - return jnp.reshape(x, list(x.shape[:num_batch_dims]) + [-1]) + # TODO(b/173492429): consider throwing an error instead. + if x.ndim < num_batch_dims: + return x + return jnp.reshape(x, list(x.shape[:num_batch_dims]) + [-1]) -def batch_concat( - values: types.NestedArray, - num_batch_dims: int = 1, -) -> jnp.ndarray: - """Flatten and concatenate nested array structure, keeping batch dims.""" - flatten_fn = lambda x: _flatten(x, num_batch_dims) - flat_leaves = tree.map_structure(flatten_fn, values) - return jnp.concatenate(tree.flatten(flat_leaves), axis=-1) +def batch_concat(values: types.NestedArray, num_batch_dims: int = 1,) -> jnp.ndarray: + """Flatten and concatenate nested array structure, keeping batch dims.""" + flatten_fn = lambda x: _flatten(x, num_batch_dims) + flat_leaves = tree.map_structure(flatten_fn, values) + return jnp.concatenate(tree.flatten(flat_leaves), axis=-1) def zeros_like(nest: types.Nest, dtype=None) -> types.NestedArray: - return jax.tree_map(lambda x: jnp.zeros(x.shape, dtype or x.dtype), nest) + return jax.tree_map(lambda x: jnp.zeros(x.shape, dtype or x.dtype), nest) def ones_like(nest: types.Nest, dtype=None) -> types.NestedArray: - return jax.tree_map(lambda x: jnp.ones(x.shape, dtype or x.dtype), nest) + return jax.tree_map(lambda x: jnp.ones(x.shape, dtype or x.dtype), nest) def squeeze_batch_dim(nest: types.Nest) -> types.NestedArray: - return jax.tree_map(lambda x: jnp.squeeze(x, axis=0), nest) + return jax.tree_map(lambda x: jnp.squeeze(x, axis=0), nest) def to_numpy_squeeze(values: types.Nest) -> types.NestedArray: - """Converts to numpy and squeezes out dummy batch dimension.""" - return jax.tree_map(lambda x: np.asarray(x).squeeze(axis=0), values) + """Converts to numpy and squeezes out dummy batch dimension.""" + return jax.tree_map(lambda x: np.asarray(x).squeeze(axis=0), values) def to_numpy(values: types.Nest) -> types.NestedArray: - return jax.tree_map(np.asarray, values) + return jax.tree_map(np.asarray, values) def fetch_devicearray(values: types.Nest) -> types.Nest: - """Fetches and converts any DeviceArrays to np.ndarrays.""" - return tree.map_structure(_fetch_devicearray, values) + """Fetches and converts any DeviceArrays to np.ndarrays.""" + return tree.map_structure(_fetch_devicearray, values) def _fetch_devicearray(x): - if isinstance(x, jax.Array): - return np.asarray(x) - return x + if isinstance(x, jax.Array): + return np.asarray(x) + return x def batch_to_sequence(values: types.Nest) -> types.NestedArray: - return jax.tree_map( - lambda x: jnp.transpose(x, axes=(1, 0, *range(2, len(x.shape)))), values) + return jax.tree_map( + lambda x: jnp.transpose(x, axes=(1, 0, *range(2, len(x.shape)))), values + ) def tile_array(array: jnp.ndarray, multiple: int) -> jnp.ndarray: - """Tiles `multiple` copies of `array` along a new leading axis.""" - return jnp.stack([array] * multiple) + """Tiles `multiple` copies of `array` along a new leading axis.""" + return jnp.stack([array] * multiple) def tile_nested(inputs: types.Nest, multiple: int) -> types.Nest: - """Tiles tensors in a nested structure along a new leading axis.""" - tile = functools.partial(tile_array, multiple=multiple) - return jax.tree_map(tile, inputs) + """Tiles tensors in a nested structure along a new leading axis.""" + tile = functools.partial(tile_array, multiple=multiple) + return jax.tree_map(tile, inputs) def maybe_recover_lstm_type(state: types.NestedArray) -> types.NestedArray: - """Recovers the type hk.LSTMState if LSTMState is in the type name. + """Recovers the type hk.LSTMState if LSTMState is in the type name. When the recurrent state of recurrent neural networks (RNN) is deserialized, for example when it is sampled from replay, it is sometimes repacked in a type @@ -140,7 +146,7 @@ def maybe_recover_lstm_type(state: types.NestedArray) -> types.NestedArray: Either the state unchanged if it is anything but an LSTMState, otherwise returns the state arrays properly contained in an hk.LSTMState. """ - return hk.LSTMState(*state) if type(state).__name__ == 'LSTMState' else state + return hk.LSTMState(*state) if type(state).__name__ == "LSTMState" else state def prefetch( @@ -149,21 +155,21 @@ def prefetch( device: Optional[jax.Device] = None, num_threads: int = NUM_PREFETCH_THREADS, ) -> core.PrefetchingIterator[T]: - """Returns prefetching iterator with additional 'ready' method.""" + """Returns prefetching iterator with additional 'ready' method.""" - return PrefetchIterator(iterable, buffer_size, device, num_threads) + return PrefetchIterator(iterable, buffer_size, device, num_threads) class PrefetchingSplit(NamedTuple): - host: types.NestedArray - device: types.NestedArray + host: types.NestedArray + device: types.NestedArray _SplitFunction = Callable[[types.NestedArray], PrefetchingSplit] def keep_key_on_host(sample: reverb.ReplaySample) -> PrefetchingSplit: - """Returns PrefetchingSplit which keeps uint64 reverb key on the host. + """Returns PrefetchingSplit which keeps uint64 reverb key on the host. We want to avoid truncation of the uint64 reverb key by JAX. @@ -173,7 +179,7 @@ def keep_key_on_host(sample: reverb.ReplaySample) -> PrefetchingSplit: Returns: PrefetchingSplit with device having the reverb sample, and key on host. """ - return PrefetchingSplit(host=sample.info.key, device=sample) + return PrefetchingSplit(host=sample.info.key, device=sample) def device_put( @@ -181,13 +187,11 @@ def device_put( device: jax.Device, split_fn: Optional[_SplitFunction] = None, ): - """Returns iterator that samples an item and places it on the device.""" + """Returns iterator that samples an item and places it on the device.""" - return PutToDevicesIterable( - iterable=iterable, - pmapped_user=False, - devices=[device], - split_fn=split_fn) + return PutToDevicesIterable( + iterable=iterable, pmapped_user=False, devices=[device], split_fn=split_fn + ) def multi_device_put( @@ -195,14 +199,15 @@ def multi_device_put( devices: Sequence[jax.Device], split_fn: Optional[_SplitFunction] = None, ): - """Returns iterator that, per device, samples an item and places on device.""" + """Returns iterator that, per device, samples an item and places on device.""" - return PutToDevicesIterable( - iterable=iterable, pmapped_user=True, devices=devices, split_fn=split_fn) + return PutToDevicesIterable( + iterable=iterable, pmapped_user=True, devices=devices, split_fn=split_fn + ) class PutToDevicesIterable(Iterable[types.NestedArray]): - """Per device, samples an item from iterator and places on device. + """Per device, samples an item from iterator and places on device. if pmapped_user: Items from the resulting generator are intended to be used in a pmapped @@ -232,14 +237,14 @@ class PutToDevicesIterable(Iterable[types.NestedArray]): the producer, but after it finishes executing. """ - def __init__( - self, - iterable: Iterable[types.NestedArray], - pmapped_user: bool, - devices: Sequence[jax.Device], - split_fn: Optional[_SplitFunction] = None, - ): - """Constructs PutToDevicesIterable. + def __init__( + self, + iterable: Iterable[types.NestedArray], + pmapped_user: bool, + devices: Sequence[jax.Device], + split_fn: Optional[_SplitFunction] = None, + ): + """Constructs PutToDevicesIterable. Args: iterable: A python iterable. This is used to build the python prefetcher. @@ -256,62 +261,67 @@ def __init__( ValueError: If devices list is empty, or if pmapped_use=False and more than 1 device is provided. """ - self.num_devices = len(devices) - if self.num_devices == 0: - raise ValueError('At least one device must be specified.') - if (not pmapped_user) and (self.num_devices != 1): - raise ValueError('User is not implemented with pmapping but len(devices) ' - f'= {len(devices)} is not equal to 1! Devices given are:' - f'\n{devices}') - - self.iterable = iterable - self.pmapped_user = pmapped_user - self.split_fn = split_fn - self.devices = devices - self.iterator = iter(self.iterable) - - def __iter__(self) -> Iterator[types.NestedArray]: - # It is important to structure the Iterable like this, because in - # JustPrefetchIterator we must build a new iterable for each thread. - # This is crucial if working with tensorflow datasets because tf.Graph - # objects are thread local. - self.iterator = iter(self.iterable) - return self - - def __next__(self) -> types.NestedArray: - try: - if not self.pmapped_user: - item = next(self.iterator) - if self.split_fn is None: - return jax.device_put(item, self.devices[0]) - item_split = self.split_fn(item) - return PrefetchingSplit( - host=item_split.host, - device=jax.device_put(item_split.device, self.devices[0])) - - items = itertools.islice(self.iterator, self.num_devices) - items = tuple(items) - if len(items) < self.num_devices: - raise StopIteration - if self.split_fn is None: - return jax.device_put_sharded(tuple(items), self.devices) - else: - # ((host: x1, device: y1), ..., (host: xN, device: yN)). - items_split = (self.split_fn(item) for item in items) - # (host: (x1, ..., xN), device: (y1, ..., yN)). - split = tree.map_structure_up_to( - PrefetchingSplit(None, None), lambda *x: x, *items_split) - - return PrefetchingSplit( - host=np.stack(split.host), - device=jax.device_put_sharded(split.device, self.devices)) - - except StopIteration: - raise - - except Exception: # pylint: disable=broad-except - logging.exception('Error for %s', self.iterable) - raise + self.num_devices = len(devices) + if self.num_devices == 0: + raise ValueError("At least one device must be specified.") + if (not pmapped_user) and (self.num_devices != 1): + raise ValueError( + "User is not implemented with pmapping but len(devices) " + f"= {len(devices)} is not equal to 1! Devices given are:" + f"\n{devices}" + ) + + self.iterable = iterable + self.pmapped_user = pmapped_user + self.split_fn = split_fn + self.devices = devices + self.iterator = iter(self.iterable) + + def __iter__(self) -> Iterator[types.NestedArray]: + # It is important to structure the Iterable like this, because in + # JustPrefetchIterator we must build a new iterable for each thread. + # This is crucial if working with tensorflow datasets because tf.Graph + # objects are thread local. + self.iterator = iter(self.iterable) + return self + + def __next__(self) -> types.NestedArray: + try: + if not self.pmapped_user: + item = next(self.iterator) + if self.split_fn is None: + return jax.device_put(item, self.devices[0]) + item_split = self.split_fn(item) + return PrefetchingSplit( + host=item_split.host, + device=jax.device_put(item_split.device, self.devices[0]), + ) + + items = itertools.islice(self.iterator, self.num_devices) + items = tuple(items) + if len(items) < self.num_devices: + raise StopIteration + if self.split_fn is None: + return jax.device_put_sharded(tuple(items), self.devices) + else: + # ((host: x1, device: y1), ..., (host: xN, device: yN)). + items_split = (self.split_fn(item) for item in items) + # (host: (x1, ..., xN), device: (y1, ..., yN)). + split = tree.map_structure_up_to( + PrefetchingSplit(None, None), lambda *x: x, *items_split + ) + + return PrefetchingSplit( + host=np.stack(split.host), + device=jax.device_put_sharded(split.device, self.devices), + ) + + except StopIteration: + raise + + except Exception: # pylint: disable=broad-except + logging.exception("Error for %s", self.iterable) + raise def sharded_prefetch( @@ -321,7 +331,7 @@ def sharded_prefetch( split_fn: Optional[_SplitFunction] = None, devices: Optional[Sequence[jax.Device]] = None, ) -> core.PrefetchingIterator: - """Performs sharded prefetching from an iterable in separate threads. + """Performs sharded prefetching from an iterable in separate threads. Elements from the resulting generator are intended to be used in a jax.pmap call. Every element is a sharded prefetched array with an additional replica @@ -349,24 +359,25 @@ def sharded_prefetch( the producer, but after it finishes executing. """ - devices = devices or jax.local_devices() + devices = devices or jax.local_devices() - iterable = PutToDevicesIterable( - iterable=iterable, pmapped_user=True, devices=devices, split_fn=split_fn) + iterable = PutToDevicesIterable( + iterable=iterable, pmapped_user=True, devices=devices, split_fn=split_fn + ) - return prefetch(iterable, buffer_size, device=None, num_threads=num_threads) + return prefetch(iterable, buffer_size, device=None, num_threads=num_threads) def replicate_in_all_devices( nest: N, devices: Optional[Sequence[jax.Device]] = None ) -> N: - """Replicate array nest in all available devices.""" - devices = devices or jax.local_devices() - return jax.device_put_sharded([nest] * len(devices), devices) + """Replicate array nest in all available devices.""" + devices = devices or jax.local_devices() + return jax.device_put_sharded([nest] * len(devices), devices) def get_from_first_device(nest: N, as_numpy: bool = True) -> N: - """Gets the first array of a nest of `jax.Array`s. + """Gets the first array of a nest of `jax.Array`s. Args: nest: A nest of `jax.Array`s. @@ -379,16 +390,14 @@ def get_from_first_device(nest: N, as_numpy: bool = True) -> N: the same device as the sharded device array). If `as_numpy=True` then the array will be copied to the host machine and converted into a `np.ndarray`. """ - zeroth_nest = jax.tree_map(lambda x: x[0], nest) - return jax.device_get(zeroth_nest) if as_numpy else zeroth_nest + zeroth_nest = jax.tree_map(lambda x: x[0], nest) + return jax.device_get(zeroth_nest) if as_numpy else zeroth_nest def mapreduce( - f: F, - reduce_fn: Optional[Callable[[jax.Array], jax.Array]] = None, - **vmap_kwargs, + f: F, reduce_fn: Optional[Callable[[jax.Array], jax.Array]] = None, **vmap_kwargs, ) -> F: - """A simple decorator that transforms `f` into (`reduce_fn` o vmap o f). + """A simple decorator that transforms `f` into (`reduce_fn` o vmap o f). By default, we vmap over axis 0, and the `reduce_fn` is jnp.mean over axis 0. Note that the call signature of `f` is invariant under this transformation. @@ -405,32 +414,32 @@ def mapreduce( g: A pure function over batches of examples. """ - if reduce_fn is None: - reduce_fn = lambda x: jnp.mean(x, axis=0) + if reduce_fn is None: + reduce_fn = lambda x: jnp.mean(x, axis=0) - vmapped_f = jax.vmap(f, **vmap_kwargs) + vmapped_f = jax.vmap(f, **vmap_kwargs) - def g(*args, **kwargs): - return jax.tree_map(reduce_fn, vmapped_f(*args, **kwargs)) + def g(*args, **kwargs): + return jax.tree_map(reduce_fn, vmapped_f(*args, **kwargs)) - return g + return g -_TrainingState = TypeVar('_TrainingState') -_TrainingData = TypeVar('_TrainingData') -_TrainingAux = TypeVar('_TrainingAux') +_TrainingState = TypeVar("_TrainingState") +_TrainingData = TypeVar("_TrainingData") +_TrainingAux = TypeVar("_TrainingAux") # TODO(b/192806089): migrate all callers to process_many_batches and remove this # method. def process_multiple_batches( - process_one_batch: Callable[[_TrainingState, _TrainingData], - Tuple[_TrainingState, _TrainingAux]], + process_one_batch: Callable[ + [_TrainingState, _TrainingData], Tuple[_TrainingState, _TrainingAux] + ], num_batches: int, - postprocess_aux: Optional[Callable[[_TrainingAux], _TrainingAux]] = None -) -> Callable[[_TrainingState, _TrainingData], Tuple[_TrainingState, - _TrainingAux]]: - """Makes 'process_one_batch' process multiple batches at once. + postprocess_aux: Optional[Callable[[_TrainingAux], _TrainingAux]] = None, +) -> Callable[[_TrainingState, _TrainingData], Tuple[_TrainingState, _TrainingAux]]: + """Makes 'process_one_batch' process multiple batches at once. Args: process_one_batch: a function that takes 'state' and 'data', and returns @@ -443,74 +452,80 @@ def process_multiple_batches( A function with the same interface as 'process_one_batch' which processes multiple batches at once. """ - assert num_batches >= 1 - if num_batches == 1: - if not postprocess_aux: - return process_one_batch - def _process_one_batch(state, data): - state, aux = process_one_batch(state, data) - return state, postprocess_aux(aux) - return _process_one_batch + assert num_batches >= 1 + if num_batches == 1: + if not postprocess_aux: + return process_one_batch + + def _process_one_batch(state, data): + state, aux = process_one_batch(state, data) + return state, postprocess_aux(aux) + + return _process_one_batch - if postprocess_aux is None: - postprocess_aux = lambda x: jax.tree_map(jnp.mean, x) + if postprocess_aux is None: + postprocess_aux = lambda x: jax.tree_map(jnp.mean, x) - def _process_multiple_batches(state, data): - data = jax.tree_map( - lambda a: jnp.reshape(a, (num_batches, -1, *a.shape[1:])), data) + def _process_multiple_batches(state, data): + data = jax.tree_map( + lambda a: jnp.reshape(a, (num_batches, -1, *a.shape[1:])), data + ) - state, aux = jax.lax.scan( - process_one_batch, state, data, length=num_batches) - return state, postprocess_aux(aux) + state, aux = jax.lax.scan(process_one_batch, state, data, length=num_batches) + return state, postprocess_aux(aux) - return _process_multiple_batches + return _process_multiple_batches def process_many_batches( - process_one_batch: Callable[[_TrainingState, _TrainingData], - jax_types.TrainingStepOutput[_TrainingState]], + process_one_batch: Callable[ + [_TrainingState, _TrainingData], jax_types.TrainingStepOutput[_TrainingState] + ], num_batches: int, - postprocess_aux: Optional[Callable[[jax_types.TrainingMetrics], - jax_types.TrainingMetrics]] = None -) -> Callable[[_TrainingState, _TrainingData], - jax_types.TrainingStepOutput[_TrainingState]]: - """The version of 'process_multiple_batches' with stronger typing.""" + postprocess_aux: Optional[ + Callable[[jax_types.TrainingMetrics], jax_types.TrainingMetrics] + ] = None, +) -> Callable[ + [_TrainingState, _TrainingData], jax_types.TrainingStepOutput[_TrainingState] +]: + """The version of 'process_multiple_batches' with stronger typing.""" - def _process_one_batch( - state: _TrainingState, - data: _TrainingData) -> Tuple[_TrainingState, jax_types.TrainingMetrics]: - result = process_one_batch(state, data) - return result.state, result.metrics + def _process_one_batch( + state: _TrainingState, data: _TrainingData + ) -> Tuple[_TrainingState, jax_types.TrainingMetrics]: + result = process_one_batch(state, data) + return result.state, result.metrics - func = process_multiple_batches(_process_one_batch, num_batches, - postprocess_aux) + func = process_multiple_batches(_process_one_batch, num_batches, postprocess_aux) - def _process_many_batches( - state: _TrainingState, - data: _TrainingData) -> jax_types.TrainingStepOutput[_TrainingState]: - state, aux = func(state, data) - return jax_types.TrainingStepOutput(state, aux) + def _process_many_batches( + state: _TrainingState, data: _TrainingData + ) -> jax_types.TrainingStepOutput[_TrainingState]: + state, aux = func(state, data) + return jax_types.TrainingStepOutput(state, aux) - return _process_many_batches + return _process_many_batches def weighted_softmax(x: jnp.ndarray, weights: jnp.ndarray, axis: int = 0): - x = x - jnp.max(x, axis=axis) - return weights * jnp.exp(x) / jnp.sum(weights * jnp.exp(x), - axis=axis, keepdims=True) + x = x - jnp.max(x, axis=axis) + return ( + weights * jnp.exp(x) / jnp.sum(weights * jnp.exp(x), axis=axis, keepdims=True) + ) def sample_uint32(random_key: jax_types.PRNGKey) -> int: - """Returns an integer uniformly distributed in 0..2^32-1.""" - iinfo = jnp.iinfo(jnp.int32) - # randint only accepts int32 values as min and max. - jax_random = jax.random.randint( - random_key, shape=(), minval=iinfo.min, maxval=iinfo.max, dtype=jnp.int32) - return np.uint32(jax_random).item() + """Returns an integer uniformly distributed in 0..2^32-1.""" + iinfo = jnp.iinfo(jnp.int32) + # randint only accepts int32 values as min and max. + jax_random = jax.random.randint( + random_key, shape=(), minval=iinfo.min, maxval=iinfo.max, dtype=jnp.int32 + ) + return np.uint32(jax_random).item() class PrefetchIterator(core.PrefetchingIterator): - """Performs prefetching from an iterable in separate threads. + """Performs prefetching from an iterable in separate threads. Its interface is additionally extended with `ready` method which tells whether there is any data waiting for processing and a `retrieved_elements` method @@ -526,14 +541,14 @@ class PrefetchIterator(core.PrefetchingIterator): the producer, but after it finishes executing. """ - def __init__( - self, - iterable: Iterable[types.NestedArray], - buffer_size: int = 5, - device: Optional[jax.Device] = None, - num_threads: int = NUM_PREFETCH_THREADS, - ): - """Constructs PrefetchIterator. + def __init__( + self, + iterable: Iterable[types.NestedArray], + buffer_size: int = 5, + device: Optional[jax.Device] = None, + num_threads: int = NUM_PREFETCH_THREADS, + ): + """Constructs PrefetchIterator. Args: iterable: A python iterable. This is used to build the python prefetcher. @@ -548,48 +563,48 @@ def __init__( num_threads (int): Number of threads. """ - if buffer_size < 1: - raise ValueError('the buffer_size should be >= 1') - self.buffer = queue.Queue(maxsize=buffer_size) - self.producer_error = [] - self.end = object() - self.iterable = iterable - self.device = device - self.count = 0 - - # Start producer threads. - for _ in range(num_threads): - threading.Thread(target=self.producer, daemon=True).start() - - def producer(self): - """Enqueues items from `iterable` on a given thread.""" - try: - # Build a new iterable for each thread. This is crucial if working with - # tensorflow datasets because tf.Graph objects are thread local. - for item in self.iterable: - if self.device: - jax.device_put(item, self.device) - self.buffer.put(item) - except Exception as e: # pylint: disable=broad-except - logging.exception('Error in producer thread for %s', self.iterable) - self.producer_error.append(e) - finally: - self.buffer.put(self.end) - - def __iter__(self): - return self - - def ready(self): - return not self.buffer.empty() - - def retrieved_elements(self): - return self.count - - def __next__(self): - value = self.buffer.get() - if value is self.end: - if self.producer_error: - raise self.producer_error[0] from self.producer_error[0] - raise StopIteration - self.count += 1 - return value + if buffer_size < 1: + raise ValueError("the buffer_size should be >= 1") + self.buffer = queue.Queue(maxsize=buffer_size) + self.producer_error = [] + self.end = object() + self.iterable = iterable + self.device = device + self.count = 0 + + # Start producer threads. + for _ in range(num_threads): + threading.Thread(target=self.producer, daemon=True).start() + + def producer(self): + """Enqueues items from `iterable` on a given thread.""" + try: + # Build a new iterable for each thread. This is crucial if working with + # tensorflow datasets because tf.Graph objects are thread local. + for item in self.iterable: + if self.device: + jax.device_put(item, self.device) + self.buffer.put(item) + except Exception as e: # pylint: disable=broad-except + logging.exception("Error in producer thread for %s", self.iterable) + self.producer_error.append(e) + finally: + self.buffer.put(self.end) + + def __iter__(self): + return self + + def ready(self): + return not self.buffer.empty() + + def retrieved_elements(self): + return self.count + + def __next__(self): + value = self.buffer.get() + if value is self.end: + if self.producer_error: + raise self.producer_error[0] from self.producer_error[0] + raise StopIteration + self.count += 1 + return value diff --git a/acme/jax/utils_test.py b/acme/jax/utils_test.py index 04786d7319..f6afcf7428 100644 --- a/acme/jax/utils_test.py +++ b/acme/jax/utils_test.py @@ -14,74 +14,67 @@ """Tests for utils.""" -from acme.jax import utils import chex import jax import jax.numpy as jnp import numpy as np - from absl.testing import absltest +from acme.jax import utils + chex.set_n_cpu_devices(4) class JaxUtilsTest(absltest.TestCase): - - def test_batch_concat(self): - batch_size = 32 - inputs = [ - jnp.zeros(shape=(batch_size, 2)), - { - 'foo': jnp.zeros(shape=(batch_size, 5, 3)) - }, - [jnp.zeros(shape=(batch_size, 1))], - jnp.zeros(shape=(batch_size,)), - ] - - output_shape = utils.batch_concat(inputs).shape - expected_shape = [batch_size, 2 + 5 * 3 + 1 + 1] - self.assertSequenceEqual(output_shape, expected_shape) - - def test_mapreduce(self): - - @utils.mapreduce - def f(y, x): - return jnp.square(x + y) - - z = f(jnp.ones(shape=(32,)), jnp.ones(shape=(32,))) - z = jax.device_get(z) - self.assertEqual(z, 4) - - def test_get_from_first_device(self): - sharded = { - 'a': - jax.device_put_sharded( + def test_batch_concat(self): + batch_size = 32 + inputs = [ + jnp.zeros(shape=(batch_size, 2)), + {"foo": jnp.zeros(shape=(batch_size, 5, 3))}, + [jnp.zeros(shape=(batch_size, 1))], + jnp.zeros(shape=(batch_size,)), + ] + + output_shape = utils.batch_concat(inputs).shape + expected_shape = [batch_size, 2 + 5 * 3 + 1 + 1] + self.assertSequenceEqual(output_shape, expected_shape) + + def test_mapreduce(self): + @utils.mapreduce + def f(y, x): + return jnp.square(x + y) + + z = f(jnp.ones(shape=(32,)), jnp.ones(shape=(32,))) + z = jax.device_get(z) + self.assertEqual(z, 4) + + def test_get_from_first_device(self): + sharded = { + "a": jax.device_put_sharded( list(jnp.arange(16).reshape([jax.local_device_count(), 4])), - jax.local_devices()), - 'b': - jax.device_put_sharded( + jax.local_devices(), + ), + "b": jax.device_put_sharded( list(jnp.arange(8).reshape([jax.local_device_count(), 2])), jax.local_devices(), ), - } + } - want = { - 'a': jnp.arange(4), - 'b': jnp.arange(2), - } + want = { + "a": jnp.arange(4), + "b": jnp.arange(2), + } - # Get zeroth device content as DeviceArray. - device_arrays = utils.get_from_first_device(sharded, as_numpy=False) - jax.tree_map( - lambda x: self.assertIsInstance(x, jax.Array), - device_arrays) - jax.tree_map(np.testing.assert_array_equal, want, device_arrays) + # Get zeroth device content as DeviceArray. + device_arrays = utils.get_from_first_device(sharded, as_numpy=False) + jax.tree_map(lambda x: self.assertIsInstance(x, jax.Array), device_arrays) + jax.tree_map(np.testing.assert_array_equal, want, device_arrays) - # Get the zeroth device content as numpy arrays. - numpy_arrays = utils.get_from_first_device(sharded, as_numpy=True) - jax.tree_map(lambda x: self.assertIsInstance(x, np.ndarray), numpy_arrays) - jax.tree_map(np.testing.assert_array_equal, want, numpy_arrays) + # Get the zeroth device content as numpy arrays. + numpy_arrays = utils.get_from_first_device(sharded, as_numpy=True) + jax.tree_map(lambda x: self.assertIsInstance(x, np.ndarray), numpy_arrays) + jax.tree_map(np.testing.assert_array_equal, want, numpy_arrays) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/jax/variable_utils.py b/acme/jax/variable_utils.py index 5134184991..c80248192e 100644 --- a/acme/jax/variable_utils.py +++ b/acme/jax/variable_utils.py @@ -14,22 +14,23 @@ """Variable utilities for JAX.""" -from concurrent import futures import datetime import time +from concurrent import futures from typing import List, NamedTuple, Optional, Sequence, Union +import jax + from acme import core from acme.jax import networks as network_types -import jax class VariableReference(NamedTuple): - variable_name: str + variable_name: str class ReferenceVariableSource(core.VariableSource): - """Variable source which returns references instead of values. + """Variable source which returns references instead of values. This is passed to each actor when using a centralized inference server. The actor uses this special variable source to get references rather than values. @@ -39,21 +40,21 @@ class ReferenceVariableSource(core.VariableSource): actor to the inference server. """ - def get_variables(self, names: Sequence[str]) -> List[VariableReference]: - return [VariableReference(name) for name in names] + def get_variables(self, names: Sequence[str]) -> List[VariableReference]: + return [VariableReference(name) for name in names] class VariableClient: - """A variable client for updating variables from a remote source.""" + """A variable client for updating variables from a remote source.""" - def __init__( - self, - client: core.VariableSource, - key: Union[str, Sequence[str]], - update_period: Union[int, datetime.timedelta] = 1, - device: Optional[Union[str, jax.Device]] = None, - ): - """Initializes the variable client. + def __init__( + self, + client: core.VariableSource, + key: Union[str, Sequence[str]], + update_period: Union[int, datetime.timedelta] = 1, + device: Optional[Union[str, jax.Device]] = None, + ): + """Initializes the variable client. Args: client: A variable source from which we fetch variables. @@ -65,28 +66,28 @@ def __init__( device: The name of a JAX device to put variables on. If None (default), VariableClient won't put params on any device. """ - self._update_period = update_period - self._call_counter = 0 - self._last_call = time.time() - self._client = client - self._params: Sequence[network_types.Params] = None + self._update_period = update_period + self._call_counter = 0 + self._last_call = time.time() + self._client = client + self._params: Sequence[network_types.Params] = None - self._device = device - if isinstance(self._device, str): - self._device = jax.devices(device)[0] + self._device = device + if isinstance(self._device, str): + self._device = jax.devices(device)[0] - self._executor = futures.ThreadPoolExecutor(max_workers=1) + self._executor = futures.ThreadPoolExecutor(max_workers=1) - if isinstance(key, str): - key = [key] + if isinstance(key, str): + key = [key] - self._key = key - self._request = lambda k=key: client.get_variables(k) - self._future: Optional[futures.Future] = None # pylint: disable=g-bare-generic - self._async_request = lambda: self._executor.submit(self._request) + self._key = key + self._request = lambda k=key: client.get_variables(k) + self._future: Optional[futures.Future] = None # pylint: disable=g-bare-generic + self._async_request = lambda: self._executor.submit(self._request) - def update(self, wait: bool = False) -> None: - """Periodically updates the variables with the latest copy from the source. + def update(self, wait: bool = False) -> None: + """Periodically updates the variables with the latest copy from the source. If wait is True, a blocking request is executed. Any active request will be cancelled. @@ -96,59 +97,59 @@ def update(self, wait: bool = False) -> None: wait: Whether to execute asynchronous (False) or blocking updates (True). Defaults to False. """ - # Track calls (we only update periodically). - self._call_counter += 1 - - # Return if it's not time to fetch another update. - if isinstance(self._update_period, datetime.timedelta): - if self._update_period.total_seconds() + self._last_call > time.time(): - return - else: - if self._call_counter < self._update_period: - return - - if wait: - if self._future is not None: - if self._future.running(): - self._future.cancel() - self._future = None - self._call_counter = 0 - self._last_call = time.time() - self.update_and_wait() - return - - # Return early if we are still waiting for a previous request to come back. - if self._future and not self._future.done(): - return - - # Get a future and add the copy function as a callback. - self._call_counter = 0 - self._last_call = time.time() - self._future = self._async_request() - self._future.add_done_callback(lambda f: self._callback(f.result())) - - def update_and_wait(self): - """Immediately update and block until we get the result.""" - self._callback(self._request()) - - def _callback(self, params_list: List[network_types.Params]): - if self._device and not isinstance(self._client, ReferenceVariableSource): - # Move variables to a proper device. - self._params = jax.device_put(params_list, self._device) - else: - self._params = params_list - - @property - def device(self) -> Optional[jax.Device]: - return self._device - - @property - def params(self) -> Union[network_types.Params, List[network_types.Params]]: - """Returns the first params for one key, otherwise the whole params list.""" - if self._params is None: - self.update_and_wait() - - if len(self._params) == 1: - return self._params[0] - else: - return self._params + # Track calls (we only update periodically). + self._call_counter += 1 + + # Return if it's not time to fetch another update. + if isinstance(self._update_period, datetime.timedelta): + if self._update_period.total_seconds() + self._last_call > time.time(): + return + else: + if self._call_counter < self._update_period: + return + + if wait: + if self._future is not None: + if self._future.running(): + self._future.cancel() + self._future = None + self._call_counter = 0 + self._last_call = time.time() + self.update_and_wait() + return + + # Return early if we are still waiting for a previous request to come back. + if self._future and not self._future.done(): + return + + # Get a future and add the copy function as a callback. + self._call_counter = 0 + self._last_call = time.time() + self._future = self._async_request() + self._future.add_done_callback(lambda f: self._callback(f.result())) + + def update_and_wait(self): + """Immediately update and block until we get the result.""" + self._callback(self._request()) + + def _callback(self, params_list: List[network_types.Params]): + if self._device and not isinstance(self._client, ReferenceVariableSource): + # Move variables to a proper device. + self._params = jax.device_put(params_list, self._device) + else: + self._params = params_list + + @property + def device(self) -> Optional[jax.Device]: + return self._device + + @property + def params(self) -> Union[network_types.Params, List[network_types.Params]]: + """Returns the first params for one key, otherwise the whole params list.""" + if self._params is None: + self.update_and_wait() + + if len(self._params) == 1: + return self._params[0] + else: + return self._params diff --git a/acme/jax/variable_utils_test.py b/acme/jax/variable_utils_test.py index 826807ecd6..b41e07ba28 100644 --- a/acme/jax/variable_utils_test.py +++ b/acme/jax/variable_utils_test.py @@ -14,50 +14,50 @@ """Tests for variable utilities.""" -from acme.jax import variable_utils -from acme.testing import fakes import haiku as hk import jax import jax.numpy as jnp import numpy as np import tree - from absl.testing import absltest +from acme.jax import variable_utils +from acme.testing import fakes + def dummy_network(x): - return hk.nets.MLP([50, 10])(x) + return hk.nets.MLP([50, 10])(x) class VariableClientTest(absltest.TestCase): - - def test_update(self): - init_fn, _ = hk.without_apply_rng( - hk.transform(dummy_network)) - params = init_fn(jax.random.PRNGKey(1), jnp.zeros(shape=(1, 32))) - variable_source = fakes.VariableSource(params) - variable_client = variable_utils.VariableClient( - variable_source, key='policy') - variable_client.update_and_wait() - tree.map_structure(np.testing.assert_array_equal, variable_client.params, - params) - - def test_multiple_keys(self): - init_fn, _ = hk.without_apply_rng( - hk.transform(dummy_network)) - params = init_fn(jax.random.PRNGKey(1), jnp.zeros(shape=(1, 32))) - steps = jnp.zeros(shape=1) - variables = {'network': params, 'steps': steps} - variable_source = fakes.VariableSource(variables, use_default_key=False) - variable_client = variable_utils.VariableClient( - variable_source, key=['network', 'steps']) - variable_client.update_and_wait() - - tree.map_structure(np.testing.assert_array_equal, variable_client.params[0], - params) - tree.map_structure(np.testing.assert_array_equal, variable_client.params[1], - steps) - - -if __name__ == '__main__': - absltest.main() + def test_update(self): + init_fn, _ = hk.without_apply_rng(hk.transform(dummy_network)) + params = init_fn(jax.random.PRNGKey(1), jnp.zeros(shape=(1, 32))) + variable_source = fakes.VariableSource(params) + variable_client = variable_utils.VariableClient(variable_source, key="policy") + variable_client.update_and_wait() + tree.map_structure( + np.testing.assert_array_equal, variable_client.params, params + ) + + def test_multiple_keys(self): + init_fn, _ = hk.without_apply_rng(hk.transform(dummy_network)) + params = init_fn(jax.random.PRNGKey(1), jnp.zeros(shape=(1, 32))) + steps = jnp.zeros(shape=1) + variables = {"network": params, "steps": steps} + variable_source = fakes.VariableSource(variables, use_default_key=False) + variable_client = variable_utils.VariableClient( + variable_source, key=["network", "steps"] + ) + variable_client.update_and_wait() + + tree.map_structure( + np.testing.assert_array_equal, variable_client.params[0], params + ) + tree.map_structure( + np.testing.assert_array_equal, variable_client.params[1], steps + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/multiagent/types.py b/acme/multiagent/types.py index c33251c4a2..b82867a626 100644 --- a/acme/multiagent/types.py +++ b/acme/multiagent/types.py @@ -16,11 +16,11 @@ from typing import Any, Callable, Dict, Tuple +import reverb + from acme import specs from acme.agents.jax import builders as jax_builders from acme.utils.loggers import base -import reverb - # Sub-agent types AgentID = str @@ -32,10 +32,11 @@ LoggerFn = Callable[[], base.Logger] InitNetworkFn = Callable[[GenericAgent, specs.EnvironmentSpec], Networks] InitPolicyNetworkFn = Callable[ - [GenericAgent, Networks, specs.EnvironmentSpec, AgentConfig, bool], - Networks] -InitBuilderFn = Callable[[GenericAgent, AgentConfig], - jax_builders.GenericActorLearnerBuilder] + [GenericAgent, Networks, specs.EnvironmentSpec, AgentConfig, bool], Networks +] +InitBuilderFn = Callable[ + [GenericAgent, AgentConfig], jax_builders.GenericActorLearnerBuilder +] # Multiagent types MultiAgentLoggerFn = Dict[AgentID, LoggerFn] @@ -43,9 +44,8 @@ MultiAgentPolicyNetworks = Dict[AgentID, PolicyNetwork] MultiAgentSample = Tuple[reverb.ReplaySample, ...] NetworkFactory = Callable[[specs.EnvironmentSpec], MultiAgentNetworks] -PolicyFactory = Callable[[MultiAgentNetworks, EvalMode], - MultiAgentPolicyNetworks] -BuilderFactory = Callable[[ - Dict[AgentID, GenericAgent], - Dict[AgentID, AgentConfig], -], Dict[AgentID, jax_builders.GenericActorLearnerBuilder]] +PolicyFactory = Callable[[MultiAgentNetworks, EvalMode], MultiAgentPolicyNetworks] +BuilderFactory = Callable[ + [Dict[AgentID, GenericAgent], Dict[AgentID, AgentConfig],], + Dict[AgentID, jax_builders.GenericActorLearnerBuilder], +] diff --git a/acme/multiagent/utils.py b/acme/multiagent/utils.py index 4c91d90bca..5e4990efb1 100644 --- a/acme/multiagent/utils.py +++ b/acme/multiagent/utils.py @@ -14,14 +14,16 @@ """Multiagent utilities.""" +import dm_env + from acme import specs from acme.multiagent import types -import dm_env -def get_agent_spec(env_spec: specs.EnvironmentSpec, - agent_id: types.AgentID) -> specs.EnvironmentSpec: - """Returns a single agent spec from environment spec. +def get_agent_spec( + env_spec: specs.EnvironmentSpec, agent_id: types.AgentID +) -> specs.EnvironmentSpec: + """Returns a single agent spec from environment spec. Args: env_spec: environment spec, wherein observation, action, and reward specs @@ -29,20 +31,23 @@ def get_agent_spec(env_spec: specs.EnvironmentSpec, given agent index). Discounts are scalars shared amongst agents. agent_id: agent index. """ - return specs.EnvironmentSpec( - actions=env_spec.actions[agent_id], - discounts=env_spec.discounts, - observations=env_spec.observations[agent_id], - rewards=env_spec.rewards[agent_id]) - - -def get_agent_timestep(timestep: dm_env.TimeStep, - agent_id: types.AgentID) -> dm_env.TimeStep: - """Returns the extracted timestep for a particular agent.""" - # Discounts are assumed to be shared amongst agents - reward = None if timestep.reward is None else timestep.reward[agent_id] - return dm_env.TimeStep( - observation=timestep.observation[agent_id], - reward=reward, - discount=timestep.discount, - step_type=timestep.step_type) + return specs.EnvironmentSpec( + actions=env_spec.actions[agent_id], + discounts=env_spec.discounts, + observations=env_spec.observations[agent_id], + rewards=env_spec.rewards[agent_id], + ) + + +def get_agent_timestep( + timestep: dm_env.TimeStep, agent_id: types.AgentID +) -> dm_env.TimeStep: + """Returns the extracted timestep for a particular agent.""" + # Discounts are assumed to be shared amongst agents + reward = None if timestep.reward is None else timestep.reward[agent_id] + return dm_env.TimeStep( + observation=timestep.observation[agent_id], + reward=reward, + discount=timestep.discount, + step_type=timestep.step_type, + ) diff --git a/acme/multiagent/utils_test.py b/acme/multiagent/utils_test.py index 7c325fb228..96d19dfecb 100644 --- a/acme/multiagent/utils_test.py +++ b/acme/multiagent/utils_test.py @@ -14,46 +14,45 @@ """Tests for multiagent_utils.""" -from acme import specs -from acme.multiagent import utils as multiagent_utils -from acme.testing import fakes -from acme.testing import multiagent_fakes import dm_env from absl.testing import absltest +from acme import specs +from acme.multiagent import utils as multiagent_utils +from acme.testing import fakes, multiagent_fakes + class UtilsTest(absltest.TestCase): - - def test_get_agent_spec(self): - agent_indices = ['a', '99', 'Z'] - spec = multiagent_fakes.make_multiagent_environment_spec(agent_indices) - for agent_id in spec.actions.keys(): - single_agent_spec = multiagent_utils.get_agent_spec( - spec, agent_id=agent_id) - expected_spec = specs.EnvironmentSpec( - actions=spec.actions[agent_id], - discounts=spec.discounts, - observations=spec.observations[agent_id], - rewards=spec.rewards[agent_id] - ) - self.assertEqual(single_agent_spec, expected_spec) - - def test_get_agent_timestep(self): - agent_indices = ['a', '99', 'Z'] - spec = multiagent_fakes.make_multiagent_environment_spec(agent_indices) - env = fakes.Environment(spec) - timestep = env.reset() - for agent_id in spec.actions.keys(): - single_agent_timestep = multiagent_utils.get_agent_timestep( - timestep, agent_id) - expected_timestep = dm_env.TimeStep( - observation=timestep.observation[agent_id], - reward=None, - discount=None, - step_type=timestep.step_type - ) - self.assertEqual(single_agent_timestep, expected_timestep) - - -if __name__ == '__main__': - absltest.main() + def test_get_agent_spec(self): + agent_indices = ["a", "99", "Z"] + spec = multiagent_fakes.make_multiagent_environment_spec(agent_indices) + for agent_id in spec.actions.keys(): + single_agent_spec = multiagent_utils.get_agent_spec(spec, agent_id=agent_id) + expected_spec = specs.EnvironmentSpec( + actions=spec.actions[agent_id], + discounts=spec.discounts, + observations=spec.observations[agent_id], + rewards=spec.rewards[agent_id], + ) + self.assertEqual(single_agent_spec, expected_spec) + + def test_get_agent_timestep(self): + agent_indices = ["a", "99", "Z"] + spec = multiagent_fakes.make_multiagent_environment_spec(agent_indices) + env = fakes.Environment(spec) + timestep = env.reset() + for agent_id in spec.actions.keys(): + single_agent_timestep = multiagent_utils.get_agent_timestep( + timestep, agent_id + ) + expected_timestep = dm_env.TimeStep( + observation=timestep.observation[agent_id], + reward=None, + discount=None, + step_type=timestep.step_type, + ) + self.assertEqual(single_agent_timestep, expected_timestep) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/specs.py b/acme/specs.py index 1d568436fa..68bc83a9be 100644 --- a/acme/specs.py +++ b/acme/specs.py @@ -31,18 +31,20 @@ class EnvironmentSpec(NamedTuple): - """Full specification of the domains used by a given environment.""" - # TODO(b/144758674): Use NestedSpec type here. - observations: Any - actions: Any - rewards: Any - discounts: Any + """Full specification of the domains used by a given environment.""" + + # TODO(b/144758674): Use NestedSpec type here. + observations: Any + actions: Any + rewards: Any + discounts: Any def make_environment_spec(environment: dm_env.Environment) -> EnvironmentSpec: - """Returns an `EnvironmentSpec` describing values used by an environment.""" - return EnvironmentSpec( - observations=environment.observation_spec(), - actions=environment.action_spec(), - rewards=environment.reward_spec(), - discounts=environment.discount_spec()) + """Returns an `EnvironmentSpec` describing values used by an environment.""" + return EnvironmentSpec( + observations=environment.observation_spec(), + actions=environment.action_spec(), + rewards=environment.reward_spec(), + discounts=environment.discount_spec(), + ) diff --git a/acme/testing/fakes.py b/acme/testing/fakes.py index 41fcb3fcff..1a3a862e4c 100644 --- a/acme/testing/fakes.py +++ b/acme/testing/fakes.py @@ -19,259 +19,273 @@ """ import threading -from typing import List, Mapping, Optional, Sequence, Callable, Iterator +from typing import Callable, Iterator, List, Mapping, Optional, Sequence -from acme import core -from acme import specs -from acme import types -from acme import wrappers import dm_env import numpy as np import reverb -from rlds import rlds_types import tensorflow as tf import tree +from rlds import rlds_types + +from acme import core, specs, types, wrappers class Actor(core.Actor): - """Fake actor which generates random actions and validates specs.""" + """Fake actor which generates random actions and validates specs.""" - def __init__(self, spec: specs.EnvironmentSpec): - self._spec = spec - self.num_updates = 0 + def __init__(self, spec: specs.EnvironmentSpec): + self._spec = spec + self.num_updates = 0 - def select_action(self, observation: types.NestedArray) -> types.NestedArray: - _validate_spec(self._spec.observations, observation) - return _generate_from_spec(self._spec.actions) + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + _validate_spec(self._spec.observations, observation) + return _generate_from_spec(self._spec.actions) - def observe_first(self, timestep: dm_env.TimeStep): - _validate_spec(self._spec.observations, timestep.observation) + def observe_first(self, timestep: dm_env.TimeStep): + _validate_spec(self._spec.observations, timestep.observation) - def observe( - self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - ): - _validate_spec(self._spec.actions, action) - _validate_spec(self._spec.rewards, next_timestep.reward) - _validate_spec(self._spec.discounts, next_timestep.discount) - _validate_spec(self._spec.observations, next_timestep.observation) + def observe( + self, action: types.NestedArray, next_timestep: dm_env.TimeStep, + ): + _validate_spec(self._spec.actions, action) + _validate_spec(self._spec.rewards, next_timestep.reward) + _validate_spec(self._spec.discounts, next_timestep.discount) + _validate_spec(self._spec.observations, next_timestep.observation) - def update(self, wait: bool = False): - self.num_updates += 1 + def update(self, wait: bool = False): + self.num_updates += 1 class VariableSource(core.VariableSource): - """Fake variable source.""" - - def __init__(self, - variables: Optional[types.NestedArray] = None, - barrier: Optional[threading.Barrier] = None, - use_default_key: bool = True): - # Add dummy variables so we can expose them in get_variables. - if use_default_key: - self._variables = {'policy': [] if variables is None else variables} - else: - self._variables = variables - self._barrier = barrier + """Fake variable source.""" + + def __init__( + self, + variables: Optional[types.NestedArray] = None, + barrier: Optional[threading.Barrier] = None, + use_default_key: bool = True, + ): + # Add dummy variables so we can expose them in get_variables. + if use_default_key: + self._variables = {"policy": [] if variables is None else variables} + else: + self._variables = variables + self._barrier = barrier - def get_variables(self, names: Sequence[str]) -> List[types.NestedArray]: - if self._barrier is not None: - self._barrier.wait() - return [self._variables[name] for name in names] + def get_variables(self, names: Sequence[str]) -> List[types.NestedArray]: + if self._barrier is not None: + self._barrier.wait() + return [self._variables[name] for name in names] class Learner(core.Learner, VariableSource): - """Fake Learner.""" + """Fake Learner.""" - def __init__(self, - variables: Optional[types.NestedArray] = None, - barrier: Optional[threading.Barrier] = None): - super().__init__(variables=variables, barrier=barrier) - self.step_counter = 0 + def __init__( + self, + variables: Optional[types.NestedArray] = None, + barrier: Optional[threading.Barrier] = None, + ): + super().__init__(variables=variables, barrier=barrier) + self.step_counter = 0 - def step(self): - self.step_counter += 1 + def step(self): + self.step_counter += 1 class Environment(dm_env.Environment): - """A fake environment with a given spec.""" - - def __init__( - self, - spec: specs.EnvironmentSpec, - *, - episode_length: int = 25, - ): - # Assert that the discount spec is a BoundedArray with range [0, 1]. - def check_discount_spec(path, discount_spec): - if (not isinstance(discount_spec, specs.BoundedArray) or - not np.isclose(discount_spec.minimum, 0) or - not np.isclose(discount_spec.maximum, 1)): - if path: - path_str = ' ' + '/'.join(str(p) for p in path) + """A fake environment with a given spec.""" + + def __init__( + self, spec: specs.EnvironmentSpec, *, episode_length: int = 25, + ): + # Assert that the discount spec is a BoundedArray with range [0, 1]. + def check_discount_spec(path, discount_spec): + if ( + not isinstance(discount_spec, specs.BoundedArray) + or not np.isclose(discount_spec.minimum, 0) + or not np.isclose(discount_spec.maximum, 1) + ): + if path: + path_str = " " + "/".join(str(p) for p in path) + else: + path_str = "" + raise ValueError( + "discount_spec {}isn't a BoundedArray in [0, 1].".format(path_str) + ) + + tree.map_structure_with_path(check_discount_spec, spec.discounts) + + self._spec = spec + self._episode_length = episode_length + self._step = 0 + + def _generate_fake_observation(self): + return _generate_from_spec(self._spec.observations) + + def _generate_fake_reward(self): + return _generate_from_spec(self._spec.rewards) + + def _generate_fake_discount(self): + return _generate_from_spec(self._spec.discounts) + + def reset(self) -> dm_env.TimeStep: + observation = self._generate_fake_observation() + self._step = 1 + return dm_env.restart(observation) + + def step(self, action) -> dm_env.TimeStep: + # Return a reset timestep if we haven't touched the environment yet. + if not self._step: + return self.reset() + + _validate_spec(self._spec.actions, action) + + observation = self._generate_fake_observation() + reward = self._generate_fake_reward() + discount = self._generate_fake_discount() + + if self._episode_length and (self._step == self._episode_length): + self._step = 0 + # We can't use dm_env.termination directly because then the discount + # wouldn't necessarily conform to the spec (if eg. we want float32). + return dm_env.TimeStep(dm_env.StepType.LAST, reward, discount, observation) else: - path_str = '' - raise ValueError( - 'discount_spec {}isn\'t a BoundedArray in [0, 1].'.format(path_str)) - - tree.map_structure_with_path(check_discount_spec, spec.discounts) - - self._spec = spec - self._episode_length = episode_length - self._step = 0 - - def _generate_fake_observation(self): - return _generate_from_spec(self._spec.observations) - - def _generate_fake_reward(self): - return _generate_from_spec(self._spec.rewards) - - def _generate_fake_discount(self): - return _generate_from_spec(self._spec.discounts) + self._step += 1 + return dm_env.transition( + reward=reward, observation=observation, discount=discount + ) - def reset(self) -> dm_env.TimeStep: - observation = self._generate_fake_observation() - self._step = 1 - return dm_env.restart(observation) + def action_spec(self): + return self._spec.actions - def step(self, action) -> dm_env.TimeStep: - # Return a reset timestep if we haven't touched the environment yet. - if not self._step: - return self.reset() + def observation_spec(self): + return self._spec.observations - _validate_spec(self._spec.actions, action) + def reward_spec(self): + return self._spec.rewards - observation = self._generate_fake_observation() - reward = self._generate_fake_reward() - discount = self._generate_fake_discount() - - if self._episode_length and (self._step == self._episode_length): - self._step = 0 - # We can't use dm_env.termination directly because then the discount - # wouldn't necessarily conform to the spec (if eg. we want float32). - return dm_env.TimeStep(dm_env.StepType.LAST, reward, discount, - observation) - else: - self._step += 1 - return dm_env.transition( - reward=reward, observation=observation, discount=discount) - - def action_spec(self): - return self._spec.actions - - def observation_spec(self): - return self._spec.observations - - def reward_spec(self): - return self._spec.rewards - - def discount_spec(self): - return self._spec.discounts + def discount_spec(self): + return self._spec.discounts class _BaseDiscreteEnvironment(Environment): - """Discrete action fake environment.""" - - def __init__(self, - *, - num_actions: int = 1, - action_dtype=np.int32, - observation_spec: types.NestedSpec, - discount_spec: Optional[types.NestedSpec] = None, - reward_spec: Optional[types.NestedSpec] = None, - **kwargs): - """Initialize the environment.""" - if reward_spec is None: - reward_spec = specs.Array((), np.float32) - - if discount_spec is None: - discount_spec = specs.BoundedArray((), np.float32, 0.0, 1.0) - - actions = specs.DiscreteArray(num_actions, dtype=action_dtype) - - super().__init__( - spec=specs.EnvironmentSpec( - observations=observation_spec, - actions=actions, - rewards=reward_spec, - discounts=discount_spec), - **kwargs) + """Discrete action fake environment.""" + + def __init__( + self, + *, + num_actions: int = 1, + action_dtype=np.int32, + observation_spec: types.NestedSpec, + discount_spec: Optional[types.NestedSpec] = None, + reward_spec: Optional[types.NestedSpec] = None, + **kwargs, + ): + """Initialize the environment.""" + if reward_spec is None: + reward_spec = specs.Array((), np.float32) + + if discount_spec is None: + discount_spec = specs.BoundedArray((), np.float32, 0.0, 1.0) + + actions = specs.DiscreteArray(num_actions, dtype=action_dtype) + + super().__init__( + spec=specs.EnvironmentSpec( + observations=observation_spec, + actions=actions, + rewards=reward_spec, + discounts=discount_spec, + ), + **kwargs, + ) class DiscreteEnvironment(_BaseDiscreteEnvironment): - """Discrete state and action fake environment.""" - - def __init__(self, - *, - num_actions: int = 1, - num_observations: int = 1, - action_dtype=np.int32, - obs_dtype=np.int32, - obs_shape: Sequence[int] = (), - discount_spec: Optional[types.NestedSpec] = None, - reward_spec: Optional[types.NestedSpec] = None, - **kwargs): - """Initialize the environment.""" - observations_spec = specs.BoundedArray( - shape=obs_shape, - dtype=obs_dtype, - minimum=obs_dtype(0), - maximum=obs_dtype(num_observations - 1)) - - super().__init__( - num_actions=num_actions, - action_dtype=action_dtype, - observation_spec=observations_spec, - discount_spec=discount_spec, - reward_spec=reward_spec, - **kwargs) + """Discrete state and action fake environment.""" + + def __init__( + self, + *, + num_actions: int = 1, + num_observations: int = 1, + action_dtype=np.int32, + obs_dtype=np.int32, + obs_shape: Sequence[int] = (), + discount_spec: Optional[types.NestedSpec] = None, + reward_spec: Optional[types.NestedSpec] = None, + **kwargs, + ): + """Initialize the environment.""" + observations_spec = specs.BoundedArray( + shape=obs_shape, + dtype=obs_dtype, + minimum=obs_dtype(0), + maximum=obs_dtype(num_observations - 1), + ) + + super().__init__( + num_actions=num_actions, + action_dtype=action_dtype, + observation_spec=observations_spec, + discount_spec=discount_spec, + reward_spec=reward_spec, + **kwargs, + ) class NestedDiscreteEnvironment(_BaseDiscreteEnvironment): - """Discrete action fake environment with nested discrete state.""" - - def __init__(self, - *, - num_observations: Mapping[str, int], - num_actions: int = 1, - action_dtype=np.int32, - obs_dtype=np.int32, - obs_shape: Sequence[int] = (), - discount_spec: Optional[types.NestedSpec] = None, - reward_spec: Optional[types.NestedSpec] = None, - **kwargs): - """Initialize the environment.""" - - observations_spec = {} - for key in num_observations: - observations_spec[key] = specs.BoundedArray( - shape=obs_shape, - dtype=obs_dtype, - minimum=obs_dtype(0), - maximum=obs_dtype(num_observations[key] - 1)) - - super().__init__( - num_actions=num_actions, - action_dtype=action_dtype, - observation_spec=observations_spec, - discount_spec=discount_spec, - reward_spec=reward_spec, - **kwargs) + """Discrete action fake environment with nested discrete state.""" + + def __init__( + self, + *, + num_observations: Mapping[str, int], + num_actions: int = 1, + action_dtype=np.int32, + obs_dtype=np.int32, + obs_shape: Sequence[int] = (), + discount_spec: Optional[types.NestedSpec] = None, + reward_spec: Optional[types.NestedSpec] = None, + **kwargs, + ): + """Initialize the environment.""" + + observations_spec = {} + for key in num_observations: + observations_spec[key] = specs.BoundedArray( + shape=obs_shape, + dtype=obs_dtype, + minimum=obs_dtype(0), + maximum=obs_dtype(num_observations[key] - 1), + ) + + super().__init__( + num_actions=num_actions, + action_dtype=action_dtype, + observation_spec=observations_spec, + discount_spec=discount_spec, + reward_spec=reward_spec, + **kwargs, + ) class ContinuousEnvironment(Environment): - """Continuous state and action fake environment.""" - - def __init__(self, - *, - action_dim: int = 1, - observation_dim: int = 1, - bounded: bool = False, - dtype=np.float32, - reward_dtype=np.float32, - **kwargs): - """Initialize the environment. + """Continuous state and action fake environment.""" + + def __init__( + self, + *, + action_dim: int = 1, + observation_dim: int = 1, + bounded: bool = False, + dtype=np.float32, + reward_dtype=np.float32, + **kwargs, + ): + """Initialize the environment. Args: action_dim: number of action dimensions. @@ -282,35 +296,37 @@ def __init__(self, **kwargs: additional kwargs passed to the Environment base class. """ - action_shape = () if action_dim == 0 else (action_dim,) - observation_shape = () if observation_dim == 0 else (observation_dim,) + action_shape = () if action_dim == 0 else (action_dim,) + observation_shape = () if observation_dim == 0 else (observation_dim,) - observations = specs.Array(observation_shape, dtype) - rewards = specs.Array((), reward_dtype) - discounts = specs.BoundedArray((), reward_dtype, 0.0, 1.0) + observations = specs.Array(observation_shape, dtype) + rewards = specs.Array((), reward_dtype) + discounts = specs.BoundedArray((), reward_dtype, 0.0, 1.0) - if bounded: - actions = specs.BoundedArray(action_shape, dtype, -1.0, 1.0) - else: - actions = specs.Array(action_shape, dtype) + if bounded: + actions = specs.BoundedArray(action_shape, dtype, -1.0, 1.0) + else: + actions = specs.Array(action_shape, dtype) - super().__init__( - spec=specs.EnvironmentSpec( - observations=observations, - actions=actions, - rewards=rewards, - discounts=discounts), - **kwargs) + super().__init__( + spec=specs.EnvironmentSpec( + observations=observations, + actions=actions, + rewards=rewards, + discounts=discounts, + ), + **kwargs, + ) def _validate_spec(spec: types.NestedSpec, value: types.NestedArray): - """Validate a value from a potentially nested spec.""" - tree.assert_same_structure(value, spec) - tree.map_structure(lambda s, v: s.validate(v), spec, value) + """Validate a value from a potentially nested spec.""" + tree.assert_same_structure(value, spec) + tree.map_structure(lambda s, v: s.validate(v), spec, value) def _normalize_array(array: specs.Array) -> specs.Array: - """Converts bounded arrays with (-inf,+inf) bounds to unbounded arrays. + """Converts bounded arrays with (-inf,+inf) bounds to unbounded arrays. The returned array should be mostly equivalent to the input, except that `generate_value()` returns -infs on arrays bounded to (-inf,+inf) and zeros @@ -322,26 +338,24 @@ def _normalize_array(array: specs.Array) -> specs.Array: Returns: normalized array. """ - if isinstance(array, specs.DiscreteArray): - return array - if not isinstance(array, specs.BoundedArray): - return array - if not (array.minimum == float('-inf')).all(): - return array - if not (array.maximum == float('+inf')).all(): - return array - return specs.Array(array.shape, array.dtype, array.name) + if isinstance(array, specs.DiscreteArray): + return array + if not isinstance(array, specs.BoundedArray): + return array + if not (array.minimum == float("-inf")).all(): + return array + if not (array.maximum == float("+inf")).all(): + return array + return specs.Array(array.shape, array.dtype, array.name) def _generate_from_spec(spec: types.NestedSpec) -> types.NestedArray: - """Generate a value from a potentially nested spec.""" - return tree.map_structure(lambda s: _normalize_array(s).generate_value(), - spec) + """Generate a value from a potentially nested spec.""" + return tree.map_structure(lambda s: _normalize_array(s).generate_value(), spec) -def transition_dataset_from_spec( - spec: specs.EnvironmentSpec) -> tf.data.Dataset: - """Constructs fake dataset of Reverb N-step transition samples. +def transition_dataset_from_spec(spec: specs.EnvironmentSpec) -> tf.data.Dataset: + """Constructs fake dataset of Reverb N-step transition samples. Args: spec: Constructed fake transitions match the provided specification. @@ -351,22 +365,23 @@ def transition_dataset_from_spec( object indefinitely. """ - observation = _generate_from_spec(spec.observations) - action = _generate_from_spec(spec.actions) - reward = _generate_from_spec(spec.rewards) - discount = _generate_from_spec(spec.discounts) - data = types.Transition(observation, action, reward, discount, observation) + observation = _generate_from_spec(spec.observations) + action = _generate_from_spec(spec.actions) + reward = _generate_from_spec(spec.rewards) + discount = _generate_from_spec(spec.discounts) + data = types.Transition(observation, action, reward, discount, observation) - info = tree.map_structure( - lambda tf_dtype: tf.ones([], tf_dtype.as_numpy_dtype), - reverb.SampleInfo.tf_dtypes()) - sample = reverb.ReplaySample(info=info, data=data) + info = tree.map_structure( + lambda tf_dtype: tf.ones([], tf_dtype.as_numpy_dtype), + reverb.SampleInfo.tf_dtypes(), + ) + sample = reverb.ReplaySample(info=info, data=data) - return tf.data.Dataset.from_tensors(sample).repeat() + return tf.data.Dataset.from_tensors(sample).repeat() def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset: - """Constructs fake dataset of Reverb N-step transition samples. + """Constructs fake dataset of Reverb N-step transition samples. Args: environment: Constructed fake transitions will match the specification of @@ -376,12 +391,13 @@ def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset: tf.data.Dataset that produces the same fake N-step transition ReverbSample object indefinitely. """ - return transition_dataset_from_spec(specs.make_environment_spec(environment)) + return transition_dataset_from_spec(specs.make_environment_spec(environment)) def transition_iterator_from_spec( - spec: specs.EnvironmentSpec) -> Callable[[int], Iterator[types.Transition]]: - """Constructs fake iterator of transitions. + spec: specs.EnvironmentSpec, +) -> Callable[[int], Iterator[types.Transition]]: + """Constructs fake iterator of transitions. Args: spec: Constructed fake transitions match the provided specification.. @@ -390,21 +406,21 @@ def transition_iterator_from_spec( A callable that given a batch_size returns an iterator of transitions. """ - observation = _generate_from_spec(spec.observations) - action = _generate_from_spec(spec.actions) - reward = _generate_from_spec(spec.rewards) - discount = _generate_from_spec(spec.discounts) - data = types.Transition(observation, action, reward, discount, observation) + observation = _generate_from_spec(spec.observations) + action = _generate_from_spec(spec.actions) + reward = _generate_from_spec(spec.rewards) + discount = _generate_from_spec(spec.discounts) + data = types.Transition(observation, action, reward, discount, observation) - dataset = tf.data.Dataset.from_tensors(data).repeat() + dataset = tf.data.Dataset.from_tensors(data).repeat() - return lambda batch_size: dataset.batch(batch_size).as_numpy_iterator() + return lambda batch_size: dataset.batch(batch_size).as_numpy_iterator() def transition_iterator( - environment: dm_env.Environment + environment: dm_env.Environment, ) -> Callable[[int], Iterator[types.Transition]]: - """Constructs fake iterator of transitions. + """Constructs fake iterator of transitions. Args: environment: Constructed fake transitions will match the specification of @@ -413,12 +429,13 @@ def transition_iterator( Returns: A callable that given a batch_size returns an iterator of transitions. """ - return transition_iterator_from_spec(specs.make_environment_spec(environment)) + return transition_iterator_from_spec(specs.make_environment_spec(environment)) -def fake_atari_wrapped(episode_length: int = 10, - oar_wrapper: bool = False) -> dm_env.Environment: - """Builds fake version of the environment to be used by tests. +def fake_atari_wrapped( + episode_length: int = 10, oar_wrapper: bool = False +) -> dm_env.Environment: + """Builds fake version of the environment to be used by tests. Args: episode_length: The length of episodes produced by this environment. @@ -428,25 +445,23 @@ def fake_atari_wrapped(episode_length: int = 10, Fake version of the environment equivalent to the one returned by env_loader.load_atari_wrapped """ - env = DiscreteEnvironment( - num_actions=18, - num_observations=2, - obs_shape=(84, 84, 4), - obs_dtype=np.float32, - episode_length=episode_length) + env = DiscreteEnvironment( + num_actions=18, + num_observations=2, + obs_shape=(84, 84, 4), + obs_dtype=np.float32, + episode_length=episode_length, + ) - if oar_wrapper: - env = wrappers.ObservationActionRewardWrapper(env) - return env + if oar_wrapper: + env = wrappers.ObservationActionRewardWrapper(env) + return env def rlds_dataset_from_env_spec( - spec: specs.EnvironmentSpec, - *, - episode_count: int = 10, - episode_length: int = 25, + spec: specs.EnvironmentSpec, *, episode_count: int = 10, episode_length: int = 25, ) -> tf.data.Dataset: - """Constructs a fake RLDS dataset with the given spec. + """Constructs a fake RLDS dataset with the given spec. Args: spec: specification to use for generation of fake steps. @@ -457,18 +472,19 @@ def rlds_dataset_from_env_spec( a fake RLDS dataset. """ - fake_steps = { - rlds_types.OBSERVATION: - ([_generate_from_spec(spec.observations)] * episode_length), - rlds_types.ACTION: ([_generate_from_spec(spec.actions)] * episode_length), - rlds_types.REWARD: ([_generate_from_spec(spec.rewards)] * episode_length), - rlds_types.DISCOUNT: - ([_generate_from_spec(spec.discounts)] * episode_length), - rlds_types.IS_TERMINAL: [False] * (episode_length - 1) + [True], - rlds_types.IS_FIRST: [True] + [False] * (episode_length - 1), - rlds_types.IS_LAST: [False] * (episode_length - 1) + [True], - } - steps_dataset = tf.data.Dataset.from_tensor_slices(fake_steps) - - return tf.data.Dataset.from_tensor_slices( - {rlds_types.STEPS: [steps_dataset] * episode_count}) + fake_steps = { + rlds_types.OBSERVATION: ( + [_generate_from_spec(spec.observations)] * episode_length + ), + rlds_types.ACTION: ([_generate_from_spec(spec.actions)] * episode_length), + rlds_types.REWARD: ([_generate_from_spec(spec.rewards)] * episode_length), + rlds_types.DISCOUNT: ([_generate_from_spec(spec.discounts)] * episode_length), + rlds_types.IS_TERMINAL: [False] * (episode_length - 1) + [True], + rlds_types.IS_FIRST: [True] + [False] * (episode_length - 1), + rlds_types.IS_LAST: [False] * (episode_length - 1) + [True], + } + steps_dataset = tf.data.Dataset.from_tensor_slices(fake_steps) + + return tf.data.Dataset.from_tensor_slices( + {rlds_types.STEPS: [steps_dataset] * episode_count} + ) diff --git a/acme/testing/multiagent_fakes.py b/acme/testing/multiagent_fakes.py index 0bfe11b618..38e9f082d2 100644 --- a/acme/testing/multiagent_fakes.py +++ b/acme/testing/multiagent_fakes.py @@ -16,35 +16,36 @@ from typing import Dict, List -from acme import specs import numpy as np +from acme import specs + def _make_multiagent_spec(agent_indices: List[str]) -> Dict[str, specs.Array]: - """Returns dummy multiagent sub-spec (e.g., observation or action spec). + """Returns dummy multiagent sub-spec (e.g., observation or action spec). Args: agent_indices: a list of agent indices. """ - return { - agent_id: specs.BoundedArray((1,), np.float32, 0, 1) - for agent_id in agent_indices - } + return { + agent_id: specs.BoundedArray((1,), np.float32, 0, 1) + for agent_id in agent_indices + } -def make_multiagent_environment_spec( - agent_indices: List[str]) -> specs.EnvironmentSpec: - """Returns dummy multiagent environment spec. +def make_multiagent_environment_spec(agent_indices: List[str]) -> specs.EnvironmentSpec: + """Returns dummy multiagent environment spec. Args: agent_indices: a list of agent indices. """ - action_spec = _make_multiagent_spec(agent_indices) - discount_spec = specs.BoundedArray((), np.float32, 0.0, 1.0) - observation_spec = _make_multiagent_spec(agent_indices) - reward_spec = _make_multiagent_spec(agent_indices) - return specs.EnvironmentSpec( - actions=action_spec, - discounts=discount_spec, - observations=observation_spec, - rewards=reward_spec) + action_spec = _make_multiagent_spec(agent_indices) + discount_spec = specs.BoundedArray((), np.float32, 0.0, 1.0) + observation_spec = _make_multiagent_spec(agent_indices) + reward_spec = _make_multiagent_spec(agent_indices) + return specs.EnvironmentSpec( + actions=action_spec, + discounts=discount_spec, + observations=observation_spec, + rewards=reward_spec, + ) diff --git a/acme/testing/test_utils.py b/acme/testing/test_utils.py index 576c9b0c9b..f7be2d7eba 100644 --- a/acme/testing/test_utils.py +++ b/acme/testing/test_utils.py @@ -22,12 +22,12 @@ class TestCase(parameterized.TestCase): - """A custom TestCase which handles FLAG parsing for pytest compatibility.""" + """A custom TestCase which handles FLAG parsing for pytest compatibility.""" - def get_tempdir(self, name: Optional[str] = None) -> str: - try: - flags.FLAGS.test_tmpdir - except flags.UnparsedFlagAccessError: - # Need to initialize flags when running `pytest`. - flags.FLAGS(sys.argv, known_only=True) - return self.create_tempdir(name).full_path + def get_tempdir(self, name: Optional[str] = None) -> str: + try: + flags.FLAGS.test_tmpdir + except flags.UnparsedFlagAccessError: + # Need to initialize flags when running `pytest`. + flags.FLAGS(sys.argv, known_only=True) + return self.create_tempdir(name).full_path diff --git a/acme/tf/__init__.py b/acme/tf/__init__.py index 240cb71526..de867df849 100644 --- a/acme/tf/__init__.py +++ b/acme/tf/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/acme/tf/losses/__init__.py b/acme/tf/losses/__init__.py index 70d51bf62a..d2451252ec 100644 --- a/acme/tf/losses/__init__.py +++ b/acme/tf/losses/__init__.py @@ -14,16 +14,13 @@ """Various losses for training agent components (policies, critics, etc).""" -from acme.tf.losses.distributional import categorical -from acme.tf.losses.distributional import multiaxis_categorical +from acme.tf.losses.distributional import categorical, multiaxis_categorical from acme.tf.losses.dpg import dpg from acme.tf.losses.huber import huber -from acme.tf.losses.mompo import KLConstraint -from acme.tf.losses.mompo import MultiObjectiveMPO +from acme.tf.losses.mompo import KLConstraint, MultiObjectiveMPO from acme.tf.losses.mpo import MPO -from acme.tf.losses.r2d2 import transformed_n_step_loss # Internal imports. # pylint: disable=g-bad-import-order,g-import-not-at-top -from acme.tf.losses.quantile import NonUniformQuantileRegression -from acme.tf.losses.quantile import QuantileDistribution +from acme.tf.losses.quantile import NonUniformQuantileRegression, QuantileDistribution +from acme.tf.losses.r2d2 import transformed_n_step_loss diff --git a/acme/tf/losses/distributional.py b/acme/tf/losses/distributional.py index 54c0560c92..d4566d7bd4 100644 --- a/acme/tf/losses/distributional.py +++ b/acme/tf/losses/distributional.py @@ -14,36 +14,37 @@ """Losses and projection operators relevant to distributional RL.""" -from acme.tf import networks import tensorflow as tf +from acme.tf import networks + -def categorical(q_tm1: networks.DiscreteValuedDistribution, r_t: tf.Tensor, - d_t: tf.Tensor, - q_t: networks.DiscreteValuedDistribution) -> tf.Tensor: - """Implements the Categorical Distributional TD(0)-learning loss.""" +def categorical( + q_tm1: networks.DiscreteValuedDistribution, + r_t: tf.Tensor, + d_t: tf.Tensor, + q_t: networks.DiscreteValuedDistribution, +) -> tf.Tensor: + """Implements the Categorical Distributional TD(0)-learning loss.""" - z_t = tf.reshape(r_t, (-1, 1)) + tf.reshape(d_t, (-1, 1)) * q_t.values - p_t = tf.nn.softmax(q_t.logits) + z_t = tf.reshape(r_t, (-1, 1)) + tf.reshape(d_t, (-1, 1)) * q_t.values + p_t = tf.nn.softmax(q_t.logits) - # Performs L2 projection. - target = tf.stop_gradient(l2_project(z_t, p_t, q_t.values)) + # Performs L2 projection. + target = tf.stop_gradient(l2_project(z_t, p_t, q_t.values)) - # Calculates loss. - loss = tf.nn.softmax_cross_entropy_with_logits( - logits=q_tm1.logits, labels=target) + # Calculates loss. + loss = tf.nn.softmax_cross_entropy_with_logits(logits=q_tm1.logits, labels=target) - return loss + return loss # Use an old version of the l2 projection which is probably slower on CPU # but will run on GPUs. def l2_project( # pylint: disable=invalid-name - Zp: tf.Tensor, - P: tf.Tensor, - Zq: tf.Tensor, + Zp: tf.Tensor, P: tf.Tensor, Zq: tf.Tensor, ) -> tf.Tensor: - """Project distribution (Zp, P) onto support Zq under the L2-metric over CDFs. + """Project distribution (Zp, P) onto support Zq under the L2-metric over CDFs. This projection works for any support Zq. Let Kq be len(Zq) and Kp be len(Zp). @@ -57,37 +58,38 @@ def l2_project( # pylint: disable=invalid-name L2 projection of (Zp, P) onto Zq. """ - # Asserts that Zq has no leading dimension of size 1. - if Zq.get_shape().ndims > 1: - Zq = tf.squeeze(Zq, axis=0) + # Asserts that Zq has no leading dimension of size 1. + if Zq.get_shape().ndims > 1: + Zq = tf.squeeze(Zq, axis=0) - # Extracts vmin and vmax and construct helper tensors from Zq. - vmin, vmax = Zq[0], Zq[-1] - d_pos = tf.concat([Zq, vmin[None]], 0)[1:] - d_neg = tf.concat([vmax[None], Zq], 0)[:-1] + # Extracts vmin and vmax and construct helper tensors from Zq. + vmin, vmax = Zq[0], Zq[-1] + d_pos = tf.concat([Zq, vmin[None]], 0)[1:] + d_neg = tf.concat([vmax[None], Zq], 0)[:-1] - # Clips Zp to be in new support range (vmin, vmax). - clipped_zp = tf.clip_by_value(Zp, vmin, vmax)[:, None, :] - clipped_zq = Zq[None, :, None] + # Clips Zp to be in new support range (vmin, vmax). + clipped_zp = tf.clip_by_value(Zp, vmin, vmax)[:, None, :] + clipped_zq = Zq[None, :, None] - # Gets the distance between atom values in support. - d_pos = (d_pos - Zq)[None, :, None] # Zq[i+1] - Zq[i] - d_neg = (Zq - d_neg)[None, :, None] # Zq[i] - Zq[i-1] + # Gets the distance between atom values in support. + d_pos = (d_pos - Zq)[None, :, None] # Zq[i+1] - Zq[i] + d_neg = (Zq - d_neg)[None, :, None] # Zq[i] - Zq[i-1] - delta_qp = clipped_zp - clipped_zq # Zp[j] - Zq[i] + delta_qp = clipped_zp - clipped_zq # Zp[j] - Zq[i] - d_sign = tf.cast(delta_qp >= 0., dtype=P.dtype) - delta_hat = (d_sign * delta_qp / d_pos) - ((1. - d_sign) * delta_qp / d_neg) - P = P[:, None, :] - return tf.reduce_sum(tf.clip_by_value(1. - delta_hat, 0., 1.) * P, 2) + d_sign = tf.cast(delta_qp >= 0.0, dtype=P.dtype) + delta_hat = (d_sign * delta_qp / d_pos) - ((1.0 - d_sign) * delta_qp / d_neg) + P = P[:, None, :] + return tf.reduce_sum(tf.clip_by_value(1.0 - delta_hat, 0.0, 1.0) * P, 2) def multiaxis_categorical( # pylint: disable=invalid-name q_tm1: networks.DiscreteValuedDistribution, r_t: tf.Tensor, d_t: tf.Tensor, - q_t: networks.DiscreteValuedDistribution) -> tf.Tensor: - """Implements a multi-axis categorical distributional TD(0)-learning loss. + q_t: networks.DiscreteValuedDistribution, +) -> tf.Tensor: + """Implements a multi-axis categorical distributional TD(0)-learning loss. All arguments may have a leading batch axis, but q_tm1.logits, and one of r_t or d_t *must* have a leading batch axis. @@ -104,34 +106,31 @@ def multiaxis_categorical( # pylint: disable=invalid-name B is the batch size. E is the broadcasted shape of r_t, d_t, and q_t.values[:-1]. """ - tf.assert_equal(tf.rank(r_t), tf.rank(d_t)) + tf.assert_equal(tf.rank(r_t), tf.rank(d_t)) - # Append a singleton axis corresponding to the axis that indexes the atoms in - # q_t.values. - r_t = r_t[..., None] # shape: (B, *R, 1) - d_t = d_t[..., None] # shape: (B, *D, 1) + # Append a singleton axis corresponding to the axis that indexes the atoms in + # q_t.values. + r_t = r_t[..., None] # shape: (B, *R, 1) + d_t = d_t[..., None] # shape: (B, *D, 1) - z_t = r_t + d_t * q_t.values # shape: (B, *E, N) + z_t = r_t + d_t * q_t.values # shape: (B, *E, N) - p_t = tf.nn.softmax(q_t.logits) + p_t = tf.nn.softmax(q_t.logits) - # Performs L2 projection. - target = tf.stop_gradient(multiaxis_l2_project(z_t, p_t, q_t.values)) + # Performs L2 projection. + target = tf.stop_gradient(multiaxis_l2_project(z_t, p_t, q_t.values)) - # Calculates loss. - loss = tf.nn.softmax_cross_entropy_with_logits( - logits=q_tm1.logits, labels=target) + # Calculates loss. + loss = tf.nn.softmax_cross_entropy_with_logits(logits=q_tm1.logits, labels=target) - return loss + return loss # A modification of l2_project that allows multi-axis support arguments. def multiaxis_l2_project( # pylint: disable=invalid-name - Zp: tf.Tensor, - P: tf.Tensor, - Zq: tf.Tensor, + Zp: tf.Tensor, P: tf.Tensor, Zq: tf.Tensor, ) -> tf.Tensor: - """Project distribution (Zp, P) onto support Zq under the L2-metric over CDFs. + """Project distribution (Zp, P) onto support Zq under the L2-metric over CDFs. Let source support Zp's shape be described as (B, *C, M), where: B is the batch size. @@ -155,82 +154,87 @@ def multiaxis_l2_project( # pylint: disable=invalid-name Shape: (B, *E, N), where E is the broadcast-merged shape of C and D. """ - tf.assert_equal(tf.shape(Zp), tf.shape(P)) - - # Shapes C, D, and E as defined in the docstring above. - shape_c = tf.shape(Zp)[1:-1] # drop the batch and atom axes - shape_d = tf.shape(Zq)[:-1] # drop the atom axis - shape_e = tf.broadcast_dynamic_shape(shape_c, shape_d) - - # If Zq has fewer inner axes than the broadcasted output shape, insert some - # size-1 axes to broadcast. - ndim_c = tf.size(shape_c) - ndim_e = tf.size(shape_e) - Zp = tf.reshape( - Zp, - tf.concat([tf.shape(Zp)[:1], # B - tf.ones(tf.math.maximum(ndim_e - ndim_c, 0), dtype=tf.int32), - shape_c, # C - tf.shape(Zp)[-1:]], # M - axis=0)) - P = tf.reshape(P, tf.shape(Zp)) - - # Broadcast Zp, P, and Zq's common axes to the same shape: E. - # - # Normally it'd be sufficient to ensure that these args have the same number - # of axes, then let the arithmetic operators broadcast as necessary. Instead, - # we need to explicitly broadcast them here, because there's a call to - # tf.clip_by_value(t, vmin, vmax) below, which doesn't allow t's dimensions - # to be expanded to match vmin and vmax. - - # Shape: (B, *E, M) - Zp = tf.broadcast_to( - Zp, - tf.concat([tf.shape(Zp)[:1], # B - shape_e, # E - tf.shape(Zp)[-1:]], # M - axis=0)) - - # Shape: (B, *E, M) - P = tf.broadcast_to(P, tf.shape(Zp)) - - # Shape: (*E, N) - Zq = tf.broadcast_to(Zq, tf.concat([shape_e, tf.shape(Zq)[-1:]], axis=0)) - - # Extracts vmin and vmax and construct helper tensors from Zq. - # These have shape shape_q, except the last axis has size 1. - # Shape: (*E, 1) - vmin, vmax = Zq[..., :1], Zq[..., -1:] - - # The distances between neighboring atom values in the target support. - # Shape: (*E, N) - d_pos = tf.roll(Zq, shift=-1, axis=-1) - Zq # d_pos[i] := Zq[i+1] - Zq[i] - d_neg = Zq - tf.roll(Zq, shift=1, axis=-1) # d_neg[i] := Zq[i] - Zq[i-1] - - # Clips Zp to be in new support range (vmin, vmax). - # Shape: (B, *E, 1, M) - clipped_zp = tf.clip_by_value(Zp, vmin, vmax)[..., None, :] - - # Shape: (1, *E, N, 1) - clipped_zq = Zq[None, ..., :, None] - - # Shape: (B, *E, N, M) - delta_qp = clipped_zp - clipped_zq # Zp[j] - Zq[i] - - # Shape: (B, *E, N, M) - d_sign = tf.cast(delta_qp >= 0., dtype=P.dtype) - - # Insert singleton axes to d_pos and d_neg to maintain the same shape as - # clipped_zq. - # Shape: (1, *E, N, 1) - d_pos = d_pos[None, ..., :, None] - d_neg = d_neg[None, ..., :, None] - - # Shape: (B, *E, N, M) - delta_hat = (d_sign * delta_qp / d_pos) - ((1. - d_sign) * delta_qp / d_neg) - - # Shape: (B, *E, 1, M) - P = P[..., None, :] - - # Shape: (B, *E, N) - return tf.reduce_sum(tf.clip_by_value(1. - delta_hat, 0., 1.) * P, axis=-1) + tf.assert_equal(tf.shape(Zp), tf.shape(P)) + + # Shapes C, D, and E as defined in the docstring above. + shape_c = tf.shape(Zp)[1:-1] # drop the batch and atom axes + shape_d = tf.shape(Zq)[:-1] # drop the atom axis + shape_e = tf.broadcast_dynamic_shape(shape_c, shape_d) + + # If Zq has fewer inner axes than the broadcasted output shape, insert some + # size-1 axes to broadcast. + ndim_c = tf.size(shape_c) + ndim_e = tf.size(shape_e) + Zp = tf.reshape( + Zp, + tf.concat( + [ + tf.shape(Zp)[:1], # B + tf.ones(tf.math.maximum(ndim_e - ndim_c, 0), dtype=tf.int32), + shape_c, # C + tf.shape(Zp)[-1:], + ], # M + axis=0, + ), + ) + P = tf.reshape(P, tf.shape(Zp)) + + # Broadcast Zp, P, and Zq's common axes to the same shape: E. + # + # Normally it'd be sufficient to ensure that these args have the same number + # of axes, then let the arithmetic operators broadcast as necessary. Instead, + # we need to explicitly broadcast them here, because there's a call to + # tf.clip_by_value(t, vmin, vmax) below, which doesn't allow t's dimensions + # to be expanded to match vmin and vmax. + + # Shape: (B, *E, M) + Zp = tf.broadcast_to( + Zp, + tf.concat( + [tf.shape(Zp)[:1], shape_e, tf.shape(Zp)[-1:]], axis=0 # B # E # M + ), + ) + + # Shape: (B, *E, M) + P = tf.broadcast_to(P, tf.shape(Zp)) + + # Shape: (*E, N) + Zq = tf.broadcast_to(Zq, tf.concat([shape_e, tf.shape(Zq)[-1:]], axis=0)) + + # Extracts vmin and vmax and construct helper tensors from Zq. + # These have shape shape_q, except the last axis has size 1. + # Shape: (*E, 1) + vmin, vmax = Zq[..., :1], Zq[..., -1:] + + # The distances between neighboring atom values in the target support. + # Shape: (*E, N) + d_pos = tf.roll(Zq, shift=-1, axis=-1) - Zq # d_pos[i] := Zq[i+1] - Zq[i] + d_neg = Zq - tf.roll(Zq, shift=1, axis=-1) # d_neg[i] := Zq[i] - Zq[i-1] + + # Clips Zp to be in new support range (vmin, vmax). + # Shape: (B, *E, 1, M) + clipped_zp = tf.clip_by_value(Zp, vmin, vmax)[..., None, :] + + # Shape: (1, *E, N, 1) + clipped_zq = Zq[None, ..., :, None] + + # Shape: (B, *E, N, M) + delta_qp = clipped_zp - clipped_zq # Zp[j] - Zq[i] + + # Shape: (B, *E, N, M) + d_sign = tf.cast(delta_qp >= 0.0, dtype=P.dtype) + + # Insert singleton axes to d_pos and d_neg to maintain the same shape as + # clipped_zq. + # Shape: (1, *E, N, 1) + d_pos = d_pos[None, ..., :, None] + d_neg = d_neg[None, ..., :, None] + + # Shape: (B, *E, N, M) + delta_hat = (d_sign * delta_qp / d_pos) - ((1.0 - d_sign) * delta_qp / d_neg) + + # Shape: (B, *E, 1, M) + P = P[..., None, :] + + # Shape: (B, *E, N) + return tf.reduce_sum(tf.clip_by_value(1.0 - delta_hat, 0.0, 1.0) * P, axis=-1) diff --git a/acme/tf/losses/distributional_test.py b/acme/tf/losses/distributional_test.py index 3a368c07f0..888f66fb29 100644 --- a/acme/tf/losses/distributional_test.py +++ b/acme/tf/losses/distributional_test.py @@ -14,17 +14,16 @@ """Tests for acme.tf.losses.distributional.""" -from acme.tf.losses import distributional import numpy as np -from numpy import testing as npt import tensorflow as tf +from absl.testing import absltest, parameterized +from numpy import testing as npt -from absl.testing import absltest -from absl.testing import parameterized +from acme.tf.losses import distributional def _reference_l2_project(src_support, src_probs, dst_support): - """Multi-axis l2_project, implemented using single-axis l2_project. + """Multi-axis l2_project, implemented using single-axis l2_project. This is for testing multiaxis_l2_project's consistency with l2_project, when used with multi-axis support vs single-axis support. @@ -37,67 +36,62 @@ def _reference_l2_project(src_support, src_probs, dst_support): Returns: src_probs, projected onto dst_support. """ - assert src_support.shape == src_probs.shape + assert src_support.shape == src_probs.shape - # Remove the batch and value axes, and broadcast the rest to a common shape. - common_shape = np.broadcast(src_support[0, ..., 0], - dst_support[..., 0]).shape + # Remove the batch and value axes, and broadcast the rest to a common shape. + common_shape = np.broadcast(src_support[0, ..., 0], dst_support[..., 0]).shape - # If src_* have fewer internal axes than len(common_shape), insert size-1 - # axes. - while src_support.ndim-2 < len(common_shape): - src_support = src_support[:, None, ...] + # If src_* have fewer internal axes than len(common_shape), insert size-1 + # axes. + while src_support.ndim - 2 < len(common_shape): + src_support = src_support[:, None, ...] - src_probs = np.reshape(src_probs, src_support.shape) + src_probs = np.reshape(src_probs, src_support.shape) - # Broadcast args' non-batch, non-value axes to common_shape. - src_support = np.broadcast_to( - src_support, - src_support.shape[:1] + common_shape + src_support.shape[-1:]) - src_probs = np.broadcast_to(src_probs, src_support.shape) - dst_support = np.broadcast_to( - dst_support, - common_shape + dst_support.shape[-1:]) + # Broadcast args' non-batch, non-value axes to common_shape. + src_support = np.broadcast_to( + src_support, src_support.shape[:1] + common_shape + src_support.shape[-1:] + ) + src_probs = np.broadcast_to(src_probs, src_support.shape) + dst_support = np.broadcast_to(dst_support, common_shape + dst_support.shape[-1:]) - output_shape = (src_support.shape[0],) + dst_support.shape + output_shape = (src_support.shape[0],) + dst_support.shape - # Collapse all but the first (batch) and last (atom) axes. - src_support = src_support.reshape( - [src_support.shape[0], -1, src_support.shape[-1]]) - src_probs = src_probs.reshape( - [src_probs.shape[0], -1, src_probs.shape[-1]]) + # Collapse all but the first (batch) and last (atom) axes. + src_support = src_support.reshape([src_support.shape[0], -1, src_support.shape[-1]]) + src_probs = src_probs.reshape([src_probs.shape[0], -1, src_probs.shape[-1]]) - # Collapse all but the last (atom) axes. - dst_support = dst_support.reshape([-1, dst_support.shape[-1]]) + # Collapse all but the last (atom) axes. + dst_support = dst_support.reshape([-1, dst_support.shape[-1]]) - dst_probs = np.zeros(src_support.shape[:1] + dst_support.shape, - dtype=src_probs.dtype) + dst_probs = np.zeros( + src_support.shape[:1] + dst_support.shape, dtype=src_probs.dtype + ) - # iterate over all supports - for i in range(src_support.shape[1]): - s_support = tf.convert_to_tensor(src_support[:, i, :]) - s_probs = tf.convert_to_tensor(src_probs[:, i, :]) - d_support = tf.convert_to_tensor(dst_support[i, :]) - d_probs = distributional.l2_project(s_support, s_probs, d_support) - dst_probs[:, i, :] = d_probs.numpy() + # iterate over all supports + for i in range(src_support.shape[1]): + s_support = tf.convert_to_tensor(src_support[:, i, :]) + s_probs = tf.convert_to_tensor(src_probs[:, i, :]) + d_support = tf.convert_to_tensor(dst_support[i, :]) + d_probs = distributional.l2_project(s_support, s_probs, d_support) + dst_probs[:, i, :] = d_probs.numpy() - return dst_probs.reshape(output_shape) + return dst_probs.reshape(output_shape) class L2ProjectTest(parameterized.TestCase): - - @parameterized.parameters( - [(2, 11), (11,)], # C = (), D = (), matching num_atoms (11 and 11) - [(2, 11), (5,)], # C = (), D = (), differing num_atoms (11 and 5). - [(2, 3, 11), (3, 5)], # C = (3,), D = (3,) - [(2, 1, 11), (3, 5)], # C = (1,), D = (3,) - [(2, 3, 11), (1, 5)], # (C = (3,), D = (1,) - [(2, 3, 4, 11), (3, 4, 5)], # C = (3, 4), D = (3, 4) - [(2, 3, 4, 11), (4, 5)], # C = (3, 4), D = (4,) - [(2, 4, 11), (3, 4, 5)], # C = (4,), D = (3, 4) - ) - def test_multiaxis(self, src_shape, dst_shape): - """Tests consistency between multi-axis and single-axis l2_project. + @parameterized.parameters( + [(2, 11), (11,)], # C = (), D = (), matching num_atoms (11 and 11) + [(2, 11), (5,)], # C = (), D = (), differing num_atoms (11 and 5). + [(2, 3, 11), (3, 5)], # C = (3,), D = (3,) + [(2, 1, 11), (3, 5)], # C = (1,), D = (3,) + [(2, 3, 11), (1, 5)], # (C = (3,), D = (1,) + [(2, 3, 4, 11), (3, 4, 5)], # C = (3, 4), D = (3, 4) + [(2, 3, 4, 11), (4, 5)], # C = (3, 4), D = (4,) + [(2, 4, 11), (3, 4, 5)], # C = (4,), D = (3, 4) + ) + def test_multiaxis(self, src_shape, dst_shape): + """Tests consistency between multi-axis and single-axis l2_project. This calls l2_project on multi-axis supports, and checks that it gets the same outcomes as many calls to single-axis supports. @@ -107,71 +101,76 @@ def test_multiaxis(self, src_shape, dst_shape): dst_shape: Shape of destination support. Does not include a leading batch axis. """ - # src_shape includes a leading batch axis, whereas dst_shape does not. - # assert len(src_shape) >= (1 + len(dst_shape)) - - def make_support(shape, minimum): - """Creates a ndarray of supports.""" - values = np.linspace(start=minimum, stop=minimum+100, num=shape[-1]) - offsets = np.arange(np.prod(shape[:-1])) - result = values[None, :] + offsets[:, None] - return result.reshape(shape) - - src_support = make_support(src_shape, -1) - dst_support = make_support(dst_shape, -.75) - - rng = np.random.RandomState(1) - src_probs = rng.uniform(low=1.0, high=2.0, size=src_shape) - src_probs /= src_probs.sum() - - # Repeated calls to l2_project using single-axis supports. - expected_dst_probs = _reference_l2_project(src_support, - src_probs, - dst_support) - - # A single call to l2_project, with multi-axis supports. - dst_probs = distributional.multiaxis_l2_project( - tf.convert_to_tensor(src_support), - tf.convert_to_tensor(src_probs), - tf.convert_to_tensor(dst_support)).numpy() - - npt.assert_allclose(dst_probs, expected_dst_probs) - - @parameterized.parameters( - # Same src and dst support shape, dst support is shifted by +.25 - ([[0., 1, 2, 3]], - [[0., 1, 0, 0]], - [.25, 1.25, 2.25, 3.25], - [[.25, .75, 0, 0]]), - # Similar to above, but with batched src. - ([[0., 1, 2, 3], - [0., 1, 2, 3]], - [[0., 1, 0, 0], - [0., 0, 1, 0]], - [.25, 1.25, 2.25, 3.25], - [[.25, .75, 0, 0], - [0., .25, .75, 0]]), - # Similar to above, but src_probs has two 0.5's, instead of being one-hot. - ([[0., 1, 2, 3]], - [[0., .5, .5, 0]], - [.25, 1.25, 2.25, 3.25], - 0.5 * (np.array([[.25, .75, 0, 0]]) + np.array([[0., .25, .75, 0]]))), - # src and dst support have differing sizes - ([[0., 1, 2, 3]], - [[0., 1, 0, 0]], - [0.00, 0.25, 0.50, 0.75, 1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50], - [[0.00, 0.00, 0.00, 0.00, 1.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00]]), - ) - def test_l2_projection( - self, src_support, src_probs, dst_support, expected_dst_probs): - - dst_probs = distributional.multiaxis_l2_project( - tf.convert_to_tensor(src_support), - tf.convert_to_tensor(src_probs), - tf.convert_to_tensor(dst_support)).numpy() - npt.assert_allclose(dst_probs, expected_dst_probs) - - -if __name__ == '__main__': - absltest.main() - + # src_shape includes a leading batch axis, whereas dst_shape does not. + # assert len(src_shape) >= (1 + len(dst_shape)) + + def make_support(shape, minimum): + """Creates a ndarray of supports.""" + values = np.linspace(start=minimum, stop=minimum + 100, num=shape[-1]) + offsets = np.arange(np.prod(shape[:-1])) + result = values[None, :] + offsets[:, None] + return result.reshape(shape) + + src_support = make_support(src_shape, -1) + dst_support = make_support(dst_shape, -0.75) + + rng = np.random.RandomState(1) + src_probs = rng.uniform(low=1.0, high=2.0, size=src_shape) + src_probs /= src_probs.sum() + + # Repeated calls to l2_project using single-axis supports. + expected_dst_probs = _reference_l2_project(src_support, src_probs, dst_support) + + # A single call to l2_project, with multi-axis supports. + dst_probs = distributional.multiaxis_l2_project( + tf.convert_to_tensor(src_support), + tf.convert_to_tensor(src_probs), + tf.convert_to_tensor(dst_support), + ).numpy() + + npt.assert_allclose(dst_probs, expected_dst_probs) + + @parameterized.parameters( + # Same src and dst support shape, dst support is shifted by +.25 + ( + [[0.0, 1, 2, 3]], + [[0.0, 1, 0, 0]], + [0.25, 1.25, 2.25, 3.25], + [[0.25, 0.75, 0, 0]], + ), + # Similar to above, but with batched src. + ( + [[0.0, 1, 2, 3], [0.0, 1, 2, 3]], + [[0.0, 1, 0, 0], [0.0, 0, 1, 0]], + [0.25, 1.25, 2.25, 3.25], + [[0.25, 0.75, 0, 0], [0.0, 0.25, 0.75, 0]], + ), + # Similar to above, but src_probs has two 0.5's, instead of being one-hot. + ( + [[0.0, 1, 2, 3]], + [[0.0, 0.5, 0.5, 0]], + [0.25, 1.25, 2.25, 3.25], + 0.5 * (np.array([[0.25, 0.75, 0, 0]]) + np.array([[0.0, 0.25, 0.75, 0]])), + ), + # src and dst support have differing sizes + ( + [[0.0, 1, 2, 3]], + [[0.0, 1, 0, 0]], + [0.00, 0.25, 0.50, 0.75, 1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50], + [[0.00, 0.00, 0.00, 0.00, 1.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00]], + ), + ) + def test_l2_projection( + self, src_support, src_probs, dst_support, expected_dst_probs + ): + + dst_probs = distributional.multiaxis_l2_project( + tf.convert_to_tensor(src_support), + tf.convert_to_tensor(src_probs), + tf.convert_to_tensor(dst_support), + ).numpy() + npt.assert_allclose(dst_probs, expected_dst_probs) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/tf/losses/dpg.py b/acme/tf/losses/dpg.py index e268e45a90..8b2adb9ac3 100644 --- a/acme/tf/losses/dpg.py +++ b/acme/tf/losses/dpg.py @@ -15,6 +15,7 @@ """Losses for Deterministic Policy Gradients.""" from typing import Optional + import tensorflow as tf @@ -25,35 +26,36 @@ def dpg( dqda_clipping: Optional[float] = None, clip_norm: bool = False, ) -> tf.Tensor: - """Deterministic policy gradient loss, similar to trfl.dpg.""" - - # Calculate the gradient dq/da. - dqda = tape.gradient([q_max], [a_max])[0] - - if dqda is None: - raise ValueError('q_max needs to be a function of a_max.') - - # Clipping the gradient dq/da. - if dqda_clipping is not None: - if dqda_clipping <= 0: - raise ValueError('dqda_clipping should be bigger than 0, {} found'.format( - dqda_clipping)) - if clip_norm: - dqda = tf.clip_by_norm(dqda, dqda_clipping, axes=-1) - else: - dqda = tf.clip_by_value(dqda, -1. * dqda_clipping, dqda_clipping) - - # Target_a ensures correct gradient calculated during backprop. - target_a = dqda + a_max - # Stop the gradient going through Q network when backprop. - target_a = tf.stop_gradient(target_a) - # Gradient only go through actor network. - loss = 0.5 * tf.reduce_sum(tf.square(target_a - a_max), axis=-1) - # This recovers the DPG because (letting w be the actor network weights): - # d(loss)/dw = 0.5 * (2 * (target_a - a_max) * d(target_a - a_max)/dw) - # = (target_a - a_max) * [d(target_a)/dw - d(a_max)/dw] - # = dq/da * [d(target_a)/dw - d(a_max)/dw] # by defn of target_a - # = dq/da * [0 - d(a_max)/dw] # by stop_gradient - # = - dq/da * da/dw - - return loss + """Deterministic policy gradient loss, similar to trfl.dpg.""" + + # Calculate the gradient dq/da. + dqda = tape.gradient([q_max], [a_max])[0] + + if dqda is None: + raise ValueError("q_max needs to be a function of a_max.") + + # Clipping the gradient dq/da. + if dqda_clipping is not None: + if dqda_clipping <= 0: + raise ValueError( + "dqda_clipping should be bigger than 0, {} found".format(dqda_clipping) + ) + if clip_norm: + dqda = tf.clip_by_norm(dqda, dqda_clipping, axes=-1) + else: + dqda = tf.clip_by_value(dqda, -1.0 * dqda_clipping, dqda_clipping) + + # Target_a ensures correct gradient calculated during backprop. + target_a = dqda + a_max + # Stop the gradient going through Q network when backprop. + target_a = tf.stop_gradient(target_a) + # Gradient only go through actor network. + loss = 0.5 * tf.reduce_sum(tf.square(target_a - a_max), axis=-1) + # This recovers the DPG because (letting w be the actor network weights): + # d(loss)/dw = 0.5 * (2 * (target_a - a_max) * d(target_a - a_max)/dw) + # = (target_a - a_max) * [d(target_a)/dw - d(a_max)/dw] + # = dq/da * [d(target_a)/dw - d(a_max)/dw] # by defn of target_a + # = dq/da * [0 - d(a_max)/dw] # by stop_gradient + # = - dq/da * da/dw + + return loss diff --git a/acme/tf/losses/huber.py b/acme/tf/losses/huber.py index 5d97c24327..e9e26f4b6c 100644 --- a/acme/tf/losses/huber.py +++ b/acme/tf/losses/huber.py @@ -18,7 +18,7 @@ def huber(inputs: tf.Tensor, quadratic_linear_boundary: float) -> tf.Tensor: - """Calculates huber loss of `inputs`. + """Calculates huber loss of `inputs`. For each value x in `inputs`, the following is calculated: @@ -41,16 +41,16 @@ def huber(inputs: tf.Tensor, quadratic_linear_boundary: float) -> tf.Tensor: Raises: ValueError: if quadratic_linear_boundary < 0. """ - if quadratic_linear_boundary < 0: - raise ValueError("quadratic_linear_boundary must be >= 0.") - - abs_x = tf.abs(inputs) - delta = tf.constant(quadratic_linear_boundary) - quad = tf.minimum(abs_x, delta) - # The following expression is the same in value as - # tf.maximum(abs_x - delta, 0), but importantly the gradient for the - # expression when abs_x == delta is 0 (for tf.maximum it would be 1). This - # is necessary to avoid doubling the gradient, since there is already a - # nonzero contribution to the gradient from the quadratic term. - lin = (abs_x - quad) - return 0.5 * quad**2 + delta * lin + if quadratic_linear_boundary < 0: + raise ValueError("quadratic_linear_boundary must be >= 0.") + + abs_x = tf.abs(inputs) + delta = tf.constant(quadratic_linear_boundary) + quad = tf.minimum(abs_x, delta) + # The following expression is the same in value as + # tf.maximum(abs_x - delta, 0), but importantly the gradient for the + # expression when abs_x == delta is 0 (for tf.maximum it would be 1). This + # is necessary to avoid doubling the gradient, since there is already a + # nonzero contribution to the gradient from the quadratic term. + lin = abs_x - quad + return 0.5 * quad ** 2 + delta * lin diff --git a/acme/tf/losses/mompo.py b/acme/tf/losses/mompo.py index 591a27861b..06c7fb4500 100644 --- a/acme/tf/losses/mompo.py +++ b/acme/tf/losses/mompo.py @@ -32,11 +32,12 @@ import dataclasses from typing import Dict, Sequence, Tuple, Union -from acme.tf.losses import mpo import sonnet as snt import tensorflow as tf import tensorflow_probability as tfp +from acme.tf.losses import mpo + tfd = tfp.distributions _MPO_FLOAT_EPSILON = 1e-8 @@ -44,18 +45,18 @@ @dataclasses.dataclass class KLConstraint: - """Defines a per-objective policy improvement step constraint for MO-MPO.""" + """Defines a per-objective policy improvement step constraint for MO-MPO.""" - name: str - value: float + name: str + value: float - def __post_init__(self): - if self.value < 0: - raise ValueError("KL constraint epsilon must be non-negative.") + def __post_init__(self): + if self.value < 0: + raise ValueError("KL constraint epsilon must be non-negative.") class MultiObjectiveMPO(snt.Module): - """Multi-objective MPO loss with decoupled KL constraints. + """Multi-objective MPO loss with decoupled KL constraints. This implementation of the MO-MPO loss is based on the approach proposed in (Abdolmaleki, Huang et al., 2020). The following features are included as @@ -65,16 +66,18 @@ class MultiObjectiveMPO(snt.Module): (Abdolmaleki, Huang et al., 2020): https://arxiv.org/pdf/2005.07513.pdf """ - def __init__(self, - epsilons: Sequence[KLConstraint], - epsilon_mean: float, - epsilon_stddev: float, - init_log_temperature: float, - init_log_alpha_mean: float, - init_log_alpha_stddev: float, - per_dim_constraining: bool = True, - name: str = "MOMPO"): - """Initialize and configure the MPO loss. + def __init__( + self, + epsilons: Sequence[KLConstraint], + epsilon_mean: float, + epsilon_stddev: float, + init_log_temperature: float, + init_log_alpha_mean: float, + init_log_alpha_stddev: float, + per_dim_constraining: bool = True, + name: str = "MOMPO", + ): + """Initialize and configure the MPO loss. Args: epsilons: per-objective KL constraints on the non-parametric auxiliary @@ -97,69 +100,70 @@ def __init__(self, name: a name for the module, passed directly to snt.Module. """ - super().__init__(name=name) - - # MO-MPO constraint thresholds. - self._epsilons = tf.constant([x.value for x in epsilons]) - self._epsilon_mean = tf.constant(epsilon_mean) - self._epsilon_stddev = tf.constant(epsilon_stddev) - - # Initial values for the constraints' dual variables. - self._init_log_temperature = init_log_temperature - self._init_log_alpha_mean = init_log_alpha_mean - self._init_log_alpha_stddev = init_log_alpha_stddev - - # Whether to ensure per-dimension KL constraint satisfication. - self._per_dim_constraining = per_dim_constraining - - # Remember the number of objectives - self._num_objectives = len(epsilons) # K = number of objectives - self._objective_names = [x.name for x in epsilons] - - # Make sure there are no duplicate objective names - if len(self._objective_names) != len(set(self._objective_names)): - raise ValueError("Duplicate objective names are not allowed.") - - @property - def objective_names(self): - return self._objective_names - - @snt.once - def create_dual_variables_once(self, shape: tf.TensorShape, dtype: tf.DType): - """Creates the dual variables the first time the loss module is called.""" - - # Create the dual variables. - self._log_temperature = tf.Variable( - initial_value=[self._init_log_temperature] * self._num_objectives, - dtype=dtype, - name="log_temperature", - shape=(self._num_objectives,)) - self._log_alpha_mean = tf.Variable( - initial_value=tf.fill(shape, self._init_log_alpha_mean), - dtype=dtype, - name="log_alpha_mean", - shape=shape) - self._log_alpha_stddev = tf.Variable( - initial_value=tf.fill(shape, self._init_log_alpha_stddev), - dtype=dtype, - name="log_alpha_stddev", - shape=shape) - - # Cast constraint thresholds to the expected dtype. - self._epsilons = tf.cast(self._epsilons, dtype) - self._epsilon_mean = tf.cast(self._epsilon_mean, dtype) - self._epsilon_stddev = tf.cast(self._epsilon_stddev, dtype) - - def __call__( - self, - online_action_distribution: Union[tfd.MultivariateNormalDiag, - tfd.Independent], - target_action_distribution: Union[tfd.MultivariateNormalDiag, - tfd.Independent], - actions: tf.Tensor, # Shape [N, B, D]. - q_values: tf.Tensor, # Shape [N, B, K]. - ) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]: - """Computes the decoupled MO-MPO loss. + super().__init__(name=name) + + # MO-MPO constraint thresholds. + self._epsilons = tf.constant([x.value for x in epsilons]) + self._epsilon_mean = tf.constant(epsilon_mean) + self._epsilon_stddev = tf.constant(epsilon_stddev) + + # Initial values for the constraints' dual variables. + self._init_log_temperature = init_log_temperature + self._init_log_alpha_mean = init_log_alpha_mean + self._init_log_alpha_stddev = init_log_alpha_stddev + + # Whether to ensure per-dimension KL constraint satisfication. + self._per_dim_constraining = per_dim_constraining + + # Remember the number of objectives + self._num_objectives = len(epsilons) # K = number of objectives + self._objective_names = [x.name for x in epsilons] + + # Make sure there are no duplicate objective names + if len(self._objective_names) != len(set(self._objective_names)): + raise ValueError("Duplicate objective names are not allowed.") + + @property + def objective_names(self): + return self._objective_names + + @snt.once + def create_dual_variables_once(self, shape: tf.TensorShape, dtype: tf.DType): + """Creates the dual variables the first time the loss module is called.""" + + # Create the dual variables. + self._log_temperature = tf.Variable( + initial_value=[self._init_log_temperature] * self._num_objectives, + dtype=dtype, + name="log_temperature", + shape=(self._num_objectives,), + ) + self._log_alpha_mean = tf.Variable( + initial_value=tf.fill(shape, self._init_log_alpha_mean), + dtype=dtype, + name="log_alpha_mean", + shape=shape, + ) + self._log_alpha_stddev = tf.Variable( + initial_value=tf.fill(shape, self._init_log_alpha_stddev), + dtype=dtype, + name="log_alpha_stddev", + shape=shape, + ) + + # Cast constraint thresholds to the expected dtype. + self._epsilons = tf.cast(self._epsilons, dtype) + self._epsilon_mean = tf.cast(self._epsilon_mean, dtype) + self._epsilon_stddev = tf.cast(self._epsilon_stddev, dtype) + + def __call__( + self, + online_action_distribution: Union[tfd.MultivariateNormalDiag, tfd.Independent], + target_action_distribution: Union[tfd.MultivariateNormalDiag, tfd.Independent], + actions: tf.Tensor, # Shape [N, B, D]. + q_values: tf.Tensor, # Shape [N, B, K]. + ) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]: + """Computes the decoupled MO-MPO loss. Args: online_action_distribution: online distribution returned by the online @@ -175,149 +179,180 @@ def __call__( Stats, for diagnostics and tracking performance. """ - # Make sure the Q-values are per-objective - q_values.get_shape().assert_has_rank(3) - if q_values.get_shape()[-1] != self._num_objectives: - raise ValueError("Q-values do not match expected number of objectives.") - - # Cast `MultivariateNormalDiag`s to Independent Normals. - # The latter allows us to satisfy KL constraints per-dimension. - if isinstance(target_action_distribution, tfd.MultivariateNormalDiag): - target_action_distribution = tfd.Independent( - tfd.Normal(target_action_distribution.mean(), - target_action_distribution.stddev())) - online_action_distribution = tfd.Independent( - tfd.Normal(online_action_distribution.mean(), - online_action_distribution.stddev())) - - # Infer the shape and dtype of dual variables. - scalar_dtype = q_values.dtype - if self._per_dim_constraining: - dual_variable_shape = target_action_distribution.distribution.kl_divergence( - online_action_distribution.distribution).shape[1:] # Should be [D]. - else: - dual_variable_shape = target_action_distribution.kl_divergence( - online_action_distribution).shape[1:] # Should be [1]. - - # Create dual variables for the KL constraints; only happens the first call. - self.create_dual_variables_once(dual_variable_shape, scalar_dtype) - - # Project dual variables to ensure they stay positive. - min_log_temperature = tf.constant(-18.0, scalar_dtype) - min_log_alpha = tf.constant(-18.0, scalar_dtype) - self._log_temperature.assign( - tf.maximum(min_log_temperature, self._log_temperature)) - self._log_alpha_mean.assign(tf.maximum(min_log_alpha, self._log_alpha_mean)) - self._log_alpha_stddev.assign( - tf.maximum(min_log_alpha, self._log_alpha_stddev)) - - # Transform dual variables from log-space. - # Note: using softplus instead of exponential for numerical stability. - temperature = tf.math.softplus(self._log_temperature) + _MPO_FLOAT_EPSILON - alpha_mean = tf.math.softplus(self._log_alpha_mean) + _MPO_FLOAT_EPSILON - alpha_stddev = tf.math.softplus(self._log_alpha_stddev) + _MPO_FLOAT_EPSILON - - # Get online and target means and stddevs in preparation for decomposition. - online_mean = online_action_distribution.distribution.mean() - online_scale = online_action_distribution.distribution.stddev() - target_mean = target_action_distribution.distribution.mean() - target_scale = target_action_distribution.distribution.stddev() - - # Compute normalized importance weights, used to compute expectations with - # respect to the non-parametric policy; and the temperature loss, used to - # adapt the tempering of Q-values. - normalized_weights, loss_temperature = compute_weights_and_temperature_loss( - q_values, self._epsilons, temperature) # Shapes [N, B, K] and [1, K]. - normalized_weights_sum = tf.reduce_sum(normalized_weights, axis=-1) - loss_temperature_mean = tf.reduce_mean(loss_temperature) - - # Only needed for diagnostics: Compute estimated actualized KL between the - # non-parametric and current target policies. - kl_nonparametric = mpo.compute_nonparametric_kl_from_normalized_weights( - normalized_weights) - - # Decompose the online policy into fixed-mean & fixed-stddev distributions. - # This has been documented as having better performance in bandit settings, - # see e.g. https://arxiv.org/pdf/1812.02256.pdf. - fixed_stddev_distribution = tfd.Independent( - tfd.Normal(loc=online_mean, scale=target_scale)) - fixed_mean_distribution = tfd.Independent( - tfd.Normal(loc=target_mean, scale=online_scale)) - - # Compute the decomposed policy losses. - loss_policy_mean = mpo.compute_cross_entropy_loss( - actions, normalized_weights_sum, fixed_stddev_distribution) - loss_policy_stddev = mpo.compute_cross_entropy_loss( - actions, normalized_weights_sum, fixed_mean_distribution) - - # Compute the decomposed KL between the target and online policies. - if self._per_dim_constraining: - kl_mean = target_action_distribution.distribution.kl_divergence( - fixed_stddev_distribution.distribution) # Shape [B, D]. - kl_stddev = target_action_distribution.distribution.kl_divergence( - fixed_mean_distribution.distribution) # Shape [B, D]. - else: - kl_mean = target_action_distribution.kl_divergence( - fixed_stddev_distribution) # Shape [B]. - kl_stddev = target_action_distribution.kl_divergence( - fixed_mean_distribution) # Shape [B]. - - # Compute the alpha-weighted KL-penalty and dual losses to adapt the alphas. - loss_kl_mean, loss_alpha_mean = mpo.compute_parametric_kl_penalty_and_dual_loss( - kl_mean, alpha_mean, self._epsilon_mean) - loss_kl_stddev, loss_alpha_stddev = mpo.compute_parametric_kl_penalty_and_dual_loss( - kl_stddev, alpha_stddev, self._epsilon_stddev) - - # Combine losses. - loss_policy = loss_policy_mean + loss_policy_stddev - loss_kl_penalty = loss_kl_mean + loss_kl_stddev - loss_dual = loss_alpha_mean + loss_alpha_stddev + loss_temperature_mean - loss = loss_policy + loss_kl_penalty + loss_dual - - stats = {} - # Dual Variables. - stats["dual_alpha_mean"] = tf.reduce_mean(alpha_mean) - stats["dual_alpha_stddev"] = tf.reduce_mean(alpha_stddev) - # Losses. - stats["loss_policy"] = tf.reduce_mean(loss) - stats["loss_alpha"] = tf.reduce_mean(loss_alpha_mean + loss_alpha_stddev) - # KL measurements. - stats["kl_mean_rel"] = tf.reduce_mean(kl_mean, axis=0) / self._epsilon_mean - stats["kl_stddev_rel"] = tf.reduce_mean( - kl_stddev, axis=0) / self._epsilon_stddev - # If the policy has standard deviation, log summary stats for this as well. - pi_stddev = online_action_distribution.distribution.stddev() - stats["pi_stddev_min"] = tf.reduce_mean(tf.reduce_min(pi_stddev, axis=-1)) - stats["pi_stddev_max"] = tf.reduce_mean(tf.reduce_max(pi_stddev, axis=-1)) - - # Condition number of the diagonal covariance (actually, stddev) matrix. - stats["pi_stddev_cond"] = tf.reduce_mean( - tf.reduce_max(pi_stddev, axis=-1) / tf.reduce_min(pi_stddev, axis=-1)) - - # Log per-objective values. - for i, name in enumerate(self._objective_names): - stats["{}_dual_temperature".format(name)] = temperature[i] - stats["{}_loss_temperature".format(name)] = loss_temperature[i] - stats["{}_kl_q_rel".format(name)] = tf.reduce_mean( - kl_nonparametric[:, i]) / self._epsilons[i] - - # Q measurements. - stats["{}_q_min".format(name)] = tf.reduce_mean(tf.reduce_min( - q_values, axis=0)[:, i]) - stats["{}_q_mean".format(name)] = tf.reduce_mean(tf.reduce_mean( - q_values, axis=0)[:, i]) - stats["{}_q_max".format(name)] = tf.reduce_mean(tf.reduce_max( - q_values, axis=0)[:, i]) - - return loss, stats + # Make sure the Q-values are per-objective + q_values.get_shape().assert_has_rank(3) + if q_values.get_shape()[-1] != self._num_objectives: + raise ValueError("Q-values do not match expected number of objectives.") + + # Cast `MultivariateNormalDiag`s to Independent Normals. + # The latter allows us to satisfy KL constraints per-dimension. + if isinstance(target_action_distribution, tfd.MultivariateNormalDiag): + target_action_distribution = tfd.Independent( + tfd.Normal( + target_action_distribution.mean(), + target_action_distribution.stddev(), + ) + ) + online_action_distribution = tfd.Independent( + tfd.Normal( + online_action_distribution.mean(), + online_action_distribution.stddev(), + ) + ) + + # Infer the shape and dtype of dual variables. + scalar_dtype = q_values.dtype + if self._per_dim_constraining: + dual_variable_shape = target_action_distribution.distribution.kl_divergence( + online_action_distribution.distribution + ).shape[ + 1: + ] # Should be [D]. + else: + dual_variable_shape = target_action_distribution.kl_divergence( + online_action_distribution + ).shape[ + 1: + ] # Should be [1]. + + # Create dual variables for the KL constraints; only happens the first call. + self.create_dual_variables_once(dual_variable_shape, scalar_dtype) + + # Project dual variables to ensure they stay positive. + min_log_temperature = tf.constant(-18.0, scalar_dtype) + min_log_alpha = tf.constant(-18.0, scalar_dtype) + self._log_temperature.assign( + tf.maximum(min_log_temperature, self._log_temperature) + ) + self._log_alpha_mean.assign(tf.maximum(min_log_alpha, self._log_alpha_mean)) + self._log_alpha_stddev.assign(tf.maximum(min_log_alpha, self._log_alpha_stddev)) + + # Transform dual variables from log-space. + # Note: using softplus instead of exponential for numerical stability. + temperature = tf.math.softplus(self._log_temperature) + _MPO_FLOAT_EPSILON + alpha_mean = tf.math.softplus(self._log_alpha_mean) + _MPO_FLOAT_EPSILON + alpha_stddev = tf.math.softplus(self._log_alpha_stddev) + _MPO_FLOAT_EPSILON + + # Get online and target means and stddevs in preparation for decomposition. + online_mean = online_action_distribution.distribution.mean() + online_scale = online_action_distribution.distribution.stddev() + target_mean = target_action_distribution.distribution.mean() + target_scale = target_action_distribution.distribution.stddev() + + # Compute normalized importance weights, used to compute expectations with + # respect to the non-parametric policy; and the temperature loss, used to + # adapt the tempering of Q-values. + normalized_weights, loss_temperature = compute_weights_and_temperature_loss( + q_values, self._epsilons, temperature + ) # Shapes [N, B, K] and [1, K]. + normalized_weights_sum = tf.reduce_sum(normalized_weights, axis=-1) + loss_temperature_mean = tf.reduce_mean(loss_temperature) + + # Only needed for diagnostics: Compute estimated actualized KL between the + # non-parametric and current target policies. + kl_nonparametric = mpo.compute_nonparametric_kl_from_normalized_weights( + normalized_weights + ) + + # Decompose the online policy into fixed-mean & fixed-stddev distributions. + # This has been documented as having better performance in bandit settings, + # see e.g. https://arxiv.org/pdf/1812.02256.pdf. + fixed_stddev_distribution = tfd.Independent( + tfd.Normal(loc=online_mean, scale=target_scale) + ) + fixed_mean_distribution = tfd.Independent( + tfd.Normal(loc=target_mean, scale=online_scale) + ) + + # Compute the decomposed policy losses. + loss_policy_mean = mpo.compute_cross_entropy_loss( + actions, normalized_weights_sum, fixed_stddev_distribution + ) + loss_policy_stddev = mpo.compute_cross_entropy_loss( + actions, normalized_weights_sum, fixed_mean_distribution + ) + + # Compute the decomposed KL between the target and online policies. + if self._per_dim_constraining: + kl_mean = target_action_distribution.distribution.kl_divergence( + fixed_stddev_distribution.distribution + ) # Shape [B, D]. + kl_stddev = target_action_distribution.distribution.kl_divergence( + fixed_mean_distribution.distribution + ) # Shape [B, D]. + else: + kl_mean = target_action_distribution.kl_divergence( + fixed_stddev_distribution + ) # Shape [B]. + kl_stddev = target_action_distribution.kl_divergence( + fixed_mean_distribution + ) # Shape [B]. + + # Compute the alpha-weighted KL-penalty and dual losses to adapt the alphas. + loss_kl_mean, loss_alpha_mean = mpo.compute_parametric_kl_penalty_and_dual_loss( + kl_mean, alpha_mean, self._epsilon_mean + ) + ( + loss_kl_stddev, + loss_alpha_stddev, + ) = mpo.compute_parametric_kl_penalty_and_dual_loss( + kl_stddev, alpha_stddev, self._epsilon_stddev + ) + + # Combine losses. + loss_policy = loss_policy_mean + loss_policy_stddev + loss_kl_penalty = loss_kl_mean + loss_kl_stddev + loss_dual = loss_alpha_mean + loss_alpha_stddev + loss_temperature_mean + loss = loss_policy + loss_kl_penalty + loss_dual + + stats = {} + # Dual Variables. + stats["dual_alpha_mean"] = tf.reduce_mean(alpha_mean) + stats["dual_alpha_stddev"] = tf.reduce_mean(alpha_stddev) + # Losses. + stats["loss_policy"] = tf.reduce_mean(loss) + stats["loss_alpha"] = tf.reduce_mean(loss_alpha_mean + loss_alpha_stddev) + # KL measurements. + stats["kl_mean_rel"] = tf.reduce_mean(kl_mean, axis=0) / self._epsilon_mean + stats["kl_stddev_rel"] = ( + tf.reduce_mean(kl_stddev, axis=0) / self._epsilon_stddev + ) + # If the policy has standard deviation, log summary stats for this as well. + pi_stddev = online_action_distribution.distribution.stddev() + stats["pi_stddev_min"] = tf.reduce_mean(tf.reduce_min(pi_stddev, axis=-1)) + stats["pi_stddev_max"] = tf.reduce_mean(tf.reduce_max(pi_stddev, axis=-1)) + + # Condition number of the diagonal covariance (actually, stddev) matrix. + stats["pi_stddev_cond"] = tf.reduce_mean( + tf.reduce_max(pi_stddev, axis=-1) / tf.reduce_min(pi_stddev, axis=-1) + ) + + # Log per-objective values. + for i, name in enumerate(self._objective_names): + stats["{}_dual_temperature".format(name)] = temperature[i] + stats["{}_loss_temperature".format(name)] = loss_temperature[i] + stats["{}_kl_q_rel".format(name)] = ( + tf.reduce_mean(kl_nonparametric[:, i]) / self._epsilons[i] + ) + + # Q measurements. + stats["{}_q_min".format(name)] = tf.reduce_mean( + tf.reduce_min(q_values, axis=0)[:, i] + ) + stats["{}_q_mean".format(name)] = tf.reduce_mean( + tf.reduce_mean(q_values, axis=0)[:, i] + ) + stats["{}_q_max".format(name)] = tf.reduce_mean( + tf.reduce_max(q_values, axis=0)[:, i] + ) + + return loss, stats def compute_weights_and_temperature_loss( - q_values: tf.Tensor, - epsilons: tf.Tensor, - temperature: tf.Variable, + q_values: tf.Tensor, epsilons: tf.Tensor, temperature: tf.Variable, ) -> Tuple[tf.Tensor, tf.Tensor]: - """Computes normalized importance weights for the policy optimization. + """Computes normalized importance weights for the policy optimization. Args: q_values: Q-values associated with the actions sampled from the target @@ -335,19 +370,18 @@ def compute_weights_and_temperature_loss( Temperature loss, used to adapt the temperature; shape [1, K]. """ - # Temper the given Q-values using the current temperature. - tempered_q_values = tf.stop_gradient(q_values) / temperature[None, None, :] + # Temper the given Q-values using the current temperature. + tempered_q_values = tf.stop_gradient(q_values) / temperature[None, None, :] - # Compute the normalized importance weights used to compute expectations with - # respect to the non-parametric policy. - normalized_weights = tf.nn.softmax(tempered_q_values, axis=0) - normalized_weights = tf.stop_gradient(normalized_weights) + # Compute the normalized importance weights used to compute expectations with + # respect to the non-parametric policy. + normalized_weights = tf.nn.softmax(tempered_q_values, axis=0) + normalized_weights = tf.stop_gradient(normalized_weights) - # Compute the temperature loss (dual of the E-step optimization problem). - q_logsumexp = tf.reduce_logsumexp(tempered_q_values, axis=0) - log_num_actions = tf.math.log(tf.cast(q_values.shape[0], tf.float32)) - loss_temperature = ( - epsilons + tf.reduce_mean(q_logsumexp, axis=0) - log_num_actions) - loss_temperature = temperature * loss_temperature + # Compute the temperature loss (dual of the E-step optimization problem). + q_logsumexp = tf.reduce_logsumexp(tempered_q_values, axis=0) + log_num_actions = tf.math.log(tf.cast(q_values.shape[0], tf.float32)) + loss_temperature = epsilons + tf.reduce_mean(q_logsumexp, axis=0) - log_num_actions + loss_temperature = temperature * loss_temperature - return normalized_weights, loss_temperature + return normalized_weights, loss_temperature diff --git a/acme/tf/losses/mpo.py b/acme/tf/losses/mpo.py index 0b956cd1e0..4731a4a7b7 100644 --- a/acme/tf/losses/mpo.py +++ b/acme/tf/losses/mpo.py @@ -36,7 +36,7 @@ class MPO(snt.Module): - """MPO loss with decoupled KL constraints as in (Abdolmaleki et al., 2018). + """MPO loss with decoupled KL constraints as in (Abdolmaleki et al., 2018). This implementation of the MPO loss includes the following features, as options: @@ -49,18 +49,20 @@ class MPO(snt.Module): (Abdolmaleki et al., 2020): https://arxiv.org/pdf/2005.07513.pdf """ - def __init__(self, - epsilon: float, - epsilon_mean: float, - epsilon_stddev: float, - init_log_temperature: float, - init_log_alpha_mean: float, - init_log_alpha_stddev: float, - per_dim_constraining: bool = True, - action_penalization: bool = True, - epsilon_penalty: float = 0.001, - name: str = "MPO"): - """Initialize and configure the MPO loss. + def __init__( + self, + epsilon: float, + epsilon_mean: float, + epsilon_stddev: float, + init_log_temperature: float, + init_log_alpha_mean: float, + init_log_alpha_stddev: float, + per_dim_constraining: bool = True, + action_penalization: bool = True, + epsilon_penalty: float = 0.001, + name: str = "MPO", + ): + """Initialize and configure the MPO loss. Args: epsilon: KL constraint on the non-parametric auxiliary policy, the one @@ -86,71 +88,73 @@ def __init__(self, name: a name for the module, passed directly to snt.Module. """ - super().__init__(name=name) - - # MPO constrain thresholds. - self._epsilon = tf.constant(epsilon) - self._epsilon_mean = tf.constant(epsilon_mean) - self._epsilon_stddev = tf.constant(epsilon_stddev) - - # Initial values for the constraints' dual variables. - self._init_log_temperature = init_log_temperature - self._init_log_alpha_mean = init_log_alpha_mean - self._init_log_alpha_stddev = init_log_alpha_stddev - - # Whether to penalize out-of-bound actions via MO-MPO and its corresponding - # constraint threshold. - self._action_penalization = action_penalization - self._epsilon_penalty = tf.constant(epsilon_penalty) - - # Whether to ensure per-dimension KL constraint satisfication. - self._per_dim_constraining = per_dim_constraining - - @snt.once - def create_dual_variables_once(self, shape: tf.TensorShape, dtype: tf.DType): - """Creates the dual variables the first time the loss module is called.""" - - # Create the dual variables. - self._log_temperature = tf.Variable( - initial_value=[self._init_log_temperature], - dtype=dtype, - name="log_temperature", - shape=(1,)) - self._log_alpha_mean = tf.Variable( - initial_value=tf.fill(shape, self._init_log_alpha_mean), - dtype=dtype, - name="log_alpha_mean", - shape=shape) - self._log_alpha_stddev = tf.Variable( - initial_value=tf.fill(shape, self._init_log_alpha_stddev), - dtype=dtype, - name="log_alpha_stddev", - shape=shape) - - # Cast constraint thresholds to the expected dtype. - self._epsilon = tf.cast(self._epsilon, dtype) - self._epsilon_mean = tf.cast(self._epsilon_mean, dtype) - self._epsilon_stddev = tf.cast(self._epsilon_stddev, dtype) - - # Maybe create the action penalization dual variable. - if self._action_penalization: - self._epsilon_penalty = tf.cast(self._epsilon_penalty, dtype) - self._log_penalty_temperature = tf.Variable( - initial_value=[self._init_log_temperature], - dtype=dtype, - name="log_penalty_temperature", - shape=(1,)) - - def __call__( - self, - online_action_distribution: Union[tfd.MultivariateNormalDiag, - tfd.Independent], - target_action_distribution: Union[tfd.MultivariateNormalDiag, - tfd.Independent], - actions: tf.Tensor, # Shape [N, B, D]. - q_values: tf.Tensor, # Shape [N, B]. - ) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]: - """Computes the decoupled MPO loss. + super().__init__(name=name) + + # MPO constrain thresholds. + self._epsilon = tf.constant(epsilon) + self._epsilon_mean = tf.constant(epsilon_mean) + self._epsilon_stddev = tf.constant(epsilon_stddev) + + # Initial values for the constraints' dual variables. + self._init_log_temperature = init_log_temperature + self._init_log_alpha_mean = init_log_alpha_mean + self._init_log_alpha_stddev = init_log_alpha_stddev + + # Whether to penalize out-of-bound actions via MO-MPO and its corresponding + # constraint threshold. + self._action_penalization = action_penalization + self._epsilon_penalty = tf.constant(epsilon_penalty) + + # Whether to ensure per-dimension KL constraint satisfication. + self._per_dim_constraining = per_dim_constraining + + @snt.once + def create_dual_variables_once(self, shape: tf.TensorShape, dtype: tf.DType): + """Creates the dual variables the first time the loss module is called.""" + + # Create the dual variables. + self._log_temperature = tf.Variable( + initial_value=[self._init_log_temperature], + dtype=dtype, + name="log_temperature", + shape=(1,), + ) + self._log_alpha_mean = tf.Variable( + initial_value=tf.fill(shape, self._init_log_alpha_mean), + dtype=dtype, + name="log_alpha_mean", + shape=shape, + ) + self._log_alpha_stddev = tf.Variable( + initial_value=tf.fill(shape, self._init_log_alpha_stddev), + dtype=dtype, + name="log_alpha_stddev", + shape=shape, + ) + + # Cast constraint thresholds to the expected dtype. + self._epsilon = tf.cast(self._epsilon, dtype) + self._epsilon_mean = tf.cast(self._epsilon_mean, dtype) + self._epsilon_stddev = tf.cast(self._epsilon_stddev, dtype) + + # Maybe create the action penalization dual variable. + if self._action_penalization: + self._epsilon_penalty = tf.cast(self._epsilon_penalty, dtype) + self._log_penalty_temperature = tf.Variable( + initial_value=[self._init_log_temperature], + dtype=dtype, + name="log_penalty_temperature", + shape=(1,), + ) + + def __call__( + self, + online_action_distribution: Union[tfd.MultivariateNormalDiag, tfd.Independent], + target_action_distribution: Union[tfd.MultivariateNormalDiag, tfd.Independent], + actions: tf.Tensor, # Shape [N, B, D]. + q_values: tf.Tensor, # Shape [N, B]. + ) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]: + """Computes the decoupled MPO loss. Args: online_action_distribution: online distribution returned by the online @@ -166,170 +170,205 @@ def __call__( Stats, for diagnostics and tracking performance. """ - # Cast `MultivariateNormalDiag`s to Independent Normals. - # The latter allows us to satisfy KL constraints per-dimension. - if isinstance(target_action_distribution, tfd.MultivariateNormalDiag): - target_action_distribution = tfd.Independent( - tfd.Normal(target_action_distribution.mean(), - target_action_distribution.stddev())) - online_action_distribution = tfd.Independent( - tfd.Normal(online_action_distribution.mean(), - online_action_distribution.stddev())) - - # Infer the shape and dtype of dual variables. - scalar_dtype = q_values.dtype - if self._per_dim_constraining: - dual_variable_shape = target_action_distribution.distribution.kl_divergence( - online_action_distribution.distribution).shape[1:] # Should be [D]. - else: - dual_variable_shape = target_action_distribution.kl_divergence( - online_action_distribution).shape[1:] # Should be [1]. - - # Create dual variables for the KL constraints; only happens the first call. - self.create_dual_variables_once(dual_variable_shape, scalar_dtype) - - # Project dual variables to ensure they stay positive. - min_log_temperature = tf.constant(-18.0, scalar_dtype) - min_log_alpha = tf.constant(-18.0, scalar_dtype) - self._log_temperature.assign( - tf.maximum(min_log_temperature, self._log_temperature)) - self._log_alpha_mean.assign(tf.maximum(min_log_alpha, self._log_alpha_mean)) - self._log_alpha_stddev.assign( - tf.maximum(min_log_alpha, self._log_alpha_stddev)) - - # Transform dual variables from log-space. - # Note: using softplus instead of exponential for numerical stability. - temperature = tf.math.softplus(self._log_temperature) + _MPO_FLOAT_EPSILON - alpha_mean = tf.math.softplus(self._log_alpha_mean) + _MPO_FLOAT_EPSILON - alpha_stddev = tf.math.softplus(self._log_alpha_stddev) + _MPO_FLOAT_EPSILON - - # Get online and target means and stddevs in preparation for decomposition. - online_mean = online_action_distribution.distribution.mean() - online_scale = online_action_distribution.distribution.stddev() - target_mean = target_action_distribution.distribution.mean() - target_scale = target_action_distribution.distribution.stddev() - - # Compute normalized importance weights, used to compute expectations with - # respect to the non-parametric policy; and the temperature loss, used to - # adapt the tempering of Q-values. - normalized_weights, loss_temperature = compute_weights_and_temperature_loss( - q_values, self._epsilon, temperature) - - # Only needed for diagnostics: Compute estimated actualized KL between the - # non-parametric and current target policies. - kl_nonparametric = compute_nonparametric_kl_from_normalized_weights( - normalized_weights) - - if self._action_penalization: - # Project and transform action penalization temperature. - self._log_penalty_temperature.assign( - tf.maximum(min_log_temperature, self._log_penalty_temperature)) - penalty_temperature = tf.math.softplus( - self._log_penalty_temperature) + _MPO_FLOAT_EPSILON - - # Compute action penalization cost. - # Note: the cost is zero in [-1, 1] and quadratic beyond. - diff_out_of_bound = actions - tf.clip_by_value(actions, -1.0, 1.0) - cost_out_of_bound = -tf.norm(diff_out_of_bound, axis=-1) - - penalty_normalized_weights, loss_penalty_temperature = compute_weights_and_temperature_loss( - cost_out_of_bound, self._epsilon_penalty, penalty_temperature) - - # Only needed for diagnostics: Compute estimated actualized KL between the - # non-parametric and current target policies. - penalty_kl_nonparametric = compute_nonparametric_kl_from_normalized_weights( - penalty_normalized_weights) - - # Combine normalized weights. - normalized_weights += penalty_normalized_weights - loss_temperature += loss_penalty_temperature - # Decompose the online policy into fixed-mean & fixed-stddev distributions. - # This has been documented as having better performance in bandit settings, - # see e.g. https://arxiv.org/pdf/1812.02256.pdf. - fixed_stddev_distribution = tfd.Independent( - tfd.Normal(loc=online_mean, scale=target_scale)) - fixed_mean_distribution = tfd.Independent( - tfd.Normal(loc=target_mean, scale=online_scale)) - - # Compute the decomposed policy losses. - loss_policy_mean = compute_cross_entropy_loss( - actions, normalized_weights, fixed_stddev_distribution) - loss_policy_stddev = compute_cross_entropy_loss( - actions, normalized_weights, fixed_mean_distribution) - - # Compute the decomposed KL between the target and online policies. - if self._per_dim_constraining: - kl_mean = target_action_distribution.distribution.kl_divergence( - fixed_stddev_distribution.distribution) # Shape [B, D]. - kl_stddev = target_action_distribution.distribution.kl_divergence( - fixed_mean_distribution.distribution) # Shape [B, D]. - else: - kl_mean = target_action_distribution.kl_divergence( - fixed_stddev_distribution) # Shape [B]. - kl_stddev = target_action_distribution.kl_divergence( - fixed_mean_distribution) # Shape [B]. - - # Compute the alpha-weighted KL-penalty and dual losses to adapt the alphas. - loss_kl_mean, loss_alpha_mean = compute_parametric_kl_penalty_and_dual_loss( - kl_mean, alpha_mean, self._epsilon_mean) - loss_kl_stddev, loss_alpha_stddev = compute_parametric_kl_penalty_and_dual_loss( - kl_stddev, alpha_stddev, self._epsilon_stddev) - - # Combine losses. - loss_policy = loss_policy_mean + loss_policy_stddev - loss_kl_penalty = loss_kl_mean + loss_kl_stddev - loss_dual = loss_alpha_mean + loss_alpha_stddev + loss_temperature - loss = loss_policy + loss_kl_penalty + loss_dual - - stats = {} - # Dual Variables. - stats["dual_alpha_mean"] = tf.reduce_mean(alpha_mean) - stats["dual_alpha_stddev"] = tf.reduce_mean(alpha_stddev) - stats["dual_temperature"] = tf.reduce_mean(temperature) - # Losses. - stats["loss_policy"] = tf.reduce_mean(loss) - stats["loss_alpha"] = tf.reduce_mean(loss_alpha_mean + loss_alpha_stddev) - stats["loss_temperature"] = tf.reduce_mean(loss_temperature) - # KL measurements. - stats["kl_q_rel"] = tf.reduce_mean(kl_nonparametric) / self._epsilon - - if self._action_penalization: - stats["penalty_kl_q_rel"] = tf.reduce_mean( - penalty_kl_nonparametric) / self._epsilon_penalty - - stats["kl_mean_rel"] = tf.reduce_mean(kl_mean) / self._epsilon_mean - stats["kl_stddev_rel"] = tf.reduce_mean(kl_stddev) / self._epsilon_stddev - if self._per_dim_constraining: - # When KL is constrained per-dimension, we also log per-dimension min and - # max of mean/std of the realized KL costs. - stats["kl_mean_rel_min"] = tf.reduce_min(tf.reduce_mean( - kl_mean, axis=0)) / self._epsilon_mean - stats["kl_mean_rel_max"] = tf.reduce_max(tf.reduce_mean( - kl_mean, axis=0)) / self._epsilon_mean - stats["kl_stddev_rel_min"] = tf.reduce_min( - tf.reduce_mean(kl_stddev, axis=0)) / self._epsilon_stddev - stats["kl_stddev_rel_max"] = tf.reduce_max( - tf.reduce_mean(kl_stddev, axis=0)) / self._epsilon_stddev - # Q measurements. - stats["q_min"] = tf.reduce_mean(tf.reduce_min(q_values, axis=0)) - stats["q_max"] = tf.reduce_mean(tf.reduce_max(q_values, axis=0)) - # If the policy has standard deviation, log summary stats for this as well. - pi_stddev = online_action_distribution.distribution.stddev() - stats["pi_stddev_min"] = tf.reduce_mean(tf.reduce_min(pi_stddev, axis=-1)) - stats["pi_stddev_max"] = tf.reduce_mean(tf.reduce_max(pi_stddev, axis=-1)) - # Condition number of the diagonal covariance (actually, stddev) matrix. - stats["pi_stddev_cond"] = tf.reduce_mean( - tf.reduce_max(pi_stddev, axis=-1) / tf.reduce_min(pi_stddev, axis=-1)) - - return loss, stats + # Cast `MultivariateNormalDiag`s to Independent Normals. + # The latter allows us to satisfy KL constraints per-dimension. + if isinstance(target_action_distribution, tfd.MultivariateNormalDiag): + target_action_distribution = tfd.Independent( + tfd.Normal( + target_action_distribution.mean(), + target_action_distribution.stddev(), + ) + ) + online_action_distribution = tfd.Independent( + tfd.Normal( + online_action_distribution.mean(), + online_action_distribution.stddev(), + ) + ) + + # Infer the shape and dtype of dual variables. + scalar_dtype = q_values.dtype + if self._per_dim_constraining: + dual_variable_shape = target_action_distribution.distribution.kl_divergence( + online_action_distribution.distribution + ).shape[ + 1: + ] # Should be [D]. + else: + dual_variable_shape = target_action_distribution.kl_divergence( + online_action_distribution + ).shape[ + 1: + ] # Should be [1]. + + # Create dual variables for the KL constraints; only happens the first call. + self.create_dual_variables_once(dual_variable_shape, scalar_dtype) + + # Project dual variables to ensure they stay positive. + min_log_temperature = tf.constant(-18.0, scalar_dtype) + min_log_alpha = tf.constant(-18.0, scalar_dtype) + self._log_temperature.assign( + tf.maximum(min_log_temperature, self._log_temperature) + ) + self._log_alpha_mean.assign(tf.maximum(min_log_alpha, self._log_alpha_mean)) + self._log_alpha_stddev.assign(tf.maximum(min_log_alpha, self._log_alpha_stddev)) + + # Transform dual variables from log-space. + # Note: using softplus instead of exponential for numerical stability. + temperature = tf.math.softplus(self._log_temperature) + _MPO_FLOAT_EPSILON + alpha_mean = tf.math.softplus(self._log_alpha_mean) + _MPO_FLOAT_EPSILON + alpha_stddev = tf.math.softplus(self._log_alpha_stddev) + _MPO_FLOAT_EPSILON + + # Get online and target means and stddevs in preparation for decomposition. + online_mean = online_action_distribution.distribution.mean() + online_scale = online_action_distribution.distribution.stddev() + target_mean = target_action_distribution.distribution.mean() + target_scale = target_action_distribution.distribution.stddev() + + # Compute normalized importance weights, used to compute expectations with + # respect to the non-parametric policy; and the temperature loss, used to + # adapt the tempering of Q-values. + normalized_weights, loss_temperature = compute_weights_and_temperature_loss( + q_values, self._epsilon, temperature + ) + + # Only needed for diagnostics: Compute estimated actualized KL between the + # non-parametric and current target policies. + kl_nonparametric = compute_nonparametric_kl_from_normalized_weights( + normalized_weights + ) + + if self._action_penalization: + # Project and transform action penalization temperature. + self._log_penalty_temperature.assign( + tf.maximum(min_log_temperature, self._log_penalty_temperature) + ) + penalty_temperature = ( + tf.math.softplus(self._log_penalty_temperature) + _MPO_FLOAT_EPSILON + ) + + # Compute action penalization cost. + # Note: the cost is zero in [-1, 1] and quadratic beyond. + diff_out_of_bound = actions - tf.clip_by_value(actions, -1.0, 1.0) + cost_out_of_bound = -tf.norm(diff_out_of_bound, axis=-1) + + ( + penalty_normalized_weights, + loss_penalty_temperature, + ) = compute_weights_and_temperature_loss( + cost_out_of_bound, self._epsilon_penalty, penalty_temperature + ) + + # Only needed for diagnostics: Compute estimated actualized KL between the + # non-parametric and current target policies. + penalty_kl_nonparametric = compute_nonparametric_kl_from_normalized_weights( + penalty_normalized_weights + ) + + # Combine normalized weights. + normalized_weights += penalty_normalized_weights + loss_temperature += loss_penalty_temperature + # Decompose the online policy into fixed-mean & fixed-stddev distributions. + # This has been documented as having better performance in bandit settings, + # see e.g. https://arxiv.org/pdf/1812.02256.pdf. + fixed_stddev_distribution = tfd.Independent( + tfd.Normal(loc=online_mean, scale=target_scale) + ) + fixed_mean_distribution = tfd.Independent( + tfd.Normal(loc=target_mean, scale=online_scale) + ) + + # Compute the decomposed policy losses. + loss_policy_mean = compute_cross_entropy_loss( + actions, normalized_weights, fixed_stddev_distribution + ) + loss_policy_stddev = compute_cross_entropy_loss( + actions, normalized_weights, fixed_mean_distribution + ) + + # Compute the decomposed KL between the target and online policies. + if self._per_dim_constraining: + kl_mean = target_action_distribution.distribution.kl_divergence( + fixed_stddev_distribution.distribution + ) # Shape [B, D]. + kl_stddev = target_action_distribution.distribution.kl_divergence( + fixed_mean_distribution.distribution + ) # Shape [B, D]. + else: + kl_mean = target_action_distribution.kl_divergence( + fixed_stddev_distribution + ) # Shape [B]. + kl_stddev = target_action_distribution.kl_divergence( + fixed_mean_distribution + ) # Shape [B]. + + # Compute the alpha-weighted KL-penalty and dual losses to adapt the alphas. + loss_kl_mean, loss_alpha_mean = compute_parametric_kl_penalty_and_dual_loss( + kl_mean, alpha_mean, self._epsilon_mean + ) + loss_kl_stddev, loss_alpha_stddev = compute_parametric_kl_penalty_and_dual_loss( + kl_stddev, alpha_stddev, self._epsilon_stddev + ) + + # Combine losses. + loss_policy = loss_policy_mean + loss_policy_stddev + loss_kl_penalty = loss_kl_mean + loss_kl_stddev + loss_dual = loss_alpha_mean + loss_alpha_stddev + loss_temperature + loss = loss_policy + loss_kl_penalty + loss_dual + + stats = {} + # Dual Variables. + stats["dual_alpha_mean"] = tf.reduce_mean(alpha_mean) + stats["dual_alpha_stddev"] = tf.reduce_mean(alpha_stddev) + stats["dual_temperature"] = tf.reduce_mean(temperature) + # Losses. + stats["loss_policy"] = tf.reduce_mean(loss) + stats["loss_alpha"] = tf.reduce_mean(loss_alpha_mean + loss_alpha_stddev) + stats["loss_temperature"] = tf.reduce_mean(loss_temperature) + # KL measurements. + stats["kl_q_rel"] = tf.reduce_mean(kl_nonparametric) / self._epsilon + + if self._action_penalization: + stats["penalty_kl_q_rel"] = ( + tf.reduce_mean(penalty_kl_nonparametric) / self._epsilon_penalty + ) + + stats["kl_mean_rel"] = tf.reduce_mean(kl_mean) / self._epsilon_mean + stats["kl_stddev_rel"] = tf.reduce_mean(kl_stddev) / self._epsilon_stddev + if self._per_dim_constraining: + # When KL is constrained per-dimension, we also log per-dimension min and + # max of mean/std of the realized KL costs. + stats["kl_mean_rel_min"] = ( + tf.reduce_min(tf.reduce_mean(kl_mean, axis=0)) / self._epsilon_mean + ) + stats["kl_mean_rel_max"] = ( + tf.reduce_max(tf.reduce_mean(kl_mean, axis=0)) / self._epsilon_mean + ) + stats["kl_stddev_rel_min"] = ( + tf.reduce_min(tf.reduce_mean(kl_stddev, axis=0)) / self._epsilon_stddev + ) + stats["kl_stddev_rel_max"] = ( + tf.reduce_max(tf.reduce_mean(kl_stddev, axis=0)) / self._epsilon_stddev + ) + # Q measurements. + stats["q_min"] = tf.reduce_mean(tf.reduce_min(q_values, axis=0)) + stats["q_max"] = tf.reduce_mean(tf.reduce_max(q_values, axis=0)) + # If the policy has standard deviation, log summary stats for this as well. + pi_stddev = online_action_distribution.distribution.stddev() + stats["pi_stddev_min"] = tf.reduce_mean(tf.reduce_min(pi_stddev, axis=-1)) + stats["pi_stddev_max"] = tf.reduce_mean(tf.reduce_max(pi_stddev, axis=-1)) + # Condition number of the diagonal covariance (actually, stddev) matrix. + stats["pi_stddev_cond"] = tf.reduce_mean( + tf.reduce_max(pi_stddev, axis=-1) / tf.reduce_min(pi_stddev, axis=-1) + ) + + return loss, stats def compute_weights_and_temperature_loss( - q_values: tf.Tensor, - epsilon: float, - temperature: tf.Variable, + q_values: tf.Tensor, epsilon: float, temperature: tf.Variable, ) -> Tuple[tf.Tensor, tf.Tensor]: - """Computes normalized importance weights for the policy optimization. + """Computes normalized importance weights for the policy optimization. Args: q_values: Q-values associated with the actions sampled from the target @@ -346,33 +385,34 @@ def compute_weights_and_temperature_loss( Temperature loss, used to adapt the temperature. """ - # Temper the given Q-values using the current temperature. - tempered_q_values = tf.stop_gradient(q_values) / temperature + # Temper the given Q-values using the current temperature. + tempered_q_values = tf.stop_gradient(q_values) / temperature - # Compute the normalized importance weights used to compute expectations with - # respect to the non-parametric policy. - normalized_weights = tf.nn.softmax(tempered_q_values, axis=0) - normalized_weights = tf.stop_gradient(normalized_weights) + # Compute the normalized importance weights used to compute expectations with + # respect to the non-parametric policy. + normalized_weights = tf.nn.softmax(tempered_q_values, axis=0) + normalized_weights = tf.stop_gradient(normalized_weights) - # Compute the temperature loss (dual of the E-step optimization problem). - q_logsumexp = tf.reduce_logsumexp(tempered_q_values, axis=0) - log_num_actions = tf.math.log(tf.cast(q_values.shape[0], tf.float32)) - loss_temperature = epsilon + tf.reduce_mean(q_logsumexp) - log_num_actions - loss_temperature = temperature * loss_temperature + # Compute the temperature loss (dual of the E-step optimization problem). + q_logsumexp = tf.reduce_logsumexp(tempered_q_values, axis=0) + log_num_actions = tf.math.log(tf.cast(q_values.shape[0], tf.float32)) + loss_temperature = epsilon + tf.reduce_mean(q_logsumexp) - log_num_actions + loss_temperature = temperature * loss_temperature - return normalized_weights, loss_temperature + return normalized_weights, loss_temperature def compute_nonparametric_kl_from_normalized_weights( - normalized_weights: tf.Tensor) -> tf.Tensor: - """Estimate the actualized KL between the non-parametric and target policies.""" + normalized_weights: tf.Tensor, +) -> tf.Tensor: + """Estimate the actualized KL between the non-parametric and target policies.""" - # Compute integrand. - num_action_samples = tf.cast(normalized_weights.shape[0], tf.float32) - integrand = tf.math.log(num_action_samples * normalized_weights + 1e-8) + # Compute integrand. + num_action_samples = tf.cast(normalized_weights.shape[0], tf.float32) + integrand = tf.math.log(num_action_samples * normalized_weights + 1e-8) - # Return the expectation with respect to the non-parametric policy. - return tf.reduce_sum(normalized_weights * integrand, axis=0) + # Return the expectation with respect to the non-parametric policy. + return tf.reduce_sum(normalized_weights * integrand, axis=0) def compute_cross_entropy_loss( @@ -380,7 +420,7 @@ def compute_cross_entropy_loss( normalized_weights: tf.Tensor, online_action_distribution: tfp.distributions.Distribution, ) -> tf.Tensor: - """Compute cross-entropy online and the reweighted target policy. + """Compute cross-entropy online and the reweighted target policy. Args: sampled_actions: samples used in the Monte Carlo integration in the policy @@ -395,22 +435,20 @@ def compute_cross_entropy_loss( produces the policy gradient. """ - # Compute the M-step loss. - log_prob = online_action_distribution.log_prob(sampled_actions) + # Compute the M-step loss. + log_prob = online_action_distribution.log_prob(sampled_actions) - # Compute the weighted average log-prob using the normalized weights. - loss_policy_gradient = -tf.reduce_sum(log_prob * normalized_weights, axis=0) + # Compute the weighted average log-prob using the normalized weights. + loss_policy_gradient = -tf.reduce_sum(log_prob * normalized_weights, axis=0) - # Return the mean loss over the batch of states. - return tf.reduce_mean(loss_policy_gradient, axis=0) + # Return the mean loss over the batch of states. + return tf.reduce_mean(loss_policy_gradient, axis=0) def compute_parametric_kl_penalty_and_dual_loss( - kl: tf.Tensor, - alpha: tf.Variable, - epsilon: float, + kl: tf.Tensor, alpha: tf.Variable, epsilon: float, ) -> Tuple[tf.Tensor, tf.Tensor]: - """Computes the KL cost to be added to the Lagragian and its dual loss. + """Computes the KL cost to be added to the Lagragian and its dual loss. The KL cost is simply the alpha-weighted KL divergence and it is added as a regularizer to the policy loss. The dual variable alpha itself has a loss that @@ -427,13 +465,13 @@ def compute_parametric_kl_penalty_and_dual_loss( loss_alpha: The Lagrange dual loss minimized to adapt alpha. """ - # Compute the mean KL over the batch. - mean_kl = tf.reduce_mean(kl, axis=0) + # Compute the mean KL over the batch. + mean_kl = tf.reduce_mean(kl, axis=0) - # Compute the regularization. - loss_kl = tf.reduce_sum(tf.stop_gradient(alpha) * mean_kl) + # Compute the regularization. + loss_kl = tf.reduce_sum(tf.stop_gradient(alpha) * mean_kl) - # Compute the dual loss. - loss_alpha = tf.reduce_sum(alpha * (epsilon - tf.stop_gradient(mean_kl))) + # Compute the dual loss. + loss_alpha = tf.reduce_sum(alpha * (epsilon - tf.stop_gradient(mean_kl))) - return loss_kl, loss_alpha + return loss_kl, loss_alpha diff --git a/acme/tf/losses/quantile.py b/acme/tf/losses/quantile.py index bfbb18b2d9..28622425ba 100644 --- a/acme/tf/losses/quantile.py +++ b/acme/tf/losses/quantile.py @@ -16,42 +16,40 @@ from typing import NamedTuple -from .huber import huber import sonnet as snt import tensorflow as tf +from .huber import huber + class QuantileDistribution(NamedTuple): - values: tf.Tensor - logits: tf.Tensor + values: tf.Tensor + logits: tf.Tensor class NonUniformQuantileRegression(snt.Module): - """Compute the quantile regression loss for the distributional TD error.""" + """Compute the quantile regression loss for the distributional TD error.""" - def __init__( - self, - huber_param: float = 0., - name: str = 'NUQuantileRegression'): - """Initializes the module. + def __init__(self, huber_param: float = 0.0, name: str = "NUQuantileRegression"): + """Initializes the module. Args: huber_param: The point where the huber loss function changes from a quadratic to linear. name: name to use for grouping operations. """ - super().__init__(name=name) - self._huber_param = huber_param - - def __call__( - self, - q_tm1: QuantileDistribution, - r_t: tf.Tensor, - pcont_t: tf.Tensor, - q_t: QuantileDistribution, - tau: tf.Tensor, - ) -> tf.Tensor: - """Calculates the loss. + super().__init__(name=name) + self._huber_param = huber_param + + def __call__( + self, + q_tm1: QuantileDistribution, + r_t: tf.Tensor, + pcont_t: tf.Tensor, + q_t: QuantileDistribution, + tau: tf.Tensor, + ) -> tf.Tensor: + """Calculates the loss. Note that this is only defined for discrete quantile-valued distributions. In particular we assume that the distributions define q.logits and @@ -67,28 +65,27 @@ def __call__( Returns: Value of the loss. """ - # Distributional Bellman update - values_t = (tf.reshape(r_t, (-1, 1)) + - tf.reshape(pcont_t, (-1, 1)) * q_t.values) - values_t = tf.stop_gradient(values_t) - probs_t = tf.nn.softmax(q_t.logits) - - # Quantile regression loss - # Tau gives the quantile regression targets, where in the sample - # space [0, 1] each output should train towards - # Tau applies along the second dimension in delta (below) - tau = tf.expand_dims(tau, -1) - - # quantile td-error and assymmetric weighting - delta = values_t[:, None, :] - q_tm1.values[:, :, None] - delta_neg = tf.cast(delta < 0., dtype=tf.float32) - # This stop_gradient is very important, do not remove - weight = tf.stop_gradient(tf.abs(tau - delta_neg)) - - # loss - loss = huber(delta, self._huber_param) * weight - loss = tf.reduce_sum(loss * probs_t[:, None, :], 2) - - # Have not been able to get quite as good performance with mean vs. sum - loss = tf.reduce_sum(loss, -1) - return loss + # Distributional Bellman update + values_t = tf.reshape(r_t, (-1, 1)) + tf.reshape(pcont_t, (-1, 1)) * q_t.values + values_t = tf.stop_gradient(values_t) + probs_t = tf.nn.softmax(q_t.logits) + + # Quantile regression loss + # Tau gives the quantile regression targets, where in the sample + # space [0, 1] each output should train towards + # Tau applies along the second dimension in delta (below) + tau = tf.expand_dims(tau, -1) + + # quantile td-error and assymmetric weighting + delta = values_t[:, None, :] - q_tm1.values[:, :, None] + delta_neg = tf.cast(delta < 0.0, dtype=tf.float32) + # This stop_gradient is very important, do not remove + weight = tf.stop_gradient(tf.abs(tau - delta_neg)) + + # loss + loss = huber(delta, self._huber_param) * weight + loss = tf.reduce_sum(loss * probs_t[:, None, :], 2) + + # Have not been able to get quite as good performance with mean vs. sum + loss = tf.reduce_sum(loss, -1) + return loss diff --git a/acme/tf/losses/r2d2.py b/acme/tf/losses/r2d2.py index 7c82868d09..ee50cb089b 100644 --- a/acme/tf/losses/r2d2.py +++ b/acme/tf/losses/r2d2.py @@ -21,8 +21,8 @@ class LossCoreExtra(NamedTuple): - targets: tf.Tensor - errors: tf.Tensor + targets: tf.Tensor + errors: tf.Tensor def transformed_n_step_loss( @@ -34,9 +34,9 @@ def transformed_n_step_loss( target_policy_probs: tf.Tensor, bootstrap_n: int, stop_targnet_gradients: bool = True, - name: str = 'transformed_n_step_loss', + name: str = "transformed_n_step_loss", ) -> trfl.base_ops.LossOutput: - """Helper function for computing transformed loss on sequences. + """Helper function for computing transformed loss on sequences. Args: qs: 3-D tensor corresponding to the Q-values to be learned. Shape is [T+1, @@ -66,65 +66,64 @@ def transformed_n_step_loss( * `LossCoreExtra`: namedtuple containing the fields `targets` and `errors`. """ - with tf.name_scope(name): - # Require correct tensor ranks---as long as we have shape information - # available to check. If there isn't any, we print a warning. - def check_rank(tensors: Iterable[tf.Tensor], ranks: Sequence[int]): - for i, (tensor, rank) in enumerate(zip(tensors, ranks)): - if tensor.get_shape(): - trfl.assert_rank_and_shape_compatibility([tensor], rank) - else: - raise ValueError( - f'Tensor "{tensor.name}", which was offered as transformed_n_step_loss' - f'parameter {i+1}, has no rank at construction time, so cannot verify' - f'that it has the necessary rank of {rank}') - - check_rank( - [qs, targnet_qs, actions, rewards, pcontinues, target_policy_probs], - [3, 3, 2, 2, 2, 3]) - - # Construct arguments to compute bootstrap target. - a_tm1 = actions[:-1] # (0:T) x B - r_t, pcont_t = rewards, pcontinues # (1:T+1) x B - q_tm1 = qs[:-1] # (0:T) x B x A - target_policy_t = target_policy_probs[1:] # (1:T+1) x B x A - targnet_q_t = targnet_qs[1:] # (1:T+1) x B x A - - bootstrap_value = tf.reduce_sum( - target_policy_t * _signed_parabolic_tx(targnet_q_t), -1) - target = _compute_n_step_sequence_targets( - r_t=r_t, - pcont_t=pcont_t, - bootstrap_value=bootstrap_value, - n=bootstrap_n) - - if stop_targnet_gradients: - target = tf.stop_gradient(target) - - # tx/inv_tx may result in numerical instabilities so mask any NaNs. - finite_mask = tf.math.is_finite(target) - target = tf.where(finite_mask, target, tf.zeros_like(target)) - - qa_tm1 = trfl.batched_index(q_tm1, a_tm1) - errors = qa_tm1 - _signed_hyperbolic_tx(target) - - # Only compute n-step errors w.r.t. finite targets. - errors = tf.where(finite_mask, errors, tf.zeros_like(errors)) - - # Sum over time dimension. - loss = 0.5 * tf.reduce_sum(tf.square(errors), axis=0) - - return trfl.base_ops.LossOutput( - loss, LossCoreExtra(targets=target, errors=errors)) + with tf.name_scope(name): + # Require correct tensor ranks---as long as we have shape information + # available to check. If there isn't any, we print a warning. + def check_rank(tensors: Iterable[tf.Tensor], ranks: Sequence[int]): + for i, (tensor, rank) in enumerate(zip(tensors, ranks)): + if tensor.get_shape(): + trfl.assert_rank_and_shape_compatibility([tensor], rank) + else: + raise ValueError( + f'Tensor "{tensor.name}", which was offered as transformed_n_step_loss' + f"parameter {i+1}, has no rank at construction time, so cannot verify" + f"that it has the necessary rank of {rank}" + ) + + check_rank( + [qs, targnet_qs, actions, rewards, pcontinues, target_policy_probs], + [3, 3, 2, 2, 2, 3], + ) + + # Construct arguments to compute bootstrap target. + a_tm1 = actions[:-1] # (0:T) x B + r_t, pcont_t = rewards, pcontinues # (1:T+1) x B + q_tm1 = qs[:-1] # (0:T) x B x A + target_policy_t = target_policy_probs[1:] # (1:T+1) x B x A + targnet_q_t = targnet_qs[1:] # (1:T+1) x B x A + + bootstrap_value = tf.reduce_sum( + target_policy_t * _signed_parabolic_tx(targnet_q_t), -1 + ) + target = _compute_n_step_sequence_targets( + r_t=r_t, pcont_t=pcont_t, bootstrap_value=bootstrap_value, n=bootstrap_n + ) + + if stop_targnet_gradients: + target = tf.stop_gradient(target) + + # tx/inv_tx may result in numerical instabilities so mask any NaNs. + finite_mask = tf.math.is_finite(target) + target = tf.where(finite_mask, target, tf.zeros_like(target)) + + qa_tm1 = trfl.batched_index(q_tm1, a_tm1) + errors = qa_tm1 - _signed_hyperbolic_tx(target) + + # Only compute n-step errors w.r.t. finite targets. + errors = tf.where(finite_mask, errors, tf.zeros_like(errors)) + + # Sum over time dimension. + loss = 0.5 * tf.reduce_sum(tf.square(errors), axis=0) + + return trfl.base_ops.LossOutput( + loss, LossCoreExtra(targets=target, errors=errors) + ) def _compute_n_step_sequence_targets( - r_t: tf.Tensor, - pcont_t: tf.Tensor, - bootstrap_value: tf.Tensor, - n: int, + r_t: tf.Tensor, pcont_t: tf.Tensor, bootstrap_value: tf.Tensor, n: int, ) -> tf.Tensor: - """Computes n-step bootstrapped returns over a sequence. + """Computes n-step bootstrapped returns over a sequence. Args: r_t: 2-D tensor of shape [T, B] corresponding to rewards. @@ -136,44 +135,43 @@ def _compute_n_step_sequence_targets( Returns: 2-D tensor of shape [T, B] corresponding to bootstrapped returns. """ - time_size, batch_size = r_t.shape.as_list() - - # Pad r_t and pcont_t so we can use static slice shapes in scan. - r_t = tf.concat([r_t, tf.zeros((n - 1, batch_size))], 0) - pcont_t = tf.concat([pcont_t, tf.ones((n - 1, batch_size))], 0) - - # We need to use tf.slice with static shapes for TPU compatibility. - def _slice(tensor, index, size): - return tf.slice(tensor, [index, 0], [size, batch_size]) - - # Construct correct bootstrap targets for each time slice t, which are exactly - # the target values at timestep min(t+n-1, time_size-1). - last_bootstrap_value = _slice(bootstrap_value, time_size - 1, 1) - if time_size > n - 1: - full_bootstrap_steps = [_slice(bootstrap_value, n - 1, time_size - (n - 1))] - truncated_bootstrap_steps = [last_bootstrap_value] * (n - 1) - else: - # Only truncated steps, since n > time_size. - full_bootstrap_steps = [] - truncated_bootstrap_steps = [last_bootstrap_value] * time_size - bootstrap_value = tf.concat(full_bootstrap_steps + truncated_bootstrap_steps, - 0) - - # Iterate backwards for n steps to construct n-step return targets. - targets = bootstrap_value - for i in range(n - 1, -1, -1): - this_pcont_t = _slice(pcont_t, i, time_size) - this_r_t = _slice(r_t, i, time_size) - targets = this_r_t + this_pcont_t * targets - return targets + time_size, batch_size = r_t.shape.as_list() + + # Pad r_t and pcont_t so we can use static slice shapes in scan. + r_t = tf.concat([r_t, tf.zeros((n - 1, batch_size))], 0) + pcont_t = tf.concat([pcont_t, tf.ones((n - 1, batch_size))], 0) + + # We need to use tf.slice with static shapes for TPU compatibility. + def _slice(tensor, index, size): + return tf.slice(tensor, [index, 0], [size, batch_size]) + + # Construct correct bootstrap targets for each time slice t, which are exactly + # the target values at timestep min(t+n-1, time_size-1). + last_bootstrap_value = _slice(bootstrap_value, time_size - 1, 1) + if time_size > n - 1: + full_bootstrap_steps = [_slice(bootstrap_value, n - 1, time_size - (n - 1))] + truncated_bootstrap_steps = [last_bootstrap_value] * (n - 1) + else: + # Only truncated steps, since n > time_size. + full_bootstrap_steps = [] + truncated_bootstrap_steps = [last_bootstrap_value] * time_size + bootstrap_value = tf.concat(full_bootstrap_steps + truncated_bootstrap_steps, 0) + + # Iterate backwards for n steps to construct n-step return targets. + targets = bootstrap_value + for i in range(n - 1, -1, -1): + this_pcont_t = _slice(pcont_t, i, time_size) + this_r_t = _slice(r_t, i, time_size) + targets = this_r_t + this_pcont_t * targets + return targets def _signed_hyperbolic_tx(x: tf.Tensor, eps: float = 1e-3) -> tf.Tensor: - """Signed hyperbolic transform, inverse of signed_parabolic.""" - return tf.sign(x) * (tf.sqrt(abs(x) + 1) - 1) + eps * x + """Signed hyperbolic transform, inverse of signed_parabolic.""" + return tf.sign(x) * (tf.sqrt(abs(x) + 1) - 1) + eps * x def _signed_parabolic_tx(x: tf.Tensor, eps: float = 1e-3) -> tf.Tensor: - """Signed parabolic transform, inverse of signed_hyperbolic.""" - z = tf.sqrt(1 + 4 * eps * (eps + 1 + abs(x))) / 2 / eps - 1 / 2 / eps - return tf.sign(x) * (tf.square(z) - 1) + """Signed parabolic transform, inverse of signed_hyperbolic.""" + z = tf.sqrt(1 + 4 * eps * (eps + 1 + abs(x))) / 2 / eps - 1 / 2 / eps + return tf.sign(x) * (tf.square(z) - 1) diff --git a/acme/tf/networks/__init__.py b/acme/tf/networks/__init__.py index 672421969e..0412af39e5 100644 --- a/acme/tf/networks/__init__.py +++ b/acme/tf/networks/__init__.py @@ -14,23 +14,27 @@ """Useful network definitions.""" -from acme.tf.networks.atari import AtariTorso -from acme.tf.networks.atari import DeepIMPALAAtariNetwork -from acme.tf.networks.atari import DQNAtariNetwork -from acme.tf.networks.atari import IMPALAAtariNetwork -from acme.tf.networks.atari import R2D2AtariNetwork -from acme.tf.networks.base import DistributionalModule -from acme.tf.networks.base import Module -from acme.tf.networks.base import RNNCore -from acme.tf.networks.continuous import LayerNormAndResidualMLP -from acme.tf.networks.continuous import LayerNormMLP -from acme.tf.networks.continuous import NearZeroInitializedLinear +from acme.tf.networks.atari import ( + AtariTorso, + DeepIMPALAAtariNetwork, + DQNAtariNetwork, + IMPALAAtariNetwork, + R2D2AtariNetwork, +) +from acme.tf.networks.base import DistributionalModule, Module, RNNCore +from acme.tf.networks.continuous import ( + LayerNormAndResidualMLP, + LayerNormMLP, + NearZeroInitializedLinear, +) from acme.tf.networks.discrete import DiscreteFilteredQNetwork -from acme.tf.networks.distributional import ApproximateMode -from acme.tf.networks.distributional import DiscreteValuedHead -from acme.tf.networks.distributional import MultivariateGaussianMixture -from acme.tf.networks.distributional import MultivariateNormalDiagHead -from acme.tf.networks.distributional import UnivariateGaussianMixture +from acme.tf.networks.distributional import ( + ApproximateMode, + DiscreteValuedHead, + MultivariateGaussianMixture, + MultivariateNormalDiagHead, + UnivariateGaussianMixture, +) from acme.tf.networks.distributions import DiscreteValuedDistribution from acme.tf.networks.duelling import DuellingMLP from acme.tf.networks.masked_epsilon_greedy import NetworkWithMaskedEpsilonGreedy @@ -38,29 +42,29 @@ from acme.tf.networks.multiplexers import CriticMultiplexer from acme.tf.networks.noise import ClippedGaussian from acme.tf.networks.policy_value import PolicyValueHead -from acme.tf.networks.recurrence import CriticDeepRNN -from acme.tf.networks.recurrence import DeepRNN -from acme.tf.networks.recurrence import LSTM -from acme.tf.networks.recurrence import RecurrentExpQWeightedPolicy -from acme.tf.networks.rescaling import ClipToSpec -from acme.tf.networks.rescaling import RescaleToSpec -from acme.tf.networks.rescaling import TanhToSpec -from acme.tf.networks.stochastic import ExpQWeightedPolicy -from acme.tf.networks.stochastic import StochasticMeanHead -from acme.tf.networks.stochastic import StochasticModeHead -from acme.tf.networks.stochastic import StochasticSamplingHead -from acme.tf.networks.vision import DrQTorso -from acme.tf.networks.vision import ResNetTorso +from acme.tf.networks.recurrence import ( + LSTM, + CriticDeepRNN, + DeepRNN, + RecurrentExpQWeightedPolicy, +) +from acme.tf.networks.rescaling import ClipToSpec, RescaleToSpec, TanhToSpec +from acme.tf.networks.stochastic import ( + ExpQWeightedPolicy, + StochasticMeanHead, + StochasticModeHead, + StochasticSamplingHead, +) +from acme.tf.networks.vision import DrQTorso, ResNetTorso # For backwards compatibility. GaussianMixtureHead = UnivariateGaussianMixture try: - # pylint: disable=g-bad-import-order,g-import-not-at-top - from acme.tf.networks.legal_actions import MaskedSequential - from acme.tf.networks.legal_actions import EpsilonGreedy + # pylint: disable=g-bad-import-order,g-import-not-at-top + from acme.tf.networks.legal_actions import EpsilonGreedy, MaskedSequential except ImportError: - pass + pass # Internal imports. from acme.tf.networks.quantile import IQNNetwork diff --git a/acme/tf/networks/atari.py b/acme/tf/networks/atari.py index 2b722e4704..adc0c190da 100644 --- a/acme/tf/networks/atari.py +++ b/acme/tf/networks/atari.py @@ -16,17 +16,12 @@ from typing import Optional, Tuple -from acme.tf.networks import base -from acme.tf.networks import duelling -from acme.tf.networks import embedding -from acme.tf.networks import policy_value -from acme.tf.networks import recurrence -from acme.tf.networks import vision -from acme.wrappers import observation_action_reward - import sonnet as snt import tensorflow as tf +from acme.tf.networks import base, duelling, embedding, policy_value, recurrence, vision +from acme.wrappers import observation_action_reward + Images = tf.Tensor QValues = tf.Tensor Logits = tf.Tensor @@ -34,158 +29,158 @@ class AtariTorso(base.Module): - """Simple convolutional stack commonly used for Atari.""" - - def __init__(self): - super().__init__(name='atari_torso') - self._network = snt.Sequential([ - snt.Conv2D(32, [8, 8], [4, 4]), - tf.nn.relu, - snt.Conv2D(64, [4, 4], [2, 2]), - tf.nn.relu, - snt.Conv2D(64, [3, 3], [1, 1]), - tf.nn.relu, - snt.Flatten(), - ]) - - def __call__(self, inputs: Images) -> tf.Tensor: - return self._network(inputs) + """Simple convolutional stack commonly used for Atari.""" + + def __init__(self): + super().__init__(name="atari_torso") + self._network = snt.Sequential( + [ + snt.Conv2D(32, [8, 8], [4, 4]), + tf.nn.relu, + snt.Conv2D(64, [4, 4], [2, 2]), + tf.nn.relu, + snt.Conv2D(64, [3, 3], [1, 1]), + tf.nn.relu, + snt.Flatten(), + ] + ) + + def __call__(self, inputs: Images) -> tf.Tensor: + return self._network(inputs) class DQNAtariNetwork(base.Module): - """A feed-forward network for use with Ape-X DQN. + """A feed-forward network for use with Ape-X DQN. See https://arxiv.org/pdf/1803.00933.pdf for more information. """ - def __init__(self, num_actions: int): - super().__init__(name='dqn_atari_network') - self._network = snt.Sequential([ - AtariTorso(), - duelling.DuellingMLP(num_actions, hidden_sizes=[512]), - ]) + def __init__(self, num_actions: int): + super().__init__(name="dqn_atari_network") + self._network = snt.Sequential( + [AtariTorso(), duelling.DuellingMLP(num_actions, hidden_sizes=[512]),] + ) - def __call__(self, inputs: Images) -> QValues: - return self._network(inputs) + def __call__(self, inputs: Images) -> QValues: + return self._network(inputs) class R2D2AtariNetwork(base.RNNCore): - """A recurrent network for use with R2D2. + """A recurrent network for use with R2D2. See https://openreview.net/forum?id=r1lyTjAqYX for more information. """ - def __init__(self, num_actions: int, core: Optional[base.RNNCore] = None): - super().__init__(name='r2d2_atari_network') - self._embed = embedding.OAREmbedding( - torso=AtariTorso(), num_actions=num_actions) - self._core = core if core is not None else recurrence.LSTM(512) - self._head = duelling.DuellingMLP(num_actions, hidden_sizes=[512]) + def __init__(self, num_actions: int, core: Optional[base.RNNCore] = None): + super().__init__(name="r2d2_atari_network") + self._embed = embedding.OAREmbedding( + torso=AtariTorso(), num_actions=num_actions + ) + self._core = core if core is not None else recurrence.LSTM(512) + self._head = duelling.DuellingMLP(num_actions, hidden_sizes=[512]) - def __call__( - self, - inputs: observation_action_reward.OAR, - state: base.State, - ) -> Tuple[QValues, base.State]: + def __call__( + self, inputs: observation_action_reward.OAR, state: base.State, + ) -> Tuple[QValues, base.State]: - embeddings = self._embed(inputs) - embeddings, new_state = self._core(embeddings, state) - action_values = self._head(embeddings) # [B, A] + embeddings = self._embed(inputs) + embeddings, new_state = self._core(embeddings, state) + action_values = self._head(embeddings) # [B, A] - return action_values, new_state + return action_values, new_state - # TODO(b/171287329): Figure out why return type annotation causes error. - def initial_state(self, batch_size: int, **unused_kwargs) -> base.State: # pytype: disable=invalid-annotation - return self._core.initial_state(batch_size) + # TODO(b/171287329): Figure out why return type annotation causes error. + def initial_state( + self, batch_size: int, **unused_kwargs + ) -> base.State: # pytype: disable=invalid-annotation + return self._core.initial_state(batch_size) - def unroll( - self, - inputs: observation_action_reward.OAR, - state: base.State, - sequence_length: int, - ) -> Tuple[QValues, base.State]: - """Efficient unroll that applies embeddings, MLP, & convnet in one pass.""" - embeddings = snt.BatchApply(self._embed)(inputs) # [T, B, D+A+1] - embeddings, new_state = self._core.unroll(embeddings, state, - sequence_length) - action_values = snt.BatchApply(self._head)(embeddings) + def unroll( + self, + inputs: observation_action_reward.OAR, + state: base.State, + sequence_length: int, + ) -> Tuple[QValues, base.State]: + """Efficient unroll that applies embeddings, MLP, & convnet in one pass.""" + embeddings = snt.BatchApply(self._embed)(inputs) # [T, B, D+A+1] + embeddings, new_state = self._core.unroll(embeddings, state, sequence_length) + action_values = snt.BatchApply(self._head)(embeddings) - return action_values, new_state + return action_values, new_state class IMPALAAtariNetwork(snt.RNNCore): - """A recurrent network for use with IMPALA. + """A recurrent network for use with IMPALA. See https://arxiv.org/pdf/1802.01561.pdf for more information. """ - def __init__(self, num_actions: int): - super().__init__(name='impala_atari_network') - self._embed = embedding.OAREmbedding( - torso=AtariTorso(), num_actions=num_actions) - self._core = snt.LSTM(256) - self._head = snt.Sequential([ - snt.Linear(256), - tf.nn.relu, - policy_value.PolicyValueHead(num_actions), - ]) - self._num_actions = num_actions + def __init__(self, num_actions: int): + super().__init__(name="impala_atari_network") + self._embed = embedding.OAREmbedding( + torso=AtariTorso(), num_actions=num_actions + ) + self._core = snt.LSTM(256) + self._head = snt.Sequential( + [snt.Linear(256), tf.nn.relu, policy_value.PolicyValueHead(num_actions),] + ) + self._num_actions = num_actions - def __call__( - self, inputs: observation_action_reward.OAR, - state: snt.LSTMState) -> Tuple[Tuple[Logits, Value], snt.LSTMState]: + def __call__( + self, inputs: observation_action_reward.OAR, state: snt.LSTMState + ) -> Tuple[Tuple[Logits, Value], snt.LSTMState]: - embeddings = self._embed(inputs) - embeddings, new_state = self._core(embeddings, state) - logits, value = self._head(embeddings) # [B, A] + embeddings = self._embed(inputs) + embeddings, new_state = self._core(embeddings, state) + logits, value = self._head(embeddings) # [B, A] - return (logits, value), new_state + return (logits, value), new_state - def initial_state(self, batch_size: int, **unused_kwargs) -> snt.LSTMState: - return self._core.initial_state(batch_size) + def initial_state(self, batch_size: int, **unused_kwargs) -> snt.LSTMState: + return self._core.initial_state(batch_size) class DeepIMPALAAtariNetwork(base.RNNCore): - """A recurrent network for use with IMPALA. + """A recurrent network for use with IMPALA. See https://arxiv.org/pdf/1802.01561.pdf for more information. """ - def __init__(self, num_actions: int): - super().__init__(name='deep_impala_atari_network') - self._embed = embedding.OAREmbedding( - torso=vision.ResNetTorso(), num_actions=num_actions) - self._core = snt.LSTM(256) - self._head = snt.Sequential([ - snt.Linear(256), - tf.nn.relu, - policy_value.PolicyValueHead(num_actions), - ]) - self._num_actions = num_actions - - def __call__( - self, inputs: observation_action_reward.OAR, - state: snt.LSTMState) -> Tuple[Tuple[Logits, Value], snt.LSTMState]: - - embeddings = self._embed(inputs) - embeddings, new_state = self._core(embeddings, state) - logits, value = self._head(embeddings) # [B, A] - - return (logits, value), new_state - - def initial_state(self, batch_size: int, **unused_kwargs) -> snt.LSTMState: - return self._core.initial_state(batch_size) - - def unroll( - self, - inputs: observation_action_reward.OAR, - states: snt.LSTMState, - sequence_length: int, - ) -> Tuple[Tuple[Logits, Value], snt.LSTMState]: - """Efficient unroll that applies embeddings, MLP, & convnet in one pass.""" - embeddings = snt.BatchApply(self._embed)(inputs) # [T, B, D+A+1] - embeddings, new_states = snt.static_unroll(self._core, embeddings, states, - sequence_length) - logits, values = snt.BatchApply(self._head)(embeddings) - - return (logits, values), new_states + def __init__(self, num_actions: int): + super().__init__(name="deep_impala_atari_network") + self._embed = embedding.OAREmbedding( + torso=vision.ResNetTorso(), num_actions=num_actions + ) + self._core = snt.LSTM(256) + self._head = snt.Sequential( + [snt.Linear(256), tf.nn.relu, policy_value.PolicyValueHead(num_actions),] + ) + self._num_actions = num_actions + + def __call__( + self, inputs: observation_action_reward.OAR, state: snt.LSTMState + ) -> Tuple[Tuple[Logits, Value], snt.LSTMState]: + + embeddings = self._embed(inputs) + embeddings, new_state = self._core(embeddings, state) + logits, value = self._head(embeddings) # [B, A] + + return (logits, value), new_state + + def initial_state(self, batch_size: int, **unused_kwargs) -> snt.LSTMState: + return self._core.initial_state(batch_size) + + def unroll( + self, + inputs: observation_action_reward.OAR, + states: snt.LSTMState, + sequence_length: int, + ) -> Tuple[Tuple[Logits, Value], snt.LSTMState]: + """Efficient unroll that applies embeddings, MLP, & convnet in one pass.""" + embeddings = snt.BatchApply(self._embed)(inputs) # [T, B, D+A+1] + embeddings, new_states = snt.static_unroll( + self._core, embeddings, states, sequence_length + ) + logits, values = snt.BatchApply(self._head)(embeddings) + + return (logits, values), new_states diff --git a/acme/tf/networks/base.py b/acme/tf/networks/base.py index ca6c01905a..40f48cf832 100644 --- a/acme/tf/networks/base.py +++ b/acme/tf/networks/base.py @@ -17,39 +17,38 @@ import abc from typing import Tuple, TypeVar -from acme import types import sonnet as snt import tensorflow_probability as tfp -State = TypeVar('State') +from acme import types + +State = TypeVar("State") class Module(snt.Module, abc.ABC): - """A base class for module with abstract __call__ method.""" + """A base class for module with abstract __call__ method.""" - @abc.abstractmethod - def __call__(self, *args, **kwargs) -> types.NestedTensor: - """Forward pass of the module.""" + @abc.abstractmethod + def __call__(self, *args, **kwargs) -> types.NestedTensor: + """Forward pass of the module.""" class DistributionalModule(snt.Module, abc.ABC): - """A base class for modules that output distributions.""" + """A base class for modules that output distributions.""" - @abc.abstractmethod - def __call__(self, *args, **kwargs) -> tfp.distributions.Distribution: - """Forward pass of the module.""" + @abc.abstractmethod + def __call__(self, *args, **kwargs) -> tfp.distributions.Distribution: + """Forward pass of the module.""" class RNNCore(snt.RNNCore, abc.ABC): - """An RNN core with a custom `unroll` function.""" - - @abc.abstractmethod - def unroll(self, - inputs: types.NestedTensor, - state: State, - sequence_length: int, - ) -> Tuple[types.NestedTensor, State]: - """A custom function for doing static unrolls over sequences. + """An RNN core with a custom `unroll` function.""" + + @abc.abstractmethod + def unroll( + self, inputs: types.NestedTensor, state: State, sequence_length: int, + ) -> Tuple[types.NestedTensor, State]: + """A custom function for doing static unrolls over sequences. This has the same API as `snt.static_unroll`, but allows the user to specify their own implementation to take advantage of the structure of the network diff --git a/acme/tf/networks/continuous.py b/acme/tf/networks/continuous.py index e003ac8d68..38c2bce501 100644 --- a/acme/tf/networks/continuous.py +++ b/acme/tf/networks/continuous.py @@ -16,38 +16,42 @@ from typing import Callable, Optional, Sequence +import sonnet as snt +import tensorflow as tf + from acme import types from acme.tf import utils as tf2_utils from acme.tf.networks import base -import sonnet as snt -import tensorflow as tf def _uniform_initializer(): - return tf.initializers.VarianceScaling( - distribution='uniform', mode='fan_out', scale=0.333) + return tf.initializers.VarianceScaling( + distribution="uniform", mode="fan_out", scale=0.333 + ) class NearZeroInitializedLinear(snt.Linear): - """Simple linear layer, initialized at near zero weights and zero biases.""" + """Simple linear layer, initialized at near zero weights and zero biases.""" - def __init__(self, output_size: int, scale: float = 1e-4): - super().__init__(output_size, w_init=tf.initializers.VarianceScaling(scale)) + def __init__(self, output_size: int, scale: float = 1e-4): + super().__init__(output_size, w_init=tf.initializers.VarianceScaling(scale)) class LayerNormMLP(snt.Module): - """Simple feedforward MLP torso with initial layer-norm. + """Simple feedforward MLP torso with initial layer-norm. This module is an MLP which uses LayerNorm (with a tanh normalizer) on the first layer and non-linearities (elu) on all but the last remaining layers. """ - def __init__(self, - layer_sizes: Sequence[int], - w_init: Optional[snt.initializers.Initializer] = None, - activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.elu, - activate_final: bool = False): - """Construct the MLP. + def __init__( + self, + layer_sizes: Sequence[int], + w_init: Optional[snt.initializers.Initializer] = None, + activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.elu, + activate_final: bool = False, + ): + """Construct the MLP. Args: layer_sizes: a sequence of ints specifying the size of each layer. @@ -57,82 +61,86 @@ def __init__(self, activate_final: whether or not to use the activation function on the final layer of the neural network. """ - super().__init__(name='feedforward_mlp_torso') - - self._network = snt.Sequential([ - snt.Linear(layer_sizes[0], w_init=w_init or _uniform_initializer()), - snt.LayerNorm( - axis=slice(1, None), create_scale=True, create_offset=True), - tf.nn.tanh, - snt.nets.MLP( - layer_sizes[1:], - w_init=w_init or _uniform_initializer(), - activation=activation, - activate_final=activate_final), - ]) - - def __call__(self, observations: types.Nest) -> tf.Tensor: - """Forwards the policy network.""" - return self._network(tf2_utils.batch_concat(observations)) + super().__init__(name="feedforward_mlp_torso") + + self._network = snt.Sequential( + [ + snt.Linear(layer_sizes[0], w_init=w_init or _uniform_initializer()), + snt.LayerNorm( + axis=slice(1, None), create_scale=True, create_offset=True + ), + tf.nn.tanh, + snt.nets.MLP( + layer_sizes[1:], + w_init=w_init or _uniform_initializer(), + activation=activation, + activate_final=activate_final, + ), + ] + ) + + def __call__(self, observations: types.Nest) -> tf.Tensor: + """Forwards the policy network.""" + return self._network(tf2_utils.batch_concat(observations)) class ResidualLayernormWrapper(snt.Module): - """Wrapper that applies residual connections and layer norm.""" + """Wrapper that applies residual connections and layer norm.""" - def __init__(self, layer: base.Module): - """Creates the Wrapper Class. + def __init__(self, layer: base.Module): + """Creates the Wrapper Class. Args: layer: module to wrap. """ - super().__init__(name='ResidualLayernormWrapper') - self._layer = layer + super().__init__(name="ResidualLayernormWrapper") + self._layer = layer - self._layer_norm = snt.LayerNorm( - axis=-1, create_scale=True, create_offset=True) + self._layer_norm = snt.LayerNorm(axis=-1, create_scale=True, create_offset=True) - def __call__(self, inputs: tf.Tensor): - """Returns the result of the residual and layernorm computation. + def __call__(self, inputs: tf.Tensor): + """Returns the result of the residual and layernorm computation. Args: inputs: inputs to the main module. """ - # Apply main module. - outputs = self._layer(inputs) - outputs = self._layer_norm(outputs + inputs) + # Apply main module. + outputs = self._layer(inputs) + outputs = self._layer_norm(outputs + inputs) - return outputs + return outputs class LayerNormAndResidualMLP(snt.Module): - """MLP with residual connections and layer norm. + """MLP with residual connections and layer norm. An MLP which applies residual connection and layer normalisation every two linear layers. Similar to Resnet, but with FC layers instead of convolutions. """ - def __init__(self, hidden_size: int, num_blocks: int): - """Create the model. + def __init__(self, hidden_size: int, num_blocks: int): + """Create the model. Args: hidden_size: width of each hidden layer. num_blocks: number of blocks, each block being MLP([hidden_size, hidden_size]) + layer norm + residual connection. """ - super().__init__(name='LayerNormAndResidualMLP') + super().__init__(name="LayerNormAndResidualMLP") - # Create initial MLP layer. - layers = [snt.nets.MLP([hidden_size], w_init=_uniform_initializer())] + # Create initial MLP layer. + layers = [snt.nets.MLP([hidden_size], w_init=_uniform_initializer())] - # Follow it up with num_blocks MLPs with layernorm and residual connections. - for _ in range(num_blocks): - mlp = snt.nets.MLP([hidden_size, hidden_size], - w_init=_uniform_initializer()) - layers.append(ResidualLayernormWrapper(mlp)) + # Follow it up with num_blocks MLPs with layernorm and residual connections. + for _ in range(num_blocks): + mlp = snt.nets.MLP( + [hidden_size, hidden_size], w_init=_uniform_initializer() + ) + layers.append(ResidualLayernormWrapper(mlp)) - self._module = snt.Sequential(layers) + self._module = snt.Sequential(layers) - def __call__(self, inputs: tf.Tensor): - return self._module(inputs) + def __call__(self, inputs: tf.Tensor): + return self._module(inputs) diff --git a/acme/tf/networks/discrete.py b/acme/tf/networks/discrete.py index b4134c35fa..529982a817 100644 --- a/acme/tf/networks/discrete.py +++ b/acme/tf/networks/discrete.py @@ -19,27 +19,24 @@ class DiscreteFilteredQNetwork(snt.Module): - """Discrete filtered Q-network. + """Discrete filtered Q-network. This produces filtered Q values according to the method used in the discrete BCQ algorithm (https://arxiv.org/pdf/1910.01708.pdf - section 4). """ - def __init__(self, - g_network: snt.Module, - q_network: snt.Module, - threshold: float): - super().__init__(name='discrete_filtered_qnet') - assert threshold >= 0 and threshold <= 1 - self.g_network = g_network - self.q_network = q_network - self._threshold = threshold - - def __call__(self, o_t: tf.Tensor) -> tf.Tensor: - q_t = self.q_network(o_t) - g_t = tf.nn.softmax(self.g_network(o_t)) - normalized_g_t = g_t / tf.reduce_max(g_t, axis=-1, keepdims=True) - - # Filter actions based on g_network outputs. - min_q = tf.reduce_min(q_t, axis=-1, keepdims=True) - return tf.where(normalized_g_t >= self._threshold, q_t, min_q) + def __init__(self, g_network: snt.Module, q_network: snt.Module, threshold: float): + super().__init__(name="discrete_filtered_qnet") + assert threshold >= 0 and threshold <= 1 + self.g_network = g_network + self.q_network = q_network + self._threshold = threshold + + def __call__(self, o_t: tf.Tensor) -> tf.Tensor: + q_t = self.q_network(o_t) + g_t = tf.nn.softmax(self.g_network(o_t)) + normalized_g_t = g_t / tf.reduce_max(g_t, axis=-1, keepdims=True) + + # Filter actions based on g_network outputs. + min_q = tf.reduce_min(q_t, axis=-1, keepdims=True) + return tf.where(normalized_g_t >= self._threshold, q_t, min_q) diff --git a/acme/tf/networks/distributional.py b/acme/tf/networks/distributional.py index e96dc29f29..d7b9ec0f89 100644 --- a/acme/tf/networks/distributional.py +++ b/acme/tf/networks/distributional.py @@ -20,12 +20,14 @@ import types from typing import Optional, Union -from absl import logging -from acme.tf.networks import distributions as ad + import numpy as np import sonnet as snt import tensorflow as tf import tensorflow_probability as tfp +from absl import logging + +from acme.tf.networks import distributions as ad tfd = tfp.distributions snt_init = snt.initializers @@ -34,19 +36,21 @@ class DiscreteValuedHead(snt.Module): - """Represents a parameterized discrete valued distribution. + """Represents a parameterized discrete valued distribution. The returned distribution is essentially a `tfd.Categorical`, but one which knows its support and so can compute the mean value. """ - def __init__(self, - vmin: Union[float, np.ndarray, tf.Tensor], - vmax: Union[float, np.ndarray, tf.Tensor], - num_atoms: int, - w_init: Optional[snt.initializers.Initializer] = None, - b_init: Optional[snt.initializers.Initializer] = None): - """Initialization. + def __init__( + self, + vmin: Union[float, np.ndarray, tf.Tensor], + vmax: Union[float, np.ndarray, tf.Tensor], + num_atoms: int, + w_init: Optional[snt.initializers.Initializer] = None, + b_init: Optional[snt.initializers.Initializer] = None, + ): + """Initialization. If vmin and vmax have shape S, this will store the category values as a Tensor of shape (S*, num_atoms). @@ -58,39 +62,42 @@ def __init__(self, w_init: Initialization for linear layer weights. b_init: Initialization for linear layer biases. """ - super().__init__(name='DiscreteValuedHead') - vmin = tf.convert_to_tensor(vmin) - vmax = tf.convert_to_tensor(vmax) - self._values = tf.linspace(vmin, vmax, num_atoms, axis=-1) - self._distributional_layer = snt.Linear(tf.size(self._values), - w_init=w_init, - b_init=b_init) - - def __call__(self, inputs: tf.Tensor) -> tfd.Distribution: - logits = self._distributional_layer(inputs) - logits = tf.reshape(logits, - tf.concat([tf.shape(logits)[:1], # batch size - tf.shape(self._values)], - axis=0)) - values = tf.cast(self._values, logits.dtype) - - return ad.DiscreteValuedDistribution(values=values, logits=logits) + super().__init__(name="DiscreteValuedHead") + vmin = tf.convert_to_tensor(vmin) + vmax = tf.convert_to_tensor(vmax) + self._values = tf.linspace(vmin, vmax, num_atoms, axis=-1) + self._distributional_layer = snt.Linear( + tf.size(self._values), w_init=w_init, b_init=b_init + ) + + def __call__(self, inputs: tf.Tensor) -> tfd.Distribution: + logits = self._distributional_layer(inputs) + logits = tf.reshape( + logits, + tf.concat( + [tf.shape(logits)[:1], tf.shape(self._values)], axis=0 # batch size + ), + ) + values = tf.cast(self._values, logits.dtype) + + return ad.DiscreteValuedDistribution(values=values, logits=logits) class MultivariateNormalDiagHead(snt.Module): - """Module that produces a multivariate normal distribution using tfd.Independent or tfd.MultivariateNormalDiag.""" - - def __init__( - self, - num_dimensions: int, - init_scale: float = 0.3, - min_scale: float = 1e-6, - tanh_mean: bool = False, - fixed_scale: bool = False, - use_tfd_independent: bool = False, - w_init: snt_init.Initializer = tf.initializers.VarianceScaling(1e-4), - b_init: snt_init.Initializer = tf.initializers.Zeros()): - """Initialization. + """Module that produces a multivariate normal distribution using tfd.Independent or tfd.MultivariateNormalDiag.""" + + def __init__( + self, + num_dimensions: int, + init_scale: float = 0.3, + min_scale: float = 1e-6, + tanh_mean: bool = False, + fixed_scale: bool = False, + use_tfd_independent: bool = False, + w_init: snt_init.Initializer = tf.initializers.VarianceScaling(1e-4), + b_init: snt_init.Initializer = tf.initializers.Zeros(), + ): + """Initialization. Args: num_dimensions: Number of dimensions of MVN distribution. @@ -104,51 +111,52 @@ def __init__( w_init: Initialization for linear layer weights. b_init: Initialization for linear layer biases. """ - super().__init__(name='MultivariateNormalDiagHead') - self._init_scale = init_scale - self._min_scale = min_scale - self._tanh_mean = tanh_mean - self._mean_layer = snt.Linear(num_dimensions, w_init=w_init, b_init=b_init) - self._fixed_scale = fixed_scale + super().__init__(name="MultivariateNormalDiagHead") + self._init_scale = init_scale + self._min_scale = min_scale + self._tanh_mean = tanh_mean + self._mean_layer = snt.Linear(num_dimensions, w_init=w_init, b_init=b_init) + self._fixed_scale = fixed_scale - if not fixed_scale: - self._scale_layer = snt.Linear( - num_dimensions, w_init=w_init, b_init=b_init) - self._use_tfd_independent = use_tfd_independent + if not fixed_scale: + self._scale_layer = snt.Linear(num_dimensions, w_init=w_init, b_init=b_init) + self._use_tfd_independent = use_tfd_independent - def __call__(self, inputs: tf.Tensor) -> tfd.Distribution: - zero = tf.constant(0, dtype=inputs.dtype) - mean = self._mean_layer(inputs) + def __call__(self, inputs: tf.Tensor) -> tfd.Distribution: + zero = tf.constant(0, dtype=inputs.dtype) + mean = self._mean_layer(inputs) - if self._fixed_scale: - scale = tf.ones_like(mean) * self._init_scale - else: - scale = tf.nn.softplus(self._scale_layer(inputs)) - scale *= self._init_scale / tf.nn.softplus(zero) - scale += self._min_scale + if self._fixed_scale: + scale = tf.ones_like(mean) * self._init_scale + else: + scale = tf.nn.softplus(self._scale_layer(inputs)) + scale *= self._init_scale / tf.nn.softplus(zero) + scale += self._min_scale - # Maybe transform the mean. - if self._tanh_mean: - mean = tf.tanh(mean) + # Maybe transform the mean. + if self._tanh_mean: + mean = tf.tanh(mean) - if self._use_tfd_independent: - dist = tfd.Independent(tfd.Normal(loc=mean, scale=scale)) - else: - dist = tfd.MultivariateNormalDiag(loc=mean, scale_diag=scale) + if self._use_tfd_independent: + dist = tfd.Independent(tfd.Normal(loc=mean, scale=scale)) + else: + dist = tfd.MultivariateNormalDiag(loc=mean, scale_diag=scale) - return dist + return dist class GaussianMixture(snt.Module): - """Module that outputs a Gaussian Mixture Distribution.""" - - def __init__(self, - num_dimensions: int, - num_components: int, - multivariate: bool, - init_scale: Optional[float] = None, - name: str = 'GaussianMixture'): - """Initialization. + """Module that outputs a Gaussian Mixture Distribution.""" + + def __init__( + self, + num_dimensions: int, + num_components: int, + multivariate: bool, + init_scale: Optional[float] = None, + name: str = "GaussianMixture", + ): + """Initialization. Args: num_dimensions: dimensionality of the output distribution @@ -157,38 +165,40 @@ def __init__(self, init_scale: the initial scale for the Gaussian mixture components. name: name of the module passed to snt.Module parent class. """ - super().__init__(name=name) - - self._num_dimensions = num_dimensions - self._num_components = num_components - self._multivariate = multivariate - - if init_scale is not None: - self._scale_factor = init_scale / tf.nn.softplus(0.) - else: - self._scale_factor = 1.0 # Corresponds to init_scale = softplus(0). - - # Define the weight initializer. - w_init = tf.initializers.VarianceScaling(1e-5) - - # Create a layer that outputs the unnormalized log-weights. - if self._multivariate: - logits_size = self._num_components - else: - logits_size = self._num_dimensions * self._num_components - self._logit_layer = snt.Linear(logits_size, w_init=w_init) - - # Create two layers that outputs a location and a scale, respectively, for - # each dimension and each component. - self._loc_layer = snt.Linear( - self._num_dimensions * self._num_components, w_init=w_init) - self._scale_layer = snt.Linear( - self._num_dimensions * self._num_components, w_init=w_init) - - def __call__(self, - inputs: tf.Tensor, - low_noise_policy: bool = False) -> tfd.Distribution: - """Run the networks through inputs. + super().__init__(name=name) + + self._num_dimensions = num_dimensions + self._num_components = num_components + self._multivariate = multivariate + + if init_scale is not None: + self._scale_factor = init_scale / tf.nn.softplus(0.0) + else: + self._scale_factor = 1.0 # Corresponds to init_scale = softplus(0). + + # Define the weight initializer. + w_init = tf.initializers.VarianceScaling(1e-5) + + # Create a layer that outputs the unnormalized log-weights. + if self._multivariate: + logits_size = self._num_components + else: + logits_size = self._num_dimensions * self._num_components + self._logit_layer = snt.Linear(logits_size, w_init=w_init) + + # Create two layers that outputs a location and a scale, respectively, for + # each dimension and each component. + self._loc_layer = snt.Linear( + self._num_dimensions * self._num_components, w_init=w_init + ) + self._scale_layer = snt.Linear( + self._num_dimensions * self._num_components, w_init=w_init + ) + + def __call__( + self, inputs: tf.Tensor, low_noise_policy: bool = False + ) -> tfd.Distribution: + """Run the networks through inputs. Args: inputs: hidden activations of the policy network body. @@ -200,54 +210,58 @@ def __call__(self, Mixture Gaussian distribution. """ - # Compute logits, locs, and scales if necessary. - logits = self._logit_layer(inputs) - locs = self._loc_layer(inputs) - - # When a low_noise_policy is requested, set the scales to its minimum value. - if low_noise_policy: - scales = tf.fill(locs.shape, _MIN_SCALE) - else: - scales = self._scale_layer(inputs) - scales = self._scale_factor * tf.nn.softplus(scales) + _MIN_SCALE - - if self._multivariate: - shape = [-1, self._num_components, self._num_dimensions] - # Reshape the mixture's location and scale parameters appropriately. - locs = tf.reshape(locs, shape) - scales = tf.reshape(scales, shape) - # In this case, no need to reshape logits as they are in the correct shape - # already, namely [batch_size, num_components]. - components_distribution = tfd.MultivariateNormalDiag( - loc=locs, scale_diag=scales) - else: - shape = [-1, self._num_dimensions, self._num_components] - # Reshape the mixture's location and scale parameters appropriately. - locs = tf.reshape(locs, shape) - scales = tf.reshape(scales, shape) - components_distribution = tfd.Normal(loc=locs, scale=scales) - logits = tf.reshape(logits, shape) - - # Create the mixture distribution. - distribution = tfd.MixtureSameFamily( - mixture_distribution=tfd.Categorical(logits=logits), - components_distribution=components_distribution) - - if not self._multivariate: - distribution = tfd.Independent(distribution) - - return distribution + # Compute logits, locs, and scales if necessary. + logits = self._logit_layer(inputs) + locs = self._loc_layer(inputs) + + # When a low_noise_policy is requested, set the scales to its minimum value. + if low_noise_policy: + scales = tf.fill(locs.shape, _MIN_SCALE) + else: + scales = self._scale_layer(inputs) + scales = self._scale_factor * tf.nn.softplus(scales) + _MIN_SCALE + + if self._multivariate: + shape = [-1, self._num_components, self._num_dimensions] + # Reshape the mixture's location and scale parameters appropriately. + locs = tf.reshape(locs, shape) + scales = tf.reshape(scales, shape) + # In this case, no need to reshape logits as they are in the correct shape + # already, namely [batch_size, num_components]. + components_distribution = tfd.MultivariateNormalDiag( + loc=locs, scale_diag=scales + ) + else: + shape = [-1, self._num_dimensions, self._num_components] + # Reshape the mixture's location and scale parameters appropriately. + locs = tf.reshape(locs, shape) + scales = tf.reshape(scales, shape) + components_distribution = tfd.Normal(loc=locs, scale=scales) + logits = tf.reshape(logits, shape) + + # Create the mixture distribution. + distribution = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical(logits=logits), + components_distribution=components_distribution, + ) + + if not self._multivariate: + distribution = tfd.Independent(distribution) + + return distribution class UnivariateGaussianMixture(GaussianMixture): - """Head which outputs a Mixture of Gaussians Distribution.""" + """Head which outputs a Mixture of Gaussians Distribution.""" - def __init__(self, - num_dimensions: int, - num_components: int = 5, - init_scale: Optional[float] = None, - num_mixtures: Optional[int] = None): - """Create an mixture of Gaussian actor head. + def __init__( + self, + num_dimensions: int, + num_components: int = 5, + init_scale: Optional[float] = None, + num_mixtures: Optional[int] = None, + ): + """Create an mixture of Gaussian actor head. Args: num_dimensions: dimensionality of the output distribution. Each dimension @@ -256,26 +270,32 @@ def __init__(self, init_scale: the initial scale for the Gaussian mixture components. num_mixtures: deprecated argument which overwrites num_components. """ - if num_mixtures is not None: - logging.warning("""the num_mixtures parameter has been deprecated; use + if num_mixtures is not None: + logging.warning( + """the num_mixtures parameter has been deprecated; use num_components instead; the value of num_components is being - ignored""") - num_components = num_mixtures - super().__init__(num_dimensions=num_dimensions, - num_components=num_components, - multivariate=False, - init_scale=init_scale, - name='UnivariateGaussianMixture') + ignored""" + ) + num_components = num_mixtures + super().__init__( + num_dimensions=num_dimensions, + num_components=num_components, + multivariate=False, + init_scale=init_scale, + name="UnivariateGaussianMixture", + ) class MultivariateGaussianMixture(GaussianMixture): - """Head which outputs a mixture of multivariate Gaussians distribution.""" + """Head which outputs a mixture of multivariate Gaussians distribution.""" - def __init__(self, - num_dimensions: int, - num_components: int = 5, - init_scale: Optional[float] = None): - """Initialization. + def __init__( + self, + num_dimensions: int, + num_components: int = 5, + init_scale: Optional[float] = None, + ): + """Initialization. Args: num_dimensions: dimensionality of the output distribution @@ -283,15 +303,17 @@ def __init__(self, num_components: number of mixture components. init_scale: the initial scale for the Gaussian mixture components. """ - super().__init__(num_dimensions=num_dimensions, - num_components=num_components, - multivariate=True, - init_scale=init_scale, - name='MultivariateGaussianMixture') + super().__init__( + num_dimensions=num_dimensions, + num_components=num_components, + multivariate=True, + init_scale=init_scale, + name="MultivariateGaussianMixture", + ) class ApproximateMode(snt.Module): - """Override the mode function of the distribution. + """Override the mode function of the distribution. For non-constant Jacobian transformed distributions, the mode is non-trivial to compute, so for these distributions the mode function is not supported in @@ -302,13 +324,15 @@ class ApproximateMode(snt.Module): constant Jacobian), this is a no-op. """ - def __call__(self, inputs: tfd.Distribution) -> tfd.Distribution: - if isinstance(inputs, tfd.TransformedDistribution): - if not inputs.bijector.is_constant_jacobian: - def _mode(self, **kwargs): - distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) - x = self.distribution.mode(**distribution_kwargs) - y = self.bijector.forward(x, **bijector_kwargs) - return y - inputs._mode = types.MethodType(_mode, inputs) - return inputs + def __call__(self, inputs: tfd.Distribution) -> tfd.Distribution: + if isinstance(inputs, tfd.TransformedDistribution): + if not inputs.bijector.is_constant_jacobian: + + def _mode(self, **kwargs): + distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) + x = self.distribution.mode(**distribution_kwargs) + y = self.bijector.forward(x, **bijector_kwargs) + return y + + inputs._mode = types.MethodType(_mode, inputs) + return inputs diff --git a/acme/tf/networks/distributional_test.py b/acme/tf/networks/distributional_test.py index 448dc2debf..c92484748d 100644 --- a/acme/tf/networks/distributional_test.py +++ b/acme/tf/networks/distributional_test.py @@ -14,55 +14,50 @@ """Tests for acme.tf.networks.distributional.""" -from acme.tf.networks import distributional import numpy as np +from absl.testing import absltest, parameterized from numpy import testing as npt -from absl.testing import absltest -from absl.testing import parameterized +from acme.tf.networks import distributional class DistributionalTest(parameterized.TestCase): - - @parameterized.parameters( - ((2, 3), (), (), 5, (2, 5)), - ((2, 3), (4, 1), (1, 5), 6, (2, 4, 5, 6)), - ) - def test_discrete_valued_head( - self, - input_shape, - vmin_shape, - vmax_shape, - num_atoms, - expected_logits_shape): - - vmin = np.zeros(vmin_shape, float) - vmax = np.ones(vmax_shape, float) - head = distributional.DiscreteValuedHead( - vmin=vmin, - vmax=vmax, - num_atoms=num_atoms) - input_array = np.zeros(input_shape, dtype=float) - output_distribution = head(input_array) - self.assertEqual(output_distribution.logits_parameter().shape, - expected_logits_shape) - - values = output_distribution._values - - # Can't do assert_allclose(values[..., 0], vmin), because the args may - # have broadcast-compatible but unequal shapes. Do the following instead: - npt.assert_allclose(values[..., 0] - vmin, np.zeros_like(values[..., 0])) - npt.assert_allclose(values[..., -1] - vmax, np.zeros_like(values[..., -1])) - - # Check that values are monotonically increasing. - intervals = values[..., 1:] - values[..., :-1] - npt.assert_array_less(np.zeros_like(intervals), intervals) - - # Check that the values are equally spaced. - npt.assert_allclose(intervals[..., 1:] - intervals[..., :1], - np.zeros_like(intervals[..., 1:]), - atol=1e-7) - - -if __name__ == '__main__': - absltest.main() + @parameterized.parameters( + ((2, 3), (), (), 5, (2, 5)), ((2, 3), (4, 1), (1, 5), 6, (2, 4, 5, 6)), + ) + def test_discrete_valued_head( + self, input_shape, vmin_shape, vmax_shape, num_atoms, expected_logits_shape + ): + + vmin = np.zeros(vmin_shape, float) + vmax = np.ones(vmax_shape, float) + head = distributional.DiscreteValuedHead( + vmin=vmin, vmax=vmax, num_atoms=num_atoms + ) + input_array = np.zeros(input_shape, dtype=float) + output_distribution = head(input_array) + self.assertEqual( + output_distribution.logits_parameter().shape, expected_logits_shape + ) + + values = output_distribution._values + + # Can't do assert_allclose(values[..., 0], vmin), because the args may + # have broadcast-compatible but unequal shapes. Do the following instead: + npt.assert_allclose(values[..., 0] - vmin, np.zeros_like(values[..., 0])) + npt.assert_allclose(values[..., -1] - vmax, np.zeros_like(values[..., -1])) + + # Check that values are monotonically increasing. + intervals = values[..., 1:] - values[..., :-1] + npt.assert_array_less(np.zeros_like(intervals), intervals) + + # Check that the values are equally spaced. + npt.assert_allclose( + intervals[..., 1:] - intervals[..., :1], + np.zeros_like(intervals[..., 1:]), + atol=1e-7, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/tf/networks/distributions.py b/acme/tf/networks/distributions.py index a5b14ea9c9..4580b5166d 100644 --- a/acme/tf/networks/distributions.py +++ b/acme/tf/networks/distributions.py @@ -15,6 +15,7 @@ """Distributions, for use in acme/networks/distributional.py.""" from typing import Optional + import tensorflow as tf import tensorflow_probability as tfp @@ -23,7 +24,7 @@ @tfp.experimental.auto_composite_tensor class DiscreteValuedDistribution(tfd.Categorical): - """This is a generalization of a categorical distribution. + """This is a generalization of a categorical distribution. The support for the DiscreteValued distribution can be any real valued range, whereas the categorical distribution has support [0, n_categories - 1] or @@ -31,12 +32,14 @@ class DiscreteValuedDistribution(tfd.Categorical): distribution over its support. """ - def __init__(self, - values: tf.Tensor, - logits: Optional[tf.Tensor] = None, - probs: Optional[tf.Tensor] = None, - name: str = 'DiscreteValuedDistribution'): - """Initialization. + def __init__( + self, + values: tf.Tensor, + logits: Optional[tf.Tensor] = None, + probs: Optional[tf.Tensor] = None, + name: str = "DiscreteValuedDistribution", + ): + """Initialization. Args: values: Values making up support of the distribution. Should have a shape @@ -52,56 +55,57 @@ def __init__(self, passed in. name: Name of the distribution object. """ - self._values = tf.convert_to_tensor(values) - shape_strings = [f'D{i}' for i, _ in enumerate(values.shape)] - - if logits is not None: - logits = tf.convert_to_tensor(logits) - tf.debugging.assert_shapes([(values, shape_strings), - (logits, [..., *shape_strings])]) - if probs is not None: - probs = tf.convert_to_tensor(probs) - tf.debugging.assert_shapes([(values, shape_strings), - (probs, [..., *shape_strings])]) - - super().__init__(logits=logits, probs=probs, name=name) - - self._parameters = dict(values=values, - logits=logits, - probs=probs, - name=name) - - @property - def values(self) -> tf.Tensor: - return self._values - - @classmethod - def _parameter_properties(cls, dtype, num_classes=None): - return dict( - values=tfp.util.ParameterProperties(event_ndims=None), - logits=tfp.util.ParameterProperties( - event_ndims=lambda self: self.values.shape.rank), - probs=tfp.util.ParameterProperties( - event_ndims=lambda self: self.values.shape.rank, - is_preferred=False)) - - def _sample_n(self, n, seed=None) -> tf.Tensor: - indices = super()._sample_n(n, seed=seed) - return tf.gather(self.values, indices, axis=-1) - - def _mean(self) -> tf.Tensor: - """Overrides the Categorical mean by incorporating category values.""" - return tf.reduce_sum(self.probs_parameter() * self.values, axis=-1) - - def _variance(self) -> tf.Tensor: - """Overrides the Categorical variance by incorporating category values.""" - dist_squared = tf.square(tf.expand_dims(self.mean(), -1) - self.values) - return tf.reduce_sum(self.probs_parameter() * dist_squared, axis=-1) - - def _event_shape(self): - # Omit the atoms axis, to return just the shape of a single (i.e. unbatched) - # sample value. - return self._values.shape[:-1] - - def _event_shape_tensor(self): - return tf.shape(self._values)[:-1] + self._values = tf.convert_to_tensor(values) + shape_strings = [f"D{i}" for i, _ in enumerate(values.shape)] + + if logits is not None: + logits = tf.convert_to_tensor(logits) + tf.debugging.assert_shapes( + [(values, shape_strings), (logits, [..., *shape_strings])] + ) + if probs is not None: + probs = tf.convert_to_tensor(probs) + tf.debugging.assert_shapes( + [(values, shape_strings), (probs, [..., *shape_strings])] + ) + + super().__init__(logits=logits, probs=probs, name=name) + + self._parameters = dict(values=values, logits=logits, probs=probs, name=name) + + @property + def values(self) -> tf.Tensor: + return self._values + + @classmethod + def _parameter_properties(cls, dtype, num_classes=None): + return dict( + values=tfp.util.ParameterProperties(event_ndims=None), + logits=tfp.util.ParameterProperties( + event_ndims=lambda self: self.values.shape.rank + ), + probs=tfp.util.ParameterProperties( + event_ndims=lambda self: self.values.shape.rank, is_preferred=False + ), + ) + + def _sample_n(self, n, seed=None) -> tf.Tensor: + indices = super()._sample_n(n, seed=seed) + return tf.gather(self.values, indices, axis=-1) + + def _mean(self) -> tf.Tensor: + """Overrides the Categorical mean by incorporating category values.""" + return tf.reduce_sum(self.probs_parameter() * self.values, axis=-1) + + def _variance(self) -> tf.Tensor: + """Overrides the Categorical variance by incorporating category values.""" + dist_squared = tf.square(tf.expand_dims(self.mean(), -1) - self.values) + return tf.reduce_sum(self.probs_parameter() * dist_squared, axis=-1) + + def _event_shape(self): + # Omit the atoms axis, to return just the shape of a single (i.e. unbatched) + # sample value. + return self._values.shape[:-1] + + def _event_shape_tensor(self): + return tf.shape(self._values)[:-1] diff --git a/acme/tf/networks/distributions_test.py b/acme/tf/networks/distributions_test.py index 211eed8ed8..a4539300ff 100644 --- a/acme/tf/networks/distributions_test.py +++ b/acme/tf/networks/distributions_test.py @@ -14,54 +14,57 @@ """Tests for acme.tf.networks.distributions.""" -from acme.tf.networks import distributions import numpy as np +from absl.testing import absltest, parameterized from numpy import testing as npt -from absl.testing import absltest -from absl.testing import parameterized +from acme.tf.networks import distributions class DiscreteValuedDistributionTest(parameterized.TestCase): + @parameterized.parameters( + ((), (), 5), + ((2,), (), 5), + ((), (3, 4), 5), + ((2,), (3, 4), 5), + ((2, 6), (3, 4), 5), + ) + def test_constructor(self, batch_shape, event_shape, num_values): + logits_shape = batch_shape + event_shape + (num_values,) + logits_size = np.prod(logits_shape) + logits = np.arange(logits_size, dtype=float).reshape(logits_shape) + values = np.linspace( + start=-np.ones(event_shape, dtype=float), + stop=np.ones(event_shape, dtype=float), + num=num_values, + axis=-1, + ) + distribution = distributions.DiscreteValuedDistribution( + values=values, logits=logits + ) - @parameterized.parameters( - ((), (), 5), - ((2,), (), 5), - ((), (3, 4), 5), - ((2,), (3, 4), 5), - ((2, 6), (3, 4), 5), - ) - def test_constructor(self, batch_shape, event_shape, num_values): - logits_shape = batch_shape + event_shape + (num_values,) - logits_size = np.prod(logits_shape) - logits = np.arange(logits_size, dtype=float).reshape(logits_shape) - values = np.linspace(start=-np.ones(event_shape, dtype=float), - stop=np.ones(event_shape, dtype=float), - num=num_values, - axis=-1) - distribution = distributions.DiscreteValuedDistribution(values=values, - logits=logits) - - # Check batch and event shapes. - self.assertEqual(distribution.batch_shape, batch_shape) - self.assertEqual(distribution.event_shape, event_shape) - self.assertEqual(distribution.logits_parameter().shape.as_list(), - list(logits.shape)) - self.assertEqual(distribution.logits_parameter().shape.as_list()[-1], - logits.shape[-1]) + # Check batch and event shapes. + self.assertEqual(distribution.batch_shape, batch_shape) + self.assertEqual(distribution.event_shape, event_shape) + self.assertEqual( + distribution.logits_parameter().shape.as_list(), list(logits.shape) + ) + self.assertEqual( + distribution.logits_parameter().shape.as_list()[-1], logits.shape[-1] + ) - # Test slicing - if len(batch_shape) == 1: - slice_0_logits = distribution[1:3].logits_parameter().numpy() - expected_slice_0_logits = distribution.logits_parameter().numpy()[1:3] - npt.assert_allclose(slice_0_logits, expected_slice_0_logits) - elif len(batch_shape) == 2: - slice_logits = distribution[0, 1:3].logits_parameter().numpy() - expected_slice_logits = distribution.logits_parameter().numpy()[0, 1:3] - npt.assert_allclose(slice_logits, expected_slice_logits) - else: - assert not batch_shape + # Test slicing + if len(batch_shape) == 1: + slice_0_logits = distribution[1:3].logits_parameter().numpy() + expected_slice_0_logits = distribution.logits_parameter().numpy()[1:3] + npt.assert_allclose(slice_0_logits, expected_slice_0_logits) + elif len(batch_shape) == 2: + slice_logits = distribution[0, 1:3].logits_parameter().numpy() + expected_slice_logits = distribution.logits_parameter().numpy()[0, 1:3] + npt.assert_allclose(slice_logits, expected_slice_logits) + else: + assert not batch_shape -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/tf/networks/duelling.py b/acme/tf/networks/duelling.py index ed891f5a47..10b64ca79b 100644 --- a/acme/tf/networks/duelling.py +++ b/acme/tf/networks/duelling.py @@ -24,20 +24,18 @@ class DuellingMLP(snt.Module): - """A Duelling MLP Q-network.""" + """A Duelling MLP Q-network.""" - def __init__( - self, - num_actions: int, - hidden_sizes: Sequence[int], - ): - super().__init__(name='duelling_q_network') + def __init__( + self, num_actions: int, hidden_sizes: Sequence[int], + ): + super().__init__(name="duelling_q_network") - self._value_mlp = snt.nets.MLP([*hidden_sizes, 1]) - self._advantage_mlp = snt.nets.MLP([*hidden_sizes, num_actions]) + self._value_mlp = snt.nets.MLP([*hidden_sizes, 1]) + self._advantage_mlp = snt.nets.MLP([*hidden_sizes, num_actions]) - def __call__(self, inputs: tf.Tensor) -> tf.Tensor: - """Forward pass of the duelling network. + def __call__(self, inputs: tf.Tensor) -> tf.Tensor: + """Forward pass of the duelling network. Args: inputs: 2-D tensor of shape [batch_size, embedding_size]. @@ -46,13 +44,13 @@ def __call__(self, inputs: tf.Tensor) -> tf.Tensor: q_values: 2-D tensor of action values of shape [batch_size, num_actions] """ - # Compute value & advantage for duelling. - value = self._value_mlp(inputs) # [B, 1] - advantages = self._advantage_mlp(inputs) # [B, A] + # Compute value & advantage for duelling. + value = self._value_mlp(inputs) # [B, 1] + advantages = self._advantage_mlp(inputs) # [B, A] - # Advantages have zero mean. - advantages -= tf.reduce_mean(advantages, axis=-1, keepdims=True) # [B, A] + # Advantages have zero mean. + advantages -= tf.reduce_mean(advantages, axis=-1, keepdims=True) # [B, A] - q_values = value + advantages # [B, A] + q_values = value + advantages # [B, A] - return q_values + return q_values diff --git a/acme/tf/networks/embedding.py b/acme/tf/networks/embedding.py index c11b765d86..33bc048ebd 100644 --- a/acme/tf/networks/embedding.py +++ b/acme/tf/networks/embedding.py @@ -14,32 +14,32 @@ """Modules for computing custom embeddings.""" -from acme.tf.networks import base -from acme.wrappers import observation_action_reward - import sonnet as snt import tensorflow as tf +from acme.tf.networks import base +from acme.wrappers import observation_action_reward + class OAREmbedding(snt.Module): - """Module for embedding (observation, action, reward) inputs together.""" + """Module for embedding (observation, action, reward) inputs together.""" - def __init__(self, torso: base.Module, num_actions: int): - super().__init__(name='oar_embedding') - self._num_actions = num_actions - self._torso = torso + def __init__(self, torso: base.Module, num_actions: int): + super().__init__(name="oar_embedding") + self._num_actions = num_actions + self._torso = torso - def __call__(self, inputs: observation_action_reward.OAR) -> tf.Tensor: - """Embed each of the (observation, action, reward) inputs & concatenate.""" + def __call__(self, inputs: observation_action_reward.OAR) -> tf.Tensor: + """Embed each of the (observation, action, reward) inputs & concatenate.""" - # Add dummy trailing dimension to rewards if necessary. - if len(inputs.reward.shape.dims) == 1: - inputs = inputs._replace(reward=tf.expand_dims(inputs.reward, axis=-1)) + # Add dummy trailing dimension to rewards if necessary. + if len(inputs.reward.shape.dims) == 1: + inputs = inputs._replace(reward=tf.expand_dims(inputs.reward, axis=-1)) - features = self._torso(inputs.observation) # [T?, B, D] - action = tf.one_hot(inputs.action, depth=self._num_actions) # [T?, B, A] - reward = tf.nn.tanh(inputs.reward) # [T?, B, 1] + features = self._torso(inputs.observation) # [T?, B, D] + action = tf.one_hot(inputs.action, depth=self._num_actions) # [T?, B, A] + reward = tf.nn.tanh(inputs.reward) # [T?, B, 1] - embedding = tf.concat([features, action, reward], axis=-1) # [T?, B, D+A+1] + embedding = tf.concat([features, action, reward], axis=-1) # [T?, B, D+A+1] - return embedding + return embedding diff --git a/acme/tf/networks/legal_actions.py b/acme/tf/networks/legal_actions.py index 84273905bf..cdffc8de63 100644 --- a/acme/tf/networks/legal_actions.py +++ b/acme/tf/networks/legal_actions.py @@ -16,53 +16,60 @@ from typing import Any, Callable, Iterable, Optional, Union -# pytype: disable=import-error -from acme.wrappers import open_spiel_wrapper -# pytype: enable=import-error - import numpy as np import sonnet as snt import tensorflow as tf import tensorflow_probability as tfp +# pytype: disable=import-error +from acme.wrappers import open_spiel_wrapper + +# pytype: enable=import-error + + tfd = tfp.distributions class MaskedSequential(snt.Module): - """Applies a legal actions mask to a linear chain of modules / callables. + """Applies a legal actions mask to a linear chain of modules / callables. It is assumed the trailing dimension of the final layer (representing action values) is the same as the trailing dimension of legal_actions. """ - def __init__(self, - layers: Optional[Iterable[Callable[..., Any]]] = None, - name: str = 'MaskedSequential'): - super().__init__(name=name) - self._layers = list(layers) if layers is not None else [] - self._illegal_action_penalty = -1e9 - # Note: illegal_action_penalty cannot be -np.inf because trfl's qlearning - # ops utilize a batched_index function that returns NaN whenever -np.inf - # is present among action values. - - def __call__(self, inputs: open_spiel_wrapper.OLT) -> tf.Tensor: - # Extract observation, legal actions, and terminal - outputs = inputs.observation - legal_actions = inputs.legal_actions - terminal = inputs.terminal - - for mod in self._layers: - outputs = mod(outputs) - - # Apply legal actions mask - outputs = tf.where(tf.equal(legal_actions, 1), outputs, - tf.fill(tf.shape(outputs), self._illegal_action_penalty)) - - # When computing the Q-learning target (r_t + d_t * max q_t) we need to - # ensure max q_t = 0 in terminal states. - outputs = tf.where(tf.equal(terminal, 1), tf.zeros_like(outputs), outputs) - - return outputs + def __init__( + self, + layers: Optional[Iterable[Callable[..., Any]]] = None, + name: str = "MaskedSequential", + ): + super().__init__(name=name) + self._layers = list(layers) if layers is not None else [] + self._illegal_action_penalty = -1e9 + # Note: illegal_action_penalty cannot be -np.inf because trfl's qlearning + # ops utilize a batched_index function that returns NaN whenever -np.inf + # is present among action values. + + def __call__(self, inputs: open_spiel_wrapper.OLT) -> tf.Tensor: + # Extract observation, legal actions, and terminal + outputs = inputs.observation + legal_actions = inputs.legal_actions + terminal = inputs.terminal + + for mod in self._layers: + outputs = mod(outputs) + + # Apply legal actions mask + outputs = tf.where( + tf.equal(legal_actions, 1), + outputs, + tf.fill(tf.shape(outputs), self._illegal_action_penalty), + ) + + # When computing the Q-learning target (r_t + d_t * max q_t) we need to + # ensure max q_t = 0 in terminal states. + outputs = tf.where(tf.equal(terminal, 1), tf.zeros_like(outputs), outputs) + + return outputs # FIXME: Add functionality to support decaying epsilon parameter. @@ -70,7 +77,7 @@ def __call__(self, inputs: open_spiel_wrapper.OLT) -> tf.Tensor: # incorporates code from the bug fix described here # https://github.com/deepmind/trfl/pull/28 class EpsilonGreedy(snt.Module): - """Computes an epsilon-greedy distribution over actions. + """Computes an epsilon-greedy distribution over actions. This policy does the following: - With probability 1 - epsilon, take the action corresponding to the highest @@ -78,11 +85,13 @@ class EpsilonGreedy(snt.Module): - With probability epsilon, take an action uniformly at random. """ - def __init__(self, - epsilon: Union[tf.Tensor, float], - threshold: float, - name: str = 'EpsilonGreedy'): - """Initialize the policy. + def __init__( + self, + epsilon: Union[tf.Tensor, float], + threshold: float, + name: str = "EpsilonGreedy", + ): + """Initialize the policy. Args: epsilon: Exploratory param with value between 0 and 1. @@ -94,34 +103,40 @@ def __init__(self, policy: tfp.distributions.Categorical distribution representing the policy. """ - super().__init__(name=name) - self._epsilon = tf.Variable(epsilon, trainable=False) - self._threshold = threshold - - def __call__(self, action_values: tf.Tensor) -> tfd.Categorical: - legal_actions_mask = tf.where( - tf.math.less_equal(action_values, self._threshold), - tf.fill(tf.shape(action_values), 0.), - tf.fill(tf.shape(action_values), 1.)) - - # Dithering action distribution. - dither_probs = 1 / tf.reduce_sum(legal_actions_mask, axis=-1, - keepdims=True) * legal_actions_mask - masked_action_values = tf.where(tf.equal(legal_actions_mask, 1), - action_values, - tf.fill(tf.shape(action_values), -np.inf)) - # Greedy action distribution, breaking ties uniformly at random. - max_value = tf.reduce_max(masked_action_values, axis=-1, keepdims=True) - greedy_probs = tf.cast( - tf.equal(action_values * legal_actions_mask, max_value), - action_values.dtype) - - greedy_probs /= tf.reduce_sum(greedy_probs, axis=-1, keepdims=True) - - # Epsilon-greedy action distribution. - probs = self._epsilon * dither_probs + (1 - self._epsilon) * greedy_probs - - # Make the policy object. - policy = tfd.Categorical(probs=probs) - - return policy + super().__init__(name=name) + self._epsilon = tf.Variable(epsilon, trainable=False) + self._threshold = threshold + + def __call__(self, action_values: tf.Tensor) -> tfd.Categorical: + legal_actions_mask = tf.where( + tf.math.less_equal(action_values, self._threshold), + tf.fill(tf.shape(action_values), 0.0), + tf.fill(tf.shape(action_values), 1.0), + ) + + # Dithering action distribution. + dither_probs = ( + 1 + / tf.reduce_sum(legal_actions_mask, axis=-1, keepdims=True) + * legal_actions_mask + ) + masked_action_values = tf.where( + tf.equal(legal_actions_mask, 1), + action_values, + tf.fill(tf.shape(action_values), -np.inf), + ) + # Greedy action distribution, breaking ties uniformly at random. + max_value = tf.reduce_max(masked_action_values, axis=-1, keepdims=True) + greedy_probs = tf.cast( + tf.equal(action_values * legal_actions_mask, max_value), action_values.dtype + ) + + greedy_probs /= tf.reduce_sum(greedy_probs, axis=-1, keepdims=True) + + # Epsilon-greedy action distribution. + probs = self._epsilon * dither_probs + (1 - self._epsilon) * greedy_probs + + # Make the policy object. + policy = tfd.Categorical(probs=probs) + + return policy diff --git a/acme/tf/networks/masked_epsilon_greedy.py b/acme/tf/networks/masked_epsilon_greedy.py index bf707dac50..45d0385d25 100644 --- a/acme/tf/networks/masked_epsilon_greedy.py +++ b/acme/tf/networks/masked_epsilon_greedy.py @@ -14,7 +14,7 @@ """Wrapping trfl epsilon_greedy with legal action masking.""" -from typing import Optional, Mapping, Union +from typing import Mapping, Optional, Union import sonnet as snt import tensorflow as tf @@ -22,12 +22,10 @@ class NetworkWithMaskedEpsilonGreedy(snt.Module): - """Epsilon greedy sampling with action masking on network outputs.""" + """Epsilon greedy sampling with action masking on network outputs.""" - def __init__(self, - network: snt.Module, - epsilon: Optional[tf.Tensor] = None): - """Initialize the network and epsilon. + def __init__(self, network: snt.Module, epsilon: Optional[tf.Tensor] = None): + """Initialize the network and epsilon. Usage: Wrap an observation in a dictionary in your environment as follows: @@ -42,14 +40,16 @@ def __init__(self, network: the online Q network (the one being optimized) epsilon: probability of taking a random action. """ - super().__init__() - self._network = network - self._epsilon = epsilon - - def __call__( - self, observation: Union[Mapping[str, tf.Tensor], - tf.Tensor]) -> tf.Tensor: - q = self._network(observation) - return trfl.epsilon_greedy( - q, epsilon=self._epsilon, - legal_actions_mask=observation['legal_actions_mask']).sample() + super().__init__() + self._network = network + self._epsilon = epsilon + + def __call__( + self, observation: Union[Mapping[str, tf.Tensor], tf.Tensor] + ) -> tf.Tensor: + q = self._network(observation) + return trfl.epsilon_greedy( + q, + epsilon=self._epsilon, + legal_actions_mask=observation["legal_actions_mask"], + ).sample() diff --git a/acme/tf/networks/multihead.py b/acme/tf/networks/multihead.py index 49a731d8f1..b3eb264281 100644 --- a/acme/tf/networks/multihead.py +++ b/acme/tf/networks/multihead.py @@ -14,39 +14,36 @@ """Multihead networks apply separate networks to the input.""" -from typing import Callable, Union, Sequence - -from acme import types +from typing import Callable, Sequence, Union import sonnet as snt import tensorflow as tf import tensorflow_probability as tfp +from acme import types + tfd = tfp.distributions -TensorTransformation = Union[snt.Module, Callable[[types.NestedTensor], - tf.Tensor]] +TensorTransformation = Union[snt.Module, Callable[[types.NestedTensor], tf.Tensor]] class Multihead(snt.Module): - """Multi-head network module. + """Multi-head network module. This takes as input a list of N `network_heads`, and returns another network whose output is the stacked outputs of each of these network heads separately applied to the module input. The dimension of the output is [..., N]. """ - def __init__(self, - network_heads: Sequence[TensorTransformation]): - if not network_heads: - raise ValueError('Must specify non-empty, non-None critic_network_heads.') - self._network_heads = network_heads - super().__init__(name='multihead') - - def __call__(self, - inputs: tf.Tensor) -> Union[tf.Tensor, Sequence[tf.Tensor]]: - outputs = [network_head(inputs) for network_head in self._network_heads] - if isinstance(outputs[0], tfd.Distribution): - # Cannot stack distributions - return outputs - outputs = tf.stack(outputs, axis=-1) - return outputs + def __init__(self, network_heads: Sequence[TensorTransformation]): + if not network_heads: + raise ValueError("Must specify non-empty, non-None critic_network_heads.") + self._network_heads = network_heads + super().__init__(name="multihead") + + def __call__(self, inputs: tf.Tensor) -> Union[tf.Tensor, Sequence[tf.Tensor]]: + outputs = [network_head(inputs) for network_head in self._network_heads] + if isinstance(outputs[0], tfd.Distribution): + # Cannot stack distributions + return outputs + outputs = tf.stack(outputs, axis=-1) + return outputs diff --git a/acme/tf/networks/multiplexers.py b/acme/tf/networks/multiplexers.py index 32815373d5..d9c933e5e5 100644 --- a/acme/tf/networks/multiplexers.py +++ b/acme/tf/networks/multiplexers.py @@ -16,20 +16,19 @@ from typing import Callable, Optional, Union -from acme import types -from acme.tf import utils as tf2_utils - import sonnet as snt import tensorflow as tf import tensorflow_probability as tfp +from acme import types +from acme.tf import utils as tf2_utils + tfd = tfp.distributions -TensorTransformation = Union[snt.Module, Callable[[types.NestedTensor], - tf.Tensor]] +TensorTransformation = Union[snt.Module, Callable[[types.NestedTensor], tf.Tensor]] class CriticMultiplexer(snt.Module): - """Module connecting a critic torso to (transformed) observations/actions. + """Module connecting a critic torso to (transformed) observations/actions. This takes as input a `critic_network`, an `observation_network`, and an `action_network` and returns another network whose outputs are given by @@ -45,35 +44,37 @@ class CriticMultiplexer(snt.Module): module reduces to a simple `tf2_utils.batch_concat()`. """ - def __init__(self, - critic_network: Optional[TensorTransformation] = None, - observation_network: Optional[TensorTransformation] = None, - action_network: Optional[TensorTransformation] = None): - self._critic_network = critic_network - self._observation_network = observation_network - self._action_network = action_network - super().__init__(name='critic_multiplexer') - - def __call__(self, - observation: types.NestedTensor, - action: types.NestedTensor) -> tf.Tensor: - - # Maybe transform observations and actions before feeding them on. - if self._observation_network: - observation = self._observation_network(observation) - if self._action_network: - action = self._action_network(action) - - if hasattr(observation, 'dtype') and hasattr(action, 'dtype'): - if observation.dtype != action.dtype: - # Observation and action must be the same type for concat to work - action = tf.cast(action, observation.dtype) - - # Concat observations and actions, with one batch dimension. - outputs = tf2_utils.batch_concat([observation, action]) - - # Maybe transform output before returning. - if self._critic_network: - outputs = self._critic_network(outputs) - - return outputs + def __init__( + self, + critic_network: Optional[TensorTransformation] = None, + observation_network: Optional[TensorTransformation] = None, + action_network: Optional[TensorTransformation] = None, + ): + self._critic_network = critic_network + self._observation_network = observation_network + self._action_network = action_network + super().__init__(name="critic_multiplexer") + + def __call__( + self, observation: types.NestedTensor, action: types.NestedTensor + ) -> tf.Tensor: + + # Maybe transform observations and actions before feeding them on. + if self._observation_network: + observation = self._observation_network(observation) + if self._action_network: + action = self._action_network(action) + + if hasattr(observation, "dtype") and hasattr(action, "dtype"): + if observation.dtype != action.dtype: + # Observation and action must be the same type for concat to work + action = tf.cast(action, observation.dtype) + + # Concat observations and actions, with one batch dimension. + outputs = tf2_utils.batch_concat([observation, action]) + + # Maybe transform output before returning. + if self._critic_network: + outputs = self._critic_network(outputs) + + return outputs diff --git a/acme/tf/networks/noise.py b/acme/tf/networks/noise.py index 6a48223335..d7011cbf0b 100644 --- a/acme/tf/networks/noise.py +++ b/acme/tf/networks/noise.py @@ -14,27 +14,29 @@ """Noise layers (for exploration).""" -from acme import types import sonnet as snt import tensorflow as tf import tensorflow_probability as tfp import tree +from acme import types + tfd = tfp.distributions class ClippedGaussian(snt.Module): - """Sonnet module for adding clipped Gaussian noise to each output.""" + """Sonnet module for adding clipped Gaussian noise to each output.""" - def __init__(self, stddev: float, name: str = 'clipped_gaussian'): - super().__init__(name=name) - self._noise = tfd.Normal(loc=0., scale=stddev) + def __init__(self, stddev: float, name: str = "clipped_gaussian"): + super().__init__(name=name) + self._noise = tfd.Normal(loc=0.0, scale=stddev) - def __call__(self, inputs: types.NestedTensor) -> types.NestedTensor: - def add_noise(tensor: tf.Tensor): - output = tensor + tf.cast(self._noise.sample(tensor.shape), - dtype=tensor.dtype) - output = tf.clip_by_value(output, -1.0, 1.0) - return output + def __call__(self, inputs: types.NestedTensor) -> types.NestedTensor: + def add_noise(tensor: tf.Tensor): + output = tensor + tf.cast( + self._noise.sample(tensor.shape), dtype=tensor.dtype + ) + output = tf.clip_by_value(output, -1.0, 1.0) + return output - return tree.map_structure(add_noise, inputs) + return tree.map_structure(add_noise, inputs) diff --git a/acme/tf/networks/policy_value.py b/acme/tf/networks/policy_value.py index cecbe305a6..00e89e2158 100644 --- a/acme/tf/networks/policy_value.py +++ b/acme/tf/networks/policy_value.py @@ -21,16 +21,16 @@ class PolicyValueHead(snt.Module): - """A network with two linear layers, for policy and value respectively.""" + """A network with two linear layers, for policy and value respectively.""" - def __init__(self, num_actions: int): - super().__init__(name='policy_value_network') - self._policy_layer = snt.Linear(num_actions) - self._value_layer = snt.Linear(1) + def __init__(self, num_actions: int): + super().__init__(name="policy_value_network") + self._policy_layer = snt.Linear(num_actions) + self._value_layer = snt.Linear(1) - def __call__(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: - """Returns a (Logits, Value) tuple.""" - logits = self._policy_layer(inputs) # [B, A] - value = tf.squeeze(self._value_layer(inputs), axis=-1) # [B] + def __call__(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + """Returns a (Logits, Value) tuple.""" + logits = self._policy_layer(inputs) # [B, A] + value = tf.squeeze(self._value_layer(inputs), axis=-1) # [B] - return logits, value + return logits, value diff --git a/acme/tf/networks/quantile.py b/acme/tf/networks/quantile.py index 89bd7bfa3f..ebb1cb7575 100644 --- a/acme/tf/networks/quantile.py +++ b/acme/tf/networks/quantile.py @@ -23,7 +23,7 @@ class IQNNetwork(snt.Module): - """A feedforward network for use with IQN. + """A feedforward network for use with IQN. IQN extends the Q-network of regular DQN which consists of torso and head networks. IQN embeds sampled quantile thresholds into the output space of the @@ -33,13 +33,15 @@ class IQNNetwork(snt.Module): quantile thresholds. """ - def __init__(self, - torso: snt.Module, - head: snt.Module, - latent_dim: int, - num_quantile_samples: int, - name: str = 'iqn_network'): - """Initializes the network. + def __init__( + self, + torso: snt.Module, + head: snt.Module, + latent_dim: int, + num_quantile_samples: int, + name: str = "iqn_network", + ): + """Initializes the network. Args: torso: Network producing an intermediate representation, typically a @@ -49,46 +51,46 @@ def __init__(self, num_quantile_samples: Number of quantile thresholds to sample. name: Module name. """ - super().__init__(name) - self._torso = torso - self._head = head - self._latent_dim = latent_dim - self._num_quantile_samples = num_quantile_samples - - @snt.once - def _create_embedding(self, size): - self._embedding = snt.Linear(size) - - def __call__(self, observations): - # Transform observations to intermediate representations (typically a - # convolutional network). - torso_output = self._torso(observations) - - # Now that dimension of intermediate representation is known initialize - # embedding of sample quantile thresholds (only done once). - self._create_embedding(torso_output.shape[-1]) - - # Sample quantile thresholds. - batch_size = tf.shape(observations)[0] - tau_shape = tf.stack([batch_size, self._num_quantile_samples]) - tau = tf.random.uniform(tau_shape) - indices = tf.range(1, self._latent_dim+1, dtype=tf.float32) - - # Embed sampled quantile thresholds in intermediate representation space. - tau_tiled = tf.tile(tau[:, :, None], (1, 1, self._latent_dim)) - indices_tiled = tf.tile(indices[None, None, :], - tf.concat([tau_shape, [1]], 0)) - tau_embedding = tf.cos(tau_tiled * indices_tiled * np.pi) - tau_embedding = snt.BatchApply(self._embedding)(tau_embedding) - tau_embedding = tf.nn.relu(tau_embedding) - - # Merge intermediate representations with embeddings, and apply head - # network (typically an MLP). - torso_output = tf.tile(torso_output[:, None, :], - (1, self._num_quantile_samples, 1)) - q_value_quantiles = snt.BatchApply(self._head)(tau_embedding * torso_output) - q_dist = tf.transpose(q_value_quantiles, (0, 2, 1)) - q_values = tf.reduce_mean(q_value_quantiles, axis=1) - q_values = tf.stop_gradient(q_values) - - return q_values, q_dist, tau + super().__init__(name) + self._torso = torso + self._head = head + self._latent_dim = latent_dim + self._num_quantile_samples = num_quantile_samples + + @snt.once + def _create_embedding(self, size): + self._embedding = snt.Linear(size) + + def __call__(self, observations): + # Transform observations to intermediate representations (typically a + # convolutional network). + torso_output = self._torso(observations) + + # Now that dimension of intermediate representation is known initialize + # embedding of sample quantile thresholds (only done once). + self._create_embedding(torso_output.shape[-1]) + + # Sample quantile thresholds. + batch_size = tf.shape(observations)[0] + tau_shape = tf.stack([batch_size, self._num_quantile_samples]) + tau = tf.random.uniform(tau_shape) + indices = tf.range(1, self._latent_dim + 1, dtype=tf.float32) + + # Embed sampled quantile thresholds in intermediate representation space. + tau_tiled = tf.tile(tau[:, :, None], (1, 1, self._latent_dim)) + indices_tiled = tf.tile(indices[None, None, :], tf.concat([tau_shape, [1]], 0)) + tau_embedding = tf.cos(tau_tiled * indices_tiled * np.pi) + tau_embedding = snt.BatchApply(self._embedding)(tau_embedding) + tau_embedding = tf.nn.relu(tau_embedding) + + # Merge intermediate representations with embeddings, and apply head + # network (typically an MLP). + torso_output = tf.tile( + torso_output[:, None, :], (1, self._num_quantile_samples, 1) + ) + q_value_quantiles = snt.BatchApply(self._head)(tau_embedding * torso_output) + q_dist = tf.transpose(q_value_quantiles, (0, 2, 1)) + q_values = tf.reduce_mean(q_value_quantiles, axis=1) + q_values = tf.stop_gradient(q_values) + + return q_values, q_dist, tau diff --git a/acme/tf/networks/recurrence.py b/acme/tf/networks/recurrence.py index b07bf9eaf6..835e4a90d3 100644 --- a/acme/tf/networks/recurrence.py +++ b/acme/tf/networks/recurrence.py @@ -17,27 +17,29 @@ import functools from typing import NamedTuple, Optional, Sequence, Tuple -from absl import logging -from acme import types -from acme.tf import savers -from acme.tf import utils -from acme.tf.networks import base + import sonnet as snt import tensorflow as tf import tensorflow_probability as tfp import tree +from absl import logging + +from acme import types +from acme.tf import savers, utils +from acme.tf.networks import base RNNState = types.NestedTensor class PolicyCriticRNNState(NamedTuple): - """Consists of two RNNStates called 'policy' and 'critic'.""" - policy: RNNState - critic: RNNState + """Consists of two RNNStates called 'policy' and 'critic'.""" + + policy: RNNState + critic: RNNState class UnpackWrapper(snt.Module): - """Gets a list of arguments and pass them as separate arguments. + """Gets a list of arguments and pass them as separate arguments. Example ``` @@ -51,18 +53,17 @@ def __call__(self, o, a): calls critic(o, a) """ - def __init__(self, module: snt.Module, name: str = 'UnpackWrapper'): - super().__init__(name=name) - self._module = module + def __init__(self, module: snt.Module, name: str = "UnpackWrapper"): + super().__init__(name=name) + self._module = module - def __call__(self, - inputs: Sequence[types.NestedTensor]) -> types.NestedTensor: - # Unpack the inputs before passing to the underlying module. - return self._module(*inputs) + def __call__(self, inputs: Sequence[types.NestedTensor]) -> types.NestedTensor: + # Unpack the inputs before passing to the underlying module. + return self._module(*inputs) class RNNUnpackWrapper(snt.RNNCore): - """Gets a list of arguments and pass them as separate arguments. + """Gets a list of arguments and pass them as separate arguments. Example ``` @@ -76,45 +77,47 @@ def __call__(self, o, a, prev_state): calls m(o, a, prev_state) """ - def __init__(self, module: snt.RNNCore, name: str = 'RNNUnpackWrapper'): - super().__init__(name=name) - self._module = module + def __init__(self, module: snt.RNNCore, name: str = "RNNUnpackWrapper"): + super().__init__(name=name) + self._module = module - def __call__(self, inputs: Sequence[types.NestedTensor], - prev_state: RNNState) -> Tuple[types.NestedTensor, RNNState]: - # Unpack the inputs before passing to the underlying module. - return self._module(*inputs, prev_state) + def __call__( + self, inputs: Sequence[types.NestedTensor], prev_state: RNNState + ) -> Tuple[types.NestedTensor, RNNState]: + # Unpack the inputs before passing to the underlying module. + return self._module(*inputs, prev_state) - def initial_state(self, batch_size): - return self._module.initial_state(batch_size) + def initial_state(self, batch_size): + return self._module.initial_state(batch_size) class CriticDeepRNN(snt.DeepRNN): - """Same as snt.DeepRNN, but takes three inputs (obs, act, prev_state). + """Same as snt.DeepRNN, but takes three inputs (obs, act, prev_state). """ - def __init__(self, layers: Sequence[snt.Module]): - # Make the first layer take a single input instead of a list of arguments. - if isinstance(layers[0], snt.RNNCore): - first_layer = RNNUnpackWrapper(layers[0]) - else: - first_layer = UnpackWrapper(layers[0]) - super().__init__([first_layer] + list(layers[1:])) - - self._unwrapped_first_layer = layers[0] - self.__input_signature = None - - def __call__(self, inputs: types.NestedTensor, action: tf.Tensor, - prev_state: RNNState) -> Tuple[types.NestedTensor, RNNState]: - # Pack the inputs into a tuple and then using inherited DeepRNN logic to - # pass them through the layers. - # This in turn will pass the packed inputs into the first layer - # (UnpackWrapper) which will unpack them back. - return super().__call__((inputs, action), prev_state) - - @property - def _input_signature(self) -> Optional[tf.TensorSpec]: - """Return input signature for Acme snapshotting. + def __init__(self, layers: Sequence[snt.Module]): + # Make the first layer take a single input instead of a list of arguments. + if isinstance(layers[0], snt.RNNCore): + first_layer = RNNUnpackWrapper(layers[0]) + else: + first_layer = UnpackWrapper(layers[0]) + super().__init__([first_layer] + list(layers[1:])) + + self._unwrapped_first_layer = layers[0] + self.__input_signature = None + + def __call__( + self, inputs: types.NestedTensor, action: tf.Tensor, prev_state: RNNState + ) -> Tuple[types.NestedTensor, RNNState]: + # Pack the inputs into a tuple and then using inherited DeepRNN logic to + # pass them through the layers. + # This in turn will pass the packed inputs into the first layer + # (UnpackWrapper) which will unpack them back. + return super().__call__((inputs, action), prev_state) + + @property + def _input_signature(self) -> Optional[tf.TensorSpec]: + """Return input signature for Acme snapshotting. The Acme way of snapshotting works as follows: you first create your network variables via the utility function `acme.tf.utils.create_variables()`, which @@ -148,62 +151,65 @@ def _input_signature(self) -> Optional[tf.TensorSpec]: module nor for any of its descendants). """ - if self.__input_signature is not None: - # To make case (2) (see above) work, we need to allow create_variables to - # assign an _input_signature attribute to this module, which is why we - # create additional __input_signature attribute with a setter (see below). - return self.__input_signature - - # To make case (1) work, we descend into self._unwrapped_first_layer - # and try to get its input signature (if it exists) by calling - # savers.get_input_signature. - - # Ideally, savers.get_input_signature should automatically descend into - # DeepRNN. But in this case it breaks on CriticDeepRNN because - # CriticDeepRNN._layers[0] is an UnpackWrapper around the underlying module - # and not the module itself. - input_signature = savers._get_input_signature(self._unwrapped_first_layer) # pylint: disable=protected-access - if input_signature is None: - return None - # Since adding recurrent modules via CriticDeepRNN changes the recurrent - # state, we need to update its spec here. - state = self.initial_state(1) - input_signature[-1] = tree.map_structure( - lambda t: tf.TensorSpec((None,) + t.shape[1:], t.dtype), state) - self.__input_signature = input_signature - return input_signature - - @_input_signature.setter - def _input_signature(self, new_spec: tf.TensorSpec): - self.__input_signature = new_spec + if self.__input_signature is not None: + # To make case (2) (see above) work, we need to allow create_variables to + # assign an _input_signature attribute to this module, which is why we + # create additional __input_signature attribute with a setter (see below). + return self.__input_signature + + # To make case (1) work, we descend into self._unwrapped_first_layer + # and try to get its input signature (if it exists) by calling + # savers.get_input_signature. + + # Ideally, savers.get_input_signature should automatically descend into + # DeepRNN. But in this case it breaks on CriticDeepRNN because + # CriticDeepRNN._layers[0] is an UnpackWrapper around the underlying module + # and not the module itself. + input_signature = savers._get_input_signature( + self._unwrapped_first_layer + ) # pylint: disable=protected-access + if input_signature is None: + return None + # Since adding recurrent modules via CriticDeepRNN changes the recurrent + # state, we need to update its spec here. + state = self.initial_state(1) + input_signature[-1] = tree.map_structure( + lambda t: tf.TensorSpec((None,) + t.shape[1:], t.dtype), state + ) + self.__input_signature = input_signature + return input_signature + + @_input_signature.setter + def _input_signature(self, new_spec: tf.TensorSpec): + self.__input_signature = new_spec class RecurrentExpQWeightedPolicy(snt.RNNCore): - """Recurrent exponentially Q-weighted policy.""" - - def __init__(self, - policy_network: snt.Module, - critic_network: snt.Module, - temperature_beta: float = 1.0, - num_action_samples: int = 16): - super().__init__(name='RecurrentExpQWeightedPolicy') - self._policy_network = policy_network - self._critic_network = critic_network - self._num_action_samples = num_action_samples - self._temperature_beta = temperature_beta - - def __call__(self, - observation: types.NestedTensor, - prev_state: PolicyCriticRNNState - ) -> Tuple[types.NestedTensor, PolicyCriticRNNState]: - - return tf.vectorized_map(self._call, (observation, prev_state)) - - def _call( - self, observation_and_state: Tuple[types.NestedTensor, - PolicyCriticRNNState] - ) -> Tuple[types.NestedTensor, PolicyCriticRNNState]: - """Computes a forward step for a single element. + """Recurrent exponentially Q-weighted policy.""" + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + temperature_beta: float = 1.0, + num_action_samples: int = 16, + ): + super().__init__(name="RecurrentExpQWeightedPolicy") + self._policy_network = policy_network + self._critic_network = critic_network + self._num_action_samples = num_action_samples + self._temperature_beta = temperature_beta + + def __call__( + self, observation: types.NestedTensor, prev_state: PolicyCriticRNNState + ) -> Tuple[types.NestedTensor, PolicyCriticRNNState]: + + return tf.vectorized_map(self._call, (observation, prev_state)) + + def _call( + self, observation_and_state: Tuple[types.NestedTensor, PolicyCriticRNNState] + ) -> Tuple[types.NestedTensor, PolicyCriticRNNState]: + """Computes a forward step for a single element. The observation and state are packed together in order to use `tf.vectorized_map` to handle batches of observations. @@ -215,42 +221,44 @@ def _call( Returns: The selected action and the corresponding state. """ - observation, prev_state = observation_and_state - - # Tile input observations and states to allow multiple policy predictions. - tiled_observation, tiled_prev_state = utils.tile_nested( - (observation, prev_state), self._num_action_samples) - actions, policy_states = self._policy_network( - tiled_observation, tiled_prev_state.policy) - - # Evaluate multiple critic predictions with the sampled actions. - value_distribution, critic_states = self._critic_network( - tiled_observation, actions, tiled_prev_state.critic) - value_estimate = value_distribution.mean() - - # Resample a single action of the sampled actions according to logits given - # by the tempered Q-values. - selected_action_idx = tfp.distributions.Categorical( - probs=tf.nn.softmax(value_estimate / self._temperature_beta)).sample() - selected_action = actions[selected_action_idx] - - # Select and return the RNN state that corresponds to the selected action. - states = PolicyCriticRNNState( - policy=policy_states, critic=critic_states) - selected_state = tree.map_structure( - lambda x: x[selected_action_idx], states) - - return selected_action, selected_state - - def initial_state(self, batch_size: int) -> PolicyCriticRNNState: - return PolicyCriticRNNState( - policy=self._policy_network.initial_state(batch_size), - critic=self._critic_network.initial_state(batch_size) + observation, prev_state = observation_and_state + + # Tile input observations and states to allow multiple policy predictions. + tiled_observation, tiled_prev_state = utils.tile_nested( + (observation, prev_state), self._num_action_samples + ) + actions, policy_states = self._policy_network( + tiled_observation, tiled_prev_state.policy + ) + + # Evaluate multiple critic predictions with the sampled actions. + value_distribution, critic_states = self._critic_network( + tiled_observation, actions, tiled_prev_state.critic + ) + value_estimate = value_distribution.mean() + + # Resample a single action of the sampled actions according to logits given + # by the tempered Q-values. + selected_action_idx = tfp.distributions.Categorical( + probs=tf.nn.softmax(value_estimate / self._temperature_beta) + ).sample() + selected_action = actions[selected_action_idx] + + # Select and return the RNN state that corresponds to the selected action. + states = PolicyCriticRNNState(policy=policy_states, critic=critic_states) + selected_state = tree.map_structure(lambda x: x[selected_action_idx], states) + + return selected_action, selected_state + + def initial_state(self, batch_size: int) -> PolicyCriticRNNState: + return PolicyCriticRNNState( + policy=self._policy_network.initial_state(batch_size), + critic=self._critic_network.initial_state(batch_size), ) class DeepRNN(snt.DeepRNN, base.RNNCore): - """Unroll-aware deep RNN module. + """Unroll-aware deep RNN module. Sonnet's DeepRNN steps through RNNCores sequentially which can result in a performance hit, in particular when using Transformers. This module adds an @@ -273,34 +281,33 @@ class DeepRNN(snt.DeepRNN, base.RNNCore): fairly large batches, potentially leading to out-of-memory issues. """ - def __init__(self, layers, name: Optional[str] = None): - """Initializes the module.""" - super().__init__(layers, name=name) - - self.__input_signature = None - self._num_unrollable = 0 - - # As a convenience, check for snt.RNNCore modules and dynamically unroll - # them if they don't already support unrolling. This check can fail, e.g. - # if a partially applied RNNCore is passed in. Sonnet's implementation of - # DeepRNN suffers from the same problem. - for layer in self._layers: - if hasattr(layer, 'unroll'): - self._num_unrollable += 1 - elif isinstance(layer, snt.RNNCore): - self._num_unrollable += 1 - layer.unroll = functools.partial(snt.dynamic_unroll, layer) - logging.warning( - 'Acme DeepRNN detected a Sonnet RNNCore. ' - 'This will be dynamically unrolled. Please implement unroll() ' - 'to suppress this warning.') - - def unroll(self, - inputs: types.NestedTensor, - state: base.State, - sequence_length: int, - ) -> Tuple[types.NestedTensor, base.State]: - """Unroll each layer individually. + def __init__(self, layers, name: Optional[str] = None): + """Initializes the module.""" + super().__init__(layers, name=name) + + self.__input_signature = None + self._num_unrollable = 0 + + # As a convenience, check for snt.RNNCore modules and dynamically unroll + # them if they don't already support unrolling. This check can fail, e.g. + # if a partially applied RNNCore is passed in. Sonnet's implementation of + # DeepRNN suffers from the same problem. + for layer in self._layers: + if hasattr(layer, "unroll"): + self._num_unrollable += 1 + elif isinstance(layer, snt.RNNCore): + self._num_unrollable += 1 + layer.unroll = functools.partial(snt.dynamic_unroll, layer) + logging.warning( + "Acme DeepRNN detected a Sonnet RNNCore. " + "This will be dynamically unrolled. Please implement unroll() " + "to suppress this warning." + ) + + def unroll( + self, inputs: types.NestedTensor, state: base.State, sequence_length: int, + ) -> Tuple[types.NestedTensor, base.State]: + """Unroll each layer individually. Calls unroll() on layers which support it, all other layers are batch-applied over the first two axes (assumed to be the time and batch @@ -318,58 +325,61 @@ def unroll(self, ValueError if the length of `state` does not match the number of unrollable layers. """ - if len(state) != self._num_unrollable: - raise ValueError( - 'DeepRNN was called with the wrong number of states. The length of ' - '`state` does not match the number of unrollable layers.') - - states = iter(state) - outputs = inputs - next_states = [] - for layer in self._layers: - if hasattr(layer, 'unroll'): - # The length of the `states` list was checked above. - outputs, next_state = layer.unroll(outputs, next(states), - sequence_length) - next_states.append(next_state) - else: - # Couldn't unroll(); assume that this is a stateless module. - outputs = snt.BatchApply(layer, num_dims=2)(outputs) - - return outputs, tuple(next_states) - - @property - def _input_signature(self) -> Optional[tf.TensorSpec]: - """Return input signature for Acme snapshotting, see CriticDeepRNN.""" - - if self.__input_signature is not None: - return self.__input_signature - - input_signature = savers._get_input_signature(self._layers[0]) # pylint: disable=protected-access - if input_signature is None: - return None - - state = self.initial_state(1) - input_signature[-1] = tree.map_structure( - lambda t: tf.TensorSpec((None,) + t.shape[1:], t.dtype), state) - self.__input_signature = input_signature - return input_signature - - @_input_signature.setter - def _input_signature(self, new_spec: tf.TensorSpec): - self.__input_signature = new_spec + if len(state) != self._num_unrollable: + raise ValueError( + "DeepRNN was called with the wrong number of states. The length of " + "`state` does not match the number of unrollable layers." + ) + + states = iter(state) + outputs = inputs + next_states = [] + for layer in self._layers: + if hasattr(layer, "unroll"): + # The length of the `states` list was checked above. + outputs, next_state = layer.unroll( + outputs, next(states), sequence_length + ) + next_states.append(next_state) + else: + # Couldn't unroll(); assume that this is a stateless module. + outputs = snt.BatchApply(layer, num_dims=2)(outputs) + + return outputs, tuple(next_states) + + @property + def _input_signature(self) -> Optional[tf.TensorSpec]: + """Return input signature for Acme snapshotting, see CriticDeepRNN.""" + + if self.__input_signature is not None: + return self.__input_signature + + input_signature = savers._get_input_signature( + self._layers[0] + ) # pylint: disable=protected-access + if input_signature is None: + return None + + state = self.initial_state(1) + input_signature[-1] = tree.map_structure( + lambda t: tf.TensorSpec((None,) + t.shape[1:], t.dtype), state + ) + self.__input_signature = input_signature + return input_signature + + @_input_signature.setter + def _input_signature(self, new_spec: tf.TensorSpec): + self.__input_signature = new_spec class LSTM(snt.LSTM, base.RNNCore): - """Unrollable interface to LSTM. + """Unrollable interface to LSTM. This module is supposed to be used with the DeepRNN class above, and more generally in networks which support unroll(). """ - def unroll(self, - inputs: types.NestedTensor, - state: base.State, - sequence_length: int, - ) -> Tuple[types.NestedTensor, base.State]: - return snt.static_unroll(self, inputs, state, sequence_length) + def unroll( + self, inputs: types.NestedTensor, state: base.State, sequence_length: int, + ) -> Tuple[types.NestedTensor, base.State]: + return snt.static_unroll(self, inputs, state, sequence_length) diff --git a/acme/tf/networks/recurrence_test.py b/acme/tf/networks/recurrence_test.py index 2c97c8fe65..a4b5f5e551 100644 --- a/acme/tf/networks/recurrence_test.py +++ b/acme/tf/networks/recurrence_test.py @@ -16,73 +16,71 @@ import os -from acme import specs -from acme.tf import savers as tf2_savers -from acme.tf import utils as tf2_utils -from acme.tf.networks import recurrence import numpy as np import sonnet as snt import tensorflow as tf import tree - from absl.testing import absltest +from acme import specs +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.tf.networks import recurrence + # Simple critic-like modules for testing. class Critic(snt.Module): - - def __call__(self, o, a): - return o * a + def __call__(self, o, a): + return o * a class RNNCritic(snt.RNNCore): + def __call__(self, o, a, prev_state): + return o * a, prev_state - def __call__(self, o, a, prev_state): - return o * a, prev_state - - def initial_state(self, batch_size): - return () + def initial_state(self, batch_size): + return () class NetsTest(tf.test.TestCase): - - def test_criticdeeprnn_snapshot(self): - """Test that CriticDeepRNN works correctly with snapshotting.""" - # Create a test network. - critic = Critic() - rnn_critic = RNNCritic() - - for base_net in [critic, rnn_critic]: - net = recurrence.CriticDeepRNN([base_net, snt.LSTM(10)]) - obs = specs.Array([10], dtype=np.float32) - actions = specs.Array([10], dtype=np.float32) - spec = [obs, actions] - tf2_utils.create_variables(net, spec) - - # Test that if you add some postprocessing without rerunning - # create_variables, it still works. - wrapped_net = recurrence.CriticDeepRNN([net, lambda x: x]) - - for curr_net in [net, wrapped_net]: - # Save the test network. - directory = absltest.get_default_test_tmpdir() - objects_to_save = {'net': curr_net} - snapshotter = tf2_savers.Snapshotter( - objects_to_save, directory=directory) - snapshotter.save() - - # Reload the test network. - net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) - - obs = tf.ones((2, 10)) - actions = tf.ones((2, 10)) - state = curr_net.initial_state(2) - outputs1, next_state1 = curr_net(obs, actions, state) - outputs2, next_state2 = net2(obs, actions, state) - - assert np.allclose(outputs1, outputs2) - assert np.allclose(tree.flatten(next_state1), tree.flatten(next_state2)) - - -if __name__ == '__main__': - absltest.main() + def test_criticdeeprnn_snapshot(self): + """Test that CriticDeepRNN works correctly with snapshotting.""" + # Create a test network. + critic = Critic() + rnn_critic = RNNCritic() + + for base_net in [critic, rnn_critic]: + net = recurrence.CriticDeepRNN([base_net, snt.LSTM(10)]) + obs = specs.Array([10], dtype=np.float32) + actions = specs.Array([10], dtype=np.float32) + spec = [obs, actions] + tf2_utils.create_variables(net, spec) + + # Test that if you add some postprocessing without rerunning + # create_variables, it still works. + wrapped_net = recurrence.CriticDeepRNN([net, lambda x: x]) + + for curr_net in [net, wrapped_net]: + # Save the test network. + directory = absltest.get_default_test_tmpdir() + objects_to_save = {"net": curr_net} + snapshotter = tf2_savers.Snapshotter( + objects_to_save, directory=directory + ) + snapshotter.save() + + # Reload the test network. + net2 = tf.saved_model.load(os.path.join(snapshotter.directory, "net")) + + obs = tf.ones((2, 10)) + actions = tf.ones((2, 10)) + state = curr_net.initial_state(2) + outputs1, next_state1 = curr_net(obs, actions, state) + outputs2, next_state2 = net2(obs, actions, state) + + assert np.allclose(outputs1, outputs2) + assert np.allclose(tree.flatten(next_state1), tree.flatten(next_state2)) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/tf/networks/rescaling.py b/acme/tf/networks/rescaling.py index d661d18be4..bbe0973cbd 100644 --- a/acme/tf/networks/rescaling.py +++ b/acme/tf/networks/rescaling.py @@ -15,59 +15,61 @@ """Rescaling layers (e.g. to match action specs).""" from typing import Union -from acme import specs + import sonnet as snt import tensorflow as tf import tensorflow_probability as tfp +from acme import specs + tfd = tfp.distributions tfb = tfp.bijectors class ClipToSpec(snt.Module): - """Sonnet module clipping inputs to within a BoundedArraySpec.""" + """Sonnet module clipping inputs to within a BoundedArraySpec.""" - def __init__(self, spec: specs.BoundedArray, name: str = 'clip_to_spec'): - super().__init__(name=name) - self._min = spec.minimum - self._max = spec.maximum + def __init__(self, spec: specs.BoundedArray, name: str = "clip_to_spec"): + super().__init__(name=name) + self._min = spec.minimum + self._max = spec.maximum - def __call__(self, inputs: tf.Tensor) -> tf.Tensor: - return tf.clip_by_value(inputs, self._min, self._max) + def __call__(self, inputs: tf.Tensor) -> tf.Tensor: + return tf.clip_by_value(inputs, self._min, self._max) class RescaleToSpec(snt.Module): - """Sonnet module rescaling inputs in [-1, 1] to match a BoundedArraySpec.""" + """Sonnet module rescaling inputs in [-1, 1] to match a BoundedArraySpec.""" - def __init__(self, spec: specs.BoundedArray, name: str = 'rescale_to_spec'): - super().__init__(name=name) - self._scale = spec.maximum - spec.minimum - self._offset = spec.minimum + def __init__(self, spec: specs.BoundedArray, name: str = "rescale_to_spec"): + super().__init__(name=name) + self._scale = spec.maximum - spec.minimum + self._offset = spec.minimum - def __call__(self, inputs: tf.Tensor) -> tf.Tensor: - inputs = 0.5 * (inputs + 1.0) # [0, 1] - output = inputs * self._scale + self._offset # [minimum, maximum] + def __call__(self, inputs: tf.Tensor) -> tf.Tensor: + inputs = 0.5 * (inputs + 1.0) # [0, 1] + output = inputs * self._scale + self._offset # [minimum, maximum] - return output + return output class TanhToSpec(snt.Module): - """Sonnet module squashing real-valued inputs to match a BoundedArraySpec.""" - - def __init__(self, spec: specs.BoundedArray, name: str = 'tanh_to_spec'): - super().__init__(name=name) - self._scale = spec.maximum - spec.minimum - self._offset = spec.minimum - - def __call__( - self, inputs: Union[tf.Tensor, tfd.Distribution] - ) -> Union[tf.Tensor, tfd.Distribution]: - if isinstance(inputs, tfd.Distribution): - inputs = tfb.Tanh()(inputs) - inputs = tfb.ScaleMatvecDiag(0.5 * self._scale)(inputs) - output = tfb.Shift(self._offset + 0.5 * self._scale)(inputs) - else: - inputs = tf.tanh(inputs) # [-1, 1] - inputs = 0.5 * (inputs + 1.0) # [0, 1] - output = inputs * self._scale + self._offset # [minimum, maximum] - return output + """Sonnet module squashing real-valued inputs to match a BoundedArraySpec.""" + + def __init__(self, spec: specs.BoundedArray, name: str = "tanh_to_spec"): + super().__init__(name=name) + self._scale = spec.maximum - spec.minimum + self._offset = spec.minimum + + def __call__( + self, inputs: Union[tf.Tensor, tfd.Distribution] + ) -> Union[tf.Tensor, tfd.Distribution]: + if isinstance(inputs, tfd.Distribution): + inputs = tfb.Tanh()(inputs) + inputs = tfb.ScaleMatvecDiag(0.5 * self._scale)(inputs) + output = tfb.Shift(self._offset + 0.5 * self._scale)(inputs) + else: + inputs = tf.tanh(inputs) # [-1, 1] + inputs = 0.5 * (inputs + 1.0) # [0, 1] + output = inputs * self._scale + self._offset # [minimum, maximum] + return output diff --git a/acme/tf/networks/stochastic.py b/acme/tf/networks/stochastic.py index 264d270ab8..c574aaab48 100644 --- a/acme/tf/networks/stochastic.py +++ b/acme/tf/networks/stochastic.py @@ -14,39 +14,40 @@ """Useful sonnet modules to chain after distributional module outputs.""" -from acme import types -from acme.tf import utils as tf2_utils import sonnet as snt import tensorflow as tf import tensorflow_probability as tfp import tree +from acme import types +from acme.tf import utils as tf2_utils + tfd = tfp.distributions class StochasticModeHead(snt.Module): - """Simple sonnet module to produce the mode of a tfp.Distribution.""" + """Simple sonnet module to produce the mode of a tfp.Distribution.""" - def __call__(self, distribution: tfd.Distribution): - return distribution.mode() + def __call__(self, distribution: tfd.Distribution): + return distribution.mode() class StochasticMeanHead(snt.Module): - """Simple sonnet module to produce the mean of a tfp.Distribution.""" + """Simple sonnet module to produce the mean of a tfp.Distribution.""" - def __call__(self, distribution: tfd.Distribution): - return distribution.mean() + def __call__(self, distribution: tfd.Distribution): + return distribution.mean() class StochasticSamplingHead(snt.Module): - """Simple sonnet module to sample from a tfp.Distribution.""" + """Simple sonnet module to sample from a tfp.Distribution.""" - def __call__(self, distribution: tfd.Distribution): - return distribution.sample() + def __call__(self, distribution: tfd.Distribution): + return distribution.sample() class ExpQWeightedPolicy(snt.Module): - """Exponentially Q-weighted policy. + """Exponentially Q-weighted policy. Given a stochastic policy and a critic, returns a (stochastic) policy which samples multiple actions from the underlying policy, computes the Q-values for @@ -55,50 +56,55 @@ class ExpQWeightedPolicy(snt.Module): a parameter beta. """ - def __init__(self, - actor_network: snt.Module, - critic_network: snt.Module, - beta: float = 1.0, - num_action_samples: int = 16): - super().__init__(name='ExpQWeightedPolicy') - self._actor_network = actor_network - self._critic_network = critic_network - self._num_action_samples = num_action_samples - self._beta = beta - - def __call__(self, inputs: types.NestedTensor) -> tf.Tensor: - # Inputs are of size [B, ...]. Here we tile them to be of shape [N, B, ...]. - tiled_inputs = tf2_utils.tile_nested(inputs, self._num_action_samples) - shape = tf.shape(tree.flatten(tiled_inputs)[0]) - n, b = shape[0], shape[1] - tf.debugging.assert_equal(n, self._num_action_samples, - 'Internal Error. Unexpected tiled_inputs shape.') - dummy_zeros_n_b = tf.zeros((n, b)) - # Reshape to [N * B, ...]. - merge = lambda x: snt.merge_leading_dims(x, 2) - tiled_inputs = tree.map_structure(merge, tiled_inputs) - - tiled_actions = self._actor_network(tiled_inputs) - - # Compute Q-values and the resulting tempered probabilities. - q = self._critic_network(tiled_inputs, tiled_actions) - boltzmann_logits = q / self._beta - - boltzmann_logits = snt.split_leading_dim(boltzmann_logits, dummy_zeros_n_b, - 2) - # [B, N] - boltzmann_logits = tf.transpose(boltzmann_logits, perm=(1, 0)) - # Resample one action per batch according to the Boltzmann distribution. - action_idx = tfp.distributions.Categorical(logits=boltzmann_logits).sample() - # [B, 2], where the first column is 0, 1, 2,... corresponding to indices to - # the batch dimension. - action_idx = tf.stack((tf.range(b), action_idx), axis=1) - - tiled_actions = snt.split_leading_dim(tiled_actions, dummy_zeros_n_b, 2) - action_dim = len(tiled_actions.get_shape().as_list()) - tiled_actions = tf.transpose(tiled_actions, - perm=[1, 0] + list(range(2, action_dim))) - # [B, ...] - action_sample = tf.gather_nd(tiled_actions, action_idx) - - return action_sample + def __init__( + self, + actor_network: snt.Module, + critic_network: snt.Module, + beta: float = 1.0, + num_action_samples: int = 16, + ): + super().__init__(name="ExpQWeightedPolicy") + self._actor_network = actor_network + self._critic_network = critic_network + self._num_action_samples = num_action_samples + self._beta = beta + + def __call__(self, inputs: types.NestedTensor) -> tf.Tensor: + # Inputs are of size [B, ...]. Here we tile them to be of shape [N, B, ...]. + tiled_inputs = tf2_utils.tile_nested(inputs, self._num_action_samples) + shape = tf.shape(tree.flatten(tiled_inputs)[0]) + n, b = shape[0], shape[1] + tf.debugging.assert_equal( + n, + self._num_action_samples, + "Internal Error. Unexpected tiled_inputs shape.", + ) + dummy_zeros_n_b = tf.zeros((n, b)) + # Reshape to [N * B, ...]. + merge = lambda x: snt.merge_leading_dims(x, 2) + tiled_inputs = tree.map_structure(merge, tiled_inputs) + + tiled_actions = self._actor_network(tiled_inputs) + + # Compute Q-values and the resulting tempered probabilities. + q = self._critic_network(tiled_inputs, tiled_actions) + boltzmann_logits = q / self._beta + + boltzmann_logits = snt.split_leading_dim(boltzmann_logits, dummy_zeros_n_b, 2) + # [B, N] + boltzmann_logits = tf.transpose(boltzmann_logits, perm=(1, 0)) + # Resample one action per batch according to the Boltzmann distribution. + action_idx = tfp.distributions.Categorical(logits=boltzmann_logits).sample() + # [B, 2], where the first column is 0, 1, 2,... corresponding to indices to + # the batch dimension. + action_idx = tf.stack((tf.range(b), action_idx), axis=1) + + tiled_actions = snt.split_leading_dim(tiled_actions, dummy_zeros_n_b, 2) + action_dim = len(tiled_actions.get_shape().as_list()) + tiled_actions = tf.transpose( + tiled_actions, perm=[1, 0] + list(range(2, action_dim)) + ) + # [B, ...] + action_sample = tf.gather_nd(tiled_actions, action_idx) + + return action_sample diff --git a/acme/tf/networks/vision.py b/acme/tf/networks/vision.py index 7065dba790..850345eb45 100644 --- a/acme/tf/networks/vision.py +++ b/acme/tf/networks/vision.py @@ -15,27 +15,29 @@ """Visual networks for processing pixel inputs.""" from typing import Callable, Optional, Sequence, Union + import sonnet as snt import tensorflow as tf class ResNetTorso(snt.Module): - """ResNet architecture used in IMPALA paper.""" - - def __init__( - self, - num_channels: Sequence[int] = (16, 32, 32), # default to IMPALA resnet. - num_blocks: Sequence[int] = (2, 2, 2), # default to IMPALA resnet. - num_output_hidden: Sequence[int] = (256,), # default to IMPALA resnet. - conv_shape: Union[int, Sequence[int]] = 3, - conv_stride: Union[int, Sequence[int]] = 1, - pool_size: Union[int, Sequence[int]] = 3, - pool_stride: Union[int, Sequence[int], Sequence[Sequence[int]]] = 2, - data_format: str = 'NHWC', - activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.relu, - output_dtype: tf.DType = tf.float32, - name: str = 'resnet_torso'): - """Builds an IMPALA-style ResNet. + """ResNet architecture used in IMPALA paper.""" + + def __init__( + self, + num_channels: Sequence[int] = (16, 32, 32), # default to IMPALA resnet. + num_blocks: Sequence[int] = (2, 2, 2), # default to IMPALA resnet. + num_output_hidden: Sequence[int] = (256,), # default to IMPALA resnet. + conv_shape: Union[int, Sequence[int]] = 3, + conv_stride: Union[int, Sequence[int]] = 1, + pool_size: Union[int, Sequence[int]] = 3, + pool_stride: Union[int, Sequence[int], Sequence[Sequence[int]]] = 2, + data_format: str = "NHWC", + activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.relu, + output_dtype: tf.DType = tf.float32, + name: str = "resnet_torso", + ): + """Builds an IMPALA-style ResNet. The arguments' default values construct the IMPALA resnet. @@ -53,184 +55,198 @@ def __init__( output_dtype: the output dtype. name: The Sonnet module name. """ - super().__init__(name=name) + super().__init__(name=name) - self._output_dtype = output_dtype - self._num_layers = len(num_blocks) + self._output_dtype = output_dtype + self._num_layers = len(num_blocks) - if isinstance(pool_stride, int): - pool_stride = (pool_stride, pool_stride) + if isinstance(pool_stride, int): + pool_stride = (pool_stride, pool_stride) - if isinstance(pool_stride[0], int): - pool_stride = self._num_layers * (pool_stride,) + if isinstance(pool_stride[0], int): + pool_stride = self._num_layers * (pool_stride,) - # Create sequence of residual blocks. - blocks = [] - for i in range(self._num_layers): - blocks.append( - ResidualBlockGroup( - num_blocks[i], - num_channels[i], - conv_shape, - conv_stride, - pool_size, - pool_stride[i], - data_format=data_format, - activation=activation)) + # Create sequence of residual blocks. + blocks = [] + for i in range(self._num_layers): + blocks.append( + ResidualBlockGroup( + num_blocks[i], + num_channels[i], + conv_shape, + conv_stride, + pool_size, + pool_stride[i], + data_format=data_format, + activation=activation, + ) + ) - # Create output layer. - out_layer = snt.nets.MLP(num_output_hidden, activation=activation) + # Create output layer. + out_layer = snt.nets.MLP(num_output_hidden, activation=activation) - # Compose blocks and final layer. - self._resnet = snt.Sequential( - blocks + [activation, snt.Flatten(), out_layer]) + # Compose blocks and final layer. + self._resnet = snt.Sequential(blocks + [activation, snt.Flatten(), out_layer]) - def __call__(self, inputs: tf.Tensor) -> tf.Tensor: - """Evaluates the ResidualPixelCore.""" + def __call__(self, inputs: tf.Tensor) -> tf.Tensor: + """Evaluates the ResidualPixelCore.""" - # Convert to floats. - preprocessed_inputs = _preprocess_inputs(inputs, self._output_dtype) - torso_output = self._resnet(preprocessed_inputs) + # Convert to floats. + preprocessed_inputs = _preprocess_inputs(inputs, self._output_dtype) + torso_output = self._resnet(preprocessed_inputs) - return torso_output + return torso_output class ResidualBlockGroup(snt.Module): - """Higher level block for ResNet implementation.""" - - def __init__(self, - num_blocks: int, - num_output_channels: int, - conv_shape: Union[int, Sequence[int]], - conv_stride: Union[int, Sequence[int]], - pool_shape: Union[int, Sequence[int]], - pool_stride: Union[int, Sequence[int]], - data_format: str = 'NHWC', - activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.relu, - name: Optional[str] = None): - super().__init__(name=name) - - self._num_blocks = num_blocks - self._data_format = data_format - self._activation = activation - - # The pooling operation expects a 2-rank shape/stride (height and width). - if isinstance(pool_shape, int): - pool_shape = 2 * [pool_shape] - if isinstance(pool_stride, int): - pool_stride = 2 * [pool_stride] - - # Create a Conv2D factory since we'll be making quite a few. - def build_conv_layer(name: str): - return snt.Conv2D( - num_output_channels, - conv_shape, - stride=conv_stride, - padding='SAME', - data_format=data_format, - name=name) - - # Create a pooling layer. - def pooling_layer(inputs: tf.Tensor) -> tf.Tensor: - return tf.nn.pool( - inputs, - pool_shape, - pooling_type='MAX', - strides=pool_stride, - padding='SAME', - data_format=data_format) - - # Create an initial conv layer and pooling to scale the image down. - self._downscale = snt.Sequential( - [build_conv_layer('downscale'), pooling_layer]) - - # Residual block(s). - self._convs = [] - for i in range(self._num_blocks): - name = 'residual_block_%d' % i - self._convs.append( - [build_conv_layer(name + '_0'), - build_conv_layer(name + '_1')]) - - def __call__(self, inputs: tf.Tensor) -> tf.Tensor: - # Downscale the inputs. - conv_out = self._downscale(inputs) - - # Apply (sequence of) residual block(s). - for i in range(self._num_blocks): - block_input = conv_out - conv_out = self._activation(conv_out) - conv_out = self._convs[i][0](conv_out) - conv_out = self._activation(conv_out) - conv_out = self._convs[i][1](conv_out) - conv_out += block_input - return conv_out + """Higher level block for ResNet implementation.""" + + def __init__( + self, + num_blocks: int, + num_output_channels: int, + conv_shape: Union[int, Sequence[int]], + conv_stride: Union[int, Sequence[int]], + pool_shape: Union[int, Sequence[int]], + pool_stride: Union[int, Sequence[int]], + data_format: str = "NHWC", + activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.relu, + name: Optional[str] = None, + ): + super().__init__(name=name) + + self._num_blocks = num_blocks + self._data_format = data_format + self._activation = activation + + # The pooling operation expects a 2-rank shape/stride (height and width). + if isinstance(pool_shape, int): + pool_shape = 2 * [pool_shape] + if isinstance(pool_stride, int): + pool_stride = 2 * [pool_stride] + + # Create a Conv2D factory since we'll be making quite a few. + def build_conv_layer(name: str): + return snt.Conv2D( + num_output_channels, + conv_shape, + stride=conv_stride, + padding="SAME", + data_format=data_format, + name=name, + ) + + # Create a pooling layer. + def pooling_layer(inputs: tf.Tensor) -> tf.Tensor: + return tf.nn.pool( + inputs, + pool_shape, + pooling_type="MAX", + strides=pool_stride, + padding="SAME", + data_format=data_format, + ) + + # Create an initial conv layer and pooling to scale the image down. + self._downscale = snt.Sequential([build_conv_layer("downscale"), pooling_layer]) + + # Residual block(s). + self._convs = [] + for i in range(self._num_blocks): + name = "residual_block_%d" % i + self._convs.append( + [build_conv_layer(name + "_0"), build_conv_layer(name + "_1")] + ) + + def __call__(self, inputs: tf.Tensor) -> tf.Tensor: + # Downscale the inputs. + conv_out = self._downscale(inputs) + + # Apply (sequence of) residual block(s). + for i in range(self._num_blocks): + block_input = conv_out + conv_out = self._activation(conv_out) + conv_out = self._convs[i][0](conv_out) + conv_out = self._activation(conv_out) + conv_out = self._convs[i][1](conv_out) + conv_out += block_input + return conv_out def _preprocess_inputs(inputs: tf.Tensor, output_dtype: tf.DType) -> tf.Tensor: - """Returns the `Tensor` corresponding to the preprocessed inputs.""" - rank = inputs.shape.rank - if rank < 4: - raise ValueError( - 'Input Tensor must have at least 4 dimensions (for ' - 'batch size, height, width, and channels), but it only has ' - '{}'.format(rank)) - - flattened_inputs = snt.Flatten(preserve_dims=3)(inputs) - processed_inputs = tf.image.convert_image_dtype( - flattened_inputs, dtype=output_dtype) - return processed_inputs + """Returns the `Tensor` corresponding to the preprocessed inputs.""" + rank = inputs.shape.rank + if rank < 4: + raise ValueError( + "Input Tensor must have at least 4 dimensions (for " + "batch size, height, width, and channels), but it only has " + "{}".format(rank) + ) + + flattened_inputs = snt.Flatten(preserve_dims=3)(inputs) + processed_inputs = tf.image.convert_image_dtype( + flattened_inputs, dtype=output_dtype + ) + return processed_inputs class DrQTorso(snt.Module): - """DrQ Torso inspired by the second DrQ paper [Yarats et al., 2021]. + """DrQ Torso inspired by the second DrQ paper [Yarats et al., 2021]. [Yarats et al., 2021] https://arxiv.org/abs/2107.09645 """ - def __init__( - self, - data_format: str = 'NHWC', - activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.relu, - output_dtype: tf.DType = tf.float32, - name: str = 'resnet_torso'): - super().__init__(name=name) - - self._output_dtype = output_dtype - - # Create a Conv2D factory since we'll be making quite a few. - gain = 2**0.5 if activation == tf.nn.relu else 1. - def build_conv_layer(name: str, - output_channels: int = 32, - kernel_shape: Sequence[int] = (3, 3), - stride: int = 1): - return snt.Conv2D( - output_channels=output_channels, - kernel_shape=kernel_shape, - stride=stride, - padding='SAME', - data_format=data_format, - w_init=snt.initializers.Orthogonal(gain=gain, seed=None), - b_init=snt.initializers.Zeros(), - name=name) - - self._network = snt.Sequential( - [build_conv_layer('conv_0', stride=2), - activation, - build_conv_layer('conv_1', stride=1), - activation, - build_conv_layer('conv_2', stride=1), - activation, - build_conv_layer('conv_3', stride=1), - activation, - snt.Flatten()]) - - def __call__(self, inputs: tf.Tensor) -> tf.Tensor: - """Evaluates the ResidualPixelCore.""" - - # Normalize to -0.5 to 0.5 - preprocessed_inputs = _preprocess_inputs(inputs, self._output_dtype) - 0.5 - - torso_output = self._network(preprocessed_inputs) - - return torso_output + def __init__( + self, + data_format: str = "NHWC", + activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.relu, + output_dtype: tf.DType = tf.float32, + name: str = "resnet_torso", + ): + super().__init__(name=name) + + self._output_dtype = output_dtype + + # Create a Conv2D factory since we'll be making quite a few. + gain = 2 ** 0.5 if activation == tf.nn.relu else 1.0 + + def build_conv_layer( + name: str, + output_channels: int = 32, + kernel_shape: Sequence[int] = (3, 3), + stride: int = 1, + ): + return snt.Conv2D( + output_channels=output_channels, + kernel_shape=kernel_shape, + stride=stride, + padding="SAME", + data_format=data_format, + w_init=snt.initializers.Orthogonal(gain=gain, seed=None), + b_init=snt.initializers.Zeros(), + name=name, + ) + + self._network = snt.Sequential( + [ + build_conv_layer("conv_0", stride=2), + activation, + build_conv_layer("conv_1", stride=1), + activation, + build_conv_layer("conv_2", stride=1), + activation, + build_conv_layer("conv_3", stride=1), + activation, + snt.Flatten(), + ] + ) + + def __call__(self, inputs: tf.Tensor) -> tf.Tensor: + """Evaluates the ResidualPixelCore.""" + + # Normalize to -0.5 to 0.5 + preprocessed_inputs = _preprocess_inputs(inputs, self._output_dtype) - 0.5 + + torso_output = self._network(preprocessed_inputs) + + return torso_output diff --git a/acme/tf/savers.py b/acme/tf/savers.py index c0efd19a65..a650a1a3e3 100644 --- a/acme/tf/savers.py +++ b/acme/tf/savers.py @@ -21,17 +21,16 @@ import time from typing import Mapping, Optional, Union -from absl import logging -from acme import core -from acme.utils import signals -from acme.utils import paths import sonnet as snt import tensorflow as tf import tensorflow_probability as tfp import tree - +from absl import logging from tensorflow.python.saved_model import revived_types +from acme import core +from acme.utils import paths, signals + PythonState = tf.train.experimental.PythonState Checkpointable = Union[tf.Module, tf.Variable, PythonState] @@ -40,16 +39,16 @@ class TFSaveable(abc.ABC): - """An interface for objects that expose their checkpointable TF state.""" + """An interface for objects that expose their checkpointable TF state.""" - @property - @abc.abstractmethod - def state(self) -> Mapping[str, Checkpointable]: - """Returns TensorFlow checkpointable state.""" + @property + @abc.abstractmethod + def state(self) -> Mapping[str, Checkpointable]: + """Returns TensorFlow checkpointable state.""" class Checkpointer: - """Convenience class for periodically checkpointing. + """Convenience class for periodically checkpointing. This can be used to checkpoint any object with trackable state (e.g. tensorflow variables or modules); see tf.train.Checkpoint for @@ -72,20 +71,20 @@ class Checkpointer: ``` """ - def __init__( - self, - objects_to_save: Mapping[str, Union[Checkpointable, core.Saveable]], - *, - directory: str = '~/acme/', - subdirectory: str = 'default', - time_delta_minutes: float = 10.0, - enable_checkpointing: bool = True, - add_uid: bool = True, - max_to_keep: int = 1, - checkpoint_ttl_seconds: Optional[int] = _DEFAULT_CHECKPOINT_TTL, - keep_checkpoint_every_n_hours: Optional[int] = None, - ): - """Builds the saver object. + def __init__( + self, + objects_to_save: Mapping[str, Union[Checkpointable, core.Saveable]], + *, + directory: str = "~/acme/", + subdirectory: str = "default", + time_delta_minutes: float = 10.0, + enable_checkpointing: bool = True, + add_uid: bool = True, + max_to_keep: int = 1, + checkpoint_ttl_seconds: Optional[int] = _DEFAULT_CHECKPOINT_TTL, + keep_checkpoint_every_n_hours: Optional[int] = None, + ): + """Builds the saver object. Args: objects_to_save: Mapping specifying what to checkpoint. @@ -102,41 +101,43 @@ def __init__( tf.train.CheckpointManager. """ - # Convert `Saveable` objects to TF `Checkpointable` first, if necessary. - def to_ckptable(x: Union[Checkpointable, core.Saveable]) -> Checkpointable: - if isinstance(x, core.Saveable): - return SaveableAdapter(x) - return x - - objects_to_save = {k: to_ckptable(v) for k, v in objects_to_save.items()} - - self._time_delta_minutes = time_delta_minutes - self._last_saved = 0. - self._enable_checkpointing = enable_checkpointing - self._checkpoint_manager = None - - if enable_checkpointing: - # Checkpoint object that handles saving/restoring. - self._checkpoint = tf.train.Checkpoint(**objects_to_save) - self._checkpoint_dir = paths.process_path( - directory, - 'checkpoints', - subdirectory, - ttl_seconds=checkpoint_ttl_seconds, - backups=False, - add_uid=add_uid) - - # Create a manager to maintain different checkpoints. - self._checkpoint_manager = tf.train.CheckpointManager( - self._checkpoint, - directory=self._checkpoint_dir, - max_to_keep=max_to_keep, - keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours) - - self.restore() - - def save(self, force: bool = False) -> bool: - """Save the checkpoint if it's the appropriate time, otherwise no-ops. + # Convert `Saveable` objects to TF `Checkpointable` first, if necessary. + def to_ckptable(x: Union[Checkpointable, core.Saveable]) -> Checkpointable: + if isinstance(x, core.Saveable): + return SaveableAdapter(x) + return x + + objects_to_save = {k: to_ckptable(v) for k, v in objects_to_save.items()} + + self._time_delta_minutes = time_delta_minutes + self._last_saved = 0.0 + self._enable_checkpointing = enable_checkpointing + self._checkpoint_manager = None + + if enable_checkpointing: + # Checkpoint object that handles saving/restoring. + self._checkpoint = tf.train.Checkpoint(**objects_to_save) + self._checkpoint_dir = paths.process_path( + directory, + "checkpoints", + subdirectory, + ttl_seconds=checkpoint_ttl_seconds, + backups=False, + add_uid=add_uid, + ) + + # Create a manager to maintain different checkpoints. + self._checkpoint_manager = tf.train.CheckpointManager( + self._checkpoint, + directory=self._checkpoint_dir, + max_to_keep=max_to_keep, + keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, + ) + + self.restore() + + def save(self, force: bool = False) -> bool: + """Save the checkpoint if it's the appropriate time, otherwise no-ops. Args: force: Whether to force a save regardless of time elapsed since last save. @@ -144,105 +145,104 @@ def save(self, force: bool = False) -> bool: Returns: A boolean indicating if a save event happened. """ - if not self._enable_checkpointing: - return False + if not self._enable_checkpointing: + return False - if (not force and - time.time() - self._last_saved < 60 * self._time_delta_minutes): - return False + if not force and time.time() - self._last_saved < 60 * self._time_delta_minutes: + return False - # Save any checkpoints. - logging.info('Saving checkpoint: %s', self._checkpoint_manager.directory) - self._checkpoint_manager.save() - self._last_saved = time.time() + # Save any checkpoints. + logging.info("Saving checkpoint: %s", self._checkpoint_manager.directory) + self._checkpoint_manager.save() + self._last_saved = time.time() - return True + return True - def restore(self): - # Restore from the most recent checkpoint (if it exists). - checkpoint_to_restore = self._checkpoint_manager.latest_checkpoint - logging.info('Attempting to restore checkpoint: %s', - checkpoint_to_restore) - self._checkpoint.restore(checkpoint_to_restore) + def restore(self): + # Restore from the most recent checkpoint (if it exists). + checkpoint_to_restore = self._checkpoint_manager.latest_checkpoint + logging.info("Attempting to restore checkpoint: %s", checkpoint_to_restore) + self._checkpoint.restore(checkpoint_to_restore) - @property - def directory(self): - return self._checkpoint_manager.directory + @property + def directory(self): + return self._checkpoint_manager.directory class CheckpointingRunner(core.Worker): - """Wrap an object and expose a run method which checkpoints periodically. + """Wrap an object and expose a run method which checkpoints periodically. This internally creates a Checkpointer around `wrapped` object and exposes all of the methods of `wrapped`. Additionally, any `**kwargs` passed to the runner are forwarded to the internal Checkpointer. """ - def __init__( - self, - wrapped: Union[Checkpointable, core.Saveable, TFSaveable], - key: str = 'wrapped', - *, - time_delta_minutes: int = 30, - **kwargs, - ): - - if isinstance(wrapped, TFSaveable): - # If the object to be wrapped exposes its TF State, checkpoint that. - objects_to_save = wrapped.state - else: - # Otherwise checkpoint the wrapped object itself. - objects_to_save = wrapped - - self._wrapped = wrapped - self._time_delta_minutes = time_delta_minutes - self._checkpointer = Checkpointer( - objects_to_save={key: objects_to_save}, - time_delta_minutes=time_delta_minutes, - **kwargs) - - # Handle preemption signal. Note that this must happen in the main thread. - def _signal_handler(self): - logging.info('Caught SIGTERM: forcing a checkpoint save.') - self._checkpointer.save(force=True) - - def step(self): - if isinstance(self._wrapped, core.Learner): - # Learners have a step() method, so alternate between that and ckpt call. - self._wrapped.step() - self._checkpointer.save() - else: - # Wrapped object doesn't have a run method; set our run method to ckpt. - self.checkpoint() - - def run(self): - """Runs the checkpointer.""" - with signals.runtime_terminator(self._signal_handler): - while True: - self.step() - - def __dir__(self): - return dir(self._wrapped) + ['get_directory'] - - # TODO(b/195915583) : Throw when wrapped object has get_directory() method. - def __getattr__(self, name): - if name == 'get_directory': - return self.get_directory - return getattr(self._wrapped, name) - - def checkpoint(self): - self._checkpointer.save() - # Do not sleep for a long period of time to avoid LaunchPad program - # termination hangs (time.sleep is not interruptible). - for _ in range(self._time_delta_minutes * 60): - time.sleep(1) - - def get_directory(self): - return self._checkpointer.directory + def __init__( + self, + wrapped: Union[Checkpointable, core.Saveable, TFSaveable], + key: str = "wrapped", + *, + time_delta_minutes: int = 30, + **kwargs, + ): + + if isinstance(wrapped, TFSaveable): + # If the object to be wrapped exposes its TF State, checkpoint that. + objects_to_save = wrapped.state + else: + # Otherwise checkpoint the wrapped object itself. + objects_to_save = wrapped + + self._wrapped = wrapped + self._time_delta_minutes = time_delta_minutes + self._checkpointer = Checkpointer( + objects_to_save={key: objects_to_save}, + time_delta_minutes=time_delta_minutes, + **kwargs, + ) + + # Handle preemption signal. Note that this must happen in the main thread. + def _signal_handler(self): + logging.info("Caught SIGTERM: forcing a checkpoint save.") + self._checkpointer.save(force=True) + + def step(self): + if isinstance(self._wrapped, core.Learner): + # Learners have a step() method, so alternate between that and ckpt call. + self._wrapped.step() + self._checkpointer.save() + else: + # Wrapped object doesn't have a run method; set our run method to ckpt. + self.checkpoint() + + def run(self): + """Runs the checkpointer.""" + with signals.runtime_terminator(self._signal_handler): + while True: + self.step() + + def __dir__(self): + return dir(self._wrapped) + ["get_directory"] + + # TODO(b/195915583) : Throw when wrapped object has get_directory() method. + def __getattr__(self, name): + if name == "get_directory": + return self.get_directory + return getattr(self._wrapped, name) + + def checkpoint(self): + self._checkpointer.save() + # Do not sleep for a long period of time to avoid LaunchPad program + # termination hangs (time.sleep is not interruptible). + for _ in range(self._time_delta_minutes * 60): + time.sleep(1) + + def get_directory(self): + return self._checkpointer.directory class Snapshotter: - """Convenience class for periodically snapshotting. + """Convenience class for periodically snapshotting. Objects which can be snapshotted are limited to Sonnet or tensorflow Modules which implement a __call__ method. This will save the module's graph and @@ -266,15 +266,15 @@ class Snapshotter: ``` """ - def __init__( - self, - objects_to_save: Mapping[str, snt.Module], - *, - directory: str = '~/acme/', - time_delta_minutes: float = 30.0, - snapshot_ttl_seconds: int = _DEFAULT_SNAPSHOT_TTL, - ): - """Builds the saver object. + def __init__( + self, + objects_to_save: Mapping[str, snt.Module], + *, + directory: str = "~/acme/", + time_delta_minutes: float = 30.0, + snapshot_ttl_seconds: int = _DEFAULT_SNAPSHOT_TTL, + ): + """Builds the saver object. Args: objects_to_save: Mapping specifying what to snapshot. @@ -282,23 +282,24 @@ def __init__( time_delta_minutes: How often to save the snapshot, in minutes. snapshot_ttl_seconds: TTL (time to leave) in seconds for snapshots. """ - objects_to_save = objects_to_save or {} + objects_to_save = objects_to_save or {} - self._time_delta_minutes = time_delta_minutes - self._last_saved = 0. - self._snapshots = {} + self._time_delta_minutes = time_delta_minutes + self._last_saved = 0.0 + self._snapshots = {} - # Save the base directory path so we can refer to it if needed. - self.directory = paths.process_path( - directory, 'snapshots', ttl_seconds=snapshot_ttl_seconds) + # Save the base directory path so we can refer to it if needed. + self.directory = paths.process_path( + directory, "snapshots", ttl_seconds=snapshot_ttl_seconds + ) - # Save a dictionary mapping paths to snapshot capable models. - for name, module in objects_to_save.items(): - path = os.path.join(self.directory, name) - self._snapshots[path] = make_snapshot(module) + # Save a dictionary mapping paths to snapshot capable models. + for name, module in objects_to_save.items(): + path = os.path.join(self.directory, name) + self._snapshots[path] = make_snapshot(module) - def save(self, force: bool = False) -> bool: - """Snapshots if it's the appropriate time, otherwise no-ops. + def save(self, force: bool = False) -> bool: + """Snapshots if it's the appropriate time, otherwise no-ops. Args: force: If True, save new snapshot no matter how long it's been since the @@ -307,52 +308,53 @@ def save(self, force: bool = False) -> bool: Returns: A boolean indicating if a save event happened. """ - seconds_since_last = time.time() - self._last_saved - if (self._snapshots and - (force or seconds_since_last >= 60 * self._time_delta_minutes)): - # Save any snapshots. - for path, snapshot in self._snapshots.items(): - tf.saved_model.save(snapshot, path) + seconds_since_last = time.time() - self._last_saved + if self._snapshots and ( + force or seconds_since_last >= 60 * self._time_delta_minutes + ): + # Save any snapshots. + for path, snapshot in self._snapshots.items(): + tf.saved_model.save(snapshot, path) - # Record the time we finished saving. - self._last_saved = time.time() + # Record the time we finished saving. + self._last_saved = time.time() - return True + return True - return False + return False class Snapshot(tf.Module): - """Thin wrapper which allows the module to be saved.""" + """Thin wrapper which allows the module to be saved.""" - def __init__(self): - super().__init__() - self._module = None - self._variables = None - self._trainable_variables = None + def __init__(self): + super().__init__() + self._module = None + self._variables = None + self._trainable_variables = None - @tf.function - def __call__(self, *args, **kwargs): - return self._module(*args, **kwargs) + @tf.function + def __call__(self, *args, **kwargs): + return self._module(*args, **kwargs) - @property - def submodules(self): - return [self._module] + @property + def submodules(self): + return [self._module] - @property - def variables(self): - return self._variables + @property + def variables(self): + return self._variables - @property - def trainable_variables(self): - return self._trainable_variables + @property + def trainable_variables(self): + return self._trainable_variables # Registers the Snapshot object above such that when it is restored by # tf.saved_model.load it will be restored as a Snapshot. This is important # because it allows us to expose the __call__, and *_variables properties. revived_types.register_revived_type( - 'acme_snapshot', + "acme_snapshot", lambda obj: isinstance(obj, Snapshot), versions=[ revived_types.VersionedTypeRegistration( @@ -362,54 +364,59 @@ def trainable_variables(self): min_consumer_version=1, setter=setattr, ) - ]) + ], +) def make_snapshot(module: snt.Module): - """Create a thin wrapper around a module to make it snapshottable.""" - # Get the input signature as long as it has been created. - input_signature = _get_input_signature(module) - if input_signature is None: - raise ValueError( - ('module instance "{}" has no input_signature attribute, ' - 'which is required for snapshotting; run ' - 'create_variables to add this annotation.').format(module.name)) - - # This function will return the object as a composite tensor if it is a - # distribution and will otherwise return it with no changes. - def as_composite(obj): - if isinstance(obj, tfp.distributions.Distribution): - return tfp.experimental.as_composite(obj) - else: - return obj - - # Replace any distributions returned by the module with composite tensors and - # wrap it up in tf.function so we can process it properly. - @tf.function - def wrapped_module(*args, **kwargs): - return tree.map_structure(as_composite, module(*args, **kwargs)) - - # pylint: disable=protected-access - snapshot = Snapshot() - snapshot._module = wrapped_module - snapshot._variables = module.variables - snapshot._trainable_variables = module.trainable_variables - # pylint: disable=protected-access - - # Make sure the snapshot has the proper input signature. - snapshot.__call__.get_concrete_function(*input_signature) - - # If we are an RNN also save the initial-state generating function. - if isinstance(module, snt.RNNCore): - snapshot.initial_state = tf.function(module.initial_state) - snapshot.initial_state.get_concrete_function( - tf.TensorSpec(shape=(), dtype=tf.int32)) - - return snapshot + """Create a thin wrapper around a module to make it snapshottable.""" + # Get the input signature as long as it has been created. + input_signature = _get_input_signature(module) + if input_signature is None: + raise ValueError( + ( + 'module instance "{}" has no input_signature attribute, ' + "which is required for snapshotting; run " + "create_variables to add this annotation." + ).format(module.name) + ) + + # This function will return the object as a composite tensor if it is a + # distribution and will otherwise return it with no changes. + def as_composite(obj): + if isinstance(obj, tfp.distributions.Distribution): + return tfp.experimental.as_composite(obj) + else: + return obj + + # Replace any distributions returned by the module with composite tensors and + # wrap it up in tf.function so we can process it properly. + @tf.function + def wrapped_module(*args, **kwargs): + return tree.map_structure(as_composite, module(*args, **kwargs)) + + # pylint: disable=protected-access + snapshot = Snapshot() + snapshot._module = wrapped_module + snapshot._variables = module.variables + snapshot._trainable_variables = module.trainable_variables + # pylint: disable=protected-access + + # Make sure the snapshot has the proper input signature. + snapshot.__call__.get_concrete_function(*input_signature) + + # If we are an RNN also save the initial-state generating function. + if isinstance(module, snt.RNNCore): + snapshot.initial_state = tf.function(module.initial_state) + snapshot.initial_state.get_concrete_function( + tf.TensorSpec(shape=(), dtype=tf.int32) + ) + + return snapshot def _get_input_signature(module: snt.Module) -> Optional[tf.TensorSpec]: - """Get module input signature. + """Get module input signature. Works even if the module with signature is wrapper into snt.Sequentual or snt.DeepRNN. @@ -423,38 +430,39 @@ def _get_input_signature(module: snt.Module) -> Optional[tf.TensorSpec]: Returns: Input signature of the module or None if it's not available. """ - if hasattr(module, '_input_signature'): - return module._input_signature # pylint: disable=protected-access - - if isinstance(module, snt.Sequential): - first_layer = module._layers[0] # pylint: disable=protected-access - return _get_input_signature(first_layer) - - if isinstance(module, snt.DeepRNN): - first_layer = module._layers[0] # pylint: disable=protected-access - input_signature = _get_input_signature(first_layer) - - # Wrapping a module in DeepRNN changes its state shape, so we need to bring - # it up to date. - state = module.initial_state(1) - input_signature[-1] = tree.map_structure( - lambda t: tf.TensorSpec((None,) + t.shape[1:], t.dtype), state) + if hasattr(module, "_input_signature"): + return module._input_signature # pylint: disable=protected-access + + if isinstance(module, snt.Sequential): + first_layer = module._layers[0] # pylint: disable=protected-access + return _get_input_signature(first_layer) + + if isinstance(module, snt.DeepRNN): + first_layer = module._layers[0] # pylint: disable=protected-access + input_signature = _get_input_signature(first_layer) + + # Wrapping a module in DeepRNN changes its state shape, so we need to bring + # it up to date. + state = module.initial_state(1) + input_signature[-1] = tree.map_structure( + lambda t: tf.TensorSpec((None,) + t.shape[1:], t.dtype), state + ) - return input_signature + return input_signature - return None + return None class SaveableAdapter(tf.train.experimental.PythonState): - """Adapter which allows `Saveable` object to be checkpointed by TensorFlow.""" + """Adapter which allows `Saveable` object to be checkpointed by TensorFlow.""" - def __init__(self, object_to_save: core.Saveable): - self._object_to_save = object_to_save + def __init__(self, object_to_save: core.Saveable): + self._object_to_save = object_to_save - def serialize(self): - state = self._object_to_save.save() - return pickle.dumps(state) + def serialize(self): + state = self._object_to_save.save() + return pickle.dumps(state) - def deserialize(self, pickled: bytes): - state = pickle.loads(pickled) - self._object_to_save.restore(state) + def deserialize(self, pickled: bytes): + state = pickle.loads(pickled) + self._object_to_save.restore(state) diff --git a/acme/tf/savers_test.py b/acme/tf/savers_test.py index cd075a1818..c2e70788b2 100644 --- a/acme/tf/savers_test.py +++ b/acme/tf/savers_test.py @@ -19,276 +19,279 @@ import time from unittest import mock -from acme import specs -from acme.testing import test_utils -from acme.tf import networks -from acme.tf import savers as tf2_savers -from acme.tf import utils as tf2_utils -from acme.utils import paths import launchpad import numpy as np import sonnet as snt import tensorflow as tf import tree - from absl.testing import absltest +from acme import specs +from acme.testing import test_utils +from acme.tf import networks +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import paths + class DummySaveable(tf2_savers.TFSaveable): - _state: tf.Variable + _state: tf.Variable - def __init__(self): - self._state = tf.Variable(0, dtype=tf.int32) + def __init__(self): + self._state = tf.Variable(0, dtype=tf.int32) - @property - def state(self): - return {'state': self._state} + @property + def state(self): + return {"state": self._state} class CheckpointerTest(test_utils.TestCase): - - def test_save_and_restore(self): - """Test that checkpointer correctly calls save and restore.""" - - x = tf.Variable(0, dtype=tf.int32) - directory = self.get_tempdir() - checkpointer = tf2_savers.Checkpointer( - objects_to_save={'x': x}, time_delta_minutes=0., directory=directory) - - for _ in range(10): - saved = checkpointer.save() - self.assertTrue(saved) - x.assign_add(1) - checkpointer.restore() - np.testing.assert_array_equal(x.numpy(), np.int32(0)) - - def test_save_and_new_restore(self): - """Tests that a fresh checkpointer correctly restores an existing ckpt.""" - with mock.patch.object(paths, 'get_unique_id') as mock_unique_id: - mock_unique_id.return_value = ('test',) - x = tf.Variable(0, dtype=tf.int32) - directory = self.get_tempdir() - checkpointer1 = tf2_savers.Checkpointer( - objects_to_save={'x': x}, time_delta_minutes=0., directory=directory) - checkpointer1.save() - x.assign_add(1) - # Simulate a preemption: x is changed, and we make a new Checkpointer. - checkpointer2 = tf2_savers.Checkpointer( - objects_to_save={'x': x}, time_delta_minutes=0., directory=directory) - checkpointer2.restore() - np.testing.assert_array_equal(x.numpy(), np.int32(0)) - - def test_save_and_restore_time_based(self): - """Test that checkpointer correctly calls save and restore based on time.""" - - x = tf.Variable(0, dtype=tf.int32) - directory = self.get_tempdir() - checkpointer = tf2_savers.Checkpointer( - objects_to_save={'x': x}, time_delta_minutes=1., directory=directory) - - with mock.patch.object(time, 'time') as mock_time: - mock_time.return_value = 0. - self.assertFalse(checkpointer.save()) - - mock_time.return_value = 40. - self.assertFalse(checkpointer.save()) - - mock_time.return_value = 70. - self.assertTrue(checkpointer.save()) - x.assign_add(1) - checkpointer.restore() - np.testing.assert_array_equal(x.numpy(), np.int32(0)) - - def test_no_checkpoint(self): - """Test that checkpointer does nothing when checkpoint=False.""" - num_steps = tf.Variable(0) - checkpointer = tf2_savers.Checkpointer( - objects_to_save={'num_steps': num_steps}, enable_checkpointing=False) - - for _ in range(10): - self.assertFalse(checkpointer.save()) - self.assertIsNone(checkpointer._checkpoint_manager) - - def test_tf_saveable(self): - x = DummySaveable() - - directory = self.get_tempdir() - checkpoint_runner = tf2_savers.CheckpointingRunner( - x, time_delta_minutes=0, directory=directory) - checkpoint_runner._checkpointer.save() - - x._state.assign_add(1) - checkpoint_runner._checkpointer.restore() - - np.testing.assert_array_equal(x._state.numpy(), np.int32(0)) + def test_save_and_restore(self): + """Test that checkpointer correctly calls save and restore.""" + + x = tf.Variable(0, dtype=tf.int32) + directory = self.get_tempdir() + checkpointer = tf2_savers.Checkpointer( + objects_to_save={"x": x}, time_delta_minutes=0.0, directory=directory + ) + + for _ in range(10): + saved = checkpointer.save() + self.assertTrue(saved) + x.assign_add(1) + checkpointer.restore() + np.testing.assert_array_equal(x.numpy(), np.int32(0)) + + def test_save_and_new_restore(self): + """Tests that a fresh checkpointer correctly restores an existing ckpt.""" + with mock.patch.object(paths, "get_unique_id") as mock_unique_id: + mock_unique_id.return_value = ("test",) + x = tf.Variable(0, dtype=tf.int32) + directory = self.get_tempdir() + checkpointer1 = tf2_savers.Checkpointer( + objects_to_save={"x": x}, time_delta_minutes=0.0, directory=directory + ) + checkpointer1.save() + x.assign_add(1) + # Simulate a preemption: x is changed, and we make a new Checkpointer. + checkpointer2 = tf2_savers.Checkpointer( + objects_to_save={"x": x}, time_delta_minutes=0.0, directory=directory + ) + checkpointer2.restore() + np.testing.assert_array_equal(x.numpy(), np.int32(0)) + + def test_save_and_restore_time_based(self): + """Test that checkpointer correctly calls save and restore based on time.""" + + x = tf.Variable(0, dtype=tf.int32) + directory = self.get_tempdir() + checkpointer = tf2_savers.Checkpointer( + objects_to_save={"x": x}, time_delta_minutes=1.0, directory=directory + ) + + with mock.patch.object(time, "time") as mock_time: + mock_time.return_value = 0.0 + self.assertFalse(checkpointer.save()) + + mock_time.return_value = 40.0 + self.assertFalse(checkpointer.save()) + + mock_time.return_value = 70.0 + self.assertTrue(checkpointer.save()) + x.assign_add(1) + checkpointer.restore() + np.testing.assert_array_equal(x.numpy(), np.int32(0)) + + def test_no_checkpoint(self): + """Test that checkpointer does nothing when checkpoint=False.""" + num_steps = tf.Variable(0) + checkpointer = tf2_savers.Checkpointer( + objects_to_save={"num_steps": num_steps}, enable_checkpointing=False + ) + + for _ in range(10): + self.assertFalse(checkpointer.save()) + self.assertIsNone(checkpointer._checkpoint_manager) + + def test_tf_saveable(self): + x = DummySaveable() + + directory = self.get_tempdir() + checkpoint_runner = tf2_savers.CheckpointingRunner( + x, time_delta_minutes=0, directory=directory + ) + checkpoint_runner._checkpointer.save() + + x._state.assign_add(1) + checkpoint_runner._checkpointer.restore() + + np.testing.assert_array_equal(x._state.numpy(), np.int32(0)) class CheckpointingRunnerTest(test_utils.TestCase): + def test_signal_handling(self): + x = DummySaveable() - def test_signal_handling(self): - x = DummySaveable() - - # Increment the value of DummySavable. - x.state['state'].assign_add(1) - - directory = self.get_tempdir() - - # Patch signals.add_handler so the registered signal handler sets the event. - with mock.patch.object( - launchpad, 'register_stop_handler') as mock_register_stop_handler: - def add_handler(fn): - fn() - mock_register_stop_handler.side_effect = add_handler - - runner = tf2_savers.CheckpointingRunner( - wrapped=x, - time_delta_minutes=0, - directory=directory) - with self.assertRaises(SystemExit): - runner.run() - - # Recreate DummySavable(), its tf.Variable is initialized to 0. - x = DummySaveable() - # Recreate the CheckpointingRunner, which will restore DummySavable() to 1. - tf2_savers.CheckpointingRunner( - wrapped=x, - time_delta_minutes=0, - directory=directory) - # Check DummyVariable() was restored properly. - np.testing.assert_array_equal(x.state['state'].numpy(), np.int32(1)) - - def test_checkpoint_dir(self): - directory = self.get_tempdir() - ckpt_runner = tf2_savers.CheckpointingRunner( - wrapped=DummySaveable(), - time_delta_minutes=0, - directory=directory) - expected_dir_re = f'{directory}/[a-z0-9-]*/checkpoints/default' - regexp = re.compile(expected_dir_re) - self.assertIsNotNone(regexp.fullmatch(ckpt_runner.get_directory())) + # Increment the value of DummySavable. + x.state["state"].assign_add(1) + directory = self.get_tempdir() -class SnapshotterTest(test_utils.TestCase): + # Patch signals.add_handler so the registered signal handler sets the event. + with mock.patch.object( + launchpad, "register_stop_handler" + ) as mock_register_stop_handler: + + def add_handler(fn): + fn() + + mock_register_stop_handler.side_effect = add_handler - def test_snapshot(self): - """Test that snapshotter correctly calls saves/restores snapshots.""" - # Create a test network. - net1 = networks.LayerNormMLP([10, 10]) - spec = specs.Array([10], dtype=np.float32) - tf2_utils.create_variables(net1, [spec]) - - # Save the test network. - directory = self.get_tempdir() - objects_to_save = {'net': net1} - snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) - snapshotter.save() - - # Reload the test network. - net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) - inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) - - with tf.GradientTape() as tape: - outputs1 = net1(inputs) - loss1 = tf.math.reduce_sum(outputs1) - grads1 = tape.gradient(loss1, net1.trainable_variables) - - with tf.GradientTape() as tape: - outputs2 = net2(inputs) - loss2 = tf.math.reduce_sum(outputs2) - grads2 = tape.gradient(loss2, net2.trainable_variables) - - assert np.allclose(outputs1, outputs2) - assert all(tree.map_structure(np.allclose, list(grads1), list(grads2))) - - def test_snapshot_distribution(self): - """Test that snapshotter correctly calls saves/restores snapshots.""" - # Create a test network. - net1 = snt.Sequential([ - networks.LayerNormMLP([10, 10]), - networks.MultivariateNormalDiagHead(1) - ]) - spec = specs.Array([10], dtype=np.float32) - tf2_utils.create_variables(net1, [spec]) - - # Save the test network. - directory = self.get_tempdir() - objects_to_save = {'net': net1} - snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) - snapshotter.save() - - # Reload the test network. - net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) - inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) - - with tf.GradientTape() as tape: - dist1 = net1(inputs) - loss1 = tf.math.reduce_sum(dist1.mean() + dist1.variance()) - grads1 = tape.gradient(loss1, net1.trainable_variables) - - with tf.GradientTape() as tape: - dist2 = net2(inputs) - loss2 = tf.math.reduce_sum(dist2.mean() + dist2.variance()) - grads2 = tape.gradient(loss2, net2.trainable_variables) - - assert all(tree.map_structure(np.allclose, list(grads1), list(grads2))) - - def test_force_snapshot(self): - """Test that the force feature in Snapshotter.save() works correctly.""" - # Create a test network. - net = snt.Linear(10) - spec = specs.Array([10], dtype=np.float32) - tf2_utils.create_variables(net, [spec]) - - # Save the test network. - directory = self.get_tempdir() - objects_to_save = {'net': net} - # Very long time_delta_minutes. - snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory, - time_delta_minutes=1000) - self.assertTrue(snapshotter.save(force=False)) - - # Due to the long time_delta_minutes, only force=True will create a new - # snapshot. This also checks the default is force=False. - self.assertFalse(snapshotter.save()) - self.assertTrue(snapshotter.save(force=True)) - - def test_rnn_snapshot(self): - """Test that snapshotter correctly calls saves/restores snapshots on RNNs.""" - # Create a test network. - net = snt.LSTM(10) - spec = specs.Array([10], dtype=np.float32) - tf2_utils.create_variables(net, [spec]) - - # Test that if you add some postprocessing without rerunning - # create_variables, it still works. - wrapped_net = snt.DeepRNN([net, lambda x: x]) - - for net1 in [net, wrapped_net]: - # Save the test network. - directory = self.get_tempdir() - objects_to_save = {'net': net1} - snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) - snapshotter.save() - - # Reload the test network. - net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) - inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) - - with tf.GradientTape() as tape: - outputs1, next_state1 = net1(inputs, net1.initial_state(1)) - loss1 = tf.math.reduce_sum(outputs1) - grads1 = tape.gradient(loss1, net1.trainable_variables) - - with tf.GradientTape() as tape: - outputs2, next_state2 = net2(inputs, net2.initial_state(1)) - loss2 = tf.math.reduce_sum(outputs2) - grads2 = tape.gradient(loss2, net2.trainable_variables) - - assert np.allclose(outputs1, outputs2) - assert np.allclose(tree.flatten(next_state1), tree.flatten(next_state2)) - assert all(tree.map_structure(np.allclose, list(grads1), list(grads2))) - - -if __name__ == '__main__': - absltest.main() + runner = tf2_savers.CheckpointingRunner( + wrapped=x, time_delta_minutes=0, directory=directory + ) + with self.assertRaises(SystemExit): + runner.run() + + # Recreate DummySavable(), its tf.Variable is initialized to 0. + x = DummySaveable() + # Recreate the CheckpointingRunner, which will restore DummySavable() to 1. + tf2_savers.CheckpointingRunner( + wrapped=x, time_delta_minutes=0, directory=directory + ) + # Check DummyVariable() was restored properly. + np.testing.assert_array_equal(x.state["state"].numpy(), np.int32(1)) + + def test_checkpoint_dir(self): + directory = self.get_tempdir() + ckpt_runner = tf2_savers.CheckpointingRunner( + wrapped=DummySaveable(), time_delta_minutes=0, directory=directory + ) + expected_dir_re = f"{directory}/[a-z0-9-]*/checkpoints/default" + regexp = re.compile(expected_dir_re) + self.assertIsNotNone(regexp.fullmatch(ckpt_runner.get_directory())) + + +class SnapshotterTest(test_utils.TestCase): + def test_snapshot(self): + """Test that snapshotter correctly calls saves/restores snapshots.""" + # Create a test network. + net1 = networks.LayerNormMLP([10, 10]) + spec = specs.Array([10], dtype=np.float32) + tf2_utils.create_variables(net1, [spec]) + + # Save the test network. + directory = self.get_tempdir() + objects_to_save = {"net": net1} + snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) + snapshotter.save() + + # Reload the test network. + net2 = tf.saved_model.load(os.path.join(snapshotter.directory, "net")) + inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) + + with tf.GradientTape() as tape: + outputs1 = net1(inputs) + loss1 = tf.math.reduce_sum(outputs1) + grads1 = tape.gradient(loss1, net1.trainable_variables) + + with tf.GradientTape() as tape: + outputs2 = net2(inputs) + loss2 = tf.math.reduce_sum(outputs2) + grads2 = tape.gradient(loss2, net2.trainable_variables) + + assert np.allclose(outputs1, outputs2) + assert all(tree.map_structure(np.allclose, list(grads1), list(grads2))) + + def test_snapshot_distribution(self): + """Test that snapshotter correctly calls saves/restores snapshots.""" + # Create a test network. + net1 = snt.Sequential( + [networks.LayerNormMLP([10, 10]), networks.MultivariateNormalDiagHead(1)] + ) + spec = specs.Array([10], dtype=np.float32) + tf2_utils.create_variables(net1, [spec]) + + # Save the test network. + directory = self.get_tempdir() + objects_to_save = {"net": net1} + snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) + snapshotter.save() + + # Reload the test network. + net2 = tf.saved_model.load(os.path.join(snapshotter.directory, "net")) + inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) + + with tf.GradientTape() as tape: + dist1 = net1(inputs) + loss1 = tf.math.reduce_sum(dist1.mean() + dist1.variance()) + grads1 = tape.gradient(loss1, net1.trainable_variables) + + with tf.GradientTape() as tape: + dist2 = net2(inputs) + loss2 = tf.math.reduce_sum(dist2.mean() + dist2.variance()) + grads2 = tape.gradient(loss2, net2.trainable_variables) + + assert all(tree.map_structure(np.allclose, list(grads1), list(grads2))) + + def test_force_snapshot(self): + """Test that the force feature in Snapshotter.save() works correctly.""" + # Create a test network. + net = snt.Linear(10) + spec = specs.Array([10], dtype=np.float32) + tf2_utils.create_variables(net, [spec]) + + # Save the test network. + directory = self.get_tempdir() + objects_to_save = {"net": net} + # Very long time_delta_minutes. + snapshotter = tf2_savers.Snapshotter( + objects_to_save, directory=directory, time_delta_minutes=1000 + ) + self.assertTrue(snapshotter.save(force=False)) + + # Due to the long time_delta_minutes, only force=True will create a new + # snapshot. This also checks the default is force=False. + self.assertFalse(snapshotter.save()) + self.assertTrue(snapshotter.save(force=True)) + + def test_rnn_snapshot(self): + """Test that snapshotter correctly calls saves/restores snapshots on RNNs.""" + # Create a test network. + net = snt.LSTM(10) + spec = specs.Array([10], dtype=np.float32) + tf2_utils.create_variables(net, [spec]) + + # Test that if you add some postprocessing without rerunning + # create_variables, it still works. + wrapped_net = snt.DeepRNN([net, lambda x: x]) + + for net1 in [net, wrapped_net]: + # Save the test network. + directory = self.get_tempdir() + objects_to_save = {"net": net1} + snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) + snapshotter.save() + + # Reload the test network. + net2 = tf.saved_model.load(os.path.join(snapshotter.directory, "net")) + inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) + + with tf.GradientTape() as tape: + outputs1, next_state1 = net1(inputs, net1.initial_state(1)) + loss1 = tf.math.reduce_sum(outputs1) + grads1 = tape.gradient(loss1, net1.trainable_variables) + + with tf.GradientTape() as tape: + outputs2, next_state2 = net2(inputs, net2.initial_state(1)) + loss2 = tf.math.reduce_sum(outputs2) + grads2 = tape.gradient(loss2, net2.trainable_variables) + + assert np.allclose(outputs1, outputs2) + assert np.allclose(tree.flatten(next_state1), tree.flatten(next_state2)) + assert all(tree.map_structure(np.allclose, list(grads1), list(grads2))) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/tf/utils.py b/acme/tf/utils.py index 8a52e93ccf..e7e631737d 100644 --- a/acme/tf/utils.py +++ b/acme/tf/utils.py @@ -17,26 +17,26 @@ import functools from typing import List, Optional, Union -from acme import types -from acme.utils import tree_utils - import sonnet as snt import tensorflow as tf import tree +from acme import types +from acme.utils import tree_utils + def add_batch_dim(nest: types.NestedTensor) -> types.NestedTensor: - """Adds a batch dimension to each leaf of a nested structure of Tensors.""" - return tree.map_structure(lambda x: tf.expand_dims(x, axis=0), nest) + """Adds a batch dimension to each leaf of a nested structure of Tensors.""" + return tree.map_structure(lambda x: tf.expand_dims(x, axis=0), nest) def squeeze_batch_dim(nest: types.NestedTensor) -> types.NestedTensor: - """Squeezes out a batch dimension from each leaf of a nested structure.""" - return tree.map_structure(lambda x: tf.squeeze(x, axis=0), nest) + """Squeezes out a batch dimension from each leaf of a nested structure.""" + return tree.map_structure(lambda x: tf.squeeze(x, axis=0), nest) def batch_concat(inputs: types.NestedTensor) -> tf.Tensor: - """Concatenate a collection of Tensors while preserving the batch dimension. + """Concatenate a collection of Tensors while preserving the batch dimension. This takes a potentially nested collection of tensors, flattens everything but the batch (first) dimension, and concatenates along the resulting data @@ -49,36 +49,35 @@ def batch_concat(inputs: types.NestedTensor) -> tf.Tensor: A concatenated tensor which maintains the batch dimension but concatenates all other data along the flattened second dimension. """ - flat_leaves = tree.map_structure(snt.Flatten(), inputs) - return tf.concat(tree.flatten(flat_leaves), axis=-1) + flat_leaves = tree.map_structure(snt.Flatten(), inputs) + return tf.concat(tree.flatten(flat_leaves), axis=-1) def batch_to_sequence(data: types.NestedTensor) -> types.NestedTensor: - """Converts data between sequence-major and batch-major format.""" - return tree.map_structure( - lambda t: tf.transpose(t, [1, 0] + list(range(2, t.shape.rank))), data) + """Converts data between sequence-major and batch-major format.""" + return tree.map_structure( + lambda t: tf.transpose(t, [1, 0] + list(range(2, t.shape.rank))), data + ) def tile_tensor(tensor: tf.Tensor, multiple: int) -> tf.Tensor: - """Tiles `multiple` copies of `tensor` along a new leading axis.""" - rank = len(tensor.shape) - multiples = tf.constant([multiple] + [1] * rank, dtype=tf.int32) - expanded_tensor = tf.expand_dims(tensor, axis=0) - return tf.tile(expanded_tensor, multiples) + """Tiles `multiple` copies of `tensor` along a new leading axis.""" + rank = len(tensor.shape) + multiples = tf.constant([multiple] + [1] * rank, dtype=tf.int32) + expanded_tensor = tf.expand_dims(tensor, axis=0) + return tf.tile(expanded_tensor, multiples) -def tile_nested(inputs: types.NestedTensor, - multiple: int) -> types.NestedTensor: - """Tiles tensors in a nested structure along a new leading axis.""" - tile = functools.partial(tile_tensor, multiple=multiple) - return tree.map_structure(tile, inputs) +def tile_nested(inputs: types.NestedTensor, multiple: int) -> types.NestedTensor: + """Tiles tensors in a nested structure along a new leading axis.""" + tile = functools.partial(tile_tensor, multiple=multiple) + return tree.map_structure(tile, inputs) def create_variables( - network: snt.Module, - input_spec: List[Union[types.NestedSpec, tf.TensorSpec]], + network: snt.Module, input_spec: List[Union[types.NestedSpec, tf.TensorSpec]], ) -> Optional[tf.TensorSpec]: - """Builds the network with dummy inputs to create the necessary variables. + """Builds the network with dummy inputs to create the necessary variables. Args: network: Sonnet Module whose variables are to be created. @@ -90,40 +89,41 @@ def create_variables( it doesn't return anything (None); e.g. if the output is a tfp.distributions.Distribution. """ - # Create a dummy observation with no batch dimension. - dummy_input = zeros_like(input_spec) - - # If we have an RNNCore the hidden state will be an additional input. - if isinstance(network, snt.RNNCore): - initial_state = squeeze_batch_dim(network.initial_state(1)) - dummy_input += [initial_state] - - # Forward pass of the network which will create variables as a side effect. - dummy_output = network(*add_batch_dim(dummy_input)) - - # Evaluate the input signature by converting the dummy input into a - # TensorSpec. We then save the signature as a property of the network. This is - # done so that we can later use it when creating snapshots. We do this here - # because the snapshot code may not have access to the precise form of the - # inputs. - input_signature = tree.map_structure( - lambda t: tf.TensorSpec((None,) + t.shape, t.dtype), dummy_input) - network._input_signature = input_signature # pylint: disable=protected-access - - def spec(output): - # If the output is not a Tensor, return None as spec is ill-defined. - if not isinstance(output, tf.Tensor): - return None - # If this is not a scalar Tensor, make sure to squeeze out the batch dim. - if tf.rank(output) > 0: - output = squeeze_batch_dim(output) - return tf.TensorSpec(output.shape, output.dtype) - - return tree.map_structure(spec, dummy_output) + # Create a dummy observation with no batch dimension. + dummy_input = zeros_like(input_spec) + + # If we have an RNNCore the hidden state will be an additional input. + if isinstance(network, snt.RNNCore): + initial_state = squeeze_batch_dim(network.initial_state(1)) + dummy_input += [initial_state] + + # Forward pass of the network which will create variables as a side effect. + dummy_output = network(*add_batch_dim(dummy_input)) + + # Evaluate the input signature by converting the dummy input into a + # TensorSpec. We then save the signature as a property of the network. This is + # done so that we can later use it when creating snapshots. We do this here + # because the snapshot code may not have access to the precise form of the + # inputs. + input_signature = tree.map_structure( + lambda t: tf.TensorSpec((None,) + t.shape, t.dtype), dummy_input + ) + network._input_signature = input_signature # pylint: disable=protected-access + + def spec(output): + # If the output is not a Tensor, return None as spec is ill-defined. + if not isinstance(output, tf.Tensor): + return None + # If this is not a scalar Tensor, make sure to squeeze out the batch dim. + if tf.rank(output) > 0: + output = squeeze_batch_dim(output) + return tf.TensorSpec(output.shape, output.dtype) + + return tree.map_structure(spec, dummy_output) class TransformationWrapper(snt.Module): - """Helper class for to_sonnet_module. + """Helper class for to_sonnet_module. This wraps arbitrary Tensor-valued callables as a Sonnet module. A use case for this is in agent code that could take either a trainable @@ -133,20 +133,18 @@ class TransformationWrapper(snt.Module): otherwise need if e.g. calling get_variables() on the policy. """ - def __init__(self, - transformation: types.TensorValuedCallable, - name: Optional[str] = None): - super().__init__(name=name) - self._transformation = transformation + def __init__( + self, transformation: types.TensorValuedCallable, name: Optional[str] = None + ): + super().__init__(name=name) + self._transformation = transformation - def __call__(self, *args, **kwargs): - return self._transformation(*args, **kwargs) + def __call__(self, *args, **kwargs): + return self._transformation(*args, **kwargs) -def to_sonnet_module( - transformation: types.TensorValuedCallable - ) -> snt.Module: - """Convert a tensor transformation to a Sonnet Module. +def to_sonnet_module(transformation: types.TensorValuedCallable) -> snt.Module: + """Convert a tensor transformation to a Sonnet Module. Args: transformation: A Callable that takes one or more (nested) Tensors, and @@ -156,28 +154,28 @@ def to_sonnet_module( A Sonnet Module that wraps the transformation. """ - if isinstance(transformation, snt.Module): - return transformation + if isinstance(transformation, snt.Module): + return transformation - module = TransformationWrapper(transformation) + module = TransformationWrapper(transformation) - # Wrap the module to allow it to return an empty variable tuple. - return snt.allow_empty_variables(module) + # Wrap the module to allow it to return an empty variable tuple. + return snt.allow_empty_variables(module) def to_numpy(nest: types.NestedTensor) -> types.NestedArray: - """Converts a nest of Tensors to a nest of numpy arrays.""" - return tree.map_structure(lambda x: x.numpy(), nest) + """Converts a nest of Tensors to a nest of numpy arrays.""" + return tree.map_structure(lambda x: x.numpy(), nest) def to_numpy_squeeze(nest: types.NestedTensor, axis=0) -> types.NestedArray: - """Converts a nest of Tensors to a nest of numpy arrays and squeeze axis.""" - return tree.map_structure(lambda x: tf.squeeze(x, axis=axis).numpy(), nest) + """Converts a nest of Tensors to a nest of numpy arrays and squeeze axis.""" + return tree.map_structure(lambda x: tf.squeeze(x, axis=axis).numpy(), nest) def zeros_like(nest: types.Nest) -> types.NestedTensor: - """Given a nest of array-like objects, returns similarly nested tf.zeros.""" - return tree.map_structure(lambda x: tf.zeros(x.shape, x.dtype), nest) + """Given a nest of array-like objects, returns similarly nested tf.zeros.""" + return tree.map_structure(lambda x: tf.zeros(x.shape, x.dtype), nest) # TODO(b/160311329): Migrate call-sites and remove. diff --git a/acme/tf/utils_test.py b/acme/tf/utils_test.py index f54212629a..64e6331e6d 100644 --- a/acme/tf/utils_test.py +++ b/acme/tf/utils_test.py @@ -16,119 +16,121 @@ from typing import Sequence, Tuple -from acme import specs -from acme.tf import utils as tf2_utils import numpy as np import sonnet as snt import tensorflow as tf +from absl.testing import absltest, parameterized -from absl.testing import absltest -from absl.testing import parameterized +from acme import specs +from acme.tf import utils as tf2_utils class PolicyValueHead(snt.Module): - """A network with two linear layers, for policy and value respectively.""" + """A network with two linear layers, for policy and value respectively.""" - def __init__(self, num_actions: int): - super().__init__(name='policy_value_network') - self._policy_layer = snt.Linear(num_actions) - self._value_layer = snt.Linear(1) + def __init__(self, num_actions: int): + super().__init__(name="policy_value_network") + self._policy_layer = snt.Linear(num_actions) + self._value_layer = snt.Linear(1) - def __call__(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: - """Returns a (Logits, Value) tuple.""" - logits = self._policy_layer(inputs) # [B, A] - value = tf.squeeze(self._value_layer(inputs), axis=-1) # [B] + def __call__(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + """Returns a (Logits, Value) tuple.""" + logits = self._policy_layer(inputs) # [B, A] + value = tf.squeeze(self._value_layer(inputs), axis=-1) # [B] - return logits, value + return logits, value class CreateVariableTest(parameterized.TestCase): - """Tests for tf2_utils.create_variables method.""" - - @parameterized.parameters([True, False]) - def test_feedforward(self, recurrent: bool): - model = snt.Linear(42) - if recurrent: - model = snt.DeepRNN([model]) - input_spec = specs.Array(shape=(10,), dtype=np.float32) - tf2_utils.create_variables(model, [input_spec]) - variables: Sequence[tf.Variable] = model.variables - shapes = [v.shape.as_list() for v in variables] - self.assertSequenceEqual(shapes, [[42], [10, 42]]) - - @parameterized.parameters([True, False]) - def test_output_spec_feedforward(self, recurrent: bool): - input_spec = specs.Array(shape=(10,), dtype=np.float32) - model = snt.Linear(42) - expected_spec = tf.TensorSpec(shape=(42,), dtype=tf.float32) - if recurrent: - model = snt.DeepRNN([model]) - expected_spec = (expected_spec, ()) - - output_spec = tf2_utils.create_variables(model, [input_spec]) - self.assertEqual(output_spec, expected_spec) - - def test_multiple_outputs(self): - model = PolicyValueHead(42) - input_spec = specs.Array(shape=(10,), dtype=np.float32) - expected_spec = (tf.TensorSpec(shape=(42,), dtype=tf.float32), - tf.TensorSpec(shape=(), dtype=tf.float32)) - output_spec = tf2_utils.create_variables(model, [input_spec]) - variables: Sequence[tf.Variable] = model.variables - shapes = [v.shape.as_list() for v in variables] - self.assertSequenceEqual(shapes, [[42], [10, 42], [1], [10, 1]]) - self.assertSequenceEqual(output_spec, expected_spec) - - def test_scalar_output(self): - model = tf2_utils.to_sonnet_module(tf.reduce_sum) - input_spec = specs.Array(shape=(10,), dtype=np.float32) - expected_spec = tf.TensorSpec(shape=(), dtype=tf.float32) - output_spec = tf2_utils.create_variables(model, [input_spec]) - self.assertEqual(model.variables, ()) - self.assertEqual(output_spec, expected_spec) - - def test_none_output(self): - model = tf2_utils.to_sonnet_module(lambda x: None) - input_spec = specs.Array(shape=(10,), dtype=np.float32) - expected_spec = None - output_spec = tf2_utils.create_variables(model, [input_spec]) - self.assertEqual(model.variables, ()) - self.assertEqual(output_spec, expected_spec) - - def test_multiple_inputs_and_outputs(self): - def transformation(aa, bb, cc): - return (tf.concat([aa, bb, cc], axis=-1), - tf.concat([bb, cc], axis=-1)) - - model = tf2_utils.to_sonnet_module(transformation) - dtype = np.float32 - input_spec = [specs.Array(shape=(2,), dtype=dtype), - specs.Array(shape=(3,), dtype=dtype), - specs.Array(shape=(4,), dtype=dtype)] - expected_output_spec = (tf.TensorSpec(shape=(9,), dtype=dtype), - tf.TensorSpec(shape=(7,), dtype=dtype)) - output_spec = tf2_utils.create_variables(model, input_spec) - self.assertEqual(model.variables, ()) - self.assertEqual(output_spec, expected_output_spec) + """Tests for tf2_utils.create_variables method.""" + + @parameterized.parameters([True, False]) + def test_feedforward(self, recurrent: bool): + model = snt.Linear(42) + if recurrent: + model = snt.DeepRNN([model]) + input_spec = specs.Array(shape=(10,), dtype=np.float32) + tf2_utils.create_variables(model, [input_spec]) + variables: Sequence[tf.Variable] = model.variables + shapes = [v.shape.as_list() for v in variables] + self.assertSequenceEqual(shapes, [[42], [10, 42]]) + + @parameterized.parameters([True, False]) + def test_output_spec_feedforward(self, recurrent: bool): + input_spec = specs.Array(shape=(10,), dtype=np.float32) + model = snt.Linear(42) + expected_spec = tf.TensorSpec(shape=(42,), dtype=tf.float32) + if recurrent: + model = snt.DeepRNN([model]) + expected_spec = (expected_spec, ()) + + output_spec = tf2_utils.create_variables(model, [input_spec]) + self.assertEqual(output_spec, expected_spec) + + def test_multiple_outputs(self): + model = PolicyValueHead(42) + input_spec = specs.Array(shape=(10,), dtype=np.float32) + expected_spec = ( + tf.TensorSpec(shape=(42,), dtype=tf.float32), + tf.TensorSpec(shape=(), dtype=tf.float32), + ) + output_spec = tf2_utils.create_variables(model, [input_spec]) + variables: Sequence[tf.Variable] = model.variables + shapes = [v.shape.as_list() for v in variables] + self.assertSequenceEqual(shapes, [[42], [10, 42], [1], [10, 1]]) + self.assertSequenceEqual(output_spec, expected_spec) + + def test_scalar_output(self): + model = tf2_utils.to_sonnet_module(tf.reduce_sum) + input_spec = specs.Array(shape=(10,), dtype=np.float32) + expected_spec = tf.TensorSpec(shape=(), dtype=tf.float32) + output_spec = tf2_utils.create_variables(model, [input_spec]) + self.assertEqual(model.variables, ()) + self.assertEqual(output_spec, expected_spec) + + def test_none_output(self): + model = tf2_utils.to_sonnet_module(lambda x: None) + input_spec = specs.Array(shape=(10,), dtype=np.float32) + expected_spec = None + output_spec = tf2_utils.create_variables(model, [input_spec]) + self.assertEqual(model.variables, ()) + self.assertEqual(output_spec, expected_spec) + + def test_multiple_inputs_and_outputs(self): + def transformation(aa, bb, cc): + return (tf.concat([aa, bb, cc], axis=-1), tf.concat([bb, cc], axis=-1)) + + model = tf2_utils.to_sonnet_module(transformation) + dtype = np.float32 + input_spec = [ + specs.Array(shape=(2,), dtype=dtype), + specs.Array(shape=(3,), dtype=dtype), + specs.Array(shape=(4,), dtype=dtype), + ] + expected_output_spec = ( + tf.TensorSpec(shape=(9,), dtype=dtype), + tf.TensorSpec(shape=(7,), dtype=dtype), + ) + output_spec = tf2_utils.create_variables(model, input_spec) + self.assertEqual(model.variables, ()) + self.assertEqual(output_spec, expected_output_spec) class Tf2UtilsTest(parameterized.TestCase): - """Tests for tf2_utils methods.""" + """Tests for tf2_utils methods.""" - def test_batch_concat(self): - batch_size = 32 - inputs = [ - tf.zeros(shape=(batch_size, 2)), - { - 'foo': tf.zeros(shape=(batch_size, 5, 3)) - }, - [tf.zeros(shape=(batch_size, 1))], - ] + def test_batch_concat(self): + batch_size = 32 + inputs = [ + tf.zeros(shape=(batch_size, 2)), + {"foo": tf.zeros(shape=(batch_size, 5, 3))}, + [tf.zeros(shape=(batch_size, 1))], + ] - output_shape = tf2_utils.batch_concat(inputs).shape.as_list() - expected_shape = [batch_size, 2 + 5 * 3 + 1] - self.assertSequenceEqual(output_shape, expected_shape) + output_shape = tf2_utils.batch_concat(inputs).shape.as_list() + expected_shape = [batch_size, 2 + 5 * 3 + 1] + self.assertSequenceEqual(output_shape, expected_shape) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/tf/variable_utils.py b/acme/tf/variable_utils.py index 462d96b2a1..030e578bc3 100644 --- a/acme/tf/variable_utils.py +++ b/acme/tf/variable_utils.py @@ -17,37 +17,39 @@ from concurrent import futures from typing import Mapping, Optional, Sequence -from acme import core - import tensorflow as tf import tree +from acme import core + class VariableClient: - """A variable client for updating variables from a remote source.""" - - def __init__(self, - client: core.VariableSource, - variables: Mapping[str, Sequence[tf.Variable]], - update_period: int = 1): - self._keys = list(variables.keys()) - self._variables = tree.flatten(list(variables.values())) - self._call_counter = 0 - self._update_period = update_period - self._client = client - self._request = lambda: client.get_variables(self._keys) - - # Create a single background thread to fetch variables without necessarily - # blocking the actor. - self._executor = futures.ThreadPoolExecutor(max_workers=1) - self._async_request = lambda: self._executor.submit(self._request) - - # Initialize this client's future to None to indicate to the `update()` - # method that there is no pending/running request. - self._future: Optional[futures.Future] = None - - def update(self, wait: bool = False): - """Periodically updates the variables with the latest copy from the source. + """A variable client for updating variables from a remote source.""" + + def __init__( + self, + client: core.VariableSource, + variables: Mapping[str, Sequence[tf.Variable]], + update_period: int = 1, + ): + self._keys = list(variables.keys()) + self._variables = tree.flatten(list(variables.values())) + self._call_counter = 0 + self._update_period = update_period + self._client = client + self._request = lambda: client.get_variables(self._keys) + + # Create a single background thread to fetch variables without necessarily + # blocking the actor. + self._executor = futures.ThreadPoolExecutor(max_workers=1) + self._async_request = lambda: self._executor.submit(self._request) + + # Initialize this client's future to None to indicate to the `update()` + # method that there is no pending/running request. + self._future: Optional[futures.Future] = None + + def update(self, wait: bool = False): + """Periodically updates the variables with the latest copy from the source. This stateful update method keeps track of the number of calls to it and, every `update_period` call, sends a request to its server to retrieve the @@ -65,44 +67,44 @@ def update(self, wait: bool = False): Args: wait: if True, executes blocking update. """ - # Track the number of calls (we only update periodically). - if self._call_counter < self._update_period: - self._call_counter += 1 - - period_reached: bool = self._call_counter >= self._update_period - - if period_reached and wait: - # Cancel any active request. - self._future: Optional[futures.Future] = None - self.update_and_wait() - self._call_counter = 0 - return - - if period_reached and self._future is None: - # The update period has been reached and no request has been sent yet, so - # making an asynchronous request now. - self._future = self._async_request() - self._call_counter = 0 - - if self._future is not None and self._future.done(): - # The active request is done so copy the result and remove the future. - self._copy(self._future.result()) - self._future: Optional[futures.Future] = None - else: - # There is either a pending/running request or we're between update - # periods, so just carry on. - return - - def update_and_wait(self): - """Immediately update and block until we get the result.""" - self._copy(self._request()) - - def _copy(self, new_variables: Sequence[Sequence[tf.Variable]]): - """Copies the new variables to the old ones.""" - - new_variables = tree.flatten(new_variables) - if len(self._variables) != len(new_variables): - raise ValueError('Length mismatch between old variables and new.') - - for new, old in zip(new_variables, self._variables): - old.assign(new) + # Track the number of calls (we only update periodically). + if self._call_counter < self._update_period: + self._call_counter += 1 + + period_reached: bool = self._call_counter >= self._update_period + + if period_reached and wait: + # Cancel any active request. + self._future: Optional[futures.Future] = None + self.update_and_wait() + self._call_counter = 0 + return + + if period_reached and self._future is None: + # The update period has been reached and no request has been sent yet, so + # making an asynchronous request now. + self._future = self._async_request() + self._call_counter = 0 + + if self._future is not None and self._future.done(): + # The active request is done so copy the result and remove the future. + self._copy(self._future.result()) + self._future: Optional[futures.Future] = None + else: + # There is either a pending/running request or we're between update + # periods, so just carry on. + return + + def update_and_wait(self): + """Immediately update and block until we get the result.""" + self._copy(self._request()) + + def _copy(self, new_variables: Sequence[Sequence[tf.Variable]]): + """Copies the new variables to the old ones.""" + + new_variables = tree.flatten(new_variables) + if len(self._variables) != len(new_variables): + raise ValueError("Length mismatch between old variables and new.") + + for new, old in zip(new_variables, self._variables): + old.assign(new) diff --git a/acme/tf/variable_utils_test.py b/acme/tf/variable_utils_test.py index 2c626c431b..8621dbeaf1 100644 --- a/acme/tf/variable_utils_test.py +++ b/acme/tf/variable_utils_test.py @@ -16,12 +16,12 @@ import threading -from acme.testing import fakes -from acme.tf import utils as tf2_utils -from acme.tf import variable_utils as tf2_variable_utils import sonnet as snt import tensorflow as tf +from acme.testing import fakes +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils _MLP_LAYERS = [50, 30] _INPUT_SIZE = 28 @@ -30,104 +30,107 @@ class VariableClientTest(tf.test.TestCase): + def setUp(self): + super().setUp() + + # Create two instances of the same model. + self._actor_model = snt.nets.MLP(_MLP_LAYERS) + self._learner_model = snt.nets.MLP(_MLP_LAYERS) + + # Create variables first. + input_spec = tf.TensorSpec(shape=(_INPUT_SIZE,), dtype=tf.float32) + tf2_utils.create_variables(self._actor_model, [input_spec]) + tf2_utils.create_variables(self._learner_model, [input_spec]) + + def test_update_and_wait(self): + # Create a variable source (emulating the learner). + np_learner_variables = tf2_utils.to_numpy(self._learner_model.variables) + variable_source = fakes.VariableSource(np_learner_variables) + + # Create a variable client (emulating the actor). + variable_client = tf2_variable_utils.VariableClient( + variable_source, {"policy": self._actor_model.variables} + ) + + # Create some random batch of test input: + x = tf.random.normal(shape=(_BATCH_SIZE, _INPUT_SIZE)) + + # Before copying variables, the models have different outputs. + self.assertNotAllClose(self._actor_model(x), self._learner_model(x)) + + # Update the variable client. + variable_client.update_and_wait() + + # After copying variables (by updating the client), the models are the same. + self.assertAllClose(self._actor_model(x), self._learner_model(x)) + + def test_update(self): + # Create a barrier to be shared between the test body and the variable + # source. The barrier will block until, in this case, two threads call + # wait(). Note that the (fake) variable source will call it within its + # get_variables() call. + barrier = threading.Barrier(2) + + # Create a variable source (emulating the learner). + np_learner_variables = tf2_utils.to_numpy(self._learner_model.variables) + variable_source = fakes.VariableSource(np_learner_variables, barrier) + + # Create a variable client (emulating the actor). + variable_client = tf2_variable_utils.VariableClient( + variable_source, + {"policy": self._actor_model.variables}, + update_period=_UPDATE_PERIOD, + ) + + # Create some random batch of test input: + x = tf.random.normal(shape=(_BATCH_SIZE, _INPUT_SIZE)) + + # Create variables by doing the computation once. + learner_output = self._learner_model(x) + actor_output = self._actor_model(x) + del learner_output, actor_output + + for _ in range(_UPDATE_PERIOD): + # Before the update period is reached, the models have different outputs. + self.assertNotAllClose( + self._actor_model.variables, self._learner_model.variables + ) + + # Before the update period is reached, the variable client should not make + # any requests for variables. + self.assertIsNone(variable_client._future) + + variable_client.update() + + # Make sure the last call created a request for variables and reset the + # internal call counter. + self.assertIsNotNone(variable_client._future) + self.assertEqual(variable_client._call_counter, 0) + future = variable_client._future + + for _ in range(_UPDATE_PERIOD): + # Before the barrier allows the variables to be released, the models have + # different outputs. + self.assertNotAllClose( + self._actor_model.variables, self._learner_model.variables + ) + + variable_client.update() + + # Make sure no new requests are made. + self.assertEqual(variable_client._future, future) + + # Calling wait() on the barrier will now allow the variables to be copied + # over from source to client. + barrier.wait() + + # Update once more to ensure the variables are copied over. + while variable_client._future is not None: + variable_client.update() + + # After a number of update calls, the variables should be the same. + self.assertAllClose(self._actor_model.variables, self._learner_model.variables) - def setUp(self): - super().setUp() - - # Create two instances of the same model. - self._actor_model = snt.nets.MLP(_MLP_LAYERS) - self._learner_model = snt.nets.MLP(_MLP_LAYERS) - - # Create variables first. - input_spec = tf.TensorSpec(shape=(_INPUT_SIZE,), dtype=tf.float32) - tf2_utils.create_variables(self._actor_model, [input_spec]) - tf2_utils.create_variables(self._learner_model, [input_spec]) - - def test_update_and_wait(self): - # Create a variable source (emulating the learner). - np_learner_variables = tf2_utils.to_numpy(self._learner_model.variables) - variable_source = fakes.VariableSource(np_learner_variables) - - # Create a variable client (emulating the actor). - variable_client = tf2_variable_utils.VariableClient( - variable_source, {'policy': self._actor_model.variables}) - - # Create some random batch of test input: - x = tf.random.normal(shape=(_BATCH_SIZE, _INPUT_SIZE)) - - # Before copying variables, the models have different outputs. - self.assertNotAllClose(self._actor_model(x), self._learner_model(x)) - - # Update the variable client. - variable_client.update_and_wait() - - # After copying variables (by updating the client), the models are the same. - self.assertAllClose(self._actor_model(x), self._learner_model(x)) - - def test_update(self): - # Create a barrier to be shared between the test body and the variable - # source. The barrier will block until, in this case, two threads call - # wait(). Note that the (fake) variable source will call it within its - # get_variables() call. - barrier = threading.Barrier(2) - - # Create a variable source (emulating the learner). - np_learner_variables = tf2_utils.to_numpy(self._learner_model.variables) - variable_source = fakes.VariableSource(np_learner_variables, barrier) - - # Create a variable client (emulating the actor). - variable_client = tf2_variable_utils.VariableClient( - variable_source, {'policy': self._actor_model.variables}, - update_period=_UPDATE_PERIOD) - - # Create some random batch of test input: - x = tf.random.normal(shape=(_BATCH_SIZE, _INPUT_SIZE)) - - # Create variables by doing the computation once. - learner_output = self._learner_model(x) - actor_output = self._actor_model(x) - del learner_output, actor_output - - for _ in range(_UPDATE_PERIOD): - # Before the update period is reached, the models have different outputs. - self.assertNotAllClose(self._actor_model.variables, - self._learner_model.variables) - - # Before the update period is reached, the variable client should not make - # any requests for variables. - self.assertIsNone(variable_client._future) - - variable_client.update() - - # Make sure the last call created a request for variables and reset the - # internal call counter. - self.assertIsNotNone(variable_client._future) - self.assertEqual(variable_client._call_counter, 0) - future = variable_client._future - - for _ in range(_UPDATE_PERIOD): - # Before the barrier allows the variables to be released, the models have - # different outputs. - self.assertNotAllClose(self._actor_model.variables, - self._learner_model.variables) - - variable_client.update() - - # Make sure no new requests are made. - self.assertEqual(variable_client._future, future) - - # Calling wait() on the barrier will now allow the variables to be copied - # over from source to client. - barrier.wait() - - # Update once more to ensure the variables are copied over. - while variable_client._future is not None: - variable_client.update() - - # After a number of update calls, the variables should be the same. - self.assertAllClose(self._actor_model.variables, - self._learner_model.variables) - -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/acme/types.py b/acme/types.py index 8305880f7a..e0756ce1af 100644 --- a/acme/types.py +++ b/acme/types.py @@ -15,6 +15,7 @@ """Common types used throughout Acme.""" from typing import Any, Callable, Iterable, Mapping, NamedTuple, Union + from acme import specs # Define types for nested arrays and tensors. @@ -24,9 +25,7 @@ # pytype: disable=not-supported-yet NestedSpec = Union[ - specs.Array, - Iterable['NestedSpec'], - Mapping[Any, 'NestedSpec'], + specs.Array, Iterable["NestedSpec"], Mapping[Any, "NestedSpec"], ] # pytype: enable=not-supported-yet @@ -38,7 +37,7 @@ class Batches(int): - """Helper class for specification of quantities in units of batches. + """Helper class for specification of quantities in units of batches. Example usage: @@ -56,10 +55,11 @@ class Batches(int): class Transition(NamedTuple): - """Container for a transition.""" - observation: NestedArray - action: NestedArray - reward: NestedArray - discount: NestedArray - next_observation: NestedArray - extras: NestedArray = () + """Container for a transition.""" + + observation: NestedArray + action: NestedArray + reward: NestedArray + discount: NestedArray + next_observation: NestedArray + extras: NestedArray = () diff --git a/acme/utils/async_utils.py b/acme/utils/async_utils.py index aaccc5611f..788fc07b37 100644 --- a/acme/utils/async_utils.py +++ b/acme/utils/async_utils.py @@ -16,24 +16,23 @@ import queue import threading -from typing import Callable, TypeVar, Generic +from typing import Callable, Generic, TypeVar from absl import logging - E = TypeVar("E") class AsyncExecutor(Generic[E]): - """Executes a blocking function asynchronously on a queue of items.""" + """Executes a blocking function asynchronously on a queue of items.""" - def __init__( - self, - fn: Callable[[E], None], - queue_size: int = 1, - interruptible_interval_secs: float = 1.0, - ): - """Buffers elements in a queue and runs `fn` asynchronously.. + def __init__( + self, + fn: Callable[[E], None], + queue_size: int = 1, + interruptible_interval_secs: float = 1.0, + ): + """Buffers elements in a queue and runs `fn` asynchronously.. NOTE: Once closed, `AsyncExecutor` will block until current `fn` finishes but is not guaranteed to dequeue all elements currently stored in @@ -48,49 +47,49 @@ def __init__( queue operations after which the background threads check for errors and if background threads should stop. """ - self._data = queue.Queue(maxsize=queue_size) - self._should_stop = threading.Event() - self._errors = queue.Queue() - self._interruptible_interval_secs = interruptible_interval_secs - - def _dequeue() -> None: - """Dequeue data from a queue and invoke blocking call.""" - while not self._should_stop.is_set(): + self._data = queue.Queue(maxsize=queue_size) + self._should_stop = threading.Event() + self._errors = queue.Queue() + self._interruptible_interval_secs = interruptible_interval_secs + + def _dequeue() -> None: + """Dequeue data from a queue and invoke blocking call.""" + while not self._should_stop.is_set(): + try: + element = self._data.get(timeout=self._interruptible_interval_secs) + # Execute fn upon dequeuing an element from the data queue. + fn(element) + except queue.Empty: + # If queue is Empty for longer than the specified time interval, + # check again if should_stop has been requested and retry. + continue + except Exception as e: + logging.error("AsyncExecuter thread terminated with error.") + logging.exception(e) + self._errors.put(e) + self._should_stop.set() + raise # Never caught by anything, just terminates the thread. + + self._thread = threading.Thread(target=_dequeue, daemon=True) + self._thread.start() + + def _raise_on_error(self) -> None: try: - element = self._data.get(timeout=self._interruptible_interval_secs) - # Execute fn upon dequeuing an element from the data queue. - fn(element) + # Raise the error on the caller thread if an error has been raised in the + # looper thread. + raise self._errors.get_nowait() except queue.Empty: - # If queue is Empty for longer than the specified time interval, - # check again if should_stop has been requested and retry. - continue - except Exception as e: - logging.error("AsyncExecuter thread terminated with error.") - logging.exception(e) - self._errors.put(e) - self._should_stop.set() - raise # Never caught by anything, just terminates the thread. - - self._thread = threading.Thread(target=_dequeue, daemon=True) - self._thread.start() - - def _raise_on_error(self) -> None: - try: - # Raise the error on the caller thread if an error has been raised in the - # looper thread. - raise self._errors.get_nowait() - except queue.Empty: - pass - - def close(self): - self._should_stop.set() - # Join all background threads. - self._thread.join() - # Raise errors produced by background threads. - self._raise_on_error() - - def put(self, element: E) -> None: - """Puts `element` asynchronuously onto the underlying data queue. + pass + + def close(self): + self._should_stop.set() + # Join all background threads. + self._thread.join() + # Raise errors produced by background threads. + self._raise_on_error() + + def put(self, element: E) -> None: + """Puts `element` asynchronuously onto the underlying data queue. The write call blocks if the underlying data_queue contains `queue_size` elements for over `self._interruptible_interval_secs` second, in which @@ -101,13 +100,13 @@ def put(self, element: E) -> None: element: an element to be put into the underlying data queue and dequeued asynchronuously for `fn(element)` call. """ - while not self._should_stop.is_set(): - try: - self._data.put(element, timeout=self._interruptible_interval_secs) - break - except queue.Full: - continue - else: - # If `should_stop` has been set, then raises if any has been raised on - # the background thread. - self._raise_on_error() + while not self._should_stop.is_set(): + try: + self._data.put(element, timeout=self._interruptible_interval_secs) + break + except queue.Full: + continue + else: + # If `should_stop` has been set, then raises if any has been raised on + # the background thread. + self._raise_on_error() diff --git a/acme/utils/counting.py b/acme/utils/counting.py index 8492b3d456..a3195dc0b9 100644 --- a/acme/utils/counting.py +++ b/acme/utils/counting.py @@ -24,14 +24,16 @@ class Counter(core.Saveable): - """A simple counter object that can periodically sync with a parent.""" + """A simple counter object that can periodically sync with a parent.""" - def __init__(self, - parent: Optional['Counter'] = None, - prefix: str = '', - time_delta: float = 1.0, - return_only_prefixed: bool = False): - """Initialize the counter. + def __init__( + self, + parent: Optional["Counter"] = None, + prefix: str = "", + time_delta: float = 1.0, + return_only_prefixed: bool = False, + ): + """Initialize the counter. Args: parent: a Counter object to cache locally (or None for no caching). @@ -43,24 +45,24 @@ def __init__(self, `get_counts`. The `prefix` is stripped from returned count names. """ - self._parent = parent - self._prefix = prefix - self._time_delta = time_delta + self._parent = parent + self._prefix = prefix + self._time_delta = time_delta - # Hold local counts and we'll lock around that. - # These are counts to be synced to the parent and the cache. - self._counts = {} - self._lock = threading.Lock() + # Hold local counts and we'll lock around that. + # These are counts to be synced to the parent and the cache. + self._counts = {} + self._lock = threading.Lock() - # We'll sync periodically (when the last sync was more than self._time_delta - # seconds ago.) - self._cache = {} - self._last_sync_time = 0.0 + # We'll sync periodically (when the last sync was more than self._time_delta + # seconds ago.) + self._cache = {} + self._last_sync_time = 0.0 - self._return_only_prefixed = return_only_prefixed + self._return_only_prefixed = return_only_prefixed - def increment(self, **counts: Number) -> Dict[str, Number]: - """Increment a set of counters. + def increment(self, **counts: Number) -> Dict[str, Number]: + """Increment a set of counters. Args: **counts: keyword arguments specifying count increments. @@ -69,61 +71,65 @@ def increment(self, **counts: Number) -> Dict[str, Number]: The [name, value] mapping of all counters stored, i.e. this will also include counts that were not updated by this call to increment. """ - with self._lock: - for key, value in counts.items(): - self._counts.setdefault(key, 0) - self._counts[key] += value - return self.get_counts() - - def get_counts(self) -> Dict[str, Number]: - """Return all counts tracked by this counter.""" - now = time.time() - # TODO(b/144421838): use futures instead of blocking. - if self._parent and (now - self._last_sync_time) > self._time_delta: - with self._lock: + with self._lock: + for key, value in counts.items(): + self._counts.setdefault(key, 0) + self._counts[key] += value + return self.get_counts() + + def get_counts(self) -> Dict[str, Number]: + """Return all counts tracked by this counter.""" + now = time.time() + # TODO(b/144421838): use futures instead of blocking. + if self._parent and (now - self._last_sync_time) > self._time_delta: + with self._lock: + counts = _prefix_keys(self._counts, self._prefix) + # Reset the local counts, as they will be merged into the parent and the + # cache. + self._counts = {} + self._cache = self._parent.increment(**counts) + self._last_sync_time = now + + # Potentially prefix the keys in the counts dictionary. counts = _prefix_keys(self._counts, self._prefix) - # Reset the local counts, as they will be merged into the parent and the - # cache. - self._counts = {} - self._cache = self._parent.increment(**counts) - self._last_sync_time = now - - # Potentially prefix the keys in the counts dictionary. - counts = _prefix_keys(self._counts, self._prefix) - - # If there's no prefix make a copy of the dictionary so we don't modify the - # internal self._counts. - if not self._prefix: - counts = dict(counts) - - # Combine local counts with any parent counts. - for key, value in self._cache.items(): - counts[key] = counts.get(key, 0) + value - - if self._prefix and self._return_only_prefixed: - counts = dict([(key[len(self._prefix) + 1:], value) - for key, value in counts.items() - if key.startswith(f'{self._prefix}_')]) - return counts - - def save(self) -> Mapping[str, Mapping[str, Number]]: - return {'counts': self._counts, 'cache': self._cache} - - def restore(self, state: Mapping[str, Mapping[str, Number]]): - # Force a sync, if necessary, on the next get_counts call. - self._last_sync_time = 0. - self._counts = state['counts'] - self._cache = state['cache'] - def get_steps_key(self) -> str: - """Returns the key to use for steps by this counter.""" - if not self._prefix or self._return_only_prefixed: - return 'steps' - return f'{self._prefix}_steps' + # If there's no prefix make a copy of the dictionary so we don't modify the + # internal self._counts. + if not self._prefix: + counts = dict(counts) + + # Combine local counts with any parent counts. + for key, value in self._cache.items(): + counts[key] = counts.get(key, 0) + value + + if self._prefix and self._return_only_prefixed: + counts = dict( + [ + (key[len(self._prefix) + 1 :], value) + for key, value in counts.items() + if key.startswith(f"{self._prefix}_") + ] + ) + return counts + + def save(self) -> Mapping[str, Mapping[str, Number]]: + return {"counts": self._counts, "cache": self._cache} + + def restore(self, state: Mapping[str, Mapping[str, Number]]): + # Force a sync, if necessary, on the next get_counts call. + self._last_sync_time = 0.0 + self._counts = state["counts"] + self._cache = state["cache"] + + def get_steps_key(self) -> str: + """Returns the key to use for steps by this counter.""" + if not self._prefix or self._return_only_prefixed: + return "steps" + return f"{self._prefix}_steps" def _prefix_keys(dictionary: Dict[str, Number], prefix: str): - """Return a dictionary with prefixed keys. + """Return a dictionary with prefixed keys. Args: dictionary: dictionary to return a copy of. @@ -134,6 +140,6 @@ def _prefix_keys(dictionary: Dict[str, Number], prefix: str): "{prefix}_{key}". If the prefix is the empty string it returns the given dictionary unchanged. """ - if prefix: - dictionary = {f'{prefix}_{k}': v for k, v in dictionary.items()} - return dictionary + if prefix: + dictionary = {f"{prefix}_{k}": v for k, v in dictionary.items()} + return dictionary diff --git a/acme/utils/counting_test.py b/acme/utils/counting_test.py index 1736b632aa..d73a9b17b5 100644 --- a/acme/utils/counting_test.py +++ b/acme/utils/counting_test.py @@ -16,102 +16,106 @@ import threading -from acme.utils import counting - from absl.testing import absltest +from acme.utils import counting + class Barrier: - """Defines a simple barrier class to synchronize on a particular event.""" + """Defines a simple barrier class to synchronize on a particular event.""" - def __init__(self, num_threads): - """Constructor. + def __init__(self, num_threads): + """Constructor. Args: num_threads: int - how many threads will be syncronizing on this barrier """ - self._num_threads = num_threads - self._count = 0 - self._cond = threading.Condition() + self._num_threads = num_threads + self._count = 0 + self._cond = threading.Condition() - def wait(self): - """Waits on the barrier until all threads have called this method.""" - with self._cond: - self._count += 1 - self._cond.notifyAll() - while self._count < self._num_threads: - self._cond.wait() + def wait(self): + """Waits on the barrier until all threads have called this method.""" + with self._cond: + self._count += 1 + self._cond.notifyAll() + while self._count < self._num_threads: + self._cond.wait() class CountingTest(absltest.TestCase): - - def test_counter_threading(self): - counter = counting.Counter() - num_threads = 10 - barrier = Barrier(num_threads) - - # Increment in every thread at the same time. - def add_to_counter(): - barrier.wait() - counter.increment(foo=1) - - # Run the threads. - threads = [] - for _ in range(num_threads): - t = threading.Thread(target=add_to_counter) - t.start() - threads.append(t) - for t in threads: - t.join() - - # Make sure the counter has been incremented once per thread. - counts = counter.get_counts() - self.assertEqual(counts['foo'], num_threads) - - def test_counter_caching(self): - parent = counting.Counter() - counter = counting.Counter(parent, time_delta=0.) - counter.increment(foo=12) - self.assertEqual(parent.get_counts(), counter.get_counts()) - - def test_shared_counts(self): - # Two counters with shared parent should share counts (modulo namespacing). - parent = counting.Counter() - child1 = counting.Counter(parent, 'child1') - child2 = counting.Counter(parent, 'child2') - child1.increment(foo=1) - result = child2.increment(foo=2) - expected = {'child1_foo': 1, 'child2_foo': 2} - self.assertEqual(result, expected) - - def test_return_only_prefixed(self): - parent = counting.Counter() - child1 = counting.Counter( - parent, 'child1', time_delta=0., return_only_prefixed=False) - child2 = counting.Counter( - parent, 'child2', time_delta=0., return_only_prefixed=True) - child1.increment(foo=1) - child2.increment(bar=1) - self.assertEqual(child1.get_counts(), {'child1_foo': 1, 'child2_bar': 1}) - self.assertEqual(child2.get_counts(), {'bar': 1}) - - def test_get_steps_key(self): - parent = counting.Counter() - child1 = counting.Counter( - parent, 'child1', time_delta=0., return_only_prefixed=False) - child2 = counting.Counter( - parent, 'child2', time_delta=0., return_only_prefixed=True) - self.assertEqual(child1.get_steps_key(), 'child1_steps') - self.assertEqual(child2.get_steps_key(), 'steps') - child1.increment(steps=1) - child2.increment(steps=2) - self.assertEqual(child1.get_counts().get(child1.get_steps_key()), 1) - self.assertEqual(child2.get_counts().get(child2.get_steps_key()), 2) - - def test_parent_prefix(self): - parent = counting.Counter(prefix='parent') - child = counting.Counter(parent, prefix='child', time_delta=0.) - self.assertEqual(child.get_steps_key(), 'child_steps') - -if __name__ == '__main__': - absltest.main() + def test_counter_threading(self): + counter = counting.Counter() + num_threads = 10 + barrier = Barrier(num_threads) + + # Increment in every thread at the same time. + def add_to_counter(): + barrier.wait() + counter.increment(foo=1) + + # Run the threads. + threads = [] + for _ in range(num_threads): + t = threading.Thread(target=add_to_counter) + t.start() + threads.append(t) + for t in threads: + t.join() + + # Make sure the counter has been incremented once per thread. + counts = counter.get_counts() + self.assertEqual(counts["foo"], num_threads) + + def test_counter_caching(self): + parent = counting.Counter() + counter = counting.Counter(parent, time_delta=0.0) + counter.increment(foo=12) + self.assertEqual(parent.get_counts(), counter.get_counts()) + + def test_shared_counts(self): + # Two counters with shared parent should share counts (modulo namespacing). + parent = counting.Counter() + child1 = counting.Counter(parent, "child1") + child2 = counting.Counter(parent, "child2") + child1.increment(foo=1) + result = child2.increment(foo=2) + expected = {"child1_foo": 1, "child2_foo": 2} + self.assertEqual(result, expected) + + def test_return_only_prefixed(self): + parent = counting.Counter() + child1 = counting.Counter( + parent, "child1", time_delta=0.0, return_only_prefixed=False + ) + child2 = counting.Counter( + parent, "child2", time_delta=0.0, return_only_prefixed=True + ) + child1.increment(foo=1) + child2.increment(bar=1) + self.assertEqual(child1.get_counts(), {"child1_foo": 1, "child2_bar": 1}) + self.assertEqual(child2.get_counts(), {"bar": 1}) + + def test_get_steps_key(self): + parent = counting.Counter() + child1 = counting.Counter( + parent, "child1", time_delta=0.0, return_only_prefixed=False + ) + child2 = counting.Counter( + parent, "child2", time_delta=0.0, return_only_prefixed=True + ) + self.assertEqual(child1.get_steps_key(), "child1_steps") + self.assertEqual(child2.get_steps_key(), "steps") + child1.increment(steps=1) + child2.increment(steps=2) + self.assertEqual(child1.get_counts().get(child1.get_steps_key()), 1) + self.assertEqual(child2.get_counts().get(child2.get_steps_key()), 2) + + def test_parent_prefix(self): + parent = counting.Counter(prefix="parent") + child = counting.Counter(parent, prefix="child", time_delta=0.0) + self.assertEqual(child.get_steps_key(), "child_steps") + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/utils/experiment_utils.py b/acme/utils/experiment_utils.py index d4cccf8f7c..01f6a5390d 100644 --- a/acme/utils/experiment_utils.py +++ b/acme/utils/experiment_utils.py @@ -19,14 +19,14 @@ from acme.utils import loggers -def make_experiment_logger(label: str, - steps_key: Optional[str] = None, - task_instance: int = 0) -> loggers.Logger: - del task_instance - if steps_key is None: - steps_key = f'{label}_steps' - return loggers.make_default_logger(label=label, steps_key=steps_key) +def make_experiment_logger( + label: str, steps_key: Optional[str] = None, task_instance: int = 0 +) -> loggers.Logger: + del task_instance + if steps_key is None: + steps_key = f"{label}_steps" + return loggers.make_default_logger(label=label, steps_key=steps_key) def create_experiment_logger_factory() -> loggers.LoggerFactory: - return make_experiment_logger + return make_experiment_logger diff --git a/acme/utils/frozen_learner.py b/acme/utils/frozen_learner.py index d33d221b12..37f787e2f7 100644 --- a/acme/utils/frozen_learner.py +++ b/acme/utils/frozen_learner.py @@ -20,12 +20,12 @@ class FrozenLearner(acme.Learner): - """Wraps a learner ignoring the step calls, i.e. freezing it.""" + """Wraps a learner ignoring the step calls, i.e. freezing it.""" - def __init__(self, - learner: acme.Learner, - step_fn: Optional[Callable[[], None]] = None): - """Initializes the frozen learner. + def __init__( + self, learner: acme.Learner, step_fn: Optional[Callable[[], None]] = None + ): + """Initializes the frozen learner. Args: learner: Learner to be wrapped. @@ -33,26 +33,26 @@ def __init__(self, This can be used, e.g. to drop samples from an iterator that would normally be consumed by the learner. """ - self._learner = learner - self._step_fn = step_fn + self._learner = learner + self._step_fn = step_fn - def step(self): - """See base class.""" - if self._step_fn: - self._step_fn() + def step(self): + """See base class.""" + if self._step_fn: + self._step_fn() - def run(self, num_steps: Optional[int] = None): - """See base class.""" - self._learner.run(num_steps) + def run(self, num_steps: Optional[int] = None): + """See base class.""" + self._learner.run(num_steps) - def save(self): - """See base class.""" - return self._learner.save() + def save(self): + """See base class.""" + return self._learner.save() - def restore(self, state): - """See base class.""" - self._learner.restore(state) + def restore(self, state): + """See base class.""" + self._learner.restore(state) - def get_variables(self, names: Sequence[str]) -> List[acme.types.NestedArray]: - """See base class.""" - return self._learner.get_variables(names) + def get_variables(self, names: Sequence[str]) -> List[acme.types.NestedArray]: + """See base class.""" + return self._learner.get_variables(names) diff --git a/acme/utils/frozen_learner_test.py b/acme/utils/frozen_learner_test.py index aabd3eb6d7..ff378b401c 100644 --- a/acme/utils/frozen_learner_test.py +++ b/acme/utils/frozen_learner_test.py @@ -16,62 +16,62 @@ from unittest import mock +from absl.testing import absltest + import acme from acme.utils import frozen_learner -from absl.testing import absltest class FrozenLearnerTest(absltest.TestCase): + @mock.patch.object(acme, "Learner", autospec=True) + def test_step_fn(self, mock_learner): + num_calls = 0 - @mock.patch.object(acme, 'Learner', autospec=True) - def test_step_fn(self, mock_learner): - num_calls = 0 - - def step_fn(): - nonlocal num_calls - num_calls += 1 + def step_fn(): + nonlocal num_calls + num_calls += 1 - learner = frozen_learner.FrozenLearner(mock_learner, step_fn=step_fn) + learner = frozen_learner.FrozenLearner(mock_learner, step_fn=step_fn) - # Step two times. - learner.step() - learner.step() + # Step two times. + learner.step() + learner.step() - self.assertEqual(num_calls, 2) - # step() method of the wrapped learner should not be called. - mock_learner.step.assert_not_called() + self.assertEqual(num_calls, 2) + # step() method of the wrapped learner should not be called. + mock_learner.step.assert_not_called() - @mock.patch.object(acme, 'Learner', autospec=True) - def test_no_step_fn(self, mock_learner): - learner = frozen_learner.FrozenLearner(mock_learner) - learner.step() - # step() method of the wrapped learner should not be called. - mock_learner.step.assert_not_called() + @mock.patch.object(acme, "Learner", autospec=True) + def test_no_step_fn(self, mock_learner): + learner = frozen_learner.FrozenLearner(mock_learner) + learner.step() + # step() method of the wrapped learner should not be called. + mock_learner.step.assert_not_called() - @mock.patch.object(acme, 'Learner', autospec=True) - def test_save_and_restore(self, mock_learner): - learner = frozen_learner.FrozenLearner(mock_learner) + @mock.patch.object(acme, "Learner", autospec=True) + def test_save_and_restore(self, mock_learner): + learner = frozen_learner.FrozenLearner(mock_learner) - mock_learner.save.return_value = 'state1' + mock_learner.save.return_value = "state1" - state = learner.save() - self.assertEqual(state, 'state1') + state = learner.save() + self.assertEqual(state, "state1") - learner.restore('state2') - # State of the wrapped learner should be restored. - mock_learner.restore.assert_called_once_with('state2') + learner.restore("state2") + # State of the wrapped learner should be restored. + mock_learner.restore.assert_called_once_with("state2") - @mock.patch.object(acme, 'Learner', autospec=True) - def test_get_variables(self, mock_learner): - learner = frozen_learner.FrozenLearner(mock_learner) + @mock.patch.object(acme, "Learner", autospec=True) + def test_get_variables(self, mock_learner): + learner = frozen_learner.FrozenLearner(mock_learner) - mock_learner.get_variables.return_value = [1, 2] + mock_learner.get_variables.return_value = [1, 2] - variables = learner.get_variables(['a', 'b']) - # Values should match with those returned by the wrapped learner. - self.assertEqual(variables, [1, 2]) - mock_learner.get_variables.assert_called_once_with(['a', 'b']) + variables = learner.get_variables(["a", "b"]) + # Values should match with those returned by the wrapped learner. + self.assertEqual(variables, [1, 2]) + mock_learner.get_variables.assert_called_once_with(["a", "b"]) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/utils/iterator_utils.py b/acme/utils/iterator_utils.py index 67a2f17fcf..57b12983d9 100644 --- a/acme/utils/iterator_utils.py +++ b/acme/utils/iterator_utils.py @@ -18,9 +18,10 @@ from typing import Any, Iterator, List, Sequence -def unzip_iterators(zipped_iterators: Iterator[Sequence[Any]], - num_sub_iterators: int) -> List[Iterator[Any]]: - """Returns unzipped iterators. +def unzip_iterators( + zipped_iterators: Iterator[Sequence[Any]], num_sub_iterators: int +) -> List[Iterator[Any]]: + """Returns unzipped iterators. Note that simply returning: [(x[i] for x in iter_tuple[i]) for i in range(num_sub_iterators)] @@ -31,8 +32,7 @@ def unzip_iterators(zipped_iterators: Iterator[Sequence[Any]], zipped_iterators: zipped iterators (e.g., from zip_iterators()). num_sub_iterators: the number of sub-iterators in the zipped iterator. """ - iter_tuple = itertools.tee(zipped_iterators, num_sub_iterators) - return [ - map(operator.itemgetter(i), iter_tuple[i]) - for i in range(num_sub_iterators) - ] + iter_tuple = itertools.tee(zipped_iterators, num_sub_iterators) + return [ + map(operator.itemgetter(i), iter_tuple[i]) for i in range(num_sub_iterators) + ] diff --git a/acme/utils/iterator_utils_test.py b/acme/utils/iterator_utils_test.py index ebe21f3a60..223c01fcd8 100644 --- a/acme/utils/iterator_utils_test.py +++ b/acme/utils/iterator_utils_test.py @@ -14,27 +14,25 @@ """Tests for iterator_utils.""" -from acme.utils import iterator_utils import numpy as np - from absl.testing import absltest +from acme.utils import iterator_utils -class IteratorUtilsTest(absltest.TestCase): - - def test_iterator_zipping(self): - def get_iters(): - x = iter(range(0, 10)) - y = iter(range(20, 30)) - return [x, y] +class IteratorUtilsTest(absltest.TestCase): + def test_iterator_zipping(self): + def get_iters(): + x = iter(range(0, 10)) + y = iter(range(20, 30)) + return [x, y] - zipped = zip(*get_iters()) - unzipped = iterator_utils.unzip_iterators(zipped, num_sub_iterators=2) - expected_x, expected_y = get_iters() - np.testing.assert_equal(list(unzipped[0]), list(expected_x)) - np.testing.assert_equal(list(unzipped[1]), list(expected_y)) + zipped = zip(*get_iters()) + unzipped = iterator_utils.unzip_iterators(zipped, num_sub_iterators=2) + expected_x, expected_y = get_iters() + np.testing.assert_equal(list(unzipped[0]), list(expected_x)) + np.testing.assert_equal(list(unzipped[1]), list(expected_y)) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/utils/loggers/__init__.py b/acme/utils/loggers/__init__.py index de2e89187a..fa9abe8a9e 100644 --- a/acme/utils/loggers/__init__.py +++ b/acme/utils/loggers/__init__.py @@ -17,23 +17,24 @@ from acme.utils.loggers.aggregators import Dispatcher from acme.utils.loggers.asynchronous import AsyncLogger from acme.utils.loggers.auto_close import AutoCloseLogger -from acme.utils.loggers.base import Logger -from acme.utils.loggers.base import LoggerFactory -from acme.utils.loggers.base import LoggerLabel -from acme.utils.loggers.base import LoggerStepsKey -from acme.utils.loggers.base import LoggingData -from acme.utils.loggers.base import NoOpLogger -from acme.utils.loggers.base import TaskInstance -from acme.utils.loggers.base import to_numpy +from acme.utils.loggers.base import ( + Logger, + LoggerFactory, + LoggerLabel, + LoggerStepsKey, + LoggingData, + NoOpLogger, + TaskInstance, + to_numpy, +) from acme.utils.loggers.constant import ConstantLogger from acme.utils.loggers.csv import CSVLogger from acme.utils.loggers.dataframe import InMemoryLogger -from acme.utils.loggers.filters import GatedFilter -from acme.utils.loggers.filters import KeyFilter -from acme.utils.loggers.filters import NoneFilter -from acme.utils.loggers.filters import TimeFilter +from acme.utils.loggers.default import ( + make_default_logger, +) # pylint: disable=g-bad-import-order +from acme.utils.loggers.filters import GatedFilter, KeyFilter, NoneFilter, TimeFilter from acme.utils.loggers.flatten import FlattenDictLogger -from acme.utils.loggers.default import make_default_logger # pylint: disable=g-bad-import-order from acme.utils.loggers.terminal import TerminalLogger from acme.utils.loggers.timestamp import TimestampLogger diff --git a/acme/utils/loggers/aggregators.py b/acme/utils/loggers/aggregators.py index 354f72cbe0..4a37881fa0 100644 --- a/acme/utils/loggers/aggregators.py +++ b/acme/utils/loggers/aggregators.py @@ -15,28 +15,29 @@ """Utilities for aggregating to other loggers.""" from typing import Callable, Optional, Sequence + from acme.utils.loggers import base class Dispatcher(base.Logger): - """Writes data to multiple `Logger` objects.""" - - def __init__( - self, - to: Sequence[base.Logger], - serialize_fn: Optional[Callable[[base.LoggingData], str]] = None, - ): - """Initialize `Dispatcher` connected to several `Logger` objects.""" - self._to = to - self._serialize_fn = serialize_fn - - def write(self, values: base.LoggingData): - """Writes `values` to the underlying `Logger` objects.""" - if self._serialize_fn: - values = self._serialize_fn(values) - for logger in self._to: - logger.write(values) - - def close(self): - for logger in self._to: - logger.close() + """Writes data to multiple `Logger` objects.""" + + def __init__( + self, + to: Sequence[base.Logger], + serialize_fn: Optional[Callable[[base.LoggingData], str]] = None, + ): + """Initialize `Dispatcher` connected to several `Logger` objects.""" + self._to = to + self._serialize_fn = serialize_fn + + def write(self, values: base.LoggingData): + """Writes `values` to the underlying `Logger` objects.""" + if self._serialize_fn: + values = self._serialize_fn(values) + for logger in self._to: + logger.write(values) + + def close(self): + for logger in self._to: + logger.close() diff --git a/acme/utils/loggers/asynchronous.py b/acme/utils/loggers/asynchronous.py index 06aeb005a3..4514cf7093 100644 --- a/acme/utils/loggers/asynchronous.py +++ b/acme/utils/loggers/asynchronous.py @@ -21,22 +21,22 @@ class AsyncLogger(base.Logger): - """Logger which makes the logging to another logger asyncronous.""" + """Logger which makes the logging to another logger asyncronous.""" - def __init__(self, to: base.Logger): - """Initializes the logger. + def __init__(self, to: base.Logger): + """Initializes the logger. Args: to: A `Logger` object to which the current object will forward its results when `write` is called. """ - self._to = to - self._async_worker = async_utils.AsyncExecutor(self._to.write, queue_size=5) + self._to = to + self._async_worker = async_utils.AsyncExecutor(self._to.write, queue_size=5) - def write(self, values: Mapping[str, Any]): - self._async_worker.put(values) + def write(self, values: Mapping[str, Any]): + self._async_worker.put(values) - def close(self): - """Closes the logger, closing is synchronous.""" - self._async_worker.close() - self._to.close() + def close(self): + """Closes the logger, closing is synchronous.""" + self._async_worker.close() + self._to.close() diff --git a/acme/utils/loggers/auto_close.py b/acme/utils/loggers/auto_close.py index c3a92eef9a..af6faf4315 100644 --- a/acme/utils/loggers/auto_close.py +++ b/acme/utils/loggers/auto_close.py @@ -20,24 +20,24 @@ class AutoCloseLogger(base.Logger): - """Logger which auto closes itself on exit if not already closed.""" - - def __init__(self, logger: base.Logger): - self._logger = logger - # The finalizer "logger.close" is invoked in one of the following scenario: - # 1) the current logger is GC - # 2) from the python doc, when the program exits, each remaining live - # finalizer is called. - # Note that in the normal flow, where "close" is explicitly called, - # the finalizer is marked as dead using the detach function so that - # the underlying logger is not closed twice (once explicitly and once - # implicitly when the object is GC or when the program exits). - self._finalizer = weakref.finalize(self, logger.close) - - def write(self, values: base.LoggingData): - self._logger.write(values) - - def close(self): - if self._finalizer.detach(): - self._logger.close() - self._logger = None + """Logger which auto closes itself on exit if not already closed.""" + + def __init__(self, logger: base.Logger): + self._logger = logger + # The finalizer "logger.close" is invoked in one of the following scenario: + # 1) the current logger is GC + # 2) from the python doc, when the program exits, each remaining live + # finalizer is called. + # Note that in the normal flow, where "close" is explicitly called, + # the finalizer is marked as dead using the detach function so that + # the underlying logger is not closed twice (once explicitly and once + # implicitly when the object is GC or when the program exits). + self._finalizer = weakref.finalize(self, logger.close) + + def write(self, values: base.LoggingData): + self._logger.write(values) + + def close(self): + if self._finalizer.detach(): + self._logger.close() + self._logger = None diff --git a/acme/utils/loggers/base.py b/acme/utils/loggers/base.py index 1517a1c27d..3d8fd91917 100644 --- a/acme/utils/loggers/base.py +++ b/acme/utils/loggers/base.py @@ -25,15 +25,15 @@ class Logger(abc.ABC): - """A logger has a `write` method.""" + """A logger has a `write` method.""" - @abc.abstractmethod - def write(self, data: LoggingData) -> None: - """Writes `data` to destination (file, terminal, database, etc).""" + @abc.abstractmethod + def write(self, data: LoggingData) -> None: + """Writes `data` to destination (file, terminal, database, etc).""" - @abc.abstractmethod - def close(self) -> None: - """Closes the logger, not expecting any further write.""" + @abc.abstractmethod + def close(self) -> None: + """Closes the logger, not expecting any further write.""" TaskInstance = int @@ -43,38 +43,39 @@ def close(self) -> None: class LoggerFactory(Protocol): - - def __call__(self, - label: LoggerLabel, - steps_key: Optional[LoggerStepsKey] = None, - instance: Optional[TaskInstance] = None) -> Logger: - ... + def __call__( + self, + label: LoggerLabel, + steps_key: Optional[LoggerStepsKey] = None, + instance: Optional[TaskInstance] = None, + ) -> Logger: + ... class NoOpLogger(Logger): - """Simple Logger which does nothing and outputs no logs. + """Simple Logger which does nothing and outputs no logs. This should be used sparingly, but it can prove useful if we want to quiet an individual component and have it produce no logging whatsoever. """ - def write(self, data: LoggingData): - pass + def write(self, data: LoggingData): + pass - def close(self): - pass + def close(self): + pass def tensor_to_numpy(value: Any): - if hasattr(value, 'numpy'): - return value.numpy() # tf.Tensor (TF2). - if hasattr(value, 'device_buffer'): - return np.asarray(value) # jnp.DeviceArray. - return value + if hasattr(value, "numpy"): + return value.numpy() # tf.Tensor (TF2). + if hasattr(value, "device_buffer"): + return np.asarray(value) # jnp.DeviceArray. + return value def to_numpy(values: Any): - """Converts tensors in a nested structure to numpy. + """Converts tensors in a nested structure to numpy. Converts tensors from TensorFlow to Numpy if needed without importing TF dependency. @@ -85,4 +86,4 @@ def to_numpy(values: Any): Returns: Same nested structure as values, but with numpy tensors. """ - return tree.map_structure(tensor_to_numpy, values) + return tree.map_structure(tensor_to_numpy, values) diff --git a/acme/utils/loggers/base_test.py b/acme/utils/loggers/base_test.py index b392a2d479..b11b6ce082 100644 --- a/acme/utils/loggers/base_test.py +++ b/acme/utils/loggers/base_test.py @@ -14,28 +14,27 @@ """Tests for acme.utils.loggers.base.""" -from acme.utils.loggers import base import jax.numpy as jnp import numpy as np import tensorflow as tf - from absl.testing import absltest +from acme.utils.loggers import base -class BaseTest(absltest.TestCase): - def test_tensor_serialisation(self): - data = {'x': tf.zeros(shape=(32,))} - output = base.to_numpy(data) - expected = {'x': np.zeros(shape=(32,))} - np.testing.assert_array_equal(output['x'], expected['x']) +class BaseTest(absltest.TestCase): + def test_tensor_serialisation(self): + data = {"x": tf.zeros(shape=(32,))} + output = base.to_numpy(data) + expected = {"x": np.zeros(shape=(32,))} + np.testing.assert_array_equal(output["x"], expected["x"]) - def test_device_array_serialisation(self): - data = {'x': jnp.zeros(shape=(32,))} - output = base.to_numpy(data) - expected = {'x': np.zeros(shape=(32,))} - np.testing.assert_array_equal(output['x'], expected['x']) + def test_device_array_serialisation(self): + data = {"x": jnp.zeros(shape=(32,))} + output = base.to_numpy(data) + expected = {"x": np.zeros(shape=(32,))} + np.testing.assert_array_equal(output["x"], expected["x"]) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/utils/loggers/constant.py b/acme/utils/loggers/constant.py index 2dba268675..799d39b863 100644 --- a/acme/utils/loggers/constant.py +++ b/acme/utils/loggers/constant.py @@ -18,29 +18,27 @@ class ConstantLogger(base.Logger): - """Logger for values that remain constant throughout the experiment. + """Logger for values that remain constant throughout the experiment. This logger is used to log additional values e.g. level_name or hyperparameters that do not change in an experiment. Having these values allows to group or facet plots when analysing data post-experiment. """ - def __init__( - self, - constant_data: base.LoggingData, - to: base.Logger, - ): - """Initialise the extra info logger. + def __init__( + self, constant_data: base.LoggingData, to: base.Logger, + ): + """Initialise the extra info logger. Args: constant_data: Key-value pairs containing the constant info to be logged. to: The logger to add these extra info to. """ - self._constant_data = constant_data - self._to = to + self._constant_data = constant_data + self._to = to - def write(self, data: base.LoggingData): - self._to.write({**self._constant_data, **data}) + def write(self, data: base.LoggingData): + self._to.write({**self._constant_data, **data}) - def close(self): - self._to.close() + def close(self): + self._to.close() diff --git a/acme/utils/loggers/csv.py b/acme/utils/loggers/csv.py index eb19a5144b..e808f3cc0b 100644 --- a/acme/utils/loggers/csv.py +++ b/acme/utils/loggers/csv.py @@ -29,7 +29,7 @@ class CSVLogger(base.Logger): - """Standard CSV logger. + """Standard CSV logger. The fields are inferred from the first call to write() and any additional fields afterwards are ignored. @@ -37,17 +37,17 @@ class CSVLogger(base.Logger): TODO(jaslanides): Consider making this stateless/robust to preemption. """ - _open = open + _open = open - def __init__( - self, - directory_or_file: Union[str, TextIO] = '~/acme', - label: str = '', - time_delta: float = 0., - add_uid: bool = True, - flush_every: int = 30, - ): - """Instantiates the logger. + def __init__( + self, + directory_or_file: Union[str, TextIO] = "~/acme", + label: str = "", + time_delta: float = 0.0, + add_uid: bool = True, + flush_every: int = 30, + ): + """Instantiates the logger. Args: directory_or_file: Either a directory path as a string, or a file TextIO @@ -61,81 +61,85 @@ def __init__( flush_every: Interval (in writes) between flushes. """ - if flush_every <= 0: - raise ValueError( - f'`flush_every` must be a positive integer (got {flush_every}).') - - self._last_log_time = time.time() - time_delta - self._time_delta = time_delta - self._flush_every = flush_every - self._add_uid = add_uid - self._writer = None - self._file_owner = False - self._file = self._create_file(directory_or_file, label) - self._writes = 0 - logging.info('Logging to %s', self.file_path) - - def _create_file( - self, - directory_or_file: Union[str, TextIO], - label: str, - ) -> TextIO: - """Opens a file if input is a directory or use existing file.""" - if isinstance(directory_or_file, str): - directory = paths.process_path( - directory_or_file, 'logs', label, add_uid=self._add_uid) - file_path = os.path.join(directory, 'logs.csv') - self._file_owner = True - return self._open(file_path, mode='a') - - # TextIO instance. - file = directory_or_file - if label: - logging.info('File, not directory, passed to CSVLogger; label not used.') - if not file.mode.startswith('a'): - raise ValueError('File must be open in append mode; instead got ' - f'mode="{file.mode}".') - return file - - def write(self, data: base.LoggingData): - """Writes a `data` into a row of comma-separated values.""" - # Only log if `time_delta` seconds have passed since last logging event. - now = time.time() - - # TODO(b/192227744): Remove this in favour of filters.TimeFilter. - elapsed = now - self._last_log_time - if elapsed < self._time_delta: - logging.debug('Not due to log for another %.2f seconds, dropping data.', - self._time_delta - elapsed) - return - self._last_log_time = now - - # Append row to CSV. - data = base.to_numpy(data) - # Use fields from initial `data` to create the header. If extra fields are - # present in subsequent `data`, we ignore them. - if not self._writer: - fields = sorted(data.keys()) - self._writer = csv.DictWriter(self._file, fieldnames=fields, - extrasaction='ignore') - # Write header only if the file is empty. - if not self._file.tell(): - self._writer.writeheader() - self._writer.writerow(data) - - # Flush every `flush_every` writes. - if self._writes % self._flush_every == 0: - self.flush() - self._writes += 1 - - def close(self): - self.flush() - if self._file_owner: - self._file.close() - - def flush(self): - self._file.flush() - - @property - def file_path(self) -> str: - return self._file.name + if flush_every <= 0: + raise ValueError( + f"`flush_every` must be a positive integer (got {flush_every})." + ) + + self._last_log_time = time.time() - time_delta + self._time_delta = time_delta + self._flush_every = flush_every + self._add_uid = add_uid + self._writer = None + self._file_owner = False + self._file = self._create_file(directory_or_file, label) + self._writes = 0 + logging.info("Logging to %s", self.file_path) + + def _create_file( + self, directory_or_file: Union[str, TextIO], label: str, + ) -> TextIO: + """Opens a file if input is a directory or use existing file.""" + if isinstance(directory_or_file, str): + directory = paths.process_path( + directory_or_file, "logs", label, add_uid=self._add_uid + ) + file_path = os.path.join(directory, "logs.csv") + self._file_owner = True + return self._open(file_path, mode="a") + + # TextIO instance. + file = directory_or_file + if label: + logging.info("File, not directory, passed to CSVLogger; label not used.") + if not file.mode.startswith("a"): + raise ValueError( + "File must be open in append mode; instead got " f'mode="{file.mode}".' + ) + return file + + def write(self, data: base.LoggingData): + """Writes a `data` into a row of comma-separated values.""" + # Only log if `time_delta` seconds have passed since last logging event. + now = time.time() + + # TODO(b/192227744): Remove this in favour of filters.TimeFilter. + elapsed = now - self._last_log_time + if elapsed < self._time_delta: + logging.debug( + "Not due to log for another %.2f seconds, dropping data.", + self._time_delta - elapsed, + ) + return + self._last_log_time = now + + # Append row to CSV. + data = base.to_numpy(data) + # Use fields from initial `data` to create the header. If extra fields are + # present in subsequent `data`, we ignore them. + if not self._writer: + fields = sorted(data.keys()) + self._writer = csv.DictWriter( + self._file, fieldnames=fields, extrasaction="ignore" + ) + # Write header only if the file is empty. + if not self._file.tell(): + self._writer.writeheader() + self._writer.writerow(data) + + # Flush every `flush_every` writes. + if self._writes % self._flush_every == 0: + self.flush() + self._writes += 1 + + def close(self): + self.flush() + if self._file_owner: + self._file.close() + + def flush(self): + self._file.flush() + + @property + def file_path(self) -> str: + return self._file.name diff --git a/acme/utils/loggers/csv_test.py b/acme/utils/loggers/csv_test.py index 3bf8d07014..f12e688856 100644 --- a/acme/utils/loggers/csv_test.py +++ b/acme/utils/loggers/csv_test.py @@ -17,86 +17,80 @@ import csv import os +from absl.testing import absltest, parameterized + from acme.testing import test_utils from acme.utils import paths from acme.utils.loggers import csv as csv_logger -from absl.testing import absltest -from absl.testing import parameterized - -_TEST_INPUTS = [{ - 'c': 'foo', - 'a': '1337', - 'b': '42.0001', -}, { - 'c': 'foo2', - 'a': '1338', - 'b': '43.0001', -}] +_TEST_INPUTS = [ + {"c": "foo", "a": "1337", "b": "42.0001",}, + {"c": "foo2", "a": "1338", "b": "43.0001",}, +] class CSVLoggingTest(test_utils.TestCase): - - def test_logging_input_is_directory(self): - - # Set up logger. - directory = self.get_tempdir() - label = 'test' - logger = csv_logger.CSVLogger(directory_or_file=directory, label=label) - - # Write data and close. - for inp in _TEST_INPUTS: - logger.write(inp) - logger.close() - - # Read back data. - outputs = [] - with open(logger.file_path) as f: - csv_reader = csv.DictReader(f) - for row in csv_reader: - outputs.append(dict(row)) - self.assertEqual(outputs, _TEST_INPUTS) - - @parameterized.parameters(True, False) - def test_logging_input_is_file(self, add_uid: bool): - - # Set up logger. - directory = paths.process_path( - self.get_tempdir(), 'logs', 'my_label', add_uid=add_uid) - file = open(os.path.join(directory, 'logs.csv'), 'a') - logger = csv_logger.CSVLogger(directory_or_file=file, add_uid=add_uid) - - # Write data and close. - for inp in _TEST_INPUTS: - logger.write(inp) - logger.close() - - # Logger doesn't close the file; caller must do this manually. - self.assertFalse(file.closed) - file.close() - - # Read back data. - outputs = [] - with open(logger.file_path) as f: - csv_reader = csv.DictReader(f) - for row in csv_reader: - outputs.append(dict(row)) - self.assertEqual(outputs, _TEST_INPUTS) - - def test_flush(self): - - logger = csv_logger.CSVLogger(self.get_tempdir(), flush_every=1) - for inp in _TEST_INPUTS: - logger.write(inp) - - # Read back data. - outputs = [] - with open(logger.file_path) as f: - csv_reader = csv.DictReader(f) - for row in csv_reader: - outputs.append(dict(row)) - self.assertEqual(outputs, _TEST_INPUTS) - - -if __name__ == '__main__': - absltest.main() + def test_logging_input_is_directory(self): + + # Set up logger. + directory = self.get_tempdir() + label = "test" + logger = csv_logger.CSVLogger(directory_or_file=directory, label=label) + + # Write data and close. + for inp in _TEST_INPUTS: + logger.write(inp) + logger.close() + + # Read back data. + outputs = [] + with open(logger.file_path) as f: + csv_reader = csv.DictReader(f) + for row in csv_reader: + outputs.append(dict(row)) + self.assertEqual(outputs, _TEST_INPUTS) + + @parameterized.parameters(True, False) + def test_logging_input_is_file(self, add_uid: bool): + + # Set up logger. + directory = paths.process_path( + self.get_tempdir(), "logs", "my_label", add_uid=add_uid + ) + file = open(os.path.join(directory, "logs.csv"), "a") + logger = csv_logger.CSVLogger(directory_or_file=file, add_uid=add_uid) + + # Write data and close. + for inp in _TEST_INPUTS: + logger.write(inp) + logger.close() + + # Logger doesn't close the file; caller must do this manually. + self.assertFalse(file.closed) + file.close() + + # Read back data. + outputs = [] + with open(logger.file_path) as f: + csv_reader = csv.DictReader(f) + for row in csv_reader: + outputs.append(dict(row)) + self.assertEqual(outputs, _TEST_INPUTS) + + def test_flush(self): + + logger = csv_logger.CSVLogger(self.get_tempdir(), flush_every=1) + for inp in _TEST_INPUTS: + logger.write(inp) + + # Read back data. + outputs = [] + with open(logger.file_path) as f: + csv_reader = csv.DictReader(f) + for row in csv_reader: + outputs.append(dict(row)) + self.assertEqual(outputs, _TEST_INPUTS) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/utils/loggers/dataframe.py b/acme/utils/loggers/dataframe.py index 16c59bed5f..b7a66dcec0 100644 --- a/acme/utils/loggers/dataframe.py +++ b/acme/utils/loggers/dataframe.py @@ -36,17 +36,17 @@ class InMemoryLogger(base.Logger): - """A simple logger that keeps all data in memory.""" + """A simple logger that keeps all data in memory.""" - def __init__(self): - self._data = [] + def __init__(self): + self._data = [] - def write(self, data: base.LoggingData): - self._data.append(data) + def write(self, data: base.LoggingData): + self._data.append(data) - def close(self): - pass + def close(self): + pass - @property - def data(self) -> Sequence[base.LoggingData]: - return self._data + @property + def data(self) -> Sequence[base.LoggingData]: + return self._data diff --git a/acme/utils/loggers/default.py b/acme/utils/loggers/default.py index 1c9e9de302..704f1e5037 100644 --- a/acme/utils/loggers/default.py +++ b/acme/utils/loggers/default.py @@ -19,10 +19,7 @@ from acme.utils.loggers import aggregators from acme.utils.loggers import asynchronous as async_logger -from acme.utils.loggers import base -from acme.utils.loggers import csv -from acme.utils.loggers import filters -from acme.utils.loggers import terminal +from acme.utils.loggers import base, csv, filters, terminal def make_default_logger( @@ -32,9 +29,9 @@ def make_default_logger( asynchronous: bool = False, print_fn: Optional[Callable[[str], None]] = None, serialize_fn: Optional[Callable[[Mapping[str, Any]], str]] = base.to_numpy, - steps_key: str = 'steps', + steps_key: str = "steps", ) -> base.Logger: - """Makes a default Acme logger. + """Makes a default Acme logger. Args: label: Name to give to the logger. @@ -49,21 +46,21 @@ def make_default_logger( Returns: A logger object that responds to logger.write(some_dict). """ - del steps_key - if not print_fn: - print_fn = logging.info - terminal_logger = terminal.TerminalLogger(label=label, print_fn=print_fn) + del steps_key + if not print_fn: + print_fn = logging.info + terminal_logger = terminal.TerminalLogger(label=label, print_fn=print_fn) - loggers = [terminal_logger] + loggers = [terminal_logger] - if save_data: - loggers.append(csv.CSVLogger(label=label)) + if save_data: + loggers.append(csv.CSVLogger(label=label)) - # Dispatch to all writers and filter Nones and by time. - logger = aggregators.Dispatcher(loggers, serialize_fn) - logger = filters.NoneFilter(logger) - if asynchronous: - logger = async_logger.AsyncLogger(logger) - logger = filters.TimeFilter(logger, time_delta) + # Dispatch to all writers and filter Nones and by time. + logger = aggregators.Dispatcher(loggers, serialize_fn) + logger = filters.NoneFilter(logger) + if asynchronous: + logger = async_logger.AsyncLogger(logger) + logger = filters.TimeFilter(logger, time_delta) - return logger + return logger diff --git a/acme/utils/loggers/filters.py b/acme/utils/loggers/filters.py index 10a2c241a5..dfbd70c639 100644 --- a/acme/utils/loggers/filters.py +++ b/acme/utils/loggers/filters.py @@ -22,30 +22,30 @@ class NoneFilter(base.Logger): - """Logger which writes to another logger, filtering any `None` values.""" + """Logger which writes to another logger, filtering any `None` values.""" - def __init__(self, to: base.Logger): - """Initializes the logger. + def __init__(self, to: base.Logger): + """Initializes the logger. Args: to: A `Logger` object to which the current object will forward its results when `write` is called. """ - self._to = to + self._to = to - def write(self, values: base.LoggingData): - values = {k: v for k, v in values.items() if v is not None} - self._to.write(values) + def write(self, values: base.LoggingData): + values = {k: v for k, v in values.items() if v is not None} + self._to.write(values) - def close(self): - self._to.close() + def close(self): + self._to.close() class TimeFilter(base.Logger): - """Logger which writes to another logger at a given time interval.""" + """Logger which writes to another logger at a given time interval.""" - def __init__(self, to: base.Logger, time_delta: float): - """Initializes the logger. + def __init__(self, to: base.Logger, time_delta: float): + """Initializes the logger. Args: to: A `Logger` object to which the current object will forward its results @@ -53,33 +53,33 @@ def __init__(self, to: base.Logger, time_delta: float): time_delta: How often to write values out in seconds. Note that writes within `time_delta` are dropped. """ - self._to = to - self._time = 0 - self._time_delta = time_delta - if time_delta < 0: - raise ValueError(f'time_delta must be greater than 0 (got {time_delta}).') + self._to = to + self._time = 0 + self._time_delta = time_delta + if time_delta < 0: + raise ValueError(f"time_delta must be greater than 0 (got {time_delta}).") - def write(self, values: base.LoggingData): - now = time.time() - if (now - self._time) > self._time_delta: - self._to.write(values) - self._time = now + def write(self, values: base.LoggingData): + now = time.time() + if (now - self._time) > self._time_delta: + self._to.write(values) + self._time = now - def close(self): - self._to.close() + def close(self): + self._to.close() class KeyFilter(base.Logger): - """Logger which filters keys in logged data.""" + """Logger which filters keys in logged data.""" - def __init__( - self, - to: base.Logger, - *, - keep: Optional[Sequence[str]] = None, - drop: Optional[Sequence[str]] = None, - ): - """Creates the filter. + def __init__( + self, + to: base.Logger, + *, + keep: Optional[Sequence[str]] = None, + drop: Optional[Sequence[str]] = None, + ): + """Creates the filter. Args: to: A `Logger` object to which the current object will forward its writes. @@ -88,32 +88,32 @@ def __init__( drop: Keys that are dropped by the filter. Note that `keep` and `drop` cannot be both set at once. """ - if bool(keep) == bool(drop): - raise ValueError('Exactly one of `keep` & `drop` arguments must be set.') - self._to = to - self._keep = keep - self._drop = drop + if bool(keep) == bool(drop): + raise ValueError("Exactly one of `keep` & `drop` arguments must be set.") + self._to = to + self._keep = keep + self._drop = drop - def write(self, data: base.LoggingData): - if self._keep: - data = {k: data[k] for k in self._keep} - if self._drop: - data = {k: v for k, v in data.items() if k not in self._drop} - self._to.write(data) + def write(self, data: base.LoggingData): + if self._keep: + data = {k: data[k] for k in self._keep} + if self._drop: + data = {k: v for k, v in data.items() if k not in self._drop} + self._to.write(data) - def close(self): - self._to.close() + def close(self): + self._to.close() class GatedFilter(base.Logger): - """Logger which writes to another logger based on a gating function. + """Logger which writes to another logger based on a gating function. This logger tracks the number of times its `write` method is called, and uses a gating function on this number to decide when to write. """ - def __init__(self, to: base.Logger, gating_fn: Callable[[int], bool]): - """Initialises the logger. + def __init__(self, to: base.Logger, gating_fn: Callable[[int], bool]): + """Initialises the logger. Args: to: A `Logger` object to which the current object will forward its results @@ -121,21 +121,21 @@ def __init__(self, to: base.Logger, gating_fn: Callable[[int], bool]): gating_fn: A function that takes an integer (number of calls) as input. For example, to log every tenth call: gating_fn=lambda t: t % 10 == 0. """ - self._to = to - self._gating_fn = gating_fn - self._calls = 0 + self._to = to + self._gating_fn = gating_fn + self._calls = 0 - def write(self, values: base.LoggingData): - if self._gating_fn(self._calls): - self._to.write(values) - self._calls += 1 + def write(self, values: base.LoggingData): + if self._gating_fn(self._calls): + self._to.write(values) + self._calls += 1 - def close(self): - self._to.close() + def close(self): + self._to.close() - @classmethod - def logarithmic(cls, to: base.Logger, n: int = 10) -> 'GatedFilter': - """Builds a logger for writing at logarithmically-spaced intervals. + @classmethod + def logarithmic(cls, to: base.Logger, n: int = 10) -> "GatedFilter": + """Builds a logger for writing at logarithmically-spaced intervals. This will log on a linear scale at each order of magnitude of `n`. For example, with n=10, this will log at times: @@ -147,14 +147,16 @@ def logarithmic(cls, to: base.Logger, n: int = 10) -> 'GatedFilter': Returns: A GatedFilter logger, which gates logarithmically as described above. """ - def logarithmic_filter(t: int) -> bool: - magnitude = math.floor(math.log10(max(t, 1))/math.log10(n)) - return t % (n**magnitude) == 0 - return cls(to, gating_fn=logarithmic_filter) - @classmethod - def periodic(cls, to: base.Logger, interval: int = 10) -> 'GatedFilter': - """Builds a logger for writing at linearly-spaced intervals. + def logarithmic_filter(t: int) -> bool: + magnitude = math.floor(math.log10(max(t, 1)) / math.log10(n)) + return t % (n ** magnitude) == 0 + + return cls(to, gating_fn=logarithmic_filter) + + @classmethod + def periodic(cls, to: base.Logger, interval: int = 10) -> "GatedFilter": + """Builds a logger for writing at linearly-spaced intervals. Args: to: The underlying logger to write to. @@ -162,4 +164,4 @@ def periodic(cls, to: base.Logger, interval: int = 10) -> 'GatedFilter': Returns: A GatedFilter logger, which gates periodically as described above. """ - return cls(to, gating_fn=lambda t: t % interval == 0) + return cls(to, gating_fn=lambda t: t % interval == 0) diff --git a/acme/utils/loggers/filters_test.py b/acme/utils/loggers/filters_test.py index c32787412c..bcf699d752 100644 --- a/acme/utils/loggers/filters_test.py +++ b/acme/utils/loggers/filters_test.py @@ -16,97 +16,93 @@ import time -from acme.utils.loggers import base -from acme.utils.loggers import filters - from absl.testing import absltest +from acme.utils.loggers import base, filters + # TODO(jaslanides): extract this to test_utils, or similar, for re-use. class FakeLogger(base.Logger): - """A fake logger for testing.""" + """A fake logger for testing.""" - def __init__(self): - self.data = [] + def __init__(self): + self.data = [] - def write(self, data): - self.data.append(data) + def write(self, data): + self.data.append(data) - @property - def last_write(self): - return self.data[-1] + @property + def last_write(self): + return self.data[-1] - def close(self): - pass + def close(self): + pass class GatedFilterTest(absltest.TestCase): - - def test_logarithmic_filter(self): - logger = FakeLogger() - filtered = filters.GatedFilter.logarithmic(logger, n=10) - for t in range(100): - filtered.write({'t': t}) - rows = [row['t'] for row in logger.data] - self.assertEqual(rows, [*range(10), *range(10, 100, 10)]) - - def test_periodic_filter(self): - logger = FakeLogger() - filtered = filters.GatedFilter.periodic(logger, interval=10) - for t in range(100): - filtered.write({'t': t}) - rows = [row['t'] for row in logger.data] - self.assertEqual(rows, list(range(0, 100, 10))) + def test_logarithmic_filter(self): + logger = FakeLogger() + filtered = filters.GatedFilter.logarithmic(logger, n=10) + for t in range(100): + filtered.write({"t": t}) + rows = [row["t"] for row in logger.data] + self.assertEqual(rows, [*range(10), *range(10, 100, 10)]) + + def test_periodic_filter(self): + logger = FakeLogger() + filtered = filters.GatedFilter.periodic(logger, interval=10) + for t in range(100): + filtered.write({"t": t}) + rows = [row["t"] for row in logger.data] + self.assertEqual(rows, list(range(0, 100, 10))) class TimeFilterTest(absltest.TestCase): + def test_delta(self): + logger = FakeLogger() + filtered = filters.TimeFilter(logger, time_delta=0.1) - def test_delta(self): - logger = FakeLogger() - filtered = filters.TimeFilter(logger, time_delta=0.1) - - # Logged. - filtered.write({'foo': 1}) - self.assertIn('foo', logger.last_write) + # Logged. + filtered.write({"foo": 1}) + self.assertIn("foo", logger.last_write) - # *Not* logged. - filtered.write({'bar': 2}) - self.assertNotIn('bar', logger.last_write) + # *Not* logged. + filtered.write({"bar": 2}) + self.assertNotIn("bar", logger.last_write) - # Wait out delta. - time.sleep(0.11) + # Wait out delta. + time.sleep(0.11) - # Logged. - filtered.write({'baz': 3}) - self.assertIn('baz', logger.last_write) + # Logged. + filtered.write({"baz": 3}) + self.assertIn("baz", logger.last_write) - self.assertLen(logger.data, 2) + self.assertLen(logger.data, 2) class KeyFilterTest(absltest.TestCase): - - def test_keep_filter(self): - logger = FakeLogger() - filtered = filters.KeyFilter(logger, keep=('foo',)) - filtered.write({'foo': 'bar', 'baz': 12}) - row, *_ = logger.data - self.assertIn('foo', row) - self.assertNotIn('baz', row) - - def test_drop_filter(self): - logger = FakeLogger() - filtered = filters.KeyFilter(logger, drop=('foo',)) - filtered.write({'foo': 'bar', 'baz': 12}) - row, *_ = logger.data - self.assertIn('baz', row) - self.assertNotIn('foo', row) - - def test_bad_arguments(self): - with self.assertRaises(ValueError): - filters.KeyFilter(FakeLogger()) - with self.assertRaises(ValueError): - filters.KeyFilter(FakeLogger(), keep=('a',), drop=('b',)) - - -if __name__ == '__main__': - absltest.main() + def test_keep_filter(self): + logger = FakeLogger() + filtered = filters.KeyFilter(logger, keep=("foo",)) + filtered.write({"foo": "bar", "baz": 12}) + row, *_ = logger.data + self.assertIn("foo", row) + self.assertNotIn("baz", row) + + def test_drop_filter(self): + logger = FakeLogger() + filtered = filters.KeyFilter(logger, drop=("foo",)) + filtered.write({"foo": "bar", "baz": 12}) + row, *_ = logger.data + self.assertIn("baz", row) + self.assertNotIn("foo", row) + + def test_bad_arguments(self): + with self.assertRaises(ValueError): + filters.KeyFilter(FakeLogger()) + with self.assertRaises(ValueError): + filters.KeyFilter(FakeLogger(), keep=("a",), drop=("b",)) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/utils/loggers/flatten.py b/acme/utils/loggers/flatten.py index e06363aa50..057a85c00a 100644 --- a/acme/utils/loggers/flatten.py +++ b/acme/utils/loggers/flatten.py @@ -20,13 +20,12 @@ class FlattenDictLogger(base.Logger): - """Logger which flattens sub-dictionaries into the top level dict.""" + """Logger which flattens sub-dictionaries into the top level dict.""" - def __init__(self, - logger: base.Logger, - label: str = 'Logs', - raw_keys: Sequence[str] = ()): - """Initializer. + def __init__( + self, logger: base.Logger, label: str = "Logs", raw_keys: Sequence[str] = () + ): + """Initializer. Args: logger: The wrapped logger. @@ -36,24 +35,24 @@ def __init__(self, keys to be present in the logs (e.g. 'step', 'timestamp'), so these keys should not be prefixed. """ - self._logger = logger - self._label = label - self._raw_keys = raw_keys - - def write(self, values: base.LoggingData): - flattened_values = {} - for key, value in values.items(): - if key in self._raw_keys: - flattened_values[key] = value - continue - name = f'{self._label}/{key}' - if isinstance(value, dict): - for sub_key, sub_value in value.items(): - flattened_values[f'{name}/{sub_key}'] = sub_value - else: - flattened_values[name] = value - - self._logger.write(flattened_values) - - def close(self): - self._logger.close() + self._logger = logger + self._label = label + self._raw_keys = raw_keys + + def write(self, values: base.LoggingData): + flattened_values = {} + for key, value in values.items(): + if key in self._raw_keys: + flattened_values[key] = value + continue + name = f"{self._label}/{key}" + if isinstance(value, dict): + for sub_key, sub_value in value.items(): + flattened_values[f"{name}/{sub_key}"] = sub_value + else: + flattened_values[name] = value + + self._logger.write(flattened_values) + + def close(self): + self._logger.close() diff --git a/acme/utils/loggers/image.py b/acme/utils/loggers/image.py index d6def89008..bdc5cac865 100644 --- a/acme/utils/loggers/image.py +++ b/acme/utils/loggers/image.py @@ -19,26 +19,23 @@ from typing import Optional from absl import logging -from acme.utils.loggers import base from PIL import Image +from acme.utils.loggers import base + class ImageLogger(base.Logger): - """Logger for writing NumPy arrays as PNG images to disk. + """Logger for writing NumPy arrays as PNG images to disk. Assumes that all data passed are NumPy arrays that can be converted to images. TODO(jaslanides): Make this stateless/robust to preemptions. """ - def __init__( - self, - directory: str, - *, - label: str = '', - mode: Optional[str] = None, - ): - """Initialises the writer. + def __init__( + self, directory: str, *, label: str = "", mode: Optional[str] = None, + ): + """Initialises the writer. Args: directory: Base directory to which images are logged. @@ -49,28 +46,28 @@ def __init__( [0] https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes """ - self._path = self._get_path(directory, label) - if not self._path.exists(): - self._path.mkdir(parents=True) + self._path = self._get_path(directory, label) + if not self._path.exists(): + self._path.mkdir(parents=True) - self._mode = mode - self._indices = collections.defaultdict(int) + self._mode = mode + self._indices = collections.defaultdict(int) - def write(self, data: base.LoggingData): - for k, v in data.items(): - image = Image.fromarray(v, mode=self._mode) - path = self._path / f'{k}_{self._indices[k]:06}.png' - self._indices[k] += 1 - with path.open(mode='wb') as f: - logging.info('Writing image to %s.', str(path)) - image.save(f) + def write(self, data: base.LoggingData): + for k, v in data.items(): + image = Image.fromarray(v, mode=self._mode) + path = self._path / f"{k}_{self._indices[k]:06}.png" + self._indices[k] += 1 + with path.open(mode="wb") as f: + logging.info("Writing image to %s.", str(path)) + image.save(f) - def close(self): - pass + def close(self): + pass - @property - def directory(self) -> str: - return str(self._path) + @property + def directory(self) -> str: + return str(self._path) - def _get_path(self, *args, **kwargs) -> pathlib.Path: - return pathlib.Path(*args, **kwargs) + def _get_path(self, *args, **kwargs) -> pathlib.Path: + return pathlib.Path(*args, **kwargs) diff --git a/acme/utils/loggers/image_test.py b/acme/utils/loggers/image_test.py index a241255847..91f32881f6 100644 --- a/acme/utils/loggers/image_test.py +++ b/acme/utils/loggers/image_test.py @@ -16,45 +16,44 @@ import os -from acme.testing import test_utils -from acme.utils.loggers import image import numpy as np +from absl.testing import absltest from PIL import Image -from absl.testing import absltest +from acme.testing import test_utils +from acme.utils.loggers import image class ImageTest(test_utils.TestCase): - - def test_save_load_identity(self): - directory = self.get_tempdir() - logger = image.ImageLogger(directory, label='foo') - array = (np.random.rand(10, 10) * 255).astype(np.uint8) - logger.write({'img': array}) - - with open(f'{directory}/foo/img_000000.png', mode='rb') as f: - out = np.asarray(Image.open(f)) - np.testing.assert_array_equal(array, out) - - def test_indexing(self): - directory = self.get_tempdir() - logger = image.ImageLogger(directory, label='foo') - zeros = np.zeros(shape=(3, 3), dtype=np.uint8) - logger.write({'img': zeros, 'other_img': zeros + 1}) - logger.write({'img': zeros - 1}) - logger.write({'other_img': zeros + 1}) - logger.write({'other_img': zeros + 2}) - - fnames = sorted(os.listdir(f'{directory}/foo')) - expected = [ - 'img_000000.png', - 'img_000001.png', - 'other_img_000000.png', - 'other_img_000001.png', - 'other_img_000002.png', - ] - self.assertEqual(fnames, expected) - - -if __name__ == '__main__': - absltest.main() + def test_save_load_identity(self): + directory = self.get_tempdir() + logger = image.ImageLogger(directory, label="foo") + array = (np.random.rand(10, 10) * 255).astype(np.uint8) + logger.write({"img": array}) + + with open(f"{directory}/foo/img_000000.png", mode="rb") as f: + out = np.asarray(Image.open(f)) + np.testing.assert_array_equal(array, out) + + def test_indexing(self): + directory = self.get_tempdir() + logger = image.ImageLogger(directory, label="foo") + zeros = np.zeros(shape=(3, 3), dtype=np.uint8) + logger.write({"img": zeros, "other_img": zeros + 1}) + logger.write({"img": zeros - 1}) + logger.write({"other_img": zeros + 1}) + logger.write({"other_img": zeros + 2}) + + fnames = sorted(os.listdir(f"{directory}/foo")) + expected = [ + "img_000000.png", + "img_000001.png", + "other_img_000000.png", + "other_img_000001.png", + "other_img_000002.png", + ] + self.assertEqual(fnames, expected) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/utils/loggers/terminal.py b/acme/utils/loggers/terminal.py index 821a00c131..cafdbe5f05 100644 --- a/acme/utils/loggers/terminal.py +++ b/acme/utils/loggers/terminal.py @@ -18,25 +18,26 @@ import time from typing import Any, Callable -from acme.utils.loggers import base import numpy as np +from acme.utils.loggers import base + def _format_key(key: str) -> str: - """Internal function for formatting keys.""" - return key.replace('_', ' ').title() + """Internal function for formatting keys.""" + return key.replace("_", " ").title() def _format_value(value: Any) -> str: - """Internal function for formatting values.""" - value = base.to_numpy(value) - if isinstance(value, (float, np.number)): - return f'{value:0.3f}' - return f'{value}' + """Internal function for formatting values.""" + value = base.to_numpy(value) + if isinstance(value, (float, np.number)): + return f"{value:0.3f}" + return f"{value}" def serialize(values: base.LoggingData) -> str: - """Converts `values` to a pretty-printed string. + """Converts `values` to a pretty-printed string. This takes a dictionary `values` whose keys are strings and returns a formatted string such that each [key, value] pair is separated by ' = ' and @@ -55,21 +56,22 @@ def serialize(values: base.LoggingData) -> str: Returns: A formatted string. """ - return ' | '.join(f'{_format_key(k)} = {_format_value(v)}' - for k, v in sorted(values.items())) + return " | ".join( + f"{_format_key(k)} = {_format_value(v)}" for k, v in sorted(values.items()) + ) class TerminalLogger(base.Logger): - """Logs to terminal.""" + """Logs to terminal.""" - def __init__( - self, - label: str = '', - print_fn: Callable[[str], None] = logging.info, - serialize_fn: Callable[[base.LoggingData], str] = serialize, - time_delta: float = 0.0, - ): - """Initializes the logger. + def __init__( + self, + label: str = "", + print_fn: Callable[[str], None] = logging.info, + serialize_fn: Callable[[base.LoggingData], str] = serialize, + time_delta: float = 0.0, + ): + """Initializes the logger. Args: label: label string to use when logging. @@ -79,17 +81,17 @@ def __init__( minimize terminal spam, but is 0 by default---ie everything is written. """ - self._print_fn = print_fn - self._serialize_fn = serialize_fn - self._label = label and f'[{_format_key(label)}] ' - self._time = time.time() - self._time_delta = time_delta + self._print_fn = print_fn + self._serialize_fn = serialize_fn + self._label = label and f"[{_format_key(label)}] " + self._time = time.time() + self._time_delta = time_delta - def write(self, values: base.LoggingData): - now = time.time() - if (now - self._time) > self._time_delta: - self._print_fn(f'{self._label}{self._serialize_fn(values)}') - self._time = now + def write(self, values: base.LoggingData): + now = time.time() + if (now - self._time) > self._time_delta: + self._print_fn(f"{self._label}{self._serialize_fn(values)}") + self._time = now - def close(self): - pass + def close(self): + pass diff --git a/acme/utils/loggers/terminal_test.py b/acme/utils/loggers/terminal_test.py index facdcacba6..c02627f98d 100644 --- a/acme/utils/loggers/terminal_test.py +++ b/acme/utils/loggers/terminal_test.py @@ -14,33 +14,32 @@ """Tests for terminal logger.""" -from acme.utils.loggers import terminal - from absl.testing import absltest +from acme.utils.loggers import terminal -class LoggingTest(absltest.TestCase): - def test_logging_output_format(self): - inputs = { - 'c': 'foo', - 'a': 1337, - 'b': 42.0001, - } - expected_outputs = 'A = 1337 | B = 42.000 | C = foo' - test_fn = lambda outputs: self.assertEqual(outputs, expected_outputs) +class LoggingTest(absltest.TestCase): + def test_logging_output_format(self): + inputs = { + "c": "foo", + "a": 1337, + "b": 42.0001, + } + expected_outputs = "A = 1337 | B = 42.000 | C = foo" + test_fn = lambda outputs: self.assertEqual(outputs, expected_outputs) - logger = terminal.TerminalLogger(print_fn=test_fn) - logger.write(inputs) + logger = terminal.TerminalLogger(print_fn=test_fn) + logger.write(inputs) - def test_label(self): - inputs = {'foo': 'bar', 'baz': 123} - expected_outputs = '[Test] Baz = 123 | Foo = bar' - test_fn = lambda outputs: self.assertEqual(outputs, expected_outputs) + def test_label(self): + inputs = {"foo": "bar", "baz": 123} + expected_outputs = "[Test] Baz = 123 | Foo = bar" + test_fn = lambda outputs: self.assertEqual(outputs, expected_outputs) - logger = terminal.TerminalLogger(print_fn=test_fn, label='test') - logger.write(inputs) + logger = terminal.TerminalLogger(print_fn=test_fn, label="test") + logger.write(inputs) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/utils/loggers/tf_summary.py b/acme/utils/loggers/tf_summary.py index 868c900b29..49909973b2 100644 --- a/acme/utils/loggers/tf_summary.py +++ b/acme/utils/loggers/tf_summary.py @@ -17,58 +17,56 @@ import time from typing import Optional +import tensorflow as tf from absl import logging + from acme.utils.loggers import base -import tensorflow as tf def _format_key(key: str) -> str: - """Internal function for formatting keys in Tensorboard format.""" - return key.title().replace('_', '') + """Internal function for formatting keys in Tensorboard format.""" + return key.title().replace("_", "") class TFSummaryLogger(base.Logger): - """Logs to a tf.summary created in a given logdir. + """Logs to a tf.summary created in a given logdir. If multiple TFSummaryLogger are created with the same logdir, results will be categorized by labels. """ - def __init__( - self, - logdir: str, - label: str = 'Logs', - steps_key: Optional[str] = None - ): - """Initializes the logger. + def __init__( + self, logdir: str, label: str = "Logs", steps_key: Optional[str] = None + ): + """Initializes the logger. Args: logdir: directory to which we should log files. label: label string to use when logging. Default to 'Logs'. steps_key: key to use for steps. Must be in the values passed to write. """ - self._time = time.time() - self.label = label - self._iter = 0 - self.summary = tf.summary.create_file_writer(logdir) - self._steps_key = steps_key + self._time = time.time() + self.label = label + self._iter = 0 + self.summary = tf.summary.create_file_writer(logdir) + self._steps_key = steps_key - def write(self, values: base.LoggingData): - if self._steps_key is not None and self._steps_key not in values: - logging.warning('steps key %s not found. Skip logging.', self._steps_key) - return + def write(self, values: base.LoggingData): + if self._steps_key is not None and self._steps_key not in values: + logging.warning("steps key %s not found. Skip logging.", self._steps_key) + return - step = values[ - self._steps_key] if self._steps_key is not None else self._iter + step = values[self._steps_key] if self._steps_key is not None else self._iter - with self.summary.as_default(): - # TODO(b/159065169): Remove this suppression once the bug is resolved. - # pytype: disable=unsupported-operands - for key in values.keys() - [self._steps_key]: - # pytype: enable=unsupported-operands - tf.summary.scalar( - f'{self.label}/{_format_key(key)}', data=values[key], step=step) - self._iter += 1 + with self.summary.as_default(): + # TODO(b/159065169): Remove this suppression once the bug is resolved. + # pytype: disable=unsupported-operands + for key in values.keys() - [self._steps_key]: + # pytype: enable=unsupported-operands + tf.summary.scalar( + f"{self.label}/{_format_key(key)}", data=values[key], step=step + ) + self._iter += 1 - def close(self): - self.summary.close() + def close(self): + self.summary.close() diff --git a/acme/utils/loggers/timestamp.py b/acme/utils/loggers/timestamp.py index 8dfc55265d..8ba255b4e0 100644 --- a/acme/utils/loggers/timestamp.py +++ b/acme/utils/loggers/timestamp.py @@ -20,16 +20,16 @@ class TimestampLogger(base.Logger): - """Logger which populates the timestamp key with the current timestamp.""" + """Logger which populates the timestamp key with the current timestamp.""" - def __init__(self, logger: base.Logger, timestamp_key: str): - self._logger = logger - self._timestamp_key = timestamp_key + def __init__(self, logger: base.Logger, timestamp_key: str): + self._logger = logger + self._timestamp_key = timestamp_key - def write(self, values: base.LoggingData): - values = dict(values) - values[self._timestamp_key] = time.time() - self._logger.write(values) + def write(self, values: base.LoggingData): + values = dict(values) + values[self._timestamp_key] = time.time() + self._logger.write(values) - def close(self): - self._logger.close() + def close(self): + self._logger.close() diff --git a/acme/utils/lp_utils.py b/acme/utils/lp_utils.py index 354c0b0d62..8aa30635a8 100644 --- a/acme/utils/lp_utils.py +++ b/acme/utils/lp_utils.py @@ -22,17 +22,15 @@ import time from typing import Any, Callable, Optional -from absl import flags -from absl import logging -from acme.utils import counting -from acme.utils import signals +from absl import flags, logging + +from acme.utils import counting, signals FLAGS = flags.FLAGS -def partial_kwargs(function: Callable[..., Any], - **kwargs: Any) -> Callable[..., Any]: - """Return a partial function application by overriding default keywords. +def partial_kwargs(function: Callable[..., Any], **kwargs: Any) -> Callable[..., Any]: + """Return a partial function application by overriding default keywords. This function is equivalent to `functools.partial(function, **kwargs)` but will raise a `ValueError` when called if either the given keyword arguments @@ -48,74 +46,78 @@ def partial_kwargs(function: Callable[..., Any], Returns: A function. """ - # Try to get the argspec of our function which we'll use to get which keywords - # have defaults. - argspec = inspect.getfullargspec(function) + # Try to get the argspec of our function which we'll use to get which keywords + # have defaults. + argspec = inspect.getfullargspec(function) - # Figure out which keywords have defaults. - if argspec.defaults is None: - defaults = [] - else: - defaults = argspec.args[-len(argspec.defaults):] + # Figure out which keywords have defaults. + if argspec.defaults is None: + defaults = [] + else: + defaults = argspec.args[-len(argspec.defaults) :] - # Find any keys not given as defaults by the function. - unknown_kwargs = set(kwargs.keys()).difference(defaults) + # Find any keys not given as defaults by the function. + unknown_kwargs = set(kwargs.keys()).difference(defaults) - # Raise an error - if unknown_kwargs: - error_string = 'Cannot override unknown or non-default kwargs: {}' - raise ValueError(error_string.format(', '.join(unknown_kwargs))) + # Raise an error + if unknown_kwargs: + error_string = "Cannot override unknown or non-default kwargs: {}" + raise ValueError(error_string.format(", ".join(unknown_kwargs))) - return functools.partial(function, **kwargs) + return functools.partial(function, **kwargs) class StepsLimiter: - """Process that terminates an experiment when `max_steps` is reached.""" - - def __init__(self, - counter: counting.Counter, - max_steps: int, - steps_key: str = 'actor_steps'): - self._counter = counter - self._max_steps = max_steps - self._steps_key = steps_key - - def run(self): - """Run steps limiter to terminate an experiment when max_steps is reached. + """Process that terminates an experiment when `max_steps` is reached.""" + + def __init__( + self, counter: counting.Counter, max_steps: int, steps_key: str = "actor_steps" + ): + self._counter = counter + self._max_steps = max_steps + self._steps_key = steps_key + + def run(self): + """Run steps limiter to terminate an experiment when max_steps is reached. """ - logging.info('StepsLimiter: Starting with max_steps = %d (%s)', - self._max_steps, self._steps_key) - with signals.runtime_terminator(): - while True: - # Update the counts. - counts = self._counter.get_counts() - num_steps = counts.get(self._steps_key, 0) + logging.info( + "StepsLimiter: Starting with max_steps = %d (%s)", + self._max_steps, + self._steps_key, + ) + with signals.runtime_terminator(): + while True: + # Update the counts. + counts = self._counter.get_counts() + num_steps = counts.get(self._steps_key, 0) + + logging.info("StepsLimiter: Reached %d recorded steps", num_steps) - logging.info('StepsLimiter: Reached %d recorded steps', num_steps) + if num_steps > self._max_steps: + logging.info( + "StepsLimiter: Max steps of %d was reached, terminating", + self._max_steps, + ) + # Avoid importing Launchpad until it is actually used. + import launchpad as lp # pylint: disable=g-import-not-at-top - if num_steps > self._max_steps: - logging.info('StepsLimiter: Max steps of %d was reached, terminating', - self._max_steps) - # Avoid importing Launchpad until it is actually used. - import launchpad as lp # pylint: disable=g-import-not-at-top - lp.stop() + lp.stop() - # Don't spam the counter. - for _ in range(10): - # Do not sleep for a long period of time to avoid LaunchPad program - # termination hangs (time.sleep is not interruptible). - time.sleep(1) + # Don't spam the counter. + for _ in range(10): + # Do not sleep for a long period of time to avoid LaunchPad program + # termination hangs (time.sleep is not interruptible). + time.sleep(1) def is_local_run() -> bool: - return FLAGS.lp_launch_type.startswith('local') + return FLAGS.lp_launch_type.startswith("local") # Resources for each individual instance of the program. -def make_xm_docker_resources(program, - requirements: Optional[str] = None): - """Returns Docker XManager resources for each program's node. +def make_xm_docker_resources(program, requirements: Optional[str] = None): + """Returns Docker XManager resources for each program's node. For each node of the Launchpad's program appropriate hardware requirements are specified (CPU, memory...), while the list of PyPi packages specified in @@ -126,104 +128,110 @@ def make_xm_docker_resources(program, requirements: file containing additional requirements to use. If not specified, default Acme dependencies are used instead. """ - if (FLAGS.lp_launch_type != 'vertex_ai' and - FLAGS.lp_launch_type != 'local_docker'): - # Avoid importing 'xmanager' for local runs. - return None - - # Avoid importing Launchpad until it is actually used. - import launchpad as lp # pylint: disable=g-import-not-at-top - # Reference lp.DockerConfig to force lazy import of xmanager by Launchpad and - # then import it. It is done this way to avoid heavy imports by default. - lp.DockerConfig # pylint: disable=pointless-statement - from xmanager import xm # pylint: disable=g-import-not-at-top - - # Get number of each type of node. - num_nodes = {k: len(v) for k, v in program.groups.items()} - - xm_resources = {} - - acme_location = os.path.dirname(os.path.dirname(__file__)) - if not requirements: - # Acme requirements are located in the Acme directory (when installed - # with pip), or need to be extracted from setup.py when using Acme codebase - # from GitHub without PyPi installation. - requirements = os.path.join(acme_location, 'requirements.txt') - if not os.path.isfile(requirements): - # Try to generate requirements.txt from setup.py - setup = os.path.join(os.path.dirname(acme_location), 'setup.py') - if os.path.isfile(setup): - # Generate requirements.txt file using setup.py. - import importlib.util # pylint: disable=g-import-not-at-top - spec = importlib.util.spec_from_file_location('setup', setup) - setup = importlib.util.module_from_spec(spec) - try: - spec.loader.exec_module(setup) # pytype: disable=attribute-error - except SystemExit: - pass - atexit.register(os.remove, requirements) - setup.generate_requirements_file(requirements) - - # Extend PYTHONPATH with paths used by the launcher. - python_path = [] - for path in sys.path: - if path.startswith(acme_location) and acme_location != path: - python_path.append(path[len(acme_location):]) - - if 'replay' in num_nodes: - replay_cpu = 6 + num_nodes.get('actor', 0) * 0.01 - replay_cpu = min(40, replay_cpu) - - xm_resources['replay'] = lp.DockerConfig( - acme_location, - requirements, - hw_requirements=xm.JobRequirements(cpu=replay_cpu, ram=10 * xm.GiB), - python_path=python_path) - - if 'evaluator' in num_nodes: - xm_resources['evaluator'] = lp.DockerConfig( - acme_location, - requirements, - hw_requirements=xm.JobRequirements(cpu=2, ram=4 * xm.GiB), - python_path=python_path) - - if 'actor' in num_nodes: - xm_resources['actor'] = lp.DockerConfig( - acme_location, - requirements, - hw_requirements=xm.JobRequirements(cpu=2, ram=4 * xm.GiB), - python_path=python_path) - - if 'learner' in num_nodes: - learner_cpu = 6 + num_nodes.get('actor', 0) * 0.01 - learner_cpu = min(40, learner_cpu) - xm_resources['learner'] = lp.DockerConfig( - acme_location, - requirements, - hw_requirements=xm.JobRequirements( - cpu=learner_cpu, ram=6 * xm.GiB, P100=1), - python_path=python_path) - - if 'environment_loop' in num_nodes: - xm_resources['environment_loop'] = lp.DockerConfig( - acme_location, - requirements, - hw_requirements=xm.JobRequirements( - cpu=6, ram=6 * xm.GiB, P100=1), - python_path=python_path) - - if 'counter' in num_nodes: - xm_resources['counter'] = lp.DockerConfig( - acme_location, - requirements, - hw_requirements=xm.JobRequirements(cpu=3, ram=4 * xm.GiB), - python_path=python_path) - - if 'cacher' in num_nodes: - xm_resources['cacher'] = lp.DockerConfig( - acme_location, - requirements, - hw_requirements=xm.JobRequirements(cpu=3, ram=6 * xm.GiB), - python_path=python_path) - - return xm_resources + if FLAGS.lp_launch_type != "vertex_ai" and FLAGS.lp_launch_type != "local_docker": + # Avoid importing 'xmanager' for local runs. + return None + + # Avoid importing Launchpad until it is actually used. + import launchpad as lp # pylint: disable=g-import-not-at-top + + # Reference lp.DockerConfig to force lazy import of xmanager by Launchpad and + # then import it. It is done this way to avoid heavy imports by default. + lp.DockerConfig # pylint: disable=pointless-statement + from xmanager import xm # pylint: disable=g-import-not-at-top + + # Get number of each type of node. + num_nodes = {k: len(v) for k, v in program.groups.items()} + + xm_resources = {} + + acme_location = os.path.dirname(os.path.dirname(__file__)) + if not requirements: + # Acme requirements are located in the Acme directory (when installed + # with pip), or need to be extracted from setup.py when using Acme codebase + # from GitHub without PyPi installation. + requirements = os.path.join(acme_location, "requirements.txt") + if not os.path.isfile(requirements): + # Try to generate requirements.txt from setup.py + setup = os.path.join(os.path.dirname(acme_location), "setup.py") + if os.path.isfile(setup): + # Generate requirements.txt file using setup.py. + import importlib.util # pylint: disable=g-import-not-at-top + + spec = importlib.util.spec_from_file_location("setup", setup) + setup = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(setup) # pytype: disable=attribute-error + except SystemExit: + pass + atexit.register(os.remove, requirements) + setup.generate_requirements_file(requirements) + + # Extend PYTHONPATH with paths used by the launcher. + python_path = [] + for path in sys.path: + if path.startswith(acme_location) and acme_location != path: + python_path.append(path[len(acme_location) :]) + + if "replay" in num_nodes: + replay_cpu = 6 + num_nodes.get("actor", 0) * 0.01 + replay_cpu = min(40, replay_cpu) + + xm_resources["replay"] = lp.DockerConfig( + acme_location, + requirements, + hw_requirements=xm.JobRequirements(cpu=replay_cpu, ram=10 * xm.GiB), + python_path=python_path, + ) + + if "evaluator" in num_nodes: + xm_resources["evaluator"] = lp.DockerConfig( + acme_location, + requirements, + hw_requirements=xm.JobRequirements(cpu=2, ram=4 * xm.GiB), + python_path=python_path, + ) + + if "actor" in num_nodes: + xm_resources["actor"] = lp.DockerConfig( + acme_location, + requirements, + hw_requirements=xm.JobRequirements(cpu=2, ram=4 * xm.GiB), + python_path=python_path, + ) + + if "learner" in num_nodes: + learner_cpu = 6 + num_nodes.get("actor", 0) * 0.01 + learner_cpu = min(40, learner_cpu) + xm_resources["learner"] = lp.DockerConfig( + acme_location, + requirements, + hw_requirements=xm.JobRequirements(cpu=learner_cpu, ram=6 * xm.GiB, P100=1), + python_path=python_path, + ) + + if "environment_loop" in num_nodes: + xm_resources["environment_loop"] = lp.DockerConfig( + acme_location, + requirements, + hw_requirements=xm.JobRequirements(cpu=6, ram=6 * xm.GiB, P100=1), + python_path=python_path, + ) + + if "counter" in num_nodes: + xm_resources["counter"] = lp.DockerConfig( + acme_location, + requirements, + hw_requirements=xm.JobRequirements(cpu=3, ram=4 * xm.GiB), + python_path=python_path, + ) + + if "cacher" in num_nodes: + xm_resources["cacher"] = lp.DockerConfig( + acme_location, + requirements, + hw_requirements=xm.JobRequirements(cpu=3, ram=6 * xm.GiB), + python_path=python_path, + ) + + return xm_resources diff --git a/acme/utils/lp_utils_test.py b/acme/utils/lp_utils_test.py index d125469856..b92c6ddd25 100644 --- a/acme/utils/lp_utils_test.py +++ b/acme/utils/lp_utils_test.py @@ -14,39 +14,37 @@ """Tests for acme launchpad utilities.""" -from acme.utils import lp_utils - from absl.testing import absltest +from acme.utils import lp_utils -class LpUtilsTest(absltest.TestCase): - - def test_partial_kwargs(self): - def foo(a, b, c=2): - return a, b, c +class LpUtilsTest(absltest.TestCase): + def test_partial_kwargs(self): + def foo(a, b, c=2): + return a, b, c - def bar(a, b): - return a, b + def bar(a, b): + return a, b - # Override the default values. The last two should be no-ops. - foo1 = lp_utils.partial_kwargs(foo, c=1) - foo2 = lp_utils.partial_kwargs(foo) - bar1 = lp_utils.partial_kwargs(bar) + # Override the default values. The last two should be no-ops. + foo1 = lp_utils.partial_kwargs(foo, c=1) + foo2 = lp_utils.partial_kwargs(foo) + bar1 = lp_utils.partial_kwargs(bar) - # Check that we raise errors on overriding kwargs with no default values - with self.assertRaises(ValueError): - lp_utils.partial_kwargs(foo, a=2) + # Check that we raise errors on overriding kwargs with no default values + with self.assertRaises(ValueError): + lp_utils.partial_kwargs(foo, a=2) - # CHeck the we raise if we try to override a kwarg that doesn't exist. - with self.assertRaises(ValueError): - lp_utils.partial_kwargs(foo, d=2) + # CHeck the we raise if we try to override a kwarg that doesn't exist. + with self.assertRaises(ValueError): + lp_utils.partial_kwargs(foo, d=2) - # Make sure we get back the correct values. - self.assertEqual(foo1(1, 2), (1, 2, 1)) - self.assertEqual(foo2(1, 2), (1, 2, 2)) - self.assertEqual(bar1(1, 2), (1, 2)) + # Make sure we get back the correct values. + self.assertEqual(foo1(1, 2), (1, 2, 1)) + self.assertEqual(foo2(1, 2), (1, 2, 2)) + self.assertEqual(bar1(1, 2), (1, 2)) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/utils/metrics.py b/acme/utils/metrics.py index 5a1991f820..877f737839 100644 --- a/acme/utils/metrics.py +++ b/acme/utils/metrics.py @@ -16,8 +16,8 @@ from typing import Type, TypeVar -T = TypeVar('T') +T = TypeVar("T") def record_class_usage(cls: Type[T]) -> Type[T]: - return cls + return cls diff --git a/acme/utils/observers/__init__.py b/acme/utils/observers/__init__.py index 093853d086..5016e406f9 100644 --- a/acme/utils/observers/__init__.py +++ b/acme/utils/observers/__init__.py @@ -16,7 +16,6 @@ from acme.utils.observers.action_metrics import ContinuousActionObserver from acme.utils.observers.action_norm import ActionNormObserver -from acme.utils.observers.base import EnvLoopObserver -from acme.utils.observers.base import Number +from acme.utils.observers.base import EnvLoopObserver, Number from acme.utils.observers.env_info import EnvInfoObserver from acme.utils.observers.measurement_metrics import MeasurementObserver diff --git a/acme/utils/observers/action_metrics.py b/acme/utils/observers/action_metrics.py index cb5665f1e2..3d1c7646d1 100644 --- a/acme/utils/observers/action_metrics.py +++ b/acme/utils/observers/action_metrics.py @@ -16,51 +16,51 @@ from typing import Dict -from acme.utils.observers import base import dm_env import numpy as np +from acme.utils.observers import base + class ContinuousActionObserver(base.EnvLoopObserver): - """Observer that tracks statstics of continuous actions taken by the agent. + """Observer that tracks statstics of continuous actions taken by the agent. Assumes the action is a np.ndarray, and for each dimension in the action, calculates some useful statistics for a particular episode. """ - def __init__(self): - self._actions = None + def __init__(self): + self._actions = None - def observe_first(self, env: dm_env.Environment, - timestep: dm_env.TimeStep) -> None: - """Observes the initial state.""" - self._actions = [] + def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep) -> None: + """Observes the initial state.""" + self._actions = [] - def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, - action: np.ndarray) -> None: - """Records one environment step.""" - self._actions.append(action) + def observe( + self, env: dm_env.Environment, timestep: dm_env.TimeStep, action: np.ndarray + ) -> None: + """Records one environment step.""" + self._actions.append(action) - def get_metrics(self) -> Dict[str, base.Number]: - """Returns metrics collected for the current episode.""" - aggregate_metrics = {} - if not self._actions: - return aggregate_metrics + def get_metrics(self) -> Dict[str, base.Number]: + """Returns metrics collected for the current episode.""" + aggregate_metrics = {} + if not self._actions: + return aggregate_metrics - metrics = { - 'action_max': np.max(self._actions, axis=0), - 'action_min': np.min(self._actions, axis=0), - 'action_mean': np.mean(self._actions, axis=0), - 'action_p50': np.percentile(self._actions, q=50., axis=0) - } + metrics = { + "action_max": np.max(self._actions, axis=0), + "action_min": np.min(self._actions, axis=0), + "action_mean": np.mean(self._actions, axis=0), + "action_p50": np.percentile(self._actions, q=50.0, axis=0), + } - for index, sub_action_metric in np.ndenumerate(metrics['action_max']): - aggregate_metrics[f'action{list(index)}_max'] = sub_action_metric - aggregate_metrics[f'action{list(index)}_min'] = metrics['action_min'][ - index] - aggregate_metrics[f'action{list(index)}_mean'] = metrics['action_mean'][ - index] - aggregate_metrics[f'action{list(index)}_p50'] = metrics['action_p50'][ - index] + for index, sub_action_metric in np.ndenumerate(metrics["action_max"]): + aggregate_metrics[f"action{list(index)}_max"] = sub_action_metric + aggregate_metrics[f"action{list(index)}_min"] = metrics["action_min"][index] + aggregate_metrics[f"action{list(index)}_mean"] = metrics["action_mean"][ + index + ] + aggregate_metrics[f"action{list(index)}_p50"] = metrics["action_p50"][index] - return aggregate_metrics + return aggregate_metrics diff --git a/acme/utils/observers/action_metrics_test.py b/acme/utils/observers/action_metrics_test.py index 406e78c5fa..f662875b92 100644 --- a/acme/utils/observers/action_metrics_test.py +++ b/acme/utils/observers/action_metrics_test.py @@ -15,113 +15,115 @@ """Tests for action_metrics_observers.""" -from acme import specs -from acme.testing import fakes -from acme.utils.observers import action_metrics import dm_env import numpy as np - from absl.testing import absltest +from acme import specs +from acme.testing import fakes +from acme.utils.observers import action_metrics + def _make_fake_env() -> dm_env.Environment: - env_spec = specs.EnvironmentSpec( - observations=specs.Array(shape=(10, 5), dtype=np.float32), - actions=specs.BoundedArray( - shape=(1,), dtype=np.float32, minimum=-100., maximum=100.), - rewards=specs.Array(shape=(), dtype=np.float32), - discounts=specs.BoundedArray( - shape=(), dtype=np.float32, minimum=0., maximum=1.), - ) - return fakes.Environment(env_spec, episode_length=10) + env_spec = specs.EnvironmentSpec( + observations=specs.Array(shape=(10, 5), dtype=np.float32), + actions=specs.BoundedArray( + shape=(1,), dtype=np.float32, minimum=-100.0, maximum=100.0 + ), + rewards=specs.Array(shape=(), dtype=np.float32), + discounts=specs.BoundedArray( + shape=(), dtype=np.float32, minimum=0.0, maximum=1.0 + ), + ) + return fakes.Environment(env_spec, episode_length=10) + _FAKE_ENV = _make_fake_env() _TIMESTEP = _FAKE_ENV.reset() class ActionMetricsTest(absltest.TestCase): - - def test_observe_nothing(self): - observer = action_metrics.ContinuousActionObserver() - self.assertEqual({}, observer.get_metrics()) - - def test_observe_first(self): - observer = action_metrics.ContinuousActionObserver() - observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) - self.assertEqual({}, observer.get_metrics()) - - def test_observe_single_step(self): - observer = action_metrics.ContinuousActionObserver() - observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) - observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([1])) - self.assertEqual( - { - 'action[0]_max': 1, - 'action[0]_min': 1, - 'action[0]_mean': 1, - 'action[0]_p50': 1, - }, - observer.get_metrics(), - ) - - def test_observe_multiple_step(self): - observer = action_metrics.ContinuousActionObserver() - observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) - observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([1])) - observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([4])) - observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([5])) - self.assertEqual( - { - 'action[0]_max': 5, - 'action[0]_min': 1, - 'action[0]_mean': 10 / 3, - 'action[0]_p50': 4, - }, - observer.get_metrics(), - ) - - def test_observe_zero_dimensions(self): - observer = action_metrics.ContinuousActionObserver() - observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) - observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array(1)) - self.assertEqual( - { - 'action[]_max': 1, - 'action[]_min': 1, - 'action[]_mean': 1, - 'action[]_p50': 1, - }, - observer.get_metrics(), - ) - - def test_observe_multiple_dimensions(self): - observer = action_metrics.ContinuousActionObserver() - observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) - observer.observe( - env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([[1, 2], [3, 4]])) - np.testing.assert_equal( - { - 'action[0, 0]_max': 1, - 'action[0, 0]_min': 1, - 'action[0, 0]_mean': 1, - 'action[0, 0]_p50': 1, - 'action[0, 1]_max': 2, - 'action[0, 1]_min': 2, - 'action[0, 1]_mean': 2, - 'action[0, 1]_p50': 2, - 'action[1, 0]_max': 3, - 'action[1, 0]_min': 3, - 'action[1, 0]_mean': 3, - 'action[1, 0]_p50': 3, - 'action[1, 1]_max': 4, - 'action[1, 1]_min': 4, - 'action[1, 1]_mean': 4, - 'action[1, 1]_p50': 4, - }, - observer.get_metrics(), - ) - - -if __name__ == '__main__': - absltest.main() - + def test_observe_nothing(self): + observer = action_metrics.ContinuousActionObserver() + self.assertEqual({}, observer.get_metrics()) + + def test_observe_first(self): + observer = action_metrics.ContinuousActionObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + self.assertEqual({}, observer.get_metrics()) + + def test_observe_single_step(self): + observer = action_metrics.ContinuousActionObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([1])) + self.assertEqual( + { + "action[0]_max": 1, + "action[0]_min": 1, + "action[0]_mean": 1, + "action[0]_p50": 1, + }, + observer.get_metrics(), + ) + + def test_observe_multiple_step(self): + observer = action_metrics.ContinuousActionObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([1])) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([4])) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([5])) + self.assertEqual( + { + "action[0]_max": 5, + "action[0]_min": 1, + "action[0]_mean": 10 / 3, + "action[0]_p50": 4, + }, + observer.get_metrics(), + ) + + def test_observe_zero_dimensions(self): + observer = action_metrics.ContinuousActionObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array(1)) + self.assertEqual( + { + "action[]_max": 1, + "action[]_min": 1, + "action[]_mean": 1, + "action[]_p50": 1, + }, + observer.get_metrics(), + ) + + def test_observe_multiple_dimensions(self): + observer = action_metrics.ContinuousActionObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + observer.observe( + env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([[1, 2], [3, 4]]) + ) + np.testing.assert_equal( + { + "action[0, 0]_max": 1, + "action[0, 0]_min": 1, + "action[0, 0]_mean": 1, + "action[0, 0]_p50": 1, + "action[0, 1]_max": 2, + "action[0, 1]_min": 2, + "action[0, 1]_mean": 2, + "action[0, 1]_p50": 2, + "action[1, 0]_max": 3, + "action[1, 0]_min": 3, + "action[1, 0]_mean": 3, + "action[1, 0]_p50": 3, + "action[1, 1]_max": 4, + "action[1, 1]_min": 4, + "action[1, 1]_mean": 4, + "action[1, 1]_p50": 4, + }, + observer.get_metrics(), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/utils/observers/action_norm.py b/acme/utils/observers/action_norm.py index ed20aaafc9..95e35a8806 100644 --- a/acme/utils/observers/action_norm.py +++ b/acme/utils/observers/action_norm.py @@ -16,29 +16,32 @@ """ from typing import Dict -from acme.utils.observers import base import dm_env import numpy as np +from acme.utils.observers import base + class ActionNormObserver(base.EnvLoopObserver): - """An observer that collects action norm stats.""" - - def __init__(self): - self._action_norms = None - - def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep - ) -> None: - """Observes the initial state.""" - self._action_norms = [] - - def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, - action: np.ndarray) -> None: - """Records one environment step.""" - self._action_norms.append(np.linalg.norm(action)) - - def get_metrics(self) -> Dict[str, base.Number]: - """Returns metrics collected for the current episode.""" - return {'action_norm_avg': np.mean(self._action_norms), - 'action_norm_min': np.min(self._action_norms), - 'action_norm_max': np.max(self._action_norms)} + """An observer that collects action norm stats.""" + + def __init__(self): + self._action_norms = None + + def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep) -> None: + """Observes the initial state.""" + self._action_norms = [] + + def observe( + self, env: dm_env.Environment, timestep: dm_env.TimeStep, action: np.ndarray + ) -> None: + """Records one environment step.""" + self._action_norms.append(np.linalg.norm(action)) + + def get_metrics(self) -> Dict[str, base.Number]: + """Returns metrics collected for the current episode.""" + return { + "action_norm_avg": np.mean(self._action_norms), + "action_norm_min": np.min(self._action_norms), + "action_norm_max": np.max(self._action_norms), + } diff --git a/acme/utils/observers/action_norm_test.py b/acme/utils/observers/action_norm_test.py index d6732f247a..1d743ae560 100644 --- a/acme/utils/observers/action_norm_test.py +++ b/acme/utils/observers/action_norm_test.py @@ -14,44 +14,45 @@ """Tests for acme.utils.observers.action_norm.""" -from acme import specs -from acme.testing import fakes -from acme.utils.observers import action_norm import dm_env import numpy as np - from absl.testing import absltest +from acme import specs +from acme.testing import fakes +from acme.utils.observers import action_norm + def _make_fake_env() -> dm_env.Environment: - env_spec = specs.EnvironmentSpec( - observations=specs.Array(shape=(10, 5), dtype=np.float32), - actions=specs.BoundedArray( - shape=(1,), dtype=np.float32, minimum=-10., maximum=10.), - rewards=specs.Array(shape=(), dtype=np.float32), - discounts=specs.BoundedArray( - shape=(), dtype=np.float32, minimum=0., maximum=1.), - ) - return fakes.Environment(env_spec, episode_length=10) + env_spec = specs.EnvironmentSpec( + observations=specs.Array(shape=(10, 5), dtype=np.float32), + actions=specs.BoundedArray( + shape=(1,), dtype=np.float32, minimum=-10.0, maximum=10.0 + ), + rewards=specs.Array(shape=(), dtype=np.float32), + discounts=specs.BoundedArray( + shape=(), dtype=np.float32, minimum=0.0, maximum=1.0 + ), + ) + return fakes.Environment(env_spec, episode_length=10) class ActionNormTest(absltest.TestCase): - - def test_basic(self): - env = _make_fake_env() - observer = action_norm.ActionNormObserver() - timestep = env.reset() - observer.observe_first(env, timestep) - for it in range(5): - action = np.ones((1,), dtype=np.float32) * it - timestep = env.step(action) - observer.observe(env, timestep, action) - metrics = observer.get_metrics() - self.assertLen(metrics, 3) - np.testing.assert_equal(metrics['action_norm_min'], 0) - np.testing.assert_equal(metrics['action_norm_max'], 4) - np.testing.assert_equal(metrics['action_norm_avg'], 2) - - -if __name__ == '__main__': - absltest.main() + def test_basic(self): + env = _make_fake_env() + observer = action_norm.ActionNormObserver() + timestep = env.reset() + observer.observe_first(env, timestep) + for it in range(5): + action = np.ones((1,), dtype=np.float32) * it + timestep = env.step(action) + observer.observe(env, timestep, action) + metrics = observer.get_metrics() + self.assertLen(metrics, 3) + np.testing.assert_equal(metrics["action_norm_min"], 0) + np.testing.assert_equal(metrics["action_norm_max"], 4) + np.testing.assert_equal(metrics["action_norm_avg"], 2) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/utils/observers/base.py b/acme/utils/observers/base.py index 3e85a71b37..f64e5f953f 100644 --- a/acme/utils/observers/base.py +++ b/acme/utils/observers/base.py @@ -20,23 +20,22 @@ import dm_env import numpy as np - Number = Union[int, float] class EnvLoopObserver(abc.ABC): - """An interface for collecting metrics/counters in EnvironmentLoop.""" + """An interface for collecting metrics/counters in EnvironmentLoop.""" - @abc.abstractmethod - def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep - ) -> None: - """Observes the initial state.""" + @abc.abstractmethod + def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep) -> None: + """Observes the initial state.""" - @abc.abstractmethod - def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, - action: np.ndarray) -> None: - """Records one environment step.""" + @abc.abstractmethod + def observe( + self, env: dm_env.Environment, timestep: dm_env.TimeStep, action: np.ndarray + ) -> None: + """Records one environment step.""" - @abc.abstractmethod - def get_metrics(self) -> Dict[str, Number]: - """Returns metrics collected for the current episode.""" + @abc.abstractmethod + def get_metrics(self) -> Dict[str, Number]: + """Returns metrics collected for the current episode.""" diff --git a/acme/utils/observers/env_info.py b/acme/utils/observers/env_info.py index 5fc77dcab2..bd3410e977 100644 --- a/acme/utils/observers/env_info.py +++ b/acme/utils/observers/env_info.py @@ -16,38 +16,39 @@ """ from typing import Dict -from acme.utils.observers import base import dm_env import numpy as np +from acme.utils.observers import base + class EnvInfoObserver(base.EnvLoopObserver): - """An observer that collects and accumulates scalars from env's info.""" - - def __init__(self): - self._metrics = None - - def _accumulate_metrics(self, env: dm_env.Environment) -> None: - if not hasattr(env, 'get_info'): - return - info = getattr(env, 'get_info')() - if not info: - return - for k, v in info.items(): - if np.isscalar(v): - self._metrics[k] = self._metrics.get(k, 0) + v - - def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep - ) -> None: - """Observes the initial state.""" - self._metrics = {} - self._accumulate_metrics(env) - - def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, - action: np.ndarray) -> None: - """Records one environment step.""" - self._accumulate_metrics(env) - - def get_metrics(self) -> Dict[str, base.Number]: - """Returns metrics collected for the current episode.""" - return self._metrics + """An observer that collects and accumulates scalars from env's info.""" + + def __init__(self): + self._metrics = None + + def _accumulate_metrics(self, env: dm_env.Environment) -> None: + if not hasattr(env, "get_info"): + return + info = getattr(env, "get_info")() + if not info: + return + for k, v in info.items(): + if np.isscalar(v): + self._metrics[k] = self._metrics.get(k, 0) + v + + def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep) -> None: + """Observes the initial state.""" + self._metrics = {} + self._accumulate_metrics(env) + + def observe( + self, env: dm_env.Environment, timestep: dm_env.TimeStep, action: np.ndarray + ) -> None: + """Records one environment step.""" + self._accumulate_metrics(env) + + def get_metrics(self) -> Dict[str, base.Number]: + """Returns metrics collected for the current episode.""" + return self._metrics diff --git a/acme/utils/observers/env_info_test.py b/acme/utils/observers/env_info_test.py index f8baabd3a4..22aeb65741 100644 --- a/acme/utils/observers/env_info_test.py +++ b/acme/utils/observers/env_info_test.py @@ -14,56 +14,54 @@ """Tests for acme.utils.observers.env_info.""" -from acme.utils.observers import env_info -from acme.wrappers import gym_wrapper import gym -from gym import spaces import numpy as np - from absl.testing import absltest +from gym import spaces +from acme.utils.observers import env_info +from acme.wrappers import gym_wrapper -class GymEnvWithInfo(gym.Env): - def __init__(self): - obs_space = np.ones((10,)) - self.observation_space = spaces.Box(-obs_space, obs_space, dtype=np.float32) - act_space = np.ones((3,)) - self.action_space = spaces.Box(-act_space, act_space, dtype=np.float32) - self._step = 0 +class GymEnvWithInfo(gym.Env): + def __init__(self): + obs_space = np.ones((10,)) + self.observation_space = spaces.Box(-obs_space, obs_space, dtype=np.float32) + act_space = np.ones((3,)) + self.action_space = spaces.Box(-act_space, act_space, dtype=np.float32) + self._step = 0 - def reset(self): - self._step = 0 - return self.observation_space.sample() + def reset(self): + self._step = 0 + return self.observation_space.sample() - def step(self, action: np.ndarray): - self._step += 1 - info = {'survival_bonus': 1} - if self._step == 1 or self._step == 7: - info['found_checkpoint'] = 1 - if self._step == 5: - info['picked_up_an_apple'] = 1 - return self.observation_space.sample(), 0, False, info + def step(self, action: np.ndarray): + self._step += 1 + info = {"survival_bonus": 1} + if self._step == 1 or self._step == 7: + info["found_checkpoint"] = 1 + if self._step == 5: + info["picked_up_an_apple"] = 1 + return self.observation_space.sample(), 0, False, info class ActionNormTest(absltest.TestCase): - - def test_basic(self): - env = GymEnvWithInfo() - env = gym_wrapper.GymWrapper(env) - observer = env_info.EnvInfoObserver() - timestep = env.reset() - observer.observe_first(env, timestep) - for _ in range(20): - action = np.zeros((3,)) - timestep = env.step(action) - observer.observe(env, timestep, action) - metrics = observer.get_metrics() - self.assertLen(metrics, 3) - np.testing.assert_equal(metrics['found_checkpoint'], 2) - np.testing.assert_equal(metrics['picked_up_an_apple'], 1) - np.testing.assert_equal(metrics['survival_bonus'], 20) + def test_basic(self): + env = GymEnvWithInfo() + env = gym_wrapper.GymWrapper(env) + observer = env_info.EnvInfoObserver() + timestep = env.reset() + observer.observe_first(env, timestep) + for _ in range(20): + action = np.zeros((3,)) + timestep = env.step(action) + observer.observe(env, timestep, action) + metrics = observer.get_metrics() + self.assertLen(metrics, 3) + np.testing.assert_equal(metrics["found_checkpoint"], 2) + np.testing.assert_equal(metrics["picked_up_an_apple"], 1) + np.testing.assert_equal(metrics["survival_bonus"], 20) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/utils/observers/measurement_metrics.py b/acme/utils/observers/measurement_metrics.py index 7199ec636f..b2a4cca797 100644 --- a/acme/utils/observers/measurement_metrics.py +++ b/acme/utils/observers/measurement_metrics.py @@ -14,15 +14,16 @@ """An observer that tracks statistics about the observations.""" -from typing import Mapping, List +from typing import List, Mapping -from acme.utils.observers import base import dm_env import numpy as np +from acme.utils.observers import base + class MeasurementObserver(base.EnvLoopObserver): - """Observer the provides statistics for measurements at every timestep. + """Observer the provides statistics for measurements at every timestep. This assumes the measurements is a multidimensional array with a static spec. Warning! It is not intended to be used for high dimensional observations. @@ -30,45 +31,52 @@ class MeasurementObserver(base.EnvLoopObserver): self._measurements: List[np.ndarray] """ - def __init__(self): - self._measurements = [] + def __init__(self): + self._measurements = [] - def observe_first(self, env: dm_env.Environment, - timestep: dm_env.TimeStep) -> None: - """Observes the initial state.""" - self._measurements = [] + def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep) -> None: + """Observes the initial state.""" + self._measurements = [] - def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, - action: np.ndarray) -> None: - """Records one environment step.""" - self._measurements.append(timestep.observation) + def observe( + self, env: dm_env.Environment, timestep: dm_env.TimeStep, action: np.ndarray + ) -> None: + """Records one environment step.""" + self._measurements.append(timestep.observation) - def get_metrics(self) -> Mapping[str, List[base.Number]]: # pytype: disable=signature-mismatch # overriding-return-type-checks - """Returns metrics collected for the current episode.""" - aggregate_metrics = {} - if not self._measurements: - return aggregate_metrics + def get_metrics( + self, + ) -> Mapping[ + str, List[base.Number] + ]: # pytype: disable=signature-mismatch # overriding-return-type-checks + """Returns metrics collected for the current episode.""" + aggregate_metrics = {} + if not self._measurements: + return aggregate_metrics - metrics = { - 'measurement_max': np.max(self._measurements, axis=0), - 'measurement_min': np.min(self._measurements, axis=0), - 'measurement_mean': np.mean(self._measurements, axis=0), - 'measurement_p25': np.percentile(self._measurements, q=25., axis=0), - 'measurement_p50': np.percentile(self._measurements, q=50., axis=0), - 'measurement_p75': np.percentile(self._measurements, q=75., axis=0), - } - for index, sub_observation_metric in np.ndenumerate( - metrics['measurement_max']): - aggregate_metrics[ - f'measurement{list(index)}_max'] = sub_observation_metric - aggregate_metrics[f'measurement{list(index)}_min'] = metrics[ - 'measurement_min'][index] - aggregate_metrics[f'measurement{list(index)}_mean'] = metrics[ - 'measurement_mean'][index] - aggregate_metrics[f'measurement{list(index)}_p50'] = metrics[ - 'measurement_p50'][index] - aggregate_metrics[f'measurement{list(index)}_p25'] = metrics[ - 'measurement_p25'][index] - aggregate_metrics[f'measurement{list(index)}_p75'] = metrics[ - 'measurement_p75'][index] - return aggregate_metrics + metrics = { + "measurement_max": np.max(self._measurements, axis=0), + "measurement_min": np.min(self._measurements, axis=0), + "measurement_mean": np.mean(self._measurements, axis=0), + "measurement_p25": np.percentile(self._measurements, q=25.0, axis=0), + "measurement_p50": np.percentile(self._measurements, q=50.0, axis=0), + "measurement_p75": np.percentile(self._measurements, q=75.0, axis=0), + } + for index, sub_observation_metric in np.ndenumerate(metrics["measurement_max"]): + aggregate_metrics[f"measurement{list(index)}_max"] = sub_observation_metric + aggregate_metrics[f"measurement{list(index)}_min"] = metrics[ + "measurement_min" + ][index] + aggregate_metrics[f"measurement{list(index)}_mean"] = metrics[ + "measurement_mean" + ][index] + aggregate_metrics[f"measurement{list(index)}_p50"] = metrics[ + "measurement_p50" + ][index] + aggregate_metrics[f"measurement{list(index)}_p25"] = metrics[ + "measurement_p25" + ][index] + aggregate_metrics[f"measurement{list(index)}_p75"] = metrics[ + "measurement_p75" + ][index] + return aggregate_metrics diff --git a/acme/utils/observers/measurement_metrics_test.py b/acme/utils/observers/measurement_metrics_test.py index 31c97d37a5..8ce23a71f1 100644 --- a/acme/utils/observers/measurement_metrics_test.py +++ b/acme/utils/observers/measurement_metrics_test.py @@ -17,25 +17,27 @@ import copy from unittest import mock -from acme import specs -from acme.testing import fakes -from acme.utils.observers import measurement_metrics import dm_env import numpy as np - from absl.testing import absltest +from acme import specs +from acme.testing import fakes +from acme.utils.observers import measurement_metrics + def _make_fake_env() -> dm_env.Environment: - env_spec = specs.EnvironmentSpec( - observations=specs.Array(shape=(10, 5), dtype=np.float32), - actions=specs.BoundedArray( - shape=(1,), dtype=np.float32, minimum=-100., maximum=100.), - rewards=specs.Array(shape=(), dtype=np.float32), - discounts=specs.BoundedArray( - shape=(), dtype=np.float32, minimum=0., maximum=1.), - ) - return fakes.Environment(env_spec, episode_length=10) + env_spec = specs.EnvironmentSpec( + observations=specs.Array(shape=(10, 5), dtype=np.float32), + actions=specs.BoundedArray( + shape=(1,), dtype=np.float32, minimum=-100.0, maximum=100.0 + ), + rewards=specs.Array(shape=(), dtype=np.float32), + discounts=specs.BoundedArray( + shape=(), dtype=np.float32, minimum=0.0, maximum=1.0 + ), + ) + return fakes.Environment(env_spec, episode_length=10) _FAKE_ENV = _make_fake_env() @@ -45,128 +47,130 @@ def _make_fake_env() -> dm_env.Environment: class MeasurementMetricsTest(absltest.TestCase): - - def test_observe_nothing(self): - observer = measurement_metrics.MeasurementObserver() - self.assertEqual({}, observer.get_metrics()) - - def test_observe_first(self): - observer = measurement_metrics.MeasurementObserver() - observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) - self.assertEqual({}, observer.get_metrics()) - - def test_observe_single_step(self): - observer = measurement_metrics.MeasurementObserver() - observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) - observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([1])) - self.assertEqual( - { - 'measurement[0]_max': 1.0, - 'measurement[0]_mean': 1.0, - 'measurement[0]_p25': 1.0, - 'measurement[0]_p50': 1.0, - 'measurement[0]_p75': 1.0, - 'measurement[1]_max': -2.0, - 'measurement[1]_mean': -2.0, - 'measurement[1]_p25': -2.0, - 'measurement[1]_p50': -2.0, - 'measurement[1]_p75': -2.0, - 'measurement[0]_min': 1.0, - 'measurement[1]_min': -2.0, - }, - observer.get_metrics(), - ) - - def test_observe_multiple_step_same_observation(self): - observer = measurement_metrics.MeasurementObserver() - observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) - observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([1])) - observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([4])) - observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([5])) - self.assertEqual( - { - 'measurement[0]_max': 1.0, - 'measurement[0]_mean': 1.0, - 'measurement[0]_p25': 1.0, - 'measurement[0]_p50': 1.0, - 'measurement[0]_p75': 1.0, - 'measurement[1]_max': -2.0, - 'measurement[1]_mean': -2.0, - 'measurement[1]_p25': -2.0, - 'measurement[1]_p50': -2.0, - 'measurement[1]_p75': -2.0, - 'measurement[0]_min': 1.0, - 'measurement[1]_min': -2.0, - }, - observer.get_metrics(), - ) - - def test_observe_multiple_step(self): - observer = measurement_metrics.MeasurementObserver() - observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) - observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([1])) - first_obs_timestep = copy.deepcopy(_TIMESTEP) - first_obs_timestep.observation = [1000.0, -50.0] - observer.observe( - env=_FAKE_ENV, timestep=first_obs_timestep, action=np.array([4])) - second_obs_timestep = copy.deepcopy(_TIMESTEP) - second_obs_timestep.observation = [-1000.0, 500.0] - observer.observe( - env=_FAKE_ENV, timestep=second_obs_timestep, action=np.array([4])) - self.assertEqual( - { - 'measurement[0]_max': 1000.0, - 'measurement[0]_mean': 1.0/3, - 'measurement[0]_p25': -499.5, - 'measurement[0]_p50': 1.0, - 'measurement[0]_p75': 500.5, - 'measurement[1]_max': 500.0, - 'measurement[1]_mean': 448.0/3.0, - 'measurement[1]_p25': -26.0, - 'measurement[1]_p50': -2.0, - 'measurement[1]_p75': 249.0, - 'measurement[0]_min': -1000.0, - 'measurement[1]_min': -50.0, - }, - observer.get_metrics(), - ) - - def test_observe_empty_observation(self): - observer = measurement_metrics.MeasurementObserver() - empty_timestep = copy.deepcopy(_TIMESTEP) - empty_timestep.observation = {} - observer.observe_first(env=_FAKE_ENV, timestep=empty_timestep) - self.assertEqual({}, observer.get_metrics()) - - def test_observe_single_dimensions(self): - observer = measurement_metrics.MeasurementObserver() - observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) - single_obs_timestep = copy.deepcopy(_TIMESTEP) - single_obs_timestep.observation = [1000.0, -50.0] - - observer.observe( - env=_FAKE_ENV, - timestep=single_obs_timestep, - action=np.array([[1, 2], [3, 4]])) - - np.testing.assert_equal( - { - 'measurement[0]_max': 1000.0, - 'measurement[0]_min': 1000.0, - 'measurement[0]_mean': 1000.0, - 'measurement[0]_p25': 1000.0, - 'measurement[0]_p50': 1000.0, - 'measurement[0]_p75': 1000.0, - 'measurement[1]_max': -50.0, - 'measurement[1]_mean': -50.0, - 'measurement[1]_p25': -50.0, - 'measurement[1]_p50': -50.0, - 'measurement[1]_p75': -50.0, - 'measurement[1]_min': -50.0, - }, - observer.get_metrics(), - ) - - -if __name__ == '__main__': - absltest.main() + def test_observe_nothing(self): + observer = measurement_metrics.MeasurementObserver() + self.assertEqual({}, observer.get_metrics()) + + def test_observe_first(self): + observer = measurement_metrics.MeasurementObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + self.assertEqual({}, observer.get_metrics()) + + def test_observe_single_step(self): + observer = measurement_metrics.MeasurementObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([1])) + self.assertEqual( + { + "measurement[0]_max": 1.0, + "measurement[0]_mean": 1.0, + "measurement[0]_p25": 1.0, + "measurement[0]_p50": 1.0, + "measurement[0]_p75": 1.0, + "measurement[1]_max": -2.0, + "measurement[1]_mean": -2.0, + "measurement[1]_p25": -2.0, + "measurement[1]_p50": -2.0, + "measurement[1]_p75": -2.0, + "measurement[0]_min": 1.0, + "measurement[1]_min": -2.0, + }, + observer.get_metrics(), + ) + + def test_observe_multiple_step_same_observation(self): + observer = measurement_metrics.MeasurementObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([1])) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([4])) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([5])) + self.assertEqual( + { + "measurement[0]_max": 1.0, + "measurement[0]_mean": 1.0, + "measurement[0]_p25": 1.0, + "measurement[0]_p50": 1.0, + "measurement[0]_p75": 1.0, + "measurement[1]_max": -2.0, + "measurement[1]_mean": -2.0, + "measurement[1]_p25": -2.0, + "measurement[1]_p50": -2.0, + "measurement[1]_p75": -2.0, + "measurement[0]_min": 1.0, + "measurement[1]_min": -2.0, + }, + observer.get_metrics(), + ) + + def test_observe_multiple_step(self): + observer = measurement_metrics.MeasurementObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + observer.observe(env=_FAKE_ENV, timestep=_TIMESTEP, action=np.array([1])) + first_obs_timestep = copy.deepcopy(_TIMESTEP) + first_obs_timestep.observation = [1000.0, -50.0] + observer.observe( + env=_FAKE_ENV, timestep=first_obs_timestep, action=np.array([4]) + ) + second_obs_timestep = copy.deepcopy(_TIMESTEP) + second_obs_timestep.observation = [-1000.0, 500.0] + observer.observe( + env=_FAKE_ENV, timestep=second_obs_timestep, action=np.array([4]) + ) + self.assertEqual( + { + "measurement[0]_max": 1000.0, + "measurement[0]_mean": 1.0 / 3, + "measurement[0]_p25": -499.5, + "measurement[0]_p50": 1.0, + "measurement[0]_p75": 500.5, + "measurement[1]_max": 500.0, + "measurement[1]_mean": 448.0 / 3.0, + "measurement[1]_p25": -26.0, + "measurement[1]_p50": -2.0, + "measurement[1]_p75": 249.0, + "measurement[0]_min": -1000.0, + "measurement[1]_min": -50.0, + }, + observer.get_metrics(), + ) + + def test_observe_empty_observation(self): + observer = measurement_metrics.MeasurementObserver() + empty_timestep = copy.deepcopy(_TIMESTEP) + empty_timestep.observation = {} + observer.observe_first(env=_FAKE_ENV, timestep=empty_timestep) + self.assertEqual({}, observer.get_metrics()) + + def test_observe_single_dimensions(self): + observer = measurement_metrics.MeasurementObserver() + observer.observe_first(env=_FAKE_ENV, timestep=_TIMESTEP) + single_obs_timestep = copy.deepcopy(_TIMESTEP) + single_obs_timestep.observation = [1000.0, -50.0] + + observer.observe( + env=_FAKE_ENV, + timestep=single_obs_timestep, + action=np.array([[1, 2], [3, 4]]), + ) + + np.testing.assert_equal( + { + "measurement[0]_max": 1000.0, + "measurement[0]_min": 1000.0, + "measurement[0]_mean": 1000.0, + "measurement[0]_p25": 1000.0, + "measurement[0]_p50": 1000.0, + "measurement[0]_p75": 1000.0, + "measurement[1]_max": -50.0, + "measurement[1]_mean": -50.0, + "measurement[1]_p25": -50.0, + "measurement[1]_p50": -50.0, + "measurement[1]_p75": -50.0, + "measurement[1]_min": -50.0, + }, + observer.get_metrics(), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/utils/paths.py b/acme/utils/paths.py index 1aa3ea6c1f..a80438a9fa 100644 --- a/acme/utils/paths.py +++ b/acme/utils/paths.py @@ -22,16 +22,17 @@ from absl import flags -ACME_ID = flags.DEFINE_string('acme_id', None, - 'Experiment identifier to use for Acme.') +ACME_ID = flags.DEFINE_string("acme_id", None, "Experiment identifier to use for Acme.") -def process_path(path: str, - *subpaths: str, - ttl_seconds: Optional[int] = None, - backups: Optional[bool] = None, - add_uid: bool = True) -> str: - """Process the path string. +def process_path( + path: str, + *subpaths: str, + ttl_seconds: Optional[int] = None, + backups: Optional[bool] = None, + add_uid: bool = True +) -> str: + """Process the path string. This will process the path string by running `os.path.expanduser` to replace any initial "~". It will also append a unique string on the end of the path @@ -49,35 +50,35 @@ def process_path(path: str, Returns: the processed, expanded path string. """ - del backups, ttl_seconds + del backups, ttl_seconds - path = os.path.expanduser(path) - if add_uid: - path = os.path.join(path, *get_unique_id()) - path = os.path.join(path, *subpaths) - os.makedirs(path, exist_ok=True) - return path + path = os.path.expanduser(path) + if add_uid: + path = os.path.join(path, *get_unique_id()) + path = os.path.join(path, *subpaths) + os.makedirs(path, exist_ok=True) + return path -_DATETIME = time.strftime('%Y%m%d-%H%M%S') +_DATETIME = time.strftime("%Y%m%d-%H%M%S") def get_unique_id() -> Tuple[str, ...]: - """Makes a unique identifier for this process; override with --acme_id.""" - # By default we'll use the global id. - identifier = _DATETIME + """Makes a unique identifier for this process; override with --acme_id.""" + # By default we'll use the global id. + identifier = _DATETIME - # If the --acme_id flag is given prefer that; ignore if flag processing has - # been skipped (this happens in colab or in tests). - try: - identifier = ACME_ID.value or identifier - except flags.UnparsedFlagAccessError: - pass + # If the --acme_id flag is given prefer that; ignore if flag processing has + # been skipped (this happens in colab or in tests). + try: + identifier = ACME_ID.value or identifier + except flags.UnparsedFlagAccessError: + pass - # Return as a tuple (for future proofing). - return (identifier,) + # Return as a tuple (for future proofing). + return (identifier,) def rmdir(path: str): - """Remove directory recursively.""" - shutil.rmtree(path) + """Remove directory recursively.""" + shutil.rmtree(path) diff --git a/acme/utils/paths_test.py b/acme/utils/paths_test.py index 6af5d6a1ba..977eefb6a8 100644 --- a/acme/utils/paths_test.py +++ b/acme/utils/paths_test.py @@ -16,26 +16,24 @@ from unittest import mock -from acme.testing import test_utils -import acme.utils.paths as paths +from absl.testing import absltest, flagsaver -from absl.testing import flagsaver -from absl.testing import absltest +import acme.utils.paths as paths +from acme.testing import test_utils class PathTest(test_utils.TestCase): + def test_process_path(self): + root_directory = self.get_tempdir() + with mock.patch.object(paths, "get_unique_id") as mock_unique_id: + mock_unique_id.return_value = ("test",) + path = paths.process_path(root_directory, "foo", "bar") + self.assertEqual(path, f"{root_directory}/test/foo/bar") - def test_process_path(self): - root_directory = self.get_tempdir() - with mock.patch.object(paths, 'get_unique_id') as mock_unique_id: - mock_unique_id.return_value = ('test',) - path = paths.process_path(root_directory, 'foo', 'bar') - self.assertEqual(path, f'{root_directory}/test/foo/bar') - - def test_unique_id_with_flag(self): - with flagsaver.flagsaver((paths.ACME_ID, 'test_flag')): - self.assertEqual(paths.get_unique_id(), ('test_flag',)) + def test_unique_id_with_flag(self): + with flagsaver.flagsaver((paths.ACME_ID, "test_flag")): + self.assertEqual(paths.get_unique_id(), ("test_flag",)) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/utils/reverb_utils.py b/acme/utils/reverb_utils.py index 5df39153a7..711cb622ab 100644 --- a/acme/utils/reverb_utils.py +++ b/acme/utils/reverb_utils.py @@ -17,20 +17,18 @@ Contains functions manipulating reverb tables and samples. """ -from acme import types import jax import numpy as np import reverb -from reverb import item_selectors -from reverb import rate_limiters -from reverb import reverb_types import tensorflow as tf import tree +from reverb import item_selectors, rate_limiters, reverb_types + +from acme import types -def make_replay_table_from_info( - table_info: reverb_types.TableInfo) -> reverb.Table: - """Build a replay table out of its specs in a TableInfo. +def make_replay_table_from_info(table_info: reverb_types.TableInfo) -> reverb.Table: + """Build a replay table out of its specs in a TableInfo. Args: table_info: A TableInfo containing the Table specs. @@ -38,55 +36,55 @@ def make_replay_table_from_info( Returns: A reverb replay table matching the info specs. """ - sampler = _make_selector_from_key_distribution_options( - table_info.sampler_options) - remover = _make_selector_from_key_distribution_options( - table_info.remover_options) - rate_limiter = _make_rate_limiter_from_rate_limiter_info( - table_info.rate_limiter_info) - return reverb.Table( - name=table_info.name, - sampler=sampler, - remover=remover, - max_size=table_info.max_size, - rate_limiter=rate_limiter, - max_times_sampled=table_info.max_times_sampled, - signature=table_info.signature) - - -def _make_selector_from_key_distribution_options( - options) -> reverb_types.SelectorType: - """Returns a Selector from its KeyDistributionOptions description.""" - one_of = options.WhichOneof('distribution') - if one_of == 'fifo': - return item_selectors.Fifo() - if one_of == 'uniform': - return item_selectors.Uniform() - if one_of == 'prioritized': - return item_selectors.Prioritized(options.prioritized.priority_exponent) - if one_of == 'heap': - if options.heap.min_heap: - return item_selectors.MinHeap() - return item_selectors.MaxHeap() - if one_of == 'lifo': - return item_selectors.Lifo() - raise ValueError(f'Unknown distribution field: {one_of}') - - -def _make_rate_limiter_from_rate_limiter_info( - info) -> rate_limiters.RateLimiter: - return rate_limiters.SampleToInsertRatio( - samples_per_insert=info.samples_per_insert, - min_size_to_sample=info.min_size_to_sample, - error_buffer=(info.min_diff, info.max_diff)) + sampler = _make_selector_from_key_distribution_options(table_info.sampler_options) + remover = _make_selector_from_key_distribution_options(table_info.remover_options) + rate_limiter = _make_rate_limiter_from_rate_limiter_info( + table_info.rate_limiter_info + ) + return reverb.Table( + name=table_info.name, + sampler=sampler, + remover=remover, + max_size=table_info.max_size, + rate_limiter=rate_limiter, + max_times_sampled=table_info.max_times_sampled, + signature=table_info.signature, + ) + + +def _make_selector_from_key_distribution_options(options) -> reverb_types.SelectorType: + """Returns a Selector from its KeyDistributionOptions description.""" + one_of = options.WhichOneof("distribution") + if one_of == "fifo": + return item_selectors.Fifo() + if one_of == "uniform": + return item_selectors.Uniform() + if one_of == "prioritized": + return item_selectors.Prioritized(options.prioritized.priority_exponent) + if one_of == "heap": + if options.heap.min_heap: + return item_selectors.MinHeap() + return item_selectors.MaxHeap() + if one_of == "lifo": + return item_selectors.Lifo() + raise ValueError(f"Unknown distribution field: {one_of}") + + +def _make_rate_limiter_from_rate_limiter_info(info) -> rate_limiters.RateLimiter: + return rate_limiters.SampleToInsertRatio( + samples_per_insert=info.samples_per_insert, + min_size_to_sample=info.min_size_to_sample, + error_buffer=(info.min_diff, info.max_diff), + ) def replay_sample_to_sars_transition( sample: reverb.ReplaySample, is_sequence: bool, strip_last_transition: bool = False, - flatten_batch: bool = False) -> types.Transition: - """Converts the replay sample to a types.Transition. + flatten_batch: bool = False, +) -> types.Transition: + """Converts the replay sample to a types.Transition. NB: If is_sequence is True then the last next_observation of each sequence is rubbish. Don't train on it. @@ -107,34 +105,38 @@ def replay_sample_to_sars_transition( smaller than the output as the last transition of every sequence will have been removed. """ - if not is_sequence: - return types.Transition(*sample.data) - # Note that the last next_observation is invalid. - steps = sample.data - def roll(observation): - return np.roll(observation, shift=-1, axis=1) - transitions = types.Transition( - observation=steps.observation, - action=steps.action, - reward=steps.reward, - discount=steps.discount, - next_observation=tree.map_structure(roll, steps.observation), - extras=steps.extras) - if strip_last_transition: - # We remove the last transition as its next_observation field is incorrect. - # It has been obtained by rolling the observation field, such that - # transitions.next_observations[:, -1] is transitions.observations[:, 0] - transitions = jax.tree_map(lambda x: x[:, :-1, ...], transitions) - if flatten_batch: - # Merge the 2 leading batch dimensions into 1. - transitions = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), - transitions) - return transitions - - -def transition_to_replaysample( - transitions: types.Transition) -> reverb.ReplaySample: - """Converts a types.Transition to a reverb.ReplaySample.""" - info = tree.map_structure(lambda dtype: tf.ones([], dtype), - reverb.SampleInfo.tf_dtypes()) - return reverb.ReplaySample(info=info, data=transitions) + if not is_sequence: + return types.Transition(*sample.data) + # Note that the last next_observation is invalid. + steps = sample.data + + def roll(observation): + return np.roll(observation, shift=-1, axis=1) + + transitions = types.Transition( + observation=steps.observation, + action=steps.action, + reward=steps.reward, + discount=steps.discount, + next_observation=tree.map_structure(roll, steps.observation), + extras=steps.extras, + ) + if strip_last_transition: + # We remove the last transition as its next_observation field is incorrect. + # It has been obtained by rolling the observation field, such that + # transitions.next_observations[:, -1] is transitions.observations[:, 0] + transitions = jax.tree_map(lambda x: x[:, :-1, ...], transitions) + if flatten_batch: + # Merge the 2 leading batch dimensions into 1. + transitions = jax.tree_map( + lambda x: np.reshape(x, (-1,) + x.shape[2:]), transitions + ) + return transitions + + +def transition_to_replaysample(transitions: types.Transition) -> reverb.ReplaySample: + """Converts a types.Transition to a reverb.ReplaySample.""" + info = tree.map_structure( + lambda dtype: tf.ones([], dtype), reverb.SampleInfo.tf_dtypes() + ) + return reverb.ReplaySample(info=info, data=transitions) diff --git a/acme/utils/reverb_utils_test.py b/acme/utils/reverb_utils_test.py index 2c71c52485..1e0ba06576 100644 --- a/acme/utils/reverb_utils_test.py +++ b/acme/utils/reverb_utils_test.py @@ -14,74 +14,81 @@ """Tests for acme.utils.reverb_utils.""" -from acme import types -from acme.adders import reverb as reverb_adders -from acme.utils import reverb_utils import numpy as np import reverb import tree - from absl.testing import absltest +from acme import types +from acme.adders import reverb as reverb_adders +from acme.utils import reverb_utils -class ReverbUtilsTest(absltest.TestCase): - def test_make_replay_table_preserves_table_info(self): - limiter = reverb.rate_limiters.SampleToInsertRatio( - samples_per_insert=1, min_size_to_sample=2, error_buffer=(0, 10)) - table = reverb.Table( - name='test', - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=10, - rate_limiter=limiter) - new_table = reverb_utils.make_replay_table_from_info(table.info) - new_info = new_table.info +class ReverbUtilsTest(absltest.TestCase): + def test_make_replay_table_preserves_table_info(self): + limiter = reverb.rate_limiters.SampleToInsertRatio( + samples_per_insert=1, min_size_to_sample=2, error_buffer=(0, 10) + ) + table = reverb.Table( + name="test", + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=10, + rate_limiter=limiter, + ) + new_table = reverb_utils.make_replay_table_from_info(table.info) + new_info = new_table.info - # table_worker_time is not set by the above utility since this is meant to - # be monitoring information about any given table. So instead we copy this - # so that the assertion below checks that everything else matches. + # table_worker_time is not set by the above utility since this is meant to + # be monitoring information about any given table. So instead we copy this + # so that the assertion below checks that everything else matches. - new_info.table_worker_time.sleeping_ms = ( - table.info.table_worker_time.sleeping_ms) + new_info.table_worker_time.sleeping_ms = ( + table.info.table_worker_time.sleeping_ms + ) - self.assertEqual(new_info, table.info) + self.assertEqual(new_info, table.info) - _EMPTY_INFO = reverb.SampleInfo(*[() for _ in reverb.SampleInfo.tf_dtypes()]) - _DUMMY_OBS = np.array([[[0], [1], [2]]]) - _DUMMY_ACTION = np.array([[[3], [4], [5]]]) - _DUMMY_REWARD = np.array([[6, 7, 8]]) - _DUMMY_DISCOUNT = np.array([[.99, .99, .99]]) - _DUMMY_NEXT_OBS = np.array([[[1], [2], [0]]]) - _DUMMY_RETURN = np.array([[20.77, 14.92, 8.]]) + _EMPTY_INFO = reverb.SampleInfo(*[() for _ in reverb.SampleInfo.tf_dtypes()]) + _DUMMY_OBS = np.array([[[0], [1], [2]]]) + _DUMMY_ACTION = np.array([[[3], [4], [5]]]) + _DUMMY_REWARD = np.array([[6, 7, 8]]) + _DUMMY_DISCOUNT = np.array([[0.99, 0.99, 0.99]]) + _DUMMY_NEXT_OBS = np.array([[[1], [2], [0]]]) + _DUMMY_RETURN = np.array([[20.77, 14.92, 8.0]]) - def _create_dummy_steps(self): - return reverb_adders.Step( - observation=self._DUMMY_OBS, - action=self._DUMMY_ACTION, - reward=self._DUMMY_REWARD, - discount=self._DUMMY_DISCOUNT, - start_of_episode=True, - extras={'return': self._DUMMY_RETURN}) + def _create_dummy_steps(self): + return reverb_adders.Step( + observation=self._DUMMY_OBS, + action=self._DUMMY_ACTION, + reward=self._DUMMY_REWARD, + discount=self._DUMMY_DISCOUNT, + start_of_episode=True, + extras={"return": self._DUMMY_RETURN}, + ) - def _create_dummy_transitions(self): - return types.Transition( - observation=self._DUMMY_OBS, - action=self._DUMMY_ACTION, - reward=self._DUMMY_REWARD, - discount=self._DUMMY_DISCOUNT, - next_observation=self._DUMMY_NEXT_OBS, - extras={'return': self._DUMMY_RETURN}) + def _create_dummy_transitions(self): + return types.Transition( + observation=self._DUMMY_OBS, + action=self._DUMMY_ACTION, + reward=self._DUMMY_REWARD, + discount=self._DUMMY_DISCOUNT, + next_observation=self._DUMMY_NEXT_OBS, + extras={"return": self._DUMMY_RETURN}, + ) - def test_replay_sample_to_sars_transition_is_sequence(self): - fake_sample = reverb.ReplaySample( - info=self._EMPTY_INFO, data=self._create_dummy_steps()) - fake_transition = self._create_dummy_transitions() - transition_from_sample = reverb_utils.replay_sample_to_sars_transition( - fake_sample, is_sequence=True) - tree.map_structure(np.testing.assert_array_equal, transition_from_sample, - fake_transition) + def test_replay_sample_to_sars_transition_is_sequence(self): + fake_sample = reverb.ReplaySample( + info=self._EMPTY_INFO, data=self._create_dummy_steps() + ) + fake_transition = self._create_dummy_transitions() + transition_from_sample = reverb_utils.replay_sample_to_sars_transition( + fake_sample, is_sequence=True + ) + tree.map_structure( + np.testing.assert_array_equal, transition_from_sample, fake_transition + ) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/utils/signals.py b/acme/utils/signals.py index 907faac26f..66a38ac5a0 100644 --- a/acme/utils/signals.py +++ b/acme/utils/signals.py @@ -26,7 +26,7 @@ @contextlib.contextmanager def runtime_terminator(callback: Optional[_Handler] = None): - """Runtime terminator used for stopping computation upon agent termination. + """Runtime terminator used for stopping computation upon agent termination. Runtime terminator optionally executed a provided `callback` and then raises `SystemExit` exception in the thread performing the computation. @@ -37,13 +37,16 @@ def runtime_terminator(callback: Optional[_Handler] = None): Yields: None. """ - worker_id = threading.get_ident() - def signal_handler(): - if callback: - callback() - res = ctypes.pythonapi.PyThreadState_SetAsyncExc( - ctypes.c_long(worker_id), ctypes.py_object(SystemExit)) - assert res < 2, 'Stopping worker failed' - launchpad.register_stop_handler(signal_handler) - yield - launchpad.unregister_stop_handler(signal_handler) + worker_id = threading.get_ident() + + def signal_handler(): + if callback: + callback() + res = ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_long(worker_id), ctypes.py_object(SystemExit) + ) + assert res < 2, "Stopping worker failed" + + launchpad.register_stop_handler(signal_handler) + yield + launchpad.unregister_stop_handler(signal_handler) diff --git a/acme/utils/tree_utils.py b/acme/utils/tree_utils.py index 7e6f72bd68..8096eba771 100644 --- a/acme/utils/tree_utils.py +++ b/acme/utils/tree_utils.py @@ -14,40 +14,40 @@ """Tensor framework-agnostic utilities for manipulating nested structures.""" -from typing import Sequence, List, TypeVar, Any +from typing import Any, List, Sequence, TypeVar import numpy as np import tree -ElementType = TypeVar('ElementType') +ElementType = TypeVar("ElementType") def fast_map_structure(func, *structure): - """Faster map_structure implementation which skips some error checking.""" - flat_structure = (tree.flatten(s) for s in structure) - entries = zip(*flat_structure) - # Arbitrarily choose one of the structures of the original sequence (the last) - # to match the structure for the flattened sequence. - return tree.unflatten_as(structure[-1], [func(*x) for x in entries]) + """Faster map_structure implementation which skips some error checking.""" + flat_structure = (tree.flatten(s) for s in structure) + entries = zip(*flat_structure) + # Arbitrarily choose one of the structures of the original sequence (the last) + # to match the structure for the flattened sequence. + return tree.unflatten_as(structure[-1], [func(*x) for x in entries]) def fast_map_structure_with_path(func, *structure): - """Faster map_structure_with_path implementation.""" - head_entries_with_path = tree.flatten_with_path(structure[0]) - if len(structure) > 1: - tail_entries = (tree.flatten(s) for s in structure[1:]) - entries_with_path = [ - e[0] + e[1:] for e in zip(head_entries_with_path, *tail_entries) - ] - else: - entries_with_path = head_entries_with_path - # Arbitrarily choose one of the structures of the original sequence (the last) - # to match the structure for the flattened sequence. - return tree.unflatten_as(structure[-1], [func(*x) for x in entries_with_path]) + """Faster map_structure_with_path implementation.""" + head_entries_with_path = tree.flatten_with_path(structure[0]) + if len(structure) > 1: + tail_entries = (tree.flatten(s) for s in structure[1:]) + entries_with_path = [ + e[0] + e[1:] for e in zip(head_entries_with_path, *tail_entries) + ] + else: + entries_with_path = head_entries_with_path + # Arbitrarily choose one of the structures of the original sequence (the last) + # to match the structure for the flattened sequence. + return tree.unflatten_as(structure[-1], [func(*x) for x in entries_with_path]) def stack_sequence_fields(sequence: Sequence[ElementType]) -> ElementType: - """Stacks a list of identically nested objects. + """Stacks a list of identically nested objects. This takes a sequence of identically nested objects and returns a single nested object whose ith leaf is a stacked numpy array of the corresponding @@ -93,22 +93,22 @@ def stack_sequence_fields(sequence: Sequence[ElementType]) -> ElementType: Raises: ValueError: If `sequence` is an empty sequence. """ - # Handle empty input sequences. - if not sequence: - raise ValueError('Input sequence must not be empty') + # Handle empty input sequences. + if not sequence: + raise ValueError("Input sequence must not be empty") - # Default to asarray when arrays don't have the same shape to be compatible - # with old behaviour. - try: - return fast_map_structure(lambda *values: np.stack(values), *sequence) - except ValueError: - return fast_map_structure(lambda *values: np.asarray(values, dtype=object), - *sequence) + # Default to asarray when arrays don't have the same shape to be compatible + # with old behaviour. + try: + return fast_map_structure(lambda *values: np.stack(values), *sequence) + except ValueError: + return fast_map_structure( + lambda *values: np.asarray(values, dtype=object), *sequence + ) -def unstack_sequence_fields(struct: ElementType, - batch_size: int) -> List[ElementType]: - """Converts a struct of batched arrays to a list of structs. +def unstack_sequence_fields(struct: ElementType, batch_size: int) -> List[ElementType]: + """Converts a struct of batched arrays to a list of structs. This is effectively the inverse of `stack_sequence_fields`. @@ -122,13 +122,11 @@ def unstack_sequence_fields(struct: ElementType, is an unbatched element of the original leaf node. """ - return [ - tree.map_structure(lambda s, i=i: s[i], struct) for i in range(batch_size) - ] + return [tree.map_structure(lambda s, i=i: s[i], struct) for i in range(batch_size)] def broadcast_structures(*args: Any) -> Any: - """Returns versions of the arguments that give them the same nested structure. + """Returns versions of the arguments that give them the same nested structure. Any nested items in *args must have the same structure. @@ -156,37 +154,37 @@ def broadcast_structures(*args: Any) -> Any: Returns: `*args`, except with all items sharing the same nest structure. """ - if not args: - return - - reference_tree = None - for arg in args: - if tree.is_nested(arg): - reference_tree = arg - break - - # If reference_tree is None then none of args are nested and we can skip over - # the rest of this function, which would be a no-op. - if reference_tree is None: - return args - - def mirror_structure(value, reference_tree): - if tree.is_nested(value): - # Use check_types=True so that the types of the trees we construct aren't - # dependent on our arbitrary choice of which nested arg to use as the - # reference_tree. - tree.assert_same_structure(value, reference_tree, check_types=True) - return value - else: - return tree.map_structure(lambda _: value, reference_tree) + if not args: + return + + reference_tree = None + for arg in args: + if tree.is_nested(arg): + reference_tree = arg + break + + # If reference_tree is None then none of args are nested and we can skip over + # the rest of this function, which would be a no-op. + if reference_tree is None: + return args + + def mirror_structure(value, reference_tree): + if tree.is_nested(value): + # Use check_types=True so that the types of the trees we construct aren't + # dependent on our arbitrary choice of which nested arg to use as the + # reference_tree. + tree.assert_same_structure(value, reference_tree, check_types=True) + return value + else: + return tree.map_structure(lambda _: value, reference_tree) - return tuple(mirror_structure(arg, reference_tree) for arg in args) + return tuple(mirror_structure(arg, reference_tree) for arg in args) def tree_map(f): - """Transforms `f` into a tree-mapped version.""" + """Transforms `f` into a tree-mapped version.""" - def mapped_f(*structures): - return tree.map_structure(f, *structures) + def mapped_f(*structures): + return tree.map_structure(f, *structures) - return mapped_f + return mapped_f diff --git a/acme/utils/tree_utils_test.py b/acme/utils/tree_utils_test.py index b24b87b725..96c5943eac 100644 --- a/acme/utils/tree_utils_test.py +++ b/acme/utils/tree_utils_test.py @@ -17,90 +17,90 @@ import functools from typing import Sequence -from acme.utils import tree_utils import numpy as np import tree - from absl.testing import absltest +from acme.utils import tree_utils + TEST_SEQUENCE = [ { - 'action': np.array([1.0]), - 'observation': (np.array([0.0, 1.0, 2.0]),), - 'reward': np.array(1.0), + "action": np.array([1.0]), + "observation": (np.array([0.0, 1.0, 2.0]),), + "reward": np.array(1.0), }, { - 'action': np.array([0.5]), - 'observation': (np.array([1.0, 2.0, 3.0]),), - 'reward': np.array(0.0), + "action": np.array([0.5]), + "observation": (np.array([1.0, 2.0, 3.0]),), + "reward": np.array(0.0), }, { - 'action': np.array([0.3]), - 'observation': (np.array([2.0, 3.0, 4.0]),), - 'reward': np.array(0.5), + "action": np.array([0.3]), + "observation": (np.array([2.0, 3.0, 4.0]),), + "reward": np.array(0.5), }, ] class SequenceStackTest(absltest.TestCase): - """Tests for various tree utilities.""" - - def test_stack_sequence_fields(self): - """Tests that `stack_sequence_fields` behaves correctly on nested data.""" - - stacked = tree_utils.stack_sequence_fields(TEST_SEQUENCE) - - # Check that the stacked output has the correct structure. - tree.assert_same_structure(stacked, TEST_SEQUENCE[0]) - - # Check that the leaves have the correct array shapes. - self.assertEqual(stacked['action'].shape, (3, 1)) - self.assertEqual(stacked['observation'][0].shape, (3, 3)) - self.assertEqual(stacked['reward'].shape, (3,)) - - # Check values. - self.assertEqual(stacked['observation'][0].tolist(), [ - [0., 1., 2.], - [1., 2., 3.], - [2., 3., 4.], - ]) - self.assertEqual(stacked['action'].tolist(), [[1.], [0.5], [0.3]]) - self.assertEqual(stacked['reward'].tolist(), [1., 0., 0.5]) - - def test_unstack_sequence_fields(self): - """Tests that `unstack_sequence_fields(stack_sequence_fields(x)) == x`.""" - stacked = tree_utils.stack_sequence_fields(TEST_SEQUENCE) - batch_size = len(TEST_SEQUENCE) - unstacked = tree_utils.unstack_sequence_fields(stacked, batch_size) - tree.map_structure(np.testing.assert_array_equal, unstacked, TEST_SEQUENCE) - - def test_fast_map_structure_with_path(self): - structure = { - 'a': { - 'b': np.array([0.0]) - }, - 'c': (np.array([1.0]), np.array([2.0])), - 'd': [np.array(3.0), np.array(4.0)], - } - - def map_fn(path: Sequence[str], x: np.ndarray, y: np.ndarray): - return x + y + len(path) - - single_arg_map_fn = functools.partial(map_fn, y=np.array([0.0])) - - expected_mapped_structure = ( - tree.map_structure_with_path(single_arg_map_fn, structure)) - mapped_structure = ( - tree_utils.fast_map_structure_with_path(single_arg_map_fn, structure)) - self.assertEqual(mapped_structure, expected_mapped_structure) - - expected_double_mapped_structure = ( - tree.map_structure_with_path(map_fn, structure, mapped_structure)) - double_mapped_structure = ( - tree_utils.fast_map_structure_with_path(map_fn, structure, - mapped_structure)) - self.assertEqual(double_mapped_structure, expected_double_mapped_structure) - - -if __name__ == '__main__': - absltest.main() + """Tests for various tree utilities.""" + + def test_stack_sequence_fields(self): + """Tests that `stack_sequence_fields` behaves correctly on nested data.""" + + stacked = tree_utils.stack_sequence_fields(TEST_SEQUENCE) + + # Check that the stacked output has the correct structure. + tree.assert_same_structure(stacked, TEST_SEQUENCE[0]) + + # Check that the leaves have the correct array shapes. + self.assertEqual(stacked["action"].shape, (3, 1)) + self.assertEqual(stacked["observation"][0].shape, (3, 3)) + self.assertEqual(stacked["reward"].shape, (3,)) + + # Check values. + self.assertEqual( + stacked["observation"][0].tolist(), + [[0.0, 1.0, 2.0], [1.0, 2.0, 3.0], [2.0, 3.0, 4.0],], + ) + self.assertEqual(stacked["action"].tolist(), [[1.0], [0.5], [0.3]]) + self.assertEqual(stacked["reward"].tolist(), [1.0, 0.0, 0.5]) + + def test_unstack_sequence_fields(self): + """Tests that `unstack_sequence_fields(stack_sequence_fields(x)) == x`.""" + stacked = tree_utils.stack_sequence_fields(TEST_SEQUENCE) + batch_size = len(TEST_SEQUENCE) + unstacked = tree_utils.unstack_sequence_fields(stacked, batch_size) + tree.map_structure(np.testing.assert_array_equal, unstacked, TEST_SEQUENCE) + + def test_fast_map_structure_with_path(self): + structure = { + "a": {"b": np.array([0.0])}, + "c": (np.array([1.0]), np.array([2.0])), + "d": [np.array(3.0), np.array(4.0)], + } + + def map_fn(path: Sequence[str], x: np.ndarray, y: np.ndarray): + return x + y + len(path) + + single_arg_map_fn = functools.partial(map_fn, y=np.array([0.0])) + + expected_mapped_structure = tree.map_structure_with_path( + single_arg_map_fn, structure + ) + mapped_structure = tree_utils.fast_map_structure_with_path( + single_arg_map_fn, structure + ) + self.assertEqual(mapped_structure, expected_mapped_structure) + + expected_double_mapped_structure = tree.map_structure_with_path( + map_fn, structure, mapped_structure + ) + double_mapped_structure = tree_utils.fast_map_structure_with_path( + map_fn, structure, mapped_structure + ) + self.assertEqual(double_mapped_structure, expected_double_mapped_structure) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/wrappers/__init__.py b/acme/wrappers/__init__.py index 0772c0b354..dcb63ec958 100644 --- a/acme/wrappers/__init__.py +++ b/acme/wrappers/__init__.py @@ -16,25 +16,23 @@ from acme.wrappers.action_repeat import ActionRepeatWrapper from acme.wrappers.atari_wrapper import AtariWrapper -from acme.wrappers.base import EnvironmentWrapper -from acme.wrappers.base import wrap_all +from acme.wrappers.base import EnvironmentWrapper, wrap_all from acme.wrappers.canonical_spec import CanonicalSpecWrapper from acme.wrappers.concatenate_observations import ConcatObservationWrapper from acme.wrappers.delayed_reward import DelayedRewardWrapper -from acme.wrappers.expand_scalar_observation_shapes import ExpandScalarObservationShapesWrapper -from acme.wrappers.frame_stacking import FrameStacker -from acme.wrappers.frame_stacking import FrameStackingWrapper -from acme.wrappers.gym_wrapper import GymAtariAdapter -from acme.wrappers.gym_wrapper import GymWrapper +from acme.wrappers.expand_scalar_observation_shapes import ( + ExpandScalarObservationShapesWrapper, +) +from acme.wrappers.frame_stacking import FrameStacker, FrameStackingWrapper +from acme.wrappers.gym_wrapper import GymAtariAdapter, GymWrapper from acme.wrappers.noop_starts import NoopStartsWrapper from acme.wrappers.observation_action_reward import ObservationActionRewardWrapper from acme.wrappers.single_precision import SinglePrecisionWrapper from acme.wrappers.step_limit import StepLimitWrapper -from acme.wrappers.video import MujocoVideoWrapper -from acme.wrappers.video import VideoWrapper +from acme.wrappers.video import MujocoVideoWrapper, VideoWrapper try: - # pylint: disable=g-import-not-at-top - from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper + # pylint: disable=g-import-not-at-top + from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper except ImportError: - pass + pass diff --git a/acme/wrappers/action_repeat.py b/acme/wrappers/action_repeat.py index 8c6978efad..16353fea97 100644 --- a/acme/wrappers/action_repeat.py +++ b/acme/wrappers/action_repeat.py @@ -14,34 +14,35 @@ """Wrapper that implements action repeats.""" +import dm_env + from acme import types from acme.wrappers import base -import dm_env class ActionRepeatWrapper(base.EnvironmentWrapper): - """Action repeat wrapper.""" + """Action repeat wrapper.""" - def __init__(self, environment: dm_env.Environment, num_repeats: int = 1): - super().__init__(environment) - self._num_repeats = num_repeats + def __init__(self, environment: dm_env.Environment, num_repeats: int = 1): + super().__init__(environment) + self._num_repeats = num_repeats - def step(self, action: types.NestedArray) -> dm_env.TimeStep: - # Initialize accumulated reward and discount. - reward = 0. - discount = 1. + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + # Initialize accumulated reward and discount. + reward = 0.0 + discount = 1.0 - # Step the environment by repeating action. - for _ in range(self._num_repeats): - timestep = self._environment.step(action) + # Step the environment by repeating action. + for _ in range(self._num_repeats): + timestep = self._environment.step(action) - # Accumulate reward and discount. - reward += timestep.reward * discount - discount *= timestep.discount + # Accumulate reward and discount. + reward += timestep.reward * discount + discount *= timestep.discount - # Don't go over episode boundaries. - if timestep.last(): - break + # Don't go over episode boundaries. + if timestep.last(): + break - # Replace the final timestep's reward and discount. - return timestep._replace(reward=reward, discount=discount) + # Replace the final timestep's reward and discount. + return timestep._replace(reward=reward, discount=discount) diff --git a/acme/wrappers/atari_wrapper.py b/acme/wrappers/atari_wrapper.py index 25f7578d4e..3c5787e451 100644 --- a/acme/wrappers/atari_wrapper.py +++ b/acme/wrappers/atari_wrapper.py @@ -15,23 +15,22 @@ """Atari wrappers functionality for Python environments.""" import abc -from typing import Tuple, List, Optional, Sequence, Union - -from acme.wrappers import base -from acme.wrappers import frame_stacking +from typing import List, Optional, Sequence, Tuple, Union import dm_env -from dm_env import specs import numpy as np +from dm_env import specs from PIL import Image +from acme.wrappers import base, frame_stacking + RGB_INDEX = 0 # Observation index holding the RGB data. LIVES_INDEX = 1 # Observation index holding the lives count. NUM_COLOR_CHANNELS = 3 # Number of color channels in RGB data. class BaseAtariWrapper(abc.ABC, base.EnvironmentWrapper): - """Abstract base class for Atari wrappers. + """Abstract base class for Atari wrappers. This assumes that the input environment is a dm_env.Environment instance in which observations are tuples whose first element is an RGB observation and @@ -61,21 +60,23 @@ class BaseAtariWrapper(abc.ABC, base.EnvironmentWrapper): in other wrappers, which rescales pixel values to floats in the range [0, 1]. """ - def __init__(self, - environment: dm_env.Environment, - *, - max_abs_reward: Optional[float] = None, - scale_dims: Optional[Tuple[int, int]] = (84, 84), - action_repeats: int = 4, - pooled_frames: int = 2, - zero_discount_on_life_loss: bool = False, - expose_lives_observation: bool = False, - num_stacked_frames: int = 4, - flatten_frame_stack: bool = False, - max_episode_len: Optional[int] = None, - to_float: bool = False, - grayscaling: bool = True): - """Initializes a new AtariWrapper. + def __init__( + self, + environment: dm_env.Environment, + *, + max_abs_reward: Optional[float] = None, + scale_dims: Optional[Tuple[int, int]] = (84, 84), + action_repeats: int = 4, + pooled_frames: int = 2, + zero_discount_on_life_loss: bool = False, + expose_lives_observation: bool = False, + num_stacked_frames: int = 4, + flatten_frame_stack: bool = False, + max_episode_len: Optional[int] = None, + to_float: bool = False, + grayscaling: bool = True + ): + """Initializes a new AtariWrapper. Args: environment: An Atari environment. @@ -108,201 +109,201 @@ def __init__(self, Raises: ValueError: For various invalid inputs. """ - if not 1 <= pooled_frames <= action_repeats: - raise ValueError("pooled_frames ({}) must be between 1 and " - "action_repeats ({}) inclusive".format( - pooled_frames, action_repeats)) - - if zero_discount_on_life_loss: - super().__init__(_ZeroDiscountOnLifeLoss(environment)) - else: - super().__init__(environment) - - if not max_episode_len: - max_episode_len = np.inf - - self._frame_stacker = frame_stacking.FrameStacker( - num_frames=num_stacked_frames, flatten=flatten_frame_stack) - self._action_repeats = action_repeats - self._pooled_frames = pooled_frames - self._scale_dims = scale_dims - self._max_abs_reward = max_abs_reward or np.inf - self._to_float = to_float - self._expose_lives_observation = expose_lives_observation - - if scale_dims: - self._height, self._width = scale_dims - else: - spec = environment.observation_spec() - self._height, self._width = spec[RGB_INDEX].shape[:2] - - self._episode_len = 0 - self._max_episode_len = max_episode_len - self._reset_next_step = True - - self._grayscaling = grayscaling - - # Based on underlying observation spec, decide whether lives are to be - # included in output observations. - observation_spec = self._environment.observation_spec() - spec_names = [spec.name for spec in observation_spec] - if "lives" in spec_names and spec_names.index("lives") != 1: - raise ValueError("`lives` observation needs to have index 1 in Atari.") - - self._observation_spec = self._init_observation_spec() - - self._raw_observation = None - - def _init_observation_spec(self): - """Computes the observation spec for the pixel observations. - - Returns: - An `Array` specification for the pixel observations. - """ - if self._to_float: - pixels_dtype = float - else: - pixels_dtype = np.uint8 - - if self._grayscaling: - pixels_spec_shape = (self._height, self._width) - pixels_spec_name = "grayscale" - else: - pixels_spec_shape = (self._height, self._width, NUM_COLOR_CHANNELS) - pixels_spec_name = "RGB" - - pixel_spec = specs.Array( - shape=pixels_spec_shape, dtype=pixels_dtype, name=pixels_spec_name) - pixel_spec = self._frame_stacker.update_spec(pixel_spec) - - if self._expose_lives_observation: - return (pixel_spec,) + self._environment.observation_spec()[1:] - return pixel_spec - - def reset(self) -> dm_env.TimeStep: - """Resets environment and provides the first timestep.""" - self._reset_next_step = False - self._episode_len = 0 - self._frame_stacker.reset() - timestep = self._environment.reset() - - observation = self._observation_from_timestep_stack([timestep]) - - return self._postprocess_observation( - timestep._replace(observation=observation)) - - def step(self, action: int) -> dm_env.TimeStep: - """Steps up to action_repeat times and returns a post-processed step.""" - if self._reset_next_step: - return self.reset() - - timestep_stack = [] - - # Step on environment multiple times for each selected action. - for _ in range(self._action_repeats): - timestep = self._environment.step([np.array([action])]) - - self._episode_len += 1 - if self._episode_len == self._max_episode_len: - timestep = timestep._replace(step_type=dm_env.StepType.LAST) - - timestep_stack.append(timestep) - - if timestep.last(): - # Action repeat frames should not span episode boundaries. Also, no need - # to pad with zero-valued observations as all the reductions in - # _postprocess_observation work gracefully for any non-zero size of - # timestep_stack. + if not 1 <= pooled_frames <= action_repeats: + raise ValueError( + "pooled_frames ({}) must be between 1 and " + "action_repeats ({}) inclusive".format(pooled_frames, action_repeats) + ) + + if zero_discount_on_life_loss: + super().__init__(_ZeroDiscountOnLifeLoss(environment)) + else: + super().__init__(environment) + + if not max_episode_len: + max_episode_len = np.inf + + self._frame_stacker = frame_stacking.FrameStacker( + num_frames=num_stacked_frames, flatten=flatten_frame_stack + ) + self._action_repeats = action_repeats + self._pooled_frames = pooled_frames + self._scale_dims = scale_dims + self._max_abs_reward = max_abs_reward or np.inf + self._to_float = to_float + self._expose_lives_observation = expose_lives_observation + + if scale_dims: + self._height, self._width = scale_dims + else: + spec = environment.observation_spec() + self._height, self._width = spec[RGB_INDEX].shape[:2] + + self._episode_len = 0 + self._max_episode_len = max_episode_len self._reset_next_step = True - break - - # Determine a single step type. We let FIRST take priority over LAST, since - # we think it's more likely algorithm code will be set up to deal with that, - # due to environments supporting reset() (which emits a FIRST). - # Note we'll never have LAST then FIRST in timestep_stack here. - step_type = dm_env.StepType.MID - for timestep in timestep_stack: - if timestep.first(): - step_type = dm_env.StepType.FIRST - break - elif timestep.last(): - step_type = dm_env.StepType.LAST - break - - if timestep_stack[0].first(): - # Update first timestep to have identity effect on reward and discount. - timestep_stack[0] = timestep_stack[0]._replace(reward=0., discount=1.) - # Sum reward over stack. - reward = sum(timestep_t.reward for timestep_t in timestep_stack) + self._grayscaling = grayscaling - # Multiply discount over stack (will either be 0. or 1.). - discount = np.product( - [timestep_t.discount for timestep_t in timestep_stack]) + # Based on underlying observation spec, decide whether lives are to be + # included in output observations. + observation_spec = self._environment.observation_spec() + spec_names = [spec.name for spec in observation_spec] + if "lives" in spec_names and spec_names.index("lives") != 1: + raise ValueError("`lives` observation needs to have index 1 in Atari.") - observation = self._observation_from_timestep_stack(timestep_stack) + self._observation_spec = self._init_observation_spec() - timestep = dm_env.TimeStep( - step_type=step_type, - reward=reward, - observation=observation, - discount=discount) + self._raw_observation = None - return self._postprocess_observation(timestep) + def _init_observation_spec(self): + """Computes the observation spec for the pixel observations. - @abc.abstractmethod - def _preprocess_pixels(self, timestep_stack: List[dm_env.TimeStep]): - """Process Atari pixels.""" - - def _observation_from_timestep_stack(self, - timestep_stack: List[dm_env.TimeStep]): - """Compute the observation for a stack of timesteps.""" - self._raw_observation = timestep_stack[-1].observation[RGB_INDEX].copy() - processed_pixels = self._preprocess_pixels(timestep_stack) - - if self._to_float: - stacked_observation = self._frame_stacker.step(processed_pixels / 255.0) - else: - stacked_observation = self._frame_stacker.step(processed_pixels) + Returns: + An `Array` specification for the pixel observations. + """ + if self._to_float: + pixels_dtype = float + else: + pixels_dtype = np.uint8 + + if self._grayscaling: + pixels_spec_shape = (self._height, self._width) + pixels_spec_name = "grayscale" + else: + pixels_spec_shape = (self._height, self._width, NUM_COLOR_CHANNELS) + pixels_spec_name = "RGB" + + pixel_spec = specs.Array( + shape=pixels_spec_shape, dtype=pixels_dtype, name=pixels_spec_name + ) + pixel_spec = self._frame_stacker.update_spec(pixel_spec) + + if self._expose_lives_observation: + return (pixel_spec,) + self._environment.observation_spec()[1:] + return pixel_spec + + def reset(self) -> dm_env.TimeStep: + """Resets environment and provides the first timestep.""" + self._reset_next_step = False + self._episode_len = 0 + self._frame_stacker.reset() + timestep = self._environment.reset() + + observation = self._observation_from_timestep_stack([timestep]) + + return self._postprocess_observation(timestep._replace(observation=observation)) + + def step(self, action: int) -> dm_env.TimeStep: + """Steps up to action_repeat times and returns a post-processed step.""" + if self._reset_next_step: + return self.reset() + + timestep_stack = [] + + # Step on environment multiple times for each selected action. + for _ in range(self._action_repeats): + timestep = self._environment.step([np.array([action])]) + + self._episode_len += 1 + if self._episode_len == self._max_episode_len: + timestep = timestep._replace(step_type=dm_env.StepType.LAST) + + timestep_stack.append(timestep) + + if timestep.last(): + # Action repeat frames should not span episode boundaries. Also, no need + # to pad with zero-valued observations as all the reductions in + # _postprocess_observation work gracefully for any non-zero size of + # timestep_stack. + self._reset_next_step = True + break + + # Determine a single step type. We let FIRST take priority over LAST, since + # we think it's more likely algorithm code will be set up to deal with that, + # due to environments supporting reset() (which emits a FIRST). + # Note we'll never have LAST then FIRST in timestep_stack here. + step_type = dm_env.StepType.MID + for timestep in timestep_stack: + if timestep.first(): + step_type = dm_env.StepType.FIRST + break + elif timestep.last(): + step_type = dm_env.StepType.LAST + break + + if timestep_stack[0].first(): + # Update first timestep to have identity effect on reward and discount. + timestep_stack[0] = timestep_stack[0]._replace(reward=0.0, discount=1.0) + + # Sum reward over stack. + reward = sum(timestep_t.reward for timestep_t in timestep_stack) + + # Multiply discount over stack (will either be 0. or 1.). + discount = np.product([timestep_t.discount for timestep_t in timestep_stack]) + + observation = self._observation_from_timestep_stack(timestep_stack) + + timestep = dm_env.TimeStep( + step_type=step_type, + reward=reward, + observation=observation, + discount=discount, + ) + + return self._postprocess_observation(timestep) + + @abc.abstractmethod + def _preprocess_pixels(self, timestep_stack: List[dm_env.TimeStep]): + """Process Atari pixels.""" + + def _observation_from_timestep_stack(self, timestep_stack: List[dm_env.TimeStep]): + """Compute the observation for a stack of timesteps.""" + self._raw_observation = timestep_stack[-1].observation[RGB_INDEX].copy() + processed_pixels = self._preprocess_pixels(timestep_stack) + + if self._to_float: + stacked_observation = self._frame_stacker.step(processed_pixels / 255.0) + else: + stacked_observation = self._frame_stacker.step(processed_pixels) - # We use last timestep for lives only. - observation = timestep_stack[-1].observation - if self._expose_lives_observation: - return (stacked_observation,) + observation[1:] + # We use last timestep for lives only. + observation = timestep_stack[-1].observation + if self._expose_lives_observation: + return (stacked_observation,) + observation[1:] - return stacked_observation + return stacked_observation - def _postprocess_observation(self, - timestep: dm_env.TimeStep) -> dm_env.TimeStep: - """Observation processing applied after action repeat consolidation.""" + def _postprocess_observation(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: + """Observation processing applied after action repeat consolidation.""" - if timestep.first(): - return dm_env.restart(timestep.observation) + if timestep.first(): + return dm_env.restart(timestep.observation) - reward = np.clip(timestep.reward, -self._max_abs_reward, - self._max_abs_reward) + reward = np.clip(timestep.reward, -self._max_abs_reward, self._max_abs_reward) - return timestep._replace(reward=reward) + return timestep._replace(reward=reward) - def action_spec(self) -> specs.DiscreteArray: - raw_spec = self._environment.action_spec()[0] - return specs.DiscreteArray(num_values=raw_spec.maximum.item() - - raw_spec.minimum.item() + 1) + def action_spec(self) -> specs.DiscreteArray: + raw_spec = self._environment.action_spec()[0] + return specs.DiscreteArray( + num_values=raw_spec.maximum.item() - raw_spec.minimum.item() + 1 + ) - def observation_spec(self) -> Union[specs.Array, Sequence[specs.Array]]: - return self._observation_spec + def observation_spec(self) -> Union[specs.Array, Sequence[specs.Array]]: + return self._observation_spec - def reward_spec(self) -> specs.Array: - return specs.Array(shape=(), dtype=float) + def reward_spec(self) -> specs.Array: + return specs.Array(shape=(), dtype=float) - @property - def raw_observation(self) -> np.ndarray: - """Returns the raw observation, after any pooling has been applied.""" - return self._raw_observation + @property + def raw_observation(self) -> np.ndarray: + """Returns the raw observation, after any pooling has been applied.""" + return self._raw_observation class AtariWrapper(BaseAtariWrapper): - """Standard "Nature Atari" wrapper for Python environments. + """Standard "Nature Atari" wrapper for Python environments. Before being fed to a neural network, Atari frames go through a prepocessing, implemented in this wrapper. For historical reasons, there were different @@ -325,37 +326,41 @@ class AtariWrapper(BaseAtariWrapper): user that wishes to compare agents between librairies. """ - def _preprocess_pixels(self, timestep_stack: List[dm_env.TimeStep]): - """Preprocess Atari frames.""" - # 1. Max pooling - processed_pixels = np.max( - np.stack([ - s.observation[RGB_INDEX] - for s in timestep_stack[-self._pooled_frames:] - ]), - axis=0) - - # 2. RGB to grayscale - if self._grayscaling: - processed_pixels = np.tensordot(processed_pixels, - [0.299, 0.587, 1 - (0.299 + 0.587)], - (-1, 0)) - - # 3. Resize - processed_pixels = processed_pixels.astype(np.uint8, copy=False) - if self._scale_dims != processed_pixels.shape[:2]: - processed_pixels = Image.fromarray(processed_pixels).resize( - (self._width, self._height), Image.Resampling.BILINEAR) - processed_pixels = np.array(processed_pixels, dtype=np.uint8) - - return processed_pixels + def _preprocess_pixels(self, timestep_stack: List[dm_env.TimeStep]): + """Preprocess Atari frames.""" + # 1. Max pooling + processed_pixels = np.max( + np.stack( + [ + s.observation[RGB_INDEX] + for s in timestep_stack[-self._pooled_frames :] + ] + ), + axis=0, + ) + + # 2. RGB to grayscale + if self._grayscaling: + processed_pixels = np.tensordot( + processed_pixels, [0.299, 0.587, 1 - (0.299 + 0.587)], (-1, 0) + ) + + # 3. Resize + processed_pixels = processed_pixels.astype(np.uint8, copy=False) + if self._scale_dims != processed_pixels.shape[:2]: + processed_pixels = Image.fromarray(processed_pixels).resize( + (self._width, self._height), Image.Resampling.BILINEAR + ) + processed_pixels = np.array(processed_pixels, dtype=np.uint8) + + return processed_pixels class _ZeroDiscountOnLifeLoss(base.EnvironmentWrapper): - """Implements soft-termination (zero discount) on life loss.""" + """Implements soft-termination (zero discount) on life loss.""" - def __init__(self, environment: dm_env.Environment): - """Initializes a new `_ZeroDiscountOnLifeLoss` wrapper. + def __init__(self, environment: dm_env.Environment): + """Initializes a new `_ZeroDiscountOnLifeLoss` wrapper. Args: environment: An Atari environment. @@ -363,31 +368,31 @@ def __init__(self, environment: dm_env.Environment): Raises: ValueError: If the environment does not expose a lives observation. """ - super().__init__(environment) - self._reset_next_step = True - self._last_num_lives = None - - def reset(self) -> dm_env.TimeStep: - timestep = self._environment.reset() - self._reset_next_step = False - self._last_num_lives = timestep.observation[LIVES_INDEX] - return timestep - - def step(self, action: int) -> dm_env.TimeStep: - if self._reset_next_step: - return self.reset() - - timestep = self._environment.step(action) - lives = timestep.observation[LIVES_INDEX] - - is_life_loss = True - # We have a life loss when: - # The wrapped environment is in a regular (MID) transition. - is_life_loss &= timestep.mid() - # Lives have decreased since last time `step` was called. - is_life_loss &= lives < self._last_num_lives - - self._last_num_lives = lives - if is_life_loss: - return timestep._replace(discount=0.0) - return timestep + super().__init__(environment) + self._reset_next_step = True + self._last_num_lives = None + + def reset(self) -> dm_env.TimeStep: + timestep = self._environment.reset() + self._reset_next_step = False + self._last_num_lives = timestep.observation[LIVES_INDEX] + return timestep + + def step(self, action: int) -> dm_env.TimeStep: + if self._reset_next_step: + return self.reset() + + timestep = self._environment.step(action) + lives = timestep.observation[LIVES_INDEX] + + is_life_loss = True + # We have a life loss when: + # The wrapped environment is in a regular (MID) transition. + is_life_loss &= timestep.mid() + # Lives have decreased since last time `step` was called. + is_life_loss &= lives < self._last_num_lives + + self._last_num_lives = lives + if is_life_loss: + return timestep._replace(discount=0.0) + return timestep diff --git a/acme/wrappers/atari_wrapper_dopamine.py b/acme/wrappers/atari_wrapper_dopamine.py index 81e2331db4..082d542bbb 100644 --- a/acme/wrappers/atari_wrapper_dopamine.py +++ b/acme/wrappers/atari_wrapper_dopamine.py @@ -26,45 +26,52 @@ from typing import List -from acme.wrappers import atari_wrapper # pytype: disable=import-error import cv2 + # pytype: enable=import-error import dm_env import numpy as np +from acme.wrappers import atari_wrapper + class AtariWrapperDopamine(atari_wrapper.BaseAtariWrapper): - """Atari wrapper that matches exactly Dopamine's prepocessing. + """Atari wrapper that matches exactly Dopamine's prepocessing. Warning: using this wrapper requires that you have opencv and its dependencies installed. In general, opencv is not required for Acme. """ - def _preprocess_pixels(self, timestep_stack: List[dm_env.TimeStep]): - """Preprocess Atari frames.""" - - # 1. RBG to grayscale - def rgb_to_grayscale(obs): - if self._grayscaling: - return np.tensordot(obs, [0.2989, 0.5870, 0.1140], (-1, 0)) - return obs - - # 2. Max pooling - processed_pixels = np.max( - np.stack([ - rgb_to_grayscale(s.observation[atari_wrapper.RGB_INDEX]) - for s in timestep_stack[-self._pooled_frames:] - ]), - axis=0) - - # 3. Resize - processed_pixels = np.round(processed_pixels).astype(np.uint8) - if self._scale_dims != processed_pixels.shape[:2]: - processed_pixels = cv2.resize( - processed_pixels, (self._width, self._height), - interpolation=cv2.INTER_AREA) - - processed_pixels = np.round(processed_pixels).astype(np.uint8) - - return processed_pixels + def _preprocess_pixels(self, timestep_stack: List[dm_env.TimeStep]): + """Preprocess Atari frames.""" + + # 1. RBG to grayscale + def rgb_to_grayscale(obs): + if self._grayscaling: + return np.tensordot(obs, [0.2989, 0.5870, 0.1140], (-1, 0)) + return obs + + # 2. Max pooling + processed_pixels = np.max( + np.stack( + [ + rgb_to_grayscale(s.observation[atari_wrapper.RGB_INDEX]) + for s in timestep_stack[-self._pooled_frames :] + ] + ), + axis=0, + ) + + # 3. Resize + processed_pixels = np.round(processed_pixels).astype(np.uint8) + if self._scale_dims != processed_pixels.shape[:2]: + processed_pixels = cv2.resize( + processed_pixels, + (self._width, self._height), + interpolation=cv2.INTER_AREA, + ) + + processed_pixels = np.round(processed_pixels).astype(np.uint8) + + return processed_pixels diff --git a/acme/wrappers/atari_wrapper_test.py b/acme/wrappers/atari_wrapper_test.py index 6b07e702df..cc38dc0f72 100644 --- a/acme/wrappers/atari_wrapper_test.py +++ b/acme/wrappers/atari_wrapper_test.py @@ -16,76 +16,78 @@ import unittest -from acme.wrappers import atari_wrapper -from dm_env import specs import numpy as np +from absl.testing import absltest, parameterized +from dm_env import specs -from absl.testing import absltest -from absl.testing import parameterized +from acme.wrappers import atari_wrapper SKIP_GYM_TESTS = False -SKIP_GYM_MESSAGE = 'gym not installed.' +SKIP_GYM_MESSAGE = "gym not installed." SKIP_ATARI_TESTS = False -SKIP_ATARI_MESSAGE = '' +SKIP_ATARI_MESSAGE = "" try: - # pylint: disable=g-import-not-at-top - from acme.wrappers import gym_wrapper - import gym - # pylint: enable=g-import-not-at-top + # pylint: disable=g-import-not-at-top + import gym + + from acme.wrappers import gym_wrapper + + # pylint: enable=g-import-not-at-top except ModuleNotFoundError: - SKIP_GYM_TESTS = True + SKIP_GYM_TESTS = True try: - import atari_py # pylint: disable=g-import-not-at-top - atari_py.get_game_path('pong') + import atari_py # pylint: disable=g-import-not-at-top + + atari_py.get_game_path("pong") except ModuleNotFoundError as e: - SKIP_ATARI_TESTS = True - SKIP_ATARI_MESSAGE = str(e) + SKIP_ATARI_TESTS = True + SKIP_ATARI_MESSAGE = str(e) except Exception as e: # pylint: disable=broad-except - # This exception is raised by atari_py.get_game_path('pong') if the Atari ROM - # file has not been installed. - SKIP_ATARI_TESTS = True - SKIP_ATARI_MESSAGE = str(e) - del atari_py + # This exception is raised by atari_py.get_game_path('pong') if the Atari ROM + # file has not been installed. + SKIP_ATARI_TESTS = True + SKIP_ATARI_MESSAGE = str(e) + del atari_py else: - del atari_py + del atari_py @unittest.skipIf(SKIP_ATARI_TESTS, SKIP_ATARI_MESSAGE) @unittest.skipIf(SKIP_GYM_TESTS, SKIP_GYM_MESSAGE) class AtariWrapperTest(parameterized.TestCase): - - @parameterized.parameters(True, False) - def test_pong(self, zero_discount_on_life_loss: bool): - env = gym.make('PongNoFrameskip-v4', full_action_space=True) - env = gym_wrapper.GymAtariAdapter(env) - env = atari_wrapper.AtariWrapper( - env, zero_discount_on_life_loss=zero_discount_on_life_loss) - - # Test converted observation spec. - observation_spec = env.observation_spec() - self.assertEqual(type(observation_spec), specs.Array) - - # Test converted action spec. - action_spec: specs.DiscreteArray = env.action_spec() - self.assertEqual(type(action_spec), specs.DiscreteArray) - self.assertEqual(action_spec.shape, ()) - self.assertEqual(action_spec.minimum, 0) - self.assertEqual(action_spec.maximum, 17) - self.assertEqual(action_spec.num_values, 18) - self.assertEqual(action_spec.dtype, np.dtype('int32')) - - # Check that the `render` call gets delegated to the underlying Gym env. - env.render('rgb_array') - - # Test step. - timestep = env.reset() - self.assertTrue(timestep.first()) - _ = env.step(0) - env.close() - - -if __name__ == '__main__': - absltest.main() + @parameterized.parameters(True, False) + def test_pong(self, zero_discount_on_life_loss: bool): + env = gym.make("PongNoFrameskip-v4", full_action_space=True) + env = gym_wrapper.GymAtariAdapter(env) + env = atari_wrapper.AtariWrapper( + env, zero_discount_on_life_loss=zero_discount_on_life_loss + ) + + # Test converted observation spec. + observation_spec = env.observation_spec() + self.assertEqual(type(observation_spec), specs.Array) + + # Test converted action spec. + action_spec: specs.DiscreteArray = env.action_spec() + self.assertEqual(type(action_spec), specs.DiscreteArray) + self.assertEqual(action_spec.shape, ()) + self.assertEqual(action_spec.minimum, 0) + self.assertEqual(action_spec.maximum, 17) + self.assertEqual(action_spec.num_values, 18) + self.assertEqual(action_spec.dtype, np.dtype("int32")) + + # Check that the `render` call gets delegated to the underlying Gym env. + env.render("rgb_array") + + # Test step. + timestep = env.reset() + self.assertTrue(timestep.first()) + _ = env.step(0) + env.close() + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/wrappers/base.py b/acme/wrappers/base.py index 987dccc71c..03e801969b 100644 --- a/acme/wrappers/base.py +++ b/acme/wrappers/base.py @@ -20,61 +20,62 @@ class EnvironmentWrapper(dm_env.Environment): - """Environment that wraps another environment. + """Environment that wraps another environment. This exposes the wrapped environment with the `.environment` property and also defines `__getattr__` so that attributes are invisibly forwarded to the wrapped environment (and hence enabling duck-typing). """ - _environment: dm_env.Environment + _environment: dm_env.Environment - def __init__(self, environment: dm_env.Environment): - self._environment = environment + def __init__(self, environment: dm_env.Environment): + self._environment = environment - def __getattr__(self, name): - if name.startswith("__"): - raise AttributeError( - "attempted to get missing private attribute '{}'".format(name)) - return getattr(self._environment, name) + def __getattr__(self, name): + if name.startswith("__"): + raise AttributeError( + "attempted to get missing private attribute '{}'".format(name) + ) + return getattr(self._environment, name) - @property - def environment(self) -> dm_env.Environment: - return self._environment + @property + def environment(self) -> dm_env.Environment: + return self._environment - # The following lines are necessary because methods defined in - # `dm_env.Environment` are not delegated through `__getattr__`, which would - # only be used to expose methods or properties that are not defined in the - # base `dm_env.Environment` class. + # The following lines are necessary because methods defined in + # `dm_env.Environment` are not delegated through `__getattr__`, which would + # only be used to expose methods or properties that are not defined in the + # base `dm_env.Environment` class. - def step(self, action) -> dm_env.TimeStep: - return self._environment.step(action) + def step(self, action) -> dm_env.TimeStep: + return self._environment.step(action) - def reset(self) -> dm_env.TimeStep: - return self._environment.reset() + def reset(self) -> dm_env.TimeStep: + return self._environment.reset() - def action_spec(self): - return self._environment.action_spec() + def action_spec(self): + return self._environment.action_spec() - def discount_spec(self): - return self._environment.discount_spec() + def discount_spec(self): + return self._environment.discount_spec() - def observation_spec(self): - return self._environment.observation_spec() + def observation_spec(self): + return self._environment.observation_spec() - def reward_spec(self): - return self._environment.reward_spec() + def reward_spec(self): + return self._environment.reward_spec() - def close(self): - return self._environment.close() + def close(self): + return self._environment.close() def wrap_all( environment: dm_env.Environment, wrappers: Sequence[Callable[[dm_env.Environment], dm_env.Environment]], ) -> dm_env.Environment: - """Given an environment, wrap it in a list of wrappers.""" - for w in wrappers: - environment = w(environment) + """Given an environment, wrap it in a list of wrappers.""" + for w in wrappers: + environment = w(environment) - return environment + return environment diff --git a/acme/wrappers/base_test.py b/acme/wrappers/base_test.py index 96160f6334..4bcb9a12dd 100644 --- a/acme/wrappers/base_test.py +++ b/acme/wrappers/base_test.py @@ -17,28 +17,27 @@ import copy import pickle +from absl.testing import absltest + from acme.testing import fakes from acme.wrappers import base -from absl.testing import absltest - class BaseTest(absltest.TestCase): + def test_pickle_unpickle(self): + test_env = base.EnvironmentWrapper(environment=fakes.DiscreteEnvironment()) - def test_pickle_unpickle(self): - test_env = base.EnvironmentWrapper(environment=fakes.DiscreteEnvironment()) + test_env_pickled = pickle.dumps(test_env) + test_env_restored = pickle.loads(test_env_pickled) + self.assertEqual( + test_env.observation_spec(), test_env_restored.observation_spec(), + ) - test_env_pickled = pickle.dumps(test_env) - test_env_restored = pickle.loads(test_env_pickled) - self.assertEqual( - test_env.observation_spec(), - test_env_restored.observation_spec(), - ) + def test_deepcopy(self): + test_env = base.EnvironmentWrapper(environment=fakes.DiscreteEnvironment()) + copied_env = copy.deepcopy(test_env) + del copied_env - def test_deepcopy(self): - test_env = base.EnvironmentWrapper(environment=fakes.DiscreteEnvironment()) - copied_env = copy.deepcopy(test_env) - del copied_env -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/wrappers/canonical_spec.py b/acme/wrappers/canonical_spec.py index c5363e488e..e3e3e1702b 100644 --- a/acme/wrappers/canonical_spec.py +++ b/acme/wrappers/canonical_spec.py @@ -21,17 +21,16 @@ of the spec is unchanged, while the maximum/minimum values are set to +/- 1. """ -from acme import specs -from acme import types -from acme.wrappers import base - import dm_env import numpy as np import tree +from acme import specs, types +from acme.wrappers import base + class CanonicalSpecWrapper(base.EnvironmentWrapper): - """Wrapper which converts environments to use canonical action specs. + """Wrapper which converts environments to use canonical action specs. This only affects action specs of type `specs.BoundedArray`. @@ -41,58 +40,57 @@ class CanonicalSpecWrapper(base.EnvironmentWrapper): to +/- 1. """ - def __init__(self, environment: dm_env.Environment, clip: bool = False): - super().__init__(environment) - self._action_spec = environment.action_spec() - self._clip = clip + def __init__(self, environment: dm_env.Environment, clip: bool = False): + super().__init__(environment) + self._action_spec = environment.action_spec() + self._clip = clip - def step(self, action: types.NestedArray) -> dm_env.TimeStep: - scaled_action = _scale_nested_action(action, self._action_spec, self._clip) - return self._environment.step(scaled_action) + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + scaled_action = _scale_nested_action(action, self._action_spec, self._clip) + return self._environment.step(scaled_action) - def action_spec(self): - return _convert_spec(self._environment.action_spec()) + def action_spec(self): + return _convert_spec(self._environment.action_spec()) def _convert_spec(nested_spec: types.NestedSpec) -> types.NestedSpec: - """Converts all bounded specs in nested spec to the canonical scale.""" + """Converts all bounded specs in nested spec to the canonical scale.""" - def _convert_single_spec(spec: specs.Array) -> specs.Array: - """Converts a single spec to canonical if bounded.""" - if isinstance(spec, specs.BoundedArray): - return spec.replace( - minimum=-np.ones(spec.shape), maximum=np.ones(spec.shape)) - else: - return spec + def _convert_single_spec(spec: specs.Array) -> specs.Array: + """Converts a single spec to canonical if bounded.""" + if isinstance(spec, specs.BoundedArray): + return spec.replace( + minimum=-np.ones(spec.shape), maximum=np.ones(spec.shape) + ) + else: + return spec - return tree.map_structure(_convert_single_spec, nested_spec) + return tree.map_structure(_convert_single_spec, nested_spec) def _scale_nested_action( - nested_action: types.NestedArray, - nested_spec: types.NestedSpec, - clip: bool, + nested_action: types.NestedArray, nested_spec: types.NestedSpec, clip: bool, ) -> types.NestedArray: - """Converts a canonical nested action back to the given nested action spec.""" + """Converts a canonical nested action back to the given nested action spec.""" - def _scale_action(action: np.ndarray, spec: specs.Array): - """Converts a single canonical action back to the given action spec.""" - if isinstance(spec, specs.BoundedArray): - # Get scale and offset of output action spec. - scale = spec.maximum - spec.minimum - offset = spec.minimum + def _scale_action(action: np.ndarray, spec: specs.Array): + """Converts a single canonical action back to the given action spec.""" + if isinstance(spec, specs.BoundedArray): + # Get scale and offset of output action spec. + scale = spec.maximum - spec.minimum + offset = spec.minimum - # Maybe clip the action. - if clip: - action = np.clip(action, -1.0, 1.0) + # Maybe clip the action. + if clip: + action = np.clip(action, -1.0, 1.0) - # Map action to [0, 1]. - action = 0.5 * (action + 1.0) + # Map action to [0, 1]. + action = 0.5 * (action + 1.0) - # Map action to [spec.minimum, spec.maximum]. - action *= scale - action += offset + # Map action to [spec.minimum, spec.maximum]. + action *= scale + action += offset - return action + return action - return tree.map_structure(_scale_action, nested_action, nested_spec) + return tree.map_structure(_scale_action, nested_action, nested_spec) diff --git a/acme/wrappers/concatenate_observations.py b/acme/wrappers/concatenate_observations.py index 895de39e78..31f5ea31e7 100644 --- a/acme/wrappers/concatenate_observations.py +++ b/acme/wrappers/concatenate_observations.py @@ -14,17 +14,18 @@ """Wrapper that implements concatenation of observation fields.""" -from typing import Sequence, Optional +from typing import Optional, Sequence -from acme import types -from acme.wrappers import base import dm_env import numpy as np import tree +from acme import types +from acme.wrappers import base + def _concat(values: types.NestedArray) -> np.ndarray: - """Concatenates the leaves of `values` along the leading dimension. + """Concatenates the leaves of `values` along the leading dimension. Treats scalars as 1d arrays and expects that the shapes of all leaves are the same except for the leading dimension. @@ -35,17 +36,17 @@ def _concat(values: types.NestedArray) -> np.ndarray: Returns: The concatenated array. """ - leaves = list(map(np.atleast_1d, tree.flatten(values))) - return np.concatenate(leaves) + leaves = list(map(np.atleast_1d, tree.flatten(values))) + return np.concatenate(leaves) def _zeros_like(nest, dtype=None): - """Generate a nested NumPy array according to spec.""" - return tree.map_structure(lambda x: np.zeros(x.shape, dtype or x.dtype), nest) + """Generate a nested NumPy array according to spec.""" + return tree.map_structure(lambda x: np.zeros(x.shape, dtype or x.dtype), nest) class ConcatObservationWrapper(base.EnvironmentWrapper): - """Wrapper that concatenates observation fields. + """Wrapper that concatenates observation fields. It takes an environment with nested observations and concatenates the fields in a single tensor. The original fields should be 1-dimensional. @@ -55,43 +56,48 @@ class ConcatObservationWrapper(base.EnvironmentWrapper): their names, see tree.flatten for more information. """ - def __init__(self, - environment: dm_env.Environment, - name_filter: Optional[Sequence[str]] = None): - """Initializes a new ConcatObservationWrapper. + def __init__( + self, + environment: dm_env.Environment, + name_filter: Optional[Sequence[str]] = None, + ): + """Initializes a new ConcatObservationWrapper. Args: environment: Environment to wrap. name_filter: Sequence of observation names to keep. None keeps them all. """ - super().__init__(environment) - observation_spec = environment.observation_spec() - if name_filter is None: - name_filter = list(observation_spec.keys()) - self._obs_names = [x for x in name_filter if x in observation_spec.keys()] - - dummy_obs = _zeros_like(observation_spec) - dummy_obs = self._convert_observation(dummy_obs) - self._observation_spec = dm_env.specs.BoundedArray( - shape=dummy_obs.shape, - dtype=dummy_obs.dtype, - minimum=-np.inf, - maximum=np.inf, - name='state') - - def _convert_observation(self, observation): - obs = {k: observation[k] for k in self._obs_names} - return _concat(obs) - - def step(self, action) -> dm_env.TimeStep: - timestep = self._environment.step(action) - return timestep._replace( - observation=self._convert_observation(timestep.observation)) - - def reset(self) -> dm_env.TimeStep: - timestep = self._environment.reset() - return timestep._replace( - observation=self._convert_observation(timestep.observation)) - - def observation_spec(self) -> types.NestedSpec: - return self._observation_spec + super().__init__(environment) + observation_spec = environment.observation_spec() + if name_filter is None: + name_filter = list(observation_spec.keys()) + self._obs_names = [x for x in name_filter if x in observation_spec.keys()] + + dummy_obs = _zeros_like(observation_spec) + dummy_obs = self._convert_observation(dummy_obs) + self._observation_spec = dm_env.specs.BoundedArray( + shape=dummy_obs.shape, + dtype=dummy_obs.dtype, + minimum=-np.inf, + maximum=np.inf, + name="state", + ) + + def _convert_observation(self, observation): + obs = {k: observation[k] for k in self._obs_names} + return _concat(obs) + + def step(self, action) -> dm_env.TimeStep: + timestep = self._environment.step(action) + return timestep._replace( + observation=self._convert_observation(timestep.observation) + ) + + def reset(self) -> dm_env.TimeStep: + timestep = self._environment.reset() + return timestep._replace( + observation=self._convert_observation(timestep.observation) + ) + + def observation_spec(self) -> types.NestedSpec: + return self._observation_spec diff --git a/acme/wrappers/delayed_reward.py b/acme/wrappers/delayed_reward.py index 65fb45c4fc..ceba937ead 100644 --- a/acme/wrappers/delayed_reward.py +++ b/acme/wrappers/delayed_reward.py @@ -17,15 +17,16 @@ import operator from typing import Optional -from acme import types -from acme.wrappers import base import dm_env import numpy as np import tree +from acme import types +from acme.wrappers import base + class DelayedRewardWrapper(base.EnvironmentWrapper): - """Implements delayed reward on environments. + """Implements delayed reward on environments. This wrapper sparsifies any environment by adding a reward delay. Instead of returning a reward at each step, the wrapped environment returns the @@ -34,10 +35,10 @@ class DelayedRewardWrapper(base.EnvironmentWrapper): the environment harder by adding exploration and longer term dependencies. """ - def __init__(self, - environment: dm_env.Environment, - accumulation_period: Optional[int] = 1): - """Initializes a `DelayedRewardWrapper`. + def __init__( + self, environment: dm_env.Environment, accumulation_period: Optional[int] = 1 + ): + """Initializes a `DelayedRewardWrapper`. Args: environment: An environment conforming to the dm_env.Environment @@ -49,41 +50,44 @@ def __init__(self, episode. If `accumulation_period`=1, this wrapper is a no-op. """ - super().__init__(environment) - if accumulation_period is not None and accumulation_period < 1: - raise ValueError( - f'Accumuluation period is {accumulation_period} but should be greater than 1.' - ) - self._accumuation_period = accumulation_period - self._delayed_reward = self._zero_reward - self._accumulation_counter = 0 + super().__init__(environment) + if accumulation_period is not None and accumulation_period < 1: + raise ValueError( + f"Accumuluation period is {accumulation_period} but should be greater than 1." + ) + self._accumuation_period = accumulation_period + self._delayed_reward = self._zero_reward + self._accumulation_counter = 0 - @property - def _zero_reward(self): - return tree.map_structure(lambda s: np.zeros(s.shape, s.dtype), - self._environment.reward_spec()) + @property + def _zero_reward(self): + return tree.map_structure( + lambda s: np.zeros(s.shape, s.dtype), self._environment.reward_spec() + ) - def reset(self) -> dm_env.TimeStep: - """Resets environment and provides the first timestep.""" - timestep = self.environment.reset() - self._delayed_reward = self._zero_reward - self._accumulation_counter = 0 - return timestep + def reset(self) -> dm_env.TimeStep: + """Resets environment and provides the first timestep.""" + timestep = self.environment.reset() + self._delayed_reward = self._zero_reward + self._accumulation_counter = 0 + return timestep - def step(self, action: types.NestedArray) -> dm_env.TimeStep: - """Performs one step and maybe returns a reward.""" - timestep = self.environment.step(action) - self._delayed_reward = tree.map_structure(operator.iadd, - self._delayed_reward, - timestep.reward) - self._accumulation_counter += 1 + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + """Performs one step and maybe returns a reward.""" + timestep = self.environment.step(action) + self._delayed_reward = tree.map_structure( + operator.iadd, self._delayed_reward, timestep.reward + ) + self._accumulation_counter += 1 - if (self._accumuation_period is not None and self._accumulation_counter - == self._accumuation_period) or timestep.last(): - timestep = timestep._replace(reward=self._delayed_reward) - self._accumulation_counter = 0 - self._delayed_reward = self._zero_reward - else: - timestep = timestep._replace(reward=self._zero_reward) + if ( + self._accumuation_period is not None + and self._accumulation_counter == self._accumuation_period + ) or timestep.last(): + timestep = timestep._replace(reward=self._delayed_reward) + self._accumulation_counter = 0 + self._delayed_reward = self._zero_reward + else: + timestep = timestep._replace(reward=self._zero_reward) - return timestep + return timestep diff --git a/acme/wrappers/delayed_reward_test.py b/acme/wrappers/delayed_reward_test.py index c85032dff3..2d7276823f 100644 --- a/acme/wrappers/delayed_reward_test.py +++ b/acme/wrappers/delayed_reward_test.py @@ -15,76 +15,78 @@ """Tests for the delayed reward wrapper.""" from typing import Any -from acme import wrappers -from acme.testing import fakes -from dm_env import specs + import numpy as np import tree +from absl.testing import absltest, parameterized +from dm_env import specs -from absl.testing import absltest -from absl.testing import parameterized +from acme import wrappers +from acme.testing import fakes def _episode_reward(env): - timestep = env.reset() - action_spec = env.action_spec() - rng = np.random.RandomState(seed=1) - reward = [] - while not timestep.last(): - timestep = env.step(rng.randint(action_spec.num_values)) - reward.append(timestep.reward) - return reward + timestep = env.reset() + action_spec = env.action_spec() + rng = np.random.RandomState(seed=1) + reward = [] + while not timestep.last(): + timestep = env.step(rng.randint(action_spec.num_values)) + reward.append(timestep.reward) + return reward def _compare_nested_sequences(seq1, seq2): - """Compare two sequences of arrays.""" - return all([(l == m).all() for l, m in zip(seq1, seq2)]) + """Compare two sequences of arrays.""" + return all([(l == m).all() for l, m in zip(seq1, seq2)]) class _DiscreteEnvironmentOneReward(fakes.DiscreteEnvironment): - """A fake discrete environement with constant reward of 1.""" + """A fake discrete environement with constant reward of 1.""" - def _generate_fake_reward(self) -> Any: - return tree.map_structure(lambda s: s.generate_value() + 1., - self._spec.rewards) + def _generate_fake_reward(self) -> Any: + return tree.map_structure( + lambda s: s.generate_value() + 1.0, self._spec.rewards + ) class DelayedRewardTest(parameterized.TestCase): - - def test_noop(self): - """Ensure when accumulation_period=1 it does not change anything.""" - base_env = _DiscreteEnvironmentOneReward( - action_dtype=np.int64, - reward_spec=specs.Array(dtype=np.float32, shape=())) - wrapped_env = wrappers.DelayedRewardWrapper(base_env, accumulation_period=1) - base_episode_reward = _episode_reward(base_env) - wrapped_episode_reward = _episode_reward(wrapped_env) - self.assertEqual(base_episode_reward, wrapped_episode_reward) - - def test_noop_composite_reward(self): - """No-op test with composite rewards.""" - base_env = _DiscreteEnvironmentOneReward( - action_dtype=np.int64, - reward_spec=specs.Array(dtype=np.float32, shape=(2, 1))) - wrapped_env = wrappers.DelayedRewardWrapper(base_env, accumulation_period=1) - base_episode_reward = _episode_reward(base_env) - wrapped_episode_reward = _episode_reward(wrapped_env) - self.assertTrue( - _compare_nested_sequences(base_episode_reward, wrapped_episode_reward)) - - @parameterized.parameters(10, None) - def test_same_episode_composite_reward(self, accumulation_period): - """Ensure that wrapper does not change total reward.""" - base_env = _DiscreteEnvironmentOneReward( - action_dtype=np.int64, - reward_spec=specs.Array(dtype=np.float32, shape=())) - wrapped_env = wrappers.DelayedRewardWrapper( - base_env, accumulation_period=accumulation_period) - base_episode_reward = _episode_reward(base_env) - wrapped_episode_reward = _episode_reward(wrapped_env) - self.assertTrue( - (sum(base_episode_reward) == sum(wrapped_episode_reward)).all()) - - -if __name__ == '__main__': - absltest.main() + def test_noop(self): + """Ensure when accumulation_period=1 it does not change anything.""" + base_env = _DiscreteEnvironmentOneReward( + action_dtype=np.int64, reward_spec=specs.Array(dtype=np.float32, shape=()) + ) + wrapped_env = wrappers.DelayedRewardWrapper(base_env, accumulation_period=1) + base_episode_reward = _episode_reward(base_env) + wrapped_episode_reward = _episode_reward(wrapped_env) + self.assertEqual(base_episode_reward, wrapped_episode_reward) + + def test_noop_composite_reward(self): + """No-op test with composite rewards.""" + base_env = _DiscreteEnvironmentOneReward( + action_dtype=np.int64, + reward_spec=specs.Array(dtype=np.float32, shape=(2, 1)), + ) + wrapped_env = wrappers.DelayedRewardWrapper(base_env, accumulation_period=1) + base_episode_reward = _episode_reward(base_env) + wrapped_episode_reward = _episode_reward(wrapped_env) + self.assertTrue( + _compare_nested_sequences(base_episode_reward, wrapped_episode_reward) + ) + + @parameterized.parameters(10, None) + def test_same_episode_composite_reward(self, accumulation_period): + """Ensure that wrapper does not change total reward.""" + base_env = _DiscreteEnvironmentOneReward( + action_dtype=np.int64, reward_spec=specs.Array(dtype=np.float32, shape=()) + ) + wrapped_env = wrappers.DelayedRewardWrapper( + base_env, accumulation_period=accumulation_period + ) + base_episode_reward = _episode_reward(base_env) + wrapped_episode_reward = _episode_reward(wrapped_env) + self.assertTrue((sum(base_episode_reward) == sum(wrapped_episode_reward)).all()) + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/wrappers/expand_scalar_observation_shapes.py b/acme/wrappers/expand_scalar_observation_shapes.py index 7554d20272..f4c172eef0 100644 --- a/acme/wrappers/expand_scalar_observation_shapes.py +++ b/acme/wrappers/expand_scalar_observation_shapes.py @@ -23,15 +23,16 @@ from typing import Any -from acme.wrappers import base import dm_env -from dm_env import specs import numpy as np import tree +from dm_env import specs + +from acme.wrappers import base class ExpandScalarObservationShapesWrapper(base.EnvironmentWrapper): - """Expands scalar shapes in the observation. + """Expands scalar shapes in the observation. For example, if the observation holds the previous (scalar) action, this wrapper makes sure the environment returns a previous action with shape [1]. @@ -39,30 +40,33 @@ class ExpandScalarObservationShapesWrapper(base.EnvironmentWrapper): This can be necessary when stacking observations with previous actions. """ - def step(self, action: Any) -> dm_env.TimeStep: - timestep = self._environment.step(action) - expanded_observation = tree.map_structure(_expand_scalar_array_shape, - timestep.observation) - return timestep._replace(observation=expanded_observation) + def step(self, action: Any) -> dm_env.TimeStep: + timestep = self._environment.step(action) + expanded_observation = tree.map_structure( + _expand_scalar_array_shape, timestep.observation + ) + return timestep._replace(observation=expanded_observation) - def reset(self) -> dm_env.TimeStep: - timestep = self._environment.reset() - expanded_observation = tree.map_structure(_expand_scalar_array_shape, - timestep.observation) - return timestep._replace(observation=expanded_observation) + def reset(self) -> dm_env.TimeStep: + timestep = self._environment.reset() + expanded_observation = tree.map_structure( + _expand_scalar_array_shape, timestep.observation + ) + return timestep._replace(observation=expanded_observation) - def observation_spec(self) -> specs.Array: - return tree.map_structure(_expand_scalar_spec_shape, - self._environment.observation_spec()) + def observation_spec(self) -> specs.Array: + return tree.map_structure( + _expand_scalar_spec_shape, self._environment.observation_spec() + ) def _expand_scalar_spec_shape(spec: specs.Array) -> specs.Array: - if not spec.shape: - # NOTE: This line upcasts the spec to an Array to avoid edge cases (as in - # DiscreteSpec) where we cannot set the spec's shape. - spec = specs.Array(shape=(1,), dtype=spec.dtype, name=spec.name) - return spec + if not spec.shape: + # NOTE: This line upcasts the spec to an Array to avoid edge cases (as in + # DiscreteSpec) where we cannot set the spec's shape. + spec = specs.Array(shape=(1,), dtype=spec.dtype, name=spec.name) + return spec def _expand_scalar_array_shape(array: np.ndarray) -> np.ndarray: - return array if array.shape else np.expand_dims(array, axis=-1) + return array if array.shape else np.expand_dims(array, axis=-1) diff --git a/acme/wrappers/frame_stacking.py b/acme/wrappers/frame_stacking.py index ce06dd71d7..cfccf9ae75 100644 --- a/acme/wrappers/frame_stacking.py +++ b/acme/wrappers/frame_stacking.py @@ -16,84 +16,93 @@ import collections -from acme import types -from acme.wrappers import base import dm_env -from dm_env import specs as dm_env_specs import numpy as np import tree +from dm_env import specs as dm_env_specs + +from acme import types +from acme.wrappers import base class FrameStackingWrapper(base.EnvironmentWrapper): - """Wrapper that stacks observations along a new final axis.""" + """Wrapper that stacks observations along a new final axis.""" - def __init__(self, environment: dm_env.Environment, num_frames: int = 4, - flatten: bool = False): - """Initializes a new FrameStackingWrapper. + def __init__( + self, + environment: dm_env.Environment, + num_frames: int = 4, + flatten: bool = False, + ): + """Initializes a new FrameStackingWrapper. Args: environment: Environment. num_frames: Number frames to stack. flatten: Whether to flatten the channel and stack dimensions together. """ - self._environment = environment - original_spec = self._environment.observation_spec() - self._stackers = tree.map_structure( - lambda _: FrameStacker(num_frames=num_frames, flatten=flatten), - self._environment.observation_spec()) - self._observation_spec = tree.map_structure( - lambda stacker, spec: stacker.update_spec(spec), - self._stackers, original_spec) - - def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: - observation = tree.map_structure(lambda stacker, x: stacker.step(x), - self._stackers, timestep.observation) - return timestep._replace(observation=observation) - - def reset(self) -> dm_env.TimeStep: - for stacker in tree.flatten(self._stackers): - stacker.reset() - return self._process_timestep(self._environment.reset()) - - def step(self, action: int) -> dm_env.TimeStep: - return self._process_timestep(self._environment.step(action)) - - def observation_spec(self) -> types.NestedSpec: - return self._observation_spec + self._environment = environment + original_spec = self._environment.observation_spec() + self._stackers = tree.map_structure( + lambda _: FrameStacker(num_frames=num_frames, flatten=flatten), + self._environment.observation_spec(), + ) + self._observation_spec = tree.map_structure( + lambda stacker, spec: stacker.update_spec(spec), + self._stackers, + original_spec, + ) + + def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: + observation = tree.map_structure( + lambda stacker, x: stacker.step(x), self._stackers, timestep.observation + ) + return timestep._replace(observation=observation) + + def reset(self) -> dm_env.TimeStep: + for stacker in tree.flatten(self._stackers): + stacker.reset() + return self._process_timestep(self._environment.reset()) + + def step(self, action: int) -> dm_env.TimeStep: + return self._process_timestep(self._environment.step(action)) + + def observation_spec(self) -> types.NestedSpec: + return self._observation_spec class FrameStacker: - """Simple class for frame-stacking observations.""" - - def __init__(self, num_frames: int, flatten: bool = False): - self._num_frames = num_frames - self._flatten = flatten - self.reset() - - @property - def num_frames(self) -> int: - return self._num_frames - - def reset(self): - self._stack = collections.deque(maxlen=self._num_frames) - - def step(self, frame: np.ndarray) -> np.ndarray: - """Append frame to stack and return the stack.""" - if not self._stack: - # Fill stack with blank frames if empty. - self._stack.extend([np.zeros_like(frame)] * (self._num_frames - 1)) - self._stack.append(frame) - stacked_frames = np.stack(self._stack, axis=-1) - - if not self._flatten: - return stacked_frames - else: - new_shape = stacked_frames.shape[:-2] + (-1,) - return stacked_frames.reshape(*new_shape) - - def update_spec(self, spec: dm_env_specs.Array) -> dm_env_specs.Array: - if not self._flatten: - new_shape = spec.shape + (self._num_frames,) - else: - new_shape = spec.shape[:-1] + (self._num_frames * spec.shape[-1],) - return dm_env_specs.Array(shape=new_shape, dtype=spec.dtype, name=spec.name) + """Simple class for frame-stacking observations.""" + + def __init__(self, num_frames: int, flatten: bool = False): + self._num_frames = num_frames + self._flatten = flatten + self.reset() + + @property + def num_frames(self) -> int: + return self._num_frames + + def reset(self): + self._stack = collections.deque(maxlen=self._num_frames) + + def step(self, frame: np.ndarray) -> np.ndarray: + """Append frame to stack and return the stack.""" + if not self._stack: + # Fill stack with blank frames if empty. + self._stack.extend([np.zeros_like(frame)] * (self._num_frames - 1)) + self._stack.append(frame) + stacked_frames = np.stack(self._stack, axis=-1) + + if not self._flatten: + return stacked_frames + else: + new_shape = stacked_frames.shape[:-2] + (-1,) + return stacked_frames.reshape(*new_shape) + + def update_spec(self, spec: dm_env_specs.Array) -> dm_env_specs.Array: + if not self._flatten: + new_shape = spec.shape + (self._num_frames,) + else: + new_shape = spec.shape[:-1] + (self._num_frames * spec.shape[-1],) + return dm_env_specs.Array(shape=new_shape, dtype=spec.dtype, name=spec.name) diff --git a/acme/wrappers/frame_stacking_test.py b/acme/wrappers/frame_stacking_test.py index ff21f47e2e..3c150f2d16 100644 --- a/acme/wrappers/frame_stacking_test.py +++ b/acme/wrappers/frame_stacking_test.py @@ -14,68 +14,67 @@ """Tests for the single precision wrapper.""" -from acme import wrappers -from acme.testing import fakes import numpy as np import tree - from absl.testing import absltest +from acme import wrappers +from acme.testing import fakes + class FakeNonZeroObservationEnvironment(fakes.ContinuousEnvironment): - """Fake environment with non-zero observations.""" + """Fake environment with non-zero observations.""" - def _generate_fake_observation(self): - original_observation = super()._generate_fake_observation() - return tree.map_structure(np.ones_like, original_observation) + def _generate_fake_observation(self): + original_observation = super()._generate_fake_observation() + return tree.map_structure(np.ones_like, original_observation) class FrameStackingTest(absltest.TestCase): + def test_specs(self): + original_env = FakeNonZeroObservationEnvironment() + env = wrappers.FrameStackingWrapper(original_env, 2) - def test_specs(self): - original_env = FakeNonZeroObservationEnvironment() - env = wrappers.FrameStackingWrapper(original_env, 2) - - original_observation_spec = original_env.observation_spec() - expected_shape = original_observation_spec.shape + (2,) - observation_spec = env.observation_spec() - self.assertEqual(expected_shape, observation_spec.shape) + original_observation_spec = original_env.observation_spec() + expected_shape = original_observation_spec.shape + (2,) + observation_spec = env.observation_spec() + self.assertEqual(expected_shape, observation_spec.shape) - expected_action_spec = original_env.action_spec() - action_spec = env.action_spec() - self.assertEqual(expected_action_spec, action_spec) + expected_action_spec = original_env.action_spec() + action_spec = env.action_spec() + self.assertEqual(expected_action_spec, action_spec) - expected_reward_spec = original_env.reward_spec() - reward_spec = env.reward_spec() - self.assertEqual(expected_reward_spec, reward_spec) + expected_reward_spec = original_env.reward_spec() + reward_spec = env.reward_spec() + self.assertEqual(expected_reward_spec, reward_spec) - expected_discount_spec = original_env.discount_spec() - discount_spec = env.discount_spec() - self.assertEqual(expected_discount_spec, discount_spec) + expected_discount_spec = original_env.discount_spec() + discount_spec = env.discount_spec() + self.assertEqual(expected_discount_spec, discount_spec) - def test_step(self): - original_env = FakeNonZeroObservationEnvironment() - env = wrappers.FrameStackingWrapper(original_env, 2) - observation_spec = env.observation_spec() - action_spec = env.action_spec() + def test_step(self): + original_env = FakeNonZeroObservationEnvironment() + env = wrappers.FrameStackingWrapper(original_env, 2) + observation_spec = env.observation_spec() + action_spec = env.action_spec() - timestep = env.reset() - self.assertEqual(observation_spec.shape, timestep.observation.shape) - self.assertTrue(np.all(timestep.observation[..., 0] == 0)) + timestep = env.reset() + self.assertEqual(observation_spec.shape, timestep.observation.shape) + self.assertTrue(np.all(timestep.observation[..., 0] == 0)) - timestep = env.step(action_spec.generate_value()) - self.assertEqual(observation_spec.shape, timestep.observation.shape) + timestep = env.step(action_spec.generate_value()) + self.assertEqual(observation_spec.shape, timestep.observation.shape) - def test_second_reset(self): - original_env = FakeNonZeroObservationEnvironment() - env = wrappers.FrameStackingWrapper(original_env, 2) - action_spec = env.action_spec() + def test_second_reset(self): + original_env = FakeNonZeroObservationEnvironment() + env = wrappers.FrameStackingWrapper(original_env, 2) + action_spec = env.action_spec() - env.reset() - env.step(action_spec.generate_value()) - timestep = env.reset() - self.assertTrue(np.all(timestep.observation[..., 0] == 0)) + env.reset() + env.step(action_spec.generate_value()) + timestep = env.reset() + self.assertTrue(np.all(timestep.observation[..., 0] == 0)) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/wrappers/gym_wrapper.py b/acme/wrappers/gym_wrapper.py index 8170f8a118..bdd33b2a64 100644 --- a/acme/wrappers/gym_wrapper.py +++ b/acme/wrappers/gym_wrapper.py @@ -16,99 +16,99 @@ from typing import Any, Dict, List, Optional -from acme import specs -from acme import types - import dm_env import gym -from gym import spaces import numpy as np import tree +from gym import spaces + +from acme import specs, types class GymWrapper(dm_env.Environment): - """Environment wrapper for OpenAI Gym environments.""" - - # Note: we don't inherit from base.EnvironmentWrapper because that class - # assumes that the wrapped environment is a dm_env.Environment. - - def __init__(self, environment: gym.Env): - - self._environment = environment - self._reset_next_step = True - self._last_info = None - - # Convert action and observation specs. - obs_space = self._environment.observation_space - act_space = self._environment.action_space - self._observation_spec = _convert_to_spec(obs_space, name='observation') - self._action_spec = _convert_to_spec(act_space, name='action') - - def reset(self) -> dm_env.TimeStep: - """Resets the episode.""" - self._reset_next_step = False - observation = self._environment.reset() - # Reset the diagnostic information. - self._last_info = None - return dm_env.restart(observation) - - def step(self, action: types.NestedArray) -> dm_env.TimeStep: - """Steps the environment.""" - if self._reset_next_step: - return self.reset() - - observation, reward, done, info = self._environment.step(action) - self._reset_next_step = done - self._last_info = info - - # Convert the type of the reward based on the spec, respecting the scalar or - # array property. - reward = tree.map_structure( - lambda x, t: ( # pylint: disable=g-long-lambda - t.dtype.type(x) - if np.isscalar(x) else np.asarray(x, dtype=t.dtype)), - reward, - self.reward_spec()) - - if done: - truncated = info.get('TimeLimit.truncated', False) - if truncated: - return dm_env.truncation(reward, observation) - return dm_env.termination(reward, observation) - return dm_env.transition(reward, observation) - - def observation_spec(self) -> types.NestedSpec: - return self._observation_spec - - def action_spec(self) -> types.NestedSpec: - return self._action_spec - - def get_info(self) -> Optional[Dict[str, Any]]: - """Returns the last info returned from env.step(action). + """Environment wrapper for OpenAI Gym environments.""" + + # Note: we don't inherit from base.EnvironmentWrapper because that class + # assumes that the wrapped environment is a dm_env.Environment. + + def __init__(self, environment: gym.Env): + + self._environment = environment + self._reset_next_step = True + self._last_info = None + + # Convert action and observation specs. + obs_space = self._environment.observation_space + act_space = self._environment.action_space + self._observation_spec = _convert_to_spec(obs_space, name="observation") + self._action_spec = _convert_to_spec(act_space, name="action") + + def reset(self) -> dm_env.TimeStep: + """Resets the episode.""" + self._reset_next_step = False + observation = self._environment.reset() + # Reset the diagnostic information. + self._last_info = None + return dm_env.restart(observation) + + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + """Steps the environment.""" + if self._reset_next_step: + return self.reset() + + observation, reward, done, info = self._environment.step(action) + self._reset_next_step = done + self._last_info = info + + # Convert the type of the reward based on the spec, respecting the scalar or + # array property. + reward = tree.map_structure( + lambda x, t: ( # pylint: disable=g-long-lambda + t.dtype.type(x) if np.isscalar(x) else np.asarray(x, dtype=t.dtype) + ), + reward, + self.reward_spec(), + ) + + if done: + truncated = info.get("TimeLimit.truncated", False) + if truncated: + return dm_env.truncation(reward, observation) + return dm_env.termination(reward, observation) + return dm_env.transition(reward, observation) + + def observation_spec(self) -> types.NestedSpec: + return self._observation_spec + + def action_spec(self) -> types.NestedSpec: + return self._action_spec + + def get_info(self) -> Optional[Dict[str, Any]]: + """Returns the last info returned from env.step(action). Returns: info: dictionary of diagnostic information from the last environment step """ - return self._last_info + return self._last_info - @property - def environment(self) -> gym.Env: - """Returns the wrapped environment.""" - return self._environment + @property + def environment(self) -> gym.Env: + """Returns the wrapped environment.""" + return self._environment - def __getattr__(self, name: str): - if name.startswith('__'): - raise AttributeError( - "attempted to get missing private attribute '{}'".format(name)) - return getattr(self._environment, name) + def __getattr__(self, name: str): + if name.startswith("__"): + raise AttributeError( + "attempted to get missing private attribute '{}'".format(name) + ) + return getattr(self._environment, name) - def close(self): - self._environment.close() + def close(self): + self._environment.close() -def _convert_to_spec(space: gym.Space, - name: Optional[str] = None) -> types.NestedSpec: - """Converts an OpenAI Gym space to a dm_env spec or nested structure of specs. +def _convert_to_spec(space: gym.Space, name: Optional[str] = None) -> types.NestedSpec: + """Converts an OpenAI Gym space to a dm_env spec or nested structure of specs. Box, MultiBinary and MultiDiscrete Gym spaces are converted to BoundedArray specs. Discrete OpenAI spaces are converted to DiscreteArray specs. Tuple and @@ -122,48 +122,46 @@ def _convert_to_spec(space: gym.Space, A dm_env spec or nested structure of specs, corresponding to the input space. """ - if isinstance(space, spaces.Discrete): - return specs.DiscreteArray(num_values=space.n, dtype=space.dtype, name=name) - - elif isinstance(space, spaces.Box): - return specs.BoundedArray( - shape=space.shape, - dtype=space.dtype, - minimum=space.low, - maximum=space.high, - name=name) - - elif isinstance(space, spaces.MultiBinary): - return specs.BoundedArray( - shape=space.shape, - dtype=space.dtype, - minimum=0.0, - maximum=1.0, - name=name) - - elif isinstance(space, spaces.MultiDiscrete): - return specs.BoundedArray( - shape=space.shape, - dtype=space.dtype, - minimum=np.zeros(space.shape), - maximum=space.nvec - 1, - name=name) - - elif isinstance(space, spaces.Tuple): - return tuple(_convert_to_spec(s, name) for s in space.spaces) - - elif isinstance(space, spaces.Dict): - return { - key: _convert_to_spec(value, key) - for key, value in space.spaces.items() - } - - else: - raise ValueError('Unexpected gym space: {}'.format(space)) + if isinstance(space, spaces.Discrete): + return specs.DiscreteArray(num_values=space.n, dtype=space.dtype, name=name) + + elif isinstance(space, spaces.Box): + return specs.BoundedArray( + shape=space.shape, + dtype=space.dtype, + minimum=space.low, + maximum=space.high, + name=name, + ) + + elif isinstance(space, spaces.MultiBinary): + return specs.BoundedArray( + shape=space.shape, dtype=space.dtype, minimum=0.0, maximum=1.0, name=name + ) + + elif isinstance(space, spaces.MultiDiscrete): + return specs.BoundedArray( + shape=space.shape, + dtype=space.dtype, + minimum=np.zeros(space.shape), + maximum=space.nvec - 1, + name=name, + ) + + elif isinstance(space, spaces.Tuple): + return tuple(_convert_to_spec(s, name) for s in space.spaces) + + elif isinstance(space, spaces.Dict): + return { + key: _convert_to_spec(value, key) for key, value in space.spaces.items() + } + + else: + raise ValueError("Unexpected gym space: {}".format(space)) class GymAtariAdapter(GymWrapper): - """Specialized wrapper exposing a Gym Atari environment. + """Specialized wrapper exposing a Gym Atari environment. This wraps the Gym Atari environment in the same way as GymWrapper, but also exposes the lives count as an observation. The resuling observations are @@ -171,36 +169,37 @@ class GymAtariAdapter(GymWrapper): lives count. """ - def _wrap_observation(self, - observation: types.NestedArray) -> types.NestedArray: - # pytype: disable=attribute-error - return observation, self._environment.ale.lives() - # pytype: enable=attribute-error + def _wrap_observation(self, observation: types.NestedArray) -> types.NestedArray: + # pytype: disable=attribute-error + return observation, self._environment.ale.lives() + # pytype: enable=attribute-error - def reset(self) -> dm_env.TimeStep: - """Resets the episode.""" - self._reset_next_step = False - observation = self._environment.reset() - observation = self._wrap_observation(observation) - return dm_env.restart(observation) + def reset(self) -> dm_env.TimeStep: + """Resets the episode.""" + self._reset_next_step = False + observation = self._environment.reset() + observation = self._wrap_observation(observation) + return dm_env.restart(observation) - def step(self, action: List[np.ndarray]) -> dm_env.TimeStep: - """Steps the environment.""" - if self._reset_next_step: - return self.reset() + def step(self, action: List[np.ndarray]) -> dm_env.TimeStep: + """Steps the environment.""" + if self._reset_next_step: + return self.reset() - observation, reward, done, _ = self._environment.step(action[0].item()) - self._reset_next_step = done + observation, reward, done, _ = self._environment.step(action[0].item()) + self._reset_next_step = done - observation = self._wrap_observation(observation) + observation = self._wrap_observation(observation) - if done: - return dm_env.termination(reward, observation) - return dm_env.transition(reward, observation) + if done: + return dm_env.termination(reward, observation) + return dm_env.transition(reward, observation) - def observation_spec(self) -> types.NestedSpec: - return (self._observation_spec, - specs.Array(shape=(), dtype=np.dtype('float64'), name='lives')) + def observation_spec(self) -> types.NestedSpec: + return ( + self._observation_spec, + specs.Array(shape=(), dtype=np.dtype("float64"), name="lives"), + ) - def action_spec(self) -> List[specs.BoundedArray]: - return [self._action_spec] # pytype: disable=bad-return-type + def action_spec(self) -> List[specs.BoundedArray]: + return [self._action_spec] # pytype: disable=bad-return-type diff --git a/acme/wrappers/gym_wrapper_test.py b/acme/wrappers/gym_wrapper_test.py index bc6fdd80f8..57aa2e1811 100644 --- a/acme/wrappers/gym_wrapper_test.py +++ b/acme/wrappers/gym_wrapper_test.py @@ -16,125 +16,125 @@ import unittest -from dm_env import specs import numpy as np - from absl.testing import absltest +from dm_env import specs SKIP_GYM_TESTS = False -SKIP_GYM_MESSAGE = 'gym not installed.' +SKIP_GYM_MESSAGE = "gym not installed." SKIP_ATARI_TESTS = False -SKIP_ATARI_MESSAGE = '' +SKIP_ATARI_MESSAGE = "" try: - # pylint: disable=g-import-not-at-top - from acme.wrappers import gym_wrapper - import gym - # pylint: enable=g-import-not-at-top + # pylint: disable=g-import-not-at-top + import gym + + from acme.wrappers import gym_wrapper + + # pylint: enable=g-import-not-at-top except ModuleNotFoundError: - SKIP_GYM_TESTS = True + SKIP_GYM_TESTS = True try: - import atari_py # pylint: disable=g-import-not-at-top - atari_py.get_game_path('pong') + import atari_py # pylint: disable=g-import-not-at-top + + atari_py.get_game_path("pong") except ModuleNotFoundError as e: - SKIP_ATARI_TESTS = True - SKIP_ATARI_MESSAGE = str(e) + SKIP_ATARI_TESTS = True + SKIP_ATARI_MESSAGE = str(e) except Exception as e: # pylint: disable=broad-except - # This exception is raised by atari_py.get_game_path('pong') if the Atari ROM - # file has not been installed. - SKIP_ATARI_TESTS = True - SKIP_ATARI_MESSAGE = str(e) - del atari_py + # This exception is raised by atari_py.get_game_path('pong') if the Atari ROM + # file has not been installed. + SKIP_ATARI_TESTS = True + SKIP_ATARI_MESSAGE = str(e) + del atari_py else: - del atari_py + del atari_py @unittest.skipIf(SKIP_GYM_TESTS, SKIP_GYM_MESSAGE) class GymWrapperTest(absltest.TestCase): - - def test_gym_cartpole(self): - env = gym_wrapper.GymWrapper(gym.make('CartPole-v0')) - - # Test converted observation spec. - observation_spec: specs.BoundedArray = env.observation_spec() - self.assertEqual(type(observation_spec), specs.BoundedArray) - self.assertEqual(observation_spec.shape, (4,)) - self.assertEqual(observation_spec.minimum.shape, (4,)) - self.assertEqual(observation_spec.maximum.shape, (4,)) - self.assertEqual(observation_spec.dtype, np.dtype('float32')) - - # Test converted action spec. - action_spec: specs.BoundedArray = env.action_spec() - self.assertEqual(type(action_spec), specs.DiscreteArray) - self.assertEqual(action_spec.shape, ()) - self.assertEqual(action_spec.minimum, 0) - self.assertEqual(action_spec.maximum, 1) - self.assertEqual(action_spec.num_values, 2) - self.assertEqual(action_spec.dtype, np.dtype('int64')) - - # Test step. - timestep = env.reset() - self.assertTrue(timestep.first()) - timestep = env.step(1) - self.assertEqual(timestep.reward, 1.0) - self.assertTrue(np.isscalar(timestep.reward)) - self.assertEqual(timestep.observation.shape, (4,)) - env.close() - - def test_early_truncation(self): - # Pendulum has no early termination condition. Recent versions of gym force - # to use v1. We try both in case an earlier version is installed. - try: - gym_env = gym.make('Pendulum-v1') - except: # pylint: disable=bare-except - gym_env = gym.make('Pendulum-v0') - env = gym_wrapper.GymWrapper(gym_env) - ts = env.reset() - while not ts.last(): - ts = env.step(env.action_spec().generate_value()) - self.assertEqual(ts.discount, 1.0) - self.assertTrue(np.isscalar(ts.reward)) - env.close() - - def test_multi_discrete(self): - space = gym.spaces.MultiDiscrete([2, 3]) - spec = gym_wrapper._convert_to_spec(space) - - spec.validate([0, 0]) - spec.validate([1, 2]) - - self.assertRaises(ValueError, spec.validate, [2, 2]) - self.assertRaises(ValueError, spec.validate, [1, 3]) + def test_gym_cartpole(self): + env = gym_wrapper.GymWrapper(gym.make("CartPole-v0")) + + # Test converted observation spec. + observation_spec: specs.BoundedArray = env.observation_spec() + self.assertEqual(type(observation_spec), specs.BoundedArray) + self.assertEqual(observation_spec.shape, (4,)) + self.assertEqual(observation_spec.minimum.shape, (4,)) + self.assertEqual(observation_spec.maximum.shape, (4,)) + self.assertEqual(observation_spec.dtype, np.dtype("float32")) + + # Test converted action spec. + action_spec: specs.BoundedArray = env.action_spec() + self.assertEqual(type(action_spec), specs.DiscreteArray) + self.assertEqual(action_spec.shape, ()) + self.assertEqual(action_spec.minimum, 0) + self.assertEqual(action_spec.maximum, 1) + self.assertEqual(action_spec.num_values, 2) + self.assertEqual(action_spec.dtype, np.dtype("int64")) + + # Test step. + timestep = env.reset() + self.assertTrue(timestep.first()) + timestep = env.step(1) + self.assertEqual(timestep.reward, 1.0) + self.assertTrue(np.isscalar(timestep.reward)) + self.assertEqual(timestep.observation.shape, (4,)) + env.close() + + def test_early_truncation(self): + # Pendulum has no early termination condition. Recent versions of gym force + # to use v1. We try both in case an earlier version is installed. + try: + gym_env = gym.make("Pendulum-v1") + except: # pylint: disable=bare-except + gym_env = gym.make("Pendulum-v0") + env = gym_wrapper.GymWrapper(gym_env) + ts = env.reset() + while not ts.last(): + ts = env.step(env.action_spec().generate_value()) + self.assertEqual(ts.discount, 1.0) + self.assertTrue(np.isscalar(ts.reward)) + env.close() + + def test_multi_discrete(self): + space = gym.spaces.MultiDiscrete([2, 3]) + spec = gym_wrapper._convert_to_spec(space) + + spec.validate([0, 0]) + spec.validate([1, 2]) + + self.assertRaises(ValueError, spec.validate, [2, 2]) + self.assertRaises(ValueError, spec.validate, [1, 3]) @unittest.skipIf(SKIP_ATARI_TESTS, SKIP_ATARI_MESSAGE) class AtariGymWrapperTest(absltest.TestCase): - - def test_pong(self): - env = gym.make('PongNoFrameskip-v4', full_action_space=True) - env = gym_wrapper.GymAtariAdapter(env) - - # Test converted observation spec. This should expose (RGB, LIVES). - observation_spec = env.observation_spec() - self.assertEqual(type(observation_spec[0]), specs.BoundedArray) - self.assertEqual(type(observation_spec[1]), specs.Array) - - # Test converted action spec. - action_spec: specs.DiscreteArray = env.action_spec()[0] - self.assertEqual(type(action_spec), specs.DiscreteArray) - self.assertEqual(action_spec.shape, ()) - self.assertEqual(action_spec.minimum, 0) - self.assertEqual(action_spec.maximum, 17) - self.assertEqual(action_spec.num_values, 18) - self.assertEqual(action_spec.dtype, np.dtype('int64')) - - # Test step. - timestep = env.reset() - self.assertTrue(timestep.first()) - _ = env.step([np.array(0)]) - env.close() - - -if __name__ == '__main__': - absltest.main() + def test_pong(self): + env = gym.make("PongNoFrameskip-v4", full_action_space=True) + env = gym_wrapper.GymAtariAdapter(env) + + # Test converted observation spec. This should expose (RGB, LIVES). + observation_spec = env.observation_spec() + self.assertEqual(type(observation_spec[0]), specs.BoundedArray) + self.assertEqual(type(observation_spec[1]), specs.Array) + + # Test converted action spec. + action_spec: specs.DiscreteArray = env.action_spec()[0] + self.assertEqual(type(action_spec), specs.DiscreteArray) + self.assertEqual(action_spec.shape, ()) + self.assertEqual(action_spec.minimum, 0) + self.assertEqual(action_spec.maximum, 17) + self.assertEqual(action_spec.num_values, 18) + self.assertEqual(action_spec.dtype, np.dtype("int64")) + + # Test step. + timestep = env.reset() + self.assertTrue(timestep.first()) + _ = env.step([np.array(0)]) + env.close() + + +if __name__ == "__main__": + absltest.main() diff --git a/acme/wrappers/mujoco.py b/acme/wrappers/mujoco.py index 60d1afc156..12416ea94d 100644 --- a/acme/wrappers/mujoco.py +++ b/acme/wrappers/mujoco.py @@ -15,36 +15,41 @@ """An environment wrapper to produce pixel observations from dm_control.""" import collections -from acme.wrappers import base + +import dm_env from dm_control.rl import control from dm_control.suite.wrappers import pixels # type: ignore -import dm_env + +from acme.wrappers import base class MujocoPixelWrapper(base.EnvironmentWrapper): - """Produces pixel observations from Mujoco environment observations.""" - - def __init__(self, - environment: control.Environment, - *, - height: int = 84, - width: int = 84, - camera_id: int = 0): - render_kwargs = {'height': height, 'width': width, 'camera_id': camera_id} - pixel_environment = pixels.Wrapper( - environment, pixels_only=True, render_kwargs=render_kwargs) - super().__init__(pixel_environment) - - def step(self, action) -> dm_env.TimeStep: - return self._convert_timestep(self._environment.step(action)) - - def reset(self) -> dm_env.TimeStep: - return self._convert_timestep(self._environment.reset()) - - def observation_spec(self): - return self._environment.observation_spec()['pixels'] - - def _convert_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: - """Removes the pixel observation's OrderedDict wrapper.""" - observation: collections.OrderedDict = timestep.observation - return timestep._replace(observation=observation['pixels']) + """Produces pixel observations from Mujoco environment observations.""" + + def __init__( + self, + environment: control.Environment, + *, + height: int = 84, + width: int = 84, + camera_id: int = 0 + ): + render_kwargs = {"height": height, "width": width, "camera_id": camera_id} + pixel_environment = pixels.Wrapper( + environment, pixels_only=True, render_kwargs=render_kwargs + ) + super().__init__(pixel_environment) + + def step(self, action) -> dm_env.TimeStep: + return self._convert_timestep(self._environment.step(action)) + + def reset(self) -> dm_env.TimeStep: + return self._convert_timestep(self._environment.reset()) + + def observation_spec(self): + return self._environment.observation_spec()["pixels"] + + def _convert_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: + """Removes the pixel observation's OrderedDict wrapper.""" + observation: collections.OrderedDict = timestep.observation + return timestep._replace(observation=observation["pixels"]) diff --git a/acme/wrappers/multiagent_dict_key_wrapper.py b/acme/wrappers/multiagent_dict_key_wrapper.py index 1c15c12199..1a8eb259d2 100644 --- a/acme/wrappers/multiagent_dict_key_wrapper.py +++ b/acme/wrappers/multiagent_dict_key_wrapper.py @@ -15,16 +15,17 @@ """Multiagent dict-indexed environment wrapped.""" from typing import Any, Dict, List, TypeVar, Union -from acme import types -from acme.wrappers import base import dm_env -V = TypeVar('V') +from acme import types +from acme.wrappers import base + +V = TypeVar("V") class MultiagentDictKeyWrapper(base.EnvironmentWrapper): - """Wrapper that converts list-indexed multiagent environments to dict-indexed. + """Wrapper that converts list-indexed multiagent environments to dict-indexed. Specifically, if the underlying environment observation and actions are: observation = [observation_agent_0, observation_agent_1, ...] @@ -39,49 +40,53 @@ class MultiagentDictKeyWrapper(base.EnvironmentWrapper): can directly be supported if dicts, but not natively supported as lists). """ - def __init__(self, environment: dm_env.Environment): - self._environment = environment - # Convert action and observation specs. - self._action_spec = self._list_to_dict(self._environment.action_spec()) - self._discount_spec = self._list_to_dict(self._environment.discount_spec()) - self._observation_spec = self._list_to_dict( - self._environment.observation_spec()) - self._reward_spec = self._list_to_dict(self._environment.reward_spec()) - - def _list_to_dict(self, data: Union[List[V], V]) -> Union[Dict[str, V], V]: - """Convert list-indexed data to dict-indexed, otherwise passthrough.""" - if isinstance(data, list): - return {str(k): v for k, v in enumerate(data)} - return data - - def _dict_to_list(self, data: Union[Dict[str, V], V]) -> Union[List[V], V]: - """Convert dict-indexed data to list-indexed, otherwise passthrough.""" - if isinstance(data, dict): - return [data[str(i_agent)] - for i_agent in range(self._environment.num_agents)] # pytype: disable=attribute-error - return data - - def _convert_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: - return timestep._replace( - reward=self._list_to_dict(timestep.reward), - discount=self._list_to_dict(timestep.discount), - observation=self._list_to_dict(timestep.observation)) - - def step(self, action: Dict[int, Any]) -> dm_env.TimeStep: - return self._convert_timestep( - self._environment.step(self._dict_to_list(action))) - - def reset(self) -> dm_env.TimeStep: - return self._convert_timestep(self._environment.reset()) - - def action_spec(self) -> types.NestedSpec: # Internal pytype check. - return self._action_spec - - def discount_spec(self) -> types.NestedSpec: # Internal pytype check. - return self._discount_spec - - def observation_spec(self) -> types.NestedSpec: # Internal pytype check. - return self._observation_spec - - def reward_spec(self) -> types.NestedSpec: # Internal pytype check. - return self._reward_spec + def __init__(self, environment: dm_env.Environment): + self._environment = environment + # Convert action and observation specs. + self._action_spec = self._list_to_dict(self._environment.action_spec()) + self._discount_spec = self._list_to_dict(self._environment.discount_spec()) + self._observation_spec = self._list_to_dict( + self._environment.observation_spec() + ) + self._reward_spec = self._list_to_dict(self._environment.reward_spec()) + + def _list_to_dict(self, data: Union[List[V], V]) -> Union[Dict[str, V], V]: + """Convert list-indexed data to dict-indexed, otherwise passthrough.""" + if isinstance(data, list): + return {str(k): v for k, v in enumerate(data)} + return data + + def _dict_to_list(self, data: Union[Dict[str, V], V]) -> Union[List[V], V]: + """Convert dict-indexed data to list-indexed, otherwise passthrough.""" + if isinstance(data, dict): + return [ + data[str(i_agent)] for i_agent in range(self._environment.num_agents) + ] # pytype: disable=attribute-error + return data + + def _convert_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: + return timestep._replace( + reward=self._list_to_dict(timestep.reward), + discount=self._list_to_dict(timestep.discount), + observation=self._list_to_dict(timestep.observation), + ) + + def step(self, action: Dict[int, Any]) -> dm_env.TimeStep: + return self._convert_timestep( + self._environment.step(self._dict_to_list(action)) + ) + + def reset(self) -> dm_env.TimeStep: + return self._convert_timestep(self._environment.reset()) + + def action_spec(self) -> types.NestedSpec: # Internal pytype check. + return self._action_spec + + def discount_spec(self) -> types.NestedSpec: # Internal pytype check. + return self._discount_spec + + def observation_spec(self) -> types.NestedSpec: # Internal pytype check. + return self._observation_spec + + def reward_spec(self) -> types.NestedSpec: # Internal pytype check. + return self._reward_spec diff --git a/acme/wrappers/multigrid_wrapper.py b/acme/wrappers/multigrid_wrapper.py index 9c4b2287bb..643f03f9fe 100644 --- a/acme/wrappers/multigrid_wrapper.py +++ b/acme/wrappers/multigrid_wrapper.py @@ -14,149 +14,152 @@ """Wraps a Multigrid multiagent environment to be used as a dm_env.""" -from typing import Any, Dict, List, Optional import warnings +from typing import Any, Dict, List, Optional -from acme import specs -from acme import types -from acme import wrappers -from acme.multiagent import types as ma_types -from acme.wrappers import multiagent_dict_key_wrapper import dm_env import gym -from gym import spaces import jax import numpy as np import tree +from gym import spaces + +from acme import specs, types, wrappers +from acme.multiagent import types as ma_types +from acme.wrappers import multiagent_dict_key_wrapper try: - # The following import registers multigrid environments in gym. Do not remove. - # pylint: disable=unused-import, disable=g-import-not-at-top - # pytype: disable=import-error - from social_rl.gym_multigrid import multigrid - # pytype: enable=import-error - # pylint: enable=unused-import, enable=g-import-not-at-top + # The following import registers multigrid environments in gym. Do not remove. + # pylint: disable=unused-import, disable=g-import-not-at-top + # pytype: disable=import-error + from social_rl.gym_multigrid import multigrid + + # pytype: enable=import-error + # pylint: enable=unused-import, enable=g-import-not-at-top except ModuleNotFoundError as err: - raise ModuleNotFoundError( - 'The multiagent multigrid environment module could not be found. ' - 'Ensure you have downloaded it from ' - 'https://github.com/google-research/google-research/tree/master/social_rl/gym_multigrid' - ' before running this example.') from err + raise ModuleNotFoundError( + "The multiagent multigrid environment module could not be found. " + "Ensure you have downloaded it from " + "https://github.com/google-research/google-research/tree/master/social_rl/gym_multigrid" + " before running this example." + ) from err # Disables verbose np.bool warnings that occur in multigrid. warnings.filterwarnings( - action='ignore', + action="ignore", category=DeprecationWarning, - message='`np.bool` is a deprecated alias') + message="`np.bool` is a deprecated alias", +) class MultigridWrapper(dm_env.Environment): - """Environment wrapper for Multigrid environments. + """Environment wrapper for Multigrid environments. Note: the main difference with vanilla GymWrapper is that reward_spec() is overridden and rewards are cast to np.arrays in step() """ - def __init__(self, environment: multigrid.MultiGridEnv): - """Initializes environment. + def __init__(self, environment: multigrid.MultiGridEnv): + """Initializes environment. Args: environment: the environment. """ - self._environment = environment - self._reset_next_step = True - self._last_info = None - self.num_agents = environment.n_agents # pytype: disable=attribute-error - - # Convert action and observation specs. - obs_space = self._environment.observation_space - act_space = self._environment.action_space - self._observation_spec = _convert_to_spec( - obs_space, self.num_agents, name='observation') - self._action_spec = _convert_to_spec( - act_space, self.num_agents, name='action') - - def process_obs(self, observation: types.NestedArray) -> types.NestedArray: - # Convert observations to agent-index-first format - observation = dict_obs_to_list_obs(observation) - - # Assign dtypes to multigrid observations (some of which are lists by - # default, so do not have a precise dtype that matches their observation - # spec. This ensures no replay signature mismatch issues occur). - observation = tree.map_structure(lambda x, t: np.asarray(x, dtype=t.dtype), - observation, self.observation_spec()) - return observation - - def reset(self) -> dm_env.TimeStep: - """Resets the episode.""" - self._reset_next_step = False - observation = self.process_obs(self._environment.reset()) - - # Reset the diagnostic information. - self._last_info = None - return dm_env.restart(observation) - - def step(self, action: types.NestedArray) -> dm_env.TimeStep: - """Steps the environment.""" - if self._reset_next_step: - return self.reset() - - observation, reward, done, info = self._environment.step(action) - observation = self.process_obs(observation) - - self._reset_next_step = done - self._last_info = info - - def _map_reward_spec(x, t): - if np.isscalar(x): - return t.dtype.type(x) - return np.asarray(x, dtype=t.dtype) - - reward = tree.map_structure( - _map_reward_spec, - reward, - self.reward_spec()) - - if done: - truncated = info.get('TimeLimit.truncated', False) - if truncated: - return dm_env.truncation(reward, observation) - return dm_env.termination(reward, observation) - return dm_env.transition(reward, observation) - - def observation_spec(self) -> types.NestedSpec: # Internal pytype check. - return self._observation_spec - - def action_spec(self) -> types.NestedSpec: # Internal pytype check. - return self._action_spec - - def reward_spec(self) -> types.NestedSpec: # Internal pytype check. - return [specs.Array(shape=(), dtype=float, name='rewards') - ] * self._environment.n_agents - - def get_info(self) -> Optional[Dict[str, Any]]: - """Returns the last info returned from env.step(action). + self._environment = environment + self._reset_next_step = True + self._last_info = None + self.num_agents = environment.n_agents # pytype: disable=attribute-error + + # Convert action and observation specs. + obs_space = self._environment.observation_space + act_space = self._environment.action_space + self._observation_spec = _convert_to_spec( + obs_space, self.num_agents, name="observation" + ) + self._action_spec = _convert_to_spec(act_space, self.num_agents, name="action") + + def process_obs(self, observation: types.NestedArray) -> types.NestedArray: + # Convert observations to agent-index-first format + observation = dict_obs_to_list_obs(observation) + + # Assign dtypes to multigrid observations (some of which are lists by + # default, so do not have a precise dtype that matches their observation + # spec. This ensures no replay signature mismatch issues occur). + observation = tree.map_structure( + lambda x, t: np.asarray(x, dtype=t.dtype), + observation, + self.observation_spec(), + ) + return observation + + def reset(self) -> dm_env.TimeStep: + """Resets the episode.""" + self._reset_next_step = False + observation = self.process_obs(self._environment.reset()) + + # Reset the diagnostic information. + self._last_info = None + return dm_env.restart(observation) + + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + """Steps the environment.""" + if self._reset_next_step: + return self.reset() + + observation, reward, done, info = self._environment.step(action) + observation = self.process_obs(observation) + + self._reset_next_step = done + self._last_info = info + + def _map_reward_spec(x, t): + if np.isscalar(x): + return t.dtype.type(x) + return np.asarray(x, dtype=t.dtype) + + reward = tree.map_structure(_map_reward_spec, reward, self.reward_spec()) + + if done: + truncated = info.get("TimeLimit.truncated", False) + if truncated: + return dm_env.truncation(reward, observation) + return dm_env.termination(reward, observation) + return dm_env.transition(reward, observation) + + def observation_spec(self) -> types.NestedSpec: # Internal pytype check. + return self._observation_spec + + def action_spec(self) -> types.NestedSpec: # Internal pytype check. + return self._action_spec + + def reward_spec(self) -> types.NestedSpec: # Internal pytype check. + return [ + specs.Array(shape=(), dtype=float, name="rewards") + ] * self._environment.n_agents + + def get_info(self) -> Optional[Dict[str, Any]]: + """Returns the last info returned from env.step(action). Returns: info: dictionary of diagnostic information from the last environment step """ - return self._last_info + return self._last_info - @property - def environment(self) -> gym.Env: - """Returns the wrapped environment.""" - return self._environment + @property + def environment(self) -> gym.Env: + """Returns the wrapped environment.""" + return self._environment - def __getattr__(self, name: str) -> Any: - """Returns any other attributes of the underlying environment.""" - return getattr(self._environment, name) + def __getattr__(self, name: str) -> Any: + """Returns any other attributes of the underlying environment.""" + return getattr(self._environment, name) - def close(self): - self._environment.close() + def close(self): + self._environment.close() def _get_single_agent_spec(spec): - """Returns a single-agent spec from multiagent multigrid spec. + """Returns a single-agent spec from multiagent multigrid spec. Primarily used for converting multigrid specs to multiagent Acme specs, wherein actions and observations specs are expected to be lists (each entry @@ -167,41 +170,43 @@ def _get_single_agent_spec(spec): Args: spec: multigrid environment spec. """ - def make_single_agent_spec(spec): - if not spec.shape: # Rewards & discounts - shape = () - elif len(spec.shape) == 1: # Actions - shape = () - else: # Observations - shape = spec.shape[1:] - - if isinstance(spec, specs.BoundedArray): - # Bounded rewards and discounts often have no dimensions as they are - # amongst the agents, whereas observations are of shape [num_agents, ...]. - # The following pair of if statements handle both cases accordingly. - minimum = spec.minimum if spec.minimum.ndim == 0 else spec.minimum[0] - maximum = spec.maximum if spec.maximum.ndim == 0 else spec.maximum[0] - return specs.BoundedArray( - shape=shape, - name=spec.name, - minimum=minimum, - maximum=maximum, - dtype=spec.dtype) - elif isinstance(spec, specs.DiscreteArray): - return specs.DiscreteArray( - num_values=spec.num_values, dtype=spec.dtype, name=spec.name) - elif isinstance(spec, specs.Array): - return specs.Array(shape=shape, dtype=spec.dtype, name=spec.name) - else: - raise ValueError(f'Unexpected spec type {type(spec)}.') - - single_agent_spec = jax.tree_map(make_single_agent_spec, spec) - return single_agent_spec - -def _gym_to_spec(space: gym.Space, - name: Optional[str] = None) -> types.NestedSpec: - """Converts an OpenAI Gym space to a dm_env spec or nested structure of specs. + def make_single_agent_spec(spec): + if not spec.shape: # Rewards & discounts + shape = () + elif len(spec.shape) == 1: # Actions + shape = () + else: # Observations + shape = spec.shape[1:] + + if isinstance(spec, specs.BoundedArray): + # Bounded rewards and discounts often have no dimensions as they are + # amongst the agents, whereas observations are of shape [num_agents, ...]. + # The following pair of if statements handle both cases accordingly. + minimum = spec.minimum if spec.minimum.ndim == 0 else spec.minimum[0] + maximum = spec.maximum if spec.maximum.ndim == 0 else spec.maximum[0] + return specs.BoundedArray( + shape=shape, + name=spec.name, + minimum=minimum, + maximum=maximum, + dtype=spec.dtype, + ) + elif isinstance(spec, specs.DiscreteArray): + return specs.DiscreteArray( + num_values=spec.num_values, dtype=spec.dtype, name=spec.name + ) + elif isinstance(spec, specs.Array): + return specs.Array(shape=shape, dtype=spec.dtype, name=spec.name) + else: + raise ValueError(f"Unexpected spec type {type(spec)}.") + + single_agent_spec = jax.tree_map(make_single_agent_spec, spec) + return single_agent_spec + + +def _gym_to_spec(space: gym.Space, name: Optional[str] = None) -> types.NestedSpec: + """Converts an OpenAI Gym space to a dm_env spec or nested structure of specs. Box, MultiBinary and MultiDiscrete Gym spaces are converted to BoundedArray specs. Discrete OpenAI spaces are converted to DiscreteArray specs. Tuple and @@ -215,49 +220,46 @@ def _gym_to_spec(space: gym.Space, A dm_env spec or nested structure of specs, corresponding to the input space. """ - if isinstance(space, spaces.Discrete): - return specs.DiscreteArray(num_values=space.n, dtype=space.dtype, name=name) - - elif isinstance(space, spaces.Box): - return specs.BoundedArray( - shape=space.shape, - dtype=space.dtype, - minimum=space.low, - maximum=space.high, - name=name) - - elif isinstance(space, spaces.MultiBinary): - return specs.BoundedArray( - shape=space.shape, - dtype=space.dtype, - minimum=0.0, - maximum=1.0, - name=name) - - elif isinstance(space, spaces.MultiDiscrete): - return specs.BoundedArray( - shape=space.shape, - dtype=space.dtype, - minimum=np.zeros(space.shape), - maximum=space.nvec - 1, - name=name) - - elif isinstance(space, spaces.Tuple): - return tuple(_gym_to_spec(s, name) for s in space.spaces) - - elif isinstance(space, spaces.Dict): - return { - key: _gym_to_spec(value, key) for key, value in space.spaces.items() - } - - else: - raise ValueError('Unexpected gym space: {}'.format(space)) - - -def _convert_to_spec(space: gym.Space, - num_agents: int, - name: Optional[str] = None) -> types.NestedSpec: - """Converts multigrid Gym space to an Acme multiagent spec. + if isinstance(space, spaces.Discrete): + return specs.DiscreteArray(num_values=space.n, dtype=space.dtype, name=name) + + elif isinstance(space, spaces.Box): + return specs.BoundedArray( + shape=space.shape, + dtype=space.dtype, + minimum=space.low, + maximum=space.high, + name=name, + ) + + elif isinstance(space, spaces.MultiBinary): + return specs.BoundedArray( + shape=space.shape, dtype=space.dtype, minimum=0.0, maximum=1.0, name=name + ) + + elif isinstance(space, spaces.MultiDiscrete): + return specs.BoundedArray( + shape=space.shape, + dtype=space.dtype, + minimum=np.zeros(space.shape), + maximum=space.nvec - 1, + name=name, + ) + + elif isinstance(space, spaces.Tuple): + return tuple(_gym_to_spec(s, name) for s in space.spaces) + + elif isinstance(space, spaces.Dict): + return {key: _gym_to_spec(value, key) for key, value in space.spaces.items()} + + else: + raise ValueError("Unexpected gym space: {}".format(space)) + + +def _convert_to_spec( + space: gym.Space, num_agents: int, name: Optional[str] = None +) -> types.NestedSpec: + """Converts multigrid Gym space to an Acme multiagent spec. Args: space: The Gym space to convert. @@ -268,16 +270,16 @@ def _convert_to_spec(space: gym.Space, A dm_env spec or nested structure of specs, corresponding to the input space. """ - # Convert gym specs to acme specs - spec = _gym_to_spec(space, name) - # Then change spec indexing from observation-key-first to agent-index-first - return [_get_single_agent_spec(spec)] * num_agents + # Convert gym specs to acme specs + spec = _gym_to_spec(space, name) + # Then change spec indexing from observation-key-first to agent-index-first + return [_get_single_agent_spec(spec)] * num_agents def dict_obs_to_list_obs( - observation: types.NestedArray + observation: types.NestedArray, ) -> List[Dict[ma_types.AgentID, types.NestedArray]]: - """Returns multigrid observations converted to agent-index-first format. + """Returns multigrid observations converted to agent-index-first format. By default, multigrid observations are structured as: observation['image'][agent_index] @@ -293,22 +295,23 @@ def dict_obs_to_list_obs( Args: observation: """ - return [dict(zip(observation, v)) for v in zip(*observation.values())] + return [dict(zip(observation, v)) for v in zip(*observation.values())] def make_multigrid_environment( - env_name: str = 'MultiGrid-Empty-5x5-v0') -> dm_env.Environment: - """Returns Multigrid Multiagent Gym environment. + env_name: str = "MultiGrid-Empty-5x5-v0", +) -> dm_env.Environment: + """Returns Multigrid Multiagent Gym environment. Args: env_name: name of multigrid task. See social_rl.gym_multigrid.envs for the available environments. """ - # Load the gym environment. - env = gym.make(env_name) - - # Make sure the environment obeys the dm_env.Environment interface. - env = MultigridWrapper(env) - env = wrappers.SinglePrecisionWrapper(env) - env = multiagent_dict_key_wrapper.MultiagentDictKeyWrapper(env) - return env + # Load the gym environment. + env = gym.make(env_name) + + # Make sure the environment obeys the dm_env.Environment interface. + env = MultigridWrapper(env) + env = wrappers.SinglePrecisionWrapper(env) + env = multiagent_dict_key_wrapper.MultiagentDictKeyWrapper(env) + return env diff --git a/acme/wrappers/noop_starts.py b/acme/wrappers/noop_starts.py index ed23ebf8a2..336c24b730 100644 --- a/acme/wrappers/noop_starts.py +++ b/acme/wrappers/noop_starts.py @@ -16,14 +16,15 @@ from typing import Optional -from acme import types -from acme.wrappers import base import dm_env import numpy as np +from acme import types +from acme.wrappers import base + class NoopStartsWrapper(base.EnvironmentWrapper): - """Implements random noop starts to episodes. + """Implements random noop starts to episodes. This introduces randomness into an otherwise deterministic environment. @@ -31,12 +32,14 @@ class NoopStartsWrapper(base.EnvironmentWrapper): of this action must be known and provided to this wrapper. """ - def __init__(self, - environment: dm_env.Environment, - noop_action: types.NestedArray = 0, - noop_max: int = 30, - seed: Optional[int] = None): - """Initializes a `NoopStartsWrapper` wrapper. + def __init__( + self, + environment: dm_env.Environment, + noop_action: types.NestedArray = 0, + noop_max: int = 30, + seed: Optional[int] = None, + ): + """Initializes a `NoopStartsWrapper` wrapper. Args: environment: An environment conforming to the dm_env.Environment @@ -46,23 +49,24 @@ def __init__(self, noop_max: The maximal number of noop actions at the start of an episode. seed: The random seed used to sample the number of noops. """ - if noop_max < 0: - raise ValueError( - 'Maximal number of no-ops after reset cannot be negative. ' - f'Received noop_max={noop_max}') + if noop_max < 0: + raise ValueError( + "Maximal number of no-ops after reset cannot be negative. " + f"Received noop_max={noop_max}" + ) - super().__init__(environment) - self.np_random = np.random.RandomState(seed) - self._noop_max = noop_max - self._noop_action = noop_action + super().__init__(environment) + self.np_random = np.random.RandomState(seed) + self._noop_max = noop_max + self._noop_action = noop_action - def reset(self) -> dm_env.TimeStep: - """Resets environment and provides the first timestep.""" - noops = self.np_random.randint(self._noop_max + 1) - timestep = self.environment.reset() - for _ in range(noops): - timestep = self.environment.step(self._noop_action) - if timestep.last(): + def reset(self) -> dm_env.TimeStep: + """Resets environment and provides the first timestep.""" + noops = self.np_random.randint(self._noop_max + 1) timestep = self.environment.reset() + for _ in range(noops): + timestep = self.environment.step(self._noop_action) + if timestep.last(): + timestep = self.environment.reset() - return timestep._replace(step_type=dm_env.StepType.FIRST) + return timestep._replace(step_type=dm_env.StepType.FIRST) diff --git a/acme/wrappers/noop_starts_test.py b/acme/wrappers/noop_starts_test.py index 74d96e788d..e7048a3a09 100644 --- a/acme/wrappers/noop_starts_test.py +++ b/acme/wrappers/noop_starts_test.py @@ -16,53 +16,51 @@ from unittest import mock -from acme import wrappers -from acme.testing import fakes -from dm_env import specs import numpy as np - from absl.testing import absltest +from dm_env import specs +from acme import wrappers +from acme.testing import fakes -class NoopStartsTest(absltest.TestCase): - def test_reset(self): - """Ensure that noop starts `reset` steps the environment multiple times.""" - noop_action = 0 - noop_max = 10 - seed = 24 +class NoopStartsTest(absltest.TestCase): + def test_reset(self): + """Ensure that noop starts `reset` steps the environment multiple times.""" + noop_action = 0 + noop_max = 10 + seed = 24 - base_env = fakes.DiscreteEnvironment( - action_dtype=np.int64, - obs_dtype=np.int64, - reward_spec=specs.Array(dtype=np.float64, shape=())) - mock_step_fn = mock.MagicMock() - expected_num_step_calls = np.random.RandomState(seed).randint(noop_max + 1) + base_env = fakes.DiscreteEnvironment( + action_dtype=np.int64, + obs_dtype=np.int64, + reward_spec=specs.Array(dtype=np.float64, shape=()), + ) + mock_step_fn = mock.MagicMock() + expected_num_step_calls = np.random.RandomState(seed).randint(noop_max + 1) - with mock.patch.object(base_env, 'step', mock_step_fn): - env = wrappers.NoopStartsWrapper( - base_env, - noop_action=noop_action, - noop_max=noop_max, - seed=seed, - ) - env.reset() + with mock.patch.object(base_env, "step", mock_step_fn): + env = wrappers.NoopStartsWrapper( + base_env, noop_action=noop_action, noop_max=noop_max, seed=seed, + ) + env.reset() - # Test environment step called with noop action as part of wrapper.reset - mock_step_fn.assert_called_with(noop_action) - self.assertEqual(mock_step_fn.call_count, expected_num_step_calls) - self.assertEqual(mock_step_fn.call_args, ((noop_action,), {})) + # Test environment step called with noop action as part of wrapper.reset + mock_step_fn.assert_called_with(noop_action) + self.assertEqual(mock_step_fn.call_count, expected_num_step_calls) + self.assertEqual(mock_step_fn.call_args, ((noop_action,), {})) - def test_raises_value_error(self): - """Ensure that wrapper raises error if noop_max is <0.""" - base_env = fakes.DiscreteEnvironment( - action_dtype=np.int64, - obs_dtype=np.int64, - reward_spec=specs.Array(dtype=np.float64, shape=())) + def test_raises_value_error(self): + """Ensure that wrapper raises error if noop_max is <0.""" + base_env = fakes.DiscreteEnvironment( + action_dtype=np.int64, + obs_dtype=np.int64, + reward_spec=specs.Array(dtype=np.float64, shape=()), + ) - with self.assertRaises(ValueError): - wrappers.NoopStartsWrapper(base_env, noop_action=0, noop_max=-1, seed=24) + with self.assertRaises(ValueError): + wrappers.NoopStartsWrapper(base_env, noop_action=0, noop_max=-1, seed=24) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/wrappers/observation_action_reward.py b/acme/wrappers/observation_action_reward.py index 2433de145f..109dc58705 100644 --- a/acme/wrappers/observation_action_reward.py +++ b/acme/wrappers/observation_action_reward.py @@ -16,47 +16,53 @@ from typing import NamedTuple -from acme import types -from acme.wrappers import base - import dm_env import tree +from acme import types +from acme.wrappers import base + class OAR(NamedTuple): - """Container for (Observation, Action, Reward) tuples.""" - observation: types.Nest - action: types.Nest - reward: types.Nest + """Container for (Observation, Action, Reward) tuples.""" + + observation: types.Nest + action: types.Nest + reward: types.Nest class ObservationActionRewardWrapper(base.EnvironmentWrapper): - """A wrapper that puts the previous action and reward into the observation.""" - - def reset(self) -> dm_env.TimeStep: - # Initialize with zeros of the appropriate shape/dtype. - action = tree.map_structure( - lambda x: x.generate_value(), self._environment.action_spec()) - reward = tree.map_structure( - lambda x: x.generate_value(), self._environment.reward_spec()) - timestep = self._environment.reset() - new_timestep = self._augment_observation(action, reward, timestep) - return new_timestep - - def step(self, action: types.NestedArray) -> dm_env.TimeStep: - timestep = self._environment.step(action) - new_timestep = self._augment_observation(action, timestep.reward, timestep) - return new_timestep - - def _augment_observation(self, action: types.NestedArray, - reward: types.NestedArray, - timestep: dm_env.TimeStep) -> dm_env.TimeStep: - oar = OAR(observation=timestep.observation, - action=action, - reward=reward) - return timestep._replace(observation=oar) - - def observation_spec(self): - return OAR(observation=self._environment.observation_spec(), - action=self.action_spec(), - reward=self.reward_spec()) + """A wrapper that puts the previous action and reward into the observation.""" + + def reset(self) -> dm_env.TimeStep: + # Initialize with zeros of the appropriate shape/dtype. + action = tree.map_structure( + lambda x: x.generate_value(), self._environment.action_spec() + ) + reward = tree.map_structure( + lambda x: x.generate_value(), self._environment.reward_spec() + ) + timestep = self._environment.reset() + new_timestep = self._augment_observation(action, reward, timestep) + return new_timestep + + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + timestep = self._environment.step(action) + new_timestep = self._augment_observation(action, timestep.reward, timestep) + return new_timestep + + def _augment_observation( + self, + action: types.NestedArray, + reward: types.NestedArray, + timestep: dm_env.TimeStep, + ) -> dm_env.TimeStep: + oar = OAR(observation=timestep.observation, action=action, reward=reward) + return timestep._replace(observation=oar) + + def observation_spec(self): + return OAR( + observation=self._environment.observation_spec(), + action=self.action_spec(), + reward=self.reward_spec(), + ) diff --git a/acme/wrappers/open_spiel_wrapper.py b/acme/wrappers/open_spiel_wrapper.py index 3d6d323024..42e372266f 100644 --- a/acme/wrappers/open_spiel_wrapper.py +++ b/acme/wrappers/open_spiel_wrapper.py @@ -16,132 +16,151 @@ from typing import List, NamedTuple -from acme import specs -from acme import types import dm_env import numpy as np + # pytype: disable=import-error from open_spiel.python import rl_environment + +from acme import specs, types + # pytype: enable=import-error class OLT(NamedTuple): - """Container for (observation, legal_actions, terminal) tuples.""" - observation: types.Nest - legal_actions: types.Nest - terminal: types.Nest + """Container for (observation, legal_actions, terminal) tuples.""" + + observation: types.Nest + legal_actions: types.Nest + terminal: types.Nest class OpenSpielWrapper(dm_env.Environment): - """Environment wrapper for OpenSpiel RL environments.""" - - # Note: we don't inherit from base.EnvironmentWrapper because that class - # assumes that the wrapped environment is a dm_env.Environment. - - def __init__(self, environment: rl_environment.Environment): - self._environment = environment - self._reset_next_step = True - if not environment.is_turn_based: - raise ValueError("Currently only supports turn based games.") - - def reset(self) -> dm_env.TimeStep: - """Resets the episode.""" - self._reset_next_step = False - open_spiel_timestep = self._environment.reset() - observations = self._convert_observation(open_spiel_timestep) - return dm_env.restart(observations) - - def step(self, action: types.NestedArray) -> dm_env.TimeStep: - """Steps the environment.""" - if self._reset_next_step: - return self.reset() - - open_spiel_timestep = self._environment.step(action) - - if open_spiel_timestep.step_type == rl_environment.StepType.LAST: - self._reset_next_step = True - - observations = self._convert_observation(open_spiel_timestep) - rewards = np.asarray(open_spiel_timestep.rewards) - discounts = np.asarray(open_spiel_timestep.discounts) - step_type = open_spiel_timestep.step_type - - if step_type == rl_environment.StepType.FIRST: - step_type = dm_env.StepType.FIRST - elif step_type == rl_environment.StepType.MID: - step_type = dm_env.StepType.MID - elif step_type == rl_environment.StepType.LAST: - step_type = dm_env.StepType.LAST - else: - raise ValueError( - "Did not recognize OpenSpiel StepType: {}".format(step_type)) - - return dm_env.TimeStep(observation=observations, - reward=rewards, - discount=discounts, - step_type=step_type) - - # Convert OpenSpiel observation so it's dm_env compatible. Also, the list - # of legal actions must be converted to a legal actions mask. - def _convert_observation( - self, open_spiel_timestep: rl_environment.TimeStep) -> List[OLT]: - observations = [] - for pid in range(self._environment.num_players): - legals = np.zeros(self._environment.game.num_distinct_actions(), - dtype=np.float32) - legals[open_spiel_timestep.observations["legal_actions"][pid]] = 1.0 - player_observation = OLT(observation=np.asarray( - open_spiel_timestep.observations["info_state"][pid], - dtype=np.float32), - legal_actions=legals, - terminal=np.asarray([open_spiel_timestep.last()], - dtype=np.float32)) - observations.append(player_observation) - return observations - - def observation_spec(self) -> OLT: - # Observation spec depends on whether the OpenSpiel environment is using - # observation/information_state tensors. - if self._environment.use_observation: - return OLT(observation=specs.Array( - (self._environment.game.observation_tensor_size(),), np.float32), - legal_actions=specs.Array( - (self._environment.game.num_distinct_actions(),), - np.float32), - terminal=specs.Array((1,), np.float32)) - else: - return OLT(observation=specs.Array( - (self._environment.game.information_state_tensor_size(),), - np.float32), - legal_actions=specs.Array( - (self._environment.game.num_distinct_actions(),), - np.float32), - terminal=specs.Array((1,), np.float32)) - - def action_spec(self) -> specs.DiscreteArray: - return specs.DiscreteArray(self._environment.game.num_distinct_actions()) - - def reward_spec(self) -> specs.BoundedArray: - return specs.BoundedArray((), - np.float32, - minimum=self._environment.game.min_utility(), - maximum=self._environment.game.max_utility()) - - def discount_spec(self) -> specs.BoundedArray: - return specs.BoundedArray((), np.float32, minimum=0, maximum=1.0) - - @property - def environment(self) -> rl_environment.Environment: - """Returns the wrapped environment.""" - return self._environment - - @property - def current_player(self) -> int: - return self._environment.get_state.current_player() - - def __getattr__(self, name: str): - """Expose any other attributes of the underlying environment.""" - if name.startswith("__"): - raise AttributeError( - "attempted to get missing private attribute '{}'".format(name)) - return getattr(self._environment, name) + """Environment wrapper for OpenSpiel RL environments.""" + + # Note: we don't inherit from base.EnvironmentWrapper because that class + # assumes that the wrapped environment is a dm_env.Environment. + + def __init__(self, environment: rl_environment.Environment): + self._environment = environment + self._reset_next_step = True + if not environment.is_turn_based: + raise ValueError("Currently only supports turn based games.") + + def reset(self) -> dm_env.TimeStep: + """Resets the episode.""" + self._reset_next_step = False + open_spiel_timestep = self._environment.reset() + observations = self._convert_observation(open_spiel_timestep) + return dm_env.restart(observations) + + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + """Steps the environment.""" + if self._reset_next_step: + return self.reset() + + open_spiel_timestep = self._environment.step(action) + + if open_spiel_timestep.step_type == rl_environment.StepType.LAST: + self._reset_next_step = True + + observations = self._convert_observation(open_spiel_timestep) + rewards = np.asarray(open_spiel_timestep.rewards) + discounts = np.asarray(open_spiel_timestep.discounts) + step_type = open_spiel_timestep.step_type + + if step_type == rl_environment.StepType.FIRST: + step_type = dm_env.StepType.FIRST + elif step_type == rl_environment.StepType.MID: + step_type = dm_env.StepType.MID + elif step_type == rl_environment.StepType.LAST: + step_type = dm_env.StepType.LAST + else: + raise ValueError( + "Did not recognize OpenSpiel StepType: {}".format(step_type) + ) + + return dm_env.TimeStep( + observation=observations, + reward=rewards, + discount=discounts, + step_type=step_type, + ) + + # Convert OpenSpiel observation so it's dm_env compatible. Also, the list + # of legal actions must be converted to a legal actions mask. + def _convert_observation( + self, open_spiel_timestep: rl_environment.TimeStep + ) -> List[OLT]: + observations = [] + for pid in range(self._environment.num_players): + legals = np.zeros( + self._environment.game.num_distinct_actions(), dtype=np.float32 + ) + legals[open_spiel_timestep.observations["legal_actions"][pid]] = 1.0 + player_observation = OLT( + observation=np.asarray( + open_spiel_timestep.observations["info_state"][pid], + dtype=np.float32, + ), + legal_actions=legals, + terminal=np.asarray([open_spiel_timestep.last()], dtype=np.float32), + ) + observations.append(player_observation) + return observations + + def observation_spec(self) -> OLT: + # Observation spec depends on whether the OpenSpiel environment is using + # observation/information_state tensors. + if self._environment.use_observation: + return OLT( + observation=specs.Array( + (self._environment.game.observation_tensor_size(),), np.float32 + ), + legal_actions=specs.Array( + (self._environment.game.num_distinct_actions(),), np.float32 + ), + terminal=specs.Array((1,), np.float32), + ) + else: + return OLT( + observation=specs.Array( + (self._environment.game.information_state_tensor_size(),), + np.float32, + ), + legal_actions=specs.Array( + (self._environment.game.num_distinct_actions(),), np.float32 + ), + terminal=specs.Array((1,), np.float32), + ) + + def action_spec(self) -> specs.DiscreteArray: + return specs.DiscreteArray(self._environment.game.num_distinct_actions()) + + def reward_spec(self) -> specs.BoundedArray: + return specs.BoundedArray( + (), + np.float32, + minimum=self._environment.game.min_utility(), + maximum=self._environment.game.max_utility(), + ) + + def discount_spec(self) -> specs.BoundedArray: + return specs.BoundedArray((), np.float32, minimum=0, maximum=1.0) + + @property + def environment(self) -> rl_environment.Environment: + """Returns the wrapped environment.""" + return self._environment + + @property + def current_player(self) -> int: + return self._environment.get_state.current_player() + + def __getattr__(self, name: str): + """Expose any other attributes of the underlying environment.""" + if name.startswith("__"): + raise AttributeError( + "attempted to get missing private attribute '{}'".format(name) + ) + return getattr(self._environment, name) diff --git a/acme/wrappers/open_spiel_wrapper_test.py b/acme/wrappers/open_spiel_wrapper_test.py index faaaf89971..88f2d92a53 100644 --- a/acme/wrappers/open_spiel_wrapper_test.py +++ b/acme/wrappers/open_spiel_wrapper_test.py @@ -16,53 +16,53 @@ import unittest -from dm_env import specs import numpy as np - from absl.testing import absltest +from dm_env import specs SKIP_OPEN_SPIEL_TESTS = False -SKIP_OPEN_SPIEL_MESSAGE = 'open_spiel not installed.' +SKIP_OPEN_SPIEL_MESSAGE = "open_spiel not installed." try: - # pylint: disable=g-import-not-at-top - # pytype: disable=import-error - from acme.wrappers import open_spiel_wrapper - from open_spiel.python import rl_environment - # pytype: enable=import-error + # pylint: disable=g-import-not-at-top + # pytype: disable=import-error + from open_spiel.python import rl_environment + + from acme.wrappers import open_spiel_wrapper + + # pytype: enable=import-error except ModuleNotFoundError: - SKIP_OPEN_SPIEL_TESTS = True + SKIP_OPEN_SPIEL_TESTS = True @unittest.skipIf(SKIP_OPEN_SPIEL_TESTS, SKIP_OPEN_SPIEL_MESSAGE) class OpenSpielWrapperTest(absltest.TestCase): + def test_tic_tac_toe(self): + raw_env = rl_environment.Environment("tic_tac_toe") + env = open_spiel_wrapper.OpenSpielWrapper(raw_env) - def test_tic_tac_toe(self): - raw_env = rl_environment.Environment('tic_tac_toe') - env = open_spiel_wrapper.OpenSpielWrapper(raw_env) - - # Test converted observation spec. - observation_spec = env.observation_spec() - self.assertEqual(type(observation_spec), open_spiel_wrapper.OLT) - self.assertEqual(type(observation_spec.observation), specs.Array) - self.assertEqual(type(observation_spec.legal_actions), specs.Array) - self.assertEqual(type(observation_spec.terminal), specs.Array) + # Test converted observation spec. + observation_spec = env.observation_spec() + self.assertEqual(type(observation_spec), open_spiel_wrapper.OLT) + self.assertEqual(type(observation_spec.observation), specs.Array) + self.assertEqual(type(observation_spec.legal_actions), specs.Array) + self.assertEqual(type(observation_spec.terminal), specs.Array) - # Test converted action spec. - action_spec: specs.DiscreteArray = env.action_spec() - self.assertEqual(type(action_spec), specs.DiscreteArray) - self.assertEqual(action_spec.shape, ()) - self.assertEqual(action_spec.minimum, 0) - self.assertEqual(action_spec.maximum, 8) - self.assertEqual(action_spec.num_values, 9) - self.assertEqual(action_spec.dtype, np.dtype('int32')) + # Test converted action spec. + action_spec: specs.DiscreteArray = env.action_spec() + self.assertEqual(type(action_spec), specs.DiscreteArray) + self.assertEqual(action_spec.shape, ()) + self.assertEqual(action_spec.minimum, 0) + self.assertEqual(action_spec.maximum, 8) + self.assertEqual(action_spec.num_values, 9) + self.assertEqual(action_spec.dtype, np.dtype("int32")) - # Test step. - timestep = env.reset() - self.assertTrue(timestep.first()) - _ = env.step([0]) - env.close() + # Test step. + timestep = env.reset() + self.assertTrue(timestep.first()) + _ = env.step([0]) + env.close() -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/wrappers/single_precision.py b/acme/wrappers/single_precision.py index e1b90c7f67..225725e53e 100644 --- a/acme/wrappers/single_precision.py +++ b/acme/wrappers/single_precision.py @@ -14,72 +14,72 @@ """Environment wrapper which converts double-to-single precision.""" -from acme import specs -from acme import types -from acme.wrappers import base - import dm_env import numpy as np import tree +from acme import specs, types +from acme.wrappers import base + class SinglePrecisionWrapper(base.EnvironmentWrapper): - """Wrapper which converts environments from double- to single-precision.""" + """Wrapper which converts environments from double- to single-precision.""" - def _convert_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: - return timestep._replace( - reward=_convert_value(timestep.reward), - discount=_convert_value(timestep.discount), - observation=_convert_value(timestep.observation)) + def _convert_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: + return timestep._replace( + reward=_convert_value(timestep.reward), + discount=_convert_value(timestep.discount), + observation=_convert_value(timestep.observation), + ) - def step(self, action) -> dm_env.TimeStep: - return self._convert_timestep(self._environment.step(action)) + def step(self, action) -> dm_env.TimeStep: + return self._convert_timestep(self._environment.step(action)) - def reset(self) -> dm_env.TimeStep: - return self._convert_timestep(self._environment.reset()) + def reset(self) -> dm_env.TimeStep: + return self._convert_timestep(self._environment.reset()) - def action_spec(self): - return _convert_spec(self._environment.action_spec()) + def action_spec(self): + return _convert_spec(self._environment.action_spec()) - def discount_spec(self): - return _convert_spec(self._environment.discount_spec()) + def discount_spec(self): + return _convert_spec(self._environment.discount_spec()) - def observation_spec(self): - return _convert_spec(self._environment.observation_spec()) + def observation_spec(self): + return _convert_spec(self._environment.observation_spec()) - def reward_spec(self): - return _convert_spec(self._environment.reward_spec()) + def reward_spec(self): + return _convert_spec(self._environment.reward_spec()) def _convert_spec(nested_spec: types.NestedSpec) -> types.NestedSpec: - """Convert a nested spec.""" + """Convert a nested spec.""" - def _convert_single_spec(spec: specs.Array): - """Convert a single spec.""" - if spec.dtype == 'O': - # Pass StringArray objects through unmodified. - return spec - if np.issubdtype(spec.dtype, np.float64): - dtype = np.float32 - elif np.issubdtype(spec.dtype, np.int64): - dtype = np.int32 - else: - dtype = spec.dtype - return spec.replace(dtype=dtype) + def _convert_single_spec(spec: specs.Array): + """Convert a single spec.""" + if spec.dtype == "O": + # Pass StringArray objects through unmodified. + return spec + if np.issubdtype(spec.dtype, np.float64): + dtype = np.float32 + elif np.issubdtype(spec.dtype, np.int64): + dtype = np.int32 + else: + dtype = spec.dtype + return spec.replace(dtype=dtype) - return tree.map_structure(_convert_single_spec, nested_spec) + return tree.map_structure(_convert_single_spec, nested_spec) def _convert_value(nested_value: types.Nest) -> types.Nest: - """Convert a nested value given a desired nested spec.""" - - def _convert_single_value(value): - if value is not None: - value = np.array(value, copy=False) - if np.issubdtype(value.dtype, np.float64): - value = np.array(value, copy=False, dtype=np.float32) - elif np.issubdtype(value.dtype, np.int64): - value = np.array(value, copy=False, dtype=np.int32) - return value - - return tree.map_structure(_convert_single_value, nested_value) + """Convert a nested value given a desired nested spec.""" + + def _convert_single_value(value): + if value is not None: + value = np.array(value, copy=False) + if np.issubdtype(value.dtype, np.float64): + value = np.array(value, copy=False, dtype=np.float32) + elif np.issubdtype(value.dtype, np.int64): + value = np.array(value, copy=False, dtype=np.int32) + return value + + return tree.map_structure(_convert_single_value, nested_value) diff --git a/acme/wrappers/single_precision_test.py b/acme/wrappers/single_precision_test.py index f99779cbce..3a3d35bed8 100644 --- a/acme/wrappers/single_precision_test.py +++ b/acme/wrappers/single_precision_test.py @@ -14,58 +14,61 @@ """Tests for the single precision wrapper.""" -from acme import wrappers -from acme.testing import fakes -from dm_env import specs import numpy as np - from absl.testing import absltest +from dm_env import specs +from acme import wrappers +from acme.testing import fakes -class SinglePrecisionTest(absltest.TestCase): - def test_continuous(self): - env = wrappers.SinglePrecisionWrapper( - fakes.ContinuousEnvironment( - action_dim=0, dtype=np.float64, reward_dtype=np.float64)) +class SinglePrecisionTest(absltest.TestCase): + def test_continuous(self): + env = wrappers.SinglePrecisionWrapper( + fakes.ContinuousEnvironment( + action_dim=0, dtype=np.float64, reward_dtype=np.float64 + ) + ) - self.assertTrue(np.issubdtype(env.observation_spec().dtype, np.float32)) - self.assertTrue(np.issubdtype(env.action_spec().dtype, np.float32)) - self.assertTrue(np.issubdtype(env.reward_spec().dtype, np.float32)) - self.assertTrue(np.issubdtype(env.discount_spec().dtype, np.float32)) + self.assertTrue(np.issubdtype(env.observation_spec().dtype, np.float32)) + self.assertTrue(np.issubdtype(env.action_spec().dtype, np.float32)) + self.assertTrue(np.issubdtype(env.reward_spec().dtype, np.float32)) + self.assertTrue(np.issubdtype(env.discount_spec().dtype, np.float32)) - timestep = env.reset() - self.assertIsNone(timestep.reward) - self.assertIsNone(timestep.discount) - self.assertTrue(np.issubdtype(timestep.observation.dtype, np.float32)) + timestep = env.reset() + self.assertIsNone(timestep.reward) + self.assertIsNone(timestep.discount) + self.assertTrue(np.issubdtype(timestep.observation.dtype, np.float32)) - timestep = env.step(0.0) - self.assertTrue(np.issubdtype(timestep.reward.dtype, np.float32)) - self.assertTrue(np.issubdtype(timestep.discount.dtype, np.float32)) - self.assertTrue(np.issubdtype(timestep.observation.dtype, np.float32)) + timestep = env.step(0.0) + self.assertTrue(np.issubdtype(timestep.reward.dtype, np.float32)) + self.assertTrue(np.issubdtype(timestep.discount.dtype, np.float32)) + self.assertTrue(np.issubdtype(timestep.observation.dtype, np.float32)) - def test_discrete(self): - env = wrappers.SinglePrecisionWrapper( - fakes.DiscreteEnvironment( - action_dtype=np.int64, - obs_dtype=np.int64, - reward_spec=specs.Array(dtype=np.float64, shape=()))) + def test_discrete(self): + env = wrappers.SinglePrecisionWrapper( + fakes.DiscreteEnvironment( + action_dtype=np.int64, + obs_dtype=np.int64, + reward_spec=specs.Array(dtype=np.float64, shape=()), + ) + ) - self.assertTrue(np.issubdtype(env.observation_spec().dtype, np.int32)) - self.assertTrue(np.issubdtype(env.action_spec().dtype, np.int32)) - self.assertTrue(np.issubdtype(env.reward_spec().dtype, np.float32)) - self.assertTrue(np.issubdtype(env.discount_spec().dtype, np.float32)) + self.assertTrue(np.issubdtype(env.observation_spec().dtype, np.int32)) + self.assertTrue(np.issubdtype(env.action_spec().dtype, np.int32)) + self.assertTrue(np.issubdtype(env.reward_spec().dtype, np.float32)) + self.assertTrue(np.issubdtype(env.discount_spec().dtype, np.float32)) - timestep = env.reset() - self.assertIsNone(timestep.reward) - self.assertIsNone(timestep.discount) - self.assertTrue(np.issubdtype(timestep.observation.dtype, np.int32)) + timestep = env.reset() + self.assertIsNone(timestep.reward) + self.assertIsNone(timestep.discount) + self.assertTrue(np.issubdtype(timestep.observation.dtype, np.int32)) - timestep = env.step(0) - self.assertTrue(np.issubdtype(timestep.reward.dtype, np.float32)) - self.assertTrue(np.issubdtype(timestep.discount.dtype, np.float32)) - self.assertTrue(np.issubdtype(timestep.observation.dtype, np.int32)) + timestep = env.step(0) + self.assertTrue(np.issubdtype(timestep.reward.dtype, np.float32)) + self.assertTrue(np.issubdtype(timestep.discount.dtype, np.float32)) + self.assertTrue(np.issubdtype(timestep.observation.dtype, np.int32)) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/wrappers/step_limit.py b/acme/wrappers/step_limit.py index d327aaafd9..4b0c395b2a 100644 --- a/acme/wrappers/step_limit.py +++ b/acme/wrappers/step_limit.py @@ -15,40 +15,44 @@ """Wrapper that implements environment step limit.""" from typing import Optional + +import dm_env + from acme import types from acme.wrappers import base -import dm_env class StepLimitWrapper(base.EnvironmentWrapper): - """A wrapper which truncates episodes at the specified step limit.""" - - def __init__(self, environment: dm_env.Environment, - step_limit: Optional[int] = None): - super().__init__(environment) - self._step_limit = step_limit - self._elapsed_steps = 0 - - def reset(self) -> dm_env.TimeStep: - self._elapsed_steps = 0 - return self._environment.reset() - - def step(self, action: types.NestedArray) -> dm_env.TimeStep: - if self._elapsed_steps == -1: - # The previous episode was truncated by the wrapper, so start a new one. - timestep = self._environment.reset() - else: - timestep = self._environment.step(action) - # If this is the first timestep, then this `step()` call was done on a new, - # terminated or truncated environment instance without calling `reset()` - # first. In this case this `step()` call should be treated as `reset()`, - # so should not increment step count. - if timestep.first(): - self._elapsed_steps = 0 - return timestep - self._elapsed_steps += 1 - if self._step_limit is not None and self._elapsed_steps >= self._step_limit: - self._elapsed_steps = -1 - return dm_env.truncation( - timestep.reward, timestep.observation, timestep.discount) - return timestep + """A wrapper which truncates episodes at the specified step limit.""" + + def __init__( + self, environment: dm_env.Environment, step_limit: Optional[int] = None + ): + super().__init__(environment) + self._step_limit = step_limit + self._elapsed_steps = 0 + + def reset(self) -> dm_env.TimeStep: + self._elapsed_steps = 0 + return self._environment.reset() + + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + if self._elapsed_steps == -1: + # The previous episode was truncated by the wrapper, so start a new one. + timestep = self._environment.reset() + else: + timestep = self._environment.step(action) + # If this is the first timestep, then this `step()` call was done on a new, + # terminated or truncated environment instance without calling `reset()` + # first. In this case this `step()` call should be treated as `reset()`, + # so should not increment step count. + if timestep.first(): + self._elapsed_steps = 0 + return timestep + self._elapsed_steps += 1 + if self._step_limit is not None and self._elapsed_steps >= self._step_limit: + self._elapsed_steps = -1 + return dm_env.truncation( + timestep.reward, timestep.observation, timestep.discount + ) + return timestep diff --git a/acme/wrappers/step_limit_test.py b/acme/wrappers/step_limit_test.py index 78fbc47856..f3d215540c 100644 --- a/acme/wrappers/step_limit_test.py +++ b/acme/wrappers/step_limit_test.py @@ -14,61 +14,60 @@ """Tests for the step limit wrapper.""" -from acme import wrappers -from acme.testing import fakes import numpy as np - from absl.testing import absltest +from acme import wrappers +from acme.testing import fakes + ACTION = np.array(0, dtype=np.int32) class StepLimitWrapperTest(absltest.TestCase): + def test_step(self): + fake_env = fakes.DiscreteEnvironment(episode_length=5) + env = wrappers.StepLimitWrapper(fake_env, step_limit=2) - def test_step(self): - fake_env = fakes.DiscreteEnvironment(episode_length=5) - env = wrappers.StepLimitWrapper(fake_env, step_limit=2) - - env.reset() - env.step(ACTION) - self.assertTrue(env.step(ACTION).last()) + env.reset() + env.step(ACTION) + self.assertTrue(env.step(ACTION).last()) - def test_step_on_new_env(self): - fake_env = fakes.DiscreteEnvironment(episode_length=5) - env = wrappers.StepLimitWrapper(fake_env, step_limit=2) + def test_step_on_new_env(self): + fake_env = fakes.DiscreteEnvironment(episode_length=5) + env = wrappers.StepLimitWrapper(fake_env, step_limit=2) - self.assertTrue(env.step(ACTION).first()) - self.assertFalse(env.step(ACTION).last()) - self.assertTrue(env.step(ACTION).last()) + self.assertTrue(env.step(ACTION).first()) + self.assertFalse(env.step(ACTION).last()) + self.assertTrue(env.step(ACTION).last()) - def test_step_after_truncation(self): - fake_env = fakes.DiscreteEnvironment(episode_length=5) - env = wrappers.StepLimitWrapper(fake_env, step_limit=2) + def test_step_after_truncation(self): + fake_env = fakes.DiscreteEnvironment(episode_length=5) + env = wrappers.StepLimitWrapper(fake_env, step_limit=2) - env.reset() - env.step(ACTION) - self.assertTrue(env.step(ACTION).last()) + env.reset() + env.step(ACTION) + self.assertTrue(env.step(ACTION).last()) - self.assertTrue(env.step(ACTION).first()) - self.assertFalse(env.step(ACTION).last()) - self.assertTrue(env.step(ACTION).last()) + self.assertTrue(env.step(ACTION).first()) + self.assertFalse(env.step(ACTION).last()) + self.assertTrue(env.step(ACTION).last()) - def test_step_after_termination(self): - fake_env = fakes.DiscreteEnvironment(episode_length=5) + def test_step_after_termination(self): + fake_env = fakes.DiscreteEnvironment(episode_length=5) - fake_env.reset() - fake_env.step(ACTION) - fake_env.step(ACTION) - fake_env.step(ACTION) - fake_env.step(ACTION) - self.assertTrue(fake_env.step(ACTION).last()) + fake_env.reset() + fake_env.step(ACTION) + fake_env.step(ACTION) + fake_env.step(ACTION) + fake_env.step(ACTION) + self.assertTrue(fake_env.step(ACTION).last()) - env = wrappers.StepLimitWrapper(fake_env, step_limit=2) + env = wrappers.StepLimitWrapper(fake_env, step_limit=2) - self.assertTrue(env.step(ACTION).first()) - self.assertFalse(env.step(ACTION).last()) - self.assertTrue(env.step(ACTION).last()) + self.assertTrue(env.step(ACTION).first()) + self.assertFalse(env.step(ACTION).last()) + self.assertTrue(env.step(ACTION).last()) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/acme/wrappers/video.py b/acme/wrappers/video.py index 4a35a2f63a..eb44a320b0 100644 --- a/acme/wrappers/video.py +++ b/acme/wrappers/video.py @@ -22,12 +22,13 @@ import tempfile from typing import Callable, Optional, Sequence, Tuple, Union +import dm_env +import matplotlib + from acme.utils import paths from acme.wrappers import base -import dm_env -import matplotlib -matplotlib.use('Agg') # Switch to headless 'Agg' to inhibit figure rendering. +matplotlib.use("Agg") # Switch to headless 'Agg' to inhibit figure rendering. import matplotlib.animation as anim # pylint: disable=g-import-not-at-top import matplotlib.pyplot as plt import numpy as np @@ -35,51 +36,50 @@ # Internal imports. # Make sure you have FFMpeg configured. + def make_animation( - frames: Sequence[np.ndarray], frame_rate: float, - figsize: Optional[Union[float, Tuple[int, int]]]) -> anim.Animation: - """Generates a matplotlib animation from a stack of frames.""" - - # Set animation characteristics. - if figsize is None: - height, width, _ = frames[0].shape - elif isinstance(figsize, tuple): - height, width = figsize - else: - diagonal = figsize - height, width, _ = frames[0].shape - scale_factor = diagonal / np.sqrt(height**2 + width**2) - width *= scale_factor - height *= scale_factor - - dpi = 70 - interval = int(round(1e3 / frame_rate)) # Time (in ms) between frames. - - # Create and configure the figure. - fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi) - ax.set_axis_off() - ax.set_aspect('equal') - ax.set_position([0, 0, 1, 1]) - - # Initialize the first frame. - im = ax.imshow(frames[0]) - - # Create the function that will modify the frame, creating an animation. - def update(frame): - im.set_data(frame) - return [im] - - return anim.FuncAnimation( - fig=fig, - func=update, - frames=frames, - interval=interval, - blit=True, - repeat=False) + frames: Sequence[np.ndarray], + frame_rate: float, + figsize: Optional[Union[float, Tuple[int, int]]], +) -> anim.Animation: + """Generates a matplotlib animation from a stack of frames.""" + + # Set animation characteristics. + if figsize is None: + height, width, _ = frames[0].shape + elif isinstance(figsize, tuple): + height, width = figsize + else: + diagonal = figsize + height, width, _ = frames[0].shape + scale_factor = diagonal / np.sqrt(height ** 2 + width ** 2) + width *= scale_factor + height *= scale_factor + + dpi = 70 + interval = int(round(1e3 / frame_rate)) # Time (in ms) between frames. + + # Create and configure the figure. + fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi) + ax.set_axis_off() + ax.set_aspect("equal") + ax.set_position([0, 0, 1, 1]) + + # Initialize the first frame. + im = ax.imshow(frames[0]) + + # Create the function that will modify the frame, creating an animation. + def update(frame): + im.set_data(frame) + return [im] + + return anim.FuncAnimation( + fig=fig, func=update, frames=frames, interval=interval, blit=True, repeat=False + ) class VideoWrapper(base.EnvironmentWrapper): - """Wrapper which creates and records videos from generated observations. + """Wrapper which creates and records videos from generated observations. This will limit itself to recording once every `record_every` episodes and videos will be recorded to the directory `path` + '//videos' where @@ -88,169 +88,185 @@ class VideoWrapper(base.EnvironmentWrapper): of the diagonal. """ - def __init__( - self, - environment: dm_env.Environment, - *, - path: str = '~/acme', - filename: str = '', - process_path: Callable[[str, str], str] = paths.process_path, - record_every: int = 100, - frame_rate: int = 30, - figsize: Optional[Union[float, Tuple[int, int]]] = None, - to_html: bool = True, - ): - super(VideoWrapper, self).__init__(environment) - self._path = process_path(path, 'videos') - self._filename = filename - self._record_every = record_every - self._frame_rate = frame_rate - self._frames = [] - self._counter = 0 - self._figsize = figsize - self._to_html = to_html - - def _render_frame(self, observation): - """Renders a frame from the given environment observation.""" - return observation - - def _write_frames(self): - """Writes frames to video.""" - if self._counter % self._record_every == 0: - animation = make_animation(self._frames, self._frame_rate, self._figsize) - path_without_extension = os.path.join( - self._path, f'{self._filename}_{self._counter:04d}' - ) - if self._to_html: - path = path_without_extension + '.html' - video = animation.to_html5_video() - with open(path, 'w') as f: - f.write(video) - else: - path = path_without_extension + '.m4v' - # Animation.save can save only locally. Save first and copy using - # gfile. - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_path = os.path.join(tmp_dir, 'temp.m4v') - animation.save(tmp_path) - with open(path, 'wb') as f: - with open(tmp_path, 'rb') as g: - f.write(g.read()) - - # Clear the frame buffer whether a video was generated or not. - self._frames = [] - - def _append_frame(self, observation): - """Appends a frame to the sequence of frames.""" - if self._counter % self._record_every == 0: - self._frames.append(self._render_frame(observation)) - - def step(self, action) -> dm_env.TimeStep: - timestep = self.environment.step(action) - self._append_frame(timestep.observation) - return timestep - - def reset(self) -> dm_env.TimeStep: - # If the frame buffer is nonempty, flush it and record video - if self._frames: - self._write_frames() - self._counter += 1 - timestep = self.environment.reset() - self._append_frame(timestep.observation) - return timestep - - def make_html_animation(self): - if self._frames: - return make_animation(self._frames, self._frame_rate, - self._figsize).to_html5_video() - else: - raise ValueError('make_html_animation should be called after running a ' - 'trajectory and before calling reset().') - - def close(self): - if self._frames: - self._write_frames() - self._frames = [] - self.environment.close() + def __init__( + self, + environment: dm_env.Environment, + *, + path: str = "~/acme", + filename: str = "", + process_path: Callable[[str, str], str] = paths.process_path, + record_every: int = 100, + frame_rate: int = 30, + figsize: Optional[Union[float, Tuple[int, int]]] = None, + to_html: bool = True, + ): + super(VideoWrapper, self).__init__(environment) + self._path = process_path(path, "videos") + self._filename = filename + self._record_every = record_every + self._frame_rate = frame_rate + self._frames = [] + self._counter = 0 + self._figsize = figsize + self._to_html = to_html + + def _render_frame(self, observation): + """Renders a frame from the given environment observation.""" + return observation + + def _write_frames(self): + """Writes frames to video.""" + if self._counter % self._record_every == 0: + animation = make_animation(self._frames, self._frame_rate, self._figsize) + path_without_extension = os.path.join( + self._path, f"{self._filename}_{self._counter:04d}" + ) + if self._to_html: + path = path_without_extension + ".html" + video = animation.to_html5_video() + with open(path, "w") as f: + f.write(video) + else: + path = path_without_extension + ".m4v" + # Animation.save can save only locally. Save first and copy using + # gfile. + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = os.path.join(tmp_dir, "temp.m4v") + animation.save(tmp_path) + with open(path, "wb") as f: + with open(tmp_path, "rb") as g: + f.write(g.read()) + + # Clear the frame buffer whether a video was generated or not. + self._frames = [] + + def _append_frame(self, observation): + """Appends a frame to the sequence of frames.""" + if self._counter % self._record_every == 0: + self._frames.append(self._render_frame(observation)) + + def step(self, action) -> dm_env.TimeStep: + timestep = self.environment.step(action) + self._append_frame(timestep.observation) + return timestep + + def reset(self) -> dm_env.TimeStep: + # If the frame buffer is nonempty, flush it and record video + if self._frames: + self._write_frames() + self._counter += 1 + timestep = self.environment.reset() + self._append_frame(timestep.observation) + return timestep + + def make_html_animation(self): + if self._frames: + return make_animation( + self._frames, self._frame_rate, self._figsize + ).to_html5_video() + else: + raise ValueError( + "make_html_animation should be called after running a " + "trajectory and before calling reset()." + ) + + def close(self): + if self._frames: + self._write_frames() + self._frames = [] + self.environment.close() class MujocoVideoWrapper(VideoWrapper): - """VideoWrapper which generates videos from a mujoco physics object. + """VideoWrapper which generates videos from a mujoco physics object. This passes its keyword arguments into the parent `VideoWrapper` class (refer here for any default arguments). """ - # Note that since we can be given a wrapped mujoco environment we can't give - # the type as dm_control.Environment. - - def __init__(self, - environment: dm_env.Environment, - *, - frame_rate: Optional[int] = None, - camera_id: Optional[int] = 0, - height: int = 240, - width: int = 320, - playback_speed: float = 1., - **kwargs): - - # Check that we have a mujoco environment (or a wrapper thereof). - if not hasattr(environment, 'physics'): - raise ValueError('MujocoVideoWrapper expects an environment which ' - 'exposes a physics attribute corresponding to a MuJoCo ' - 'physics engine') - - # Compute frame rate if not set. - if frame_rate is None: - try: - control_timestep = getattr(environment, 'control_timestep')() - except AttributeError as e: - raise AttributeError('MujocoVideoWrapper expects an environment which ' - 'exposes a control_timestep method, like ' - 'dm_control environments, or frame_rate ' - 'to be specified.') from e - frame_rate = int(round(playback_speed / control_timestep)) - - super().__init__(environment, frame_rate=frame_rate, **kwargs) - self._camera_id = camera_id - self._height = height - self._width = width - - def _render_frame(self, unused_observation): - del unused_observation - - # We've checked above that this attribute should exist. Pytype won't like - # it if we just try and do self.environment.physics, so we use the slightly - # grosser version below. - physics = getattr(self.environment, 'physics') - - if self._camera_id is not None: - frame = physics.render( - camera_id=self._camera_id, height=self._height, width=self._width) - else: - # If camera_id is None, we create a minimal canvas that will accommodate - # physics.model.ncam frames, and render all of them on a grid. - num_cameras = physics.model.ncam - num_columns = int(np.ceil(np.sqrt(num_cameras))) - num_rows = int(np.ceil(float(num_cameras)/num_columns)) - height = self._height - width = self._width - - # Make a black canvas. - frame = np.zeros((num_rows*height, num_columns*width, 3), dtype=np.uint8) - - for col in range(num_columns): - for row in range(num_rows): - - camera_id = row*num_columns + col - - if camera_id >= num_cameras: - break - - subframe = physics.render( - camera_id=camera_id, height=height, width=width) - - # Place the frame in the appropriate rectangle on the pixel canvas. - frame[row*height:(row+1)*height, col*width:(col+1)*width] = subframe - - return frame + # Note that since we can be given a wrapped mujoco environment we can't give + # the type as dm_control.Environment. + + def __init__( + self, + environment: dm_env.Environment, + *, + frame_rate: Optional[int] = None, + camera_id: Optional[int] = 0, + height: int = 240, + width: int = 320, + playback_speed: float = 1.0, + **kwargs, + ): + + # Check that we have a mujoco environment (or a wrapper thereof). + if not hasattr(environment, "physics"): + raise ValueError( + "MujocoVideoWrapper expects an environment which " + "exposes a physics attribute corresponding to a MuJoCo " + "physics engine" + ) + + # Compute frame rate if not set. + if frame_rate is None: + try: + control_timestep = getattr(environment, "control_timestep")() + except AttributeError as e: + raise AttributeError( + "MujocoVideoWrapper expects an environment which " + "exposes a control_timestep method, like " + "dm_control environments, or frame_rate " + "to be specified." + ) from e + frame_rate = int(round(playback_speed / control_timestep)) + + super().__init__(environment, frame_rate=frame_rate, **kwargs) + self._camera_id = camera_id + self._height = height + self._width = width + + def _render_frame(self, unused_observation): + del unused_observation + + # We've checked above that this attribute should exist. Pytype won't like + # it if we just try and do self.environment.physics, so we use the slightly + # grosser version below. + physics = getattr(self.environment, "physics") + + if self._camera_id is not None: + frame = physics.render( + camera_id=self._camera_id, height=self._height, width=self._width + ) + else: + # If camera_id is None, we create a minimal canvas that will accommodate + # physics.model.ncam frames, and render all of them on a grid. + num_cameras = physics.model.ncam + num_columns = int(np.ceil(np.sqrt(num_cameras))) + num_rows = int(np.ceil(float(num_cameras) / num_columns)) + height = self._height + width = self._width + + # Make a black canvas. + frame = np.zeros( + (num_rows * height, num_columns * width, 3), dtype=np.uint8 + ) + + for col in range(num_columns): + for row in range(num_rows): + + camera_id = row * num_columns + col + + if camera_id >= num_cameras: + break + + subframe = physics.render( + camera_id=camera_id, height=height, width=width + ) + + # Place the frame in the appropriate rectangle on the pixel canvas. + frame[ + row * height : (row + 1) * height, + col * width : (col + 1) * width, + ] = subframe + + return frame diff --git a/docs/conf.py b/docs/conf.py index 5611bf80bf..d759e8c55c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -15,27 +15,24 @@ """Sphinx configuration. """ -project = 'Acme' -author = 'DeepMind Technologies Limited' -copyright = '2018, DeepMind Technologies Limited' # pylint: disable=redefined-builtin -version = '' -release = '' -master_doc = 'index' +project = "Acme" +author = "DeepMind Technologies Limited" +copyright = "2018, DeepMind Technologies Limited" # pylint: disable=redefined-builtin +version = "" +release = "" +master_doc = "index" -extensions = [ - 'myst_parser' -] +extensions = ["myst_parser"] -html_theme = 'sphinx_rtd_theme' -html_logo = 'imgs/acme.png' +html_theme = "sphinx_rtd_theme" +html_logo = "imgs/acme.png" html_theme_options = { - 'logo_only': True, + "logo_only": True, } html_css_files = [ - 'custom.css', + "custom.css", ] templates_path = [] -html_static_path = ['_static'] -exclude_patterns = ['_build', 'requirements.txt'] - +html_static_path = ["_static"] +exclude_patterns = ["_build", "requirements.txt"] diff --git a/examples/baselines/imitation/helpers.py b/examples/baselines/imitation/helpers.py index bf464bc963..ef7b6f4b04 100644 --- a/examples/baselines/imitation/helpers.py +++ b/examples/baselines/imitation/helpers.py @@ -15,54 +15,53 @@ """Helpers functions for imitation tasks.""" from typing import Tuple -from acme import wrappers - import dm_env import gym import numpy as np import tensorflow as tf +from acme import wrappers DATASET_NAMES = { - 'HalfCheetah-v2': 'locomotion/halfcheetah_sac_1M_single_policy_stochastic', - 'Ant-v2': 'locomotion/ant_sac_1M_single_policy_stochastic', - 'Walker2d-v2': 'locomotion/walker2d_sac_1M_single_policy_stochastic', - 'Hopper-v2': 'locomotion/hopper_sac_1M_single_policy_stochastic', - 'Humanoid-v2': 'locomotion/humanoid_sac_15M_single_policy_stochastic' + "HalfCheetah-v2": "locomotion/halfcheetah_sac_1M_single_policy_stochastic", + "Ant-v2": "locomotion/ant_sac_1M_single_policy_stochastic", + "Walker2d-v2": "locomotion/walker2d_sac_1M_single_policy_stochastic", + "Hopper-v2": "locomotion/hopper_sac_1M_single_policy_stochastic", + "Humanoid-v2": "locomotion/humanoid_sac_15M_single_policy_stochastic", } def get_dataset_name(env_name: str) -> str: - return DATASET_NAMES[env_name] + return DATASET_NAMES[env_name] -def get_observation_stats(transitions_iterator: tf.data.Dataset - ) -> Tuple[np.float64, np.float64]: - """Returns scale and shift of the observations in demonstrations.""" - observations = [step.observation for step in transitions_iterator] - mean = np.mean(observations, axis=0, dtype='float64') - std = np.std(observations, axis=0, dtype='float64') +def get_observation_stats( + transitions_iterator: tf.data.Dataset, +) -> Tuple[np.float64, np.float64]: + """Returns scale and shift of the observations in demonstrations.""" + observations = [step.observation for step in transitions_iterator] + mean = np.mean(observations, axis=0, dtype="float64") + std = np.std(observations, axis=0, dtype="float64") - shift = - mean - # The std is set to 1 if the observation values are below a threshold. - # This prevents normalizing observation values that are constant (which can - # be problematic with e.g. demonstrations coming from a different version - # of the environment and where the constant values are slightly different). - scale = 1 / ((std < 1e-6) + std) - return shift, scale + shift = -mean + # The std is set to 1 if the observation values are below a threshold. + # This prevents normalizing observation values that are constant (which can + # be problematic with e.g. demonstrations coming from a different version + # of the environment and where the constant values are slightly different). + scale = 1 / ((std < 1e-6) + std) + return shift, scale -def make_environment( - task: str = 'MountainCarContinuous-v0') -> dm_env.Environment: - """Creates an OpenAI Gym environment.""" +def make_environment(task: str = "MountainCarContinuous-v0") -> dm_env.Environment: + """Creates an OpenAI Gym environment.""" - # Load the gym environment. - environment = gym.make(task) + # Load the gym environment. + environment = gym.make(task) - # Make sure the environment obeys the dm_env.Environment interface. - environment = wrappers.GymWrapper(environment) - # Clip the action returned by the agent to the environment spec. - environment = wrappers.CanonicalSpecWrapper(environment, clip=True) - environment = wrappers.SinglePrecisionWrapper(environment) + # Make sure the environment obeys the dm_env.Environment interface. + environment = wrappers.GymWrapper(environment) + # Clip the action returned by the agent to the environment spec. + environment = wrappers.CanonicalSpecWrapper(environment, clip=True) + environment = wrappers.SinglePrecisionWrapper(environment) - return environment + return environment diff --git a/examples/baselines/imitation/run_bc.py b/examples/baselines/imitation/run_bc.py index ccbb2b29c6..e74faeacc8 100644 --- a/examples/baselines/imitation/run_bc.py +++ b/examples/baselines/imitation/run_bc.py @@ -20,77 +20,86 @@ from typing import Callable, Iterator, Tuple -from absl import flags -from acme import specs -from acme import types +import dm_env +import haiku as hk +import helpers +import launchpad as lp +import numpy as np +from absl import app, flags + +from acme import specs, types from acme.agents.jax import actor_core as actor_core_lib from acme.agents.jax import bc from acme.datasets import tfds -import helpers -from absl import app from acme.jax import experiments from acme.jax import types as jax_types from acme.jax import utils from acme.utils import lp_utils -import dm_env -import haiku as hk -import launchpad as lp -import numpy as np FLAGS = flags.FLAGS flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.') + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) # Agent flags -flags.DEFINE_string('env_name', 'HalfCheetah-v2', 'What environment to run') -flags.DEFINE_integer('num_demonstrations', 11, - 'Number of demonstration trajectories.') -flags.DEFINE_integer('num_bc_steps', 100_000, 'Number of bc learning steps.') -flags.DEFINE_integer('num_steps', 0, 'Number of environment steps.') -flags.DEFINE_integer('batch_size', 64, 'Batch size.') -flags.DEFINE_float('learning_rate', 1e-4, 'Optimizer learning rate.') -flags.DEFINE_float('dropout_rate', 0.1, 'Dropout rate of bc network.') -flags.DEFINE_integer('num_layers', 3, 'Num layers of bc network.') -flags.DEFINE_integer('num_units', 256, 'Num units of bc network layers.') -flags.DEFINE_integer('eval_every', 5000, 'Evaluation period.') -flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') -flags.DEFINE_integer('seed', 0, 'Random seed for learner and evaluator.') +flags.DEFINE_string("env_name", "HalfCheetah-v2", "What environment to run") +flags.DEFINE_integer("num_demonstrations", 11, "Number of demonstration trajectories.") +flags.DEFINE_integer("num_bc_steps", 100_000, "Number of bc learning steps.") +flags.DEFINE_integer("num_steps", 0, "Number of environment steps.") +flags.DEFINE_integer("batch_size", 64, "Batch size.") +flags.DEFINE_float("learning_rate", 1e-4, "Optimizer learning rate.") +flags.DEFINE_float("dropout_rate", 0.1, "Dropout rate of bc network.") +flags.DEFINE_integer("num_layers", 3, "Num layers of bc network.") +flags.DEFINE_integer("num_units", 256, "Num units of bc network layers.") +flags.DEFINE_integer("eval_every", 5000, "Evaluation period.") +flags.DEFINE_integer("evaluation_episodes", 10, "Evaluation episodes.") +flags.DEFINE_integer("seed", 0, "Random seed for learner and evaluator.") def _make_demonstration_dataset_factory( - dataset_name: str, num_demonstrations: int, - environment_spec: specs.EnvironmentSpec, batch_size: int + dataset_name: str, + num_demonstrations: int, + environment_spec: specs.EnvironmentSpec, + batch_size: int, ) -> Callable[[jax_types.PRNGKey], Iterator[types.Transition]]: - """Returns the demonstration dataset factory for the given dataset.""" + """Returns the demonstration dataset factory for the given dataset.""" - def demonstration_dataset_factory( - random_key: jax_types.PRNGKey) -> Iterator[types.Transition]: - """Returns an iterator of demonstration samples.""" + def demonstration_dataset_factory( + random_key: jax_types.PRNGKey, + ) -> Iterator[types.Transition]: + """Returns an iterator of demonstration samples.""" - transitions_iterator = tfds.get_tfds_dataset( - dataset_name, num_demonstrations, env_spec=environment_spec) - return tfds.JaxInMemoryRandomSampleIterator( - transitions_iterator, key=random_key, batch_size=batch_size) + transitions_iterator = tfds.get_tfds_dataset( + dataset_name, num_demonstrations, env_spec=environment_spec + ) + return tfds.JaxInMemoryRandomSampleIterator( + transitions_iterator, key=random_key, batch_size=batch_size + ) - return demonstration_dataset_factory + return demonstration_dataset_factory def _make_environment_factory(env_name: str) -> jax_types.EnvironmentFactory: - """Returns the environment factory for the given environment.""" + """Returns the environment factory for the given environment.""" - def environment_factory(seed: int) -> dm_env.Environment: - del seed - return helpers.make_environment(task=env_name) + def environment_factory(seed: int) -> dm_env.Environment: + del seed + return helpers.make_environment(task=env_name) - return environment_factory + return environment_factory def _make_network_factory( - shift: Tuple[np.float64], scale: Tuple[np.float64], num_layers: int, + shift: Tuple[np.float64], + scale: Tuple[np.float64], + num_layers: int, num_units: int, - dropout_rate: float) -> Callable[[specs.EnvironmentSpec], bc.BCNetworks]: - """Returns the factory of networks to be used by the agent. + dropout_rate: float, +) -> Callable[[specs.EnvironmentSpec], bc.BCNetworks]: + """Returns the factory of networks to be used by the agent. Args: shift: Shift of the observations in demonstrations. @@ -103,93 +112,96 @@ def _make_network_factory( Network factory. """ - def network_factory(spec: specs.EnvironmentSpec) -> bc.BCNetworks: - """Creates the network used by the agent.""" + def network_factory(spec: specs.EnvironmentSpec) -> bc.BCNetworks: + """Creates the network used by the agent.""" - action_spec = spec.actions - num_dimensions = np.prod(action_spec.shape, dtype=int) + action_spec = spec.actions + num_dimensions = np.prod(action_spec.shape, dtype=int) - def actor_fn(obs, is_training=False, key=None): - obs += shift - obs *= scale - hidden_layers = [num_units] * num_layers - mlp = hk.Sequential([ - hk.nets.MLP(hidden_layers + [num_dimensions]), - ]) - if is_training: - return mlp(obs, dropout_rate=dropout_rate, rng=key) - else: - return mlp(obs) + def actor_fn(obs, is_training=False, key=None): + obs += shift + obs *= scale + hidden_layers = [num_units] * num_layers + mlp = hk.Sequential([hk.nets.MLP(hidden_layers + [num_dimensions]),]) + if is_training: + return mlp(obs, dropout_rate=dropout_rate, rng=key) + else: + return mlp(obs) - policy = hk.without_apply_rng(hk.transform(actor_fn)) + policy = hk.without_apply_rng(hk.transform(actor_fn)) - # Create dummy observations to create network parameters. - dummy_obs = utils.zeros_like(spec.observations) - dummy_obs = utils.add_batch_dim(dummy_obs) + # Create dummy observations to create network parameters. + dummy_obs = utils.zeros_like(spec.observations) + dummy_obs = utils.add_batch_dim(dummy_obs) - policy_network = bc.BCPolicyNetwork(lambda key: policy.init(key, dummy_obs), - policy.apply) + policy_network = bc.BCPolicyNetwork( + lambda key: policy.init(key, dummy_obs), policy.apply + ) - return bc.BCNetworks(policy_network=policy_network) + return bc.BCNetworks(policy_network=policy_network) - return network_factory + return network_factory def build_experiment_config() -> experiments.OfflineExperimentConfig[ - bc.BCNetworks, actor_core_lib.FeedForwardPolicy, types.Transition]: - """Returns a config for BC experiments.""" - - # Create an environment, grab the spec, and use it to create networks. - environment = helpers.make_environment(task=FLAGS.env_name) - environment_spec = specs.make_environment_spec(environment) - - # Define the demonstrations factory. - dataset_name = helpers.get_dataset_name(FLAGS.env_name) - demonstration_dataset_factory = _make_demonstration_dataset_factory( - dataset_name, FLAGS.num_demonstrations, environment_spec, - FLAGS.batch_size) - - # Load the demonstrations to compute the stats. - dataset = tfds.get_tfds_dataset( - dataset_name, FLAGS.num_demonstrations, env_spec=environment_spec) - shift, scale = helpers.get_observation_stats(dataset) - - # Define the network factory. - network_factory = _make_network_factory( # pytype: disable=wrong-arg-types # numpy-scalars - shift=shift, - scale=scale, - num_layers=FLAGS.num_layers, - num_units=FLAGS.num_units, - dropout_rate=FLAGS.dropout_rate) - - # Create the BC builder. - bc_config = bc.BCConfig(learning_rate=FLAGS.learning_rate) - bc_builder = bc.BCBuilder(bc_config, loss_fn=bc.mse()) - - environment_factory = _make_environment_factory(FLAGS.env_name) - - return experiments.OfflineExperimentConfig( - builder=bc_builder, - network_factory=network_factory, - demonstration_dataset_factory=demonstration_dataset_factory, - environment_factory=environment_factory, - max_num_learner_steps=FLAGS.num_bc_steps, - seed=FLAGS.seed, - environment_spec=environment_spec, - ) + bc.BCNetworks, actor_core_lib.FeedForwardPolicy, types.Transition +]: + """Returns a config for BC experiments.""" + + # Create an environment, grab the spec, and use it to create networks. + environment = helpers.make_environment(task=FLAGS.env_name) + environment_spec = specs.make_environment_spec(environment) + + # Define the demonstrations factory. + dataset_name = helpers.get_dataset_name(FLAGS.env_name) + demonstration_dataset_factory = _make_demonstration_dataset_factory( + dataset_name, FLAGS.num_demonstrations, environment_spec, FLAGS.batch_size + ) + + # Load the demonstrations to compute the stats. + dataset = tfds.get_tfds_dataset( + dataset_name, FLAGS.num_demonstrations, env_spec=environment_spec + ) + shift, scale = helpers.get_observation_stats(dataset) + + # Define the network factory. + network_factory = _make_network_factory( # pytype: disable=wrong-arg-types # numpy-scalars + shift=shift, + scale=scale, + num_layers=FLAGS.num_layers, + num_units=FLAGS.num_units, + dropout_rate=FLAGS.dropout_rate, + ) + + # Create the BC builder. + bc_config = bc.BCConfig(learning_rate=FLAGS.learning_rate) + bc_builder = bc.BCBuilder(bc_config, loss_fn=bc.mse()) + + environment_factory = _make_environment_factory(FLAGS.env_name) + + return experiments.OfflineExperimentConfig( + builder=bc_builder, + network_factory=network_factory, + demonstration_dataset_factory=demonstration_dataset_factory, + environment_factory=environment_factory, + max_num_learner_steps=FLAGS.num_bc_steps, + seed=FLAGS.seed, + environment_spec=environment_spec, + ) def main(_): - config = build_experiment_config() - if FLAGS.run_distributed: - program = experiments.make_distributed_offline_experiment(experiment=config) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_offline_experiment( - experiment=config, - eval_every=FLAGS.eval_every, - num_eval_episodes=FLAGS.evaluation_episodes) - - -if __name__ == '__main__': - app.run(main) + config = build_experiment_config() + if FLAGS.run_distributed: + program = experiments.make_distributed_offline_experiment(experiment=config) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_offline_experiment( + experiment=config, + eval_every=FLAGS.eval_every, + num_eval_episodes=FLAGS.evaluation_episodes, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/baselines/imitation/run_gail.py b/examples/baselines/imitation/run_gail.py index a7c43a6c04..37b04c952a 100644 --- a/examples/baselines/imitation/run_gail.py +++ b/examples/baselines/imitation/run_gail.py @@ -33,122 +33,131 @@ (even for Humanoid). """ -from absl import flags -from acme import specs -from acme.agents.jax import ail -from acme.agents.jax import td3 -from acme.datasets import tfds -import helpers -from absl import app -from acme.jax import experiments -from acme.jax import networks as networks_lib -from acme.utils import lp_utils import dm_env import haiku as hk +import helpers import jax import launchpad as lp +from absl import app, flags +from acme import specs +from acme.agents.jax import ail, td3 +from acme.datasets import tfds +from acme.jax import experiments +from acme.jax import networks as networks_lib +from acme.utils import lp_utils FLAGS = flags.FLAGS flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.') -flags.DEFINE_string('env_name', 'HalfCheetah-v2', 'What environment to run') -flags.DEFINE_integer('seed', 0, 'Random seed.') -flags.DEFINE_integer('num_steps', 1_000_000, 'Number of env steps to run.') -flags.DEFINE_integer('eval_every', 50_000, 'Number of env steps to run.') -flags.DEFINE_integer('num_demonstrations', 11, - 'Number of demonstration trajectories.') -flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) +flags.DEFINE_string("env_name", "HalfCheetah-v2", "What environment to run") +flags.DEFINE_integer("seed", 0, "Random seed.") +flags.DEFINE_integer("num_steps", 1_000_000, "Number of env steps to run.") +flags.DEFINE_integer("eval_every", 50_000, "Number of env steps to run.") +flags.DEFINE_integer("num_demonstrations", 11, "Number of demonstration trajectories.") +flags.DEFINE_integer("evaluation_episodes", 10, "Evaluation episodes.") def build_experiment_config() -> experiments.ExperimentConfig: - """Returns a configuration for GAIL/DAC experiments.""" - - # Create an environment, grab the spec, and use it to create networks. - environment = helpers.make_environment(task=FLAGS.env_name) - environment_spec = specs.make_environment_spec(environment) - - # Create the direct RL agent. - td3_config = td3.TD3Config( - min_replay_size=1, - samples_per_insert_tolerance_rate=2.0) - td3_networks = td3.make_networks(environment_spec) - - # Create the discriminator. - def discriminator(*args, **kwargs) -> networks_lib.Logits: - return ail.DiscriminatorModule( - environment_spec=environment_spec, - use_action=True, - use_next_obs=False, - network_core=ail.DiscriminatorMLP( - hidden_layer_sizes=[64,], - spectral_normalization_lipschitz_coeff=1.) + """Returns a configuration for GAIL/DAC experiments.""" + + # Create an environment, grab the spec, and use it to create networks. + environment = helpers.make_environment(task=FLAGS.env_name) + environment_spec = specs.make_environment_spec(environment) + + # Create the direct RL agent. + td3_config = td3.TD3Config(min_replay_size=1, samples_per_insert_tolerance_rate=2.0) + td3_networks = td3.make_networks(environment_spec) + + # Create the discriminator. + def discriminator(*args, **kwargs) -> networks_lib.Logits: + return ail.DiscriminatorModule( + environment_spec=environment_spec, + use_action=True, + use_next_obs=False, + network_core=ail.DiscriminatorMLP( + hidden_layer_sizes=[64,], spectral_normalization_lipschitz_coeff=1.0 + ), )(*args, **kwargs) - discriminator_transformed = hk.without_apply_rng( - hk.transform_with_state(discriminator)) - - def network_factory( - environment_spec: specs.EnvironmentSpec) -> ail.AILNetworks: - return ail.AILNetworks( - ail.make_discriminator(environment_spec, discriminator_transformed), - # reward balance = 0 corresponds to the GAIL reward: -ln(1-D) - imitation_reward_fn=ail.rewards.gail_reward(reward_balance=0.), - direct_rl_networks=td3_networks) - - # Create demonstrations function. - dataset_name = helpers.get_dataset_name(FLAGS.env_name) - num_demonstrations = FLAGS.num_demonstrations - def make_demonstrations(batch_size, seed: int = 0): - transitions_iterator = tfds.get_tfds_dataset( - dataset_name, num_demonstrations, env_spec=environment_spec) - return tfds.JaxInMemoryRandomSampleIterator( - transitions_iterator, jax.random.PRNGKey(seed), batch_size) - - # Create DAC agent. - ail_config = ail.AILConfig(direct_rl_batch_size=td3_config.batch_size * - td3_config.num_sgd_steps_per_step) - - env_name = FLAGS.env_name - - def environment_factory(seed: int) -> dm_env.Environment: - del seed - return helpers.make_environment(task=env_name) - - td3_builder = td3.TD3Builder(td3_config) - - dac_loss = ail.losses.add_gradient_penalty( - ail.losses.gail_loss(entropy_coefficient=1e-3), - gradient_penalty_coefficient=10., - gradient_penalty_target=1.) - - ail_builder = ail.AILBuilder( - rl_agent=td3_builder, - config=ail_config, - discriminator_loss=dac_loss, - make_demonstrations=make_demonstrations) - - return experiments.ExperimentConfig( - builder=ail_builder, - environment_factory=environment_factory, - network_factory=network_factory, - seed=FLAGS.seed, - max_num_actor_steps=FLAGS.num_steps) + + discriminator_transformed = hk.without_apply_rng( + hk.transform_with_state(discriminator) + ) + + def network_factory(environment_spec: specs.EnvironmentSpec) -> ail.AILNetworks: + return ail.AILNetworks( + ail.make_discriminator(environment_spec, discriminator_transformed), + # reward balance = 0 corresponds to the GAIL reward: -ln(1-D) + imitation_reward_fn=ail.rewards.gail_reward(reward_balance=0.0), + direct_rl_networks=td3_networks, + ) + + # Create demonstrations function. + dataset_name = helpers.get_dataset_name(FLAGS.env_name) + num_demonstrations = FLAGS.num_demonstrations + + def make_demonstrations(batch_size, seed: int = 0): + transitions_iterator = tfds.get_tfds_dataset( + dataset_name, num_demonstrations, env_spec=environment_spec + ) + return tfds.JaxInMemoryRandomSampleIterator( + transitions_iterator, jax.random.PRNGKey(seed), batch_size + ) + + # Create DAC agent. + ail_config = ail.AILConfig( + direct_rl_batch_size=td3_config.batch_size * td3_config.num_sgd_steps_per_step + ) + + env_name = FLAGS.env_name + + def environment_factory(seed: int) -> dm_env.Environment: + del seed + return helpers.make_environment(task=env_name) + + td3_builder = td3.TD3Builder(td3_config) + + dac_loss = ail.losses.add_gradient_penalty( + ail.losses.gail_loss(entropy_coefficient=1e-3), + gradient_penalty_coefficient=10.0, + gradient_penalty_target=1.0, + ) + + ail_builder = ail.AILBuilder( + rl_agent=td3_builder, + config=ail_config, + discriminator_loss=dac_loss, + make_demonstrations=make_demonstrations, + ) + + return experiments.ExperimentConfig( + builder=ail_builder, + environment_factory=environment_factory, + network_factory=network_factory, + seed=FLAGS.seed, + max_num_actor_steps=FLAGS.num_steps, + ) def main(_): - config = build_experiment_config() - if FLAGS.run_distributed: - program = experiments.make_distributed_experiment( - experiment=config, num_actors=4) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_experiment( - experiment=config, - eval_every=FLAGS.eval_every, - num_eval_episodes=FLAGS.evaluation_episodes) - - -if __name__ == '__main__': - app.run(main) + config = build_experiment_config() + if FLAGS.run_distributed: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4 + ) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment( + experiment=config, + eval_every=FLAGS.eval_every, + num_eval_episodes=FLAGS.evaluation_episodes, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/baselines/imitation/run_iqlearn.py b/examples/baselines/imitation/run_iqlearn.py index 90bc2eeb6e..df3d4800a8 100644 --- a/examples/baselines/imitation/run_iqlearn.py +++ b/examples/baselines/imitation/run_iqlearn.py @@ -19,49 +19,46 @@ from typing import Callable, Iterator -from absl import flags -from acme import specs -from acme import types +import dm_env +import helpers +import jax +import launchpad as lp +from absl import app, flags + +from acme import specs, types from acme.agents.jax import actor_core as actor_core_lib from acme.agents.jax import iq_learn from acme.datasets import tfds -import helpers -from absl import app from acme.jax import experiments from acme.jax import types as jax_types from acme.utils import lp_utils -import dm_env -import jax -import launchpad as lp FLAGS = flags.FLAGS flags.DEFINE_bool( - 'run_distributed', + "run_distributed", True, ( - 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.' + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded." ), ) -flags.DEFINE_string('env_name', 'HalfCheetah-v2', 'What environment to run') -flags.DEFINE_integer('seed', 0, 'Random seed.') -flags.DEFINE_integer('num_steps', 1_000_000, 'Number of env steps to run.') -flags.DEFINE_integer('eval_every', 50_000, 'Number of env steps to run.') -flags.DEFINE_integer( - 'num_demonstrations', 11, 'Number of demonstration trajectories.' -) -flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') +flags.DEFINE_string("env_name", "HalfCheetah-v2", "What environment to run") +flags.DEFINE_integer("seed", 0, "Random seed.") +flags.DEFINE_integer("num_steps", 1_000_000, "Number of env steps to run.") +flags.DEFINE_integer("eval_every", 50_000, "Number of env steps to run.") +flags.DEFINE_integer("num_demonstrations", 11, "Number of demonstration trajectories.") +flags.DEFINE_integer("evaluation_episodes", 10, "Evaluation episodes.") def _make_environment_factory(env_name: str) -> jax_types.EnvironmentFactory: - """Returns the environment factory for the given environment.""" + """Returns the environment factory for the given environment.""" - def environment_factory(seed: int) -> dm_env.Environment: - del seed - return helpers.make_environment(task=env_name) + def environment_factory(seed: int) -> dm_env.Environment: + del seed + return helpers.make_environment(task=env_name) - return environment_factory + return environment_factory def _make_demonstration_dataset_factory( @@ -70,77 +67,73 @@ def _make_demonstration_dataset_factory( num_demonstrations: int, random_key: jax_types.PRNGKey, ) -> Callable[[jax_types.PRNGKey], Iterator[types.Transition]]: - """Returns the demonstration dataset factory for the given dataset.""" - - def demonstration_dataset_factory( - batch_size: int, - ) -> Iterator[types.Transition]: - """Returns an iterator of demonstration samples.""" - transitions_iterator = tfds.get_tfds_dataset( - dataset_name, num_episodes=num_demonstrations, env_spec=environment_spec - ) - return tfds.JaxInMemoryRandomSampleIterator( - transitions_iterator, key=random_key, batch_size=batch_size - ) + """Returns the demonstration dataset factory for the given dataset.""" - return demonstration_dataset_factory + def demonstration_dataset_factory(batch_size: int,) -> Iterator[types.Transition]: + """Returns an iterator of demonstration samples.""" + transitions_iterator = tfds.get_tfds_dataset( + dataset_name, num_episodes=num_demonstrations, env_spec=environment_spec + ) + return tfds.JaxInMemoryRandomSampleIterator( + transitions_iterator, key=random_key, batch_size=batch_size + ) + + return demonstration_dataset_factory def build_experiment_config() -> ( experiments.ExperimentConfig[ - iq_learn.IQLearnNetworks, - actor_core_lib.ActorCore, - iq_learn.IQLearnSample, + iq_learn.IQLearnNetworks, actor_core_lib.ActorCore, iq_learn.IQLearnSample, ] ): - """Returns a configuration for IQLearn experiments.""" - - # Create an environment, grab the spec, and use it to create networks. - env_name = FLAGS.env_name - environment_factory = _make_environment_factory(env_name) - - dummy_seed = 1 - environment = environment_factory(dummy_seed) - environment_spec = specs.make_environment_spec(environment) - - # Create demonstrations function. - dataset_name = helpers.get_dataset_name(env_name) - make_demonstrations = _make_demonstration_dataset_factory( - dataset_name, - environment_spec, - FLAGS.num_demonstrations, - jax.random.PRNGKey(FLAGS.seed), - ) - - # Construct the agent - iq_learn_config = iq_learn.IQLearnConfig(alpha=1.0) - iq_learn_builder = iq_learn.IQLearnBuilder( - config=iq_learn_config, make_demonstrations=make_demonstrations - ) - - return experiments.ExperimentConfig( - builder=iq_learn_builder, - environment_factory=environment_factory, - network_factory=iq_learn.make_networks, - seed=FLAGS.seed, - max_num_actor_steps=FLAGS.num_steps, - ) - + """Returns a configuration for IQLearn experiments.""" + + # Create an environment, grab the spec, and use it to create networks. + env_name = FLAGS.env_name + environment_factory = _make_environment_factory(env_name) + + dummy_seed = 1 + environment = environment_factory(dummy_seed) + environment_spec = specs.make_environment_spec(environment) + + # Create demonstrations function. + dataset_name = helpers.get_dataset_name(env_name) + make_demonstrations = _make_demonstration_dataset_factory( + dataset_name, + environment_spec, + FLAGS.num_demonstrations, + jax.random.PRNGKey(FLAGS.seed), + ) -def main(_): - config = build_experiment_config() - if FLAGS.run_distributed: - program = experiments.make_distributed_experiment( - experiment=config, num_actors=4 + # Construct the agent + iq_learn_config = iq_learn.IQLearnConfig(alpha=1.0) + iq_learn_builder = iq_learn.IQLearnBuilder( + config=iq_learn_config, make_demonstrations=make_demonstrations ) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_experiment( - experiment=config, - eval_every=FLAGS.eval_every, - num_eval_episodes=FLAGS.evaluation_episodes, + + return experiments.ExperimentConfig( + builder=iq_learn_builder, + environment_factory=environment_factory, + network_factory=iq_learn.make_networks, + seed=FLAGS.seed, + max_num_actor_steps=FLAGS.num_steps, ) -if __name__ == '__main__': - app.run(main) +def main(_): + config = build_experiment_config() + if FLAGS.run_distributed: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4 + ) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment( + experiment=config, + eval_every=FLAGS.eval_every, + num_eval_episodes=FLAGS.evaluation_episodes, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/baselines/imitation/run_pwil.py b/examples/baselines/imitation/run_pwil.py index 179b77945a..f102b71e77 100644 --- a/examples/baselines/imitation/run_pwil.py +++ b/examples/baselines/imitation/run_pwil.py @@ -20,141 +20,151 @@ from typing import Sequence -from absl import flags -from acme import specs -from acme.agents.jax import d4pg -from acme.agents.jax import pwil -from acme.datasets import tfds -import helpers -from absl import app -from acme.jax import experiments -from acme.jax import networks as networks_lib -from acme.jax import utils -from acme.utils import lp_utils import dm_env import haiku as hk +import helpers import jax.numpy as jnp import launchpad as lp import numpy as np +from absl import app, flags +from acme import specs +from acme.agents.jax import d4pg, pwil +from acme.datasets import tfds +from acme.jax import experiments +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import lp_utils FLAGS = flags.FLAGS flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.') -flags.DEFINE_string('env_name', 'HalfCheetah-v2', 'What environment to run') -flags.DEFINE_integer('seed', 0, 'Random seed.') -flags.DEFINE_integer('num_steps', 1_000_000, 'Number of env steps to run.') -flags.DEFINE_integer('eval_every', 50_000, 'Number of env steps to run.') -flags.DEFINE_integer('num_demonstrations', 11, - 'Number of demonstration trajectories.') -flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) +flags.DEFINE_string("env_name", "HalfCheetah-v2", "What environment to run") +flags.DEFINE_integer("seed", 0, "Random seed.") +flags.DEFINE_integer("num_steps", 1_000_000, "Number of env steps to run.") +flags.DEFINE_integer("eval_every", 50_000, "Number of env steps to run.") +flags.DEFINE_integer("num_demonstrations", 11, "Number of demonstration trajectories.") +flags.DEFINE_integer("evaluation_episodes", 10, "Evaluation episodes.") def make_networks( spec: specs.EnvironmentSpec, policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), - vmin: float = -150., - vmax: float = 150., + vmin: float = -150.0, + vmax: float = 150.0, num_atoms: int = 201, ) -> d4pg.D4PGNetworks: - """Creates networks used by the agent.""" - - action_spec = spec.actions - - num_dimensions = np.prod(action_spec.shape, dtype=int) - critic_atoms = jnp.linspace(vmin, vmax, num_atoms) - - def _actor_fn(obs): - network = hk.Sequential([ - utils.batch_concat, - networks_lib.LayerNormMLP(list(policy_layer_sizes) + [num_dimensions]), - networks_lib.TanhToSpec(action_spec), - ]) - return network(obs) - - def _critic_fn(obs, action): - network = hk.Sequential([ - utils.batch_concat, - networks_lib.LayerNormMLP(layer_sizes=[*critic_layer_sizes, num_atoms]), - ]) - value = network([obs, action]) - return value, critic_atoms - - policy = hk.without_apply_rng(hk.transform(_actor_fn)) - critic = hk.without_apply_rng(hk.transform(_critic_fn)) - - # Create dummy observations and actions to create network parameters. - dummy_action = utils.zeros_like(spec.actions) - dummy_obs = utils.zeros_like(spec.observations) - dummy_action = utils.add_batch_dim(dummy_action) - dummy_obs = utils.add_batch_dim(dummy_obs) - - return d4pg.D4PGNetworks( - policy_network=networks_lib.FeedForwardNetwork( - lambda rng: policy.init(rng, dummy_obs), policy.apply), - critic_network=networks_lib.FeedForwardNetwork( - lambda rng: critic.init(rng, dummy_obs, dummy_action), critic.apply)) + """Creates networks used by the agent.""" + + action_spec = spec.actions + + num_dimensions = np.prod(action_spec.shape, dtype=int) + critic_atoms = jnp.linspace(vmin, vmax, num_atoms) + + def _actor_fn(obs): + network = hk.Sequential( + [ + utils.batch_concat, + networks_lib.LayerNormMLP(list(policy_layer_sizes) + [num_dimensions]), + networks_lib.TanhToSpec(action_spec), + ] + ) + return network(obs) + + def _critic_fn(obs, action): + network = hk.Sequential( + [ + utils.batch_concat, + networks_lib.LayerNormMLP(layer_sizes=[*critic_layer_sizes, num_atoms]), + ] + ) + value = network([obs, action]) + return value, critic_atoms + + policy = hk.without_apply_rng(hk.transform(_actor_fn)) + critic = hk.without_apply_rng(hk.transform(_critic_fn)) + + # Create dummy observations and actions to create network parameters. + dummy_action = utils.zeros_like(spec.actions) + dummy_obs = utils.zeros_like(spec.observations) + dummy_action = utils.add_batch_dim(dummy_action) + dummy_obs = utils.add_batch_dim(dummy_obs) + + return d4pg.D4PGNetworks( + policy_network=networks_lib.FeedForwardNetwork( + lambda rng: policy.init(rng, dummy_obs), policy.apply + ), + critic_network=networks_lib.FeedForwardNetwork( + lambda rng: critic.init(rng, dummy_obs, dummy_action), critic.apply + ), + ) def build_experiment_config() -> experiments.ExperimentConfig: - """Returns a configuration for PWIL experiments.""" - - # Create an environment, grab the spec, and use it to create networks. - env_name = FLAGS.env_name - - def environment_factory(seed: int) -> dm_env.Environment: - del seed - return helpers.make_environment(task=env_name) - - dummy_seed = 1 - environment = environment_factory(dummy_seed) - environment_spec = specs.make_environment_spec(environment) - - # Create d4pg agent - d4pg_config = d4pg.D4PGConfig( - learning_rate=5e-5, sigma=0.2, samples_per_insert=256) - d4pg_builder = d4pg.D4PGBuilder(config=d4pg_config) - - # Create demonstrations function. - dataset_name = helpers.get_dataset_name(FLAGS.env_name) - num_demonstrations = FLAGS.num_demonstrations - - def make_demonstrations(): - transitions_iterator = tfds.get_tfds_dataset( - dataset_name, num_demonstrations, env_spec=environment_spec) - return pwil.PWILDemonstrations( - demonstrations=transitions_iterator, episode_length=1000) - - # Construct PWIL agent - pwil_config = pwil.PWILConfig(num_transitions_rb=0) - pwil_builder = pwil.PWILBuilder( - rl_agent=d4pg_builder, - config=pwil_config, - demonstrations_fn=make_demonstrations) - - return experiments.ExperimentConfig( - builder=pwil_builder, - environment_factory=environment_factory, - network_factory=make_networks, - seed=FLAGS.seed, - max_num_actor_steps=FLAGS.num_steps) + """Returns a configuration for PWIL experiments.""" + + # Create an environment, grab the spec, and use it to create networks. + env_name = FLAGS.env_name + + def environment_factory(seed: int) -> dm_env.Environment: + del seed + return helpers.make_environment(task=env_name) + + dummy_seed = 1 + environment = environment_factory(dummy_seed) + environment_spec = specs.make_environment_spec(environment) + + # Create d4pg agent + d4pg_config = d4pg.D4PGConfig(learning_rate=5e-5, sigma=0.2, samples_per_insert=256) + d4pg_builder = d4pg.D4PGBuilder(config=d4pg_config) + + # Create demonstrations function. + dataset_name = helpers.get_dataset_name(FLAGS.env_name) + num_demonstrations = FLAGS.num_demonstrations + + def make_demonstrations(): + transitions_iterator = tfds.get_tfds_dataset( + dataset_name, num_demonstrations, env_spec=environment_spec + ) + return pwil.PWILDemonstrations( + demonstrations=transitions_iterator, episode_length=1000 + ) + + # Construct PWIL agent + pwil_config = pwil.PWILConfig(num_transitions_rb=0) + pwil_builder = pwil.PWILBuilder( + rl_agent=d4pg_builder, config=pwil_config, demonstrations_fn=make_demonstrations + ) + + return experiments.ExperimentConfig( + builder=pwil_builder, + environment_factory=environment_factory, + network_factory=make_networks, + seed=FLAGS.seed, + max_num_actor_steps=FLAGS.num_steps, + ) def main(_): - config = build_experiment_config() - if FLAGS.run_distributed: - program = experiments.make_distributed_experiment( - experiment=config, num_actors=4) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_experiment( - experiment=config, - eval_every=FLAGS.eval_every, - num_eval_episodes=FLAGS.evaluation_episodes) - - -if __name__ == '__main__': - app.run(main) + config = build_experiment_config() + if FLAGS.run_distributed: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4 + ) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment( + experiment=config, + eval_every=FLAGS.eval_every, + num_eval_episodes=FLAGS.evaluation_episodes, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/baselines/imitation/run_sqil.py b/examples/baselines/imitation/run_sqil.py index aeb77c645c..1348a8d6ab 100644 --- a/examples/baselines/imitation/run_sqil.py +++ b/examples/baselines/imitation/run_sqil.py @@ -18,87 +18,95 @@ Reddy et al., 2019 https://arxiv.org/abs/1905.11108 """ -from absl import flags -from acme import specs -from acme.agents.jax import sac -from acme.agents.jax import sqil -from acme.datasets import tfds -import helpers -from absl import app -from acme.jax import experiments -from acme.utils import lp_utils import dm_env +import helpers import jax import launchpad as lp +from absl import app, flags +from acme import specs +from acme.agents.jax import sac, sqil +from acme.datasets import tfds +from acme.jax import experiments +from acme.utils import lp_utils FLAGS = flags.FLAGS flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.') -flags.DEFINE_string('env_name', 'HalfCheetah-v2', 'What environment to run') -flags.DEFINE_integer('seed', 0, 'Random seed.') -flags.DEFINE_integer('num_steps', 1_000_000, 'Number of env steps to run.') -flags.DEFINE_integer('eval_every', 50_000, 'Number of env steps to run.') -flags.DEFINE_integer('num_demonstrations', 11, - 'Number of demonstration trajectories.') -flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) +flags.DEFINE_string("env_name", "HalfCheetah-v2", "What environment to run") +flags.DEFINE_integer("seed", 0, "Random seed.") +flags.DEFINE_integer("num_steps", 1_000_000, "Number of env steps to run.") +flags.DEFINE_integer("eval_every", 50_000, "Number of env steps to run.") +flags.DEFINE_integer("num_demonstrations", 11, "Number of demonstration trajectories.") +flags.DEFINE_integer("evaluation_episodes", 10, "Evaluation episodes.") def build_experiment_config() -> experiments.ExperimentConfig: - """Returns a configuration for SQIL experiments.""" - - # Create an environment, grab the spec, and use it to create networks. - env_name = FLAGS.env_name - - def environment_factory(seed: int) -> dm_env.Environment: - del seed - return helpers.make_environment(task=env_name) - - dummy_seed = 1 - environment = environment_factory(dummy_seed) - environment_spec = specs.make_environment_spec(environment) - - # Construct the agent. - sac_config = sac.SACConfig( - target_entropy=sac.target_entropy_from_env_spec(environment_spec), - min_replay_size=1, - samples_per_insert_tolerance_rate=2.0) - sac_builder = sac.SACBuilder(sac_config) - - # Create demonstrations function. - dataset_name = helpers.get_dataset_name(FLAGS.env_name) - num_demonstrations = FLAGS.num_demonstrations - def make_demonstrations(batch_size: int, seed: int = 0): - transitions_iterator = tfds.get_tfds_dataset( - dataset_name, num_demonstrations, env_spec=environment_spec) - return tfds.JaxInMemoryRandomSampleIterator( - transitions_iterator, jax.random.PRNGKey(seed), batch_size) - - sqil_builder = sqil.SQILBuilder(sac_builder, sac_config.batch_size, - make_demonstrations) - - return experiments.ExperimentConfig( - builder=sqil_builder, - environment_factory=environment_factory, - network_factory=sac.make_networks, - seed=FLAGS.seed, - max_num_actor_steps=FLAGS.num_steps) + """Returns a configuration for SQIL experiments.""" + + # Create an environment, grab the spec, and use it to create networks. + env_name = FLAGS.env_name + + def environment_factory(seed: int) -> dm_env.Environment: + del seed + return helpers.make_environment(task=env_name) + + dummy_seed = 1 + environment = environment_factory(dummy_seed) + environment_spec = specs.make_environment_spec(environment) + + # Construct the agent. + sac_config = sac.SACConfig( + target_entropy=sac.target_entropy_from_env_spec(environment_spec), + min_replay_size=1, + samples_per_insert_tolerance_rate=2.0, + ) + sac_builder = sac.SACBuilder(sac_config) + + # Create demonstrations function. + dataset_name = helpers.get_dataset_name(FLAGS.env_name) + num_demonstrations = FLAGS.num_demonstrations + + def make_demonstrations(batch_size: int, seed: int = 0): + transitions_iterator = tfds.get_tfds_dataset( + dataset_name, num_demonstrations, env_spec=environment_spec + ) + return tfds.JaxInMemoryRandomSampleIterator( + transitions_iterator, jax.random.PRNGKey(seed), batch_size + ) + + sqil_builder = sqil.SQILBuilder( + sac_builder, sac_config.batch_size, make_demonstrations + ) + + return experiments.ExperimentConfig( + builder=sqil_builder, + environment_factory=environment_factory, + network_factory=sac.make_networks, + seed=FLAGS.seed, + max_num_actor_steps=FLAGS.num_steps, + ) def main(_): - config = build_experiment_config() - if FLAGS.run_distributed: - program = experiments.make_distributed_experiment( - experiment=config, num_actors=4) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_experiment( - experiment=config, - eval_every=FLAGS.eval_every, - num_eval_episodes=FLAGS.evaluation_episodes) - - -if __name__ == '__main__': - app.run(main) + config = build_experiment_config() + if FLAGS.run_distributed: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4 + ) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment( + experiment=config, + eval_every=FLAGS.eval_every, + num_eval_episodes=FLAGS.evaluation_episodes, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/baselines/rl_continuous/helpers.py b/examples/baselines/rl_continuous/helpers.py index 9c74703ce0..3bbb44238f 100644 --- a/examples/baselines/rl_continuous/helpers.py +++ b/examples/baselines/rl_continuous/helpers.py @@ -14,16 +14,16 @@ """Shared helpers for rl_continuous experiments.""" -from acme import wrappers import dm_env import gym +from acme import wrappers -_VALID_TASK_SUITES = ('gym', 'control') +_VALID_TASK_SUITES = ("gym", "control") def make_environment(suite: str, task: str) -> dm_env.Environment: - """Makes the requested continuous control environment. + """Makes the requested continuous control environment. Args: suite: One of 'gym' or 'control'. @@ -34,24 +34,26 @@ def make_environment(suite: str, task: str) -> dm_env.Environment: An environment satisfying the dm_env interface expected by Acme agents. """ - if suite not in _VALID_TASK_SUITES: - raise ValueError( - f'Unsupported suite: {suite}. Expected one of {_VALID_TASK_SUITES}') - - if suite == 'gym': - env = gym.make(task) - # Make sure the environment obeys the dm_env.Environment interface. - env = wrappers.GymWrapper(env) - - elif suite == 'control': - # Load dm_suite lazily not require Mujoco license when not using it. - from dm_control import suite as dm_suite # pylint: disable=g-import-not-at-top - domain_name, task_name = task.split(':') - env = dm_suite.load(domain_name, task_name) - env = wrappers.ConcatObservationWrapper(env) - - # Wrap the environment so the expected continuous action spec is [-1, 1]. - # Note: this is a no-op on 'control' tasks. - env = wrappers.CanonicalSpecWrapper(env, clip=True) - env = wrappers.SinglePrecisionWrapper(env) - return env + if suite not in _VALID_TASK_SUITES: + raise ValueError( + f"Unsupported suite: {suite}. Expected one of {_VALID_TASK_SUITES}" + ) + + if suite == "gym": + env = gym.make(task) + # Make sure the environment obeys the dm_env.Environment interface. + env = wrappers.GymWrapper(env) + + elif suite == "control": + # Load dm_suite lazily not require Mujoco license when not using it. + from dm_control import suite as dm_suite # pylint: disable=g-import-not-at-top + + domain_name, task_name = task.split(":") + env = dm_suite.load(domain_name, task_name) + env = wrappers.ConcatObservationWrapper(env) + + # Wrap the environment so the expected continuous action spec is [-1, 1]. + # Note: this is a no-op on 'control' tasks. + env = wrappers.CanonicalSpecWrapper(env, clip=True) + env = wrappers.SinglePrecisionWrapper(env) + return env diff --git a/examples/baselines/rl_continuous/run_d4pg.py b/examples/baselines/rl_continuous/run_d4pg.py index e6e9326b38..9a255ac811 100644 --- a/examples/baselines/rl_continuous/run_d4pg.py +++ b/examples/baselines/rl_continuous/run_d4pg.py @@ -14,72 +14,78 @@ """Example running D4PG on continuous control tasks.""" -from absl import flags -from acme.agents.jax import d4pg import helpers -from absl import app +import launchpad as lp +from absl import app, flags + +from acme.agents.jax import d4pg from acme.jax import experiments from acme.utils import lp_utils -import launchpad as lp FLAGS = flags.FLAGS flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.') -flags.DEFINE_string('env_name', 'gym:HalfCheetah-v2', 'What environment to run') -flags.DEFINE_integer('seed', 0, 'Random seed.') -flags.DEFINE_integer('num_steps', 1_000_000, 'Number of env steps to run.') -flags.DEFINE_integer('eval_every', 50_000, 'How often to run evaluation.') -flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) +flags.DEFINE_string("env_name", "gym:HalfCheetah-v2", "What environment to run") +flags.DEFINE_integer("seed", 0, "Random seed.") +flags.DEFINE_integer("num_steps", 1_000_000, "Number of env steps to run.") +flags.DEFINE_integer("eval_every", 50_000, "How often to run evaluation.") +flags.DEFINE_integer("evaluation_episodes", 10, "Evaluation episodes.") def build_experiment_config(): - """Builds D4PG experiment config which can be executed in different ways.""" - - # Create an environment, grab the spec, and use it to create networks. - suite, task = FLAGS.env_name.split(':', 1) - - # Bound of the distributional critic. The reward for control environments is - # normalized, not for gym locomotion environments hence the different scales. - vmax_values = { - 'gym': 1000., - 'control': 150., - } - vmax = vmax_values[suite] - - def network_factory(spec) -> d4pg.D4PGNetworks: - return d4pg.make_networks( - spec, - policy_layer_sizes=(256, 256, 256), - critic_layer_sizes=(256, 256, 256), - vmin=-vmax, - vmax=vmax, - ) + """Builds D4PG experiment config which can be executed in different ways.""" - # Configure the agent. - d4pg_config = d4pg.D4PGConfig(learning_rate=3e-4, sigma=0.2) + # Create an environment, grab the spec, and use it to create networks. + suite, task = FLAGS.env_name.split(":", 1) - return experiments.ExperimentConfig( - builder=d4pg.D4PGBuilder(d4pg_config), - environment_factory=lambda seed: helpers.make_environment(suite, task), - network_factory=network_factory, - seed=FLAGS.seed, - max_num_actor_steps=FLAGS.num_steps) + # Bound of the distributional critic. The reward for control environments is + # normalized, not for gym locomotion environments hence the different scales. + vmax_values = { + "gym": 1000.0, + "control": 150.0, + } + vmax = vmax_values[suite] + + def network_factory(spec) -> d4pg.D4PGNetworks: + return d4pg.make_networks( + spec, + policy_layer_sizes=(256, 256, 256), + critic_layer_sizes=(256, 256, 256), + vmin=-vmax, + vmax=vmax, + ) + + # Configure the agent. + d4pg_config = d4pg.D4PGConfig(learning_rate=3e-4, sigma=0.2) + + return experiments.ExperimentConfig( + builder=d4pg.D4PGBuilder(d4pg_config), + environment_factory=lambda seed: helpers.make_environment(suite, task), + network_factory=network_factory, + seed=FLAGS.seed, + max_num_actor_steps=FLAGS.num_steps, + ) def main(_): - config = build_experiment_config() - if FLAGS.run_distributed: - program = experiments.make_distributed_experiment( - experiment=config, num_actors=4) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_experiment( - experiment=config, - eval_every=FLAGS.eval_every, - num_eval_episodes=FLAGS.evaluation_episodes) - - -if __name__ == '__main__': - app.run(main) + config = build_experiment_config() + if FLAGS.run_distributed: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4 + ) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment( + experiment=config, + eval_every=FLAGS.eval_every, + num_eval_episodes=FLAGS.evaluation_episodes, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/baselines/rl_continuous/run_dmpo.py b/examples/baselines/rl_continuous/run_dmpo.py index a477ce86f4..04fd74015c 100644 --- a/examples/baselines/rl_continuous/run_dmpo.py +++ b/examples/baselines/rl_continuous/run_dmpo.py @@ -14,86 +14,98 @@ """Example running Distributional MPO on continuous control tasks.""" -from absl import flags +import helpers +import launchpad as lp +from absl import app, flags + from acme import specs from acme.agents.jax import mpo from acme.agents.jax.mpo import types as mpo_types -import helpers -from absl import app from acme.jax import experiments from acme.utils import lp_utils -import launchpad as lp RUN_DISTRIBUTED = flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.') + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) ENV_NAME = flags.DEFINE_string( - 'env_name', 'gym:HalfCheetah-v2', - 'What environment to run on, in the format {gym|control}:{task}, ' + "env_name", + "gym:HalfCheetah-v2", + "What environment to run on, in the format {gym|control}:{task}, " 'where "control" refers to the DM control suite. DM Control tasks are ' - 'further split into {domain_name}:{task_name}.') -SEED = flags.DEFINE_integer('seed', 0, 'Random seed.') + "further split into {domain_name}:{task_name}.", +) +SEED = flags.DEFINE_integer("seed", 0, "Random seed.") NUM_STEPS = flags.DEFINE_integer( - 'num_steps', 1_000_000, - 'Number of environment steps to run the experiment for.') + "num_steps", 1_000_000, "Number of environment steps to run the experiment for." +) EVAL_EVERY = flags.DEFINE_integer( - 'eval_every', 50_000, - 'How often (in actor environment steps) to run evaluation episodes.') + "eval_every", + 50_000, + "How often (in actor environment steps) to run evaluation episodes.", +) EVAL_EPISODES = flags.DEFINE_integer( - 'evaluation_episodes', 10, - 'Number of evaluation episodes to run periodically.') + "evaluation_episodes", 10, "Number of evaluation episodes to run periodically." +) def build_experiment_config(): - """Builds MPO experiment config which can be executed in different ways.""" - suite, task = ENV_NAME.value.split(':', 1) - critic_type = mpo.CriticType.CATEGORICAL + """Builds MPO experiment config which can be executed in different ways.""" + suite, task = ENV_NAME.value.split(":", 1) + critic_type = mpo.CriticType.CATEGORICAL - vmax_values = { - 'gym': 1600., - 'control': 150., - } - vmax = vmax_values[suite] + vmax_values = { + "gym": 1600.0, + "control": 150.0, + } + vmax = vmax_values[suite] - def network_factory(spec: specs.EnvironmentSpec) -> mpo.MPONetworks: - return mpo.make_control_networks( - spec, - policy_layer_sizes=(256, 256, 256), - critic_layer_sizes=(256, 256, 256), - policy_init_scale=0.5, - vmin=-vmax, - vmax=vmax, - critic_type=critic_type) + def network_factory(spec: specs.EnvironmentSpec) -> mpo.MPONetworks: + return mpo.make_control_networks( + spec, + policy_layer_sizes=(256, 256, 256), + critic_layer_sizes=(256, 256, 256), + policy_init_scale=0.5, + vmin=-vmax, + vmax=vmax, + critic_type=critic_type, + ) - # Configure and construct the agent builder. - config = mpo.MPOConfig( - critic_type=critic_type, - policy_loss_config=mpo_types.GaussianPolicyLossConfig(epsilon_mean=0.01), - samples_per_insert=64, - learning_rate=3e-4, - experience_type=mpo_types.FromTransitions(n_step=4)) - agent_builder = mpo.MPOBuilder(config, sgd_steps_per_learner_step=1) + # Configure and construct the agent builder. + config = mpo.MPOConfig( + critic_type=critic_type, + policy_loss_config=mpo_types.GaussianPolicyLossConfig(epsilon_mean=0.01), + samples_per_insert=64, + learning_rate=3e-4, + experience_type=mpo_types.FromTransitions(n_step=4), + ) + agent_builder = mpo.MPOBuilder(config, sgd_steps_per_learner_step=1) - return experiments.ExperimentConfig( - builder=agent_builder, - environment_factory=lambda _: helpers.make_environment(suite, task), - network_factory=network_factory, - seed=SEED.value, - max_num_actor_steps=NUM_STEPS.value) + return experiments.ExperimentConfig( + builder=agent_builder, + environment_factory=lambda _: helpers.make_environment(suite, task), + network_factory=network_factory, + seed=SEED.value, + max_num_actor_steps=NUM_STEPS.value, + ) def main(_): - config = build_experiment_config() - if RUN_DISTRIBUTED.value: - program = experiments.make_distributed_experiment( - experiment=config, num_actors=4) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_experiment( - experiment=config, - eval_every=EVAL_EVERY.value, - num_eval_episodes=EVAL_EPISODES.value) + config = build_experiment_config() + if RUN_DISTRIBUTED.value: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4 + ) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment( + experiment=config, + eval_every=EVAL_EVERY.value, + num_eval_episodes=EVAL_EPISODES.value, + ) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/examples/baselines/rl_continuous/run_mogmpo.py b/examples/baselines/rl_continuous/run_mogmpo.py index d4da09d2d3..53c11c267b 100644 --- a/examples/baselines/rl_continuous/run_mogmpo.py +++ b/examples/baselines/rl_continuous/run_mogmpo.py @@ -14,78 +14,90 @@ """Example running Mixture of Gaussian MPO on continuous control tasks.""" -from absl import flags +import helpers +import launchpad as lp +from absl import app, flags + from acme import specs from acme.agents.jax import mpo from acme.agents.jax.mpo import types as mpo_types -import helpers -from absl import app from acme.jax import experiments from acme.utils import lp_utils -import launchpad as lp RUN_DISTRIBUTED = flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.') + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) ENV_NAME = flags.DEFINE_string( - 'env_name', 'gym:HalfCheetah-v2', - 'What environment to run on, in the format {gym|control}:{task}, ' + "env_name", + "gym:HalfCheetah-v2", + "What environment to run on, in the format {gym|control}:{task}, " 'where "control" refers to the DM control suite. DM Control tasks are ' - 'further split into {domain_name}:{task_name}.') -SEED = flags.DEFINE_integer('seed', 0, 'Random seed.') + "further split into {domain_name}:{task_name}.", +) +SEED = flags.DEFINE_integer("seed", 0, "Random seed.") NUM_STEPS = flags.DEFINE_integer( - 'num_steps', 1_000_000, - 'Number of environment steps to run the experiment for.') + "num_steps", 1_000_000, "Number of environment steps to run the experiment for." +) EVAL_EVERY = flags.DEFINE_integer( - 'eval_every', 50_000, - 'How often (in actor environment steps) to run evaluation episodes.') + "eval_every", + 50_000, + "How often (in actor environment steps) to run evaluation episodes.", +) EVAL_EPISODES = flags.DEFINE_integer( - 'evaluation_episodes', 10, - 'Number of evaluation episodes to run periodically.') + "evaluation_episodes", 10, "Number of evaluation episodes to run periodically." +) def build_experiment_config(): - """Builds MPO experiment config which can be executed in different ways.""" - suite, task = ENV_NAME.value.split(':', 1) - critic_type = mpo.CriticType.MIXTURE_OF_GAUSSIANS + """Builds MPO experiment config which can be executed in different ways.""" + suite, task = ENV_NAME.value.split(":", 1) + critic_type = mpo.CriticType.MIXTURE_OF_GAUSSIANS - def network_factory(spec: specs.EnvironmentSpec) -> mpo.MPONetworks: - return mpo.make_control_networks( - spec, - policy_layer_sizes=(256, 256, 256), - critic_layer_sizes=(256, 256, 256), - policy_init_scale=0.5, - critic_type=critic_type) + def network_factory(spec: specs.EnvironmentSpec) -> mpo.MPONetworks: + return mpo.make_control_networks( + spec, + policy_layer_sizes=(256, 256, 256), + critic_layer_sizes=(256, 256, 256), + policy_init_scale=0.5, + critic_type=critic_type, + ) - # Configure and construct the agent builder. - config = mpo.MPOConfig( - critic_type=critic_type, - policy_loss_config=mpo_types.GaussianPolicyLossConfig(epsilon_mean=0.01), - samples_per_insert=64, - learning_rate=3e-4, - experience_type=mpo_types.FromTransitions(n_step=4)) - agent_builder = mpo.MPOBuilder(config, sgd_steps_per_learner_step=1) + # Configure and construct the agent builder. + config = mpo.MPOConfig( + critic_type=critic_type, + policy_loss_config=mpo_types.GaussianPolicyLossConfig(epsilon_mean=0.01), + samples_per_insert=64, + learning_rate=3e-4, + experience_type=mpo_types.FromTransitions(n_step=4), + ) + agent_builder = mpo.MPOBuilder(config, sgd_steps_per_learner_step=1) - return experiments.ExperimentConfig( - builder=agent_builder, - environment_factory=lambda _: helpers.make_environment(suite, task), - network_factory=network_factory, - seed=SEED.value, - max_num_actor_steps=NUM_STEPS.value) + return experiments.ExperimentConfig( + builder=agent_builder, + environment_factory=lambda _: helpers.make_environment(suite, task), + network_factory=network_factory, + seed=SEED.value, + max_num_actor_steps=NUM_STEPS.value, + ) def main(_): - config = build_experiment_config() - if RUN_DISTRIBUTED.value: - program = experiments.make_distributed_experiment( - experiment=config, num_actors=4) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_experiment( - experiment=config, - eval_every=EVAL_EVERY.value, - num_eval_episodes=EVAL_EPISODES.value) + config = build_experiment_config() + if RUN_DISTRIBUTED.value: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4 + ) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment( + experiment=config, + eval_every=EVAL_EVERY.value, + num_eval_episodes=EVAL_EPISODES.value, + ) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/examples/baselines/rl_continuous/run_mpo.py b/examples/baselines/rl_continuous/run_mpo.py index 51c928a62a..c040637fcc 100644 --- a/examples/baselines/rl_continuous/run_mpo.py +++ b/examples/baselines/rl_continuous/run_mpo.py @@ -14,78 +14,90 @@ """Example running MPO on continuous control tasks.""" -from absl import flags +import helpers +import launchpad as lp +from absl import app, flags + from acme import specs from acme.agents.jax import mpo from acme.agents.jax.mpo import types as mpo_types -import helpers -from absl import app from acme.jax import experiments from acme.utils import lp_utils -import launchpad as lp RUN_DISTRIBUTED = flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.') + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) ENV_NAME = flags.DEFINE_string( - 'env_name', 'gym:HalfCheetah-v2', - 'What environment to run on, in the format {gym|control}:{task}, ' + "env_name", + "gym:HalfCheetah-v2", + "What environment to run on, in the format {gym|control}:{task}, " 'where "control" refers to the DM control suite. DM Control tasks are ' - 'further split into {domain_name}:{task_name}.') -SEED = flags.DEFINE_integer('seed', 0, 'Random seed.') + "further split into {domain_name}:{task_name}.", +) +SEED = flags.DEFINE_integer("seed", 0, "Random seed.") NUM_STEPS = flags.DEFINE_integer( - 'num_steps', 1_000_000, - 'Number of environment steps to run the experiment for.') + "num_steps", 1_000_000, "Number of environment steps to run the experiment for." +) EVAL_EVERY = flags.DEFINE_integer( - 'eval_every', 50_000, - 'How often (in actor environment steps) to run evaluation episodes.') + "eval_every", + 50_000, + "How often (in actor environment steps) to run evaluation episodes.", +) EVAL_EPISODES = flags.DEFINE_integer( - 'evaluation_episodes', 10, - 'Number of evaluation episodes to run periodically.') + "evaluation_episodes", 10, "Number of evaluation episodes to run periodically." +) def build_experiment_config(): - """Builds MPO experiment config which can be executed in different ways.""" - suite, task = ENV_NAME.value.split(':', 1) - critic_type = mpo.CriticType.NONDISTRIBUTIONAL + """Builds MPO experiment config which can be executed in different ways.""" + suite, task = ENV_NAME.value.split(":", 1) + critic_type = mpo.CriticType.NONDISTRIBUTIONAL - def network_factory(spec: specs.EnvironmentSpec) -> mpo.MPONetworks: - return mpo.make_control_networks( - spec, - policy_layer_sizes=(256, 256, 256), - critic_layer_sizes=(256, 256, 256), - policy_init_scale=0.5, - critic_type=critic_type) + def network_factory(spec: specs.EnvironmentSpec) -> mpo.MPONetworks: + return mpo.make_control_networks( + spec, + policy_layer_sizes=(256, 256, 256), + critic_layer_sizes=(256, 256, 256), + policy_init_scale=0.5, + critic_type=critic_type, + ) - # Configure and construct the agent builder. - config = mpo.MPOConfig( - critic_type=critic_type, - policy_loss_config=mpo_types.GaussianPolicyLossConfig(epsilon_mean=0.01), - samples_per_insert=64, - learning_rate=3e-4, - experience_type=mpo_types.FromTransitions(n_step=4)) - agent_builder = mpo.MPOBuilder(config, sgd_steps_per_learner_step=1) + # Configure and construct the agent builder. + config = mpo.MPOConfig( + critic_type=critic_type, + policy_loss_config=mpo_types.GaussianPolicyLossConfig(epsilon_mean=0.01), + samples_per_insert=64, + learning_rate=3e-4, + experience_type=mpo_types.FromTransitions(n_step=4), + ) + agent_builder = mpo.MPOBuilder(config, sgd_steps_per_learner_step=1) - return experiments.ExperimentConfig( - builder=agent_builder, - environment_factory=lambda _: helpers.make_environment(suite, task), - network_factory=network_factory, - seed=SEED.value, - max_num_actor_steps=NUM_STEPS.value) + return experiments.ExperimentConfig( + builder=agent_builder, + environment_factory=lambda _: helpers.make_environment(suite, task), + network_factory=network_factory, + seed=SEED.value, + max_num_actor_steps=NUM_STEPS.value, + ) def main(_): - config = build_experiment_config() - if RUN_DISTRIBUTED.value: - program = experiments.make_distributed_experiment( - experiment=config, num_actors=4) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_experiment( - experiment=config, - eval_every=EVAL_EVERY.value, - num_eval_episodes=EVAL_EPISODES.value) + config = build_experiment_config() + if RUN_DISTRIBUTED.value: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4 + ) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment( + experiment=config, + eval_every=EVAL_EVERY.value, + num_eval_episodes=EVAL_EPISODES.value, + ) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/examples/baselines/rl_continuous/run_ppo.py b/examples/baselines/rl_continuous/run_ppo.py index ae2fb12831..c51c355ffc 100644 --- a/examples/baselines/rl_continuous/run_ppo.py +++ b/examples/baselines/rl_continuous/run_ppo.py @@ -14,60 +14,68 @@ """Example running PPO on continuous control tasks.""" -from absl import flags -from acme.agents.jax import ppo import helpers -from absl import app +import launchpad as lp +from absl import app, flags + +from acme.agents.jax import ppo from acme.jax import experiments from acme.utils import lp_utils -import launchpad as lp FLAGS = flags.FLAGS flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.') -flags.DEFINE_string('env_name', 'gym:HalfCheetah-v2', 'What environment to run') -flags.DEFINE_integer('seed', 0, 'Random seed.') -flags.DEFINE_integer('num_steps', 1_000_000, 'Number of env steps to run.') -flags.DEFINE_integer('eval_every', 50_000, 'How often to run evaluation.') -flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') -flags.DEFINE_integer('num_distributed_actors', 64, - 'Number of actors to use in the distributed setting.') + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) +flags.DEFINE_string("env_name", "gym:HalfCheetah-v2", "What environment to run") +flags.DEFINE_integer("seed", 0, "Random seed.") +flags.DEFINE_integer("num_steps", 1_000_000, "Number of env steps to run.") +flags.DEFINE_integer("eval_every", 50_000, "How often to run evaluation.") +flags.DEFINE_integer("evaluation_episodes", 10, "Evaluation episodes.") +flags.DEFINE_integer( + "num_distributed_actors", 64, "Number of actors to use in the distributed setting." +) def build_experiment_config(): - """Builds PPO experiment config which can be executed in different ways.""" - # Create an environment, grab the spec, and use it to create networks. - suite, task = FLAGS.env_name.split(':', 1) + """Builds PPO experiment config which can be executed in different ways.""" + # Create an environment, grab the spec, and use it to create networks. + suite, task = FLAGS.env_name.split(":", 1) - config = ppo.PPOConfig( - normalize_advantage=True, - normalize_value=True, - obs_normalization_fns_factory=ppo.build_mean_std_normalizer) - ppo_builder = ppo.PPOBuilder(config) + config = ppo.PPOConfig( + normalize_advantage=True, + normalize_value=True, + obs_normalization_fns_factory=ppo.build_mean_std_normalizer, + ) + ppo_builder = ppo.PPOBuilder(config) - layer_sizes = (256, 256, 256) - return experiments.ExperimentConfig( - builder=ppo_builder, - environment_factory=lambda seed: helpers.make_environment(suite, task), - network_factory=lambda spec: ppo.make_networks(spec, layer_sizes), - seed=FLAGS.seed, - max_num_actor_steps=FLAGS.num_steps) + layer_sizes = (256, 256, 256) + return experiments.ExperimentConfig( + builder=ppo_builder, + environment_factory=lambda seed: helpers.make_environment(suite, task), + network_factory=lambda spec: ppo.make_networks(spec, layer_sizes), + seed=FLAGS.seed, + max_num_actor_steps=FLAGS.num_steps, + ) def main(_): - config = build_experiment_config() - if FLAGS.run_distributed: - program = experiments.make_distributed_experiment( - experiment=config, num_actors=FLAGS.num_distributed_actors) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_experiment( - experiment=config, - eval_every=FLAGS.eval_every, - num_eval_episodes=FLAGS.evaluation_episodes) + config = build_experiment_config() + if FLAGS.run_distributed: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=FLAGS.num_distributed_actors + ) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment( + experiment=config, + eval_every=FLAGS.eval_every, + num_eval_episodes=FLAGS.evaluation_episodes, + ) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/examples/baselines/rl_continuous/run_sac.py b/examples/baselines/rl_continuous/run_sac.py index c80e84b8fe..72a848598a 100644 --- a/examples/baselines/rl_continuous/run_sac.py +++ b/examples/baselines/rl_continuous/run_sac.py @@ -14,68 +14,75 @@ """Example running SAC on continuous control tasks.""" -from absl import flags +import helpers +import launchpad as lp +from absl import app, flags + from acme import specs -from acme.agents.jax import normalization -from acme.agents.jax import sac +from acme.agents.jax import normalization, sac from acme.agents.jax.sac import builder -import helpers -from absl import app from acme.jax import experiments from acme.utils import lp_utils -import launchpad as lp FLAGS = flags.FLAGS flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.') -flags.DEFINE_string('env_name', 'gym:HalfCheetah-v2', 'What environment to run') -flags.DEFINE_integer('seed', 0, 'Random seed.') -flags.DEFINE_integer('num_steps', 1_000_000, 'Number of env steps to run.') -flags.DEFINE_integer('eval_every', 50_000, 'How often to run evaluation.') -flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) +flags.DEFINE_string("env_name", "gym:HalfCheetah-v2", "What environment to run") +flags.DEFINE_integer("seed", 0, "Random seed.") +flags.DEFINE_integer("num_steps", 1_000_000, "Number of env steps to run.") +flags.DEFINE_integer("eval_every", 50_000, "How often to run evaluation.") +flags.DEFINE_integer("evaluation_episodes", 10, "Evaluation episodes.") def build_experiment_config(): - """Builds SAC experiment config which can be executed in different ways.""" - # Create an environment, grab the spec, and use it to create networks. + """Builds SAC experiment config which can be executed in different ways.""" + # Create an environment, grab the spec, and use it to create networks. - suite, task = FLAGS.env_name.split(':', 1) - environment = helpers.make_environment(suite, task) + suite, task = FLAGS.env_name.split(":", 1) + environment = helpers.make_environment(suite, task) - environment_spec = specs.make_environment_spec(environment) - network_factory = ( - lambda spec: sac.make_networks(spec, hidden_layer_sizes=(256, 256, 256))) + environment_spec = specs.make_environment_spec(environment) + network_factory = lambda spec: sac.make_networks( + spec, hidden_layer_sizes=(256, 256, 256) + ) - # Construct the agent. - config = sac.SACConfig( - learning_rate=3e-4, - n_step=2, - target_entropy=sac.target_entropy_from_env_spec(environment_spec), - input_normalization=normalization.NormalizationConfig()) - sac_builder = builder.SACBuilder(config) + # Construct the agent. + config = sac.SACConfig( + learning_rate=3e-4, + n_step=2, + target_entropy=sac.target_entropy_from_env_spec(environment_spec), + input_normalization=normalization.NormalizationConfig(), + ) + sac_builder = builder.SACBuilder(config) - return experiments.ExperimentConfig( - builder=sac_builder, - environment_factory=lambda seed: helpers.make_environment(suite, task), - network_factory=network_factory, - seed=FLAGS.seed, - max_num_actor_steps=FLAGS.num_steps) + return experiments.ExperimentConfig( + builder=sac_builder, + environment_factory=lambda seed: helpers.make_environment(suite, task), + network_factory=network_factory, + seed=FLAGS.seed, + max_num_actor_steps=FLAGS.num_steps, + ) def main(_): - config = build_experiment_config() - if FLAGS.run_distributed: - program = experiments.make_distributed_experiment( - experiment=config, num_actors=4) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_experiment( - experiment=config, - eval_every=FLAGS.eval_every, - num_eval_episodes=FLAGS.evaluation_episodes) + config = build_experiment_config() + if FLAGS.run_distributed: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4 + ) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment( + experiment=config, + eval_every=FLAGS.eval_every, + num_eval_episodes=FLAGS.evaluation_episodes, + ) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/examples/baselines/rl_continuous/run_td3.py b/examples/baselines/rl_continuous/run_td3.py index 9bb446a900..17530f47e8 100644 --- a/examples/baselines/rl_continuous/run_td3.py +++ b/examples/baselines/rl_continuous/run_td3.py @@ -14,63 +14,66 @@ """Example running SAC on continuous control tasks.""" -from absl import flags -from acme.agents.jax import td3 import helpers -from absl import app -from acme.jax import experiments -from acme.utils import lp_utils import launchpad as lp +from absl import app, flags +from acme.agents.jax import td3 +from acme.jax import experiments +from acme.utils import lp_utils FLAGS = flags.FLAGS flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.') -flags.DEFINE_string('env_name', 'gym:HalfCheetah-v2', 'What environment to run') -flags.DEFINE_integer('seed', 0, 'Random seed.') -flags.DEFINE_integer('num_steps', 1_000_000, 'Number of env steps to run.') -flags.DEFINE_integer('eval_every', 50_000, 'How often to run evaluation.') -flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) +flags.DEFINE_string("env_name", "gym:HalfCheetah-v2", "What environment to run") +flags.DEFINE_integer("seed", 0, "Random seed.") +flags.DEFINE_integer("num_steps", 1_000_000, "Number of env steps to run.") +flags.DEFINE_integer("eval_every", 50_000, "How often to run evaluation.") +flags.DEFINE_integer("evaluation_episodes", 10, "Evaluation episodes.") def build_experiment_config(): - """Builds TD3 experiment config which can be executed in different ways.""" - # Create an environment, grab the spec, and use it to create networks. + """Builds TD3 experiment config which can be executed in different ways.""" + # Create an environment, grab the spec, and use it to create networks. - suite, task = FLAGS.env_name.split(':', 1) - network_factory = ( - lambda spec: td3.make_networks(spec, hidden_layer_sizes=(256, 256, 256))) + suite, task = FLAGS.env_name.split(":", 1) + network_factory = lambda spec: td3.make_networks( + spec, hidden_layer_sizes=(256, 256, 256) + ) - # Construct the agent. - config = td3.TD3Config( - policy_learning_rate=3e-4, - critic_learning_rate=3e-4, - ) - td3_builder = td3.TD3Builder(config) - # pylint:disable=g-long-lambda - return experiments.ExperimentConfig( - builder=td3_builder, - environment_factory=lambda seed: helpers.make_environment(suite, task), - network_factory=network_factory, - seed=FLAGS.seed, - max_num_actor_steps=FLAGS.num_steps) - # pylint:enable=g-long-lambda + # Construct the agent. + config = td3.TD3Config(policy_learning_rate=3e-4, critic_learning_rate=3e-4,) + td3_builder = td3.TD3Builder(config) + # pylint:disable=g-long-lambda + return experiments.ExperimentConfig( + builder=td3_builder, + environment_factory=lambda seed: helpers.make_environment(suite, task), + network_factory=network_factory, + seed=FLAGS.seed, + max_num_actor_steps=FLAGS.num_steps, + ) + # pylint:enable=g-long-lambda def main(_): - config = build_experiment_config() - if FLAGS.run_distributed: - program = experiments.make_distributed_experiment( - experiment=config, num_actors=4) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_experiment( - experiment=config, - eval_every=FLAGS.eval_every, - num_eval_episodes=FLAGS.evaluation_episodes) + config = build_experiment_config() + if FLAGS.run_distributed: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4 + ) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment( + experiment=config, + eval_every=FLAGS.eval_every, + num_eval_episodes=FLAGS.evaluation_episodes, + ) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/examples/baselines/rl_discrete/helpers.py b/examples/baselines/rl_discrete/helpers.py index 3c7afba4e2..9ac9e28a5c 100644 --- a/examples/baselines/rl_discrete/helpers.py +++ b/examples/baselines/rl_discrete/helpers.py @@ -18,24 +18,23 @@ import os from typing import Tuple -from absl import flags -from acme import specs -from acme import wrappers -from acme.agents.jax import dqn -from acme.jax import networks as networks_lib -from acme.jax import utils import atari_py # pylint:disable=unused-import import dm_env import gym import haiku as hk import jax.numpy as jnp +from absl import flags +from acme import specs, wrappers +from acme.agents.jax import dqn +from acme.jax import networks as networks_lib +from acme.jax import utils FLAGS = flags.FLAGS def make_atari_environment( - level: str = 'Pong', + level: str = "Pong", sticky_actions: bool = True, zero_discount_on_life_loss: bool = False, oar_wrapper: bool = False, @@ -45,69 +44,77 @@ def make_atari_environment( to_float: bool = True, scale_dims: Tuple[int, int] = (84, 84), ) -> dm_env.Environment: - """Loads the Atari environment.""" -# Internal logic. - version = 'v0' if sticky_actions else 'v4' - level_name = f'{level}NoFrameskip-{version}' - env = gym.make(level_name, full_action_space=True) - - wrapper_list = [ - wrappers.GymAtariAdapter, - functools.partial( - wrappers.AtariWrapper, - scale_dims=scale_dims, - to_float=to_float, - max_episode_len=108_000, - num_stacked_frames=num_stacked_frames, - flatten_frame_stack=flatten_frame_stack, - grayscaling=grayscaling, - zero_discount_on_life_loss=zero_discount_on_life_loss, - ), - wrappers.SinglePrecisionWrapper, - ] - - if oar_wrapper: - # E.g. IMPALA and R2D2 use this particular variant. - wrapper_list.append(wrappers.ObservationActionRewardWrapper) - - return wrappers.wrap_all(env, wrapper_list) - - -def make_dqn_atari_network( - environment_spec: specs.EnvironmentSpec) -> dqn.DQNNetworks: - """Creates networks for training DQN on Atari.""" - def network(inputs): - model = hk.Sequential([ - networks_lib.AtariTorso(), - hk.nets.MLP([512, environment_spec.actions.num_values]), - ]) - return model(inputs) - network_hk = hk.without_apply_rng(hk.transform(network)) - obs = utils.add_batch_dim(utils.zeros_like(environment_spec.observations)) - network = networks_lib.FeedForwardNetwork( - init=lambda rng: network_hk.init(rng, obs), apply=network_hk.apply) - typed_network = networks_lib.non_stochastic_network_to_typed(network) - return dqn.DQNNetworks(policy_network=typed_network) + """Loads the Atari environment.""" + # Internal logic. + version = "v0" if sticky_actions else "v4" + level_name = f"{level}NoFrameskip-{version}" + env = gym.make(level_name, full_action_space=True) + + wrapper_list = [ + wrappers.GymAtariAdapter, + functools.partial( + wrappers.AtariWrapper, + scale_dims=scale_dims, + to_float=to_float, + max_episode_len=108_000, + num_stacked_frames=num_stacked_frames, + flatten_frame_stack=flatten_frame_stack, + grayscaling=grayscaling, + zero_discount_on_life_loss=zero_discount_on_life_loss, + ), + wrappers.SinglePrecisionWrapper, + ] + + if oar_wrapper: + # E.g. IMPALA and R2D2 use this particular variant. + wrapper_list.append(wrappers.ObservationActionRewardWrapper) + + return wrappers.wrap_all(env, wrapper_list) + + +def make_dqn_atari_network(environment_spec: specs.EnvironmentSpec) -> dqn.DQNNetworks: + """Creates networks for training DQN on Atari.""" + + def network(inputs): + model = hk.Sequential( + [ + networks_lib.AtariTorso(), + hk.nets.MLP([512, environment_spec.actions.num_values]), + ] + ) + return model(inputs) + + network_hk = hk.without_apply_rng(hk.transform(network)) + obs = utils.add_batch_dim(utils.zeros_like(environment_spec.observations)) + network = networks_lib.FeedForwardNetwork( + init=lambda rng: network_hk.init(rng, obs), apply=network_hk.apply + ) + typed_network = networks_lib.non_stochastic_network_to_typed(network) + return dqn.DQNNetworks(policy_network=typed_network) def make_distributional_dqn_atari_network( - environment_spec: specs.EnvironmentSpec, - num_quantiles: int) -> dqn.DQNNetworks: - """Creates networks for training Distributional DQN on Atari.""" - - def network(inputs): - model = hk.Sequential([ - networks_lib.AtariTorso(), - hk.nets.MLP([512, environment_spec.actions.num_values * num_quantiles]), - ]) - q_dist = model(inputs).reshape(-1, environment_spec.actions.num_values, - num_quantiles) - q_values = jnp.mean(q_dist, axis=-1) - return q_values, q_dist - - network_hk = hk.without_apply_rng(hk.transform(network)) - obs = utils.add_batch_dim(utils.zeros_like(environment_spec.observations)) - network = networks_lib.FeedForwardNetwork( - init=lambda rng: network_hk.init(rng, obs), apply=network_hk.apply) - typed_network = networks_lib.non_stochastic_network_to_typed(network) - return dqn.DQNNetworks(policy_network=typed_network) + environment_spec: specs.EnvironmentSpec, num_quantiles: int +) -> dqn.DQNNetworks: + """Creates networks for training Distributional DQN on Atari.""" + + def network(inputs): + model = hk.Sequential( + [ + networks_lib.AtariTorso(), + hk.nets.MLP([512, environment_spec.actions.num_values * num_quantiles]), + ] + ) + q_dist = model(inputs).reshape( + -1, environment_spec.actions.num_values, num_quantiles + ) + q_values = jnp.mean(q_dist, axis=-1) + return q_values, q_dist + + network_hk = hk.without_apply_rng(hk.transform(network)) + obs = utils.add_batch_dim(utils.zeros_like(environment_spec.observations)) + network = networks_lib.FeedForwardNetwork( + init=lambda rng: network_hk.init(rng, obs), apply=network_hk.apply + ) + typed_network = networks_lib.non_stochastic_network_to_typed(network) + return dqn.DQNNetworks(policy_network=typed_network) diff --git a/examples/baselines/rl_discrete/run_dqn.py b/examples/baselines/rl_discrete/run_dqn.py index 760e84ea52..05e9f7c7be 100644 --- a/examples/baselines/rl_discrete/run_dqn.py +++ b/examples/baselines/rl_discrete/run_dqn.py @@ -14,70 +14,74 @@ """Example running DQN on discrete control tasks.""" -from absl import flags +import helpers +import launchpad as lp +from absl import app, flags + from acme.agents.jax import dqn from acme.agents.jax.dqn import losses -import helpers -from absl import app from acme.jax import experiments from acme.utils import lp_utils -import launchpad as lp - RUN_DISTRIBUTED = flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.') -ENV_NAME = flags.DEFINE_string('env_name', 'Pong', 'What environment to run') -SEED = flags.DEFINE_integer('seed', 0, 'Random seed.') -NUM_STEPS = flags.DEFINE_integer('num_steps', 1_000_000, - 'Number of env steps to run.') + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) +ENV_NAME = flags.DEFINE_string("env_name", "Pong", "What environment to run") +SEED = flags.DEFINE_integer("seed", 0, "Random seed.") +NUM_STEPS = flags.DEFINE_integer("num_steps", 1_000_000, "Number of env steps to run.") def build_experiment_config(): - """Builds DQN experiment config which can be executed in different ways.""" - # Create an environment, grab the spec, and use it to create networks. - env_name = ENV_NAME.value + """Builds DQN experiment config which can be executed in different ways.""" + # Create an environment, grab the spec, and use it to create networks. + env_name = ENV_NAME.value - def env_factory(seed): - del seed - return helpers.make_atari_environment( - level=env_name, sticky_actions=True, zero_discount_on_life_loss=False) + def env_factory(seed): + del seed + return helpers.make_atari_environment( + level=env_name, sticky_actions=True, zero_discount_on_life_loss=False + ) - # Construct the agent. - config = dqn.DQNConfig( - discount=0.99, - eval_epsilon=0., - learning_rate=5e-5, - n_step=1, - epsilon=0.01, - target_update_period=2000, - min_replay_size=20_000, - max_replay_size=1_000_000, - samples_per_insert=8, - batch_size=32) - loss_fn = losses.QLearning( - discount=config.discount, max_abs_reward=1.) + # Construct the agent. + config = dqn.DQNConfig( + discount=0.99, + eval_epsilon=0.0, + learning_rate=5e-5, + n_step=1, + epsilon=0.01, + target_update_period=2000, + min_replay_size=20_000, + max_replay_size=1_000_000, + samples_per_insert=8, + batch_size=32, + ) + loss_fn = losses.QLearning(discount=config.discount, max_abs_reward=1.0) - dqn_builder = dqn.DQNBuilder(config, loss_fn=loss_fn) + dqn_builder = dqn.DQNBuilder(config, loss_fn=loss_fn) - return experiments.ExperimentConfig( - builder=dqn_builder, - environment_factory=env_factory, - network_factory=helpers.make_dqn_atari_network, - seed=SEED.value, - max_num_actor_steps=NUM_STEPS.value) + return experiments.ExperimentConfig( + builder=dqn_builder, + environment_factory=env_factory, + network_factory=helpers.make_dqn_atari_network, + seed=SEED.value, + max_num_actor_steps=NUM_STEPS.value, + ) def main(_): - experiment_config = build_experiment_config() - if RUN_DISTRIBUTED.value: - program = experiments.make_distributed_experiment( - experiment=experiment_config, - num_actors=4 if lp_utils.is_local_run() else 128) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_experiment(experiment_config) + experiment_config = build_experiment_config() + if RUN_DISTRIBUTED.value: + program = experiments.make_distributed_experiment( + experiment=experiment_config, + num_actors=4 if lp_utils.is_local_run() else 128, + ) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment(experiment_config) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/examples/baselines/rl_discrete/run_impala.py b/examples/baselines/rl_discrete/run_impala.py index c566a7558f..d4ce072443 100644 --- a/examples/baselines/rl_discrete/run_impala.py +++ b/examples/baselines/rl_discrete/run_impala.py @@ -14,75 +14,80 @@ """Example running IMPALA on discrete control tasks.""" -from absl import flags -from acme.agents.jax import impala -from acme.agents.jax.impala import builder as impala_builder import helpers -from absl import app -from acme.jax import experiments -from acme.utils import lp_utils import launchpad as lp import optax +from absl import app, flags +from acme.agents.jax import impala +from acme.agents.jax.impala import builder as impala_builder +from acme.jax import experiments +from acme.utils import lp_utils # Flags which modify the behavior of the launcher. RUN_DISTRIBUTED = flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.') -ENV_NAME = flags.DEFINE_string('env_name', 'Pong', 'What environment to run.') -SEED = flags.DEFINE_integer('seed', 0, 'Random seed (experiment).') + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) +ENV_NAME = flags.DEFINE_string("env_name", "Pong", "What environment to run.") +SEED = flags.DEFINE_integer("seed", 0, "Random seed (experiment).") NUM_ACTOR_STEPS = flags.DEFINE_integer( - 'num_steps', 1_000_000, - 'Number of environment steps to run the agent for.') + "num_steps", 1_000_000, "Number of environment steps to run the agent for." +) _BATCH_SIZE = 32 _SEQUENCE_LENGTH = _SEQUENCE_PERIOD = 20 # Avoids overlapping sequences. def build_experiment_config(): - """Builds IMPALA experiment config which can be executed in different ways.""" + """Builds IMPALA experiment config which can be executed in different ways.""" - # Create an environment, grab the spec, and use it to create networks. - env_name = ENV_NAME.value + # Create an environment, grab the spec, and use it to create networks. + env_name = ENV_NAME.value - def env_factory(seed): - del seed - return helpers.make_atari_environment( - level=env_name, - sticky_actions=True, - zero_discount_on_life_loss=False, - oar_wrapper=True) + def env_factory(seed): + del seed + return helpers.make_atari_environment( + level=env_name, + sticky_actions=True, + zero_discount_on_life_loss=False, + oar_wrapper=True, + ) - # Construct the agent. - num_learner_steps = NUM_ACTOR_STEPS.value // (_SEQUENCE_PERIOD * _BATCH_SIZE) - lr_schedule = optax.linear_schedule(2e-4, 0., num_learner_steps) - config = impala.IMPALAConfig( - batch_size=_BATCH_SIZE, - sequence_length=_SEQUENCE_LENGTH, - sequence_period=_SEQUENCE_PERIOD, - learning_rate=lr_schedule, - entropy_cost=5e-3, - max_abs_reward=1., - ) + # Construct the agent. + num_learner_steps = NUM_ACTOR_STEPS.value // (_SEQUENCE_PERIOD * _BATCH_SIZE) + lr_schedule = optax.linear_schedule(2e-4, 0.0, num_learner_steps) + config = impala.IMPALAConfig( + batch_size=_BATCH_SIZE, + sequence_length=_SEQUENCE_LENGTH, + sequence_period=_SEQUENCE_PERIOD, + learning_rate=lr_schedule, + entropy_cost=5e-3, + max_abs_reward=1.0, + ) - return experiments.ExperimentConfig( - builder=impala_builder.IMPALABuilder(config), - environment_factory=env_factory, - network_factory=impala.make_atari_networks, - seed=SEED.value, - max_num_actor_steps=NUM_ACTOR_STEPS.value) + return experiments.ExperimentConfig( + builder=impala_builder.IMPALABuilder(config), + environment_factory=env_factory, + network_factory=impala.make_atari_networks, + seed=SEED.value, + max_num_actor_steps=NUM_ACTOR_STEPS.value, + ) def main(_): - experiment_config = build_experiment_config() - if RUN_DISTRIBUTED.value: - program = experiments.make_distributed_experiment( - experiment=experiment_config, - num_actors=4 if lp_utils.is_local_run() else 256) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_experiment(experiment_config) + experiment_config = build_experiment_config() + if RUN_DISTRIBUTED.value: + program = experiments.make_distributed_experiment( + experiment=experiment_config, + num_actors=4 if lp_utils.is_local_run() else 256, + ) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment(experiment_config) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/examples/baselines/rl_discrete/run_mdqn.py b/examples/baselines/rl_discrete/run_mdqn.py index 3b4c940097..31076f931c 100644 --- a/examples/baselines/rl_discrete/run_mdqn.py +++ b/examples/baselines/rl_discrete/run_mdqn.py @@ -14,71 +14,80 @@ """Example running Munchausen-DQN on discrete control tasks.""" -from absl import flags +import helpers +import launchpad as lp +from absl import app, flags + from acme.agents.jax import dqn from acme.agents.jax.dqn import losses -import helpers -from absl import app from acme.jax import experiments from acme.utils import lp_utils -import launchpad as lp - RUN_DISTRIBUTED = flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.') -ENV_NAME = flags.DEFINE_string('env_name', 'Pong', 'What environment to run') -SEED = flags.DEFINE_integer('seed', 0, 'Random seed.') -NUM_STEPS = flags.DEFINE_integer('num_steps', 1_000_000, - 'Number of env steps to run.') + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) +ENV_NAME = flags.DEFINE_string("env_name", "Pong", "What environment to run") +SEED = flags.DEFINE_integer("seed", 0, "Random seed.") +NUM_STEPS = flags.DEFINE_integer("num_steps", 1_000_000, "Number of env steps to run.") def build_experiment_config(): - """Builds MDQN experiment config which can be executed in different ways.""" - # Create an environment, grab the spec, and use it to create networks. - env_name = ENV_NAME.value + """Builds MDQN experiment config which can be executed in different ways.""" + # Create an environment, grab the spec, and use it to create networks. + env_name = ENV_NAME.value - def env_factory(seed): - del seed - return helpers.make_atari_environment( - level=env_name, sticky_actions=True, zero_discount_on_life_loss=False) + def env_factory(seed): + del seed + return helpers.make_atari_environment( + level=env_name, sticky_actions=True, zero_discount_on_life_loss=False + ) - # Construct the agent. - config = dqn.DQNConfig( - discount=0.99, - eval_epsilon=0., - learning_rate=5e-5, - n_step=1, - epsilon=0.01, - target_update_period=2000, - min_replay_size=20_000, - max_replay_size=1_000_000, - samples_per_insert=8, - batch_size=32) - loss_fn = losses.MunchausenQLearning( - discount=config.discount, max_abs_reward=1., huber_loss_parameter=1., - entropy_temperature=0.03, munchausen_coefficient=0.9) + # Construct the agent. + config = dqn.DQNConfig( + discount=0.99, + eval_epsilon=0.0, + learning_rate=5e-5, + n_step=1, + epsilon=0.01, + target_update_period=2000, + min_replay_size=20_000, + max_replay_size=1_000_000, + samples_per_insert=8, + batch_size=32, + ) + loss_fn = losses.MunchausenQLearning( + discount=config.discount, + max_abs_reward=1.0, + huber_loss_parameter=1.0, + entropy_temperature=0.03, + munchausen_coefficient=0.9, + ) - dqn_builder = dqn.DQNBuilder(config, loss_fn=loss_fn) + dqn_builder = dqn.DQNBuilder(config, loss_fn=loss_fn) - return experiments.ExperimentConfig( - builder=dqn_builder, - environment_factory=env_factory, - network_factory=helpers.make_dqn_atari_network, - seed=SEED.value, - max_num_actor_steps=NUM_STEPS.value) + return experiments.ExperimentConfig( + builder=dqn_builder, + environment_factory=env_factory, + network_factory=helpers.make_dqn_atari_network, + seed=SEED.value, + max_num_actor_steps=NUM_STEPS.value, + ) def main(_): - experiment_config = build_experiment_config() - if RUN_DISTRIBUTED.value: - program = experiments.make_distributed_experiment( - experiment=experiment_config, - num_actors=4 if lp_utils.is_local_run() else 128) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_experiment(experiment_config) + experiment_config = build_experiment_config() + if RUN_DISTRIBUTED.value: + program = experiments.make_distributed_experiment( + experiment=experiment_config, + num_actors=4 if lp_utils.is_local_run() else 128, + ) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment(experiment_config) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/examples/baselines/rl_discrete/run_muzero.py b/examples/baselines/rl_discrete/run_muzero.py index ca8fc83eec..3be4660855 100644 --- a/examples/baselines/rl_discrete/run_muzero.py +++ b/examples/baselines/rl_discrete/run_muzero.py @@ -17,113 +17,102 @@ import datetime import math -from absl import flags +import dm_env +import helpers +import launchpad as lp +from absl import app, flags + from acme import specs from acme.agents.jax import muzero -import helpers -from absl import app from acme.jax import experiments from acme.jax import inference_server as inference_server_lib from acme.utils import lp_utils -import dm_env -import launchpad as lp - -ENV_NAME = flags.DEFINE_string('env_name', 'Pong', 'What environment to run') -SEED = flags.DEFINE_integer('seed', 0, 'Random seed.') -NUM_STEPS = flags.DEFINE_integer( - 'num_steps', 2_000_000, 'Number of env steps to run.' -) -NUM_LEARNERS = flags.DEFINE_integer('num_learners', 1, 'Number of learners.') -NUM_ACTORS = flags.DEFINE_integer('num_actors', 4, 'Number of actors.') +ENV_NAME = flags.DEFINE_string("env_name", "Pong", "What environment to run") +SEED = flags.DEFINE_integer("seed", 0, "Random seed.") +NUM_STEPS = flags.DEFINE_integer("num_steps", 2_000_000, "Number of env steps to run.") +NUM_LEARNERS = flags.DEFINE_integer("num_learners", 1, "Number of learners.") +NUM_ACTORS = flags.DEFINE_integer("num_actors", 4, "Number of actors.") NUM_ACTORS_PER_NODE = flags.DEFINE_integer( - 'num_actors_per_node', - 2, - 'Number of colocated actors', + "num_actors_per_node", 2, "Number of colocated actors", ) RUN_DISTRIBUTED = flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.',) + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) def build_experiment_config() -> experiments.ExperimentConfig: - """Builds DQN experiment config which can be executed in different ways.""" - env_name = ENV_NAME.value - muzero_config = muzero.MZConfig() - - def env_factory(seed: int) -> dm_env.Environment: - del seed - return helpers.make_atari_environment( - level=env_name, - sticky_actions=True, - zero_discount_on_life_loss=True, - num_stacked_frames=1, - grayscaling=False, - to_float=False, + """Builds DQN experiment config which can be executed in different ways.""" + env_name = ENV_NAME.value + muzero_config = muzero.MZConfig() + + def env_factory(seed: int) -> dm_env.Environment: + del seed + return helpers.make_atari_environment( + level=env_name, + sticky_actions=True, + zero_discount_on_life_loss=True, + num_stacked_frames=1, + grayscaling=False, + to_float=False, + ) + + def network_factory(spec: specs.EnvironmentSpec,) -> muzero.MzNetworks: + return muzero.make_network(spec, stack_size=muzero_config.stack_size,) + + # Construct the builder. + env_spec = specs.make_environment_spec(env_factory(SEED.value)) + extra_spec = { + muzero.POLICY_PROBS_KEY: specs.Array( + shape=(env_spec.actions.num_values,), dtype="float32" + ), + muzero.RAW_VALUES_KEY: specs.Array(shape=(), dtype="float32"), + } + muzero_builder = muzero.MzBuilder( # pytype: disable=wrong-arg-types # jax-ndarray + muzero_config, extra_spec, ) - def network_factory( - spec: specs.EnvironmentSpec, - ) -> muzero.MzNetworks: - return muzero.make_network( - spec, - stack_size=muzero_config.stack_size, + checkpointing_config = experiments.CheckpointingConfig( + replay_checkpointing_time_delta_minutes=20, time_delta_minutes=1, + ) + return experiments.ExperimentConfig( + builder=muzero_builder, + environment_factory=env_factory, + network_factory=network_factory, + seed=SEED.value, + max_num_actor_steps=NUM_STEPS.value, + checkpointing=checkpointing_config, ) - - # Construct the builder. - env_spec = specs.make_environment_spec(env_factory(SEED.value)) - extra_spec = { - muzero.POLICY_PROBS_KEY: specs.Array( - shape=(env_spec.actions.num_values,), dtype='float32' - ), - muzero.RAW_VALUES_KEY: specs.Array(shape=(), dtype='float32'), - } - muzero_builder = muzero.MzBuilder( # pytype: disable=wrong-arg-types # jax-ndarray - muzero_config, - extra_spec, - ) - - checkpointing_config = experiments.CheckpointingConfig( - replay_checkpointing_time_delta_minutes=20, - time_delta_minutes=1, - ) - return experiments.ExperimentConfig( - builder=muzero_builder, - environment_factory=env_factory, - network_factory=network_factory, - seed=SEED.value, - max_num_actor_steps=NUM_STEPS.value, - checkpointing=checkpointing_config, - ) def main(_): - experiment_config = build_experiment_config() - - if not RUN_DISTRIBUTED.value: - raise NotImplementedError('Single threaded experiment not supported.') - - inference_server_config = inference_server_lib.InferenceServerConfig( - batch_size=64, - update_period=400, - timeout=datetime.timedelta( - seconds=1, - ), - ) - num_inference_servers = math.ceil( - NUM_ACTORS.value / (128 * NUM_ACTORS_PER_NODE.value), - ) - - program = experiments.make_distributed_experiment( - experiment=experiment_config, - num_actors=NUM_ACTORS.value, - num_learner_nodes=NUM_LEARNERS.value, - num_actors_per_node=NUM_ACTORS_PER_NODE.value, - num_inference_servers=num_inference_servers, - inference_server_config=inference_server_config, - ) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program,),) - - -if __name__ == '__main__': - app.run(main) + experiment_config = build_experiment_config() + + if not RUN_DISTRIBUTED.value: + raise NotImplementedError("Single threaded experiment not supported.") + + inference_server_config = inference_server_lib.InferenceServerConfig( + batch_size=64, update_period=400, timeout=datetime.timedelta(seconds=1,), + ) + num_inference_servers = math.ceil( + NUM_ACTORS.value / (128 * NUM_ACTORS_PER_NODE.value), + ) + + program = experiments.make_distributed_experiment( + experiment=experiment_config, + num_actors=NUM_ACTORS.value, + num_learner_nodes=NUM_LEARNERS.value, + num_actors_per_node=NUM_ACTORS_PER_NODE.value, + num_inference_servers=num_inference_servers, + inference_server_config=inference_server_config, + ) + lp.launch( + program, xm_resources=lp_utils.make_xm_docker_resources(program,), + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/baselines/rl_discrete/run_qr_dqn.py b/examples/baselines/rl_discrete/run_qr_dqn.py index ce3ba0852a..e36fd4cdee 100644 --- a/examples/baselines/rl_discrete/run_qr_dqn.py +++ b/examples/baselines/rl_discrete/run_qr_dqn.py @@ -14,75 +14,82 @@ """Example running QR-DQN on discrete control tasks.""" -from absl import flags +import helpers +import launchpad as lp +from absl import app, flags + from acme import specs from acme.agents.jax import dqn from acme.agents.jax.dqn import losses -import helpers -from absl import app from acme.jax import experiments from acme.utils import lp_utils -import launchpad as lp RUN_DISTRIBUTED = flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.') -ENV_NAME = flags.DEFINE_string('env_name', 'Pong', 'What environment to run') -SEED = flags.DEFINE_integer('seed', 0, 'Random seed.') -NUM_STEPS = flags.DEFINE_integer('num_steps', 1_000_000, - 'Number of env steps to run.') -NUM_QUANTILES = flags.DEFINE_integer('num_quantiles', 200, - 'Number of bins to use.') + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) +ENV_NAME = flags.DEFINE_string("env_name", "Pong", "What environment to run") +SEED = flags.DEFINE_integer("seed", 0, "Random seed.") +NUM_STEPS = flags.DEFINE_integer("num_steps", 1_000_000, "Number of env steps to run.") +NUM_QUANTILES = flags.DEFINE_integer("num_quantiles", 200, "Number of bins to use.") def build_experiment_config(): - """Builds QR-DQN experiment config which can be executed in different ways.""" - # Create an environment, grab the spec, and use it to create networks. - env_name = ENV_NAME.value + """Builds QR-DQN experiment config which can be executed in different ways.""" + # Create an environment, grab the spec, and use it to create networks. + env_name = ENV_NAME.value + + def env_factory(seed): + del seed + return helpers.make_atari_environment( + level=env_name, sticky_actions=True, zero_discount_on_life_loss=False + ) - def env_factory(seed): - del seed - return helpers.make_atari_environment( - level=env_name, sticky_actions=True, zero_discount_on_life_loss=False) + num_quantiles = NUM_QUANTILES.value - num_quantiles = NUM_QUANTILES.value - def network_factory(environment_spec: specs.EnvironmentSpec): - return helpers.make_distributional_dqn_atari_network( - environment_spec=environment_spec, num_quantiles=num_quantiles) + def network_factory(environment_spec: specs.EnvironmentSpec): + return helpers.make_distributional_dqn_atari_network( + environment_spec=environment_spec, num_quantiles=num_quantiles + ) - # Construct the agent. - config = dqn.DQNConfig( - discount=0.99, - eval_epsilon=0., - learning_rate=5e-5, - n_step=3, - epsilon=0.01 / 32, - target_update_period=2000, - min_replay_size=20_000, - max_replay_size=1_000_000, - samples_per_insert=8, - batch_size=32) - loss_fn = losses.QrDqn(num_atoms=NUM_QUANTILES.value, huber_param=1.) - dqn_builder = dqn.DistributionalDQNBuilder(config, loss_fn=loss_fn) + # Construct the agent. + config = dqn.DQNConfig( + discount=0.99, + eval_epsilon=0.0, + learning_rate=5e-5, + n_step=3, + epsilon=0.01 / 32, + target_update_period=2000, + min_replay_size=20_000, + max_replay_size=1_000_000, + samples_per_insert=8, + batch_size=32, + ) + loss_fn = losses.QrDqn(num_atoms=NUM_QUANTILES.value, huber_param=1.0) + dqn_builder = dqn.DistributionalDQNBuilder(config, loss_fn=loss_fn) - return experiments.ExperimentConfig( - builder=dqn_builder, - environment_factory=env_factory, - network_factory=network_factory, - seed=SEED.value, - max_num_actor_steps=NUM_STEPS.value) + return experiments.ExperimentConfig( + builder=dqn_builder, + environment_factory=env_factory, + network_factory=network_factory, + seed=SEED.value, + max_num_actor_steps=NUM_STEPS.value, + ) def main(_): - experiment_config = build_experiment_config() - if RUN_DISTRIBUTED.value: - program = experiments.make_distributed_experiment( - experiment=experiment_config, - num_actors=4 if lp_utils.is_local_run() else 16) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_experiment(experiment_config) + experiment_config = build_experiment_config() + if RUN_DISTRIBUTED.value: + program = experiments.make_distributed_experiment( + experiment=experiment_config, + num_actors=4 if lp_utils.is_local_run() else 16, + ) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment(experiment_config) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/examples/baselines/rl_discrete/run_r2d2.py b/examples/baselines/rl_discrete/run_r2d2.py index c22166cbb5..eb9b7e2fd9 100644 --- a/examples/baselines/rl_discrete/run_r2d2.py +++ b/examples/baselines/rl_discrete/run_r2d2.py @@ -14,80 +14,85 @@ """Example running R2D2 on discrete control tasks.""" -from absl import flags -from acme.agents.jax import r2d2 +import dm_env import helpers -from absl import app +import launchpad as lp +from absl import app, flags + +from acme.agents.jax import r2d2 from acme.jax import experiments from acme.utils import lp_utils -import dm_env -import launchpad as lp # Flags which modify the behavior of the launcher. flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.') -flags.DEFINE_string('env_name', 'Pong', 'What environment to run.') -flags.DEFINE_integer('seed', 0, 'Random seed (experiment).') -flags.DEFINE_integer('num_steps', 1_000_000, - 'Number of environment steps to run for.') + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) +flags.DEFINE_string("env_name", "Pong", "What environment to run.") +flags.DEFINE_integer("seed", 0, "Random seed (experiment).") +flags.DEFINE_integer("num_steps", 1_000_000, "Number of environment steps to run for.") FLAGS = flags.FLAGS def build_experiment_config(): - """Builds R2D2 experiment config which can be executed in different ways.""" - batch_size = 32 + """Builds R2D2 experiment config which can be executed in different ways.""" + batch_size = 32 - # The env_name must be dereferenced outside the environment factory as FLAGS - # cannot be pickled and pickling is necessary when launching distributed - # experiments via Launchpad. - env_name = FLAGS.env_name + # The env_name must be dereferenced outside the environment factory as FLAGS + # cannot be pickled and pickling is necessary when launching distributed + # experiments via Launchpad. + env_name = FLAGS.env_name - # Create an environment factory. - def environment_factory(seed: int) -> dm_env.Environment: - del seed - return helpers.make_atari_environment( - level=env_name, - sticky_actions=True, - zero_discount_on_life_loss=False, - oar_wrapper=True, - num_stacked_frames=1, - flatten_frame_stack=True, - grayscaling=False) + # Create an environment factory. + def environment_factory(seed: int) -> dm_env.Environment: + del seed + return helpers.make_atari_environment( + level=env_name, + sticky_actions=True, + zero_discount_on_life_loss=False, + oar_wrapper=True, + num_stacked_frames=1, + flatten_frame_stack=True, + grayscaling=False, + ) - # Configure the agent. - config = r2d2.R2D2Config( - burn_in_length=8, - trace_length=40, - sequence_period=20, - min_replay_size=10_000, - batch_size=batch_size, - prefetch_size=1, - samples_per_insert=1.0, - evaluation_epsilon=1e-3, - learning_rate=1e-4, - target_update_period=1200, - variable_update_period=100, - ) + # Configure the agent. + config = r2d2.R2D2Config( + burn_in_length=8, + trace_length=40, + sequence_period=20, + min_replay_size=10_000, + batch_size=batch_size, + prefetch_size=1, + samples_per_insert=1.0, + evaluation_epsilon=1e-3, + learning_rate=1e-4, + target_update_period=1200, + variable_update_period=100, + ) - return experiments.ExperimentConfig( - builder=r2d2.R2D2Builder(config), - network_factory=r2d2.make_atari_networks, - environment_factory=environment_factory, - seed=FLAGS.seed, - max_num_actor_steps=FLAGS.num_steps) + return experiments.ExperimentConfig( + builder=r2d2.R2D2Builder(config), + network_factory=r2d2.make_atari_networks, + environment_factory=environment_factory, + seed=FLAGS.seed, + max_num_actor_steps=FLAGS.num_steps, + ) def main(_): - config = build_experiment_config() - if FLAGS.run_distributed: - program = experiments.make_distributed_experiment( - experiment=config, num_actors=4 if lp_utils.is_local_run() else 80) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_experiment(experiment=config) + config = build_experiment_config() + if FLAGS.run_distributed: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4 if lp_utils.is_local_run() else 80 + ) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment(experiment=config) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/examples/bsuite/run_dqn.py b/examples/bsuite/run_dqn.py index 0a570b44ee..bd2a774852 100644 --- a/examples/bsuite/run_dqn.py +++ b/examples/bsuite/run_dqn.py @@ -14,47 +14,44 @@ """Example running DQN on BSuite in a single process.""" -from absl import app -from absl import flags +import bsuite +import sonnet as snt +from absl import app, flags import acme -from acme import specs -from acme import wrappers +from acme import specs, wrappers from acme.agents.tf import dqn -import bsuite -import sonnet as snt - # Bsuite flags -flags.DEFINE_string('bsuite_id', 'deep_sea/0', 'Bsuite id.') -flags.DEFINE_string('results_dir', '/tmp/bsuite', 'CSV results directory.') -flags.DEFINE_boolean('overwrite', False, 'Whether to overwrite csv results.') +flags.DEFINE_string("bsuite_id", "deep_sea/0", "Bsuite id.") +flags.DEFINE_string("results_dir", "/tmp/bsuite", "CSV results directory.") +flags.DEFINE_boolean("overwrite", False, "Whether to overwrite csv results.") FLAGS = flags.FLAGS def main(_): - # Create an environment and grab the spec. - raw_environment = bsuite.load_and_record_to_csv( - bsuite_id=FLAGS.bsuite_id, - results_dir=FLAGS.results_dir, - overwrite=FLAGS.overwrite, - ) - environment = wrappers.SinglePrecisionWrapper(raw_environment) - environment_spec = specs.make_environment_spec(environment) - - network = snt.Sequential([ - snt.Flatten(), - snt.nets.MLP([50, 50, environment_spec.actions.num_values]) - ]) - - # Construct the agent. - agent = dqn.DQN( - environment_spec=environment_spec, network=network) - - # Run the environment loop. - loop = acme.EnvironmentLoop(environment, agent) - loop.run(num_episodes=environment.bsuite_num_episodes) # pytype: disable=attribute-error - - -if __name__ == '__main__': - app.run(main) + # Create an environment and grab the spec. + raw_environment = bsuite.load_and_record_to_csv( + bsuite_id=FLAGS.bsuite_id, + results_dir=FLAGS.results_dir, + overwrite=FLAGS.overwrite, + ) + environment = wrappers.SinglePrecisionWrapper(raw_environment) + environment_spec = specs.make_environment_spec(environment) + + network = snt.Sequential( + [snt.Flatten(), snt.nets.MLP([50, 50, environment_spec.actions.num_values])] + ) + + # Construct the agent. + agent = dqn.DQN(environment_spec=environment_spec, network=network) + + # Run the environment loop. + loop = acme.EnvironmentLoop(environment, agent) + loop.run( + num_episodes=environment.bsuite_num_episodes + ) # pytype: disable=attribute-error + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/bsuite/run_impala.py b/examples/bsuite/run_impala.py index c592e604eb..eb9cd80783 100644 --- a/examples/bsuite/run_impala.py +++ b/examples/bsuite/run_impala.py @@ -14,56 +14,59 @@ """Runs IMPALA on bsuite locally.""" -from absl import app -from absl import flags +import bsuite +import sonnet as snt +from absl import app, flags + import acme -from acme import specs -from acme import wrappers +from acme import specs, wrappers from acme.agents.tf import impala from acme.tf import networks -import bsuite -import sonnet as snt # Bsuite flags -flags.DEFINE_string('bsuite_id', 'deep_sea/0', 'Bsuite id.') -flags.DEFINE_string('results_dir', '/tmp/bsuite', 'CSV results directory.') -flags.DEFINE_boolean('overwrite', False, 'Whether to overwrite csv results.') +flags.DEFINE_string("bsuite_id", "deep_sea/0", "Bsuite id.") +flags.DEFINE_string("results_dir", "/tmp/bsuite", "CSV results directory.") +flags.DEFINE_boolean("overwrite", False, "Whether to overwrite csv results.") FLAGS = flags.FLAGS def make_network(action_spec: specs.DiscreteArray) -> snt.RNNCore: - return snt.DeepRNN([ - snt.Flatten(), - snt.nets.MLP([50, 50]), - snt.LSTM(20), - networks.PolicyValueHead(action_spec.num_values), - ]) + return snt.DeepRNN( + [ + snt.Flatten(), + snt.nets.MLP([50, 50]), + snt.LSTM(20), + networks.PolicyValueHead(action_spec.num_values), + ] + ) def main(_): - # Create an environment and grab the spec. - raw_environment = bsuite.load_and_record_to_csv( - bsuite_id=FLAGS.bsuite_id, - results_dir=FLAGS.results_dir, - overwrite=FLAGS.overwrite, - ) - environment = wrappers.SinglePrecisionWrapper(raw_environment) - environment_spec = specs.make_environment_spec(environment) + # Create an environment and grab the spec. + raw_environment = bsuite.load_and_record_to_csv( + bsuite_id=FLAGS.bsuite_id, + results_dir=FLAGS.results_dir, + overwrite=FLAGS.overwrite, + ) + environment = wrappers.SinglePrecisionWrapper(raw_environment) + environment_spec = specs.make_environment_spec(environment) - # Create the networks to optimize. - network = make_network(environment_spec.actions) + # Create the networks to optimize. + network = make_network(environment_spec.actions) - agent = impala.IMPALA( - environment_spec=environment_spec, - network=network, - sequence_length=3, - sequence_period=3, - ) + agent = impala.IMPALA( + environment_spec=environment_spec, + network=network, + sequence_length=3, + sequence_period=3, + ) - # Run the environment loop. - loop = acme.EnvironmentLoop(environment, agent) - loop.run(num_episodes=environment.bsuite_num_episodes) # pytype: disable=attribute-error + # Run the environment loop. + loop = acme.EnvironmentLoop(environment, agent) + loop.run( + num_episodes=environment.bsuite_num_episodes + ) # pytype: disable=attribute-error -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/examples/bsuite/run_mcts.py b/examples/bsuite/run_mcts.py index dc89b88da1..99e736a84a 100644 --- a/examples/bsuite/run_mcts.py +++ b/examples/bsuite/run_mcts.py @@ -16,90 +16,92 @@ from typing import Tuple -from absl import app -from absl import flags +import bsuite +import dm_env +import sonnet as snt +from absl import app, flags +from bsuite.logging import csv_logging + import acme -from acme import specs -from acme import wrappers +from acme import specs, wrappers from acme.agents.tf import mcts from acme.agents.tf.mcts import models -from acme.agents.tf.mcts.models import mlp -from acme.agents.tf.mcts.models import simulator +from acme.agents.tf.mcts.models import mlp, simulator from acme.tf import networks -import bsuite -from bsuite.logging import csv_logging -import dm_env -import sonnet as snt # Bsuite flags -flags.DEFINE_string('bsuite_id', 'deep_sea/0', 'Bsuite id.') -flags.DEFINE_string('results_dir', '/tmp/bsuite', 'CSV results directory.') -flags.DEFINE_boolean('overwrite', False, 'Whether to overwrite csv results.') +flags.DEFINE_string("bsuite_id", "deep_sea/0", "Bsuite id.") +flags.DEFINE_string("results_dir", "/tmp/bsuite", "CSV results directory.") +flags.DEFINE_boolean("overwrite", False, "Whether to overwrite csv results.") # Agent flags -flags.DEFINE_boolean('simulator', True, 'Simulator or learned model?') +flags.DEFINE_boolean("simulator", True, "Simulator or learned model?") FLAGS = flags.FLAGS def make_env_and_model( - bsuite_id: str, - results_dir: str, - overwrite: bool) -> Tuple[dm_env.Environment, models.Model]: - """Create environment and corresponding model (learned or simulator).""" - raw_env = bsuite.load_from_id(bsuite_id) - if FLAGS.simulator: - model = simulator.Simulator(raw_env) # pytype: disable=attribute-error - else: - model = mlp.MLPModel( - specs.make_environment_spec(raw_env), - replay_capacity=1000, - batch_size=16, - hidden_sizes=(50,), + bsuite_id: str, results_dir: str, overwrite: bool +) -> Tuple[dm_env.Environment, models.Model]: + """Create environment and corresponding model (learned or simulator).""" + raw_env = bsuite.load_from_id(bsuite_id) + if FLAGS.simulator: + model = simulator.Simulator(raw_env) # pytype: disable=attribute-error + else: + model = mlp.MLPModel( + specs.make_environment_spec(raw_env), + replay_capacity=1000, + batch_size=16, + hidden_sizes=(50,), + ) + environment = csv_logging.wrap_environment( + raw_env, bsuite_id, results_dir, overwrite ) - environment = csv_logging.wrap_environment( - raw_env, bsuite_id, results_dir, overwrite) - environment = wrappers.SinglePrecisionWrapper(environment) + environment = wrappers.SinglePrecisionWrapper(environment) - return environment, model + return environment, model def make_network(action_spec: specs.DiscreteArray) -> snt.Module: - return snt.Sequential([ - snt.Flatten(), - snt.nets.MLP([50, 50]), - networks.PolicyValueHead(action_spec.num_values), - ]) + return snt.Sequential( + [ + snt.Flatten(), + snt.nets.MLP([50, 50]), + networks.PolicyValueHead(action_spec.num_values), + ] + ) def main(_): - # Create an environment and environment model. - environment, model = make_env_and_model( - bsuite_id=FLAGS.bsuite_id, - results_dir=FLAGS.results_dir, - overwrite=FLAGS.overwrite, - ) - environment_spec = specs.make_environment_spec(environment) - - # Create the network and optimizer. - network = make_network(environment_spec.actions) - optimizer = snt.optimizers.Adam(learning_rate=1e-3) - - # Construct the agent. - agent = mcts.MCTS( - environment_spec=environment_spec, - model=model, - network=network, - optimizer=optimizer, - discount=0.99, - replay_capacity=10000, - n_step=1, - batch_size=16, - num_simulations=50, - ) - - # Run the environment loop. - loop = acme.EnvironmentLoop(environment, agent) - loop.run(num_episodes=environment.bsuite_num_episodes) # pytype: disable=attribute-error - - -if __name__ == '__main__': - app.run(main) + # Create an environment and environment model. + environment, model = make_env_and_model( + bsuite_id=FLAGS.bsuite_id, + results_dir=FLAGS.results_dir, + overwrite=FLAGS.overwrite, + ) + environment_spec = specs.make_environment_spec(environment) + + # Create the network and optimizer. + network = make_network(environment_spec.actions) + optimizer = snt.optimizers.Adam(learning_rate=1e-3) + + # Construct the agent. + agent = mcts.MCTS( + environment_spec=environment_spec, + model=model, + network=network, + optimizer=optimizer, + discount=0.99, + replay_capacity=10000, + n_step=1, + batch_size=16, + num_simulations=50, + ) + + # Run the environment loop. + loop = acme.EnvironmentLoop(environment, agent) + loop.run( + num_episodes=environment.bsuite_num_episodes + ) # pytype: disable=attribute-error + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/multiagent/multigrid/helpers.py b/examples/multiagent/multigrid/helpers.py index 6e961ce42c..cae4a0883b 100644 --- a/examples/multiagent/multigrid/helpers.py +++ b/examples/multiagent/multigrid/helpers.py @@ -17,34 +17,38 @@ import functools from typing import Any, Dict, NamedTuple, Sequence +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import tensorflow_probability + from acme import specs from acme.agents.jax import ppo from acme.agents.jax.multiagent.decentralized import factories from acme.jax import networks as networks_lib from acme.jax import utils as acme_jax_utils from acme.multiagent import types as ma_types -import haiku as hk -import jax -import jax.numpy as jnp -import numpy as np -import tensorflow_probability tfp = tensorflow_probability.substrates.jax tfd = tfp.distributions class CategoricalParams(NamedTuple): - """Parameters for a categorical distribution.""" - logits: jnp.ndarray + """Parameters for a categorical distribution.""" + + logits: jnp.ndarray -def multigrid_obs_preproc(obs: Dict[str, Any], - conv_filters: int = 8, - conv_kernel: int = 3, - scalar_fc: int = 5, - scalar_name: str = 'direction', - scalar_dim: int = 4) -> jnp.ndarray: - """Conducts preprocessing on 'multigrid' environment dict observations. +def multigrid_obs_preproc( + obs: Dict[str, Any], + conv_filters: int = 8, + conv_kernel: int = 3, + scalar_fc: int = 5, + scalar_name: str = "direction", + scalar_dim: int = 4, +) -> jnp.ndarray: + """Conducts preprocessing on 'multigrid' environment dict observations. The preprocessing applied here is similar to those in: https://github.com/google-research/google-research/blob/master/social_rl/multiagent_tfagents/multigrid_networks.py @@ -67,107 +71,113 @@ def multigrid_obs_preproc(obs: Dict[str, Any], out: output observation. """ - def _cast_and_scale(x, scale_by=10.0): - if isinstance(x, jnp.ndarray): - x = x.astype(jnp.float32) - return x / scale_by - - outputs = [] - - if 'image' in obs.keys(): - image_preproc = hk.Sequential([ - _cast_and_scale, - hk.Conv2D(output_channels=conv_filters, kernel_shape=conv_kernel), - jax.nn.relu, - hk.Flatten() - ]) - outputs.append(image_preproc(obs['image'])) - - if 'position' in obs.keys(): - position_preproc = hk.Sequential([_cast_and_scale, hk.Linear(scalar_fc)]) - outputs.append(position_preproc(obs['position'])) - - if scalar_name in obs.keys(): - direction_preproc = hk.Sequential([ - functools.partial(jax.nn.one_hot, num_classes=scalar_dim), - hk.Flatten(), - hk.Linear(scalar_fc) - ]) - outputs.append(direction_preproc(obs[scalar_name])) - - out = jnp.concatenate(outputs, axis=-1) - return out + def _cast_and_scale(x, scale_by=10.0): + if isinstance(x, jnp.ndarray): + x = x.astype(jnp.float32) + return x / scale_by + + outputs = [] + + if "image" in obs.keys(): + image_preproc = hk.Sequential( + [ + _cast_and_scale, + hk.Conv2D(output_channels=conv_filters, kernel_shape=conv_kernel), + jax.nn.relu, + hk.Flatten(), + ] + ) + outputs.append(image_preproc(obs["image"])) + + if "position" in obs.keys(): + position_preproc = hk.Sequential([_cast_and_scale, hk.Linear(scalar_fc)]) + outputs.append(position_preproc(obs["position"])) + + if scalar_name in obs.keys(): + direction_preproc = hk.Sequential( + [ + functools.partial(jax.nn.one_hot, num_classes=scalar_dim), + hk.Flatten(), + hk.Linear(scalar_fc), + ] + ) + outputs.append(direction_preproc(obs[scalar_name])) + + out = jnp.concatenate(outputs, axis=-1) + return out def make_multigrid_dqn_networks( - environment_spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: - """Returns DQN networks used by the agent in the multigrid environment.""" - # Check that multigrid environment is defined with discrete actions, 0-indexed - assert np.issubdtype(environment_spec.actions.dtype, np.integer), ( - 'Expected multigrid environment to have discrete actions with int dtype' - f' but environment_spec.actions.dtype == {environment_spec.actions.dtype}' - ) - assert environment_spec.actions.minimum == 0, ( - 'Expected multigrid environment to have 0-indexed action indices, but' - f' environment_spec.actions.minimum == {environment_spec.actions.minimum}' - ) - num_actions = environment_spec.actions.maximum + 1 - - def network(inputs): - model = hk.Sequential([ - hk.Flatten(), - hk.nets.MLP([50, 50, num_actions]), - ]) - processed_inputs = multigrid_obs_preproc(inputs) - return model(processed_inputs) - - network_hk = hk.without_apply_rng(hk.transform(network)) - dummy_obs = acme_jax_utils.add_batch_dim( - acme_jax_utils.zeros_like(environment_spec.observations)) - - return networks_lib.FeedForwardNetwork( - init=lambda rng: network_hk.init(rng, dummy_obs), apply=network_hk.apply) + environment_spec: specs.EnvironmentSpec, +) -> networks_lib.FeedForwardNetwork: + """Returns DQN networks used by the agent in the multigrid environment.""" + # Check that multigrid environment is defined with discrete actions, 0-indexed + assert np.issubdtype(environment_spec.actions.dtype, np.integer), ( + "Expected multigrid environment to have discrete actions with int dtype" + f" but environment_spec.actions.dtype == {environment_spec.actions.dtype}" + ) + assert environment_spec.actions.minimum == 0, ( + "Expected multigrid environment to have 0-indexed action indices, but" + f" environment_spec.actions.minimum == {environment_spec.actions.minimum}" + ) + num_actions = environment_spec.actions.maximum + 1 + + def network(inputs): + model = hk.Sequential([hk.Flatten(), hk.nets.MLP([50, 50, num_actions]),]) + processed_inputs = multigrid_obs_preproc(inputs) + return model(processed_inputs) + + network_hk = hk.without_apply_rng(hk.transform(network)) + dummy_obs = acme_jax_utils.add_batch_dim( + acme_jax_utils.zeros_like(environment_spec.observations) + ) + + return networks_lib.FeedForwardNetwork( + init=lambda rng: network_hk.init(rng, dummy_obs), apply=network_hk.apply + ) def make_multigrid_ppo_networks( environment_spec: specs.EnvironmentSpec, hidden_layer_sizes: Sequence[int] = (64, 64), ) -> ppo.PPONetworks: - """Returns PPO networks used by the agent in the multigrid environments.""" - - # Check that multigrid environment is defined with discrete actions, 0-indexed - assert np.issubdtype(environment_spec.actions.dtype, np.integer), ( - 'Expected multigrid environment to have discrete actions with int dtype' - f' but environment_spec.actions.dtype == {environment_spec.actions.dtype}' - ) - assert environment_spec.actions.minimum == 0, ( - 'Expected multigrid environment to have 0-indexed action indices, but' - f' environment_spec.actions.minimum == {environment_spec.actions.minimum}' - ) - num_actions = environment_spec.actions.maximum + 1 - - def forward_fn(inputs): - processed_inputs = multigrid_obs_preproc(inputs) - trunk = hk.nets.MLP(hidden_layer_sizes, activation=jnp.tanh) - h = trunk(processed_inputs) - logits = hk.Linear(num_actions)(h) - values = hk.Linear(1)(h) - values = jnp.squeeze(values, axis=-1) - return (CategoricalParams(logits=logits), values) - - # Transform into pure functions. - forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) - - dummy_obs = acme_jax_utils.zeros_like(environment_spec.observations) - dummy_obs = acme_jax_utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. - network = networks_lib.FeedForwardNetwork( - lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply) - return make_categorical_ppo_networks(network) # pylint:disable=undefined-variable + """Returns PPO networks used by the agent in the multigrid environments.""" + + # Check that multigrid environment is defined with discrete actions, 0-indexed + assert np.issubdtype(environment_spec.actions.dtype, np.integer), ( + "Expected multigrid environment to have discrete actions with int dtype" + f" but environment_spec.actions.dtype == {environment_spec.actions.dtype}" + ) + assert environment_spec.actions.minimum == 0, ( + "Expected multigrid environment to have 0-indexed action indices, but" + f" environment_spec.actions.minimum == {environment_spec.actions.minimum}" + ) + num_actions = environment_spec.actions.maximum + 1 + + def forward_fn(inputs): + processed_inputs = multigrid_obs_preproc(inputs) + trunk = hk.nets.MLP(hidden_layer_sizes, activation=jnp.tanh) + h = trunk(processed_inputs) + logits = hk.Linear(num_actions)(h) + values = hk.Linear(1)(h) + values = jnp.squeeze(values, axis=-1) + return (CategoricalParams(logits=logits), values) + + # Transform into pure functions. + forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) + + dummy_obs = acme_jax_utils.zeros_like(environment_spec.observations) + dummy_obs = acme_jax_utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. + network = networks_lib.FeedForwardNetwork( + lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply + ) + return make_categorical_ppo_networks(network) # pylint:disable=undefined-variable def make_categorical_ppo_networks( - network: networks_lib.FeedForwardNetwork) -> ppo.PPONetworks: - """Constructs a PPONetworks for Categorical Policy from FeedForwardNetwork. + network: networks_lib.FeedForwardNetwork, +) -> ppo.PPONetworks: + """Constructs a PPONetworks for Categorical Policy from FeedForwardNetwork. Args: network: a transformed Haiku network (or equivalent in other libraries) that @@ -177,33 +187,34 @@ def make_categorical_ppo_networks( A PPONetworks instance with pure functions wrapping the input network. """ - def log_prob(params: CategoricalParams, action): - return tfd.Categorical(logits=params.logits).log_prob(action) + def log_prob(params: CategoricalParams, action): + return tfd.Categorical(logits=params.logits).log_prob(action) - def entropy(params: CategoricalParams, key: networks_lib.PRNGKey): - del key - return tfd.Categorical(logits=params.logits).entropy() + def entropy(params: CategoricalParams, key: networks_lib.PRNGKey): + del key + return tfd.Categorical(logits=params.logits).entropy() - def sample(params: CategoricalParams, key: networks_lib.PRNGKey): - return tfd.Categorical(logits=params.logits).sample(seed=key) + def sample(params: CategoricalParams, key: networks_lib.PRNGKey): + return tfd.Categorical(logits=params.logits).sample(seed=key) - def sample_eval(params: CategoricalParams, key: networks_lib.PRNGKey): - del key - return tfd.Categorical(logits=params.logits).mode() + def sample_eval(params: CategoricalParams, key: networks_lib.PRNGKey): + del key + return tfd.Categorical(logits=params.logits).mode() - return ppo.PPONetworks( - network=network, - log_prob=log_prob, - entropy=entropy, - sample=sample, - sample_eval=sample_eval) + return ppo.PPONetworks( + network=network, + log_prob=log_prob, + entropy=entropy, + sample=sample, + sample_eval=sample_eval, + ) def init_default_multigrid_network( - agent_type: str, - agent_spec: specs.EnvironmentSpec) -> ma_types.Networks: - """Returns default networks for multigrid environment.""" - if agent_type == factories.DefaultSupportedAgent.PPO: - return make_multigrid_ppo_networks(agent_spec) - else: - raise ValueError(f'Unsupported agent type: {agent_type}.') + agent_type: str, agent_spec: specs.EnvironmentSpec +) -> ma_types.Networks: + """Returns default networks for multigrid environment.""" + if agent_type == factories.DefaultSupportedAgent.PPO: + return make_multigrid_ppo_networks(agent_spec) + else: + raise ValueError(f"Unsupported agent type: {agent_type}.") diff --git a/examples/multiagent/multigrid/run_multigrid.py b/examples/multiagent/multigrid/run_multigrid.py index 0b5ef0a06d..a4a44b8756 100644 --- a/examples/multiagent/multigrid/run_multigrid.py +++ b/examples/multiagent/multigrid/run_multigrid.py @@ -16,96 +16,104 @@ """Multiagent multigrid training run example.""" from typing import Callable, Dict -from absl import flags +import dm_env +import helpers +import launchpad as lp +from absl import app, flags from acme import specs from acme.agents.jax.multiagent import decentralized -from absl import app -import helpers from acme.jax import experiments from acme.jax import types as jax_types from acme.multiagent import types as ma_types from acme.utils import lp_utils from acme.wrappers import multigrid_wrapper -import dm_env -import launchpad as lp FLAGS = flags.FLAGS _RUN_DISTRIBUTED = flags.DEFINE_bool( - 'run_distributed', True, 'Should an agent be executed in a distributed ' - 'way. If False, will run single-threaded.') -_NUM_STEPS = flags.DEFINE_integer('num_steps', 10000, - 'Number of env steps to run training for.') -_EVAL_EVERY = flags.DEFINE_integer('eval_every', 1000, - 'How often to run evaluation.') -_ENV_NAME = flags.DEFINE_string('env_name', 'MultiGrid-Empty-5x5-v0', - 'What environment to run.') -_BATCH_SIZE = flags.DEFINE_integer('batch_size', 64, 'Batch size.') -_SEED = flags.DEFINE_integer('seed', 0, 'Random seed.') + "run_distributed", + True, + "Should an agent be executed in a distributed " + "way. If False, will run single-threaded.", +) +_NUM_STEPS = flags.DEFINE_integer( + "num_steps", 10000, "Number of env steps to run training for." +) +_EVAL_EVERY = flags.DEFINE_integer("eval_every", 1000, "How often to run evaluation.") +_ENV_NAME = flags.DEFINE_string( + "env_name", "MultiGrid-Empty-5x5-v0", "What environment to run." +) +_BATCH_SIZE = flags.DEFINE_integer("batch_size", 64, "Batch size.") +_SEED = flags.DEFINE_integer("seed", 0, "Random seed.") def _make_environment_factory(env_name: str) -> jax_types.EnvironmentFactory: + def environment_factory(seed: int) -> dm_env.Environment: + del seed + return multigrid_wrapper.make_multigrid_environment(env_name) - def environment_factory(seed: int) -> dm_env.Environment: - del seed - return multigrid_wrapper.make_multigrid_environment(env_name) - - return environment_factory + return environment_factory def _make_network_factory( agent_types: Dict[ma_types.AgentID, ma_types.GenericAgent] ) -> Callable[[specs.EnvironmentSpec], ma_types.MultiAgentNetworks]: + def environment_factory( + environment_spec: specs.EnvironmentSpec, + ) -> ma_types.MultiAgentNetworks: + return decentralized.network_factory( + environment_spec, agent_types, helpers.init_default_multigrid_network + ) - def environment_factory( - environment_spec: specs.EnvironmentSpec) -> ma_types.MultiAgentNetworks: - return decentralized.network_factory(environment_spec, agent_types, - helpers.init_default_multigrid_network) - - return environment_factory + return environment_factory def build_experiment_config() -> experiments.ExperimentConfig[ - ma_types.MultiAgentNetworks, ma_types.MultiAgentPolicyNetworks, - ma_types.MultiAgentSample]: - """Returns a config for multigrid experiments.""" - - environment_factory = _make_environment_factory(_ENV_NAME.value) - environment = environment_factory(_SEED.value) - agent_types = { - str(i): decentralized.DefaultSupportedAgent.PPO - for i in range(environment.num_agents) # pytype: disable=attribute-error - } - # Example of how to set custom sub-agent configurations. - ppo_configs = {'unroll_length': 16, 'num_minibatches': 32, 'num_epochs': 10} - config_overrides = { - k: ppo_configs for k, v in agent_types.items() if v == 'ppo' - } - - configs = decentralized.default_config_factory(agent_types, _BATCH_SIZE.value, - config_overrides) - - builder = decentralized.DecentralizedMultiAgentBuilder( - agent_types=agent_types, agent_configs=configs) - - return experiments.ExperimentConfig( - builder=builder, - environment_factory=environment_factory, - network_factory=_make_network_factory(agent_types=agent_types), - seed=_SEED.value, - max_num_actor_steps=_NUM_STEPS.value) + ma_types.MultiAgentNetworks, + ma_types.MultiAgentPolicyNetworks, + ma_types.MultiAgentSample, +]: + """Returns a config for multigrid experiments.""" + + environment_factory = _make_environment_factory(_ENV_NAME.value) + environment = environment_factory(_SEED.value) + agent_types = { + str(i): decentralized.DefaultSupportedAgent.PPO + for i in range(environment.num_agents) # pytype: disable=attribute-error + } + # Example of how to set custom sub-agent configurations. + ppo_configs = {"unroll_length": 16, "num_minibatches": 32, "num_epochs": 10} + config_overrides = {k: ppo_configs for k, v in agent_types.items() if v == "ppo"} + + configs = decentralized.default_config_factory( + agent_types, _BATCH_SIZE.value, config_overrides + ) + + builder = decentralized.DecentralizedMultiAgentBuilder( + agent_types=agent_types, agent_configs=configs + ) + + return experiments.ExperimentConfig( + builder=builder, + environment_factory=environment_factory, + network_factory=_make_network_factory(agent_types=agent_types), + seed=_SEED.value, + max_num_actor_steps=_NUM_STEPS.value, + ) def main(_): - config = build_experiment_config() - if _RUN_DISTRIBUTED.value: - program = experiments.make_distributed_experiment( - experiment=config, num_actors=4) - lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) - else: - experiments.run_experiment( - experiment=config, eval_every=_EVAL_EVERY.value, num_eval_episodes=5) - - -if __name__ == '__main__': - app.run(main) + config = build_experiment_config() + if _RUN_DISTRIBUTED.value: + program = experiments.make_distributed_experiment( + experiment=config, num_actors=4 + ) + lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program)) + else: + experiments.run_experiment( + experiment=config, eval_every=_EVAL_EVERY.value, num_eval_episodes=5 + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/offline/bc_utils.py b/examples/offline/bc_utils.py index c4237bcb14..734bfb704e 100644 --- a/examples/offline/bc_utils.py +++ b/examples/offline/bc_utils.py @@ -18,77 +18,76 @@ import operator from typing import Callable -from acme import core -from acme import environment_loop -from acme import specs -from acme import types -from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.agents.jax import bc -from acme.agents.tf.dqfd import bsuite_demonstrations -from acme.jax import networks as networks_lib -from acme.jax import types as jax_types -from acme.jax import utils -from acme.jax import variable_utils -from acme.jax.deprecated import offline_distributed_layout -from acme.utils import counting -from acme.utils import loggers -from acme.wrappers import single_precision import bsuite import dm_env import haiku as hk import jax import jax.numpy as jnp -from jax.scipy import special import rlax import tensorflow as tf import tree +from jax.scipy import special + +from acme import core, environment_loop, specs, types +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors, bc +from acme.agents.tf.dqfd import bsuite_demonstrations +from acme.jax import networks as networks_lib +from acme.jax import types as jax_types +from acme.jax import utils, variable_utils +from acme.jax.deprecated import offline_distributed_layout +from acme.utils import counting, loggers +from acme.wrappers import single_precision def make_network(spec: specs.EnvironmentSpec) -> bc.BCNetworks: - """Creates networks used by the agent.""" - num_actions = spec.actions.num_values + """Creates networks used by the agent.""" + num_actions = spec.actions.num_values - def actor_fn(obs, is_training=True, key=None): - # is_training and key allows to utilize train/test dependant modules - # like dropout. - del is_training - del key - mlp = hk.Sequential( - [hk.Flatten(), - hk.nets.MLP([64, 64, num_actions])]) - return mlp(obs) + def actor_fn(obs, is_training=True, key=None): + # is_training and key allows to utilize train/test dependant modules + # like dropout. + del is_training + del key + mlp = hk.Sequential([hk.Flatten(), hk.nets.MLP([64, 64, num_actions])]) + return mlp(obs) - policy = hk.without_apply_rng(hk.transform(actor_fn)) + policy = hk.without_apply_rng(hk.transform(actor_fn)) - # Create dummy observations to create network parameters. - dummy_obs = utils.zeros_like(spec.observations) - dummy_obs = utils.add_batch_dim(dummy_obs) + # Create dummy observations to create network parameters. + dummy_obs = utils.zeros_like(spec.observations) + dummy_obs = utils.add_batch_dim(dummy_obs) - policy_network = bc.BCPolicyNetwork(lambda key: policy.init(key, dummy_obs), - policy.apply) + policy_network = bc.BCPolicyNetwork( + lambda key: policy.init(key, dummy_obs), policy.apply + ) - def sample_fn(logits: networks_lib.NetworkOutput, - key: jax_types.PRNGKey) -> networks_lib.Action: - return rlax.epsilon_greedy(epsilon=0.0).sample(key, logits) + def sample_fn( + logits: networks_lib.NetworkOutput, key: jax_types.PRNGKey + ) -> networks_lib.Action: + return rlax.epsilon_greedy(epsilon=0.0).sample(key, logits) - def log_prob(logits: networks_lib.Params, - actions: networks_lib.Action) -> networks_lib.LogProb: - logits_actions = jnp.sum( - jax.nn.one_hot(actions, logits.shape[-1]) * logits, axis=-1) - logits_actions = logits_actions - special.logsumexp(logits, axis=-1) - return logits_actions + def log_prob( + logits: networks_lib.Params, actions: networks_lib.Action + ) -> networks_lib.LogProb: + logits_actions = jnp.sum( + jax.nn.one_hot(actions, logits.shape[-1]) * logits, axis=-1 + ) + logits_actions = logits_actions - special.logsumexp(logits, axis=-1) + return logits_actions - return bc.BCNetworks(policy_network, sample_fn, log_prob) + return bc.BCNetworks(policy_network, sample_fn, log_prob) def _n_step_transition_from_episode( observations: types.NestedTensor, actions: tf.Tensor, rewards: tf.Tensor, - discounts: tf.Tensor, n_step: int, - additional_discount: float) -> types.Transition: - """Produce Reverb-like N-step transition from a full episode. + discounts: tf.Tensor, + n_step: int, + additional_discount: float, +) -> types.Transition: + """Produce Reverb-like N-step transition from a full episode. Observations, actions, rewards and discounts have the same length. This function will ignore the first reward and discount and the last action. @@ -105,62 +104,61 @@ def _n_step_transition_from_episode( A types.Transition. """ - max_index = tf.shape(rewards)[0] - 1 - first = tf.random.uniform( - shape=(), minval=0, maxval=max_index - 1, dtype=tf.int32) - last = tf.minimum(first + n_step, max_index) + max_index = tf.shape(rewards)[0] - 1 + first = tf.random.uniform(shape=(), minval=0, maxval=max_index - 1, dtype=tf.int32) + last = tf.minimum(first + n_step, max_index) - o_t = tree.map_structure(operator.itemgetter(first), observations) - a_t = tree.map_structure(operator.itemgetter(first), actions) - o_tp1 = tree.map_structure(operator.itemgetter(last), observations) + o_t = tree.map_structure(operator.itemgetter(first), observations) + a_t = tree.map_structure(operator.itemgetter(first), actions) + o_tp1 = tree.map_structure(operator.itemgetter(last), observations) - # 0, 1, ..., n-1. - discount_range = tf.cast(tf.range(last - first), tf.float32) - # 1, g, ..., g^{n-1}. - additional_discounts = tf.pow(additional_discount, discount_range) - # 1, d_t, d_t * d_{t+1}, ..., d_t * ... * d_{t+n-2}. - discounts = tf.concat([[1.], tf.math.cumprod(discounts[first:last - 1])], 0) - # 1, g * d_t, ..., g^{n-1} * d_t * ... * d_{t+n-2}. - discounts *= additional_discounts - # r_t + g * d_t * r_{t+1} + ... + g^{n-1} * d_t * ... * d_{t+n-2} * r_{t+n-1} - # We have to shift rewards by one so last=max_index corresponds to transitions - # that include the last reward. - r_t = tf.reduce_sum(rewards[first + 1:last + 1] * discounts) + # 0, 1, ..., n-1. + discount_range = tf.cast(tf.range(last - first), tf.float32) + # 1, g, ..., g^{n-1}. + additional_discounts = tf.pow(additional_discount, discount_range) + # 1, d_t, d_t * d_{t+1}, ..., d_t * ... * d_{t+n-2}. + discounts = tf.concat([[1.0], tf.math.cumprod(discounts[first : last - 1])], 0) + # 1, g * d_t, ..., g^{n-1} * d_t * ... * d_{t+n-2}. + discounts *= additional_discounts + #  r_t + g * d_t * r_{t+1} + ... + g^{n-1} * d_t * ... * d_{t+n-2} * r_{t+n-1} + # We have to shift rewards by one so last=max_index corresponds to transitions + # that include the last reward. + r_t = tf.reduce_sum(rewards[first + 1 : last + 1] * discounts) - # g^{n-1} * d_{t} * ... * d_{t+n-1}. - d_t = discounts[-1] + # g^{n-1} * d_{t} * ... * d_{t+n-1}. + d_t = discounts[-1] - return types.Transition(o_t, a_t, r_t, d_t, o_tp1) + return types.Transition(o_t, a_t, r_t, d_t, o_tp1) def make_environment(training: bool = True): - del training - env = bsuite.load(experiment_name='deep_sea', kwargs={'size': 10}) - return single_precision.SinglePrecisionWrapper(env) + del training + env = bsuite.load(experiment_name="deep_sea", kwargs={"size": 10}) + return single_precision.SinglePrecisionWrapper(env) -def make_demonstrations(env: dm_env.Environment, - batch_size: int) -> tf.data.Dataset: - """Prepare the dataset of demonstrations.""" - batch_dataset = bsuite_demonstrations.make_dataset(env, stochastic=False) - # Combine with demonstration dataset. - transition = functools.partial( - _n_step_transition_from_episode, n_step=1, additional_discount=1.) +def make_demonstrations(env: dm_env.Environment, batch_size: int) -> tf.data.Dataset: + """Prepare the dataset of demonstrations.""" + batch_dataset = bsuite_demonstrations.make_dataset(env, stochastic=False) + # Combine with demonstration dataset. + transition = functools.partial( + _n_step_transition_from_episode, n_step=1, additional_discount=1.0 + ) - dataset = batch_dataset.map(transition) + dataset = batch_dataset.map(transition) - # Batch and prefetch. - dataset = dataset.batch(batch_size, drop_remainder=True) - dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) + # Batch and prefetch. + dataset = dataset.batch(batch_size, drop_remainder=True) + dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) - return dataset + return dataset def make_actor_evaluator( environment_factory: Callable[[bool], dm_env.Environment], evaluator_network: actor_core_lib.FeedForwardPolicy, ) -> offline_distributed_layout.EvaluatorFactory: - """Makes an evaluator that runs the agent on the environment. + """Makes an evaluator that runs the agent on the environment. Args: environment_factory: Function that creates a dm_env. @@ -170,37 +168,35 @@ def make_actor_evaluator( actor_evaluator: Function that returns a Worker that will be executed by launchpad. """ - def actor_evaluator( - random_key: networks_lib.PRNGKey, - variable_source: core.VariableSource, - counter: counting.Counter, - ): - """The evaluation process.""" - # Create the actor loading the weights from variable source. - actor_core = actor_core_lib.batched_feed_forward_to_actor_core( - evaluator_network) - # Inference happens on CPU, so it's better to move variables there too. - variable_client = variable_utils.VariableClient(variable_source, 'policy', - device='cpu') - actor = actors.GenericActor( - actor_core, random_key, variable_client, backend='cpu') - - # Logger. - logger = loggers.make_default_logger( - 'evaluator', steps_key='evaluator_steps') - - # Create environment and evaluator networks - environment = environment_factory(False) - - # Create logger and counter. - counter = counting.Counter(counter, 'evaluator') - - # Create the run loop and return it. - return environment_loop.EnvironmentLoop( - environment, - actor, - counter, - logger, - ) - return actor_evaluator + def actor_evaluator( + random_key: networks_lib.PRNGKey, + variable_source: core.VariableSource, + counter: counting.Counter, + ): + """The evaluation process.""" + # Create the actor loading the weights from variable source. + actor_core = actor_core_lib.batched_feed_forward_to_actor_core( + evaluator_network + ) + # Inference happens on CPU, so it's better to move variables there too. + variable_client = variable_utils.VariableClient( + variable_source, "policy", device="cpu" + ) + actor = actors.GenericActor( + actor_core, random_key, variable_client, backend="cpu" + ) + + # Logger. + logger = loggers.make_default_logger("evaluator", steps_key="evaluator_steps") + + # Create environment and evaluator networks + environment = environment_factory(False) + + # Create logger and counter. + counter = counting.Counter(counter, "evaluator") + + # Create the run loop and return it. + return environment_loop.EnvironmentLoop(environment, actor, counter, logger,) + + return actor_evaluator diff --git a/examples/offline/run_bc.py b/examples/offline/run_bc.py index 9d45180144..7d77d0fdae 100644 --- a/examples/offline/run_bc.py +++ b/examples/offline/run_bc.py @@ -17,53 +17,54 @@ import functools import operator -from absl import app -from absl import flags -import acme -from acme import specs -from acme import types -from acme.agents.tf import actors -from acme.agents.tf.bc import learning -from acme.agents.tf.dqfd import bsuite_demonstrations -from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers -from acme.wrappers import single_precision import bsuite import reverb import sonnet as snt import tensorflow as tf import tree import trfl +from absl import app, flags + +import acme +from acme import specs, types +from acme.agents.tf import actors +from acme.agents.tf.bc import learning +from acme.agents.tf.dqfd import bsuite_demonstrations +from acme.tf import utils as tf2_utils +from acme.utils import counting, loggers +from acme.wrappers import single_precision # Bsuite flags -flags.DEFINE_string('bsuite_id', 'deep_sea/0', 'Bsuite id.') -flags.DEFINE_string('results_dir', '/tmp/bsuite', 'CSV results directory.') -flags.DEFINE_boolean('overwrite', False, 'Whether to overwrite csv results.') +flags.DEFINE_string("bsuite_id", "deep_sea/0", "Bsuite id.") +flags.DEFINE_string("results_dir", "/tmp/bsuite", "CSV results directory.") +flags.DEFINE_boolean("overwrite", False, "Whether to overwrite csv results.") # Agent flags -flags.DEFINE_float('learning_rate', 2e-4, 'Learning rate.') -flags.DEFINE_integer('batch_size', 16, 'Batch size.') -flags.DEFINE_float('epsilon', 0., 'Epsilon for the epsilon greedy in the env.') -flags.DEFINE_integer('evaluate_every', 100, 'Evaluation period.') -flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') +flags.DEFINE_float("learning_rate", 2e-4, "Learning rate.") +flags.DEFINE_integer("batch_size", 16, "Batch size.") +flags.DEFINE_float("epsilon", 0.0, "Epsilon for the epsilon greedy in the env.") +flags.DEFINE_integer("evaluate_every", 100, "Evaluation period.") +flags.DEFINE_integer("evaluation_episodes", 10, "Evaluation episodes.") FLAGS = flags.FLAGS def make_policy_network(action_spec: specs.DiscreteArray) -> snt.Module: - return snt.Sequential([ - snt.Flatten(), - snt.nets.MLP([64, 64, action_spec.num_values]), - ]) + return snt.Sequential( + [snt.Flatten(), snt.nets.MLP([64, 64, action_spec.num_values]),] + ) # TODO(b/152733199): Move this function to acme utils. -def _n_step_transition_from_episode(observations: types.NestedTensor, - actions: tf.Tensor, rewards: tf.Tensor, - discounts: tf.Tensor, n_step: int, - additional_discount: float): - """Produce Reverb-like N-step transition from a full episode. +def _n_step_transition_from_episode( + observations: types.NestedTensor, + actions: tf.Tensor, + rewards: tf.Tensor, + discounts: tf.Tensor, + n_step: int, + additional_discount: float, +): + """Produce Reverb-like N-step transition from a full episode. Observations, actions, rewards and discounts have the same length. This function will ignore the first reward and discount and the last action. @@ -80,116 +81,121 @@ def _n_step_transition_from_episode(observations: types.NestedTensor, (o_t, a_t, r_t, d_t, o_tp1) tuple. """ - max_index = tf.shape(rewards)[0] - 1 - first = tf.random.uniform( - shape=(), minval=0, maxval=max_index - 1, dtype=tf.int32) - last = tf.minimum(first + n_step, max_index) - - o_t = tree.map_structure(operator.itemgetter(first), observations) - a_t = tree.map_structure(operator.itemgetter(first), actions) - o_tp1 = tree.map_structure(operator.itemgetter(last), observations) - - # 0, 1, ..., n-1. - discount_range = tf.cast(tf.range(last - first), tf.float32) - # 1, g, ..., g^{n-1}. - additional_discounts = tf.pow(additional_discount, discount_range) - # 1, d_t, d_t * d_{t+1}, ..., d_t * ... * d_{t+n-2}. - discounts = tf.concat([[1.], tf.math.cumprod(discounts[first:last - 1])], 0) - # 1, g * d_t, ..., g^{n-1} * d_t * ... * d_{t+n-2}. - discounts *= additional_discounts - # r_t + g * d_t * r_{t+1} + ... + g^{n-1} * d_t * ... * d_{t+n-2} * r_{t+n-1} - # We have to shift rewards by one so last=max_index corresponds to transitions - # that include the last reward. - r_t = tf.reduce_sum(rewards[first + 1:last + 1] * discounts) - - # g^{n-1} * d_{t} * ... * d_{t+n-1}. - d_t = discounts[-1] - - # Reverb requires every sample to be given a key and priority. - # In the supervised learning case for BC, neither of those will be used. - # We set the key to `0` and the priorities probabilities to `1`, but that - # should not matter much. - key = tf.constant(0, tf.uint64) - probability = tf.constant(1.0, tf.float64) - table_size = tf.constant(1, tf.int64) - priority = tf.constant(1.0, tf.float64) - times_sampled = tf.constant(1, tf.int32) - info = reverb.SampleInfo( - key=key, - probability=probability, - table_size=table_size, - priority=priority, - times_sampled=times_sampled, - ) - - return reverb.ReplaySample(info=info, data=(o_t, a_t, r_t, d_t, o_tp1)) + max_index = tf.shape(rewards)[0] - 1 + first = tf.random.uniform(shape=(), minval=0, maxval=max_index - 1, dtype=tf.int32) + last = tf.minimum(first + n_step, max_index) + + o_t = tree.map_structure(operator.itemgetter(first), observations) + a_t = tree.map_structure(operator.itemgetter(first), actions) + o_tp1 = tree.map_structure(operator.itemgetter(last), observations) + + # 0, 1, ..., n-1. + discount_range = tf.cast(tf.range(last - first), tf.float32) + # 1, g, ..., g^{n-1}. + additional_discounts = tf.pow(additional_discount, discount_range) + # 1, d_t, d_t * d_{t+1}, ..., d_t * ... * d_{t+n-2}. + discounts = tf.concat([[1.0], tf.math.cumprod(discounts[first : last - 1])], 0) + # 1, g * d_t, ..., g^{n-1} * d_t * ... * d_{t+n-2}. + discounts *= additional_discounts + #  r_t + g * d_t * r_{t+1} + ... + g^{n-1} * d_t * ... * d_{t+n-2} * r_{t+n-1} + # We have to shift rewards by one so last=max_index corresponds to transitions + # that include the last reward. + r_t = tf.reduce_sum(rewards[first + 1 : last + 1] * discounts) + + # g^{n-1} * d_{t} * ... * d_{t+n-1}. + d_t = discounts[-1] + + # Reverb requires every sample to be given a key and priority. + # In the supervised learning case for BC, neither of those will be used. + # We set the key to `0` and the priorities probabilities to `1`, but that + # should not matter much. + key = tf.constant(0, tf.uint64) + probability = tf.constant(1.0, tf.float64) + table_size = tf.constant(1, tf.int64) + priority = tf.constant(1.0, tf.float64) + times_sampled = tf.constant(1, tf.int32) + info = reverb.SampleInfo( + key=key, + probability=probability, + table_size=table_size, + priority=priority, + times_sampled=times_sampled, + ) + + return reverb.ReplaySample(info=info, data=(o_t, a_t, r_t, d_t, o_tp1)) def main(_): - # Create an environment and grab the spec. - raw_environment = bsuite.load_and_record_to_csv( - bsuite_id=FLAGS.bsuite_id, - results_dir=FLAGS.results_dir, - overwrite=FLAGS.overwrite, - ) - environment = single_precision.SinglePrecisionWrapper(raw_environment) - environment_spec = specs.make_environment_spec(environment) - - # Build demonstration dataset. - if hasattr(raw_environment, 'raw_env'): - raw_environment = raw_environment.raw_env - - batch_dataset = bsuite_demonstrations.make_dataset(raw_environment, - stochastic=False) - # Combine with demonstration dataset. - transition = functools.partial( - _n_step_transition_from_episode, n_step=1, additional_discount=1.) - - dataset = batch_dataset.map(transition) - - # Batch and prefetch. - dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) - dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) - - # Create the networks to optimize. - policy_network = make_policy_network(environment_spec.actions) - - # If the agent is non-autoregressive use epsilon=0 which will be a greedy - # policy. - evaluator_network = snt.Sequential([ - policy_network, - lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(), - ]) - - # Ensure that we create the variables before proceeding (maybe not needed). - tf2_utils.create_variables(policy_network, [environment_spec.observations]) - - counter = counting.Counter() - learner_counter = counting.Counter(counter, prefix='learner') - - # Create the actor which defines how we take actions. - evaluation_network = actors.FeedForwardActor(evaluator_network) - - eval_loop = acme.EnvironmentLoop( - environment=environment, - actor=evaluation_network, - counter=counter, - logger=loggers.TerminalLogger('evaluation', time_delta=1.)) - - # The learner updates the parameters (and initializes them). - learner = learning.BCLearner( - network=policy_network, - learning_rate=FLAGS.learning_rate, - dataset=dataset, - counter=learner_counter) - - # Run the environment loop. - while True: - for _ in range(FLAGS.evaluate_every): - learner.step() - learner_counter.increment(learner_steps=FLAGS.evaluate_every) - eval_loop.run(FLAGS.evaluation_episodes) - - -if __name__ == '__main__': - app.run(main) + # Create an environment and grab the spec. + raw_environment = bsuite.load_and_record_to_csv( + bsuite_id=FLAGS.bsuite_id, + results_dir=FLAGS.results_dir, + overwrite=FLAGS.overwrite, + ) + environment = single_precision.SinglePrecisionWrapper(raw_environment) + environment_spec = specs.make_environment_spec(environment) + + # Build demonstration dataset. + if hasattr(raw_environment, "raw_env"): + raw_environment = raw_environment.raw_env + + batch_dataset = bsuite_demonstrations.make_dataset( + raw_environment, stochastic=False + ) + # Combine with demonstration dataset. + transition = functools.partial( + _n_step_transition_from_episode, n_step=1, additional_discount=1.0 + ) + + dataset = batch_dataset.map(transition) + + # Batch and prefetch. + dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) + dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) + + # Create the networks to optimize. + policy_network = make_policy_network(environment_spec.actions) + + # If the agent is non-autoregressive use epsilon=0 which will be a greedy + # policy. + evaluator_network = snt.Sequential( + [ + policy_network, + lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(), + ] + ) + + # Ensure that we create the variables before proceeding (maybe not needed). + tf2_utils.create_variables(policy_network, [environment_spec.observations]) + + counter = counting.Counter() + learner_counter = counting.Counter(counter, prefix="learner") + + # Create the actor which defines how we take actions. + evaluation_network = actors.FeedForwardActor(evaluator_network) + + eval_loop = acme.EnvironmentLoop( + environment=environment, + actor=evaluation_network, + counter=counter, + logger=loggers.TerminalLogger("evaluation", time_delta=1.0), + ) + + # The learner updates the parameters (and initializes them). + learner = learning.BCLearner( + network=policy_network, + learning_rate=FLAGS.learning_rate, + dataset=dataset, + counter=learner_counter, + ) + + # Run the environment loop. + while True: + for _ in range(FLAGS.evaluate_every): + learner.step() + learner_counter.increment(learner_steps=FLAGS.evaluate_every) + eval_loop.run(FLAGS.evaluation_episodes) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/offline/run_bc_jax.py b/examples/offline/run_bc_jax.py index dc21b23a0e..1163b6eee6 100644 --- a/examples/offline/run_bc_jax.py +++ b/examples/offline/run_bc_jax.py @@ -14,85 +14,81 @@ """An example BC running on BSuite.""" -from absl import app -from absl import flags +import haiku as hk +import jax +import optax +import rlax +from absl import app, flags + import acme from acme import specs from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.agents.jax import bc +from acme.agents.jax import actors, bc from acme.examples.offline import bc_utils -from acme.jax import utils -from acme.jax import variable_utils +from acme.jax import utils, variable_utils from acme.utils import loggers -import haiku as hk -import jax -import optax -import rlax # Agent flags -flags.DEFINE_float('learning_rate', 1e-3, 'Learning rate.') -flags.DEFINE_integer('batch_size', 64, 'Batch size.') -flags.DEFINE_float('evaluation_epsilon', 0., - 'Epsilon for the epsilon greedy in the evaluation agent.') -flags.DEFINE_integer('evaluate_every', 20, 'Evaluation period.') -flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') -flags.DEFINE_integer('seed', 0, 'Random seed for learner and evaluator.') +flags.DEFINE_float("learning_rate", 1e-3, "Learning rate.") +flags.DEFINE_integer("batch_size", 64, "Batch size.") +flags.DEFINE_float( + "evaluation_epsilon", 0.0, "Epsilon for the epsilon greedy in the evaluation agent." +) +flags.DEFINE_integer("evaluate_every", 20, "Evaluation period.") +flags.DEFINE_integer("evaluation_episodes", 10, "Evaluation episodes.") +flags.DEFINE_integer("seed", 0, "Random seed for learner and evaluator.") FLAGS = flags.FLAGS def main(_): - # Create an environment and grab the spec. - environment = bc_utils.make_environment() - environment_spec = specs.make_environment_spec(environment) - - # Unwrap the environment to get the demonstrations. - dataset = bc_utils.make_demonstrations(environment.environment, - FLAGS.batch_size) - dataset = dataset.as_numpy_iterator() - - # Create the networks to optimize. - bc_networks = bc_utils.make_network(environment_spec) - - key = jax.random.PRNGKey(FLAGS.seed) - key, key1 = jax.random.split(key, 2) - - loss_fn = bc.logp() - - learner = bc.BCLearner( - networks=bc_networks, - random_key=key1, - loss_fn=loss_fn, - optimizer=optax.adam(FLAGS.learning_rate), - prefetching_iterator=utils.sharded_prefetch(dataset), - num_sgd_steps_per_step=1) - - def evaluator_network( - params: hk.Params, key: jax.Array, observation: jax.Array - ) -> jax.Array: - dist_params = bc_networks.policy_network.apply(params, observation) - return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample( - key, dist_params) - - actor_core = actor_core_lib.batched_feed_forward_to_actor_core( - evaluator_network) - variable_client = variable_utils.VariableClient( - learner, 'policy', device='cpu') - evaluator = actors.GenericActor( - actor_core, key, variable_client, backend='cpu') - - eval_loop = acme.EnvironmentLoop( - environment=environment, - actor=evaluator, - logger=loggers.TerminalLogger('evaluation', time_delta=0.)) - - # Run the environment loop. - while True: - for _ in range(FLAGS.evaluate_every): - learner.step() - eval_loop.run(FLAGS.evaluation_episodes) - - -if __name__ == '__main__': - app.run(main) + # Create an environment and grab the spec. + environment = bc_utils.make_environment() + environment_spec = specs.make_environment_spec(environment) + + # Unwrap the environment to get the demonstrations. + dataset = bc_utils.make_demonstrations(environment.environment, FLAGS.batch_size) + dataset = dataset.as_numpy_iterator() + + # Create the networks to optimize. + bc_networks = bc_utils.make_network(environment_spec) + + key = jax.random.PRNGKey(FLAGS.seed) + key, key1 = jax.random.split(key, 2) + + loss_fn = bc.logp() + + learner = bc.BCLearner( + networks=bc_networks, + random_key=key1, + loss_fn=loss_fn, + optimizer=optax.adam(FLAGS.learning_rate), + prefetching_iterator=utils.sharded_prefetch(dataset), + num_sgd_steps_per_step=1, + ) + + def evaluator_network( + params: hk.Params, key: jax.Array, observation: jax.Array + ) -> jax.Array: + dist_params = bc_networks.policy_network.apply(params, observation) + return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample(key, dist_params) + + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(evaluator_network) + variable_client = variable_utils.VariableClient(learner, "policy", device="cpu") + evaluator = actors.GenericActor(actor_core, key, variable_client, backend="cpu") + + eval_loop = acme.EnvironmentLoop( + environment=environment, + actor=evaluator, + logger=loggers.TerminalLogger("evaluation", time_delta=0.0), + ) + + # Run the environment loop. + while True: + for _ in range(FLAGS.evaluate_every): + learner.step() + eval_loop.run(FLAGS.evaluation_episodes) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/offline/run_bcq.py b/examples/offline/run_bcq.py index 3c2eefa883..7f8ea880fe 100644 --- a/examples/offline/run_bcq.py +++ b/examples/offline/run_bcq.py @@ -31,114 +31,117 @@ --num_shards=1 """ -from absl import app -from absl import flags -import acme -from acme import specs -from acme.agents.tf import actors -from acme.agents.tf import bcq -from acme.tf import networks -from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers import sonnet as snt import tensorflow as tf import trfl +from absl import app, flags +from deepmind_research.rl_unplugged import atari # type: ignore -from deepmind_research.rl_unplugged import atari # type: ignore +import acme +from acme import specs +from acme.agents.tf import actors, bcq +from acme.tf import networks +from acme.tf import utils as tf2_utils +from acme.utils import counting, loggers # Atari dataset flags -flags.DEFINE_string('dataset_path', None, 'Dataset path.') -flags.DEFINE_string('game', 'Pong', 'Dataset path.') -flags.DEFINE_integer('run', 1, 'Dataset path.') -flags.DEFINE_integer('num_shards', 100, 'Number of dataset shards.') -flags.DEFINE_integer('batch_size', 16, 'Batch size.') +flags.DEFINE_string("dataset_path", None, "Dataset path.") +flags.DEFINE_string("game", "Pong", "Dataset path.") +flags.DEFINE_integer("run", 1, "Dataset path.") +flags.DEFINE_integer("num_shards", 100, "Number of dataset shards.") +flags.DEFINE_integer("batch_size", 16, "Batch size.") # Agent flags -flags.DEFINE_float('bcq_threshold', 0.5, 'BCQ threshold.') -flags.DEFINE_float('learning_rate', 1e-4, 'Learning rate.') -flags.DEFINE_float('discount', 0.99, 'Discount.') -flags.DEFINE_float('importance_sampling_exponent', 0.2, - 'Importance sampling exponent.') -flags.DEFINE_integer('target_update_period', 2500, - ('Number of learner steps to perform before updating' - 'the target networks.')) +flags.DEFINE_float("bcq_threshold", 0.5, "BCQ threshold.") +flags.DEFINE_float("learning_rate", 1e-4, "Learning rate.") +flags.DEFINE_float("discount", 0.99, "Discount.") +flags.DEFINE_float("importance_sampling_exponent", 0.2, "Importance sampling exponent.") +flags.DEFINE_integer( + "target_update_period", + 2500, + ("Number of learner steps to perform before updating" "the target networks."), +) # Evaluation flags. -flags.DEFINE_float('epsilon', 0., 'Epsilon for the epsilon greedy in the env.') -flags.DEFINE_integer('evaluate_every', 100, 'Evaluation period.') -flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') +flags.DEFINE_float("epsilon", 0.0, "Epsilon for the epsilon greedy in the env.") +flags.DEFINE_integer("evaluate_every", 100, "Evaluation period.") +flags.DEFINE_integer("evaluation_episodes", 10, "Evaluation episodes.") FLAGS = flags.FLAGS def make_network(action_spec: specs.DiscreteArray) -> snt.Module: - return snt.Sequential([ - lambda x: tf.image.convert_image_dtype(x, tf.float32), - networks.DQNAtariNetwork(action_spec.num_values) - ]) + return snt.Sequential( + [ + lambda x: tf.image.convert_image_dtype(x, tf.float32), + networks.DQNAtariNetwork(action_spec.num_values), + ] + ) def main(_): - # Create an environment and grab the spec. - environment = atari.environment(FLAGS.game) - environment_spec = specs.make_environment_spec(environment) - - # Create dataset. - dataset = atari.dataset(path=FLAGS.dataset_path, - game=FLAGS.game, - run=FLAGS.run, - num_shards=FLAGS.num_shards) - # Discard extra inputs - dataset = dataset.map(lambda x: x._replace(data=x.data[:5])) - - # Batch and prefetch. - dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) - dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) - - # Build network. - g_network = make_network(environment_spec.actions) - q_network = make_network(environment_spec.actions) - network = networks.DiscreteFilteredQNetwork(g_network=g_network, - q_network=q_network, - threshold=FLAGS.bcq_threshold) - tf2_utils.create_variables(network, [environment_spec.observations]) - - evaluator_network = snt.Sequential([ - q_network, - lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(), - ]) - - # Counters. - counter = counting.Counter() - learner_counter = counting.Counter(counter, prefix='learner') - - # Create the actor which defines how we take actions. - evaluation_network = actors.FeedForwardActor(evaluator_network) - - eval_loop = acme.EnvironmentLoop( - environment=environment, - actor=evaluation_network, - counter=counter, - logger=loggers.TerminalLogger('evaluation', time_delta=1.)) - - # The learner updates the parameters (and initializes them). - learner = bcq.DiscreteBCQLearner( - network=network, - dataset=dataset, - learning_rate=FLAGS.learning_rate, - discount=FLAGS.discount, - importance_sampling_exponent=FLAGS.importance_sampling_exponent, - target_update_period=FLAGS.target_update_period, - counter=counter) - - # Run the environment loop. - while True: - for _ in range(FLAGS.evaluate_every): - learner.step() - learner_counter.increment(learner_steps=FLAGS.evaluate_every) - eval_loop.run(FLAGS.evaluation_episodes) - - -if __name__ == '__main__': - app.run(main) + # Create an environment and grab the spec. + environment = atari.environment(FLAGS.game) + environment_spec = specs.make_environment_spec(environment) + + # Create dataset. + dataset = atari.dataset( + path=FLAGS.dataset_path, + game=FLAGS.game, + run=FLAGS.run, + num_shards=FLAGS.num_shards, + ) + # Discard extra inputs + dataset = dataset.map(lambda x: x._replace(data=x.data[:5])) + + # Batch and prefetch. + dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) + dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) + + # Build network. + g_network = make_network(environment_spec.actions) + q_network = make_network(environment_spec.actions) + network = networks.DiscreteFilteredQNetwork( + g_network=g_network, q_network=q_network, threshold=FLAGS.bcq_threshold + ) + tf2_utils.create_variables(network, [environment_spec.observations]) + + evaluator_network = snt.Sequential( + [q_network, lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(),] + ) + + # Counters. + counter = counting.Counter() + learner_counter = counting.Counter(counter, prefix="learner") + + # Create the actor which defines how we take actions. + evaluation_network = actors.FeedForwardActor(evaluator_network) + + eval_loop = acme.EnvironmentLoop( + environment=environment, + actor=evaluation_network, + counter=counter, + logger=loggers.TerminalLogger("evaluation", time_delta=1.0), + ) + + # The learner updates the parameters (and initializes them). + learner = bcq.DiscreteBCQLearner( + network=network, + dataset=dataset, + learning_rate=FLAGS.learning_rate, + discount=FLAGS.discount, + importance_sampling_exponent=FLAGS.importance_sampling_exponent, + target_update_period=FLAGS.target_update_period, + counter=counter, + ) + + # Run the environment loop. + while True: + for _ in range(FLAGS.evaluate_every): + learner.step() + learner_counter.increment(learner_steps=FLAGS.evaluate_every) + eval_loop.run(FLAGS.evaluation_episodes) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/offline/run_cql_jax.py b/examples/offline/run_cql_jax.py index 15482933cf..6200900ddc 100644 --- a/examples/offline/run_cql_jax.py +++ b/examples/offline/run_cql_jax.py @@ -14,101 +14,109 @@ """An example CQL running on locomotion datasets (mujoco) from D4rl.""" -from absl import app -from absl import flags +import haiku as hk +import jax +import optax +from absl import app, flags + import acme from acme import specs from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.agents.jax import cql +from acme.agents.jax import actors, cql from acme.datasets import tfds from acme.examples.offline import helpers as gym_helpers from acme.jax import variable_utils from acme.utils import loggers -import haiku as hk -import jax -import optax # Agent flags -flags.DEFINE_integer('batch_size', 64, 'Batch size.') -flags.DEFINE_integer('evaluate_every', 20, 'Evaluation period.') -flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') +flags.DEFINE_integer("batch_size", 64, "Batch size.") +flags.DEFINE_integer("evaluate_every", 20, "Evaluation period.") +flags.DEFINE_integer("evaluation_episodes", 10, "Evaluation episodes.") flags.DEFINE_integer( - 'num_demonstrations', 10, - 'Number of demonstration episodes to load from the dataset. If None, loads the full dataset.' + "num_demonstrations", + 10, + "Number of demonstration episodes to load from the dataset. If None, loads the full dataset.", ) -flags.DEFINE_integer('seed', 0, 'Random seed for learner and evaluator.') +flags.DEFINE_integer("seed", 0, "Random seed for learner and evaluator.") # CQL specific flags. -flags.DEFINE_float('policy_learning_rate', 3e-5, 'Policy learning rate.') -flags.DEFINE_float('critic_learning_rate', 3e-4, 'Critic learning rate.') -flags.DEFINE_float('fixed_cql_coefficient', None, - 'Fixed CQL coefficient. If None, an adaptive one is used.') -flags.DEFINE_float('cql_lagrange_threshold', 10., - 'Lagrange threshold for the adaptive CQL coefficient.') +flags.DEFINE_float("policy_learning_rate", 3e-5, "Policy learning rate.") +flags.DEFINE_float("critic_learning_rate", 3e-4, "Critic learning rate.") +flags.DEFINE_float( + "fixed_cql_coefficient", + None, + "Fixed CQL coefficient. If None, an adaptive one is used.", +) +flags.DEFINE_float( + "cql_lagrange_threshold", + 10.0, + "Lagrange threshold for the adaptive CQL coefficient.", +) # Environment flags. -flags.DEFINE_string('env_name', 'HalfCheetah-v2', - 'Gym mujoco environment name.') +flags.DEFINE_string("env_name", "HalfCheetah-v2", "Gym mujoco environment name.") flags.DEFINE_string( - 'dataset_name', 'd4rl_mujoco_halfcheetah/v2-medium', - 'D4rl dataset name. Can be any locomotion dataset from ' - 'https://www.tensorflow.org/datasets/catalog/overview#d4rl.') + "dataset_name", + "d4rl_mujoco_halfcheetah/v2-medium", + "D4rl dataset name. Can be any locomotion dataset from " + "https://www.tensorflow.org/datasets/catalog/overview#d4rl.", +) FLAGS = flags.FLAGS def main(_): - key = jax.random.PRNGKey(FLAGS.seed) - key_demonstrations, key_learner = jax.random.split(key, 2) - - # Create an environment and grab the spec. - environment = gym_helpers.make_environment(task=FLAGS.env_name) - environment_spec = specs.make_environment_spec(environment) - - # Get a demonstrations dataset. - transitions_iterator = tfds.get_tfds_dataset(FLAGS.dataset_name, - FLAGS.num_demonstrations) - demonstrations = tfds.JaxInMemoryRandomSampleIterator( - transitions_iterator, key=key_demonstrations, batch_size=FLAGS.batch_size) - - # Create the networks to optimize. - networks = cql.make_networks(environment_spec) - - # Create the learner. - learner = cql.CQLLearner( - batch_size=FLAGS.batch_size, - networks=networks, - random_key=key_learner, - policy_optimizer=optax.adam(FLAGS.policy_learning_rate), - critic_optimizer=optax.adam(FLAGS.critic_learning_rate), - fixed_cql_coefficient=FLAGS.fixed_cql_coefficient, - cql_lagrange_threshold=FLAGS.cql_lagrange_threshold, - demonstrations=demonstrations, - num_sgd_steps_per_step=1) - - def evaluator_network( - params: hk.Params, key: jax.Array, observation: jax.Array - ) -> jax.Array: - dist_params = networks.policy_network.apply(params, observation) - return networks.sample_eval(dist_params, key) - - actor_core = actor_core_lib.batched_feed_forward_to_actor_core( - evaluator_network) - variable_client = variable_utils.VariableClient( - learner, 'policy', device='cpu') - evaluator = actors.GenericActor( - actor_core, key, variable_client, backend='cpu') - - eval_loop = acme.EnvironmentLoop( - environment=environment, - actor=evaluator, - logger=loggers.TerminalLogger('evaluation', time_delta=0.)) - - # Run the environment loop. - while True: - for _ in range(FLAGS.evaluate_every): - learner.step() - eval_loop.run(FLAGS.evaluation_episodes) - - -if __name__ == '__main__': - app.run(main) + key = jax.random.PRNGKey(FLAGS.seed) + key_demonstrations, key_learner = jax.random.split(key, 2) + + # Create an environment and grab the spec. + environment = gym_helpers.make_environment(task=FLAGS.env_name) + environment_spec = specs.make_environment_spec(environment) + + # Get a demonstrations dataset. + transitions_iterator = tfds.get_tfds_dataset( + FLAGS.dataset_name, FLAGS.num_demonstrations + ) + demonstrations = tfds.JaxInMemoryRandomSampleIterator( + transitions_iterator, key=key_demonstrations, batch_size=FLAGS.batch_size + ) + + # Create the networks to optimize. + networks = cql.make_networks(environment_spec) + + # Create the learner. + learner = cql.CQLLearner( + batch_size=FLAGS.batch_size, + networks=networks, + random_key=key_learner, + policy_optimizer=optax.adam(FLAGS.policy_learning_rate), + critic_optimizer=optax.adam(FLAGS.critic_learning_rate), + fixed_cql_coefficient=FLAGS.fixed_cql_coefficient, + cql_lagrange_threshold=FLAGS.cql_lagrange_threshold, + demonstrations=demonstrations, + num_sgd_steps_per_step=1, + ) + + def evaluator_network( + params: hk.Params, key: jax.Array, observation: jax.Array + ) -> jax.Array: + dist_params = networks.policy_network.apply(params, observation) + return networks.sample_eval(dist_params, key) + + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(evaluator_network) + variable_client = variable_utils.VariableClient(learner, "policy", device="cpu") + evaluator = actors.GenericActor(actor_core, key, variable_client, backend="cpu") + + eval_loop = acme.EnvironmentLoop( + environment=environment, + actor=evaluator, + logger=loggers.TerminalLogger("evaluation", time_delta=0.0), + ) + + # Run the environment loop. + while True: + for _ in range(FLAGS.evaluate_every): + learner.step() + eval_loop.run(FLAGS.evaluation_episodes) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/offline/run_crr_jax.py b/examples/offline/run_crr_jax.py index 14fc21b368..3d1c218a13 100644 --- a/examples/offline/run_crr_jax.py +++ b/examples/offline/run_crr_jax.py @@ -14,124 +14,126 @@ """An example CRR running on locomotion datasets (mujoco) from D4rl.""" -from absl import app -from absl import flags +import haiku as hk +import jax +import optax +import rlds +from absl import app, flags + import acme from acme import specs from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.agents.jax import crr +from acme.agents.jax import actors, crr from acme.datasets import tfds from acme.examples.offline import helpers as gym_helpers from acme.jax import variable_utils from acme.types import Transition from acme.utils import loggers -import haiku as hk -import jax -import optax -import rlds # Agent flags -flags.DEFINE_integer('batch_size', 64, 'Batch size.') -flags.DEFINE_integer('evaluate_every', 20, 'Evaluation period.') -flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') +flags.DEFINE_integer("batch_size", 64, "Batch size.") +flags.DEFINE_integer("evaluate_every", 20, "Evaluation period.") +flags.DEFINE_integer("evaluation_episodes", 10, "Evaluation episodes.") flags.DEFINE_integer( - 'num_demonstrations', 10, - 'Number of demonstration episodes to load from the dataset. If None, loads the full dataset.' + "num_demonstrations", + 10, + "Number of demonstration episodes to load from the dataset. If None, loads the full dataset.", ) -flags.DEFINE_integer('seed', 0, 'Random seed for learner and evaluator.') +flags.DEFINE_integer("seed", 0, "Random seed for learner and evaluator.") # CQL specific flags. -flags.DEFINE_float('policy_learning_rate', 3e-5, 'Policy learning rate.') -flags.DEFINE_float('critic_learning_rate', 3e-4, 'Critic learning rate.') -flags.DEFINE_float('discount', 0.99, 'Discount.') -flags.DEFINE_integer('target_update_period', 100, 'Target update periode.') -flags.DEFINE_integer('grad_updates_per_batch', 1, 'Grad updates per batch.') +flags.DEFINE_float("policy_learning_rate", 3e-5, "Policy learning rate.") +flags.DEFINE_float("critic_learning_rate", 3e-4, "Critic learning rate.") +flags.DEFINE_float("discount", 0.99, "Discount.") +flags.DEFINE_integer("target_update_period", 100, "Target update periode.") +flags.DEFINE_integer("grad_updates_per_batch", 1, "Grad updates per batch.") flags.DEFINE_bool( - 'use_sarsa_target', True, - 'Compute on-policy target using iterator actions rather than sampled ' - 'actions.' + "use_sarsa_target", + True, + "Compute on-policy target using iterator actions rather than sampled " "actions.", ) # Environment flags. -flags.DEFINE_string('env_name', 'HalfCheetah-v2', - 'Gym mujoco environment name.') +flags.DEFINE_string("env_name", "HalfCheetah-v2", "Gym mujoco environment name.") flags.DEFINE_string( - 'dataset_name', 'd4rl_mujoco_halfcheetah/v2-medium', - 'D4rl dataset name. Can be any locomotion dataset from ' - 'https://www.tensorflow.org/datasets/catalog/overview#d4rl.') + "dataset_name", + "d4rl_mujoco_halfcheetah/v2-medium", + "D4rl dataset name. Can be any locomotion dataset from " + "https://www.tensorflow.org/datasets/catalog/overview#d4rl.", +) FLAGS = flags.FLAGS def _add_next_action_extras(double_transitions: Transition) -> Transition: - return Transition( - observation=double_transitions.observation[0], - action=double_transitions.action[0], - reward=double_transitions.reward[0], - discount=double_transitions.discount[0], - next_observation=double_transitions.next_observation[0], - extras={'next_action': double_transitions.action[1]}) + return Transition( + observation=double_transitions.observation[0], + action=double_transitions.action[0], + reward=double_transitions.reward[0], + discount=double_transitions.discount[0], + next_observation=double_transitions.next_observation[0], + extras={"next_action": double_transitions.action[1]}, + ) def main(_): - key = jax.random.PRNGKey(FLAGS.seed) - key_demonstrations, key_learner = jax.random.split(key, 2) - - # Create an environment and grab the spec. - environment = gym_helpers.make_environment(task=FLAGS.env_name) - environment_spec = specs.make_environment_spec(environment) - - # Get a demonstrations dataset with next_actions extra. - transitions = tfds.get_tfds_dataset( - FLAGS.dataset_name, FLAGS.num_demonstrations) - double_transitions = rlds.transformations.batch( - transitions, size=2, shift=1, drop_remainder=True) - transitions = double_transitions.map(_add_next_action_extras) - demonstrations = tfds.JaxInMemoryRandomSampleIterator( - transitions, key=key_demonstrations, batch_size=FLAGS.batch_size) - - # Create the networks to optimize. - networks = crr.make_networks(environment_spec) - - # CRR policy loss function. - policy_loss_coeff_fn = crr.policy_loss_coeff_advantage_exp - - # Create the learner. - learner = crr.CRRLearner( - networks=networks, - random_key=key_learner, - discount=FLAGS.discount, - target_update_period=FLAGS.target_update_period, - policy_loss_coeff_fn=policy_loss_coeff_fn, - iterator=demonstrations, - policy_optimizer=optax.adam(FLAGS.policy_learning_rate), - critic_optimizer=optax.adam(FLAGS.critic_learning_rate), - grad_updates_per_batch=FLAGS.grad_updates_per_batch, - use_sarsa_target=FLAGS.use_sarsa_target) - - def evaluator_network( - params: hk.Params, key: jax.Array, observation: jax.Array - ) -> jax.Array: - dist_params = networks.policy_network.apply(params, observation) - return networks.sample_eval(dist_params, key) - - actor_core = actor_core_lib.batched_feed_forward_to_actor_core( - evaluator_network) - variable_client = variable_utils.VariableClient( - learner, 'policy', device='cpu') - evaluator = actors.GenericActor( - actor_core, key, variable_client, backend='cpu') - - eval_loop = acme.EnvironmentLoop( - environment=environment, - actor=evaluator, - logger=loggers.TerminalLogger('evaluation', time_delta=0.)) - - # Run the environment loop. - while True: - for _ in range(FLAGS.evaluate_every): - learner.step() - eval_loop.run(FLAGS.evaluation_episodes) - - -if __name__ == '__main__': - app.run(main) + key = jax.random.PRNGKey(FLAGS.seed) + key_demonstrations, key_learner = jax.random.split(key, 2) + + # Create an environment and grab the spec. + environment = gym_helpers.make_environment(task=FLAGS.env_name) + environment_spec = specs.make_environment_spec(environment) + + # Get a demonstrations dataset with next_actions extra. + transitions = tfds.get_tfds_dataset(FLAGS.dataset_name, FLAGS.num_demonstrations) + double_transitions = rlds.transformations.batch( + transitions, size=2, shift=1, drop_remainder=True + ) + transitions = double_transitions.map(_add_next_action_extras) + demonstrations = tfds.JaxInMemoryRandomSampleIterator( + transitions, key=key_demonstrations, batch_size=FLAGS.batch_size + ) + + # Create the networks to optimize. + networks = crr.make_networks(environment_spec) + + # CRR policy loss function. + policy_loss_coeff_fn = crr.policy_loss_coeff_advantage_exp + + # Create the learner. + learner = crr.CRRLearner( + networks=networks, + random_key=key_learner, + discount=FLAGS.discount, + target_update_period=FLAGS.target_update_period, + policy_loss_coeff_fn=policy_loss_coeff_fn, + iterator=demonstrations, + policy_optimizer=optax.adam(FLAGS.policy_learning_rate), + critic_optimizer=optax.adam(FLAGS.critic_learning_rate), + grad_updates_per_batch=FLAGS.grad_updates_per_batch, + use_sarsa_target=FLAGS.use_sarsa_target, + ) + + def evaluator_network( + params: hk.Params, key: jax.Array, observation: jax.Array + ) -> jax.Array: + dist_params = networks.policy_network.apply(params, observation) + return networks.sample_eval(dist_params, key) + + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(evaluator_network) + variable_client = variable_utils.VariableClient(learner, "policy", device="cpu") + evaluator = actors.GenericActor(actor_core, key, variable_client, backend="cpu") + + eval_loop = acme.EnvironmentLoop( + environment=environment, + actor=evaluator, + logger=loggers.TerminalLogger("evaluation", time_delta=0.0), + ) + + # Run the environment loop. + while True: + for _ in range(FLAGS.evaluate_every): + learner.step() + eval_loop.run(FLAGS.evaluation_episodes) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/offline/run_dqfd.py b/examples/offline/run_dqfd.py index d5b3a1d0df..2df84e8077 100644 --- a/examples/offline/run_dqfd.py +++ b/examples/offline/run_dqfd.py @@ -15,69 +15,73 @@ """Example running DQfD on BSuite in a single process. """ -from absl import app -from absl import flags +import bsuite +import sonnet as snt +from absl import app, flags import acme -from acme import specs -from acme import wrappers +from acme import specs, wrappers from acme.agents.tf import dqfd from acme.agents.tf.dqfd import bsuite_demonstrations -import bsuite -import sonnet as snt - - # Bsuite flags -flags.DEFINE_string('bsuite_id', 'deep_sea/0', 'Bsuite id.') -flags.DEFINE_string('results_dir', '/tmp/bsuite', 'CSV results directory.') -flags.DEFINE_boolean('overwrite', False, 'Whether to overwrite csv results.') +flags.DEFINE_string("bsuite_id", "deep_sea/0", "Bsuite id.") +flags.DEFINE_string("results_dir", "/tmp/bsuite", "CSV results directory.") +flags.DEFINE_boolean("overwrite", False, "Whether to overwrite csv results.") # Agent flags -flags.DEFINE_float('demonstration_ratio', 0.5, - ('Proportion of demonstration transitions in the replay ' - 'buffer.')) -flags.DEFINE_integer('n_step', 5, - ('Number of steps to squash into a single transition.')) -flags.DEFINE_float('samples_per_insert', 8, - ('Number of samples to take from replay for every insert ' - 'that is made.')) -flags.DEFINE_float('learning_rate', 1e-3, 'Learning rate.') +flags.DEFINE_float( + "demonstration_ratio", + 0.5, + ("Proportion of demonstration transitions in the replay " "buffer."), +) +flags.DEFINE_integer( + "n_step", 5, ("Number of steps to squash into a single transition.") +) +flags.DEFINE_float( + "samples_per_insert", + 8, + ("Number of samples to take from replay for every insert " "that is made."), +) +flags.DEFINE_float("learning_rate", 1e-3, "Learning rate.") FLAGS = flags.FLAGS def make_network(action_spec: specs.DiscreteArray) -> snt.Module: - return snt.Sequential([ - snt.Flatten(), - snt.nets.MLP([50, 50, action_spec.num_values]), - ]) + return snt.Sequential( + [snt.Flatten(), snt.nets.MLP([50, 50, action_spec.num_values]),] + ) def main(_): - # Create an environment and grab the spec. - raw_environment = bsuite.load_and_record_to_csv( - bsuite_id=FLAGS.bsuite_id, - results_dir=FLAGS.results_dir, - overwrite=FLAGS.overwrite, - ) - environment = wrappers.SinglePrecisionWrapper(raw_environment) - environment_spec = specs.make_environment_spec(environment) - - # Construct the agent. - agent = dqfd.DQfD( - environment_spec=environment_spec, - network=make_network(environment_spec.actions), - demonstration_dataset=bsuite_demonstrations.make_dataset( - raw_environment, stochastic=False), - demonstration_ratio=FLAGS.demonstration_ratio, - samples_per_insert=FLAGS.samples_per_insert, - learning_rate=FLAGS.learning_rate) - - # Run the environment loop. - loop = acme.EnvironmentLoop(environment, agent) - loop.run(num_episodes=environment.bsuite_num_episodes) # pytype: disable=attribute-error - - -if __name__ == '__main__': - app.run(main) + # Create an environment and grab the spec. + raw_environment = bsuite.load_and_record_to_csv( + bsuite_id=FLAGS.bsuite_id, + results_dir=FLAGS.results_dir, + overwrite=FLAGS.overwrite, + ) + environment = wrappers.SinglePrecisionWrapper(raw_environment) + environment_spec = specs.make_environment_spec(environment) + + # Construct the agent. + agent = dqfd.DQfD( + environment_spec=environment_spec, + network=make_network(environment_spec.actions), + demonstration_dataset=bsuite_demonstrations.make_dataset( + raw_environment, stochastic=False + ), + demonstration_ratio=FLAGS.demonstration_ratio, + samples_per_insert=FLAGS.samples_per_insert, + learning_rate=FLAGS.learning_rate, + ) + + # Run the environment loop. + loop = acme.EnvironmentLoop(environment, agent) + loop.run( + num_episodes=environment.bsuite_num_episodes + ) # pytype: disable=attribute-error + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/offline/run_mbop_jax.py b/examples/offline/run_mbop_jax.py index 62a58cfd86..f3b384f327 100644 --- a/examples/offline/run_mbop_jax.py +++ b/examples/offline/run_mbop_jax.py @@ -16,8 +16,11 @@ import functools -from absl import app -from absl import flags +import jax +import optax +import tensorflow_datasets +from absl import app, flags + import acme from acme import specs from acme.agents.jax import mbop @@ -25,111 +28,124 @@ from acme.examples.offline import helpers as gym_helpers from acme.jax import running_statistics from acme.utils import loggers -import jax -import optax -import tensorflow_datasets # Training flags. -_NUM_NETWORKS = flags.DEFINE_integer('num_networks', 10, - 'Number of ensemble networks.') -_LEARNING_RATE = flags.DEFINE_float('learning_rate', 1e-3, 'Learning rate.') -_BATCH_SIZE = flags.DEFINE_integer('batch_size', 64, 'Batch size.') -_HIDDEN_LAYER_SIZES = flags.DEFINE_multi_integer('hidden_layer_sizes', [64, 64], - 'Sizes of the hidden layers.') +_NUM_NETWORKS = flags.DEFINE_integer("num_networks", 10, "Number of ensemble networks.") +_LEARNING_RATE = flags.DEFINE_float("learning_rate", 1e-3, "Learning rate.") +_BATCH_SIZE = flags.DEFINE_integer("batch_size", 64, "Batch size.") +_HIDDEN_LAYER_SIZES = flags.DEFINE_multi_integer( + "hidden_layer_sizes", [64, 64], "Sizes of the hidden layers." +) _NUM_SGD_STEPS_PER_STEP = flags.DEFINE_integer( - 'num_sgd_steps_per_step', 1, - 'Denotes how many gradient updates perform per one learner step.') + "num_sgd_steps_per_step", + 1, + "Denotes how many gradient updates perform per one learner step.", +) _NUM_NORMALIZATION_BATCHES = flags.DEFINE_integer( - 'num_normalization_batches', 50, - 'Number of batches used for calculating the normalization statistics.') -_EVALUATE_EVERY = flags.DEFINE_integer('evaluate_every', 20, - 'Evaluation period.') -_EVALUATION_EPISODES = flags.DEFINE_integer('evaluation_episodes', 10, - 'Evaluation episodes.') -_SEED = flags.DEFINE_integer('seed', 0, - 'Random seed for learner and evaluator.') + "num_normalization_batches", + 50, + "Number of batches used for calculating the normalization statistics.", +) +_EVALUATE_EVERY = flags.DEFINE_integer("evaluate_every", 20, "Evaluation period.") +_EVALUATION_EPISODES = flags.DEFINE_integer( + "evaluation_episodes", 10, "Evaluation episodes." +) +_SEED = flags.DEFINE_integer("seed", 0, "Random seed for learner and evaluator.") # Environment flags. -_ENV_NAME = flags.DEFINE_string('env_name', 'HalfCheetah-v2', - 'Gym mujoco environment name.') +_ENV_NAME = flags.DEFINE_string( + "env_name", "HalfCheetah-v2", "Gym mujoco environment name." +) _DATASET_NAME = flags.DEFINE_string( - 'dataset_name', 'd4rl_mujoco_halfcheetah/v2-medium', - 'D4rl dataset name. Can be any locomotion dataset from ' - 'https://www.tensorflow.org/datasets/catalog/overview#d4rl.') + "dataset_name", + "d4rl_mujoco_halfcheetah/v2-medium", + "D4rl dataset name. Can be any locomotion dataset from " + "https://www.tensorflow.org/datasets/catalog/overview#d4rl.", +) def main(_): - # Create an environment and grab the spec. - environment = gym_helpers.make_environment(task=_ENV_NAME.value) - spec = specs.make_environment_spec(environment) - - key = jax.random.PRNGKey(_SEED.value) - key, dataset_key, evaluator_key = jax.random.split(key, 3) - - # Load the dataset. - dataset = tensorflow_datasets.load(_DATASET_NAME.value)['train'] - # Unwrap the environment to get the demonstrations. - dataset = mbop.episodes_to_timestep_batched_transitions( - dataset, return_horizon=10) - dataset = tfds.JaxInMemoryRandomSampleIterator( - dataset, key=dataset_key, batch_size=_BATCH_SIZE.value) - - # Apply normalization to the dataset. - mean_std = mbop.get_normalization_stats(dataset, - _NUM_NORMALIZATION_BATCHES.value) - apply_normalization = jax.jit( - functools.partial(running_statistics.normalize, mean_std=mean_std)) - dataset = (apply_normalization(sample) for sample in dataset) - - # Create the networks. - networks = mbop.make_networks( - spec, hidden_layer_sizes=tuple(_HIDDEN_LAYER_SIZES.value)) - - # Use the default losses. - losses = mbop.MBOPLosses() - - def logger_fn(label: str, steps_key: str): - return loggers.make_default_logger(label, steps_key=steps_key) - - def make_learner(name, logger_fn, counter, rng_key, dataset, network, loss): - return mbop.make_ensemble_regressor_learner( - name, - _NUM_NETWORKS.value, - logger_fn, - counter, - rng_key, - dataset, - network, - loss, - optax.adam(_LEARNING_RATE.value), - _NUM_SGD_STEPS_PER_STEP.value, + # Create an environment and grab the spec. + environment = gym_helpers.make_environment(task=_ENV_NAME.value) + spec = specs.make_environment_spec(environment) + + key = jax.random.PRNGKey(_SEED.value) + key, dataset_key, evaluator_key = jax.random.split(key, 3) + + # Load the dataset. + dataset = tensorflow_datasets.load(_DATASET_NAME.value)["train"] + # Unwrap the environment to get the demonstrations. + dataset = mbop.episodes_to_timestep_batched_transitions(dataset, return_horizon=10) + dataset = tfds.JaxInMemoryRandomSampleIterator( + dataset, key=dataset_key, batch_size=_BATCH_SIZE.value ) - learner = mbop.MBOPLearner(networks, losses, dataset, key, logger_fn, - functools.partial(make_learner, 'world_model'), - functools.partial(make_learner, 'policy_prior'), - functools.partial(make_learner, 'n_step_return')) + # Apply normalization to the dataset. + mean_std = mbop.get_normalization_stats(dataset, _NUM_NORMALIZATION_BATCHES.value) + apply_normalization = jax.jit( + functools.partial(running_statistics.normalize, mean_std=mean_std) + ) + dataset = (apply_normalization(sample) for sample in dataset) - planning_config = mbop.MPPIConfig() + # Create the networks. + networks = mbop.make_networks( + spec, hidden_layer_sizes=tuple(_HIDDEN_LAYER_SIZES.value) + ) - assert planning_config.n_trajectories % _NUM_NETWORKS.value == 0, ( - 'Number of trajectories must be a multiple of the number of networks.') + # Use the default losses. + losses = mbop.MBOPLosses() + + def logger_fn(label: str, steps_key: str): + return loggers.make_default_logger(label, steps_key=steps_key) + + def make_learner(name, logger_fn, counter, rng_key, dataset, network, loss): + return mbop.make_ensemble_regressor_learner( + name, + _NUM_NETWORKS.value, + logger_fn, + counter, + rng_key, + dataset, + network, + loss, + optax.adam(_LEARNING_RATE.value), + _NUM_SGD_STEPS_PER_STEP.value, + ) + + learner = mbop.MBOPLearner( + networks, + losses, + dataset, + key, + logger_fn, + functools.partial(make_learner, "world_model"), + functools.partial(make_learner, "policy_prior"), + functools.partial(make_learner, "n_step_return"), + ) - actor_core = mbop.make_ensemble_actor_core( - networks, planning_config, spec, mean_std, use_round_robin=False) - evaluator = mbop.make_actor(actor_core, evaluator_key, learner) + planning_config = mbop.MPPIConfig() - eval_loop = acme.EnvironmentLoop( - environment=environment, - actor=evaluator, - logger=loggers.TerminalLogger('evaluation', time_delta=0.)) + assert ( + planning_config.n_trajectories % _NUM_NETWORKS.value == 0 + ), "Number of trajectories must be a multiple of the number of networks." + + actor_core = mbop.make_ensemble_actor_core( + networks, planning_config, spec, mean_std, use_round_robin=False + ) + evaluator = mbop.make_actor(actor_core, evaluator_key, learner) + + eval_loop = acme.EnvironmentLoop( + environment=environment, + actor=evaluator, + logger=loggers.TerminalLogger("evaluation", time_delta=0.0), + ) - # Train the agent. - while True: - for _ in range(_EVALUATE_EVERY.value): - learner.step() - eval_loop.run(_EVALUATION_EPISODES.value) + # Train the agent. + while True: + for _ in range(_EVALUATE_EVERY.value): + learner.step() + eval_loop.run(_EVALUATION_EPISODES.value) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/examples/offline/run_offline_td3_jax.py b/examples/offline/run_offline_td3_jax.py index 6684278eb8..dc4d4b14e5 100644 --- a/examples/offline/run_offline_td3_jax.py +++ b/examples/offline/run_offline_td3_jax.py @@ -14,131 +14,137 @@ """An example offline TD3 running on locomotion datasets (mujoco) from D4rl.""" -from absl import app -from absl import flags +import haiku as hk +import jax +import optax +import reverb +import rlds +import tensorflow as tf +import tree +from absl import app, flags + import acme from acme import specs from acme.agents.jax import actor_core as actor_core_lib -from acme.agents.jax import actors -from acme.agents.jax import td3 +from acme.agents.jax import actors, td3 from acme.datasets import tfds from acme.examples.offline import helpers as gym_helpers from acme.jax import variable_utils from acme.types import Transition from acme.utils import loggers -import haiku as hk -import jax -import optax -import reverb -import rlds -import tensorflow as tf -import tree # Agent flags -flags.DEFINE_integer('batch_size', 64, 'Batch size.') -flags.DEFINE_integer('evaluate_every', 20, 'Evaluation period.') -flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') +flags.DEFINE_integer("batch_size", 64, "Batch size.") +flags.DEFINE_integer("evaluate_every", 20, "Evaluation period.") +flags.DEFINE_integer("evaluation_episodes", 10, "Evaluation episodes.") flags.DEFINE_integer( - 'num_demonstrations', 10, - 'Number of demonstration episodes to load from the dataset. If None, loads the full dataset.' + "num_demonstrations", + 10, + "Number of demonstration episodes to load from the dataset. If None, loads the full dataset.", ) -flags.DEFINE_integer('seed', 0, 'Random seed for learner and evaluator.') +flags.DEFINE_integer("seed", 0, "Random seed for learner and evaluator.") # TD3 specific flags. -flags.DEFINE_float('discount', 0.99, 'Discount.') -flags.DEFINE_float('policy_learning_rate', 3e-4, 'Policy learning rate.') -flags.DEFINE_float('critic_learning_rate', 3e-4, 'Critic learning rate.') -flags.DEFINE_float('bc_alpha', 2.5, - 'Add a bc regularization term to the policy loss.' - 'If set to None, TD3 is run without bc regularisation.') +flags.DEFINE_float("discount", 0.99, "Discount.") +flags.DEFINE_float("policy_learning_rate", 3e-4, "Policy learning rate.") +flags.DEFINE_float("critic_learning_rate", 3e-4, "Critic learning rate.") +flags.DEFINE_float( + "bc_alpha", + 2.5, + "Add a bc regularization term to the policy loss." + "If set to None, TD3 is run without bc regularisation.", +) flags.DEFINE_bool( - 'use_sarsa_target', True, - 'Compute on-policy target using iterator actions rather than sampled ' - 'actions.' + "use_sarsa_target", + True, + "Compute on-policy target using iterator actions rather than sampled " "actions.", ) # Environment flags. -flags.DEFINE_string('env_name', 'HalfCheetah-v2', - 'Gym mujoco environment name.') +flags.DEFINE_string("env_name", "HalfCheetah-v2", "Gym mujoco environment name.") flags.DEFINE_string( - 'dataset_name', 'd4rl_mujoco_halfcheetah/v2-medium', - 'D4rl dataset name. Can be any locomotion dataset from ' - 'https://www.tensorflow.org/datasets/catalog/overview#d4rl.') + "dataset_name", + "d4rl_mujoco_halfcheetah/v2-medium", + "D4rl dataset name. Can be any locomotion dataset from " + "https://www.tensorflow.org/datasets/catalog/overview#d4rl.", +) FLAGS = flags.FLAGS -def _add_next_action_extras(double_transitions: Transition - ) -> reverb.ReplaySample: - # As TD3 is online by default, it expects an iterator over replay samples. - info = tree.map_structure(lambda dtype: tf.ones([], dtype), - reverb.SampleInfo.tf_dtypes()) - return reverb.ReplaySample( - info=info, - data=Transition( - observation=double_transitions.observation[0], - action=double_transitions.action[0], - reward=double_transitions.reward[0], - discount=double_transitions.discount[0], - next_observation=double_transitions.next_observation[0], - extras={'next_action': double_transitions.action[1]})) +def _add_next_action_extras(double_transitions: Transition) -> reverb.ReplaySample: + # As TD3 is online by default, it expects an iterator over replay samples. + info = tree.map_structure( + lambda dtype: tf.ones([], dtype), reverb.SampleInfo.tf_dtypes() + ) + return reverb.ReplaySample( + info=info, + data=Transition( + observation=double_transitions.observation[0], + action=double_transitions.action[0], + reward=double_transitions.reward[0], + discount=double_transitions.discount[0], + next_observation=double_transitions.next_observation[0], + extras={"next_action": double_transitions.action[1]}, + ), + ) def main(_): - key = jax.random.PRNGKey(FLAGS.seed) - key_demonstrations, key_learner = jax.random.split(key, 2) - - # Create an environment and grab the spec. - environment = gym_helpers.make_environment(task=FLAGS.env_name) - environment_spec = specs.make_environment_spec(environment) - - # Get a demonstrations dataset with next_actions extra. - transitions = tfds.get_tfds_dataset( - FLAGS.dataset_name, FLAGS.num_demonstrations) - double_transitions = rlds.transformations.batch( - transitions, size=2, shift=1, drop_remainder=True) - transitions = double_transitions.map(_add_next_action_extras) - demonstrations = tfds.JaxInMemoryRandomSampleIterator( - transitions, key=key_demonstrations, batch_size=FLAGS.batch_size) - - # Create the networks to optimize. - networks = td3.make_networks(environment_spec) - - # Create the learner. - learner = td3.TD3Learner( - networks=networks, - random_key=key_learner, - discount=FLAGS.discount, - iterator=demonstrations, - policy_optimizer=optax.adam(FLAGS.policy_learning_rate), - critic_optimizer=optax.adam(FLAGS.critic_learning_rate), - twin_critic_optimizer=optax.adam(FLAGS.critic_learning_rate), - use_sarsa_target=FLAGS.use_sarsa_target, - bc_alpha=FLAGS.bc_alpha, - num_sgd_steps_per_step=1) - - def evaluator_network( - params: hk.Params, key: jax.Array, observation: jax.Array - ) -> jax.Array: - del key - return networks.policy_network.apply(params, observation) - - actor_core = actor_core_lib.batched_feed_forward_to_actor_core( - evaluator_network) - variable_client = variable_utils.VariableClient( - learner, 'policy', device='cpu') - evaluator = actors.GenericActor( - actor_core, key, variable_client, backend='cpu') - - eval_loop = acme.EnvironmentLoop( - environment=environment, - actor=evaluator, - logger=loggers.TerminalLogger('evaluation', time_delta=0.)) - - # Run the environment loop. - while True: - for _ in range(FLAGS.evaluate_every): - learner.step() - eval_loop.run(FLAGS.evaluation_episodes) - - -if __name__ == '__main__': - app.run(main) + key = jax.random.PRNGKey(FLAGS.seed) + key_demonstrations, key_learner = jax.random.split(key, 2) + + # Create an environment and grab the spec. + environment = gym_helpers.make_environment(task=FLAGS.env_name) + environment_spec = specs.make_environment_spec(environment) + + # Get a demonstrations dataset with next_actions extra. + transitions = tfds.get_tfds_dataset(FLAGS.dataset_name, FLAGS.num_demonstrations) + double_transitions = rlds.transformations.batch( + transitions, size=2, shift=1, drop_remainder=True + ) + transitions = double_transitions.map(_add_next_action_extras) + demonstrations = tfds.JaxInMemoryRandomSampleIterator( + transitions, key=key_demonstrations, batch_size=FLAGS.batch_size + ) + + # Create the networks to optimize. + networks = td3.make_networks(environment_spec) + + # Create the learner. + learner = td3.TD3Learner( + networks=networks, + random_key=key_learner, + discount=FLAGS.discount, + iterator=demonstrations, + policy_optimizer=optax.adam(FLAGS.policy_learning_rate), + critic_optimizer=optax.adam(FLAGS.critic_learning_rate), + twin_critic_optimizer=optax.adam(FLAGS.critic_learning_rate), + use_sarsa_target=FLAGS.use_sarsa_target, + bc_alpha=FLAGS.bc_alpha, + num_sgd_steps_per_step=1, + ) + + def evaluator_network( + params: hk.Params, key: jax.Array, observation: jax.Array + ) -> jax.Array: + del key + return networks.policy_network.apply(params, observation) + + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(evaluator_network) + variable_client = variable_utils.VariableClient(learner, "policy", device="cpu") + evaluator = actors.GenericActor(actor_core, key, variable_client, backend="cpu") + + eval_loop = acme.EnvironmentLoop( + environment=environment, + actor=evaluator, + logger=loggers.TerminalLogger("evaluation", time_delta=0.0), + ) + + # Run the environment loop. + while True: + for _ in range(FLAGS.evaluate_every): + learner.step() + eval_loop.run(FLAGS.evaluation_episodes) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/open_spiel/run_dqn.py b/examples/open_spiel/run_dqn.py index 5264b35624..6f1ca23822 100644 --- a/examples/open_spiel/run_dqn.py +++ b/examples/open_spiel/run_dqn.py @@ -14,8 +14,9 @@ """Example running DQN on OpenSpiel game in a single process.""" -from absl import app -from absl import flags +import sonnet as snt +from absl import app, flags +from open_spiel.python import rl_environment import acme from acme import wrappers @@ -23,53 +24,53 @@ from acme.environment_loops import open_spiel_environment_loop from acme.tf.networks import legal_actions from acme.wrappers import open_spiel_wrapper -import sonnet as snt - -from open_spiel.python import rl_environment -flags.DEFINE_string('game', 'tic_tac_toe', 'Name of the game') -flags.DEFINE_integer('num_players', None, 'Number of players') +flags.DEFINE_string("game", "tic_tac_toe", "Name of the game") +flags.DEFINE_integer("num_players", None, "Number of players") FLAGS = flags.FLAGS def main(_): - # Create an environment and grab the spec. - env_configs = {'players': FLAGS.num_players} if FLAGS.num_players else {} - raw_environment = rl_environment.Environment(FLAGS.game, **env_configs) + # Create an environment and grab the spec. + env_configs = {"players": FLAGS.num_players} if FLAGS.num_players else {} + raw_environment = rl_environment.Environment(FLAGS.game, **env_configs) - environment = open_spiel_wrapper.OpenSpielWrapper(raw_environment) - environment = wrappers.SinglePrecisionWrapper(environment) # type: open_spiel_wrapper.OpenSpielWrapper # pytype: disable=annotation-type-mismatch - environment_spec = acme.make_environment_spec(environment) + environment = open_spiel_wrapper.OpenSpielWrapper(raw_environment) + environment = wrappers.SinglePrecisionWrapper( + environment + ) # type: open_spiel_wrapper.OpenSpielWrapper # pytype: disable=annotation-type-mismatch + environment_spec = acme.make_environment_spec(environment) - # Build the networks. - networks = [] - policy_networks = [] - for _ in range(environment.num_players): - network = legal_actions.MaskedSequential([ - snt.Flatten(), - snt.nets.MLP([50, 50, environment_spec.actions.num_values]) - ]) - policy_network = snt.Sequential( - [network, - legal_actions.EpsilonGreedy(epsilon=0.1, threshold=-1e8)]) - networks.append(network) - policy_networks.append(policy_network) + # Build the networks. + networks = [] + policy_networks = [] + for _ in range(environment.num_players): + network = legal_actions.MaskedSequential( + [snt.Flatten(), snt.nets.MLP([50, 50, environment_spec.actions.num_values])] + ) + policy_network = snt.Sequential( + [network, legal_actions.EpsilonGreedy(epsilon=0.1, threshold=-1e8)] + ) + networks.append(network) + policy_networks.append(policy_network) - # Construct the agents. - agents = [] + # Construct the agents. + agents = [] - for network, policy_network in zip(networks, policy_networks): - agents.append( - dqn.DQN(environment_spec=environment_spec, + for network, policy_network in zip(networks, policy_networks): + agents.append( + dqn.DQN( + environment_spec=environment_spec, network=network, - policy_network=policy_network)) + policy_network=policy_network, + ) + ) - # Run the environment loop. - loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop( - environment, agents) - loop.run(num_episodes=100000) + # Run the environment loop. + loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop(environment, agents) + loop.run(num_episodes=100000) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/examples/tf/control_suite/helpers.py b/examples/tf/control_suite/helpers.py index d00ce9567f..679c9b5735 100644 --- a/examples/tf/control_suite/helpers.py +++ b/examples/tf/control_suite/helpers.py @@ -16,42 +16,48 @@ from typing import Optional -from acme import wrappers import dm_env +from acme import wrappers + def make_environment( evaluation: bool = False, - domain_name: str = 'cartpole', - task_name: str = 'balance', + domain_name: str = "cartpole", + task_name: str = "balance", from_pixels: bool = False, frames_to_stack: int = 3, flatten_stack: bool = False, num_action_repeats: Optional[int] = None, ) -> dm_env.Environment: - """Implements a control suite environment factory.""" - # Load dm_suite lazily not require Mujoco license when not using it. - from dm_control import suite # pylint: disable=g-import-not-at-top - from acme.wrappers import mujoco as mujoco_wrappers # pylint: disable=g-import-not-at-top - - # Load raw control suite environment. - environment = suite.load(domain_name, task_name) - - # Maybe wrap to get pixel observations from environment state. - if from_pixels: - environment = mujoco_wrappers.MujocoPixelWrapper(environment) - environment = wrappers.FrameStackingWrapper( - environment, num_frames=frames_to_stack, flatten=flatten_stack) - environment = wrappers.CanonicalSpecWrapper(environment, clip=True) - - if num_action_repeats: - environment = wrappers.ActionRepeatWrapper( - environment, num_repeats=num_action_repeats) - environment = wrappers.SinglePrecisionWrapper(environment) - - if evaluation: - # The evaluator in the distributed agent will set this to True so you can - # use this clause to, e.g., set up video recording by the evaluator. - pass - - return environment + """Implements a control suite environment factory.""" + # Load dm_suite lazily not require Mujoco license when not using it. + from dm_control import suite # pylint: disable=g-import-not-at-top + + from acme.wrappers import ( + mujoco as mujoco_wrappers, + ) # pylint: disable=g-import-not-at-top + + # Load raw control suite environment. + environment = suite.load(domain_name, task_name) + + # Maybe wrap to get pixel observations from environment state. + if from_pixels: + environment = mujoco_wrappers.MujocoPixelWrapper(environment) + environment = wrappers.FrameStackingWrapper( + environment, num_frames=frames_to_stack, flatten=flatten_stack + ) + environment = wrappers.CanonicalSpecWrapper(environment, clip=True) + + if num_action_repeats: + environment = wrappers.ActionRepeatWrapper( + environment, num_repeats=num_action_repeats + ) + environment = wrappers.SinglePrecisionWrapper(environment) + + if evaluation: + # The evaluator in the distributed agent will set this to True so you can + # use this clause to, e.g., set up video recording by the evaluator. + pass + + return environment diff --git a/examples/tf/control_suite/lp_d4pg.py b/examples/tf/control_suite/lp_d4pg.py index bf7578f874..776edef66b 100644 --- a/examples/tf/control_suite/lp_d4pg.py +++ b/examples/tf/control_suite/lp_d4pg.py @@ -17,77 +17,81 @@ import functools from typing import Callable, Dict, Sequence, Union -from absl import app -from absl import flags -from acme import specs -from acme.agents.tf import d4pg import helpers -from acme.tf import networks -from acme.tf import utils as tf2_utils import launchpad as lp import numpy as np import sonnet as snt import tensorflow as tf +from absl import app, flags +from acme import specs +from acme.agents.tf import d4pg +from acme.tf import networks +from acme.tf import utils as tf2_utils # Flags which modify the behavior of the launcher. FLAGS = flags.FLAGS _MAX_ACTOR_STEPS = flags.DEFINE_integer( - 'max_actor_steps', None, - 'Number of actor steps to run; defaults to None for an endless loop.') -_DOMAIN = flags.DEFINE_string('domain', 'cartpole', - 'Control suite domain name.') -_TASK = flags.DEFINE_string('task', 'balance', 'Control suite task name.') + "max_actor_steps", + None, + "Number of actor steps to run; defaults to None for an endless loop.", +) +_DOMAIN = flags.DEFINE_string("domain", "cartpole", "Control suite domain name.") +_TASK = flags.DEFINE_string("task", "balance", "Control suite task name.") def make_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), - vmin: float = -150., - vmax: float = 150., + vmin: float = -150.0, + vmax: float = 150.0, num_atoms: int = 51, ) -> Dict[str, Union[snt.Module, Callable[[tf.Tensor], tf.Tensor]]]: - """Creates networks used by the agent.""" - - num_dimensions = np.prod(action_spec.shape, dtype=int) - - policy_network = snt.Sequential([ - networks.LayerNormMLP(policy_layer_sizes, activate_final=True), - networks.NearZeroInitializedLinear(num_dimensions), - networks.TanhToSpec(action_spec) - ]) - # The multiplexer concatenates the (maybe transformed) observations/actions. - critic_network = snt.Sequential([ - networks.CriticMultiplexer(), - networks.LayerNormMLP(critic_layer_sizes, activate_final=True), - networks.DiscreteValuedHead(vmin, vmax, num_atoms), - ]) - - return { - 'policy': policy_network, - 'critic': critic_network, - 'observation': tf2_utils.batch_concat, - } + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential( + [ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(num_dimensions), + networks.TanhToSpec(action_spec), + ] + ) + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = snt.Sequential( + [ + networks.CriticMultiplexer(), + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.DiscreteValuedHead(vmin, vmax, num_atoms), + ] + ) + + return { + "policy": policy_network, + "critic": critic_network, + "observation": tf2_utils.batch_concat, + } def main(_): - # Configure the environment factory with requested task. - make_environment = functools.partial( - helpers.make_environment, - domain_name=_DOMAIN.value, - task_name=_TASK.value) + # Configure the environment factory with requested task. + make_environment = functools.partial( + helpers.make_environment, domain_name=_DOMAIN.value, task_name=_TASK.value + ) - # Construct the program. - program_builder = d4pg.DistributedD4PG( - make_environment, - make_networks, - max_actor_steps=_MAX_ACTOR_STEPS.value, - num_actors=4) + # Construct the program. + program_builder = d4pg.DistributedD4PG( + make_environment, + make_networks, + max_actor_steps=_MAX_ACTOR_STEPS.value, + num_actors=4, + ) - # Launch experiment. - lp.launch(programs=program_builder.build()) + # Launch experiment. + lp.launch(programs=program_builder.build()) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/examples/tf/control_suite/lp_ddpg.py b/examples/tf/control_suite/lp_ddpg.py index 0ba95dc977..62ca0b4434 100644 --- a/examples/tf/control_suite/lp_ddpg.py +++ b/examples/tf/control_suite/lp_ddpg.py @@ -17,27 +17,26 @@ import functools from typing import Dict, Sequence -from absl import app -from absl import flags -from acme import specs -from acme import types -from acme.agents.tf import ddpg import helpers -from acme.tf import networks -from acme.tf import utils as tf2_utils import launchpad as lp import numpy as np import sonnet as snt +from absl import app, flags +from acme import specs, types +from acme.agents.tf import ddpg +from acme.tf import networks +from acme.tf import utils as tf2_utils # Flags which modify the behavior of the launcher. FLAGS = flags.FLAGS _MAX_ACTOR_STEPS = flags.DEFINE_integer( - 'max_actor_steps', None, - 'Number of actor steps to run; defaults to None for an endless loop.') -_DOMAIN = flags.DEFINE_string('domain', 'cartpole', - 'Control suite domain name.') -_TASK = flags.DEFINE_string('task', 'balance', 'Control suite task name.') + "max_actor_steps", + None, + "Number of actor steps to run; defaults to None for an endless loop.", +) +_DOMAIN = flags.DEFINE_string("domain", "cartpole", "Control suite domain name.") +_TASK = flags.DEFINE_string("task", "balance", "Control suite task name.") def make_networks( @@ -45,46 +44,50 @@ def make_networks( policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), ) -> Dict[str, types.TensorTransformation]: - """Creates networks used by the agent.""" - - num_dimensions = np.prod(action_spec.shape, dtype=int) - - policy_network = snt.Sequential([ - networks.LayerNormMLP(policy_layer_sizes, activate_final=True), - networks.NearZeroInitializedLinear(num_dimensions), - networks.TanhToSpec(action_spec) - ]) - critic_network = snt.Sequential([ - # The multiplexer concatenates the observations/actions. - networks.CriticMultiplexer(), - networks.LayerNormMLP(critic_layer_sizes, activate_final=True), - networks.NearZeroInitializedLinear(1), - ]) - - return { - 'policy': policy_network, - 'critic': critic_network, - 'observation': tf2_utils.batch_concat, - } + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential( + [ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(num_dimensions), + networks.TanhToSpec(action_spec), + ] + ) + critic_network = snt.Sequential( + [ + # The multiplexer concatenates the observations/actions. + networks.CriticMultiplexer(), + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ] + ) + + return { + "policy": policy_network, + "critic": critic_network, + "observation": tf2_utils.batch_concat, + } def main(_): - # Configure the environment factory with requested task. - make_environment = functools.partial( - helpers.make_environment, - domain_name=_DOMAIN.value, - task_name=_TASK.value) + # Configure the environment factory with requested task. + make_environment = functools.partial( + helpers.make_environment, domain_name=_DOMAIN.value, task_name=_TASK.value + ) - # Construct the program. - program_builder = ddpg.DistributedDDPG( - make_environment, - make_networks, - max_actor_steps=_MAX_ACTOR_STEPS.value, - num_actors=4) + # Construct the program. + program_builder = ddpg.DistributedDDPG( + make_environment, + make_networks, + max_actor_steps=_MAX_ACTOR_STEPS.value, + num_actors=4, + ) - # Launch experiment. - lp.launch(programs=program_builder.build()) + # Launch experiment. + lp.launch(programs=program_builder.build()) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/examples/tf/control_suite/lp_dmpo.py b/examples/tf/control_suite/lp_dmpo.py index a314637916..462a1d5d17 100644 --- a/examples/tf/control_suite/lp_dmpo.py +++ b/examples/tf/control_suite/lp_dmpo.py @@ -17,80 +17,86 @@ import functools from typing import Dict, Sequence -from absl import app -from absl import flags -from acme import specs -from acme import types -from acme.agents.tf import dmpo import helpers -from acme.tf import networks -from acme.tf import utils as tf2_utils import launchpad as lp import numpy as np import sonnet as snt +from absl import app, flags + +from acme import specs, types +from acme.agents.tf import dmpo +from acme.tf import networks +from acme.tf import utils as tf2_utils # Flags which modify the behavior of the launcher. FLAGS = flags.FLAGS _MAX_ACTOR_STEPS = flags.DEFINE_integer( - 'max_actor_steps', None, - 'Number of actor steps to run; defaults to None for an endless loop.') -_DOMAIN = flags.DEFINE_string('domain', 'cartpole', - 'Control suite domain name.') -_TASK = flags.DEFINE_string('task', 'balance', 'Control suite task name.') + "max_actor_steps", + None, + "Number of actor steps to run; defaults to None for an endless loop.", +) +_DOMAIN = flags.DEFINE_string("domain", "cartpole", "Control suite domain name.") +_TASK = flags.DEFINE_string("task", "balance", "Control suite task name.") def make_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), - vmin: float = -150., - vmax: float = 150., + vmin: float = -150.0, + vmax: float = 150.0, num_atoms: int = 51, ) -> Dict[str, types.TensorTransformation]: - """Creates networks used by the agent.""" - - num_dimensions = np.prod(action_spec.shape, dtype=int) - - policy_network = snt.Sequential([ - networks.LayerNormMLP(policy_layer_sizes, activate_final=True), - networks.MultivariateNormalDiagHead( - num_dimensions, init_scale=0.7, use_tfd_independent=True) - ]) - - # The multiplexer concatenates the (maybe transformed) observations/actions. - multiplexer = networks.CriticMultiplexer( - action_network=networks.ClipToSpec(action_spec)) - critic_network = snt.Sequential([ - multiplexer, - networks.LayerNormMLP(critic_layer_sizes, activate_final=True), - networks.DiscreteValuedHead(vmin, vmax, num_atoms) - ]) - - return { - 'policy': policy_network, - 'critic': critic_network, - 'observation': tf2_utils.batch_concat, - } + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential( + [ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, init_scale=0.7, use_tfd_independent=True + ), + ] + ) + + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer( + action_network=networks.ClipToSpec(action_spec) + ) + critic_network = snt.Sequential( + [ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.DiscreteValuedHead(vmin, vmax, num_atoms), + ] + ) + + return { + "policy": policy_network, + "critic": critic_network, + "observation": tf2_utils.batch_concat, + } def main(_): - # Configure the environment factory with requested task. - make_environment = functools.partial( - helpers.make_environment, - domain_name=_DOMAIN.value, - task_name=_TASK.value) - - # Construct the program. - program_builder = dmpo.DistributedDistributionalMPO( - make_environment, - make_networks, - target_policy_update_period=25, - max_actor_steps=_MAX_ACTOR_STEPS.value, - num_actors=4) - - # Launch experiment. - lp.launch(programs=program_builder.build()) - - -if __name__ == '__main__': - app.run(main) + # Configure the environment factory with requested task. + make_environment = functools.partial( + helpers.make_environment, domain_name=_DOMAIN.value, task_name=_TASK.value + ) + + # Construct the program. + program_builder = dmpo.DistributedDistributionalMPO( + make_environment, + make_networks, + target_policy_update_period=25, + max_actor_steps=_MAX_ACTOR_STEPS.value, + num_actors=4, + ) + + # Launch experiment. + lp.launch(programs=program_builder.build()) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/tf/control_suite/lp_dmpo_pixels.py b/examples/tf/control_suite/lp_dmpo_pixels.py index e9afd4b833..1e485011f2 100644 --- a/examples/tf/control_suite/lp_dmpo_pixels.py +++ b/examples/tf/control_suite/lp_dmpo_pixels.py @@ -17,89 +17,92 @@ import functools from typing import Dict, Sequence -from absl import app -from absl import flags -from acme import specs -from acme import types -from acme.agents.tf import dmpo import helpers -from acme.tf import networks import launchpad as lp import numpy as np import sonnet as snt +from absl import app, flags + +from acme import specs, types +from acme.agents.tf import dmpo +from acme.tf import networks # Flags which modify the behavior of the launcher. FLAGS = flags.FLAGS _MAX_ACTOR_STEPS = flags.DEFINE_integer( - 'max_actor_steps', None, - 'Number of actor steps to run; defaults to None for an endless loop.') -_DOMAIN = flags.DEFINE_string('domain', 'cartpole', - 'Control suite domain name.') -_TASK = flags.DEFINE_string('task', 'balance', 'Control suite task name.') + "max_actor_steps", + None, + "Number of actor steps to run; defaults to None for an endless loop.", +) +_DOMAIN = flags.DEFINE_string("domain", "cartpole", "Control suite domain name.") +_TASK = flags.DEFINE_string("task", "balance", "Control suite task name.") def make_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), - vmin: float = -150., - vmax: float = 150., + vmin: float = -150.0, + vmax: float = 150.0, num_atoms: int = 51, ) -> Dict[str, types.TensorTransformation]: - """Creates networks used by the agent.""" - - num_dimensions = np.prod(action_spec.shape, dtype=int) - - policy_network = snt.Sequential([ - networks.LayerNormMLP(policy_layer_sizes, activate_final=True), - networks.MultivariateNormalDiagHead( - num_dimensions, - tanh_mean=False, - init_scale=1.0, - fixed_scale=False, - use_tfd_independent=True) - ]) - - # The multiplexer concatenates the (maybe transformed) observations/actions. - critic_network = networks.CriticMultiplexer( - critic_network=networks.LayerNormMLP( - critic_layer_sizes, activate_final=True), - action_network=networks.ClipToSpec(action_spec)) - critic_network = snt.Sequential( - [critic_network, - networks.DiscreteValuedHead(vmin, vmax, num_atoms)]) - observation_network = networks.ResNetTorso() - - return { - 'policy': policy_network, - 'critic': critic_network, - 'observation': observation_network, - } + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential( + [ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=False, + init_scale=1.0, + fixed_scale=False, + use_tfd_independent=True, + ), + ] + ) + + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = networks.CriticMultiplexer( + critic_network=networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + action_network=networks.ClipToSpec(action_spec), + ) + critic_network = snt.Sequential( + [critic_network, networks.DiscreteValuedHead(vmin, vmax, num_atoms)] + ) + observation_network = networks.ResNetTorso() + + return { + "policy": policy_network, + "critic": critic_network, + "observation": observation_network, + } def main(_): - # Configure the environment factory with requested task. - make_environment = functools.partial( - helpers.make_environment, - domain_name=_DOMAIN.value, - task_name=_TASK.value, - from_pixels=True, - frames_to_stack=3, - num_action_repeats=2) - - # Construct the program. - program_builder = dmpo.DistributedDistributionalMPO( - make_environment, - make_networks, - n_step=3, # Reduce the n-step to account for action-repeat. - max_actor_steps=_MAX_ACTOR_STEPS.value, - num_actors=4) - - # Launch experiment. - lp.launch( - programs=program_builder.build() - ) - - -if __name__ == '__main__': - app.run(main) + # Configure the environment factory with requested task. + make_environment = functools.partial( + helpers.make_environment, + domain_name=_DOMAIN.value, + task_name=_TASK.value, + from_pixels=True, + frames_to_stack=3, + num_action_repeats=2, + ) + + # Construct the program. + program_builder = dmpo.DistributedDistributionalMPO( + make_environment, + make_networks, + n_step=3, # Reduce the n-step to account for action-repeat. + max_actor_steps=_MAX_ACTOR_STEPS.value, + num_actors=4, + ) + + # Launch experiment. + lp.launch(programs=program_builder.build()) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/tf/control_suite/lp_dmpo_pixels_drqv2.py b/examples/tf/control_suite/lp_dmpo_pixels_drqv2.py index 3b405fcb0d..8e6e37263b 100644 --- a/examples/tf/control_suite/lp_dmpo_pixels_drqv2.py +++ b/examples/tf/control_suite/lp_dmpo_pixels_drqv2.py @@ -17,109 +17,123 @@ import functools from typing import Dict, Sequence -from absl import app -from absl import flags -from acme import specs -from acme.agents.tf import dmpo -from acme.datasets import image_augmentation import helpers -from acme.tf import networks import launchpad as lp import numpy as np import sonnet as snt import tensorflow as tf +from absl import app, flags + +from acme import specs +from acme.agents.tf import dmpo +from acme.datasets import image_augmentation +from acme.tf import networks # Flags which modify the behavior of the launcher. FLAGS = flags.FLAGS _MAX_ACTOR_STEPS = flags.DEFINE_integer( - 'max_actor_steps', None, - 'Number of actor steps to run; defaults to None for an endless loop.') -_DOMAIN = flags.DEFINE_string('domain', 'cartpole', - 'Control suite domain name.') -_TASK = flags.DEFINE_string('task', 'balance', 'Control suite task name.') + "max_actor_steps", + None, + "Number of actor steps to run; defaults to None for an endless loop.", +) +_DOMAIN = flags.DEFINE_string("domain", "cartpole", "Control suite domain name.") +_TASK = flags.DEFINE_string("task", "balance", "Control suite task name.") def make_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (50, 1024, 1024), critic_layer_sizes: Sequence[int] = (50, 1024, 1024), - vmin: float = -150., - vmax: float = 150., + vmin: float = -150.0, + vmax: float = 150.0, num_atoms: int = 51, ) -> Dict[str, snt.Module]: - """Creates networks used by the agent.""" - - num_dimensions = np.prod(action_spec.shape, dtype=int) - - policy_network = snt.Sequential([ - networks.LayerNormMLP( - policy_layer_sizes, - w_init=snt.initializers.Orthogonal(), - activation=tf.nn.relu, - activate_final=True), - networks.MultivariateNormalDiagHead( - num_dimensions, - tanh_mean=False, - init_scale=1.0, - fixed_scale=False, - use_tfd_independent=True, - w_init=snt.initializers.Orthogonal()) - ]) - - # The multiplexer concatenates the (maybe transformed) observations/actions. - critic_network = networks.CriticMultiplexer( - observation_network=snt.Sequential([ - snt.Linear(critic_layer_sizes[0], - w_init=snt.initializers.Orthogonal()), - snt.LayerNorm( - axis=slice(1, None), create_scale=True, create_offset=True), - tf.nn.tanh]), - critic_network=snt.nets.MLP( - critic_layer_sizes[1:], - w_init=snt.initializers.Orthogonal(), - activation=tf.nn.relu, - activate_final=True), - action_network=networks.ClipToSpec(action_spec)) - critic_network = snt.Sequential( - [critic_network, - networks.DiscreteValuedHead(vmin, vmax, num_atoms, - w_init=snt.initializers.Orthogonal()) - ]) - observation_network = networks.DrQTorso() - - return { - 'policy': policy_network, - 'critic': critic_network, - 'observation': observation_network, - } + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential( + [ + networks.LayerNormMLP( + policy_layer_sizes, + w_init=snt.initializers.Orthogonal(), + activation=tf.nn.relu, + activate_final=True, + ), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=False, + init_scale=1.0, + fixed_scale=False, + use_tfd_independent=True, + w_init=snt.initializers.Orthogonal(), + ), + ] + ) + + # The multiplexer concatenates the (maybe transformed) observations/actions. + critic_network = networks.CriticMultiplexer( + observation_network=snt.Sequential( + [ + snt.Linear(critic_layer_sizes[0], w_init=snt.initializers.Orthogonal()), + snt.LayerNorm( + axis=slice(1, None), create_scale=True, create_offset=True + ), + tf.nn.tanh, + ] + ), + critic_network=snt.nets.MLP( + critic_layer_sizes[1:], + w_init=snt.initializers.Orthogonal(), + activation=tf.nn.relu, + activate_final=True, + ), + action_network=networks.ClipToSpec(action_spec), + ) + critic_network = snt.Sequential( + [ + critic_network, + networks.DiscreteValuedHead( + vmin, vmax, num_atoms, w_init=snt.initializers.Orthogonal() + ), + ] + ) + observation_network = networks.DrQTorso() + + return { + "policy": policy_network, + "critic": critic_network, + "observation": observation_network, + } def main(_): - # Configure the environment factory with requested task. - make_environment = functools.partial( - helpers.make_environment, - domain_name=_DOMAIN.value, - task_name=_TASK.value, - from_pixels=True, - frames_to_stack=3, - flatten_stack=True, - num_action_repeats=2) - - # Construct the program. - program_builder = dmpo.DistributedDistributionalMPO( - make_environment, - make_networks, - target_policy_update_period=100, - max_actor_steps=_MAX_ACTOR_STEPS.value, - num_actors=4, - samples_per_insert=256, - n_step=3, # Reduce the n-step to account for action-repeat. - observation_augmentation=image_augmentation.pad_and_crop, - ) - - # Launch experiment. - lp.launch(programs=program_builder.build()) - - -if __name__ == '__main__': - app.run(main) + # Configure the environment factory with requested task. + make_environment = functools.partial( + helpers.make_environment, + domain_name=_DOMAIN.value, + task_name=_TASK.value, + from_pixels=True, + frames_to_stack=3, + flatten_stack=True, + num_action_repeats=2, + ) + + # Construct the program. + program_builder = dmpo.DistributedDistributionalMPO( + make_environment, + make_networks, + target_policy_update_period=100, + max_actor_steps=_MAX_ACTOR_STEPS.value, + num_actors=4, + samples_per_insert=256, + n_step=3, # Reduce the n-step to account for action-repeat. + observation_augmentation=image_augmentation.pad_and_crop, + ) + + # Launch experiment. + lp.launch(programs=program_builder.build()) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/tf/control_suite/lp_mpo.py b/examples/tf/control_suite/lp_mpo.py index 63a05d5eaa..92231938a7 100644 --- a/examples/tf/control_suite/lp_mpo.py +++ b/examples/tf/control_suite/lp_mpo.py @@ -17,27 +17,26 @@ import functools from typing import Dict, Sequence -from absl import app -from absl import flags -from acme import specs -from acme import types -from acme.agents.tf import mpo import helpers -from acme.tf import networks -from acme.tf import utils as tf2_utils import launchpad as lp import numpy as np import sonnet as snt +from absl import app, flags +from acme import specs, types +from acme.agents.tf import mpo +from acme.tf import networks +from acme.tf import utils as tf2_utils # Flags which modify the behavior of the launcher. FLAGS = flags.FLAGS _MAX_ACTOR_STEPS = flags.DEFINE_integer( - 'max_actor_steps', None, - 'Number of actor steps to run; defaults to None for an endless loop.') -_DOMAIN = flags.DEFINE_string('domain', 'cartpole', - 'Control suite domain name.') -_TASK = flags.DEFINE_string('task', 'balance', 'Control suite task name.') + "max_actor_steps", + None, + "Number of actor steps to run; defaults to None for an endless loop.", +) +_DOMAIN = flags.DEFINE_string("domain", "cartpole", "Control suite domain name.") +_TASK = flags.DEFINE_string("task", "balance", "Control suite task name.") def make_networks( @@ -45,49 +44,55 @@ def make_networks( policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), ) -> Dict[str, types.TensorTransformation]: - """Creates networks used by the agent.""" - - num_dimensions = np.prod(action_spec.shape, dtype=int) - - policy_network = snt.Sequential([ - networks.LayerNormMLP(policy_layer_sizes, activate_final=True), - networks.MultivariateNormalDiagHead( - num_dimensions, init_scale=0.7, use_tfd_independent=True) - ]) - - # The multiplexer concatenates the (maybe transformed) observations/actions. - multiplexer = networks.CriticMultiplexer( - action_network=networks.ClipToSpec(action_spec)) - critic_network = snt.Sequential([ - multiplexer, - networks.LayerNormMLP(critic_layer_sizes, activate_final=True), - networks.NearZeroInitializedLinear(1), - ]) - - return { - 'policy': policy_network, - 'critic': critic_network, - 'observation': tf2_utils.batch_concat, - } + """Creates networks used by the agent.""" + + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential( + [ + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, init_scale=0.7, use_tfd_independent=True + ), + ] + ) + + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer( + action_network=networks.ClipToSpec(action_spec) + ) + critic_network = snt.Sequential( + [ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ] + ) + + return { + "policy": policy_network, + "critic": critic_network, + "observation": tf2_utils.batch_concat, + } def main(_): - # Configure the environment factory with requested task. - make_environment = functools.partial( - helpers.make_environment, - domain_name=_DOMAIN.value, - task_name=_TASK.value) + # Configure the environment factory with requested task. + make_environment = functools.partial( + helpers.make_environment, domain_name=_DOMAIN.value, task_name=_TASK.value + ) - # Construct the program. - program_builder = mpo.DistributedMPO( - make_environment, - make_networks, - target_policy_update_period=25, - max_actor_steps=_MAX_ACTOR_STEPS.value, - num_actors=4) + # Construct the program. + program_builder = mpo.DistributedMPO( + make_environment, + make_networks, + target_policy_update_period=25, + max_actor_steps=_MAX_ACTOR_STEPS.value, + num_actors=4, + ) - lp.launch(programs=program_builder.build()) + lp.launch(programs=program_builder.build()) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/setup.py b/setup.py index f14377ab50..1379e231e6 100755 --- a/setup.py +++ b/setup.py @@ -15,16 +15,15 @@ """Install script for setuptools.""" import datetime -from importlib import util as import_util import os import sys +from importlib import util as import_util -from setuptools import find_packages -from setuptools import setup import setuptools.command.build_py import setuptools.command.develop +from setuptools import find_packages, setup -spec = import_util.spec_from_file_location('_metadata', 'acme/_metadata.py') +spec = import_util.spec_from_file_location("_metadata", "acme/_metadata.py") _metadata = import_util.module_from_spec(spec) spec.loader.exec_module(_metadata) @@ -37,54 +36,51 @@ # sure this constraint is upheld. tensorflow = [ - 'tensorflow==2.8.0', - 'tensorflow_probability==0.15.0', - 'tensorflow_datasets==4.6.0', - 'dm-reverb==0.7.2', - 'dm-launchpad==0.5.2', + "tensorflow==2.8.0", + "tensorflow_probability==0.15.0", + "tensorflow_datasets==4.6.0", + "dm-reverb==0.7.2", + "dm-launchpad==0.5.2", ] core_requirements = [ - 'absl-py', - 'dm-env', - 'dm-tree', - 'numpy', - 'pillow', - 'typing-extensions', + "absl-py", + "dm-env", + "dm-tree", + "numpy", + "pillow", + "typing-extensions", ] jax_requirements = [ - 'jax>=0.4.3', - 'chex', - 'dm-haiku', - 'flax', - 'optax', - 'rlax', + "jax>=0.4.3", + "chex", + "dm-haiku", + "flax", + "optax", + "rlax", ] + tensorflow -tf_requirements = [ - 'dm-sonnet', - 'trfl', -] + tensorflow +tf_requirements = ["dm-sonnet", "trfl",] + tensorflow testing_requirements = [ - 'pytype==2021.8.11', # TODO(b/206926677): update to new version. - 'pytest-xdist', + "pytype==2021.8.11", # TODO(b/206926677): update to new version. + "pytest-xdist", ] envs_requirements = [ - 'atari-py', - 'bsuite', - 'dm-control', - 'gym==0.25.0', - 'gym[atari]', - 'pygame==2.1.0', - 'rlds', + "atari-py", + "bsuite", + "dm-control", + "gym==0.25.0", + "gym[atari]", + "pygame==2.1.0", + "rlds", ] def generate_requirements_file(path=None): - """Generates requirements.txt file with the Acme's dependencies. + """Generates requirements.txt file with the Acme's dependencies. It is used by Launchpad GCP runtime to generate Acme requirements to be installed inside the docker image. Acme itself is not installed from pypi, @@ -94,12 +90,13 @@ def generate_requirements_file(path=None): Args: path: path to the requirements.txt file to generate. """ - if not path: - path = os.path.join(os.path.dirname(__file__), 'acme/requirements.txt') - with open(path, 'w') as f: - for package in set(core_requirements + jax_requirements + tf_requirements + - envs_requirements): - f.write(f'{package}\n') + if not path: + path = os.path.join(os.path.dirname(__file__), "acme/requirements.txt") + with open(path, "w") as f: + for package in set( + core_requirements + jax_requirements + tf_requirements + envs_requirements + ): + f.write(f"{package}\n") long_description = """Acme is a library of reinforcement learning (RL) agents @@ -115,58 +112,57 @@ def generate_requirements_file(path=None): version = _metadata.__version__ # If we're releasing a nightly/dev version append to the version string. -if '--nightly' in sys.argv: - sys.argv.remove('--nightly') - version += '.dev' + datetime.datetime.now().strftime('%Y%m%d') +if "--nightly" in sys.argv: + sys.argv.remove("--nightly") + version += ".dev" + datetime.datetime.now().strftime("%Y%m%d") class BuildPy(setuptools.command.build_py.build_py): - - def run(self): - generate_requirements_file() - setuptools.command.build_py.build_py.run(self) + def run(self): + generate_requirements_file() + setuptools.command.build_py.build_py.run(self) class Develop(setuptools.command.develop.develop): + def run(self): + generate_requirements_file() + setuptools.command.develop.develop.run(self) - def run(self): - generate_requirements_file() - setuptools.command.develop.develop.run(self) cmdclass = { - 'build_py': BuildPy, - 'develop': Develop, + "build_py": BuildPy, + "develop": Develop, } setup( - name='dm-acme', + name="dm-acme", version=version, cmdclass=cmdclass, - description='A Python library for Reinforcement Learning.', + description="A Python library for Reinforcement Learning.", long_description=long_description, - long_description_content_type='text/markdown', - author='DeepMind', - license='Apache License, Version 2.0', - keywords='reinforcement-learning python machine learning', + long_description_content_type="text/markdown", + author="DeepMind", + license="Apache License, Version 2.0", + keywords="reinforcement-learning python machine learning", packages=find_packages(), - package_data={'': ['requirements.txt']}, + package_data={"": ["requirements.txt"]}, include_package_data=True, install_requires=core_requirements, extras_require={ - 'jax': jax_requirements, - 'tf': tf_requirements, - 'testing': testing_requirements, - 'envs': envs_requirements, + "jax": jax_requirements, + "tf": tf_requirements, + "testing": testing_requirements, + "envs": envs_requirements, }, classifiers=[ - 'Development Status :: 3 - Alpha', - 'Environment :: Console', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', + "Development Status :: 3 - Alpha", + "Environment :: Console", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Topic :: Scientific/Engineering :: Artificial Intelligence", ], )