1- import inspect
2- import warnings
3- from typing import TypeVar , Union , get_args , get_origin , Callable , Any , Optional , cast
1+ from typing import Optional
42
5- from pydantic import BaseModel
3+ from pydantic import BaseModel , Field
4+ from pydantic_prompt import prompt_schema
65
7- T = TypeVar ('T' , bound = BaseModel )
86
7+ def test_basic_docstring_extraction ():
8+ @prompt_schema
9+ class BasicModel (BaseModel ):
10+ name : str
11+ """The user's name"""
912
10- def prompt_schema (cls : type [T ], * , warn_undocumented : bool = True ) -> type [T ]:
11- """
12- Decorator to add LLM documentation methods to a Pydantic model.
13-
14- Args:
15- cls: The Pydantic model class to decorate
16- warn_undocumented: Whether to emit warnings for fields without docstrings
17-
18- Returns:
19- The decorated class with format_for_llm method
20- """
21-
22- def format_for_llm_impl (cls_param : type [T ], include_validation : bool = False ) -> str :
23- """Format this model's fields and docstrings for LLM prompts."""
24- lines = [f"{ cls_param .__name__ } :" ]
25-
26- # Get JSON schema to extract validation info if needed
27- json_schema = cls_param .model_json_schema () if include_validation else {}
28- properties = json_schema .get ("properties" , {})
29-
30- # Iterate through each field
31- for name , field_info in cls_param .model_fields .items ():
32- # Get the field's type
33- field_type = _get_field_type_name (field_info )
34-
35- # Get docstring for the field
36- docstring = _extract_field_docstring (cls_param , name )
37-
38- # Warn if field is not documented
39- if warn_undocumented and not docstring :
40- warnings .warn (
41- f"Field '{ name } ' in { cls_param .__name__ } has no docstring. "
42- "Add a docstring for better LLM prompts." ,
43- UserWarning ,
44- stacklevel = 2
45- )
46-
47- # Determine if field is optional
48- is_optional = not field_info .is_required ()
49- optional_str = ", optional" if is_optional else ""
50-
51- # Format the field line
52- field_line = f"- { name } ({ field_type } { optional_str } ): { docstring } "
53-
54- # Add validation info if requested
55- if include_validation and name in properties :
56- field_schema = properties [name ]
57-
58- constraints = []
59- # Common validation keywords
60- for key in ["minLength" , "maxLength" , "minimum" , "maximum" , "pattern" ]:
61- if key in field_schema :
62- # Convert camelCase to snake_case for display
63- display_key = "" .join (
64- ["_" + c .lower () if c .isupper () else c for c in key ]
65- ).lstrip ("_" )
66- # Special case mappings
67- if display_key == "minimum" :
68- display_key = "ge"
69- elif display_key == "maximum" :
70- display_key = "le"
71- elif display_key == "min_length" :
72- display_key = "min_length"
73- elif display_key == "max_length" :
74- display_key = "max_length"
75-
76- constraints .append (f"{ display_key } : { field_schema [key ]} " )
77-
78- if constraints :
79- field_line += f" [Constraints: { ', ' .join (constraints )} ]"
80-
81- lines .append (field_line )
82-
83- return "\n " .join (lines )
84-
85- # Add the format_for_llm method to the class using the classmethod decorator
86- setattr (cls , "format_for_llm" , classmethod (format_for_llm_impl )) # type: ignore
87-
88- return cls
89-
90-
91- def _extract_field_docstring (cls : type , field_name : str ) -> str :
92- """Extract docstring for a field from class source code."""
93- try :
94- source = inspect .getsource (cls )
95-
96- # Look for field definition
97- patterns = [
98- f"{ field_name } :" ,
99- f"{ field_name } :" ,
100- f"{ field_name } ="
101- ]
102-
103- field_pos = - 1
104- for pattern in patterns :
105- pos = source .find (pattern )
106- if pos != - 1 :
107- field_pos = pos
108- break
109-
110- if field_pos == - 1 :
111- return ""
112-
113- # Look for triple-quoted docstring
114- for quote in ['"""' , "'''" ]:
115- doc_start = source .find (quote , field_pos )
116- if doc_start != - 1 :
117- doc_end = source .find (quote , doc_start + 3 )
118- if doc_end != - 1 :
119- return source [doc_start + 3 :doc_end ].strip ()
120-
121- except Exception :
122- pass
123-
124- return ""
125-
126-
127- def _get_field_type_name (field_info : Any ) -> str :
128- """Get a user-friendly type name from a field."""
129- annotation = field_info .annotation
130-
131- # Handle Optional types
132- if get_origin (annotation ) is Union and type (None ) in get_args (annotation ):
133- args = get_args (annotation )
134- for arg in args :
135- if arg is not type (None ):
136- # Remove Optional wrapper, we handle optionality separately
137- annotation = arg
138- break
139-
140- # Handle basic types
141- if isinstance (annotation , type ):
142- return annotation .__name__
143-
144- # Handle parameterized generics
145- origin = get_origin (annotation )
146- if origin is not None :
147- args = get_args (annotation )
148-
149- # Handle list types
150- if origin is list or str (origin ).endswith ("list" ):
151- arg_type = args [0 ]
152- inner_type = ""
153-
154- # Get the name of the inner type more reliably
155- if isinstance (arg_type , type ):
156- inner_type = arg_type .__name__
157- elif hasattr (arg_type , "_name" ) and arg_type ._name :
158- inner_type = arg_type ._name
159- else :
160- # Fall back to string representation with cleanup
161- inner_type = str (arg_type ).replace ("typing." , "" ).strip ("'<>" )
162-
163- # Extract class name from ForwardRef if needed
164- if "ForwardRef" in inner_type :
165- import re
166- match = re .search (r"ForwardRef\('([^']+)'\)" , inner_type )
167- if match :
168- inner_type = match .group (1 )
169-
170- return f"list[{ inner_type } ]"
171-
172- # Handle dict types
173- if origin is dict or str (origin ).endswith ("dict" ):
174- key_type = args [0 ]
175- val_type = args [1 ]
176-
177- # Get key type name
178- if isinstance (key_type , type ):
179- key_name = key_type .__name__
180- else :
181- key_name = str (key_type ).replace ("typing." , "" )
182-
183- # Get value type name
184- if isinstance (val_type , type ):
185- val_name = val_type .__name__
186- else :
187- val_name = str (val_type ).replace ("typing." , "" )
188-
189- return f"dict[{ key_name } , { val_name } ]"
190-
191- # Handle other generic types
192- origin_name = origin .__name__ if hasattr (origin , "__name__" ) else str (origin )
193- origin_name = origin_name .lower () # Convert List to list, etc.
194-
195- arg_strs = []
196- for arg in args :
197- if isinstance (arg , type ):
198- arg_strs .append (arg .__name__ )
199- else :
200- arg_str = str (arg ).replace ("typing." , "" )
201- if "ForwardRef" in arg_str :
202- import re
203- match = re .search (r"ForwardRef\('([^']+)'\)" , arg_str )
204- if match :
205- arg_str = match .group (1 )
206- arg_strs .append (arg_str )
207-
208- return f"{ origin_name } [{ ', ' .join (arg_strs )} ]"
209-
210- # For any other types
211- type_str = str (annotation ).replace ("typing." , "" )
212-
213- # Clean up ForwardRef representation
214- if "ForwardRef" in type_str :
215- import re
216- match = re .search (r"ForwardRef\('([^']+)'\)" , type_str )
217- if match :
218- return match .group (1 )
13+ age : int
14+ """Age in years"""
15+
16+ output = BasicModel .format_for_llm ()
17+ assert "name (str): The user's name" in output
18+ assert "age (int): Age in years" in output
19+
20+
21+ def test_optional_fields ():
22+ @prompt_schema
23+ class OptionalFieldsModel (BaseModel ):
24+ required : str
25+ """Required field"""
26+
27+ optional : Optional [str ] = None
28+ """Optional field"""
29+
30+ output = OptionalFieldsModel .format_for_llm ()
31+ assert "required (str):" in output
32+ assert "optional (str, optional):" in output
33+
34+
35+ def test_validation_rules ():
36+ @prompt_schema
37+ class ValidationModel (BaseModel ):
38+ name : str = Field (min_length = 2 , max_length = 50 )
39+ """User name"""
40+
41+ age : int = Field (ge = 0 , le = 120 )
42+ """Age in years"""
43+
44+ # Without validation
45+ basic_output = ValidationModel .format_for_llm ()
46+ assert "Constraints" not in basic_output
47+
48+ # With validation
49+ validation_output = ValidationModel .format_for_llm (include_validation = True )
50+ assert "Constraints: min_length: 2, max_length: 50" in validation_output
51+ assert "Constraints: ge: 0, le: 120" in validation_output
52+
53+
54+ def test_nested_models ():
55+ @prompt_schema
56+ class Address (BaseModel ):
57+ street : str
58+ """Street address"""
59+
60+ city : str
61+ """City name"""
62+
63+ @prompt_schema
64+ class Person (BaseModel ):
65+ name : str
66+ """Person's name"""
67+
68+ addresses : list [Address ] = []
69+ """List of addresses"""
70+
71+ output = Person .format_for_llm ()
72+ assert "name (str): Person's name" in output
21973
220- return type_str
74+ # More flexible assertion that checks for the important parts
75+ assert "addresses (list[Address], optional): List of addresses" in output or (
76+ "addresses (list[" in output and
77+ "Address" in output and
78+ "optional): List of addresses" in output
79+ )
0 commit comments