Skip to content

Commit e549244

Browse files
committed
simplify group init and sprite operations with recursive _SpriteOrIterable type
1 parent 984c780 commit e549244

File tree

1 file changed

+13
-32
lines changed

1 file changed

+13
-32
lines changed

buildconfig/stubs/pygame/sprite.pyi

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ _TSprite2 = TypeVar("_TSprite2", bound=_SpriteSupportsGroup)
135135
# almost the same as _TSprite but bound to DirtySprite
136136
_TDirtySprite = TypeVar("_TDirtySprite", bound=_DirtySpriteSupportsGroup)
137137

138+
_SpriteOrIterable = Union[_TSprite, Iterable[_SpriteOrIterable]]
139+
138140
# Below code demonstrates the advantages of the _SpriteSupportsGroup protocol
139141

140142
# typechecker should error, regular Sprite does not support Group.draw due to
@@ -158,7 +160,9 @@ _TDirtySprite = TypeVar("_TDirtySprite", bound=_DirtySpriteSupportsGroup)
158160
class AbstractGroup(Generic[_TSprite]):
159161
spritedict: dict[_TSprite, Optional[Union[FRect, Rect]]]
160162
lostsprites: list[Union[FRect, Rect]]
161-
def __class_getitem__(cls, item: type[_SupportsSpriteGroup], /) -> types.GenericAlias: ...
163+
def __class_getitem__(
164+
cls, item: type[_SupportsSpriteGroup], /
165+
) -> types.GenericAlias: ...
162166
def __init__(self) -> None: ...
163167
def __len__(self) -> int: ...
164168
def __iter__(self) -> Iterator[_TSprite]: ...
@@ -169,15 +173,9 @@ class AbstractGroup(Generic[_TSprite]):
169173
def has_internal(self, sprite: _TSprite) -> bool: ...
170174
def copy(self) -> Self: ...
171175
def sprites(self) -> list[_TSprite]: ...
172-
def add(
173-
self, *sprites: Union[_TSprite, AbstractGroup[_TSprite], Iterable[_TSprite]]
174-
) -> None: ...
175-
def remove(
176-
self, *sprites: Union[_TSprite, AbstractGroup[_TSprite], Iterable[_TSprite]]
177-
) -> None: ...
178-
def has(
179-
self, *sprites: Union[_TSprite, AbstractGroup[_TSprite], Iterable[_TSprite]]
180-
) -> bool: ...
176+
def add(self, *sprites: _SpriteOrIterable[_TSprite]) -> None: ...
177+
def remove(self, *sprites: _SpriteOrIterable[_TSprite]) -> None: ...
178+
def has(self, *sprites: _SpriteOrIterable[_TSprite]) -> bool: ...
181179
def update(self, *args: Any, **kwargs: Any) -> None: ...
182180
def draw(
183181
self, surface: Surface, bgd: Optional[Surface] = None, special_flags: int = 0
@@ -190,16 +188,14 @@ class AbstractGroup(Generic[_TSprite]):
190188
def empty(self) -> None: ...
191189

192190
class Group(AbstractGroup[_TSprite]):
193-
def __init__(
194-
self, *sprites: Union[_TSprite, AbstractGroup[_TSprite], Iterable[_TSprite]]
195-
) -> None: ...
191+
def __init__(self, *sprites: _SpriteOrIterable[_TSprite]) -> None: ...
196192

197193
# these are aliased in the code too
198194
@deprecated("Use `pygame.sprite.Group` instead")
199-
class RenderPlain(Group): ...
195+
class RenderPlain(Group[_TSprite]): ...
200196

201197
@deprecated("Use `pygame.sprite.Group` instead")
202-
class RenderClear(Group): ...
198+
class RenderClear(Group[_TSprite]): ...
203199

204200
class RenderUpdates(Group[_TSprite]): ...
205201

@@ -208,23 +204,9 @@ class OrderedUpdates(RenderUpdates[_TSprite]): ...
208204

209205
class LayeredUpdates(AbstractGroup[_TSprite]):
210206
def __init__(
211-
self,
212-
*sprites: Union[
213-
_TSprite,
214-
AbstractGroup[_TSprite],
215-
Iterable[Union[_TSprite, AbstractGroup[_TSprite]]],
216-
],
217-
**kwargs: Any,
218-
) -> None: ...
219-
def add(
220-
self,
221-
*sprites: Union[
222-
_TSprite,
223-
AbstractGroup[_TSprite],
224-
Iterable[Union[_TSprite, AbstractGroup[_TSprite]]],
225-
],
226-
**kwargs: Any,
207+
self, *sprites: _SpriteOrIterable[_TSprite], **kwargs: Any
227208
) -> None: ...
209+
def add(self, *sprites: _SpriteOrIterable[_TSprite], **kwargs: Any) -> None: ...
228210
def get_sprites_at(self, pos: Point) -> list[_TSprite]: ...
229211
def get_sprite(self, idx: int) -> _TSprite: ...
230212
def remove_sprites_of_layer(self, layer_nr: int) -> list[_TSprite]: ...
@@ -240,7 +222,6 @@ class LayeredUpdates(AbstractGroup[_TSprite]):
240222
def switch_layer(self, layer1_nr: int, layer2_nr: int) -> None: ...
241223

242224
class LayeredDirty(LayeredUpdates[_TDirtySprite]):
243-
def __init__(self, *sprites: _TDirtySprite, **kwargs: Any) -> None: ...
244225
def draw(
245226
self,
246227
surface: Surface,

0 commit comments

Comments
 (0)