Skip to content

Commit c21d78b

Browse files
committed
add trained weights
1 parent 6ae10b3 commit c21d78b

File tree

3 files changed

+21
-43
lines changed

3 files changed

+21
-43
lines changed

pufferlib/ocean/g2048/g2048.c

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,25 @@
11
#include "g2048.h"
22
#include "puffernet.h"
33

4+
// Network with hidden size 256. Should go to puffernet
5+
LinearLSTM* make_linearlstm_256(Weights* weights, int num_agents, int input_dim, int logit_sizes[], int num_actions) {
6+
LinearLSTM* net = calloc(1, sizeof(LinearLSTM));
7+
net->num_agents = num_agents;
8+
net->obs = calloc(num_agents*input_dim, sizeof(float));
9+
int hidden_dim = 256;
10+
net->encoder = make_linear(weights, num_agents, input_dim, hidden_dim);
11+
net->gelu1 = make_gelu(num_agents, hidden_dim);
12+
int atn_sum = 0;
13+
for (int i = 0; i < num_actions; i++) {
14+
atn_sum += logit_sizes[i];
15+
}
16+
net->actor = make_linear(weights, num_agents, hidden_dim, atn_sum);
17+
net->value_fn = make_linear(weights, num_agents, hidden_dim, 1);
18+
net->lstm = make_lstm(weights, num_agents, hidden_dim, hidden_dim);
19+
net->multidiscrete = make_multidiscrete(num_agents, logit_sizes, num_actions);
20+
return net;
21+
}
22+
423
int main() {
524
srand(time(NULL));
625
Game env;
@@ -14,9 +33,9 @@ int main() {
1433
env.actions = actions;
1534
env.rewards = rewards;
1635

17-
Weights* weights = load_weights("resources/g2048/g2048_weights.bin", 134917);
36+
Weights* weights = load_weights("resources/g2048/g2048_weights.bin", 531973);
1837
int logit_sizes[1] = {4};
19-
LinearLSTM* net = make_linearlstm(weights, 1, 16, logit_sizes, 1);
38+
LinearLSTM* net = make_linearlstm_256(weights, 1, 16, logit_sizes, 1);
2039
c_reset(&env);
2140
c_render(&env);
2241

pufferlib/ocean/torch.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -190,47 +190,6 @@ def decode_actions(self, hidden):
190190
return action, value
191191

192192

193-
class G2048(nn.Module):
194-
def __init__(self, env, cnn_channels=32, hidden_size=128):
195-
super().__init__()
196-
self.hidden_size = hidden_size
197-
self.is_continuous = False
198-
199-
self.cnn = nn.Sequential(
200-
pufferlib.pytorch.layer_init(
201-
nn.Conv2d(1, cnn_channels, 2, stride=1)),
202-
nn.GELU(),
203-
pufferlib.pytorch.layer_init(
204-
nn.Conv2d(cnn_channels, cnn_channels, 2, stride=1)),
205-
nn.Flatten(),
206-
nn.GELU(),
207-
pufferlib.pytorch.layer_init(
208-
nn.Linear(128, hidden_size), std=0.01),
209-
)
210-
211-
self.decoder = pufferlib.pytorch.layer_init(
212-
nn.Linear(hidden_size, env.single_action_space.n), std=0.01)
213-
self.value = pufferlib.pytorch.layer_init(
214-
nn.Linear(hidden_size, 1), std=1)
215-
216-
def forward_eval(self, observations, state=None):
217-
hidden = self.encode_observations(observations)
218-
actions, value = self.decode_actions(hidden)
219-
return actions, value
220-
221-
def forward(self, x, state=None):
222-
return self.forward_eval(x, state)
223-
224-
def encode_observations(self, observations, state=None):
225-
#observations = F.one_hot(observations.long(), 16).view(-1, 16, 4, 4).float()
226-
observations = observations.float().view(-1, 1, 4, 4)
227-
return self.cnn(observations)
228-
229-
def decode_actions(self, hidden):
230-
action = self.decoder(hidden)
231-
value = self.value(hidden)
232-
return action, value
233-
234193
class Snake(nn.Module):
235194
def __init__(self, env, cnn_channels=32, hidden_size=128):
236195
super().__init__()
1.51 MB
Binary file not shown.

0 commit comments

Comments
 (0)