Skip to content

Commit 08bf97f

Browse files
Merge pull request FLAIROx#138 from FLAIROx/hanabi-observation-bug-fix
Fixed issue FLAIROx#132
2 parents 8dceff9 + 3a35186 commit 08bf97f

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

jaxmarl/environments/hanabi/hanabi.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import chex
1010
from typing import Tuple, Dict
1111
from functools import partial
12-
from jaxmarl.environments.spaces import Discrete
12+
from jaxmarl.environments.spaces import Discrete, Box
1313
from .hanabi_game import HanabiGame, State
1414

1515

@@ -133,7 +133,7 @@ def __init__(
133133
if action_spaces is None:
134134
self.action_spaces = {i: Discrete(self.num_moves) for i in self.agents}
135135
if observation_spaces is None:
136-
self.observation_spaces = {i: Discrete(self.obs_size) for i in self.agents}
136+
self.observation_spaces = {i: Box(low=0, high=1, shape=self.obs_size) for i in self.agents}
137137

138138
@partial(jax.jit, static_argnums=[0])
139139
def reset(self, key: chex.PRNGKey) -> Tuple[Dict, State]:
@@ -194,7 +194,7 @@ def get_obs(
194194
"""Get all agents' observations."""
195195

196196
# no agent-specific obs
197-
board_fats = self.get_board_feats(new_state)
197+
board_feats = self.get_board_feats(new_state)
198198
discard_feats = self._binarize_discard_pile(new_state.discard_pile)
199199

200200
def _observe(aidx: int):
@@ -225,7 +225,7 @@ def _observe(aidx: int):
225225
return jnp.concatenate(
226226
(
227227
hands_feats,
228-
board_fats,
228+
board_feats,
229229
discard_feats,
230230
last_action_feats,
231231
belief_v0_feats,
@@ -808,4 +808,4 @@ def get_card_knowledge_str(card_idx: int) -> str:
808808
legal_actions = [self.action_encoding[int(a)] for a in np.where(legal_moves)[0]]
809809
output += f"Legal Actions: {legal_actions}\n"
810810

811-
return output
811+
return output

0 commit comments

Comments
 (0)