Skip to content

Commit 7a63185

Browse files
authored
Add support for type in json schema to be a list (#138)
Close #126 --------- Signed-off-by: Benjamin <[email protected]>
1 parent 616e803 commit 7a63185

File tree

3 files changed

+61
-8
lines changed

3 files changed

+61
-8
lines changed

src/error.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ pub enum Error {
5757
ExternalReferencesNotSupported(Box<str>),
5858
#[error("Invalid reference format: {0}")]
5959
InvalidReferenceFormat(Box<str>),
60-
#[error("'type' must be a string")]
61-
TypeMustBeAString,
60+
#[error("'type' must be a string or an array of string")]
61+
TypeMustBeAStringOrArray,
6262
#[error("Unsupported type: {0}")]
6363
UnsupportedType(Box<str>),
6464
#[error("maxLength must be greater than or equal to minLength")]
@@ -80,7 +80,8 @@ impl Error {
8080
#[cfg(feature = "python-bindings")]
8181
impl From<Error> for pyo3::PyErr {
8282
fn from(e: Error) -> Self {
83-
use pyo3::{exceptions::PyValueError, PyErr};
83+
use pyo3::exceptions::PyValueError;
84+
use pyo3::PyErr;
8485
PyErr::new::<PyValueError, _>(e.to_string())
8586
}
8687
}

src/json_schema/mod.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
//! An empty object means unconstrained, allowing any JSON type.
101101
102102
use serde_json::Value;
103+
pub use types::*;
103104

104105
mod parsing;
105106
pub mod types;
@@ -189,9 +190,10 @@ pub fn regex_from_value(json: &Value, whitespace_pattern: Option<&str>) -> Resul
189190

190191
#[cfg(test)]
191192
mod tests {
192-
use super::*;
193193
use regex::Regex;
194194

195+
use super::*;
196+
195197
fn should_match(re: &Regex, value: &str) {
196198
// Asserts that value is fully matched.
197199
match re.find(value) {
@@ -1084,6 +1086,25 @@ mod tests {
10841086
"[email protected]", // multiple errors in domain
10851087
]
10861088
),
1089+
1090+
// ==========================================================
1091+
// Multiple types
1092+
// ==========================================================
1093+
(
1094+
r#"{
1095+
"title": "Foo",
1096+
"type": ["string", "number", "boolean"]
1097+
}"#,
1098+
format!(r#"((?:"{STRING_INNER}*")|(?:{NUMBER})|(?:{BOOLEAN}))"#).as_str(),
1099+
vec!["12.3", "true", r#""a""#],
1100+
vec![
1101+
"null",
1102+
"",
1103+
"12true",
1104+
r#"1.3"a""#,
1105+
r#"12.3true"a""#,
1106+
],
1107+
),
10871108
// Confirm that oneOf doesn't produce illegal lookaround: https://github.com/dottxt-ai/outlines/issues/823
10881109
//
10891110
// The pet field uses the discriminator field to decide which schema (Cat or Dog) applies, based on the pet_type property.

src/json_schema/parsing.rs

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ use serde_json::json;
77
use serde_json::Value;
88

99
use crate::json_schema::types;
10-
use crate::{Error, Result};
10+
use crate::Error;
11+
use crate::Result;
1112

1213
pub(crate) struct Parser<'a> {
1314
root: &'a Value,
@@ -305,9 +306,39 @@ impl<'a> Parser<'a> {
305306
}
306307

307308
fn parse_type(&mut self, obj: &serde_json::Map<String, Value>) -> Result<String> {
308-
let instance_type = obj["type"]
309-
.as_str()
310-
.ok_or_else(|| Error::TypeMustBeAString)?;
309+
match obj.get("type") {
310+
Some(Value::String(instance_type)) => self.parse_type_string(instance_type, obj),
311+
Some(Value::Array(instance_types)) => self.parse_type_array(instance_types, obj),
312+
_ => Err(Error::TypeMustBeAStringOrArray),
313+
}
314+
}
315+
316+
fn parse_type_array(
317+
&mut self,
318+
instance_types: &[serde_json::Value],
319+
obj: &serde_json::Map<String, Value>,
320+
) -> Result<String> {
321+
let xor_patterns = instance_types
322+
.iter()
323+
.map(|instance_type| match instance_type.as_str() {
324+
Some(instance_type) => {
325+
let sub_regex = self.parse_type_string(instance_type, obj)?;
326+
327+
Ok(format!(r"(?:{})", sub_regex))
328+
}
329+
None => Err(Error::TypeMustBeAStringOrArray),
330+
})
331+
.collect::<Result<Vec<String>>>()?
332+
.join("|");
333+
334+
Ok(format!(r"({})", xor_patterns))
335+
}
336+
337+
fn parse_type_string(
338+
&mut self,
339+
instance_type: &str,
340+
obj: &serde_json::Map<String, Value>,
341+
) -> Result<String> {
311342
match instance_type {
312343
"string" => self.parse_string_type(obj),
313344
"number" => self.parse_number_type(obj),

0 commit comments

Comments
 (0)