66
66
# TODO Check strings for Unicode standard for `XID_Start` and `XID_Continue`
67
67
# @functools.cache
68
68
def rust_sanitize_field_ident (value : str ) -> str :
69
- """
70
- Checks whether the field name is a Rust reserved world, or escapes it.
71
- """
72
- # value = functools.reduce(lambda s, r: re.sub(*r, s), __FIELD_NAME_REX_DICT, value)
73
- # value = value.lower()
69
+ """Check whether the field name is a Rust reserved world, or escape it."""
74
70
if value in __RUST_RESERVED_WORDS :
75
71
return f"r#{ value } "
76
72
return value
@@ -79,18 +75,15 @@ def rust_sanitize_field_ident(value: str) -> str:
79
75
# TODO Check strings for Unicode standard for `XID_Start` and `XID_Continue`
80
76
@functools .cache
81
77
def rust_sanitize_type_ident (value : str ) -> str :
82
- """
83
- Converts an input string into a valid Rust type name (PascalCase).
78
+ """Convert an input string into a valid Rust type name (PascalCase).
79
+
84
80
Results are cached for performance optimization.
85
81
"""
86
82
return functools .reduce (lambda s , r : re .sub (r [0 ], r [1 ], s ), __TYPE_NAME_REX_DICT , value )
87
83
88
84
89
85
def rust_sanitize_doc_iter (value : Union [Sequence [str ], str ]) -> Iterator [str ]:
90
- """
91
- Sanitizes Markdown doc-strings by splitting lines and wrapping non-hyperlinked
92
- URLs in angle brackets.
93
- """
86
+ """Sanitize Markdown doc-strings by splitting lines and wrapping non-hyperlinked URLs in angle brackets."""
94
87
return map (
95
88
lambda v : re .sub (__MD_NON_HYPERLINK_REX , lambda m : f"<{ str (m .group ())} >" , v ),
96
89
itertools .chain .from_iterable (map ( # flat_map
@@ -102,8 +95,8 @@ def rust_sanitize_doc_iter(value: Union[Sequence[str], str]) -> Iterator[str]:
102
95
103
96
@functools .cache
104
97
def to_rust_literal (value : Any ) -> str :
105
- """
106
- Convert Python values to their equivalent Rust literal representation.
98
+ """Convert Python values to their equivalent Rust literal representation.
99
+
107
100
Results are cached for performance optimization.
108
101
"""
109
102
if isinstance (value , bool ):
@@ -123,10 +116,7 @@ def to_rust_literal(value: Any) -> str:
123
116
124
117
125
118
def make_avro (items : MutableSequence [JsonDataType ]) -> MutableSequence [NamedSchema ]:
126
- """
127
- Processes a list of dictionaries to generate a list of Avro schemas.
128
- """
129
-
119
+ """Process a list of dictionaries to generate a list of Avro schemas."""
130
120
# Same as `from .utils import convert_to_dict`, which, however, is not public
131
121
def convert_to_dict (j4 : Any ) -> Any :
132
122
"""Convert generic Mapping objects to dicts recursively."""
@@ -162,9 +152,7 @@ def convert_to_dict(j4: Any) -> Any:
162
152
163
153
@dataclass # ASSERT: Immutable class
164
154
class RustLifetime :
165
- """
166
- Represents a Rust lifetime parameter (e.g., `'a`).
167
- """
155
+ """Represents a Rust lifetime parameter (e.g., `'a`)."""
168
156
169
157
ident : RustIdent
170
158
@@ -176,26 +164,20 @@ def __str__(self) -> str:
176
164
177
165
178
166
class RustType (ABC ):
179
- """
180
- Abstract class for Rust types.
181
- """
167
+ """Abstract class for Rust types."""
182
168
183
169
pass
184
170
185
171
186
172
class RustMeta (ABC ):
187
- """
188
- Abstract class for Rust attribute metas.
189
- """
173
+ """Abstract class for Rust attribute metas."""
190
174
191
175
pass
192
176
193
177
194
178
@dataclass (unsafe_hash = True ) # ASSERT: Immutable class
195
179
class RustAttribute :
196
- """
197
- Represents a Rust attribute (e.g., `#[derive(Debug)]`).
198
- """
180
+ """Represents a Rust attribute (e.g., `#[derive(Debug)]`)."""
199
181
200
182
meta : RustMeta
201
183
@@ -213,9 +195,7 @@ def __str__(self) -> str:
213
195
214
196
@dataclass (unsafe_hash = True ) # ASSERT: Immutable class
215
197
class RustPathSegment :
216
- """
217
- Represents a segment in a Rust path with optional generics.
218
- """
198
+ """Represents a segment in a Rust path with optional generics."""
219
199
220
200
ident : RustIdent
221
201
generics : RustGenerics = dataclasses .field (default_factory = tuple )
@@ -232,8 +212,8 @@ def __str__(self) -> str:
232
212
@classmethod
233
213
@functools .cache
234
214
def from_str (cls , value : str ) -> "RustPathSegment" :
235
- """
236
- Parses a string into RustPathSegment class.
215
+ """Parse a string into RustPathSegment class.
216
+
237
217
Results are cached for performance optimization.
238
218
"""
239
219
@@ -276,9 +256,7 @@ def parse_generics_string(value_generics: str) -> RustGenerics:
276
256
277
257
@dataclass (unsafe_hash = True ) # ASSERT: Immutable class
278
258
class RustPath (RustMeta ):
279
- """
280
- Represents a complete Rust path (e.g., `::std::vec::Vec<T>`).
281
- """
259
+ """Represents a complete Rust path (e.g., `::std::vec::Vec<T>`)."""
282
260
283
261
# ASSERT: Never initialized with an empty sequence
284
262
segments : RustPathSegments
@@ -309,8 +287,8 @@ def __str__(self) -> str:
309
287
@classmethod
310
288
@functools .cache
311
289
def from_str (cls , value : str ) -> "RustPath" :
312
- """
313
- Parses a string into RustPath class.
290
+ """Parse a string into RustPath class.
291
+
314
292
Results are cached for performance optimization.
315
293
"""
316
294
norm_value , leading_colon = (value [2 :], True ) if value .startswith ("::" ) else (value , False )
@@ -323,21 +301,10 @@ def from_str(cls, value: str) -> "RustPath":
323
301
raise ValueError (f"Poorly formatted Rust path: '{ value } '" )
324
302
return cls (segments = tuple (segments ), leading_colon = leading_colon )
325
303
326
- # def parent(self) -> "RustPath":
327
- # """
328
- # Returns a new RustPath containing all but the last segment.
329
- # """
330
- # return RustPath(
331
- # segments=self.segments[:-1],
332
- # leading_colon=self.leading_colon,
333
- # )
334
-
335
304
336
305
@dataclass (unsafe_hash = True ) # ASSERT: Immutable class
337
306
class RustTypeTuple (RustType ):
338
- """
339
- Represents a Rust tuple type (e.g., `(T, U)`).
340
- """
307
+ """Represents a Rust tuple type (e.g., `(T, U)`)."""
341
308
342
309
# ASSERT: Never initialized with an empty sequence
343
310
types : Sequence [RustPath ]
@@ -349,9 +316,7 @@ def __str__(self) -> str:
349
316
350
317
@dataclass # ASSERT: Immutable class
351
318
class RustMetaList (RustMeta ):
352
- """
353
- Represents attribute meta list information (e.g., `derive(Debug, Clone)`)
354
- """
319
+ """Represents attribute meta list information (e.g., `derive(Debug, Clone)`).."""
355
320
356
321
path : RustPath
357
322
metas : Sequence [RustMeta ] = tuple ()
@@ -366,9 +331,7 @@ def __str__(self) -> str:
366
331
367
332
@dataclass # ASSERT: Immutable class
368
333
class RustMetaNameValue (RustMeta ):
369
- """
370
- Represents attribute meta name-value information (e.g., `key = value`)
371
- """
334
+ """Represents attribute meta name-value information (e.g., `key = value`)."""
372
335
373
336
path : RustPath
374
337
value : Any = True
@@ -387,9 +350,7 @@ def __str__(self) -> str:
387
350
388
351
@dataclass
389
352
class RustNamedType (ABC ): # ABC class
390
- """
391
- Abstract class for Rust struct and enum types.
392
- """
353
+ """Abstract class for Rust struct and enum types."""
393
354
394
355
ident : RustIdent
395
356
attrs : RustAttributes = dataclasses .field (default_factory = list )
@@ -410,9 +371,7 @@ def __str__(self) -> str:
410
371
411
372
@dataclass # ASSERT: Immutable class
412
373
class RustField :
413
- """
414
- Represents a field in a Rust struct.
415
- """
374
+ """Represents a field in a Rust struct."""
416
375
417
376
ident : RustIdent
418
377
type : RustPath
@@ -435,9 +394,7 @@ def write_to(self, writer: IO[str], depth: int = 0) -> None:
435
394
436
395
@dataclass
437
396
class RustStruct (RustNamedType ):
438
- """
439
- Represents a Rust struct definition.
440
- """
397
+ """Represents a Rust struct definition."""
441
398
442
399
fields : Optional [RustFields ] = None
443
400
@@ -462,9 +419,7 @@ def write_to(self, writer: IO[str], depth: int = 0) -> None:
462
419
463
420
@dataclass # ASSERT: Immutable class
464
421
class RustVariant :
465
- """
466
- Represents a variant in a Rust enum.
467
- """
422
+ """Represents a variant in a Rust enum."""
468
423
469
424
ident : RustIdent
470
425
tuple : Optional [RustTypeTuple ] = None
@@ -507,9 +462,7 @@ def from_path(cls, path: RustPath) -> "RustVariant":
507
462
508
463
@dataclass
509
464
class RustEnum (RustNamedType ):
510
- """
511
- Represents a Rust enum definition.
512
- """
465
+ """Represents a Rust enum definition."""
513
466
514
467
variants : RustVariants = dataclasses .field (default_factory = tuple )
515
468
@@ -528,9 +481,7 @@ def write_to(self, writer: IO[str], depth: int = 0) -> None:
528
481
529
482
# Wrapper for the RustNamedType `write_to()` method call
530
483
def salad_macro_write_to (ty : RustNamedType , writer : IO [str ], depth : int = 0 ) -> None :
531
- """
532
- Writes a RustNamedType wrapping it in the Schema Salad macro
533
- """
484
+ """Write a RustNamedType wrapping it in the Schema Salad macro."""
534
485
indent = " " * depth
535
486
writer .write (indent + "salad_core::define_type! {\n " )
536
487
ty .write_to (writer , 1 )
@@ -544,9 +495,7 @@ def salad_macro_write_to(ty: RustNamedType, writer: IO[str], depth: int = 0) ->
544
495
545
496
@dataclass
546
497
class RustModuleTree :
547
- """
548
- Represents a Rust module with submodules and named types
549
- """
498
+ """Represents a Rust module with submodules and named types."""
550
499
551
500
ident : RustIdent # ASSERT: Immutable field
552
501
parent : Optional ["RustModuleTree" ] # ASSERT: Immutable field
@@ -559,9 +508,7 @@ def __hash__(self) -> int:
559
508
return hash ((self .ident , self .parent ))
560
509
561
510
def get_rust_path (self ) -> RustPath :
562
- """
563
- Returns the complete Rust path from root to this module.
564
- """
511
+ """Return the complete Rust path from root to this module."""
565
512
segments : list [RustPathSegment ] = []
566
513
current : Optional ["RustModuleTree" ] = self
567
514
@@ -571,9 +518,7 @@ def get_rust_path(self) -> RustPath:
571
518
return RustPath (segments = tuple (reversed (segments )))
572
519
573
520
def add_submodule (self , path : Union [RustPath , str ]) -> "RustModuleTree" :
574
- """
575
- Creates a new submodule or returns an existing one with the given path.
576
- """
521
+ """Create a new submodule or returns an existing one with the given path."""
577
522
if isinstance (path , str ):
578
523
path = RustPath .from_str (path )
579
524
segments = iter (path .segments )
@@ -598,25 +543,10 @@ def add_submodule(self, path: Union[RustPath, str]) -> "RustModuleTree":
598
543
)
599
544
return current
600
545
601
- # def get_submodule(self, path: Union[RustPath, str]) -> Optional["RustModuleTree"]:
602
- # """
603
- # Returns a submodule from this module tree by its Rust path, if any.
604
- # """
605
- # if isinstance(path, str):
606
- # path = RustPath.from_str(path)
607
- # current, last_segment_idx = self, len(path.segments) - 1
608
- # for idx, segment in enumerate(path.segments):
609
- # if (idx == last_segment_idx) and (current.ident == segment.ident):
610
- # return current
611
- # current = current.submodules.get(segment.ident)
612
- # if not current:
613
- # return None
614
- # return None
615
-
616
546
def add_named_type (self , ty : RustNamedType ) -> RustPath :
617
- """
618
- Adds a named type to this module tree and returns its complete Rust path.
619
- Raises `ValueError` if type with same name already exists
547
+ """Add a named type to this module tree and returns its complete Rust path.
548
+
549
+ Raises `ValueError` if type with same name already exists.
620
550
"""
621
551
module_rust_path = self .get_rust_path ()
622
552
if ty .ident in self .named_types :
@@ -630,9 +560,7 @@ def add_named_type(self, ty: RustNamedType) -> RustPath:
630
560
# return None
631
561
632
562
def write_to_fs (self , base_path : Path ) -> None :
633
- """
634
- Writes the module tree to the filesystem under the given base path.
635
- """
563
+ """Write the module tree to the filesystem under the given base path."""
636
564
637
565
def write_module_file (module : "RustModuleTree" , path : Path , mode : str = "wt" ) -> None :
638
566
with open (path , mode = mode ) as module_rs :
@@ -715,9 +643,7 @@ def rust_type_list(rust_ty: RustPath) -> RustPath:
715
643
716
644
717
645
class RustCodeGen (CodeGenBase ):
718
- """
719
- Rust code generator for schema salad definitions.
720
- """
646
+ """Rust code generator for schema salad definitions."""
721
647
722
648
# Static
723
649
CRATE_VERSION : ClassVar [str ] = "0.1.0" # Version of the generated crate
@@ -738,6 +664,7 @@ def __init__(
738
664
salad_version : str ,
739
665
target : Optional [str ] = None ,
740
666
) -> None :
667
+ """Initialize the RustCodeGen class."""
741
668
self .package = package
742
669
self .package_version = self .__generate_crate_version (salad_version )
743
670
self .output_dir = Path (target or "." ).resolve ()
@@ -755,6 +682,7 @@ def __init__(
755
682
)
756
683
757
684
def parse (self , items : MutableSequence [JsonDataType ]) -> None :
685
+ """Parse the provided item list to generate the corresponding Rust types."""
758
686
# Create output directory
759
687
self .__init_output_directory ()
760
688
@@ -1120,9 +1048,7 @@ def __get_submodule_path(self, schema: NamedSchema) -> RustPath:
1120
1048
return RustPath (segments = segments )
1121
1049
1122
1050
def __init_output_directory (self ) -> None :
1123
- """
1124
- Initialize the output directory structure.
1125
- """
1051
+ """Initialize the output directory structure."""
1126
1052
if self .output_dir .is_file ():
1127
1053
raise ValueError (f"Output directory cannot be a file: { self .output_dir } " )
1128
1054
if not self .output_dir .exists ():
0 commit comments