Skip to content

Commit 42e80a8

Browse files
Simplify dataclass fields (#38)
* Simplify dataclass fields * Remove Struct test * Fix pickling test * Fix test stream stream * Partially fix test_features * Fix casing test * Fix feature optional tests * More features tests * More features tests * More features tests * Fix oneof pattern matching test * Add missing file, code quality
1 parent 48ebb30 commit 42e80a8

File tree

9 files changed

+4468
-4109
lines changed

9 files changed

+4468
-4109
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ protoc-gen-python_betterproto = "betterproto2.plugin:main"
4848
rust-codec = ["betterproto-rust-codec"]
4949

5050
[tool.ruff]
51-
extend-exclude = ["tests/output_*"]
51+
extend-exclude = ["tests/output_*", "src/betterproto2/lib"]
5252
target-version = "py38"
5353
line-length = 120
5454

src/betterproto2/__init__.py

Lines changed: 34 additions & 282 deletions
Original file line numberDiff line numberDiff line change
@@ -192,300 +192,52 @@ def get(field: dataclasses.Field) -> "FieldMetadata":
192192
return field.metadata["betterproto"]
193193

194194

195-
def dataclass_field(
195+
def field(
196196
number: int,
197197
proto_type: str,
198-
default_factory: Callable[[], Any],
199198
*,
199+
default_factory: Callable[[], Any] | None = None,
200200
map_types: Optional[Tuple[str, str]] = None,
201201
group: Optional[str] = None,
202202
wraps: Optional[str] = None,
203203
optional: bool = False,
204204
repeated: bool = False,
205-
) -> dataclasses.Field:
205+
) -> Any: # Return type is Any to pass type checking
206206
"""Creates a dataclass field with attached protobuf metadata."""
207207
if repeated:
208208
default_factory = list
209209

210210
elif optional or group:
211211
default_factory = type(None)
212212

213+
else:
214+
default_factory = {
215+
TYPE_ENUM: default_factory,
216+
TYPE_BOOL: bool,
217+
TYPE_INT32: int,
218+
TYPE_INT64: int,
219+
TYPE_UINT32: int,
220+
TYPE_UINT64: int,
221+
TYPE_SINT32: int,
222+
TYPE_SINT64: int,
223+
TYPE_FLOAT: float,
224+
TYPE_DOUBLE: float,
225+
TYPE_FIXED32: int,
226+
TYPE_SFIXED32: int,
227+
TYPE_FIXED64: int,
228+
TYPE_SFIXED64: int,
229+
TYPE_STRING: str,
230+
TYPE_BYTES: bytes,
231+
TYPE_MESSAGE: type(None),
232+
TYPE_MAP: dict,
233+
}[proto_type]
234+
213235
return dataclasses.field(
214236
default_factory=default_factory,
215237
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, group, wraps, optional)},
216238
)
217239

218240

