|
| 1 | +import numpy as np |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | + |
| 6 | +import pufferlib |
1 | 7 | from pufferlib.models import Default as Policy |
2 | 8 | from pufferlib.models import LSTMWrapper as Recurrent |
| 9 | + |
| 10 | +class FakePolicy(nn.Module): |
| 11 | + '''Default PyTorch policy. Flattens obs and applies a linear layer. |
| 12 | +
|
| 13 | + PufferLib is not a framework. It does not enforce a base class. |
| 14 | + You can use any PyTorch policy that returns actions and values. |
| 15 | + We structure our forward methods as encode_observations and decode_actions |
| 16 | + to make it easier to wrap policies with LSTMs. You can do that and use |
| 17 | + our LSTM wrapper or implement your own. To port an existing policy |
| 18 | + for use with our LSTM wrapper, simply put everything from forward() before |
| 19 | + the recurrent cell into encode_observations and put everything after |
| 20 | + into decode_actions. |
| 21 | + ''' |
| 22 | + def __init__(self, env, hidden_size=256): |
| 23 | + super().__init__() |
| 24 | + self.hidden_size = hidden_size |
| 25 | + |
| 26 | + n_obs = np.prod(env.single_observation_space.shape) |
| 27 | + n_atn = env.single_action_space.shape[0] |
| 28 | + self.decoder_mean = nn.Sequential( |
| 29 | + pufferlib.pytorch.layer_init(nn.Linear(n_obs, 256)), |
| 30 | + nn.Tanh(), |
| 31 | + pufferlib.pytorch.layer_init(nn.Linear(256, 256)), |
| 32 | + nn.Tanh(), |
| 33 | + pufferlib.pytorch.layer_init(nn.Linear(256, 256)), |
| 34 | + nn.Tanh(), |
| 35 | + pufferlib.pytorch.layer_init(nn.Linear(256, n_atn), std=0.01), |
| 36 | + ) |
| 37 | + self.decoder_logstd = nn.Parameter(torch.zeros( |
| 38 | + 1, env.single_action_space.shape[0])) |
| 39 | + |
| 40 | + self.value = nn.Sequential( |
| 41 | + pufferlib.pytorch.layer_init(nn.Linear(n_obs, 256)), |
| 42 | + nn.Tanh(), |
| 43 | + pufferlib.pytorch.layer_init(nn.Linear(256, 256)), |
| 44 | + nn.Tanh(), |
| 45 | + pufferlib.pytorch.layer_init(nn.Linear(256, 256)), |
| 46 | + nn.Tanh(), |
| 47 | + pufferlib.pytorch.layer_init(nn.Linear(256, 1), std=1), |
| 48 | + ) |
| 49 | + |
| 50 | + def forward_eval(self, observations, state=None): |
| 51 | + hidden = self.encode_observations(observations, state=state) |
| 52 | + logits, values = self.decode_actions(hidden) |
| 53 | + return logits, values |
| 54 | + |
| 55 | + def forward(self, observations, state=None): |
| 56 | + return self.forward_eval(observations, state) |
| 57 | + |
| 58 | + def encode_observations(self, observations, state=None): |
| 59 | + '''Encodes a batch of observations into hidden states. Assumes |
| 60 | + no time dimension (handled by LSTM wrappers).''' |
| 61 | + return observations |
| 62 | + |
| 63 | + def decode_actions(self, hidden): |
| 64 | + '''Decodes a batch of hidden states into (multi)discrete actions. |
| 65 | + Assumes no time dimension (handled by LSTM wrappers).''' |
| 66 | + mean = self.decoder_mean(hidden) |
| 67 | + logstd = self.decoder_logstd.expand_as(mean) |
| 68 | + std = torch.exp(logstd) |
| 69 | + logits = torch.distributions.Normal(mean, std) |
| 70 | + values = self.value(hidden) |
| 71 | + return logits, values |
0 commit comments