Skip to content

Commit c9b3d20

Browse files
authored
Logic for instances of subclasses of strings (#294)
* prevent subclasses of str to strict string * strict and lax, add tests * one more test
1 parent 49d77de commit c9b3d20

File tree

5 files changed

+74
-4
lines changed

5 files changed

+74
-4
lines changed

pydantic_core/core_schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,7 @@ def json_schema(schema: CoreSchema | None = None, *, ref: str | None = None, ext
10361036
'iterable_type',
10371037
'iteration_error',
10381038
'string_type',
1039+
'string_sub_type',
10391040
'string_unicode',
10401041
'string_too_short',
10411042
'string_too_long',

src/errors/kinds.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ pub enum ErrorKind {
131131
// string errors
132132
#[strum(message = "Input should be a valid string")]
133133
StringType,
134+
#[strum(message = "Input should be a string, not an instance of a subclass of str")]
135+
StringSubType,
134136
#[strum(message = "Input should be a valid string, unable to parse raw data as a unicode string")]
135137
StringUnicode,
136138
#[strum(message = "String should have at least {min_length} characters")]

src/input/input_python.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use pyo3::types::{
1010
};
1111
#[cfg(not(PyPy))]
1212
use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues};
13-
use pyo3::{intern, AsPyPointer};
13+
use pyo3::{intern, AsPyPointer, PyTypeInfo};
1414

1515
use crate::errors::{py_err_string, ErrorKind, InputValue, LocItem, ValError, ValResult};
1616

@@ -142,15 +142,25 @@ impl<'a> Input<'a> for PyAny {
142142

143143
fn strict_str(&'a self) -> ValResult<EitherString<'a>> {
144144
if let Ok(py_str) = self.cast_as::<PyString>() {
145-
Ok(py_str.into())
145+
if is_builtin_str(py_str) {
146+
Ok(py_str.into())
147+
} else {
148+
Err(ValError::new(ErrorKind::StringSubType, self))
149+
}
146150
} else {
147151
Err(ValError::new(ErrorKind::StringType, self))
148152
}
149153
}
150154

151155
fn lax_str(&'a self) -> ValResult<EitherString<'a>> {
152156
if let Ok(py_str) = self.cast_as::<PyString>() {
153-
Ok(py_str.into())
157+
if is_builtin_str(py_str) {
158+
Ok(py_str.into())
159+
} else {
160+
// force to a rust string to make sure behaviour is consistent whether or not we go via a
161+
// rust string in StrConstrainedValidator - e.g. to_lower
162+
Ok(py_string_str(py_str)?.into())
163+
}
154164
} else if let Ok(bytes) = self.cast_as::<PyBytes>() {
155165
let str = match from_utf8(bytes.as_bytes()) {
156166
Ok(s) => s,
@@ -672,3 +682,7 @@ fn import_type(py: Python, module: &str, attr: &str) -> PyResult<Py<PyType>> {
672682
let obj = py.import(module)?.getattr(attr)?;
673683
Ok(obj.cast_as::<PyType>()?.into())
674684
}
685+
686+
fn is_builtin_str(py_str: &PyString) -> bool {
687+
py_str.get_type().is(PyString::type_object(py_str.py()))
688+
}

tests/benchmarks/test_micro_benchmarks.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,24 @@ class MyCoreModel:
212212

213213
@pytest.mark.benchmark(group='string')
214214
def test_core_string_lax(benchmark):
215-
validator = SchemaValidator({'type': 'str'})
215+
validator = SchemaValidator(core_schema.string_schema())
216216
input_str = 'Hello ' * 20
217217

218+
assert validator.validate_python(input_str) == input_str
219+
218220
benchmark(validator.validate_python, input_str)
219221

220222

223+
@pytest.mark.benchmark(group='string')
224+
def test_core_string_strict(benchmark):
225+
validator = SchemaValidator(core_schema.string_schema(strict=True))
226+
input_str = 'Hello ' * 20
227+
228+
assert validator.validate_python(input_str) == input_str
229+
230+
benchmark(validator.validate_python, 'foo')
231+
232+
221233
@pytest.fixture
222234
def recursive_model_data():
223235
data = {'width': -1}

tests/validators/test_string.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,44 @@ def test_regex_error():
186186
def test_default_validator():
187187
v = SchemaValidator(core_schema.string_schema(strict=True, to_lower=False), {'str_strip_whitespace': False})
188188
assert plain_repr(v) == 'SchemaValidator(name="str",validator=Str(StrValidator{strict:true}),slots=[])'
189+
190+
191+
@pytest.fixture(scope='session', name='FruitEnum')
192+
def fruit_enum_fixture():
193+
from enum import Enum
194+
195+
class FruitEnum(str, Enum):
196+
pear = 'pear'
197+
banana = 'banana'
198+
199+
return FruitEnum
200+
201+
202+
@pytest.mark.parametrize('kwargs', [{}, {'to_lower': True}], ids=repr)
203+
def test_strict_subclass(FruitEnum, kwargs):
204+
v = SchemaValidator(core_schema.string_schema(strict=True, **kwargs))
205+
assert v.validate_python('foobar') == 'foobar'
206+
with pytest.raises(ValidationError, match='kind=string_type,'):
207+
v.validate_python(b'foobar')
208+
with pytest.raises(ValidationError, match='kind=string_sub_type,') as exc_info:
209+
v.validate_python(FruitEnum.pear)
210+
# insert_assert(exc_info.value.errors())
211+
assert exc_info.value.errors() == [
212+
{
213+
'kind': 'string_sub_type',
214+
'loc': [],
215+
'message': 'Input should be a string, not an instance of a subclass of str',
216+
'input_value': FruitEnum.pear,
217+
}
218+
]
219+
220+
221+
@pytest.mark.parametrize('kwargs', [{}, {'to_lower': True}], ids=repr)
222+
def test_lax_subclass(FruitEnum, kwargs):
223+
v = SchemaValidator(core_schema.string_schema(**kwargs))
224+
assert v.validate_python('foobar') == 'foobar'
225+
assert v.validate_python(b'foobar') == 'foobar'
226+
p = v.validate_python(FruitEnum.pear)
227+
assert p == 'pear'
228+
assert type(p) is str
229+
assert repr(p) == "'pear'"

0 commit comments

Comments
 (0)