Skip to content

Commit e4287d5

Browse files
committed
feat: Add QuerySetT type parameter to InheritanceManagerMixin
Add support for custom QuerySet types in InheritanceManagerMixin by introducing a second type parameter QuerySetT with a default value. This enables subclasses to specify their own QuerySet type while maintaining full backwards compatibility with existing code. Changes: - Add QuerySetT TypeVar with default=InheritanceQuerySet[ModelT] - Update InheritanceManagerMixin to be Generic[ModelT, QuerySetT] - Update all return types in InheritanceManagerMixin to use QuerySetT - Add typing_extensions dependency for Python < 3.13 - Add docstring with usage examples Backwards compatibility: - Existing InheritanceManager[MyModel] code works unchanged - QuerySetT defaults to InheritanceQuerySet[ModelT] when not specified Usage example: class MyQuerySet(InheritanceQuerySetMixin[MyModel], QuerySet[MyModel]): def custom_filter(self) -> Self: return self.filter(active=True) class MyManager(InheritanceManager[MyModel, MyQuerySet]): _queryset_class = MyQuerySet
1 parent 62a49f1 commit e4287d5

File tree

2 files changed

+72
-30
lines changed

2 files changed

+72
-30
lines changed

model_utils/managers.py

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3+
import sys
34
import warnings
4-
from typing import TYPE_CHECKING, Any, Generic, Sequence, TypeVar, cast, overload
5+
from typing import TYPE_CHECKING, Any, Generic, Sequence, cast, overload
56

67
from django.core.exceptions import ObjectDoesNotExist
78
from django.db import connection, models
@@ -10,8 +11,21 @@
1011
from django.db.models.query import ModelIterable, QuerySet
1112
from django.db.models.sql.datastructures import Join
1213

14+
if sys.version_info >= (3, 13):
15+
from typing import TypeVar
16+
else:
17+
from typing_extensions import TypeVar
18+
1319
ModelT = TypeVar('ModelT', bound=models.Model, covariant=True)
1420

21+
# TypeVar for QuerySet with default (enables backwards compatibility)
22+
# When only ModelT is specified, QuerySetT defaults to InheritanceQuerySet[ModelT]
23+
QuerySetT = TypeVar(
24+
'QuerySetT',
25+
bound='InheritanceQuerySet[Any]',
26+
default='InheritanceQuerySet[ModelT]',
27+
)
28+
1529
if TYPE_CHECKING:
1630
from collections.abc import Iterator
1731

@@ -226,57 +240,82 @@ def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]:
226240
)
227241

228242

229-
class InheritanceManagerMixin(Generic[ModelT]):
230-
_queryset_class = InheritanceQuerySet
243+
class InheritanceManagerMixin(Generic[ModelT, QuerySetT]):
244+
"""
245+
Mixin for Manager classes that provides inheritance-aware queryset methods.
246+
247+
This mixin supports an optional second type parameter ``QuerySetT`` which
248+
defaults to ``InheritanceQuerySet[ModelT]``. This allows subclasses to
249+
specify a custom QuerySet type while maintaining backwards compatibility
250+
with existing code that only specifies the model type.
251+
252+
Example usage with default QuerySet::
253+
254+
class MyManager(InheritanceManager[MyModel]):
255+
pass # get_queryset() returns InheritanceQuerySet[MyModel]
256+
257+
Example usage with custom QuerySet::
258+
259+
class MyQuerySet(InheritanceQuerySetMixin[MyModel], QuerySet[MyModel]):
260+
def custom_filter(self) -> Self:
261+
return self.filter(active=True)
262+
263+
class MyManager(InheritanceManager[MyModel, MyQuerySet]):
264+
_queryset_class = MyQuerySet
265+
266+
def get_queryset(self) -> MyQuerySet:
267+
return MyQuerySet(self.model, using=self._db)
268+
"""
269+
_queryset_class: type[QuerySetT] = InheritanceQuerySet # type: ignore[assignment]
231270

232271
if TYPE_CHECKING:
233272
from collections.abc import Sequence
234273

235-
def none(self) -> InheritanceQuerySet[ModelT]:
274+
def none(self) -> QuerySetT:
236275
...
237276

238-
def all(self) -> InheritanceQuerySet[ModelT]:
277+
def all(self) -> QuerySetT:
239278
...
240279

