4545from _ast import AST
4646from datetime import datetime
4747from pathlib import Path
48- from typing import Type , cast
48+ from typing import Any , Type , cast
4949
5050
5151class ClassTransformer (ast .NodeTransformer ):
@@ -118,21 +118,36 @@ def has_model_config(self, node: ast.ClassDef) -> ast.Assign | None:
118118 return item
119119 return None
120120
121- def visit_ClassDef (self , _node : ast .ClassDef ) -> ast .ClassDef : # noqa: N802
122- """Visit and transform a class definition node.
123-
124- Args:
125- node: The ClassDef AST node to transform.
121+ def visit_AnnAssign (self , node : ast .AnnAssign ) -> ast .AnnAssign :
122+ """Visit and transform annotated assignment."""
123+ if isinstance (node .annotation , ast .Name ) and node .annotation .id == 'Role' :
124+ node .annotation = ast .BinOp (
125+ left = ast .Name (id = 'Role' , ctx = ast .Load ()),
126+ op = ast .BitOr (),
127+ right = ast .Name (id = 'str' , ctx = ast .Load ()),
128+ )
129+ self .modified = True
130+ return node
126131
127- Returns:
128- The transformed ClassDef node.
129- """
132+ def visit_ClassDef (self , node : ast .ClassDef ) -> Any :
133+ # Visit and transform a class definition node.
134+ #
135+ # Args:
136+ # node: The ClassDef AST node to transform.
137+ #
138+ # Returns:
139+ # The transformed ClassDef node.
130140 # First apply base class transformations recursively
131- node = super ().generic_visit (_node )
141+ node = cast ( ast . ClassDef , super ().generic_visit (node ) )
132142 new_body : list [ast .stmt | ast .Constant | ast .Assign ] = []
133143
134144 # Handle Docstrings
135- if not node .body or not isinstance (node .body [0 ], ast .Expr ) or not isinstance (node .body [0 ].value , ast .Constant ):
145+ if (
146+ not node .body
147+ or not isinstance (node .body [0 ], ast .Expr )
148+ or not isinstance (node .body [0 ].value , ast .Constant )
149+ or not isinstance (node .body [0 ].value .value , str )
150+ ):
136151 # Generate a more descriptive docstring based on class type
137152 if self .is_rootmodel_class (node ):
138153 docstring = f'Root model for { node .name .lower ().replace ("_" , " " )} .'
@@ -151,13 +166,21 @@ def visit_ClassDef(self, _node: ast.ClassDef) -> ast.ClassDef: # noqa: N802
151166
152167 # Handle model_config for BaseModel and RootModel
153168 existing_model_config_assign = self .has_model_config (node )
169+
154170 existing_model_config_call = None
155171 if existing_model_config_assign and isinstance (existing_model_config_assign .value , ast .Call ):
156172 existing_model_config_call = existing_model_config_assign .value
157173
158174 # Determine start index for iterating original body (skip docstring)
159175 body_start_index = (
160- 1 if (node .body and isinstance (node .body [0 ], ast .Expr ) and isinstance (node .body [0 ].value , ast .Str )) else 0
176+ 1
177+ if (
178+ node .body
179+ and isinstance (node .body [0 ], ast .Expr )
180+ and isinstance (node .body [0 ].value , ast .Constant )
181+ and isinstance (node .body [0 ].value .value , str )
182+ )
183+ else 0
161184 )
162185
163186 if self .is_rootmodel_class (node ):
0 commit comments