Skip to content

Commit e38042e

Browse files
committed
refactor: avoid Python 3.9 dataclass unsafe_hash inheritance error
1 parent 0634620 commit e38042e

File tree

1 file changed

+107
-44
lines changed

1 file changed

+107
-44
lines changed

schema_salad/rust_codegen.py

Lines changed: 107 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
__all__ = ["RustCodeGen"]
44

5-
import dataclasses
65
import functools
76
import itertools
87
import json
@@ -11,7 +10,6 @@
1110
import sys
1211
from abc import ABC, abstractmethod
1312
from collections.abc import Iterator, MutableMapping, MutableSequence, Sequence
14-
from dataclasses import dataclass
1513
from importlib.resources import files as resource_files
1614
from io import StringIO
1715
from pathlib import Path
@@ -117,6 +115,7 @@ def to_rust_literal(value: Any) -> str:
117115

118116
def make_avro(items: MutableSequence[JsonDataType]) -> MutableSequence[NamedSchema]:
119117
"""Process a list of dictionaries to generate a list of Avro schemas."""
118+
120119
# Same as `from .utils import convert_to_dict`, which, however, is not public
121120
def convert_to_dict(j4: Any) -> Any:
122121
"""Convert generic Mapping objects to dicts recursively."""
@@ -150,11 +149,13 @@ def convert_to_dict(j4: Any) -> Any:
150149
RustIdent = str # alias
151150

152151

153-
@dataclass # ASSERT: Immutable class
154152
class RustLifetime:
155153
"""Represents a Rust lifetime parameter (e.g., `'a`)."""
156154

157-
ident: RustIdent
155+
__slots__ = ("ident",)
156+
157+
def __init__(self, ident: RustIdent):
158+
self.ident = ident
158159

159160
def __hash__(self) -> int:
160161
return hash(self.ident)
@@ -175,11 +176,16 @@ class RustMeta(ABC):
175176
pass
176177

177178

178-
@dataclass(unsafe_hash=True) # ASSERT: Immutable class
179179
class RustAttribute:
180180
"""Represents a Rust attribute (e.g., `#[derive(Debug)]`)."""
181181

182-
meta: RustMeta
182+
__slots__ = ("meta",)
183+
184+
def __init__(self, meta: RustMeta):
185+
self.meta = meta
186+
187+
def __hash__(self) -> int:
188+
return hash(self.meta)
183189

184190
def __str__(self) -> str:
185191
return f"#[{str(self.meta)}]"
@@ -193,17 +199,22 @@ def __str__(self) -> str:
193199
RustGenericsMut = MutableSequence[Union[RustLifetime, "RustPath"]] # alias
194200

195201

196-
@dataclass(unsafe_hash=True) # ASSERT: Immutable class
197202
class RustPathSegment:
198203
"""Represents a segment in a Rust path with optional generics."""
199204

200-
ident: RustIdent
201-
generics: RustGenerics = dataclasses.field(default_factory=tuple)
205+
__slots__ = ("ident", "generics")
202206

203207
REX: ClassVar[Pattern[str]] = re.compile(
204208
r"^([a-zA-Z_]\w*)(?:<([ \w\t,'<>]+)>)?$"
205209
) # Using `re.Pattern[str]` raise CI build errors
206210

211+
def __init__(self, ident: RustIdent, generics: Optional[RustGenerics] = None):
212+
self.ident = ident
213+
self.generics = () if generics is None else generics
214+
215+
def __hash__(self) -> int:
216+
return hash((self.ident, self.generics))
217+
207218
def __str__(self) -> str:
208219
if not self.generics:
209220
return self.ident
@@ -256,13 +267,18 @@ def parse_generics_string(value_generics: str) -> RustGenerics:
256267
RustPathSegmentsMut = MutableSequence[RustPathSegment] # alias
257268

258269

259-
@dataclass(unsafe_hash=True) # ASSERT: Immutable class
260270
class RustPath(RustMeta):
261271
"""Represents a complete Rust path (e.g., `::std::vec::Vec<T>`)."""
262272

273+
__slots__ = ("segments", "leading_colon")
274+
263275
# ASSERT: Never initialized with an empty sequence
264-
segments: RustPathSegments
265-
leading_colon: bool = False
276+
def __init__(self, segments: RustPathSegments, leading_colon: bool = False):
277+
self.segments = segments
278+
self.leading_colon = leading_colon
279+
280+
def __hash__(self) -> int:
281+
return hash((self.segments, self.leading_colon))
266282

