diff --git a/orm/fields.py b/orm/fields.py index 9015a76..6488f8b 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -16,13 +16,22 @@ def __init__( **kwargs: typing.Any, ) -> None: if primary_key: - kwargs["read_only"] = True + default_value = kwargs.get("default", None) + self.raise_if_pk_without_default(default_value) + kwargs['allow_null'] = True + self.allow_null = kwargs.get("allow_null", False) self.primary_key = primary_key self.index = index self.unique = unique self.validator = self.get_validator(**kwargs) + def raise_if_pk_without_default(self, default: typing.Any): + if not default: + raise ValueError( + f"You need to specify default value for {self.__class__.__name__} primary key field" + ) + def get_column(self, name: str) -> sqlalchemy.Column: column_type = self.get_column_type() constraints = self.get_constraints() @@ -70,6 +79,9 @@ def get_column_type(self): class Integer(ModelField): + def raise_if_pk_without_default(self, default: typing.Any): + pass + def get_validator(self, **kwargs) -> typesystem.Field: return typesystem.Integer(**kwargs) diff --git a/orm/models.py b/orm/models.py index b402814..dd5c135 100644 --- a/orm/models.py +++ b/orm/models.py @@ -410,14 +410,14 @@ def _validate_kwargs(self, **kwargs): for key, value in fields.items(): if value.validator.read_only and value.validator.has_default(): kwargs[key] = value.validator.get_default_value() - return kwargs + + return {key: value for key, value in kwargs.items() if value is not None} async def create(self, **kwargs): kwargs = self._validate_kwargs(**kwargs) instance = self.model_cls(**kwargs) expr = self.table.insert().values(**kwargs) - - if self.pkname not in kwargs: + if not self.pkname in kwargs: instance.pk = await self.database.execute(expr) else: await self.database.execute(expr) diff --git a/tests/test_columns.py b/tests/test_columns.py index 278aecd..54bb9b0 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -131,6 +131,12 @@ async def test_model_crud(): assert product.updated_date == last_updated_date +async def test_create_user_with_custom_uuid(): + custom_uuid = uuid.uuid4() + user = await User.objects.create(id=custom_uuid) + assert user.pk == custom_uuid + + async def test_both_auto_now_and_auto_now_add_raise_error(): with pytest.raises(ValueError): @@ -159,3 +165,13 @@ async def test_bulk_create(): assert products[1].data == {"foo": 456} assert products[1].value == 456.789 assert products[1].status == StatusEnum.DRAFT + + +async def test_create_with_pk_not_integer_and_without_default_value(): + with pytest.raises(ValueError): + + class Post(orm.Model): + registry = models + fields = { + "id": orm.UUID(primary_key=True) + }