Skip to content

Commit b63c1c4

Browse files
committed
refactor: avoid Python 3.9 dataclass unsafe_hash inheritance error
1 parent 0634620 commit b63c1c4

File tree

1 file changed

+106
-44
lines changed

1 file changed

+106
-44
lines changed

schema_salad/rust_codegen.py

Lines changed: 106 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
__all__ = ["RustCodeGen"]
44

5-
import dataclasses
65
import functools
76
import itertools
87
import json
@@ -11,7 +10,6 @@
1110
import sys
1211
from abc import ABC, abstractmethod
1312
from collections.abc import Iterator, MutableMapping, MutableSequence, Sequence
14-
from dataclasses import dataclass
1513
from importlib.resources import files as resource_files
1614
from io import StringIO
1715
from pathlib import Path
@@ -150,11 +148,13 @@ def convert_to_dict(j4: Any) -> Any:
150148
RustIdent = str # alias
151149

152150

153-
@dataclass # ASSERT: Immutable class
154151
class RustLifetime:
155152
"""Represents a Rust lifetime parameter (e.g., `'a`)."""
156153

157-
ident: RustIdent
154+
__slots__ = ("ident",)
155+
156+
def __init__(self, ident: RustIdent):
157+
self.ident = ident
158158

159159
def __hash__(self) -> int:
160160
return hash(self.ident)
@@ -175,11 +175,16 @@ class RustMeta(ABC):
175175
pass
176176

177177

178-
@dataclass(unsafe_hash=True) # ASSERT: Immutable class
179178
class RustAttribute:
180179
"""Represents a Rust attribute (e.g., `#[derive(Debug)]`)."""
181180

182-
meta: RustMeta
181+
__slots__ = ("meta",)
182+
183+
def __init__(self, meta: RustMeta):
184+
self.meta = meta
185+
186+
def __hash__(self) -> int:
187+
return hash(self.meta)
183188

184189
def __str__(self) -> str:
185190
return f"#[{str(self.meta)}]"
@@ -193,17 +198,22 @@ def __str__(self) -> str:
193198
RustGenericsMut = MutableSequence[Union[RustLifetime, "RustPath"]] # alias
194199

195200

196-
@dataclass(unsafe_hash=True) # ASSERT: Immutable class
197201
class RustPathSegment:
198202
"""Represents a segment in a Rust path with optional generics."""
199203

200-
ident: RustIdent
201-
generics: RustGenerics = dataclasses.field(default_factory=tuple)
204+
__slots__ = ("ident", "generics")
202205

203206
REX: ClassVar[Pattern[str]] = re.compile(
204207
r"^([a-zA-Z_]\w*)(?:<([ \w\t,'<>]+)>)?$"
205208
) # Using `re.Pattern[str]` raise CI build errors
206209

210+
def __init__(self, ident: RustIdent, generics: Optional[RustGenerics] = None):
211+
self.ident = ident
212+
self.generics = () if generics is None else generics
213+
214+
def __hash__(self) -> int:
215+
return hash((self.ident, self.generics))
216+
207217
def __str__(self) -> str:
208218
if not self.generics:
209219
return self.ident
@@ -256,13 +266,18 @@ def parse_generics_string(value_generics: str) -> RustGenerics:
256266
RustPathSegmentsMut = MutableSequence[RustPathSegment] # alias
257267

258268

259-
@dataclass(unsafe_hash=True) # ASSERT: Immutable class
260269
class RustPath(RustMeta):
261270
"""Represents a complete Rust path (e.g., `::std::vec::Vec<T>`)."""
262271

272+
__slots__ = ("segments", "leading_colon")
273+
263274
# ASSERT: Never initialized with an empty sequence
264-
segments: RustPathSegments
265-
leading_colon: bool = False
275+
def __init__(self, segments: RustPathSegments, leading_colon: bool = False):
276+
self.segments = segments
277+
self.leading_colon = leading_colon
278+
279+
def __hash__(self) -> int:
280+
return hash((self.segments, self.leading_colon))
266281

