Skip to content

Commit 3e8eeb3

Browse files
committed
restore tests
1 parent 817093b commit 3e8eeb3

File tree

1 file changed

+74
-215
lines changed

1 file changed

+74
-215
lines changed

tests/test_pydantic_prompt.py

Lines changed: 74 additions & 215 deletions
Original file line numberDiff line numberDiff line change
@@ -1,220 +1,79 @@
1-
import inspect
2-
import warnings
3-
from typing import TypeVar, Union, get_args, get_origin, Callable, Any, Optional, cast
1+
from typing import Optional
42

5-
from pydantic import BaseModel
3+
from pydantic import BaseModel, Field
4+
from pydantic_prompt import prompt_schema
65

7-
T = TypeVar('T', bound=BaseModel)
86

7+
def test_basic_docstring_extraction():
8+
@prompt_schema
9+
class BasicModel(BaseModel):
10+
name: str
11+
"""The user's name"""
912

10-
def prompt_schema(cls: type[T], *, warn_undocumented: bool = True) -> type[T]:
11-
"""
12-
Decorator to add LLM documentation methods to a Pydantic model.
13-
14-
Args:
15-
cls: The Pydantic model class to decorate
16-
warn_undocumented: Whether to emit warnings for fields without docstrings
17-
18-
Returns:
19-
The decorated class with format_for_llm method
20-
"""
21-
22-
def format_for_llm_impl(cls_param: type[T], include_validation: bool = False) -> str:
23-
"""Format this model's fields and docstrings for LLM prompts."""
24-
lines = [f"{cls_param.__name__}:"]
25-
26-
# Get JSON schema to extract validation info if needed
27-
json_schema = cls_param.model_json_schema() if include_validation else {}
28-
properties = json_schema.get("properties", {})
29-
30-
# Iterate through each field
31-
for name, field_info in cls_param.model_fields.items():
32-
# Get the field's type
33-
field_type = _get_field_type_name(field_info)
34-
35-
# Get docstring for the field
36-
docstring = _extract_field_docstring(cls_param, name)
37-
38-
# Warn if field is not documented
39-
if warn_undocumented and not docstring:
40-
warnings.warn(
41-
f"Field '{name}' in {cls_param.__name__} has no docstring. "
42-
"Add a docstring for better LLM prompts.",
43-
UserWarning,
44-
stacklevel=2
45-
)
46-
47-
# Determine if field is optional
48-
is_optional = not field_info.is_required()
49-
optional_str = ", optional" if is_optional else ""
50-
51-
# Format the field line
52-
field_line = f"- {name} ({field_type}{optional_str}): {docstring}"
53-
54-
# Add validation info if requested
55-
if include_validation and name in properties:
56-
field_schema = properties[name]
57-
58-
constraints = []
59-
# Common validation keywords
60-
for key in ["minLength", "maxLength", "minimum", "maximum", "pattern"]:
61-
if key in field_schema:
62-
# Convert camelCase to snake_case for display
63-
display_key = "".join(
64-
["_" + c.lower() if c.isupper() else c for c in key]
65-
).lstrip("_")
66-
# Special case mappings
67-
if display_key == "minimum":
68-
display_key = "ge"
69-
elif display_key == "maximum":
70-
display_key = "le"
71-
elif display_key == "min_length":
72-
display_key = "min_length"
73-
elif display_key == "max_length":
74-
display_key = "max_length"
75-
76-
constraints.append(f"{display_key}: {field_schema[key]}")
77-
78-
if constraints:
79-
field_line += f" [Constraints: {', '.join(constraints)}]"
80-
81-
lines.append(field_line)
82-
83-
return "\n".join(lines)
84-
85-
# Add the format_for_llm method to the class using the classmethod decorator
86-
setattr(cls, "format_for_llm", classmethod(format_for_llm_impl)) # type: ignore
87-
88-
return cls
89-
90-
91-
def _extract_field_docstring(cls: type, field_name: str) -> str:
92-
"""Extract docstring for a field from class source code."""
93-
try:
94-
source = inspect.getsource(cls)
95-
96-
# Look for field definition
97-
patterns = [
98-
f"{field_name}:",
99-
f"{field_name} :",
100-
f"{field_name} ="
101-
]
102-
103-
field_pos = -1
104-
for pattern in patterns:
105-
pos = source.find(pattern)
106-
if pos != -1:
107-
field_pos = pos
108-
break
109-
110-
if field_pos == -1:
111-
return ""
112-
113-
# Look for triple-quoted docstring
114-
for quote in ['"""', "'''"]:
115-
doc_start = source.find(quote, field_pos)
116-
if doc_start != -1:
117-
doc_end = source.find(quote, doc_start + 3)
118-
if doc_end != -1:
119-
return source[doc_start + 3:doc_end].strip()
120-
121-
except Exception:
122-
pass
123-
124-
return ""
125-
126-
127-
def _get_field_type_name(field_info: Any) -> str:
128-
"""Get a user-friendly type name from a field."""
129-
annotation = field_info.annotation
130-
131-
# Handle Optional types
132-
if get_origin(annotation) is Union and type(None) in get_args(annotation):
133-
args = get_args(annotation)
134-
for arg in args:
135-
if arg is not type(None):
136-
# Remove Optional wrapper, we handle optionality separately
137-
annotation = arg
138-
break
139-
140-
# Handle basic types
141-
if isinstance(annotation, type):
142-
return annotation.__name__
143-
144-
# Handle parameterized generics
145-
origin = get_origin(annotation)
146-
if origin is not None:
147-
args = get_args(annotation)
148-
149-
# Handle list types
150-
if origin is list or str(origin).endswith("list"):
151-
arg_type = args[0]
152-
inner_type = ""
153-
154-
# Get the name of the inner type more reliably
155-
if isinstance(arg_type, type):
156-
inner_type = arg_type.__name__
157-
elif hasattr(arg_type, "_name") and arg_type._name:
158-
inner_type = arg_type._name
159-
else:
160-
# Fall back to string representation with cleanup
161-
inner_type = str(arg_type).replace("typing.", "").strip("'<>")
162-
163-
# Extract class name from ForwardRef if needed
164-
if "ForwardRef" in inner_type:
165-
import re
166-
match = re.search(r"ForwardRef\('([^']+)'\)", inner_type)
167-
if match:
168-
inner_type = match.group(1)
169-
170-
return f"list[{inner_type}]"
171-
172-
# Handle dict types
173-
if origin is dict or str(origin).endswith("dict"):
174-
key_type = args[0]
175-
val_type = args[1]
176-
177-
# Get key type name
178-
if isinstance(key_type, type):
179-
key_name = key_type.__name__
180-
else:
181-
key_name = str(key_type).replace("typing.", "")
182-
183-
# Get value type name
184-
if isinstance(val_type, type):
185-
val_name = val_type.__name__
186-
else:
187-
val_name = str(val_type).replace("typing.", "")
188-
189-
return f"dict[{key_name}, {val_name}]"
190-
191-
# Handle other generic types
192-
origin_name = origin.__name__ if hasattr(origin, "__name__") else str(origin)
193-
origin_name = origin_name.lower() # Convert List to list, etc.
194-
195-
arg_strs = []
196-
for arg in args:
197-
if isinstance(arg, type):
198-
arg_strs.append(arg.__name__)
199-
else:
200-
arg_str = str(arg).replace("typing.", "")
201-
if "ForwardRef" in arg_str:
202-
import re
203-
match = re.search(r"ForwardRef\('([^']+)'\)", arg_str)
204-
if match:
205-
arg_str = match.group(1)
206-
arg_strs.append(arg_str)
207-
208-
return f"{origin_name}[{', '.join(arg_strs)}]"
209-
210-
# For any other types
211-
type_str = str(annotation).replace("typing.", "")
212-
213-
# Clean up ForwardRef representation
214-
if "ForwardRef" in type_str:
215-
import re
216-
match = re.search(r"ForwardRef\('([^']+)'\)", type_str)
217-
if match:
218-
return match.group(1)
13+
age: int
14+
"""Age in years"""
15+
16+
output = BasicModel.format_for_llm()
17+
assert "name (str): The user's name" in output
18+
assert "age (int): Age in years" in output
19+
20+
21+
def test_optional_fields():
22+
@prompt_schema
23+
class OptionalFieldsModel(BaseModel):
24+
required: str
25+
"""Required field"""
26+
27+
optional: Optional[str] = None
28+
"""Optional field"""
29+
30+
output = OptionalFieldsModel.format_for_llm()
31+
assert "required (str):" in output
32+
assert "optional (str, optional):" in output
33+
34+
35+
def test_validation_rules():
36+
@prompt_schema
37+
class ValidationModel(BaseModel):
38+
name: str = Field(min_length=2, max_length=50)
39+
"""User name"""
40+
41+
age: int = Field(ge=0, le=120)
42+
"""Age in years"""
43+
44+
# Without validation
45+
basic_output = ValidationModel.format_for_llm()
46+
assert "Constraints" not in basic_output
47+
48+
# With validation
49+
validation_output = ValidationModel.format_for_llm(include_validation=True)
50+
assert "Constraints: min_length: 2, max_length: 50" in validation_output
51+
assert "Constraints: ge: 0, le: 120" in validation_output
52+
53+
54+
def test_nested_models():
55+
@prompt_schema
56+
class Address(BaseModel):
57+
street: str
58+
"""Street address"""
59+
60+
city: str
61+
"""City name"""
62+
63+
@prompt_schema
64+
class Person(BaseModel):
65+
name: str
66+
"""Person's name"""
67+
68+
addresses: list[Address] = []
69+
"""List of addresses"""
70+
71+
output = Person.format_for_llm()
72+
assert "name (str): Person's name" in output
21973

220-
return type_str
74+
# More flexible assertion that checks for the important parts
75+
assert "addresses (list[Address], optional): List of addresses" in output or (
76+
"addresses (list[" in output and
77+
"Address" in output and
78+
"optional): List of addresses" in output
79+
)

0 commit comments

Comments
 (0)