@@ -901,3 +901,65 @@ def decode_actions(self, hidden):
901901
902902 values = self .value (hidden )
903903 return logits , values
904+
905+
906+ class G2048 (nn .Module ):
907+ def __init__ (self , env , hidden_size = 128 ):
908+ super ().__init__ ()
909+ self .hidden_size = hidden_size
910+ self .is_continuous = False
911+
912+ num_obs = np .prod (env .single_observation_space .shape )
913+
914+ if hidden_size <= 256 :
915+ self .encoder = torch .nn .Sequential (
916+ pufferlib .pytorch .layer_init (nn .Linear (num_obs , 512 )),
917+ nn .GELU (),
918+ pufferlib .pytorch .layer_init (nn .Linear (512 , 256 )),
919+ nn .GELU (),
920+ pufferlib .pytorch .layer_init (nn .Linear (256 , hidden_size )),
921+ nn .GELU (),
922+ )
923+ else :
924+ self .encoder = torch .nn .Sequential (
925+ pufferlib .pytorch .layer_init (nn .Linear (num_obs , 2 * hidden_size )),
926+ nn .GELU (),
927+ pufferlib .pytorch .layer_init (nn .Linear (2 * hidden_size , hidden_size )),
928+ nn .GELU (),
929+ pufferlib .pytorch .layer_init (nn .Linear (hidden_size , hidden_size )),
930+ nn .GELU (),
931+ )
932+
933+ num_atns = env .single_action_space .n
934+ self .decoder = torch .nn .Sequential (
935+ pufferlib .pytorch .layer_init (nn .Linear (hidden_size , hidden_size )),
936+ nn .GELU (),
937+ pufferlib .pytorch .layer_init (nn .Linear (hidden_size , num_atns ), std = 0.01 ),
938+ )
939+ self .value = torch .nn .Sequential (
940+ pufferlib .pytorch .layer_init (nn .Linear (hidden_size , hidden_size )),
941+ nn .GELU (),
942+ pufferlib .pytorch .layer_init (nn .Linear (hidden_size , 1 ), std = 1.0 ),
943+ )
944+
945+ def forward_eval (self , observations , state = None ):
946+ hidden = self .encode_observations (observations , state = state )
947+ logits , values = self .decode_actions (hidden )
948+ return logits , values
949+
950+ def forward (self , observations , state = None ):
951+ return self .forward_eval (observations , state )
952+
953+ def encode_observations (self , observations , state = None ):
954+ batch_size = observations .shape [0 ]
955+ observations = observations .view (batch_size , - 1 ).float ()
956+
957+ # Scale the feat 1 (tile**1.5)
958+ observations [:, :16 ] = observations [:, :16 ] / 100.0
959+
960+ return self .encoder (observations )
961+
962+ def decode_actions (self , hidden ):
963+ logits = self .decoder (hidden )
964+ values = self .value (hidden )
965+ return logits , values
0 commit comments