diff --git a/src/betterproto2_compiler/plugin/models.py b/src/betterproto2_compiler/plugin/models.py index 8e2166bf..66a47d41 100644 --- a/src/betterproto2_compiler/plugin/models.py +++ b/src/betterproto2_compiler/plugin/models.py @@ -410,10 +410,43 @@ def py_type(self) -> str: def unwrapped_py_type(self) -> str: return self._py_type(wrap=False) + @property + def annotations(self) -> list[str]: + """List of the Pydantic annotation to add to the field.""" + assert self.output_file.settings.pydantic_dataclasses + + annotations = [] + + if self.proto_obj.type in (FieldType.TYPE_INT32, FieldType.TYPE_SFIXED32, FieldType.TYPE_SINT32): + annotations.append("pydantic.Field(ge=-2**31, le=2**31 - 1)") + + elif self.proto_obj.type in (FieldType.TYPE_UINT32, FieldType.TYPE_FIXED32): + annotations.append("pydantic.Field(ge=0, le=2**32 - 1)") + + elif self.proto_obj.type in (FieldType.TYPE_INT64, FieldType.TYPE_SFIXED64, FieldType.TYPE_SINT64): + annotations.append("pydantic.Field(ge=-2**63, le=2**63 - 1)") + + elif self.proto_obj.type in (FieldType.TYPE_UINT64, FieldType.TYPE_FIXED64): + annotations.append("pydantic.Field(ge=0, le=2**64 - 1)") + + elif self.proto_obj.type == FieldType.TYPE_FLOAT: + annotations.append("pydantic.AfterValidator(betterproto2.validators.validate_float32)") + + elif self.proto_obj.type == FieldType.TYPE_STRING: + annotations.append("pydantic.AfterValidator(betterproto2.validators.validate_string)") + + return annotations + @property def annotation(self) -> str: py_type = self.py_type + # Add the pydantic annotation if needed + if self.output_file.settings.pydantic_dataclasses: + annotations = self.annotations + if annotations: + py_type = f"typing.Annotated[{py_type}, {', '.join(annotations)}]" + if self.use_builtins: py_type = f"builtins.{py_type}" if self.repeated: diff --git a/src/betterproto2_compiler/templates/header.py.j2 b/src/betterproto2_compiler/templates/header.py.j2 index 41c0882e..d6bb310f 100644 --- a/src/betterproto2_compiler/templates/header.py.j2 +++ b/src/betterproto2_compiler/templates/header.py.j2 @@ -28,6 +28,7 @@ import typing from typing import TYPE_CHECKING {% if output_file.settings.pydantic_dataclasses %} +import pydantic from pydantic.dataclasses import dataclass from pydantic import model_validator {%- else -%} diff --git a/tests/inputs/validation/validation.proto b/tests/inputs/validation/validation.proto new file mode 100644 index 00000000..e0e79092 --- /dev/null +++ b/tests/inputs/validation/validation.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package validation; + +message Message { + int32 int32_value = 1; + int64 int64_value = 2; + uint32 uint32_value = 3; + uint64 uint64_value = 4; + sint32 sint32_value = 5; + sint64 sint64_value = 6; + fixed32 fixed32_value = 7; + fixed64 fixed64_value = 8; + sfixed32 sfixed32_value = 9; + sfixed64 sfixed64_value = 10; + + float float_value = 11; + + string string_value = 12; +}