@@ -190,47 +190,6 @@ def decode_actions(self, hidden):
190190 return action , value
191191
192192
193- class G2048 (nn .Module ):
194- def __init__ (self , env , cnn_channels = 32 , hidden_size = 128 ):
195- super ().__init__ ()
196- self .hidden_size = hidden_size
197- self .is_continuous = False
198-
199- self .cnn = nn .Sequential (
200- pufferlib .pytorch .layer_init (
201- nn .Conv2d (1 , cnn_channels , 2 , stride = 1 )),
202- nn .GELU (),
203- pufferlib .pytorch .layer_init (
204- nn .Conv2d (cnn_channels , cnn_channels , 2 , stride = 1 )),
205- nn .Flatten (),
206- nn .GELU (),
207- pufferlib .pytorch .layer_init (
208- nn .Linear (128 , hidden_size ), std = 0.01 ),
209- )
210-
211- self .decoder = pufferlib .pytorch .layer_init (
212- nn .Linear (hidden_size , env .single_action_space .n ), std = 0.01 )
213- self .value = pufferlib .pytorch .layer_init (
214- nn .Linear (hidden_size , 1 ), std = 1 )
215-
216- def forward_eval (self , observations , state = None ):
217- hidden = self .encode_observations (observations )
218- actions , value = self .decode_actions (hidden )
219- return actions , value
220-
221- def forward (self , x , state = None ):
222- return self .forward_eval (x , state )
223-
224- def encode_observations (self , observations , state = None ):
225- #observations = F.one_hot(observations.long(), 16).view(-1, 16, 4, 4).float()
226- observations = observations .float ().view (- 1 , 1 , 4 , 4 )
227- return self .cnn (observations )
228-
229- def decode_actions (self , hidden ):
230- action = self .decoder (hidden )
231- value = self .value (hidden )
232- return action , value
233-
234193class Snake (nn .Module ):
235194 def __init__ (self , env , cnn_channels = 32 , hidden_size = 128 ):
236195 super ().__init__ ()
0 commit comments