From b01ebd66e47e8ed2b1f5f9d4c584fff6400d7d8b Mon Sep 17 00:00:00 2001 From: LJ Date: Tue, 25 Mar 2025 19:41:37 -0700 Subject: [PATCH] For OpenAI, make all fields required in JSON schema. We use union type (with `null`) to mark fields optional. Background: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas#all-fields-must-be-required --- src/base/json_schema.rs | 46 +++++++++++++++++++++-------- src/llm/mod.rs | 15 ++++++++++ src/llm/openai.rs | 4 +++ src/ops/functions/extract_by_llm.rs | 8 +++-- 4 files changed, 59 insertions(+), 14 deletions(-) diff --git a/src/base/json_schema.rs b/src/base/json_schema.rs index 5a0840dfd..2a1124f78 100644 --- a/src/base/json_schema.rs +++ b/src/base/json_schema.rs @@ -3,12 +3,19 @@ use schemars::schema::{ ArrayValidation, InstanceType, Metadata, ObjectValidation, Schema, SchemaObject, SingleOrVec, }; +pub struct ToJsonSchemaOptions { + /// If true, mark all fields as required. + /// Use union type (with `null`) for optional fields instead. + /// Models like OpenAI will reject the schema if a field is not required. + pub fields_always_required: bool, +} + pub trait ToJsonSchema { - fn to_json_schema(&self) -> SchemaObject; + fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject; } impl ToJsonSchema for schema::BasicValueType { - fn to_json_schema(&self) -> SchemaObject { + fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject { let mut schema = SchemaObject::default(); match self { schema::BasicValueType::Str => { @@ -59,7 +66,7 @@ impl ToJsonSchema for schema::BasicValueType { schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::Array))); schema.array = Some(Box::new(ArrayValidation { items: Some(SingleOrVec::Single(Box::new( - s.element_type.to_json_schema().into(), + s.element_type.to_json_schema(options).into(), ))), min_items: s.dimension.and_then(|d| u32::try_from(d).ok()), max_items: s.dimension.and_then(|d| u32::try_from(d).ok()), @@ -72,7 +79,7 @@ impl ToJsonSchema for schema::BasicValueType { } impl ToJsonSchema for schema::StructSchema { - fn to_json_schema(&self) -> SchemaObject { + fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject { SchemaObject { metadata: Some(Box::new(Metadata { description: self.description.as_ref().map(|s| s.to_string()), @@ -83,12 +90,25 @@ impl ToJsonSchema for schema::StructSchema { properties: self .fields .iter() - .map(|f| (f.name.to_string(), f.value_type.to_json_schema().into())) + .map(|f| { + let mut schema = f.value_type.to_json_schema(options); + if options.fields_always_required && f.value_type.nullable { + if let Some(instance_type) = &mut schema.instance_type { + let mut types = match instance_type { + SingleOrVec::Single(t) => vec![**t], + SingleOrVec::Vec(t) => std::mem::take(t), + }; + types.push(InstanceType::Null); + *instance_type = SingleOrVec::Vec(types); + } + } + (f.name.to_string(), schema.into()) + }) .collect(), required: self .fields .iter() - .filter(|&f| (!f.value_type.nullable)) + .filter(|&f| (options.fields_always_required || !f.value_type.nullable)) .map(|f| f.name.to_string()) .collect(), additional_properties: Some(Schema::Bool(false).into()), @@ -100,14 +120,16 @@ impl ToJsonSchema for schema::StructSchema { } impl ToJsonSchema for schema::ValueType { - fn to_json_schema(&self) -> SchemaObject { + fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject { match self { - schema::ValueType::Basic(b) => b.to_json_schema(), - schema::ValueType::Struct(s) => s.to_json_schema(), + schema::ValueType::Basic(b) => b.to_json_schema(options), + schema::ValueType::Struct(s) => s.to_json_schema(options), schema::ValueType::Collection(c) => SchemaObject { instance_type: Some(SingleOrVec::Single(Box::new(InstanceType::Array))), array: Some(Box::new(ArrayValidation { - items: Some(SingleOrVec::Single(Box::new(c.row.to_json_schema().into()))), + items: Some(SingleOrVec::Single(Box::new( + c.row.to_json_schema(options).into(), + ))), ..Default::default() })), ..Default::default() @@ -117,7 +139,7 @@ impl ToJsonSchema for schema::ValueType { } impl ToJsonSchema for schema::EnrichedValueType { - fn to_json_schema(&self) -> SchemaObject { - self.typ.to_json_schema() + fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject { + self.typ.to_json_schema(options) } } diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 7c0b2b324..39cc71a84 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -5,6 +5,8 @@ use async_trait::async_trait; use schemars::schema::SchemaObject; use serde::{Deserialize, Serialize}; +use crate::base::json_schema::ToJsonSchemaOptions; + #[derive(Debug, Clone, Serialize, Deserialize)] pub enum LlmApiType { Ollama, @@ -44,6 +46,19 @@ pub trait LlmGenerationClient: Send + Sync { &self, request: LlmGenerateRequest<'req>, ) -> Result; + + /// If true, the LLM only accepts a JSON schema with all fields required. + /// This is a limitation of LLM models such as OpenAI. + /// Otherwise, the LLM will accept a JSON schema with optional fields. + fn json_schema_fields_always_required(&self) -> bool { + false + } + + fn to_json_schema_options(&self) -> ToJsonSchemaOptions { + ToJsonSchemaOptions { + fields_always_required: self.json_schema_fields_always_required(), + } + } } mod ollama; diff --git a/src/llm/openai.rs b/src/llm/openai.rs index 8df7f4406..47c82932d 100644 --- a/src/llm/openai.rs +++ b/src/llm/openai.rs @@ -97,4 +97,8 @@ impl LlmGenerationClient for Client { Ok(super::LlmGenerateResponse { text }) } + + fn json_schema_fields_always_required(&self) -> bool { + true + } } diff --git a/src/ops/functions/extract_by_llm.rs b/src/ops/functions/extract_by_llm.rs index 35a06474d..765a7508f 100644 --- a/src/ops/functions/extract_by_llm.rs +++ b/src/ops/functions/extract_by_llm.rs @@ -47,10 +47,14 @@ Output only the JSON without any additional messages or explanations." impl Executor { async fn new(spec: Spec, args: Args) -> Result { + let client = new_llm_generation_client(spec.llm_spec).await?; + let output_json_schema = spec + .output_type + .to_json_schema(&client.to_json_schema_options()); Ok(Self { args, - client: new_llm_generation_client(spec.llm_spec).await?, - output_json_schema: spec.output_type.to_json_schema(), + client, + output_json_schema, output_type: spec.output_type, system_prompt: get_system_prompt(&spec.instruction), })