267282
def __truediv__(self, other: Union["RustPath", RustPathSegment]) -> "RustPath":
268283
if self.segments[-1].generics:
@@ -304,24 +319,31 @@ def from_str(cls, value: str) -> "RustPath":
304319
return cls(segments=tuple(segments), leading_colon=leading_colon)
305320

306321

307-
@dataclass(unsafe_hash=True) # ASSERT: Immutable class
308322
class RustTypeTuple(RustType):
309323
"""Represents a Rust tuple type (e.g., `(T, U)`)."""
310324

325+
__slots__ = ("types",)
326+
311327
# ASSERT: Never initialized with an empty sequence
312-
types: Sequence[RustPath]
328+
def __init__(self, types: Sequence[RustPath]):
329+
self.types = types
330+
331+
def __hash__(self) -> int:
332+
return hash(self.types)
313333

314334
def __str__(self) -> str:
315335
types_str = ", ".join(str(ty) for ty in self.types)
316336
return f"({types_str})"
317337

318338

319-
@dataclass # ASSERT: Immutable class
320339
class RustMetaList(RustMeta):
321340
"""Represents attribute meta list information (e.g., `derive(Debug, Clone)`).."""
322341

323-
path: RustPath
324-
metas: Sequence[RustMeta] = tuple()
342+
__slots__ = ("path", "metas")
343+
344+
def __init__(self, path: RustPath, metas: Optional[Sequence[RustMeta]] = None):
345+
self.path = path
346+
self.metas = () if metas is None else metas
325347

326348
def __hash__(self) -> int:
327349
return hash(self.path)
@@ -331,12 +353,14 @@ def __str__(self) -> str:
331353
return f"{str(self.path)}(" + meta_str + ")"
332354

333355

334-
@dataclass # ASSERT: Immutable class
335356
class RustMetaNameValue(RustMeta):
336357
"""Represents attribute meta name-value information (e.g., `key = value`)."""
337358

338-
path: RustPath
339-
value: Any = True
359+
__slots__ = ("path", "value")
360+
361+
def __init__(self, path: RustPath, value: Any = True):
362+
self.path = path
363+
self.value = value
340364

341365
def __hash__(self) -> int:
342366
return hash(self.path)
@@ -350,13 +374,17 @@ def __str__(self) -> str:
350374
#
351375

352376

353-
@dataclass
354377
class RustNamedType(ABC): # ABC class
355378
"""Abstract class for Rust struct and enum types."""
356379

357-
ident: RustIdent
358-
attrs: RustAttributes = dataclasses.field(default_factory=list)
359-
visibility: str = "pub"
380+
__slots__ = ("ident", "attrs", "visibility")
381+
382+
def __init__(
383+
self, ident: RustIdent, attrs: Optional[RustAttributes] = None, visibility: str = "pub"
384+
):
385+
self.ident = ident
386+
self.attrs = () if attrs is None else attrs
387+
self.visibility = visibility
360388

361389
def __hash__(self) -> int:
362390
return hash(self.ident)
@@ -371,13 +399,15 @@ def __str__(self) -> str:
371399
return output.getvalue()
372400

373401

374-
@dataclass # ASSERT: Immutable class
375402
class RustField:
376403
"""Represents a field in a Rust struct."""
377404

378-
ident: RustIdent
379-
type: RustPath
380-
attrs: RustAttributes = dataclasses.field(default_factory=list)
405+
__slots__ = ("ident", "type", "attrs")
406+
407+
def __init__(self, ident: RustIdent, type: RustPath, attrs: Optional[RustAttributes] = None):
408+
self.ident = ident
409+
self.type = type
410+
self.attrs = () if attrs is None else attrs
381411

382412
def __hash__(self) -> int:
383413
return hash(self.ident)
@@ -394,11 +424,21 @@ def write_to(self, writer: IO[str], depth: int = 0) -> None:
394424
RustFieldsMut = Union[MutableSequence[RustField], RustTypeTuple] # alias
395425

396426

397-
@dataclass
398427
class RustStruct(RustNamedType):
399428
"""Represents a Rust struct definition."""
400429

