@@ -829,39 +829,40 @@ class CrazyhousePocket:
829829 """A Crazyhouse pocket with a counter for each piece type."""
830830
831831 def __init__ (self , symbols : Iterable [str ] = "" ) -> None :
832- self .pieces : Dict [ chess . PieceType , int ] = {}
832+ self .reset ()
833833 for symbol in symbols :
834834 self .add (chess .PIECE_SYMBOLS .index (symbol ))
835835
836+ def reset (self ) -> None :
837+ """Clears the pocket."""
838+ self ._pieces = [None , 0 , 0 , 0 , 0 , 0 , 0 ]
839+
836840 def add (self , piece_type : chess .PieceType ) -> None :
837841 """Adds a piece of the given type to this pocket."""
838- self .pieces [piece_type ] = self . pieces . get ( piece_type , 0 ) + 1
842+ self ._pieces [piece_type ] += 1
839843
840844 def remove (self , piece_type : chess .PieceType ) -> None :
841845 """Removes a piece of the given type from this pocket."""
842- self .pieces [piece_type ] -= 1
846+ assert self ._pieces [piece_type ], f"cannot remove { chess .piece_symbol (piece_type )} from { self !r} "
847+ self ._pieces [piece_type ] -= 1
843848
844849 def count (self , piece_type : chess .PieceType ) -> int :
845850 """Returns the number of pieces of the given type in the pocket."""
846- return self .pieces .get (piece_type , 0 )
847-
848- def reset (self ) -> None :
849- """Clears the pocket."""
850- self .pieces .clear ()
851+ return self ._pieces [piece_type ]
851852
852853 def __str__ (self ) -> str :
853854 return "" .join (chess .piece_symbol (pt ) * self .count (pt ) for pt in reversed (chess .PIECE_TYPES ))
854855
855856 def __len__ (self ) -> int :
856- return sum (self .pieces . values () )
857+ return sum (self ._pieces [ 1 :] )
857858
858859 def __repr__ (self ) -> str :
859860 return f"CrazyhousePocket('{ self } ')"
860861
861862 def copy (self : CrazyhousePocketT ) -> CrazyhousePocketT :
862863 """Returns a copy of this pocket."""
863864 pocket = type (self )()
864- pocket .pieces = copy . copy ( self .pieces )
865+ pocket ._pieces = self ._pieces [:]
865866 return pocket
866867
867868class CrazyhouseBoard (chess .Board ):
@@ -959,9 +960,9 @@ def is_legal(self, move: chess.Move) -> bool:
959960 return super ().is_legal (move )
960961
961962 def generate_pseudo_legal_drops (self , to_mask : chess .Bitboard = chess .BB_ALL ) -> Iterator [chess .Move ]:
962- for to_square in chess .scan_forward ( to_mask & ~ self . occupied ) :
963- for pt , count in self .pockets [self .turn ].pieces . items ( ):
964- if count and ( pt != chess .PAWN or not chess .BB_BACKRANKS & chess .BB_SQUARES [ to_square ] ):
963+ for pt in chess .PIECE_TYPES :
964+ if self .pockets [self .turn ].count ( pt ):
965+ for to_square in chess . scan_forward ( to_mask & ~ self . occupied & ( ~ chess .BB_BACKRANKS if pt == chess .PAWN else chess .BB_ALL ) ):
965966 yield chess .Move (to_square , to_square , drop = pt )
966967
967968 def generate_legal_drops (self , to_mask : chess .Bitboard = chess .BB_ALL ) -> Iterator [chess .Move ]:
0 commit comments