11// Validator for things inside of a typing.Literal[]
22// which can be an int, a string, bytes or an Enum value (including `class Foo(str, Enum)` type enums)
33use core:: fmt:: Debug ;
4- use std:: cmp:: Ordering ;
54
65use pyo3:: prelude:: * ;
76use pyo3:: types:: { PyDict , PyInt , PyList } ;
@@ -35,7 +34,7 @@ pub struct LiteralLookup<T: Debug> {
3534 // Catch all for hashable types like Enum and bytes (the latter only because it is seldom used)
3635 expected_py_dict : Option < Py < PyDict > > ,
3736 // Catch all for unhashable types like list
38- expected_py_list : Option < Py < PyList > > ,
37+ expected_py_values : Option < Vec < ( Py < PyAny > , usize ) > > ,
3938
4039 pub values : Vec < T > ,
4140}
@@ -46,7 +45,7 @@ impl<T: Debug> LiteralLookup<T> {
4645 let mut expected_int = AHashMap :: new ( ) ;
4746 let mut expected_str: AHashMap < String , usize > = AHashMap :: new ( ) ;
4847 let expected_py_dict = PyDict :: new_bound ( py) ;
49- let expected_py_list = PyList :: empty_bound ( py ) ;
48+ let mut expected_py_values = Vec :: new ( ) ;
5049 let mut values = Vec :: new ( ) ;
5150 for ( k, v) in expected {
5251 let id = values. len ( ) ;
@@ -71,7 +70,7 @@ impl<T: Debug> LiteralLookup<T> {
7170 . map_err ( |_| py_schema_error_type ! ( "error extracting str {:?}" , k) ) ?;
7271 expected_str. insert ( str. to_string ( ) , id) ;
7372 } else if expected_py_dict. set_item ( & k, id) . is_err ( ) {
74- expected_py_list . append ( ( & k , id) ) ? ;
73+ expected_py_values . push ( ( k . as_unbound ( ) . clone_ref ( py ) , id) ) ;
7574 }
7675 }
7776
@@ -92,9 +91,9 @@ impl<T: Debug> LiteralLookup<T> {
9291 true => None ,
9392 false => Some ( expected_py_dict. into ( ) ) ,
9493 } ,
95- expected_py_list : match expected_py_list . is_empty ( ) {
94+ expected_py_values : match expected_py_values . is_empty ( ) {
9695 true => None ,
97- false => Some ( expected_py_list . into ( ) ) ,
96+ false => Some ( expected_py_values ) ,
9897 } ,
9998 values,
10099 } )
@@ -143,23 +142,23 @@ impl<T: Debug> LiteralLookup<T> {
143142 }
144143 }
145144 }
145+ // cache py_input if needed, since we might need it for multiple lookups
146+ let mut py_input = None ;
146147 if let Some ( expected_py_dict) = & self . expected_py_dict {
148+ let py_input = py_input. get_or_insert_with ( || input. to_object ( py) ) ;
147149 // We don't use ? to unpack the result of `get_item` in the next line because unhashable
148150 // inputs will produce a TypeError, which in this case we just want to treat equivalently
149151 // to a failed lookup
150- if let Ok ( Some ( v) ) = expected_py_dict. bind ( py) . get_item ( input ) {
152+ if let Ok ( Some ( v) ) = expected_py_dict. bind ( py) . get_item ( & * py_input ) {
151153 let id: usize = v. extract ( ) . unwrap ( ) ;
152154 return Ok ( Some ( ( input, & self . values [ id] ) ) ) ;
153155 }
154156 } ;
155- if let Some ( expected_py_list) = & self . expected_py_list {
156- for item in expected_py_list. bind ( py) {
157- let ( k, id) : ( Bound < PyAny > , usize ) = item. extract ( ) ?;
158- if k. compare ( input. to_object ( py) . bind ( py) )
159- . unwrap_or ( Ordering :: Less )
160- . is_eq ( )
161- {
162- return Ok ( Some ( ( input, & self . values [ id] ) ) ) ;
157+ if let Some ( expected_py_values) = & self . expected_py_values {
158+ let py_input = py_input. get_or_insert_with ( || input. to_object ( py) ) ;
159+ for ( k, id) in expected_py_values {
160+ if k. bind ( py) . eq ( & * py_input) . unwrap_or ( false ) {
161+ return Ok ( Some ( ( input, & self . values [ * id] ) ) ) ;
163162 }
164163 }
165164 } ;
0 commit comments