241-
def filter(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
280+
def filter(self, *args: Any, **kwargs: Any) -> QuerySetT:
242281
...
243282

244-
def exclude(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
283+
def exclude(self, *args: Any, **kwargs: Any) -> QuerySetT:
245284
...
246285

247-
def complex_filter(self, filter_obj: Any) -> InheritanceQuerySet[ModelT]:
286+
def complex_filter(self, filter_obj: Any) -> QuerySetT:
248287
...
249288

250-
def union(self, *other_qs: Any, all: bool = ...) -> InheritanceQuerySet[ModelT]:
289+
def union(self, *other_qs: Any, all: bool = ...) -> QuerySetT:
251290
...
252291

253-
def intersection(self, *other_qs: Any) -> InheritanceQuerySet[ModelT]:
292+
def intersection(self, *other_qs: Any) -> QuerySetT:
254293
...
255294

256-
def difference(self, *other_qs: Any) -> InheritanceQuerySet[ModelT]:
295+
def difference(self, *other_qs: Any) -> QuerySetT:
257296
...
258297

259298
def select_for_update(
260299
self, nowait: bool = ..., skip_locked: bool = ..., of: Sequence[str] = ..., no_key: bool = ...
261-
) -> InheritanceQuerySet[ModelT]:
300+
) -> QuerySetT:
262301
...
263302

264-
def select_related(self, *fields: Any) -> InheritanceQuerySet[ModelT]:
303+
def select_related(self, *fields: Any) -> QuerySetT:
265304
...
266305

267-
def prefetch_related(self, *lookups: Any) -> InheritanceQuerySet[ModelT]:
306+
def prefetch_related(self, *lookups: Any) -> QuerySetT:
268307
...
269308

270-
def annotate(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
309+
def annotate(self, *args: Any, **kwargs: Any) -> QuerySetT:
271310
...
272311

273-
def alias(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
312+
def alias(self, *args: Any, **kwargs: Any) -> QuerySetT:
274313
...
275314

276-
def order_by(self, *field_names: Any) -> InheritanceQuerySet[ModelT]:
315+
def order_by(self, *field_names: Any) -> QuerySetT:
277316
...
278317

279-
def distinct(self, *field_names: Any) -> InheritanceQuerySet[ModelT]:
318+
def distinct(self, *field_names: Any) -> QuerySetT:
280319
...
281320

282321
def extra(
@@ -287,38 +326,38 @@ def extra(
287326
tables: list[str] | None = ...,
288327
order_by: Sequence[str] | None = ...,
289328
select_params: Sequence[Any] | None = ...,
290-
) -> InheritanceQuerySet[Any]:
329+
) -> QuerySetT:
291330
...
292331

293-
def reverse(self) -> InheritanceQuerySet[ModelT]:
332+
def reverse(self) -> QuerySetT:
294333
...
295334

296-
def defer(self, *fields: Any) -> InheritanceQuerySet[ModelT]:
335+
def defer(self, *fields: Any) -> QuerySetT:
297336
...
298337

299-
def only(self, *fields: Any) -> InheritanceQuerySet[ModelT]:
338+
def only(self, *fields: Any) -> QuerySetT:
300339
...
301340

302-
def using(self, alias: str | None) -> InheritanceQuerySet[ModelT]:
341+
def using(self, alias: str | None) -> QuerySetT:
303342
...
304343

305-
def get_queryset(self) -> InheritanceQuerySet[ModelT]:
344+
def get_queryset(self) -> QuerySetT:
306345
model: type[ModelT] = self.model # type: ignore[attr-defined]
307346
return self._queryset_class(model)
308347

309348
def select_subclasses(
310349
self, *subclasses: str | type[models.Model]
311-
) -> InheritanceQuerySet[ModelT]:
312-
return self.get_queryset().select_subclasses(*subclasses)
350+
) -> QuerySetT:
351+
return self.get_queryset().select_subclasses(*subclasses) # type: ignore[return-value]
313352

314353
def get_subclass(self, *args: object, **kwargs: object) -> ModelT:
315354
return self.get_queryset().get_subclass(*args, **kwargs)
316355

317-
def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]:
318-
return self.get_queryset().instance_of(*models)
356+
def instance_of(self, *models: type[ModelT]) -> QuerySetT:
357+
return self.get_queryset().instance_of(*models) # type: ignore[return-value]
319358

320359

321-
class InheritanceManager(InheritanceManagerMixin[ModelT], models.Manager[ModelT]):
360+
class InheritanceManager(InheritanceManagerMixin[ModelT, QuerySetT], models.Manager[ModelT]):
322361
pass
323362

324363

setup.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ def long_desc(root_path):
3030
url='https://github.com/jazzband/django-model-utils',
3131
packages=find_packages(exclude=['tests*']),
3232
python_requires=">=3.10",
33-
install_requires=['Django>=4.2'],
33+
install_requires=[
34+
'Django>=4.2',
35+
'typing_extensions>=4.0.0; python_version < "3.13"',
36+
],
3437
classifiers=[
3538
'Development Status :: 5 - Production/Stable',
3639
'Environment :: Web Environment',

0 commit comments

Comments
 (0)