99import chex
1010from typing import Tuple , Dict
1111from functools import partial
12- from jaxmarl .environments .spaces import Discrete
12+ from jaxmarl .environments .spaces import Discrete , Box
1313from .hanabi_game import HanabiGame , State
1414
1515
@@ -133,7 +133,7 @@ def __init__(
133133 if action_spaces is None :
134134 self .action_spaces = {i : Discrete (self .num_moves ) for i in self .agents }
135135 if observation_spaces is None :
136- self .observation_spaces = {i : Discrete ( self .obs_size ) for i in self .agents }
136+ self .observation_spaces = {i : Box ( low = 0 , high = 1 , shape = self .obs_size ) for i in self .agents }
137137
138138 @partial (jax .jit , static_argnums = [0 ])
139139 def reset (self , key : chex .PRNGKey ) -> Tuple [Dict , State ]:
@@ -194,7 +194,7 @@ def get_obs(
194194 """Get all agents' observations."""
195195
196196 # no agent-specific obs
197- board_fats = self .get_board_feats (new_state )
197+ board_feats = self .get_board_feats (new_state )
198198 discard_feats = self ._binarize_discard_pile (new_state .discard_pile )
199199
200200 def _observe (aidx : int ):
@@ -225,7 +225,7 @@ def _observe(aidx: int):
225225 return jnp .concatenate (
226226 (
227227 hands_feats ,
228- board_fats ,
228+ board_feats ,
229229 discard_feats ,
230230 last_action_feats ,
231231 belief_v0_feats ,
@@ -808,4 +808,4 @@ def get_card_knowledge_str(card_idx: int) -> str:
808808 legal_actions = [self .action_encoding [int (a )] for a in np .where (legal_moves )[0 ]]
809809 output += f"Legal Actions: { legal_actions } \n "
810810
811- return output
811+ return output
0 commit comments