Skip to content

Commit 888a16d

Browse files
committed
add model
1 parent 23f25ac commit 888a16d

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

pufferlib/ocean/torch.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,3 +901,65 @@ def decode_actions(self, hidden):
901901

902902
values = self.value(hidden)
903903
return logits, values
904+
905+
906+
class G2048(nn.Module):
907+
def __init__(self, env, hidden_size=128):
908+
super().__init__()
909+
self.hidden_size = hidden_size
910+
self.is_continuous = False
911+
912+
num_obs = np.prod(env.single_observation_space.shape)
913+
914+
if hidden_size <= 256:
915+
self.encoder = torch.nn.Sequential(
916+
pufferlib.pytorch.layer_init(nn.Linear(num_obs, 512)),
917+
nn.GELU(),
918+
pufferlib.pytorch.layer_init(nn.Linear(512, 256)),
919+
nn.GELU(),
920+
pufferlib.pytorch.layer_init(nn.Linear(256, hidden_size)),
921+
nn.GELU(),
922+
)
923+
else:
924+
self.encoder = torch.nn.Sequential(
925+
pufferlib.pytorch.layer_init(nn.Linear(num_obs, 2*hidden_size)),
926+
nn.GELU(),
927+
pufferlib.pytorch.layer_init(nn.Linear(2*hidden_size, hidden_size)),
928+
nn.GELU(),
929+
pufferlib.pytorch.layer_init(nn.Linear(hidden_size, hidden_size)),
930+
nn.GELU(),
931+
)
932+
933+
num_atns = env.single_action_space.n
934+
self.decoder = torch.nn.Sequential(
935+
pufferlib.pytorch.layer_init(nn.Linear(hidden_size, hidden_size)),
936+
nn.GELU(),
937+
pufferlib.pytorch.layer_init(nn.Linear(hidden_size, num_atns), std=0.01),
938+
)
939+
self.value = torch.nn.Sequential(
940+
pufferlib.pytorch.layer_init(nn.Linear(hidden_size, hidden_size)),
941+
nn.GELU(),
942+
pufferlib.pytorch.layer_init(nn.Linear(hidden_size, 1), std=1.0),
943+
)
944+
945+
def forward_eval(self, observations, state=None):
946+
hidden = self.encode_observations(observations, state=state)
947+
logits, values = self.decode_actions(hidden)
948+
return logits, values
949+
950+
def forward(self, observations, state=None):
951+
return self.forward_eval(observations, state)
952+
953+
def encode_observations(self, observations, state=None):
954+
batch_size = observations.shape[0]
955+
observations = observations.view(batch_size, -1).float()
956+
957+
# Scale the feat 1 (tile**1.5)
958+
observations[:, :16] = observations[:, :16] / 100.0
959+
960+
return self.encoder(observations)
961+
962+
def decode_actions(self, hidden):
963+
logits = self.decoder(hidden)
964+
values = self.value(hidden)
965+
return logits, values

0 commit comments

Comments
 (0)