Skip to content

Commit 203b395

Browse files
Fix validation of Literal from JSON keys when used as dict key (#1075)
Co-authored-by: David Montague <[email protected]>
1 parent 3fea833 commit 203b395

File tree

2 files changed

+45
-4
lines changed

2 files changed

+45
-4
lines changed

src/validators/literal.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use pyo3::{intern, PyTraverseError, PyVisit};
99

1010
use crate::build_tools::{py_schema_err, py_schema_error_type};
1111
use crate::errors::{ErrorType, ValError, ValResult};
12-
use crate::input::Input;
12+
use crate::input::{Input, ValidationMatch};
1313
use crate::py_gc::PyGcTraverse;
1414
use crate::tools::SchemaDict;
1515

@@ -116,8 +116,18 @@ impl<T: Debug> LiteralLookup<T> {
116116
}
117117
}
118118
if let Some(expected_strings) = &self.expected_str {
119-
// dbg!(expected_strings);
120-
if let Ok(either_str) = input.exact_str() {
119+
let validation_result = if input.is_python() {
120+
input.exact_str()
121+
} else {
122+
// Strings coming from JSON are treated as "strict" but not "exact" for reasons
123+
// of parsing types like UUID; see the implementation of `validate_str` for Json
124+
// inputs for justification. We might change that eventually, but for now we need
125+
// to work around this when loading from JSON
126+
// V3 TODO: revisit making this "exact" for JSON inputs
127+
input.validate_str(true, false).map(ValidationMatch::into_inner)
128+
};
129+
130+
if let Ok(either_str) = validation_result {
121131
let cow = either_str.as_cow()?;
122132
if let Some(id) = expected_strings.get(cow.as_ref()) {
123133
return Ok(Some((input, &self.values[*id])));

tests/test.rs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#[cfg(test)]
22
mod tests {
3-
use _pydantic_core::SchemaSerializer;
3+
use _pydantic_core::{SchemaSerializer, SchemaValidator};
44
use pyo3::prelude::*;
55
use pyo3::types::PyDict;
66

@@ -86,4 +86,35 @@ a = A()
8686
assert_eq!(serialized, b"{\"b\":\"b\"}");
8787
});
8888
}
89+
90+
#[test]
91+
fn test_literal_schema() {
92+
Python::with_gil(|py| {
93+
let code = r#"
94+
schema = {
95+
"type": "dict",
96+
"keys_schema": {
97+
"type": "literal",
98+
"expected": ["a", "b"],
99+
},
100+
"values_schema": {
101+
"type": "str",
102+
},
103+
"strict": False,
104+
}
105+
json_input = '{"a": "something"}'
106+
"#;
107+
let locals = PyDict::new(py);
108+
py.run(code, None, Some(locals)).unwrap();
109+
let schema: &PyDict = locals.get_item("schema").unwrap().unwrap().extract().unwrap();
110+
let json_input: &PyAny = locals.get_item("json_input").unwrap().unwrap().extract().unwrap();
111+
let binding = SchemaValidator::py_new(py, schema, None)
112+
.unwrap()
113+
.validate_json(py, json_input, None, None, None)
114+
.unwrap();
115+
let validation_result: &PyAny = binding.extract(py).unwrap();
116+
let repr = format!("{}", validation_result.repr().unwrap());
117+
assert_eq!(repr, "{'a': 'something'}");
118+
});
119+
}
89120
}

0 commit comments

Comments
 (0)