Skip to content

Commit d818a1f

Browse files
committed
feat: improve nested type handling and rename decorator to prompt_schema
- Enhance type formatting to properly show nested types (like list[Address]) - Rename decorator from llm_documented to prompt_schema for better clarity - Add warning system for undocumented fields - Make tests more robust across different Python versions - Fix type extraction from ForwardRef and complex generic types
1 parent b99a22f commit d818a1f

File tree

2 files changed

+328
-108
lines changed

2 files changed

+328
-108
lines changed

pydantic_prompt/core.py

Lines changed: 113 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,60 @@
11
import inspect
2-
from typing import Any, TypeVar, Union, get_args, get_origin
2+
import warnings
3+
from typing import TypeVar, Union, get_args, get_origin, Callable, Any, Optional, cast
34

45
from pydantic import BaseModel
56

6-
T = TypeVar("T", bound=BaseModel)
7+
T = TypeVar('T', bound=BaseModel)
78

89

9-
def llm_documented(cls: type[T]) -> type[T]:
10-
"""Decorator to add LLM documentation methods to a Pydantic model."""
11-
12-
def format_for_llm_impl(
13-
cls_param: type[T], include_validation: bool = False
14-
) -> str:
10+
def prompt_schema(cls: type[T], *, warn_undocumented: bool = True) -> type[T]:
11+
"""
12+
Decorator to add LLM documentation methods to a Pydantic model.
13+
14+
Args:
15+
cls: The Pydantic model class to decorate
16+
warn_undocumented: Whether to emit warnings for fields without docstrings
17+
18+
Returns:
19+
The decorated class with format_for_llm method
20+
"""
21+
22+
def format_for_llm_impl(cls_param: type[T], include_validation: bool = False) -> str:
1523
"""Format this model's fields and docstrings for LLM prompts."""
1624
lines = [f"{cls_param.__name__}:"]
17-
25+
1826
# Get JSON schema to extract validation info if needed
1927
json_schema = cls_param.model_json_schema() if include_validation else {}
2028
properties = json_schema.get("properties", {})
21-
29+
2230
# Iterate through each field
2331
for name, field_info in cls_param.model_fields.items():
2432
# Get the field's type
2533
field_type = _get_field_type_name(field_info)
26-
34+
2735
# Get docstring for the field
2836
docstring = _extract_field_docstring(cls_param, name)
29-
37+
38+
# Warn if field is not documented
39+
if warn_undocumented and not docstring:
40+
warnings.warn(
41+
f"Field '{name}' in {cls_param.__name__} has no docstring. "
42+
"Add a docstring for better LLM prompts.",
43+
UserWarning,
44+
stacklevel=2
45+
)
46+
3047
# Determine if field is optional
3148
is_optional = not field_info.is_required()
3249
optional_str = ", optional" if is_optional else ""
33-
50+
3451
# Format the field line
3552
field_line = f"- {name} ({field_type}{optional_str}): {docstring}"
36-
53+
3754
# Add validation info if requested
3855
if include_validation and name in properties:
3956
field_schema = properties[name]
40-
57+
4158
constraints = []
4259
# Common validation keywords
4360
for key in ["minLength", "maxLength", "minimum", "maximum", "pattern"]:
@@ -52,54 +69,58 @@ def format_for_llm_impl(
5269
elif display_key == "maximum":
5370
display_key = "le"
5471
elif display_key == "min_length":
55-
display_key = "min_length"
72+
display_key = "min_length"
5673
elif display_key == "max_length":
5774
display_key = "max_length"
58-
75+
5976
constraints.append(f"{display_key}: {field_schema[key]}")
60-
77+
6178
if constraints:
6279
field_line += f" [Constraints: {', '.join(constraints)}]"
63-
80+
6481
lines.append(field_line)
65-
82+
6683
return "\n".join(lines)
67-
84+
6885
# Add the format_for_llm method to the class using the classmethod decorator
69-
cls.format_for_llm = classmethod(format_for_llm_impl) # type: ignore
70-
86+
setattr(cls, "format_for_llm", classmethod(format_for_llm_impl)) # type: ignore
87+
7188
return cls
7289

7390

7491
def _extract_field_docstring(cls: type, field_name: str) -> str:
7592
"""Extract docstring for a field from class source code."""
7693
try:
7794
source = inspect.getsource(cls)
78-
95+
7996
# Look for field definition
80-
patterns = [f"{field_name}:", f"{field_name} :", f"{field_name} ="]
81-
97+
patterns = [
98+
f"{field_name}:",
99+
f"{field_name} :",
100+
f"{field_name} ="
101+
]
102+
82103
field_pos = -1
83104
for pattern in patterns:
84105
pos = source.find(pattern)
85106
if pos != -1:
86107
field_pos = pos
87108
break
88-
109+
89110
if field_pos == -1:
90111
return ""
91-
112+
92113
# Look for triple-quoted docstring
93114
for quote in ['"""', "'''"]:
94115
doc_start = source.find(quote, field_pos)
95116
if doc_start != -1:
96117
doc_end = source.find(quote, doc_start + 3)
97118
if doc_end != -1:
98-
return source[doc_start + 3 : doc_end].strip()
99-
119+
return source[doc_start + 3:doc_end].strip()
120+
100121
except Exception:
101122
pass
102-
123+
103124
return ""
104125

105126

@@ -127,17 +148,73 @@ def _get_field_type_name(field_info: Any) -> str:
127148

128149
# Handle list types
129150
if origin is list or str(origin).endswith("list"):
130-
# Just return "list" for simplicity and cross-version compatibility
131-
return "list"
151+
arg_type = args[0]
152+
inner_type = ""
153+
154+
# Get the name of the inner type more reliably
155+
if isinstance(arg_type, type):
156+
inner_type = arg_type.__name__
157+
elif hasattr(arg_type, "_name") and arg_type._name:
158+
inner_type = arg_type._name
159+
else:
160+
# Fall back to string representation with cleanup
161+
inner_type = str(arg_type).replace("typing.", "").strip("'<>")
162+
163+
# Extract class name from ForwardRef if needed
164+
if "ForwardRef" in inner_type:
165+
import re
166+
match = re.search(r"ForwardRef\('([^']+)'\)", inner_type)
167+
if match:
168+
inner_type = match.group(1)
169+
170+
return f"list[{inner_type}]"
132171

133172
# Handle dict types
134173
if origin is dict or str(origin).endswith("dict"):
135-
return "dict"
174+
key_type = args[0]
175+
val_type = args[1]
176+
177+
# Get key type name
178+
if isinstance(key_type, type):
179+
key_name = key_type.__name__
180+
else:
181+
key_name = str(key_type).replace("typing.", "")
182+
183+
# Get value type name
184+
if isinstance(val_type, type):
185+
val_name = val_type.__name__
186+
else:
187+
val_name = str(val_type).replace("typing.", "")
188+
189+
return f"dict[{key_name}, {val_name}]"
136190

137191
# Handle other generic types
138192
origin_name = origin.__name__ if hasattr(origin, "__name__") else str(origin)
139193
origin_name = origin_name.lower() # Convert List to list, etc.
140-
return origin_name
194+
195+
arg_strs = []
196+
for arg in args:
197+
if isinstance(arg, type):
198+
arg_strs.append(arg.__name__)
199+
else:
200+
arg_str = str(arg).replace("typing.", "")
201+
if "ForwardRef" in arg_str:
202+
import re
203+
match = re.search(r"ForwardRef\('([^']+)'\)", arg_str)
204+
if match:
205+
arg_str = match.group(1)
206+
arg_strs.append(arg_str)
207+
208+
return f"{origin_name}[{', '.join(arg_strs)}]"
141209

142210
# For any other types
143-
return str(annotation).replace("typing.", "")
211+
type_str = str(annotation).replace("typing.", "")
212+
213+
# Clean up ForwardRef representation
214+
if "ForwardRef" in type_str:
215+
import re
216+
match = re.search(r"ForwardRef\('([^']+)'\)", type_str)
217+
if match:
218+
return match.group(1)
219+
220+
return type_str

0 commit comments

Comments
 (0)