Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 164 additions & 76 deletions src/base/json_schema.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use super::schema;
use crate::utils::immutable::RefList;

use super::{schema, spec::FieldName};
use anyhow::Result;
use indexmap::IndexMap;
use schemars::schema::{
ArrayValidation, InstanceType, Metadata, ObjectValidation, Schema, SchemaObject, SingleOrVec,
ArrayValidation, InstanceType, ObjectValidation, Schema, SchemaObject, SingleOrVec,
};
use std::fmt::Write;

pub struct ToJsonSchemaOptions {
/// If true, mark all fields as required.
Expand All @@ -11,16 +16,47 @@ pub struct ToJsonSchemaOptions {

/// If true, the JSON schema supports the `format` keyword.
pub supports_format: bool,

/// If true, extract descriptions to a separate extra instruction.
pub extract_descriptions: bool,
}

pub trait ToJsonSchema {
fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject;
struct JsonSchemaBuilder {
options: ToJsonSchemaOptions,
extra_instructions_per_field: IndexMap<String, String>,
}

impl ToJsonSchema for schema::BasicValueType {
fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject {
impl JsonSchemaBuilder {
fn new(options: ToJsonSchemaOptions) -> Self {
Self {
options,
extra_instructions_per_field: IndexMap::new(),
}
}

fn set_description(
&mut self,
schema: &mut SchemaObject,
description: impl ToString,
field_path: RefList<'_, &'_ FieldName>,
) {
if self.options.extract_descriptions {
let mut fields: Vec<_> = field_path.iter().map(|f| f.as_str()).collect();
fields.reverse();
self.extra_instructions_per_field
.insert(fields.join("."), description.to_string());
} else {
schema.metadata.get_or_insert_default().description = Some(description.to_string());
}
}

fn for_basic_value_type(
&mut self,
basic_type: &schema::BasicValueType,
field_path: RefList<'_, &'_ FieldName>,
) -> SchemaObject {
let mut schema = SchemaObject::default();
match self {
match basic_type {
schema::BasicValueType::Str => {
schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String)));
}
Expand Down Expand Up @@ -52,51 +88,66 @@ impl ToJsonSchema for schema::BasicValueType {
max_items: Some(2),
..Default::default()
}));
schema.metadata.get_or_insert_default().description =
Some("A range, start pos (inclusive), end pos (exclusive).".to_string());
self.set_description(
&mut schema,
"A range represented by a list of two positions, start pos (inclusive), end pos (exclusive).",
field_path,
);
}
schema::BasicValueType::Uuid => {
schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String)));
if options.supports_format {
if self.options.supports_format {
schema.format = Some("uuid".to_string());
} else {
schema.metadata.get_or_insert_default().description =
Some("A UUID, e.g. 123e4567-e89b-12d3-a456-426614174000".to_string());
}
self.set_description(
&mut schema,
"A UUID, e.g. 123e4567-e89b-12d3-a456-426614174000",
field_path,
);
}
schema::BasicValueType::Date => {
schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String)));
if options.supports_format {
if self.options.supports_format {
schema.format = Some("date".to_string());
} else {
schema.metadata.get_or_insert_default().description =
Some("A date, e.g. 2025-03-27".to_string());
}
self.set_description(
&mut schema,
"A date in YYYY-MM-DD format, e.g. 2025-03-27",
field_path,
);
}
schema::BasicValueType::Time => {
schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String)));
if options.supports_format {
if self.options.supports_format {
schema.format = Some("time".to_string());
} else {
schema.metadata.get_or_insert_default().description =
Some("A time, e.g. 13:32:12".to_string());
}
self.set_description(
&mut schema,
"A time in HH:MM:SS format, e.g. 13:32:12",
field_path,
);
}
schema::BasicValueType::LocalDateTime => {
schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String)));
if options.supports_format {
if self.options.supports_format {
schema.format = Some("date-time".to_string());
}
schema.metadata.get_or_insert_default().description =
Some("Date time without timezone offset, e.g. 2025-03-27T13:32:12".to_string());
self.set_description(
&mut schema,
"Date time without timezone offset in YYYY-MM-DDTHH:MM:SS format, e.g. 2025-03-27T13:32:12",
field_path,
);
}
schema::BasicValueType::OffsetDateTime => {
schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String)));
if options.supports_format {
if self.options.supports_format {
schema.format = Some("date-time".to_string());
}
schema.metadata.get_or_insert_default().description =
Some("Date time with timezone offset in RFC3339, e.g. 2025-03-27T13:32:12Z, 2025-03-27T07:32:12.313-06:00".to_string());
self.set_description(
&mut schema,
"Date time with timezone offset in RFC3339, e.g. 2025-03-27T13:32:12Z, 2025-03-27T07:32:12.313-06:00",
field_path,
);
}
schema::BasicValueType::Json => {
// Can be any value. No type constraint.
Expand All @@ -105,7 +156,8 @@ impl ToJsonSchema for schema::BasicValueType {
schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::Array)));
schema.array = Some(Box::new(ArrayValidation {
items: Some(SingleOrVec::Single(Box::new(
s.element_type.to_json_schema(options).into(),
self.for_basic_value_type(&s.element_type, field_path)
.into(),
))),
min_items: s.dimension.and_then(|d| u32::try_from(d).ok()),
max_items: s.dimension.and_then(|d| u32::try_from(d).ok()),
Expand All @@ -115,70 +167,106 @@ impl ToJsonSchema for schema::BasicValueType {
}
schema
}
}

impl ToJsonSchema for schema::StructSchema {
fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject {
SchemaObject {
metadata: Some(Box::new(Metadata {
description: self.description.as_ref().map(|s| s.to_string()),
..Default::default()
})),
instance_type: Some(SingleOrVec::Single(Box::new(InstanceType::Object))),
object: Some(Box::new(ObjectValidation {
properties: self
.fields
.iter()
.map(|f| {
let mut schema = f.value_type.to_json_schema(options);
if options.fields_always_required && f.value_type.nullable {
if let Some(instance_type) = &mut schema.instance_type {
let mut types = match instance_type {
SingleOrVec::Single(t) => vec![**t],
SingleOrVec::Vec(t) => std::mem::take(t),
};
types.push(InstanceType::Null);
*instance_type = SingleOrVec::Vec(types);
}
fn for_struct_schema(
&mut self,
struct_schema: &schema::StructSchema,
field_path: RefList<'_, &'_ FieldName>,
) -> SchemaObject {
let mut schema = SchemaObject::default();
if let Some(description) = &struct_schema.description {
self.set_description(&mut schema, description, field_path);
}
schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::Object)));
schema.object = Some(Box::new(ObjectValidation {
properties: struct_schema
.fields
.iter()
.map(|f| {
let mut schema =
self.for_enriched_value_type(&f.value_type, field_path.prepend(&f.name));
if self.options.fields_always_required && f.value_type.nullable {
if let Some(instance_type) = &mut schema.instance_type {
let mut types = match instance_type {
SingleOrVec::Single(t) => vec![**t],
SingleOrVec::Vec(t) => std::mem::take(t),
};
types.push(InstanceType::Null);
*instance_type = SingleOrVec::Vec(types);
}
(f.name.to_string(), schema.into())
})
.collect(),
required: self
.fields
.iter()
.filter(|&f| (options.fields_always_required || !f.value_type.nullable))
.map(|f| f.name.to_string())
.collect(),
additional_properties: Some(Schema::Bool(false).into()),
..Default::default()
})),
}
(f.name.to_string(), schema.into())
})
.collect(),
required: struct_schema
.fields
.iter()
.filter(|&f| (self.options.fields_always_required || !f.value_type.nullable))
.map(|f| f.name.to_string())
.collect(),
additional_properties: Some(Schema::Bool(false).into()),
..Default::default()
}
}));
schema
}
}

