2
2
3
3
__all__ = ["RustCodeGen" ]
4
4
5
- import dataclasses
6
5
import functools
7
6
import itertools
8
7
import json
11
10
import sys
12
11
from abc import ABC , abstractmethod
13
12
from collections .abc import Iterator , MutableMapping , MutableSequence , Sequence
14
- from dataclasses import dataclass
15
13
from importlib .resources import files as resource_files
16
14
from io import StringIO
17
15
from pathlib import Path
@@ -150,11 +148,13 @@ def convert_to_dict(j4: Any) -> Any:
150
148
RustIdent = str # alias
151
149
152
150
153
- @dataclass # ASSERT: Immutable class
154
151
class RustLifetime :
155
152
"""Represents a Rust lifetime parameter (e.g., `'a`)."""
156
153
157
- ident : RustIdent
154
+ __slots__ = ("ident" ,)
155
+
156
+ def __init__ (self , ident : RustIdent ):
157
+ self .ident = ident
158
158
159
159
def __hash__ (self ) -> int :
160
160
return hash (self .ident )
@@ -175,11 +175,16 @@ class RustMeta(ABC):
175
175
pass
176
176
177
177
178
- @dataclass (unsafe_hash = True ) # ASSERT: Immutable class
179
178
class RustAttribute :
180
179
"""Represents a Rust attribute (e.g., `#[derive(Debug)]`)."""
181
180
182
- meta : RustMeta
181
+ __slots__ = ("meta" ,)
182
+
183
+ def __init__ (self , meta : RustMeta ):
184
+ self .meta = meta
185
+
186
+ def __hash__ (self ) -> int :
187
+ return hash (self .meta )
183
188
184
189
def __str__ (self ) -> str :
185
190
return f"#[{ str (self .meta )} ]"
@@ -193,17 +198,22 @@ def __str__(self) -> str:
193
198
RustGenericsMut = MutableSequence [Union [RustLifetime , "RustPath" ]] # alias
194
199
195
200
196
- @dataclass (unsafe_hash = True ) # ASSERT: Immutable class
197
201
class RustPathSegment :
198
202
"""Represents a segment in a Rust path with optional generics."""
199
203
200
- ident : RustIdent
201
- generics : RustGenerics = dataclasses .field (default_factory = tuple )
204
+ __slots__ = ("ident" , "generics" )
202
205
203
206
REX : ClassVar [Pattern [str ]] = re .compile (
204
207
r"^([a-zA-Z_]\w*)(?:<([ \w\t,'<>]+)>)?$"
205
208
) # Using `re.Pattern[str]` raise CI build errors
206
209
210
+ def __init__ (self , ident : RustIdent , generics : Optional [RustGenerics ] = None ):
211
+ self .ident = ident
212
+ self .generics = () if generics is None else generics
213
+
214
+ def __hash__ (self ) -> int :
215
+ return hash ((self .ident , self .generics ))
216
+
207
217
def __str__ (self ) -> str :
208
218
if not self .generics :
209
219
return self .ident
@@ -256,13 +266,18 @@ def parse_generics_string(value_generics: str) -> RustGenerics:
256
266
RustPathSegmentsMut = MutableSequence [RustPathSegment ] # alias
257
267
258
268
259
- @dataclass (unsafe_hash = True ) # ASSERT: Immutable class
260
269
class RustPath (RustMeta ):
261
270
"""Represents a complete Rust path (e.g., `::std::vec::Vec<T>`)."""
262
271
272
+ __slots__ = ("segments" , "leading_colon" )
273
+
263
274
# ASSERT: Never initialized with an empty sequence
264
- segments : RustPathSegments
265
- leading_colon : bool = False
275
+ def __init__ (self , segments : RustPathSegments , leading_colon : bool = False ):
276
+ self .segments = segments
277
+ self .leading_colon = leading_colon
278
+
279
+ def __hash__ (self ) -> int :
280
+ return hash ((self .segments , self .leading_colon ))
266
281
267
282
def __truediv__ (self , other : Union ["RustPath" , RustPathSegment ]) -> "RustPath" :
268
283
if self .segments [- 1 ].generics :
@@ -304,24 +319,31 @@ def from_str(cls, value: str) -> "RustPath":
304
319
return cls (segments = tuple (segments ), leading_colon = leading_colon )
305
320
306
321
307
- @dataclass (unsafe_hash = True ) # ASSERT: Immutable class
308
322
class RustTypeTuple (RustType ):
309
323
"""Represents a Rust tuple type (e.g., `(T, U)`)."""
310
324
325
+ __slots__ = ("types" ,)
326
+
311
327
# ASSERT: Never initialized with an empty sequence
312
- types : Sequence [RustPath ]
328
+ def __init__ (self , types : Sequence [RustPath ]):
329
+ self .types = types
330
+
331
+ def __hash__ (self ) -> int :
332
+ return hash (self .types )
313
333
314
334
def __str__ (self ) -> str :
315
335
types_str = ", " .join (str (ty ) for ty in self .types )
316
336
return f"({ types_str } )"
317
337
318
338
319
- @dataclass # ASSERT: Immutable class
320
339
class RustMetaList (RustMeta ):
321
340
"""Represents attribute meta list information (e.g., `derive(Debug, Clone)`).."""
322
341
323
- path : RustPath
324
- metas : Sequence [RustMeta ] = tuple ()
342
+ __slots__ = ("path" , "metas" )
343
+
344
+ def __init__ (self , path : RustPath , metas : Optional [Sequence [RustMeta ]] = None ):
345
+ self .path = path
346
+ self .metas = () if metas is None else metas
325
347
326
348
def __hash__ (self ) -> int :
327
349
return hash (self .path )
@@ -331,12 +353,14 @@ def __str__(self) -> str:
331
353
return f"{ str (self .path )} (" + meta_str + ")"
332
354
333
355
334
- @dataclass # ASSERT: Immutable class
335
356
class RustMetaNameValue (RustMeta ):
336
357
"""Represents attribute meta name-value information (e.g., `key = value`)."""
337
358
338
- path : RustPath
339
- value : Any = True
359
+ __slots__ = ("path" , "value" )
360
+
361
+ def __init__ (self , path : RustPath , value : Any = True ):
362
+ self .path = path
363
+ self .value = value
340
364
341
365
def __hash__ (self ) -> int :
342
366
return hash (self .path )
@@ -350,13 +374,17 @@ def __str__(self) -> str:
350
374
#
351
375
352
376
353
- @dataclass
354
377
class RustNamedType (ABC ): # ABC class
355
378
"""Abstract class for Rust struct and enum types."""
356
379
357
- ident : RustIdent
358
- attrs : RustAttributes = dataclasses .field (default_factory = list )
359
- visibility : str = "pub"
380
+ __slots__ = ("ident" , "attrs" , "visibility" )
381
+
382
+ def __init__ (
383
+ self , ident : RustIdent , attrs : Optional [RustAttributes ] = None , visibility : str = "pub"
384
+ ):
385
+ self .ident = ident
386
+ self .attrs = () if attrs is None else attrs
387
+ self .visibility = visibility
360
388
361
389
def __hash__ (self ) -> int :
362
390
return hash (self .ident )
@@ -371,13 +399,15 @@ def __str__(self) -> str:
371
399
return output .getvalue ()
372
400
373
401
374
- @dataclass # ASSERT: Immutable class
375
402
class RustField :
376
403
"""Represents a field in a Rust struct."""
377
404
378
- ident : RustIdent
379
- type : RustPath
380
- attrs : RustAttributes = dataclasses .field (default_factory = list )
405
+ __slots__ = ("ident" , "type" , "attrs" )
406
+
407
+ def __init__ (self , ident : RustIdent , type : RustPath , attrs : Optional [RustAttributes ] = None ):
408
+ self .ident = ident
409
+ self .type = type
410
+ self .attrs = () if attrs is None else attrs
381
411
382
412
def __hash__ (self ) -> int :
383
413
return hash (self .ident )
@@ -394,11 +424,21 @@ def write_to(self, writer: IO[str], depth: int = 0) -> None:
394
424
RustFieldsMut = Union [MutableSequence [RustField ], RustTypeTuple ] # alias
395
425
396
426
397
- @dataclass
398
427
class RustStruct (RustNamedType ):
399
428
"""Represents a Rust struct definition."""
400
429
401
- fields : Optional [RustFields ] = None
430
+ __slots__ = ("fields" ,)
431
+
432
+ def __init__ (
433
+ self ,
434
+ ident : RustIdent ,
435
+ fields : Optional [RustFields ] = None ,
436
+ attrs : Optional [RustAttributes ] = None ,
437
+ visibility : str = "pub" ,
438
+ ):
439
+ _attrs = () if attrs is None else attrs
440
+ super ().__init__ (ident , _attrs , visibility )
441
+ self .fields = fields
402
442
403
443
def write_to (self , writer : IO [str ], depth : int = 0 ) -> None :
404
444
indent = " " * depth
@@ -419,13 +459,20 @@ def write_to(self, writer: IO[str], depth: int = 0) -> None:
419
459
writer .write (f"{ indent } }}\n " )
420
460
421
461
422
- @dataclass # ASSERT: Immutable class
423
462
class RustVariant :
424
463
"""Represents a variant in a Rust enum."""
425
464
426
- ident : RustIdent
427
- tuple : Optional [RustTypeTuple ] = None
428
- attrs : RustAttributes = dataclasses .field (default_factory = list )
465
+ __slots__ = ("ident" , "tuple" , "attrs" )
466
+
467
+ def __init__ (
468
+ self ,
469
+ ident : RustIdent ,
470
+ tuple : Optional [RustTypeTuple ] = None ,
471
+ attrs : Optional [RustAttributes ] = None ,
472
+ ):
473
+ self .ident = ident
474
+ self .tuple = tuple
475
+ self .attrs = () if attrs is None else attrs
429
476
430
477
def __hash__ (self ) -> int :
431
478
return hash (self .ident )
@@ -435,7 +482,6 @@ def write_to(self, writer: IO[str], depth: int = 0) -> None:
435
482
436
483
if self .attrs :
437
484
writer .write ("\n " .join (f"{ indent } { str (attr )} " for attr in self .attrs ) + "\n " )
438
-
439
485
writer .write (f"{ indent } { self .ident } " )
440
486
if self .tuple :
441
487
writer .write (str (self .tuple ))
@@ -462,11 +508,21 @@ def from_path(cls, path: RustPath) -> "RustVariant":
462
508
RustVariantsMut = MutableSequence [RustVariant ] # alias
463
509
464
510
465
- @dataclass
466
511
class RustEnum (RustNamedType ):
467
512
"""Represents a Rust enum definition."""
468
513
469
- variants : RustVariants = dataclasses .field (default_factory = tuple )
514
+ __slots__ = ("variants" ,)
515
+
516
+ def __init__ (
517
+ self ,
518
+ ident : RustIdent ,
519
+ variants : Optional [RustVariants ] = None ,
520
+ attrs : Optional [RustAttributes ] = None ,
521
+ visibility : str = "pub" ,
522
+ ):
523
+ _attrs = () if attrs is None else attrs
524
+ super ().__init__ (ident , _attrs , visibility )
525
+ self .variants = () if variants is None else variants
470
526
471
527
def write_to (self , writer : IO [str ], depth : int = 0 ) -> None :
472
528
indent = " " * depth
@@ -495,16 +551,22 @@ def salad_macro_write_to(ty: RustNamedType, writer: IO[str], depth: int = 0) ->
495
551
#
496
552
497
553
498
- @dataclass
499
554
class RustModuleTree :
500
555
"""Represents a Rust module with submodules and named types."""
501
556
502
- ident : RustIdent # ASSERT: Immutable field
503
- parent : Optional ["RustModuleTree" ] # ASSERT: Immutable field
504
- named_types : MutableMapping [RustIdent , RustNamedType ] = dataclasses .field (default_factory = dict )
505
- submodules : MutableMapping [RustIdent , "RustModuleTree" ] = dataclasses .field (
506
- default_factory = dict
507
- )
557
+ __slots__ = ("ident" , "parent" , "named_types" , "submodules" )
558
+
559
+ def __init__ (
560
+ self ,
561
+ ident : RustIdent , # ASSERT: Immutable field
562
+ parent : Optional ["RustModuleTree" ] = None , # ASSERT: Immutable field
563
+ named_types : Optional [MutableMapping [RustIdent , RustNamedType ]] = None ,
564
+ submodules : Optional [MutableMapping [RustIdent , "RustModuleTree" ]] = None ,
565
+ ):
566
+ self .ident = ident
567
+ self .parent = parent
568
+ self .named_types = {} if named_types is None else named_types
569
+ self .submodules = {} if submodules is None else submodules
508
570
509
571
def __hash__ (self ) -> int :
510
572
return hash ((self .ident , self .parent ))
0 commit comments