Skip to content

Commit 0a9d97e

Browse files
leonardhussenotcopybara-github
authored andcommitted
Adding a wrapper that concatenates observation in a single tensor.
It allows to run seamlessly on DMControl and Gym environments. PiperOrigin-RevId: 413688077 Change-Id: Id59f6bf2800088c71438e3a0e0eaa5d9debdf9ed
1 parent 32156b2 commit 0a9d97e

File tree

3 files changed

+99
-2
lines changed

3 files changed

+99
-2
lines changed

acme/wrappers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from acme.wrappers.base import EnvironmentWrapper
2020
from acme.wrappers.base import wrap_all
2121
from acme.wrappers.canonical_spec import CanonicalSpecWrapper
22+
from acme.wrappers.concatenate_observations import ConcatObservationWrapper
2223
from acme.wrappers.frame_stacking import FrameStackingWrapper
2324
from acme.wrappers.gym_wrapper import GymAtariAdapter
2425
from acme.wrappers.gym_wrapper import GymWrapper
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# python3
2+
# Copyright 2018 DeepMind Technologies Limited. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Wrapper that implements concatenation of observation fields."""
17+
18+
from typing import Sequence, Optional
19+
20+
from acme import types
21+
from acme.jax import utils
22+
from acme.wrappers import base
23+
import dm_env
24+
import numpy as np
25+
import tree
26+
27+
28+
def _concat(values: types.NestedArray) -> np.ndarray:
29+
"""Concatenates the leaves of `values` along the leading dimension.
30+
31+
Treats scalars as 1d arrays and expects that the shapes of all leaves are
32+
the same except for the leading dimension.
33+
34+
Args:
35+
values: the nested arrays to concatenate.
36+
Returns:
37+
The concatenated array.
38+
"""
39+
leaves = list(map(np.atleast_1d, tree.flatten(values)))
40+
return np.concatenate(leaves)
41+
42+
43+
class ConcatObservationWrapper(base.EnvironmentWrapper):
44+
"""Wrapper that concatenates observation fields.
45+
46+
It takes an environment with nested observations and concatenates the fields
47+
in a single tensor. The orginial fields should be 1-dimensional.
48+
Observation fields that are not in name_filter are dropped.
49+
"""
50+
51+
def __init__(self, environment: dm_env.Environment,
52+
name_filter: Optional[Sequence[str]] = None):
53+
"""Initializes a new ConcatObservationWrapper.
54+
55+
Args:
56+
environment: Environment to wrap.
57+
name_filter: Sequence of observation names to keep. None keeps them all.
58+
"""
59+
super().__init__(environment)
60+
observation_spec = environment.observation_spec()
61+
if name_filter is None:
62+
name_filter = list(observation_spec.keys())
63+
self._obs_names = [x for x in name_filter if x in observation_spec.keys()]
64+
65+
dummy_obs = utils.zeros_like(observation_spec)
66+
dummy_obs = self._convert_observation(dummy_obs)
67+
self._observation_spec = dm_env.specs.BoundedArray(
68+
shape=dummy_obs.shape,
69+
dtype=dummy_obs.dtype,
70+
minimum=-np.inf,
71+
maximum=np.inf,
72+
name='state')
73+
74+
def _convert_observation(self, observation):
75+
obs = {k: observation[k] for k in self._obs_names}
76+
return _concat(obs)
77+
78+
def step(self, action) -> dm_env.TimeStep:
79+
timestep = self._environment.step(action)
80+
return timestep._replace(
81+
observation=self._convert_observation(timestep.observation))
82+
83+
def reset(self) -> dm_env.TimeStep:
84+
timestep = self._environment.reset()
85+
return timestep._replace(
86+
observation=self._convert_observation(timestep.observation))
87+
88+
def observation_spec(self) -> types.NestedSpec:
89+
return self._observation_spec

examples/control/helpers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,19 @@
2222

2323
def make_environment(evaluation: bool = False,
2424
domain_name: str = 'cartpole',
25-
task_name: str = 'balance') -> dm_env.Environment:
25+
task_name: str = 'balance',
26+
concatenate_observations: bool = False
27+
) -> dm_env.Environment:
2628
"""Implements a control suite environment factory."""
2729
# Nothing special to be done for evaluation environment.
2830
del evaluation
2931

3032
environment = suite.load(domain_name, task_name)
3133
environment = wrappers.SinglePrecisionWrapper(environment)
32-
34+
timestep = environment.reset()
35+
obs_names = list(timestep.observation.keys())
36+
if concatenate_observations:
37+
environment = wrappers.ConcatObservationWrapper(environment, obs_names)
3338
return environment
39+
40+

0 commit comments

Comments
 (0)