3535from bitproto .renderer .renderer import Renderer
3636from bitproto .utils import cached_property , override
3737
38+ _enum_field_proxy_prefix = "_enum_field_proxy__"
39+
3840
3941class BlockProtoDocstring (BlockBindProto [F ]):
4042 @override (Block )
@@ -47,7 +49,8 @@ class BlockGeneralImports(Block[F]):
4749 def render (self ) -> None :
4850 self .push ("import json" )
4951 self .push ("from dataclasses import dataclass, field" )
50- self .push ("from typing import ClassVar, Dict, List" )
52+ self .push ("from typing import ClassVar, Dict, List, Union" )
53+ self .push ("from enum import IntEnum, unique" )
5154 self .push_empty_line ()
5255 self .push ("from bitprotolib import bp" )
5356
@@ -140,7 +143,7 @@ class BlockEnumField(BlockBindEnumField[F]):
140143 def render (self ) -> None :
141144 self .push_definition_comments ()
142145 self .push (
143- f"{ self .enum_field_name } : { self .enum_field_type } = { self .enum_field_value } "
146+ f"{ self .enum_field_name } : { self .enum_field_type } = { self .enum_field_type } . { self . enum_field_name } "
144147 )
145148
146149
@@ -159,20 +162,58 @@ class BlockEnumFieldListWrapper(BlockBindEnum[F], BlockWrapper[F]):
159162 def wraps (self ) -> Block :
160163 return BlockEnumFieldList (self .d )
161164
165+ @override (BlockWrapper )
166+ def before (self ) -> None :
167+ self .push_comment ("Aliases for backwards compatibility" )
168+
169+ @override (BlockWrapper )
170+ def after (self ) -> None :
171+ self .push_empty_line ()
172+
173+
174+ class BlockIntEnumField (BlockBindEnumField [F ]):
175+ @override (Block )
176+ def render (self ) -> None :
177+ self .push_definition_comments ()
178+ self .push (f" { self .enum_field_name } = { self .enum_field_value } " )
179+
180+
181+ class BlockIntEnumFieldList (BlockBindEnum [F ], BlockComposition [F ]):
182+ @override (BlockComposition )
183+ def blocks (self ) -> List [Block ]:
184+ return [BlockIntEnumField (field ) for field in self .d .fields ()]
185+
186+ @override (BlockComposition )
187+ def separator (self ) -> str :
188+ return "\n "
189+
190+
191+ class BlockIntEnumFieldListWrapper (BlockBindEnum [F ], BlockWrapper [F ]):
192+ @override (BlockWrapper )
193+ def wraps (self ) -> Block :
194+ return BlockIntEnumFieldList (self .d )
195+
162196 @override (BlockWrapper )
163197 def before (self ) -> None :
164198 self .render_enum_type ()
165199
200+ @override (BlockWrapper )
201+ def after (self ) -> None :
202+ self .push_empty_line ()
203+
166204 def render_enum_type (self ) -> None :
167205 self .push_definition_comments ()
168- self .push (f"{ self .enum_name } = int" )
206+ self .push ("@unique" )
207+ self .push (f"class { self .enum_name } (IntEnum):" )
169208 self .push_typing_hint_inline_comment ()
170209
171210
172211class BlockEnumValueToNameMapItem (BlockBindEnumField [F ]):
173212 @override (Block )
174213 def render (self ) -> None :
175- self .push (f'{ self .enum_field_value } : "{ self .enum_field_name } ",' )
214+ self .push (
215+ f'{ self .enum_field_type } .{ self .enum_field_name } : "{ self .enum_field_name } ",'
216+ )
176217
177218
178219class BlockEnumValueToNameMapItemList (BlockBindEnum [F ], BlockComposition [F ]):
@@ -215,6 +256,7 @@ class BlockEnum(BlockBindEnum[F], BlockComposition[F]):
215256 @override (BlockComposition [F ])
216257 def blocks (self ) -> List [Block ]:
217258 return [
259+ BlockIntEnumFieldListWrapper (self .d ),
218260 BlockEnumFieldListWrapper (self .d ),
219261 BlockEnumValueToNameMap (self .d ),
220262 BlockEnumMethodProcessor (self .d ),
@@ -229,9 +271,22 @@ def message_field_default_value(self) -> str:
229271 @override (Block )
230272 def render (self ) -> None :
231273 self .push_definition_comments ()
232- self .push (
233- f"{ self .message_field_name } : { self .message_field_type } = { self .message_field_default_value } "
234- )
274+ if issubclass (type (self .d .type ), Enum ):
275+ # use union of int and IntEnum so that field assignment works as before without typing problem.
276+ self .push (
277+ f"{ self .message_field_name } : Union[int, { self .message_field_type } ] = { self .message_field_default_value } "
278+ )
279+ # push a proxy field for enum fields to hold integer value
280+ self .push_comment (
281+ f"This field is a proxy to hold integer value of enum field '{ self .message_field_name } '"
282+ )
283+ self .push (
284+ f"{ _enum_field_proxy_prefix } { self .message_field_name } : int = field(init=False, repr=False)"
285+ )
286+ else :
287+ self .push (
288+ f"{ self .message_field_name } : { self .message_field_type } = { self .message_field_default_value } "
289+ )
235290 self .push_typing_hint_inline_comment ()
236291
237292
@@ -289,6 +344,19 @@ def before(self) -> None:
289344 self .push_definition_docstring (indent = 4 )
290345
291346
347+ class BlockMessageDictFactory (BlockMessageBase , BlockWrapper [F ]):
348+ def wraps (self ) -> Block [F ]:
349+ pass
350+
351+ @override (BlockWrapper )
352+ def before (self ) -> None :
353+ self .push ("@staticmethod" )
354+ self .push ("def dict_factory(kv_pairs):" )
355+ self .push (
356+ f" return {{k: v for k, v in kv_pairs if not k.startswith('{ _enum_field_proxy_prefix } ')}}"
357+ )
358+
359+
292360class BlockMessageMethodProcessorFieldItem (BlockBindMessageField [F ]):
293361 @override (Block )
294362 def render (self ) -> None :
@@ -310,6 +378,104 @@ def separator(self) -> str:
310378 return "\n "
311379
312380
381+ class BlockMessagePostInitField (BlockBindMessageField [F ]):
382+ @override (Block )
383+ def render (self ) -> None :
384+ self .push_comment (
385+ f"initialize handling of enum field '{ self .message_field_name } ' as `enum.IntEnum`"
386+ )
387+ self .push (
388+ f'if not isinstance(getattr({ self .d .message .name } , "{ self .message_field_name } ", False), property):'
389+ )
390+ self .push (
391+ f" self.{ _enum_field_proxy_prefix } { self .message_field_name } = self.{ self .message_field_name } "
392+ )
393+ self .push (
394+ f" { self .d .message .name } .{ self .message_field_name } = property("
395+ f"{ self .d .message .name } ._get_{ self .message_field_name } , "
396+ f"{ self .d .message .name } ._set_{ self .message_field_name } )"
397+ )
398+
399+
400+ class BlockMessagePostInitPass (BlockBindMessage [F ]):
401+ @override (Block )
402+ def render (self ) -> None :
403+ self .push ("pass" )
404+
405+
406+ class BlockMessagePostInitItemList (BlockMessageBase , BlockComposition [F ]):
407+ @override (BlockComposition )
408+ def blocks (self ) -> List [Block [F ]]:
409+ if any (
410+ (issubclass (type (field .type ), Enum ) for field in self .d .sorted_fields ())
411+ ):
412+ b : List [Block [F ]] = [
413+ BlockMessagePostInitField (field , indent = self .indent )
414+ for field in self .d .sorted_fields ()
415+ if issubclass (type (field .type ), Enum )
416+ ]
417+ return b
418+ else :
419+ return [BlockMessagePostInitPass (self .d , indent = self .indent )]
420+
421+ @override (BlockComposition )
422+ def separator (self ) -> str :
423+ return "\n "
424+
425+
426+ class BlockMessagePostInit (BlockBindMessage [F ], BlockWrapper [F ]):
427+ @override (BlockWrapper )
428+ def wraps (self ) -> Block [F ]:
429+ return BlockMessagePostInitItemList (self .d , indent = self .indent + 4 )
430+
431+ @override (BlockWrapper )
432+ def before (self ) -> None :
433+ self .push ("def __post_init__(self):" )
434+
435+
436+ class BlockMessageEnumProxyFieldAccessorField (BlockBindMessageField [F ]):
437+ @override (Block )
438+ def render (self ) -> None :
439+ self .push (
440+ f"def _get_{ self .message_field_name } (self) -> { self .message_field_type } :"
441+ )
442+ self .push (f' """property getter for enum proxy field"""' )
443+ self .push (
444+ f" return { self .message_field_type } (self.{ _enum_field_proxy_prefix } { self .message_field_name } )\n "
445+ )
446+ self .push (f"def _set_{ self .message_field_name } (self, val):" )
447+ self .push (f' """property setter for enum proxy field"""' )
448+ self .push (f" self.{ _enum_field_proxy_prefix } { self .message_field_name } = val" )
449+
450+
451+ class BlockMessageEnumProxyFieldAccessorFieldList (
452+ BlockMessageBase , BlockComposition [F ]
453+ ):
454+ @override (BlockComposition )
455+ def blocks (self ) -> List [Block [F ]]:
456+ if any (
457+ (issubclass (type (field .type ), Enum ) for field in self .d .sorted_fields ())
458+ ):
459+ b : List [Block [F ]] = [
460+ BlockMessageEnumProxyFieldAccessorField (field , indent = self .indent )
461+ for field in self .d .sorted_fields ()
462+ if issubclass (type (field .type ), Enum )
463+ ]
464+ return b
465+ else :
466+ return []
467+
468+ @override (BlockComposition )
469+ def separator (self ) -> str :
470+ return "\n "
471+
472+
473+ class BlockMessageEnumProxyFieldAccessors (BlockBindMessage [F ], BlockWrapper [F ]):
474+ @override (BlockWrapper )
475+ def wraps (self ) -> Block [F ]:
476+ return BlockMessageEnumProxyFieldAccessorFieldList (self .d , indent = self .indent )
477+
478+
313479class BlockMessageMethodProcessor (BlockMessageBase , BlockWrapper [F ]):
314480 @override (BlockWrapper )
315481 def wraps (self ) -> Block [F ]:
@@ -606,6 +772,9 @@ class BlockMessage(BlockMessageBase, BlockComposition[F]):
606772 def blocks (self ) -> List [Block [F ]]:
607773 return [
608774 BlockMessageClass (self .d ),
775+ BlockMessagePostInit (self .d , indent = 4 ),
776+ BlockMessageDictFactory (self .d , indent = 4 ),
777+ BlockMessageEnumProxyFieldAccessors (self .d , indent = 4 ),
609778 BlockMessageMethodProcessor (self .d , indent = 4 ),
610779 BlockMessageMethodSetByte (self .d , indent = 4 ),
611780 BlockMessageMethodGetByte (self .d , indent = 4 ),
0 commit comments