267283
def __truediv__(self, other: Union["RustPath", RustPathSegment]) -> "RustPath":
268284
if self.segments[-1].generics:
@@ -304,24 +320,31 @@ def from_str(cls, value: str) -> "RustPath":
304320
return cls(segments=tuple(segments), leading_colon=leading_colon)
305321

306322

307-
@dataclass(unsafe_hash=True) # ASSERT: Immutable class
308323
class RustTypeTuple(RustType):
309324
"""Represents a Rust tuple type (e.g., `(T, U)`)."""
310325

326+
__slots__ = ("types",)
327+
311328
# ASSERT: Never initialized with an empty sequence
312-
types: Sequence[RustPath]
329+
def __init__(self, types: Sequence[RustPath]):
330+
self.types = types
331+
332+
def __hash__(self) -> int:
333+
return hash(self.types)
313334

314335
def __str__(self) -> str:
315336
types_str = ", ".join(str(ty) for ty in self.types)
316337
return f"({types_str})"
317338

318339

319-
@dataclass # ASSERT: Immutable class
320340
class RustMetaList(RustMeta):
321341
"""Represents attribute meta list information (e.g., `derive(Debug, Clone)`).."""
322342

323-
path: RustPath
324-
metas: Sequence[RustMeta] = tuple()
343+
__slots__ = ("path", "metas")
344+
345+
def __init__(self, path: RustPath, metas: Optional[Sequence[RustMeta]] = None):
346+
self.path = path
347+
self.metas = () if metas is None else metas
325348

326349
def __hash__(self) -> int:
327350
return hash(self.path)
@@ -331,12 +354,14 @@ def __str__(self) -> str:
331354
return f"{str(self.path)}(" + meta_str + ")"
332355

333356

334-
@dataclass # ASSERT: Immutable class
335357
class RustMetaNameValue(RustMeta):
336358
"""Represents attribute meta name-value information (e.g., `key = value`)."""
337359

338-
path: RustPath
339-
value: Any = True
360+
__slots__ = ("path", "value")
361+
362+
def __init__(self, path: RustPath, value: Any = True):
363+
self.path = path
364+
self.value = value
340365

341366
def __hash__(self) -> int:
342367
return hash(self.path)
@@ -350,13 +375,17 @@ def __str__(self) -> str:
350375
#
351376

352377

353-
@dataclass
354378
class RustNamedType(ABC): # ABC class
355379
"""Abstract class for Rust struct and enum types."""
356380

357-
ident: RustIdent
358-
attrs: RustAttributes = dataclasses.field(default_factory=list)
359-
visibility: str = "pub"
381+
__slots__ = ("ident", "attrs", "visibility")
382+
383+
def __init__(
384+
self, ident: RustIdent, attrs: Optional[RustAttributes] = None, visibility: str = "pub"
385+
):
386+
self.ident = ident
387+
self.attrs = () if attrs is None else attrs
388+
self.visibility = visibility
360389

361390
def __hash__(self) -> int:
362391
return hash(self.ident)
@@ -371,13 +400,15 @@ def __str__(self) -> str:
371400
return output.getvalue()
372401

373402

374-
@dataclass # ASSERT: Immutable class
375403
class RustField:
376404
"""Represents a field in a Rust struct."""
377405

378-
ident: RustIdent
379-
type: RustPath
380-
attrs: RustAttributes = dataclasses.field(default_factory=list)
406+
__slots__ = ("ident", "type", "attrs")
407+
408+
def __init__(self, ident: RustIdent, type: RustPath, attrs: Optional[RustAttributes] = None):
409+
self.ident = ident
410+
self.type = type
411+
self.attrs = () if attrs is None else attrs
381412

382413
def __hash__(self) -> int:
383414
return hash(self.ident)
@@ -394,11 +425,21 @@ def write_to(self, writer: IO[str], depth: int = 0) -> None:
394425
RustFieldsMut = Union[MutableSequence[RustField], RustTypeTuple] # alias
395426

396427

397-
@dataclass
398428
class RustStruct(RustNamedType):
399429
"""Represents a Rust struct definition."""
400430

