diff --git a/docs/declaring_models.md b/docs/declaring_models.md index 1099ef4..e46964d 100644 --- a/docs/declaring_models.md +++ b/docs/declaring_models.md @@ -18,8 +18,13 @@ models = orm.ModelRegistry(database=database) class Note(orm.Model): + class MyQuerySet(QuerySet): + ... + tablename = "notes" registry = models + # or do not define the queryset_class + queryset_class = MyQuerySet fields = { "id": orm.Integer(primary_key=True), "text": orm.String(max_length=100), diff --git a/orm/__init__.py b/orm/__init__.py index ee3ef5c..553c96a 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -20,7 +20,7 @@ Text, Time, ) -from orm.models import Model, ModelRegistry +from orm.models import Model, ModelRegistry, QuerySet __version__ = "0.3.1" __all__ = [ @@ -49,4 +49,5 @@ "UUID", "Model", "ModelRegistry", + "QuerySet" ] diff --git a/orm/models.py b/orm/models.py index b402814..c001585 100644 --- a/orm/models.py +++ b/orm/models.py @@ -83,6 +83,8 @@ def __new__(cls, name, bases, attrs): if "tablename" not in attrs: setattr(model_class, "tablename", name.lower()) + model_class.queryset_class = attrs.get("queryset_class") + for name, field in attrs.get("fields", {}).items(): setattr(field, "registry", attrs.get("registry")) if field.primary_key: @@ -485,7 +487,6 @@ def _prepare_order_by(self, order_by: str): class Model(metaclass=ModelMeta): - objects = QuerySet() def __init__(self, **kwargs): if "pk" in kwargs: @@ -497,6 +498,10 @@ def __init__(self, **kwargs): ) setattr(self, key, value) + @property + def objects(self) -> QuerySet: + return self.queryset_class() if self.queryset_class else QuerySet() + @property def pk(self): return getattr(self, self.pkname) diff --git a/tests/test_models.py b/tests/test_models.py index 8e24437..c6b6bcf 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -32,6 +32,25 @@ class Product(orm.Model): } +class Book(orm.Model): + class MyQuerySet(QuerySet): + + async def get_or_none(self, **kwargs): + try: + return await super().get(**kwargs) + except NoMatch: + # or raise HttpException(404) + return None + + tablename = "products" + registry = models + queryset_class = QuerySet + fields = { + "id": orm.Integer(primary_key=True), + "name": orm.String(max_length=100), + } + + @pytest.fixture(autouse=True, scope="function") async def create_test_database(): await models.create_all() @@ -333,3 +352,13 @@ async def test_model_sqlalchemy_filter_operators(): shirt == await Product.objects.filter(Product.columns.name.contains("Cotton")).get() ) + + +async def test_queryset_class(): + await Book.objects.create(name="book") + + b = await Book.objects.get_or_none(name="book") + assert b + + b = await Book.objects.get_or_none(name="books") + assert b is None