401-
fields: Optional[RustFields] = None
430+
__slots__ = ("fields",)
431+
432+
def __init__(
433+
self,
434+
ident: RustIdent,
435+
fields: Optional[RustFields] = None,
436+
attrs: Optional[RustAttributes] = None,
437+
visibility: str = "pub",
438+
):
439+
_attrs = () if attrs is None else attrs
440+
super().__init__(ident, _attrs, visibility)
441+
self.fields = fields
402442

403443
def write_to(self, writer: IO[str], depth: int = 0) -> None:
404444
indent = " " * depth
@@ -419,13 +459,20 @@ def write_to(self, writer: IO[str], depth: int = 0) -> None:
419459
writer.write(f"{indent}}}\n")
420460

421461

422-
@dataclass # ASSERT: Immutable class
423462
class RustVariant:
424463
"""Represents a variant in a Rust enum."""
425464

426-
ident: RustIdent
427-
tuple: Optional[RustTypeTuple] = None
428-
attrs: RustAttributes = dataclasses.field(default_factory=list)
465+
__slots__ = ("ident", "tuple", "attrs")
466+
467+
def __init__(
468+
self,
469+
ident: RustIdent,
470+
tuple: Optional[RustTypeTuple] = None,
471+
attrs: Optional[RustAttributes] = None,
472+
):
473+
self.ident = ident
474+
self.tuple = tuple
475+
self.attrs = () if attrs is None else attrs
429476

430477
def __hash__(self) -> int:
431478
return hash(self.ident)
@@ -435,7 +482,6 @@ def write_to(self, writer: IO[str], depth: int = 0) -> None:
435482

436483
if self.attrs:
437484
writer.write("\n".join(f"{indent}{str(attr)}" for attr in self.attrs) + "\n")
438-
439485
writer.write(f"{indent}{self.ident}")
440486
if self.tuple:
441487
writer.write(str(self.tuple))
@@ -462,11 +508,21 @@ def from_path(cls, path: RustPath) -> "RustVariant":
462508
RustVariantsMut = MutableSequence[RustVariant] # alias
463509

464510

465-
@dataclass
466511
class RustEnum(RustNamedType):
467512
"""Represents a Rust enum definition."""
468513

469-
variants: RustVariants = dataclasses.field(default_factory=tuple)
514+
__slots__ = ("variants",)
515+
516+
def __init__(
517+
self,
518+
ident: RustIdent,
519+
variants: Optional[RustVariants] = None,
520+
attrs: Optional[RustAttributes] = None,
521+
visibility: str = "pub",
522+
):
523+
_attrs = () if attrs is None else attrs
524+
super().__init__(ident, _attrs, visibility)
525+
self.variants = () if variants is None else variants
470526

471527
def write_to(self, writer: IO[str], depth: int = 0) -> None:
472528
indent = " " * depth
@@ -495,16 +551,22 @@ def salad_macro_write_to(ty: RustNamedType, writer: IO[str], depth: int = 0) ->
495551
#
496552

497553

498-
@dataclass
499554
class RustModuleTree:
500555
"""Represents a Rust module with submodules and named types."""
501556

502-
ident: RustIdent # ASSERT: Immutable field
503-
parent: Optional["RustModuleTree"] # ASSERT: Immutable field
504-
named_types: MutableMapping[RustIdent, RustNamedType] = dataclasses.field(default_factory=dict)
505-
submodules: MutableMapping[RustIdent, "RustModuleTree"] = dataclasses.field(
506-
default_factory=dict
507-
)
557+
__slots__ = ("ident", "parent", "named_types", "submodules")
558+
559+
def __init__(
560+
self,
561+
ident: RustIdent, # ASSERT: Immutable field
562+
parent: Optional["RustModuleTree"] = None, # ASSERT: Immutable field
563+
named_types: Optional[MutableMapping[RustIdent, RustNamedType]] = None,
564+
submodules: Optional[MutableMapping[RustIdent, "RustModuleTree"]] = None,
565+
):
566+
self.ident = ident
567+
self.parent = parent
568+
self.named_types = {} if named_types is None else named_types
569+
self.submodules = {} if submodules is None else submodules
508570

509571
def __hash__(self) -> int:
510572
return hash((self.ident, self.parent))

0 commit comments

Comments
 (0)