401-
fields: Optional[RustFields] = None
431+
__slots__ = ("fields",)
432+
433+
def __init__(
434+
self,
435+
ident: RustIdent,
436+
fields: Optional[RustFields] = None,
437+
attrs: Optional[RustAttributes] = None,
438+
visibility: str = "pub",
439+
):
440+
_attrs = () if attrs is None else attrs
441+
super().__init__(ident, _attrs, visibility)
442+
self.fields = fields
402443

403444
def write_to(self, writer: IO[str], depth: int = 0) -> None:
404445
indent = " " * depth
@@ -419,13 +460,20 @@ def write_to(self, writer: IO[str], depth: int = 0) -> None:
419460
writer.write(f"{indent}}}\n")
420461

421462

422-
@dataclass # ASSERT: Immutable class
423463
class RustVariant:
424464
"""Represents a variant in a Rust enum."""
425465

426-
ident: RustIdent
427-
tuple: Optional[RustTypeTuple] = None
428-
attrs: RustAttributes = dataclasses.field(default_factory=list)
466+
__slots__ = ("ident", "tuple", "attrs")
467+
468+
def __init__(
469+
self,
470+
ident: RustIdent,
471+
tuple: Optional[RustTypeTuple] = None,
472+
attrs: Optional[RustAttributes] = None,
473+
):
474+
self.ident = ident
475+
self.tuple = tuple
476+
self.attrs = () if attrs is None else attrs
429477

430478
def __hash__(self) -> int:
431479
return hash(self.ident)
@@ -435,7 +483,6 @@ def write_to(self, writer: IO[str], depth: int = 0) -> None:
435483

436484
if self.attrs:
437485
writer.write("\n".join(f"{indent}{str(attr)}" for attr in self.attrs) + "\n")
438-
439486
writer.write(f"{indent}{self.ident}")
440487
if self.tuple:
441488
writer.write(str(self.tuple))
@@ -462,11 +509,21 @@ def from_path(cls, path: RustPath) -> "RustVariant":
462509
RustVariantsMut = MutableSequence[RustVariant] # alias
463510

464511

465-
@dataclass
466512
class RustEnum(RustNamedType):
467513
"""Represents a Rust enum definition."""
468514

469-
variants: RustVariants = dataclasses.field(default_factory=tuple)
515+
__slots__ = ("variants",)
516+
517+
def __init__(
518+
self,
519+
ident: RustIdent,
520+
variants: Optional[RustVariants] = None,
521+
attrs: Optional[RustAttributes] = None,
522+
visibility: str = "pub",
523+
):
524+
_attrs = () if attrs is None else attrs
525+
super().__init__(ident, _attrs, visibility)
526+
self.variants = () if variants is None else variants
470527

471528
def write_to(self, writer: IO[str], depth: int = 0) -> None:
472529
indent = " " * depth
@@ -495,16 +552,22 @@ def salad_macro_write_to(ty: RustNamedType, writer: IO[str], depth: int = 0) ->
495552
#
496553

497554

498-
@dataclass
499555
class RustModuleTree:
500556
"""Represents a Rust module with submodules and named types."""
501557

502-
ident: RustIdent # ASSERT: Immutable field
503-
parent: Optional["RustModuleTree"] # ASSERT: Immutable field
504-
named_types: MutableMapping[RustIdent, RustNamedType] = dataclasses.field(default_factory=dict)
505-
submodules: MutableMapping[RustIdent, "RustModuleTree"] = dataclasses.field(
506-
default_factory=dict
507-
)
558+
__slots__ = ("ident", "parent", "named_types", "submodules")
559+
560+
def __init__(
561+
self,
562+
ident: RustIdent, # ASSERT: Immutable field
563+
parent: Optional["RustModuleTree"] = None, # ASSERT: Immutable field
564+
named_types: Optional[MutableMapping[RustIdent, RustNamedType]] = None,
565+
submodules: Optional[MutableMapping[RustIdent, "RustModuleTree"]] = None,
566+
):
567+
self.ident = ident
568+
self.parent = parent
569+
self.named_types = {} if named_types is None else named_types
570+
self.submodules = {} if submodules is None else submodules
508571

509572
def __hash__(self) -> int:
510573
return hash((self.ident, self.parent))

0 commit comments

Comments
 (0)