impl ToJsonSchema for schema::ValueType {
fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject {
match self {
schema::ValueType::Basic(b) => b.to_json_schema(options),
schema::ValueType::Struct(s) => s.to_json_schema(options),
fn for_value_type(
&mut self,
value_type: &schema::ValueType,
field_path: RefList<'_, &'_ FieldName>,
) -> SchemaObject {
match value_type {
schema::ValueType::Basic(b) => self.for_basic_value_type(b, field_path),
schema::ValueType::Struct(s) => self.for_struct_schema(s, field_path),
schema::ValueType::Collection(c) => SchemaObject {
instance_type: Some(SingleOrVec::Single(Box::new(InstanceType::Array))),
array: Some(Box::new(ArrayValidation {
items: Some(SingleOrVec::Single(Box::new(
c.row.to_json_schema(options).into(),
self.for_struct_schema(&c.row, field_path).into(),
))),
..Default::default()
})),
..Default::default()
},
}
}
}

impl ToJsonSchema for schema::EnrichedValueType {
fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject {
self.typ.to_json_schema(options)
fn for_enriched_value_type(
&mut self,
enriched_value_type: &schema::EnrichedValueType,
field_path: RefList<'_, &'_ FieldName>,
) -> SchemaObject {
self.for_value_type(&enriched_value_type.typ, field_path)
}

fn build_extra_instructions(&self) -> Result<Option<String>> {
if self.extra_instructions_per_field.is_empty() {
return Ok(None);
}

let mut instructions = String::new();
write!(&mut instructions, "Instructions for specific fields:\n\n")?;
for (field_path, instruction) in self.extra_instructions_per_field.iter() {
write!(
&mut instructions,
"- {}: {}\n\n",
if field_path.is_empty() {
"(root object)"
} else {
field_path.as_str()
},
instruction
)?;
}
Ok(Some(instructions))
}
}

pub fn build_json_schema(
value_type: &schema::EnrichedValueType,
options: ToJsonSchemaOptions,
) -> Result<(SchemaObject, Option<String>)> {
let mut builder = JsonSchemaBuilder::new(options);
let schema = builder.for_enriched_value_type(value_type, RefList::Nil);
Ok((schema, builder.build_extra_instructions()?))
}
4 changes: 2 additions & 2 deletions src/base/value.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{api_bail, api_error};

use super::schema::*;
use anyhow::{Context, Result};
use anyhow::Result;
use base64::prelude::*;
use chrono::Offset;
use log::warn;
Expand All @@ -10,7 +10,7 @@ use serde::{
ser::{SerializeMap, SerializeSeq, SerializeTuple},
Deserialize, Serialize,
};
use std::{collections::BTreeMap, ops::Deref, str::FromStr, sync::Arc};
use std::{collections::BTreeMap, ops::Deref, sync::Arc};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct RangeValue {
Expand Down
19 changes: 1 addition & 18 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,7 @@ pub trait LlmGenerationClient: Send + Sync {
request: LlmGenerateRequest<'req>,
) -> Result<LlmGenerateResponse>;

/// If true, the LLM only accepts a JSON schema with all fields required.
/// This is a limitation of LLM models such as OpenAI.
/// Otherwise, the LLM will accept a JSON schema with optional fields.
fn json_schema_fields_always_required(&self) -> bool {
false
}

/// If true, the LLM supports the `format` keyword in the JSON schema.
fn json_schema_supports_format(&self) -> bool {
true
}

fn to_json_schema_options(&self) -> ToJsonSchemaOptions {
ToJsonSchemaOptions {
fields_always_required: self.json_schema_fields_always_required(),
supports_format: self.json_schema_supports_format(),
}
}
fn json_schema_options(&self) -> ToJsonSchemaOptions;
}

mod ollama;
Expand Down
8 changes: 8 additions & 0 deletions src/llm/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,12 @@ impl LlmGenerationClient for Client {
text: json.response,
})
}

fn json_schema_options(&self) -> super::ToJsonSchemaOptions {
super::ToJsonSchemaOptions {
fields_always_required: false,
supports_format: true,
extract_descriptions: true,
}
}
}
12 changes: 6 additions & 6 deletions src/llm/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ impl LlmGenerationClient for Client {
Ok(super::LlmGenerateResponse { text })
}

fn json_schema_fields_always_required(&self) -> bool {
true
}

fn json_schema_supports_format(&self) -> bool {
false
fn json_schema_options(&self) -> super::ToJsonSchemaOptions {
super::ToJsonSchemaOptions {
fields_always_required: true,
supports_format: false,
extract_descriptions: false,
}
}
}
Loading