11import inspect
2- from typing import Any , TypeVar , Union , get_args , get_origin
2+ import warnings
3+ from typing import TypeVar , Union , get_args , get_origin , Callable , Any , Optional , cast
34
45from pydantic import BaseModel
56
6- T = TypeVar ("T" , bound = BaseModel )
7+ T = TypeVar ('T' , bound = BaseModel )
78
89
9- def llm_documented (cls : type [T ]) -> type [T ]:
10- """Decorator to add LLM documentation methods to a Pydantic model."""
11-
12- def format_for_llm_impl (
13- cls_param : type [T ], include_validation : bool = False
14- ) -> str :
10+ def prompt_schema (cls : type [T ], * , warn_undocumented : bool = True ) -> type [T ]:
11+ """
12+ Decorator to add LLM documentation methods to a Pydantic model.
13+
14+ Args:
15+ cls: The Pydantic model class to decorate
16+ warn_undocumented: Whether to emit warnings for fields without docstrings
17+
18+ Returns:
19+ The decorated class with format_for_llm method
20+ """
21+
22+ def format_for_llm_impl (cls_param : type [T ], include_validation : bool = False ) -> str :
1523 """Format this model's fields and docstrings for LLM prompts."""
1624 lines = [f"{ cls_param .__name__ } :" ]
17-
25+
1826 # Get JSON schema to extract validation info if needed
1927 json_schema = cls_param .model_json_schema () if include_validation else {}
2028 properties = json_schema .get ("properties" , {})
21-
29+
2230 # Iterate through each field
2331 for name , field_info in cls_param .model_fields .items ():
2432 # Get the field's type
2533 field_type = _get_field_type_name (field_info )
26-
34+
2735 # Get docstring for the field
2836 docstring = _extract_field_docstring (cls_param , name )
29-
37+
38+ # Warn if field is not documented
39+ if warn_undocumented and not docstring :
40+ warnings .warn (
41+ f"Field '{ name } ' in { cls_param .__name__ } has no docstring. "
42+ "Add a docstring for better LLM prompts." ,
43+ UserWarning ,
44+ stacklevel = 2
45+ )
46+
3047 # Determine if field is optional
3148 is_optional = not field_info .is_required ()
3249 optional_str = ", optional" if is_optional else ""
33-
50+
3451 # Format the field line
3552 field_line = f"- { name } ({ field_type } { optional_str } ): { docstring } "
36-
53+
3754 # Add validation info if requested
3855 if include_validation and name in properties :
3956 field_schema = properties [name ]
40-
57+
4158 constraints = []
4259 # Common validation keywords
4360 for key in ["minLength" , "maxLength" , "minimum" , "maximum" , "pattern" ]:
@@ -52,54 +69,58 @@ def format_for_llm_impl(
5269 elif display_key == "maximum" :
5370 display_key = "le"
5471 elif display_key == "min_length" :
55- display_key = "min_length"
72+ display_key = "min_length"
5673 elif display_key == "max_length" :
5774 display_key = "max_length"
58-
75+
5976 constraints .append (f"{ display_key } : { field_schema [key ]} " )
60-
77+
6178 if constraints :
6279 field_line += f" [Constraints: { ', ' .join (constraints )} ]"
63-
80+
6481 lines .append (field_line )
65-
82+
6683 return "\n " .join (lines )
67-
84+
6885 # Add the format_for_llm method to the class using the classmethod decorator
69- cls . format_for_llm = classmethod (format_for_llm_impl ) # type: ignore
70-
86+ setattr ( cls , " format_for_llm" , classmethod (format_for_llm_impl ) ) # type: ignore
87+
7188 return cls
7289
7390
7491def _extract_field_docstring (cls : type , field_name : str ) -> str :
7592 """Extract docstring for a field from class source code."""
7693 try :
7794 source = inspect .getsource (cls )
78-
95+
7996 # Look for field definition
80- patterns = [f"{ field_name } :" , f"{ field_name } :" , f"{ field_name } =" ]
81-
97+ patterns = [
98+ f"{ field_name } :" ,
99+ f"{ field_name } :" ,
100+ f"{ field_name } ="
101+ ]
102+
82103 field_pos = - 1
83104 for pattern in patterns :
84105 pos = source .find (pattern )
85106 if pos != - 1 :
86107 field_pos = pos
87108 break
88-
109+
89110 if field_pos == - 1 :
90111 return ""
91-
112+
92113 # Look for triple-quoted docstring
93114 for quote in ['"""' , "'''" ]:
94115 doc_start = source .find (quote , field_pos )
95116 if doc_start != - 1 :
96117 doc_end = source .find (quote , doc_start + 3 )
97118 if doc_end != - 1 :
98- return source [doc_start + 3 : doc_end ].strip ()
99-
119+ return source [doc_start + 3 : doc_end ].strip ()
120+
100121 except Exception :
101122 pass
102-
123+
103124 return ""
104125
105126
@@ -127,17 +148,73 @@ def _get_field_type_name(field_info: Any) -> str:
127148
128149 # Handle list types
129150 if origin is list or str (origin ).endswith ("list" ):
130- # Just return "list" for simplicity and cross-version compatibility
131- return "list"
151+ arg_type = args [0 ]
152+ inner_type = ""
153+
154+ # Get the name of the inner type more reliably
155+ if isinstance (arg_type , type ):
156+ inner_type = arg_type .__name__
157+ elif hasattr (arg_type , "_name" ) and arg_type ._name :
158+ inner_type = arg_type ._name
159+ else :
160+ # Fall back to string representation with cleanup
161+ inner_type = str (arg_type ).replace ("typing." , "" ).strip ("'<>" )
162+
163+ # Extract class name from ForwardRef if needed
164+ if "ForwardRef" in inner_type :
165+ import re
166+ match = re .search (r"ForwardRef\('([^']+)'\)" , inner_type )
167+ if match :
168+ inner_type = match .group (1 )
169+
170+ return f"list[{ inner_type } ]"
132171
133172 # Handle dict types
134173 if origin is dict or str (origin ).endswith ("dict" ):
135- return "dict"
174+ key_type = args [0 ]
175+ val_type = args [1 ]
176+
177+ # Get key type name
178+ if isinstance (key_type , type ):
179+ key_name = key_type .__name__
180+ else :
181+ key_name = str (key_type ).replace ("typing." , "" )
182+
183+ # Get value type name
184+ if isinstance (val_type , type ):
185+ val_name = val_type .__name__
186+ else :
187+ val_name = str (val_type ).replace ("typing." , "" )
188+
189+ return f"dict[{ key_name } , { val_name } ]"
136190
137191 # Handle other generic types
138192 origin_name = origin .__name__ if hasattr (origin , "__name__" ) else str (origin )
139193 origin_name = origin_name .lower () # Convert List to list, etc.
140- return origin_name
194+
195+ arg_strs = []
196+ for arg in args :
197+ if isinstance (arg , type ):
198+ arg_strs .append (arg .__name__ )
199+ else :
200+ arg_str = str (arg ).replace ("typing." , "" )
201+ if "ForwardRef" in arg_str :
202+ import re
203+ match = re .search (r"ForwardRef\('([^']+)'\)" , arg_str )
204+ if match :
205+ arg_str = match .group (1 )
206+ arg_strs .append (arg_str )
207+
208+ return f"{ origin_name } [{ ', ' .join (arg_strs )} ]"
141209
142210 # For any other types
143- return str (annotation ).replace ("typing." , "" )
211+ type_str = str (annotation ).replace ("typing." , "" )
212+
213+ # Clean up ForwardRef representation
214+ if "ForwardRef" in type_str :
215+ import re
216+ match = re .search (r"ForwardRef\('([^']+)'\)" , type_str )
217+ if match :
218+ return match .group (1 )
219+
220+ return type_str
0 commit comments