Skip to content

Commit 4376eb7

Browse files
committed
Fix recursive Crazyhouse move generation (fixes #893)
1 parent d9cee54 commit 4376eb7

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

chess/variant.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -829,39 +829,40 @@ class CrazyhousePocket:
829829
"""A Crazyhouse pocket with a counter for each piece type."""
830830

831831
def __init__(self, symbols: Iterable[str] = "") -> None:
832-
self.pieces: Dict[chess.PieceType, int] = {}
832+
self.reset()
833833
for symbol in symbols:
834834
self.add(chess.PIECE_SYMBOLS.index(symbol))
835835

836+
def reset(self) -> None:
837+
"""Clears the pocket."""
838+
self._pieces = [None, 0, 0, 0, 0, 0, 0]
839+
836840
def add(self, piece_type: chess.PieceType) -> None:
837841
"""Adds a piece of the given type to this pocket."""
838-
self.pieces[piece_type] = self.pieces.get(piece_type, 0) + 1
842+
self._pieces[piece_type] += 1
839843

840844
def remove(self, piece_type: chess.PieceType) -> None:
841845
"""Removes a piece of the given type from this pocket."""
842-
self.pieces[piece_type] -= 1
846+
assert self._pieces[piece_type], f"cannot remove {chess.piece_symbol(piece_type)} from {self!r}"
847+
self._pieces[piece_type] -= 1
843848

844849
def count(self, piece_type: chess.PieceType) -> int:
845850
"""Returns the number of pieces of the given type in the pocket."""
846-
return self.pieces.get(piece_type, 0)
847-
848-
def reset(self) -> None:
849-
"""Clears the pocket."""
850-
self.pieces.clear()
851+
return self._pieces[piece_type]
851852

852853
def __str__(self) -> str:
853854
return "".join(chess.piece_symbol(pt) * self.count(pt) for pt in reversed(chess.PIECE_TYPES))
854855

855856
def __len__(self) -> int:
856-
return sum(self.pieces.values())
857+
return sum(self._pieces[1:])
857858

858859
def __repr__(self) -> str:
859860
return f"CrazyhousePocket('{self}')"
860861

861862
def copy(self: CrazyhousePocketT) -> CrazyhousePocketT:
862863
"""Returns a copy of this pocket."""
863864
pocket = type(self)()
864-
pocket.pieces = copy.copy(self.pieces)
865+
pocket._pieces = self._pieces[:]
865866
return pocket
866867

867868
class CrazyhouseBoard(chess.Board):
@@ -959,9 +960,9 @@ def is_legal(self, move: chess.Move) -> bool:
959960
return super().is_legal(move)
960961

961962
def generate_pseudo_legal_drops(self, to_mask: chess.Bitboard = chess.BB_ALL) -> Iterator[chess.Move]:
962-
for to_square in chess.scan_forward(to_mask & ~self.occupied):
963-
for pt, count in self.pockets[self.turn].pieces.items():
964-
if count and (pt != chess.PAWN or not chess.BB_BACKRANKS & chess.BB_SQUARES[to_square]):
963+
for pt in chess.PIECE_TYPES:
964+
if self.pockets[self.turn].count(pt):
965+
for to_square in chess.scan_forward(to_mask & ~self.occupied & (~chess.BB_BACKRANKS if pt == chess.PAWN else chess.BB_ALL)):
965966
yield chess.Move(to_square, to_square, drop=pt)
966967

967968
def generate_legal_drops(self, to_mask: chess.Bitboard = chess.BB_ALL) -> Iterator[chess.Move]:

examples/perft/crazyhouse.perft

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,10 @@ perft 1 20
2323
perft 2 360
2424
perft 3 5445
2525
perft 4 132758
26+
27+
id zh-midgame
28+
epd 2rn1b1r/1pp2n1p/B1PBk1pP/5p1P/P5p1/2P2RP1/RP1QNP2/1NBK4[QPp] w - -
29+
perft 1 99
30+
perft 2 3932
31+
perft 3 314782
32+
perft 4 10118606

0 commit comments

Comments
 (0)