219-
# Note: the fields below return `Any` to prevent type errors in the generated
220-
# data classes since the types won't match with `Field` and they get swapped
221-
# out at runtime. The generated dataclass variables are still typed correctly.
222-
223-
224-
def enum_field(
225-
number: int,
226-
enum_default_value: Callable[[], Enum],
227-
group: Optional[str] = None,
228-
optional: bool = False,
229-
repeated: bool = False,
230-
) -> Any:
231-
return dataclass_field(
232-
number,
233-
TYPE_ENUM,
234-
enum_default_value,
235-
group=group,
236-
optional=optional,
237-
repeated=repeated,
238-
)
239-
240-
241-
def bool_field(
242-
number: int,
243-
group: Optional[str] = None,
244-
optional: bool = False,
245-
repeated: bool = False,
246-
) -> Any:
247-
return dataclass_field(
248-
number,
249-
TYPE_BOOL,
250-
bool,
251-
group=group,
252-
optional=optional,
253-
repeated=repeated,
254-
)
255-
256-
257-
def int32_field(
258-
number: int,
259-
group: Optional[str] = None,
260-
optional: bool = False,
261-
repeated: bool = False,
262-
) -> Any:
263-
return dataclass_field(number, TYPE_INT32, int, group=group, optional=optional, repeated=repeated)
264-
265-
266-
def int64_field(
267-
number: int,
268-
group: Optional[str] = None,
269-
optional: bool = False,
270-
repeated: bool = False,
271-
) -> Any:
272-
return dataclass_field(number, TYPE_INT64, int, group=group, optional=optional, repeated=repeated)
273-
274-
275-
def uint32_field(
276-
number: int,
277-
group: Optional[str] = None,
278-
optional: bool = False,
279-
repeated: bool = False,
280-
) -> Any:
281-
return dataclass_field(
282-
number,
283-
TYPE_UINT32,
284-
int,
285-
group=group,
286-
optional=optional,
287-
repeated=repeated,
288-
)
289-
290-
291-
def uint64_field(
292-
number: int,
293-
group: Optional[str] = None,
294-
optional: bool = False,
295-
repeated: bool = False,
296-
) -> Any:
297-
return dataclass_field(
298-
number,
299-
TYPE_UINT64,
300-
int,
301-
group=group,
302-
optional=optional,
303-
repeated=repeated,
304-
)
305-
306-
307-
def sint32_field(
308-
number: int,
309-
group: Optional[str] = None,
310-
optional: bool = False,
311-
repeated: bool = False,
312-
) -> Any:
313-
return dataclass_field(
314-
number,
315-
TYPE_SINT32,
316-
int,
317-
group=group,
318-
optional=optional,
319-
repeated=repeated,
320-
)
321-
322-
323-
def sint64_field(
324-
number: int,
325-
group: Optional[str] = None,
326-
optional: bool = False,
327-
repeated: bool = False,
328-
) -> Any:
329-
return dataclass_field(
330-
number,
331-
TYPE_SINT64,
332-
int,
333-
group=group,
334-
optional=optional,
335-
repeated=repeated,
336-
)
337-
338-
339-
def float_field(
340-
number: int,
341-
group: Optional[str] = None,
342-
optional: bool = False,
343-
repeated: bool = False,
344-
) -> Any:
345-
return dataclass_field(
346-
number,
347-
TYPE_FLOAT,
348-
float,
349-
group=group,
350-
optional=optional,
351-
repeated=repeated,
352-
)
353-
354-
355-
def double_field(
356-
number: int,
357-
group: Optional[str] = None,
358-
optional: bool = False,
359-
repeated: bool = False,
360-
) -> Any:
361-
return dataclass_field(
362-
number,
363-
TYPE_DOUBLE,
364-
float,
365-
group=group,
366-
optional=optional,
367-
repeated=repeated,
368-
)
369-
370-
371-
def fixed32_field(
372-
number: int,
373-
group: Optional[str] = None,
374-
optional: bool = False,
375-
repeated: bool = False,
376-
) -> Any:
377-
return dataclass_field(
378-
number,
379-
TYPE_FIXED32,
380-
float,
381-
group=group,
382-
optional=optional,
383-
repeated=repeated,
384-
)
385-
386-
387-
def fixed64_field(
388-
number: int,
389-
group: Optional[str] = None,
390-
optional: bool = False,
391-
repeated: bool = False,
392-
) -> Any:
393-
return dataclass_field(
394-
number,
395-
TYPE_FIXED64,
396-
float,
397-
group=group,
398-
optional=optional,
399-
repeated=repeated,
400-
)
401-
402-
403-
def sfixed32_field(
404-
number: int,
405-
group: Optional[str] = None,
406-
optional: bool = False,
407-
repeated: bool = False,
408-
) -> Any:
409-
return dataclass_field(
410-
number,
411-
TYPE_SFIXED32,
412-
float,
413-
group=group,
414-
optional=optional,
415-
repeated=repeated,
416-
)
417-
418-
419-
def sfixed64_field(
420-
number: int,
421-
group: Optional[str] = None,
422-
optional: bool = False,
423-
repeated: bool = False,
424-
) -> Any:
425-
return dataclass_field(
426-
number,
427-
TYPE_SFIXED64,
428-
float,
429-
group=group,
430-
optional=optional,
431-
repeated=repeated,
432-
)
433-
434-
435-
def string_field(
436-
number: int,
437-
group: Optional[str] = None,
438-
optional: bool = False,
439-
repeated: bool = False,
440-
) -> Any:
441-
return dataclass_field(
442-
number,
443-
TYPE_STRING,
444-
str,
445-
group=group,
446-
optional=optional,
447-
repeated=repeated,
448-
)
449-
450-
451-
def bytes_field(
452-
number: int,
453-
group: Optional[str] = None,
454-
optional: bool = False,
455-
repeated: bool = False,
456-
) -> Any:
457-
return dataclass_field(
458-
number,
459-
TYPE_BYTES,
460-
bytes,
461-
group=group,
462-
optional=optional,
463-
repeated=repeated,
464-
)
465-
466-
467-
def message_field(
468-
number: int,
469-
group: Optional[str] = None,
470-
wraps: Optional[str] = None,
471-
optional: bool = False,
472-
repeated: bool = False,
473-
) -> Any:
474-
return dataclass_field(
475-
number,
476-
TYPE_MESSAGE,
477-
type(None),
478-
group=group,
479-
wraps=wraps,
480-
optional=optional,
481-
repeated=repeated,
482-
)
483-
484-
485-
def map_field(number: int, key_type: str, value_type: str, group: Optional[str] = None) -> Any:
486-
return dataclass_field(number, TYPE_MAP, dict, map_types=(key_type, value_type), group=group)
487-
488-
489241
def _pack_fmt(proto_type: str) -> str:
490242
"""Returns a little-endian format string for reading/writing binary."""
491243
return {
@@ -774,31 +526,31 @@ def _get_default_gen(cls: Type["Message"], fields: Iterable[dataclasses.Field])
774526
def _get_cls_by_field(cls: Type["Message"], fields: Iterable[dataclasses.Field]) -> Dict[str, Type]:
775527
field_cls = {}
776528

777-
for field in fields:
778-
meta = FieldMetadata.get(field)
529+
for field_ in fields:
530+
meta = FieldMetadata.get(field_)
779531
if meta.proto_type == TYPE_MAP:
780532
assert meta.map_types
781-
kt = cls._cls_for(field, index=0)
782-
vt = cls._cls_for(field, index=1)
783-
field_cls[field.name] = dataclasses.make_dataclass(
533+
kt = cls._cls_for(field_, index=0)
534+
vt = cls._cls_for(field_, index=1)
535+
field_cls[field_.name] = dataclasses.make_dataclass(
784536
"Entry",
785537
[
786538
(
787539
"key",
788540
kt,
789-
dataclass_field(1, meta.map_types[0], default_factory=kt),
541+
field(1, meta.map_types[0], default_factory=kt),
790542
),
791543
(
792544
"value",
793545
vt,
794-
dataclass_field(2, meta.map_types[1], default_factory=vt),
546+
field(2, meta.map_types[1], default_factory=vt),
795547
),
796548
],
797549
bases=(Message,),
798550
)
799-
field_cls[f"{field.name}.value"] = vt
551+
field_cls[f"{field_.name}.value"] = vt
800552
else:
801-
field_cls[field.name] = cls._cls_for(field)
553+
field_cls[field_.name] = cls._cls_for(field_)
802554

803555
return field_cls
804556

0 commit comments

Comments
 (0)