diff --git a/src/base/json_schema.rs b/src/base/json_schema.rs index 66b834666..874d7087a 100644 --- a/src/base/json_schema.rs +++ b/src/base/json_schema.rs @@ -1,7 +1,12 @@ -use super::schema; +use crate::utils::immutable::RefList; + +use super::{schema, spec::FieldName}; +use anyhow::Result; +use indexmap::IndexMap; use schemars::schema::{ - ArrayValidation, InstanceType, Metadata, ObjectValidation, Schema, SchemaObject, SingleOrVec, + ArrayValidation, InstanceType, ObjectValidation, Schema, SchemaObject, SingleOrVec, }; +use std::fmt::Write; pub struct ToJsonSchemaOptions { /// If true, mark all fields as required. @@ -11,16 +16,47 @@ pub struct ToJsonSchemaOptions { /// If true, the JSON schema supports the `format` keyword. pub supports_format: bool, + + /// If true, extract descriptions to a separate extra instruction. + pub extract_descriptions: bool, } -pub trait ToJsonSchema { - fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject; +struct JsonSchemaBuilder { + options: ToJsonSchemaOptions, + extra_instructions_per_field: IndexMap, } -impl ToJsonSchema for schema::BasicValueType { - fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject { +impl JsonSchemaBuilder { + fn new(options: ToJsonSchemaOptions) -> Self { + Self { + options, + extra_instructions_per_field: IndexMap::new(), + } + } + + fn set_description( + &mut self, + schema: &mut SchemaObject, + description: impl ToString, + field_path: RefList<'_, &'_ FieldName>, + ) { + if self.options.extract_descriptions { + let mut fields: Vec<_> = field_path.iter().map(|f| f.as_str()).collect(); + fields.reverse(); + self.extra_instructions_per_field + .insert(fields.join("."), description.to_string()); + } else { + schema.metadata.get_or_insert_default().description = Some(description.to_string()); + } + } + + fn for_basic_value_type( + &mut self, + basic_type: &schema::BasicValueType, + field_path: RefList<'_, &'_ FieldName>, + ) -> SchemaObject { let mut schema = SchemaObject::default(); - match self { + match basic_type { schema::BasicValueType::Str => { schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String))); } @@ -52,51 +88,66 @@ impl ToJsonSchema for schema::BasicValueType { max_items: Some(2), ..Default::default() })); - schema.metadata.get_or_insert_default().description = - Some("A range, start pos (inclusive), end pos (exclusive).".to_string()); + self.set_description( + &mut schema, + "A range represented by a list of two positions, start pos (inclusive), end pos (exclusive).", + field_path, + ); } schema::BasicValueType::Uuid => { schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String))); - if options.supports_format { + if self.options.supports_format { schema.format = Some("uuid".to_string()); - } else { - schema.metadata.get_or_insert_default().description = - Some("A UUID, e.g. 123e4567-e89b-12d3-a456-426614174000".to_string()); } + self.set_description( + &mut schema, + "A UUID, e.g. 123e4567-e89b-12d3-a456-426614174000", + field_path, + ); } schema::BasicValueType::Date => { schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String))); - if options.supports_format { + if self.options.supports_format { schema.format = Some("date".to_string()); - } else { - schema.metadata.get_or_insert_default().description = - Some("A date, e.g. 2025-03-27".to_string()); } + self.set_description( + &mut schema, + "A date in YYYY-MM-DD format, e.g. 2025-03-27", + field_path, + ); } schema::BasicValueType::Time => { schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String))); - if options.supports_format { + if self.options.supports_format { schema.format = Some("time".to_string()); - } else { - schema.metadata.get_or_insert_default().description = - Some("A time, e.g. 13:32:12".to_string()); } + self.set_description( + &mut schema, + "A time in HH:MM:SS format, e.g. 13:32:12", + field_path, + ); } schema::BasicValueType::LocalDateTime => { schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String))); - if options.supports_format { + if self.options.supports_format { schema.format = Some("date-time".to_string()); } - schema.metadata.get_or_insert_default().description = - Some("Date time without timezone offset, e.g. 2025-03-27T13:32:12".to_string()); + self.set_description( + &mut schema, + "Date time without timezone offset in YYYY-MM-DDTHH:MM:SS format, e.g. 2025-03-27T13:32:12", + field_path, + ); } schema::BasicValueType::OffsetDateTime => { schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String))); - if options.supports_format { + if self.options.supports_format { schema.format = Some("date-time".to_string()); } - schema.metadata.get_or_insert_default().description = - Some("Date time with timezone offset in RFC3339, e.g. 2025-03-27T13:32:12Z, 2025-03-27T07:32:12.313-06:00".to_string()); + self.set_description( + &mut schema, + "Date time with timezone offset in RFC3339, e.g. 2025-03-27T13:32:12Z, 2025-03-27T07:32:12.313-06:00", + field_path, + ); } schema::BasicValueType::Json => { // Can be any value. No type constraint. @@ -105,7 +156,8 @@ impl ToJsonSchema for schema::BasicValueType { schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::Array))); schema.array = Some(Box::new(ArrayValidation { items: Some(SingleOrVec::Single(Box::new( - s.element_type.to_json_schema(options).into(), + self.for_basic_value_type(&s.element_type, field_path) + .into(), ))), min_items: s.dimension.and_then(|d| u32::try_from(d).ok()), max_items: s.dimension.and_then(|d| u32::try_from(d).ok()), @@ -115,59 +167,62 @@ impl ToJsonSchema for schema::BasicValueType { } schema } -} -impl ToJsonSchema for schema::StructSchema { - fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject { - SchemaObject { - metadata: Some(Box::new(Metadata { - description: self.description.as_ref().map(|s| s.to_string()), - ..Default::default() - })), - instance_type: Some(SingleOrVec::Single(Box::new(InstanceType::Object))), - object: Some(Box::new(ObjectValidation { - properties: self - .fields - .iter() - .map(|f| { - let mut schema = f.value_type.to_json_schema(options); - if options.fields_always_required && f.value_type.nullable { - if let Some(instance_type) = &mut schema.instance_type { - let mut types = match instance_type { - SingleOrVec::Single(t) => vec![**t], - SingleOrVec::Vec(t) => std::mem::take(t), - }; - types.push(InstanceType::Null); - *instance_type = SingleOrVec::Vec(types); - } + fn for_struct_schema( + &mut self, + struct_schema: &schema::StructSchema, + field_path: RefList<'_, &'_ FieldName>, + ) -> SchemaObject { + let mut schema = SchemaObject::default(); + if let Some(description) = &struct_schema.description { + self.set_description(&mut schema, description, field_path); + } + schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::Object))); + schema.object = Some(Box::new(ObjectValidation { + properties: struct_schema + .fields + .iter() + .map(|f| { + let mut schema = + self.for_enriched_value_type(&f.value_type, field_path.prepend(&f.name)); + if self.options.fields_always_required && f.value_type.nullable { + if let Some(instance_type) = &mut schema.instance_type { + let mut types = match instance_type { + SingleOrVec::Single(t) => vec![**t], + SingleOrVec::Vec(t) => std::mem::take(t), + }; + types.push(InstanceType::Null); + *instance_type = SingleOrVec::Vec(types); } - (f.name.to_string(), schema.into()) - }) - .collect(), - required: self - .fields - .iter() - .filter(|&f| (options.fields_always_required || !f.value_type.nullable)) - .map(|f| f.name.to_string()) - .collect(), - additional_properties: Some(Schema::Bool(false).into()), - ..Default::default() - })), + } + (f.name.to_string(), schema.into()) + }) + .collect(), + required: struct_schema + .fields + .iter() + .filter(|&f| (self.options.fields_always_required || !f.value_type.nullable)) + .map(|f| f.name.to_string()) + .collect(), + additional_properties: Some(Schema::Bool(false).into()), ..Default::default() - } + })); + schema } -} -impl ToJsonSchema for schema::ValueType { - fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject { - match self { - schema::ValueType::Basic(b) => b.to_json_schema(options), - schema::ValueType::Struct(s) => s.to_json_schema(options), + fn for_value_type( + &mut self, + value_type: &schema::ValueType, + field_path: RefList<'_, &'_ FieldName>, + ) -> SchemaObject { + match value_type { + schema::ValueType::Basic(b) => self.for_basic_value_type(b, field_path), + schema::ValueType::Struct(s) => self.for_struct_schema(s, field_path), schema::ValueType::Collection(c) => SchemaObject { instance_type: Some(SingleOrVec::Single(Box::new(InstanceType::Array))), array: Some(Box::new(ArrayValidation { items: Some(SingleOrVec::Single(Box::new( - c.row.to_json_schema(options).into(), + self.for_struct_schema(&c.row, field_path).into(), ))), ..Default::default() })), @@ -175,10 +230,43 @@ impl ToJsonSchema for schema::ValueType { }, } } -} -impl ToJsonSchema for schema::EnrichedValueType { - fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject { - self.typ.to_json_schema(options) + fn for_enriched_value_type( + &mut self, + enriched_value_type: &schema::EnrichedValueType, + field_path: RefList<'_, &'_ FieldName>, + ) -> SchemaObject { + self.for_value_type(&enriched_value_type.typ, field_path) + } + + fn build_extra_instructions(&self) -> Result> { + if self.extra_instructions_per_field.is_empty() { + return Ok(None); + } + + let mut instructions = String::new(); + write!(&mut instructions, "Instructions for specific fields:\n\n")?; + for (field_path, instruction) in self.extra_instructions_per_field.iter() { + write!( + &mut instructions, + "- {}: {}\n\n", + if field_path.is_empty() { + "(root object)" + } else { + field_path.as_str() + }, + instruction + )?; + } + Ok(Some(instructions)) } } + +pub fn build_json_schema( + value_type: &schema::EnrichedValueType, + options: ToJsonSchemaOptions, +) -> Result<(SchemaObject, Option)> { + let mut builder = JsonSchemaBuilder::new(options); + let schema = builder.for_enriched_value_type(value_type, RefList::Nil); + Ok((schema, builder.build_extra_instructions()?)) +} diff --git a/src/base/value.rs b/src/base/value.rs index f4149627f..cfa8583fd 100644 --- a/src/base/value.rs +++ b/src/base/value.rs @@ -1,7 +1,7 @@ use crate::{api_bail, api_error}; use super::schema::*; -use anyhow::{Context, Result}; +use anyhow::Result; use base64::prelude::*; use chrono::Offset; use log::warn; @@ -10,7 +10,7 @@ use serde::{ ser::{SerializeMap, SerializeSeq, SerializeTuple}, Deserialize, Serialize, }; -use std::{collections::BTreeMap, ops::Deref, str::FromStr, sync::Arc}; +use std::{collections::BTreeMap, ops::Deref, sync::Arc}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct RangeValue { diff --git a/src/llm/mod.rs b/src/llm/mod.rs index be7cde23d..e362fb310 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -47,24 +47,7 @@ pub trait LlmGenerationClient: Send + Sync { request: LlmGenerateRequest<'req>, ) -> Result; - /// If true, the LLM only accepts a JSON schema with all fields required. - /// This is a limitation of LLM models such as OpenAI. - /// Otherwise, the LLM will accept a JSON schema with optional fields. - fn json_schema_fields_always_required(&self) -> bool { - false - } - - /// If true, the LLM supports the `format` keyword in the JSON schema. - fn json_schema_supports_format(&self) -> bool { - true - } - - fn to_json_schema_options(&self) -> ToJsonSchemaOptions { - ToJsonSchemaOptions { - fields_always_required: self.json_schema_fields_always_required(), - supports_format: self.json_schema_supports_format(), - } - } + fn json_schema_options(&self) -> ToJsonSchemaOptions; } mod ollama; diff --git a/src/llm/ollama.rs b/src/llm/ollama.rs index 92c4688fa..560af5443 100644 --- a/src/llm/ollama.rs +++ b/src/llm/ollama.rs @@ -75,4 +75,12 @@ impl LlmGenerationClient for Client { text: json.response, }) } + + fn json_schema_options(&self) -> super::ToJsonSchemaOptions { + super::ToJsonSchemaOptions { + fields_always_required: false, + supports_format: true, + extract_descriptions: true, + } + } } diff --git a/src/llm/openai.rs b/src/llm/openai.rs index ca16dede9..119be6ea1 100644 --- a/src/llm/openai.rs +++ b/src/llm/openai.rs @@ -98,11 +98,11 @@ impl LlmGenerationClient for Client { Ok(super::LlmGenerateResponse { text }) } - fn json_schema_fields_always_required(&self) -> bool { - true - } - - fn json_schema_supports_format(&self) -> bool { - false + fn json_schema_options(&self) -> super::ToJsonSchemaOptions { + super::ToJsonSchemaOptions { + fields_always_required: true, + supports_format: false, + extract_descriptions: false, + } } } diff --git a/src/ops/functions/extract_by_llm.rs b/src/ops/functions/extract_by_llm.rs index 765a7508f..767077c7e 100644 --- a/src/ops/functions/extract_by_llm.rs +++ b/src/ops/functions/extract_by_llm.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use schemars::schema::SchemaObject; use serde::Serialize; -use crate::base::json_schema::ToJsonSchema; +use crate::base::json_schema::build_json_schema; use crate::llm::{ new_llm_generation_client, LlmGenerateRequest, LlmGenerationClient, LlmSpec, OutputFormat, }; @@ -29,7 +29,7 @@ struct Executor { system_prompt: String, } -fn get_system_prompt(instructions: &Option) -> String { +fn get_system_prompt(instructions: &Option, extra_instructions: Option) -> String { let mut message = "You are a helpful assistant that extracts structured information from text. \ Your task is to analyze the input text and output valid JSON that matches the specified schema. \ @@ -42,21 +42,25 @@ Output only the JSON without any additional messages or explanations." message.push_str(custom_instructions); } + if let Some(extra_instructions) = extra_instructions { + message.push_str("\n\n"); + message.push_str(&extra_instructions); + } + message } impl Executor { async fn new(spec: Spec, args: Args) -> Result { let client = new_llm_generation_client(spec.llm_spec).await?; - let output_json_schema = spec - .output_type - .to_json_schema(&client.to_json_schema_options()); + let (output_json_schema, extra_instructions) = + build_json_schema(&spec.output_type, client.json_schema_options())?; Ok(Self { args, client, output_json_schema, output_type: spec.output_type, - system_prompt: get_system_prompt(&spec.instruction), + system_prompt: get_system_prompt(&spec.instruction, extra_instructions), }) } }