Skip to content

Commit e8a6290

Browse files
authored
LLM extraction allows top-level to be non-object (e.g. list) for OpenAI. (#302)
1 parent e470f80 commit e8a6290

File tree

6 files changed

+90
-34
lines changed

6 files changed

+90
-34
lines changed

examples/docs_to_kg/main.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,6 @@ class Relationship:
1313
predicate: str
1414
object: str
1515

16-
@dataclasses.dataclass
17-
class Relationships:
18-
"""Describe a relationship between two nodes."""
19-
relationships: list[Relationship]
20-
2116
@cocoindex.flow_def(name="DocsToKG")
2217
def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope):
2318
"""
@@ -48,13 +43,16 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D
4843
cocoindex.functions.ExtractByLlm(
4944
llm_spec=cocoindex.LlmSpec(
5045
api_type=cocoindex.LlmApiType.OPENAI, model="gpt-4o"),
51-
output_type=Relationships,
46+
# Replace by this spec below, to use Ollama API instead of OpenAI
47+
# llm_spec=cocoindex.LlmSpec(
48+
# api_type=cocoindex.LlmApiType.OLLAMA, model="llama3.2"),
49+
output_type=list[Relationship],
5250
instruction=(
5351
"Please extract relationships from CocoIndex documents. "
5452
"Focus on concepts and ingnore specific examples. "
5553
"Each relationship should be a tuple of (subject, predicate, object).")))
5654

57-
with chunk["relationships"]["relationships"].row() as relationship:
55+
with chunk["relationships"].row() as relationship:
5856
relationship["subject_embedding"] = relationship["subject"].transform(
5957
cocoindex.functions.SentenceTransformerEmbed(
6058
model="sentence-transformers/all-MiniLM-L6-v2"))

src/base/json_schema.rs

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
use crate::utils::immutable::RefList;
1+
use crate::prelude::*;
22

3-
use super::{schema, spec::FieldName};
4-
use anyhow::Result;
5-
use indexmap::IndexMap;
3+
use crate::utils::immutable::RefList;
64
use schemars::schema::{
75
ArrayValidation, InstanceType, ObjectValidation, Schema, SchemaObject, SingleOrVec,
86
};
@@ -19,6 +17,9 @@ pub struct ToJsonSchemaOptions {
1917

2018
/// If true, extract descriptions to a separate extra instruction.
2119
pub extract_descriptions: bool,
20+
21+
/// If true, the top level must be a JSON object.
22+
pub top_level_must_be_object: bool,
2223
}
2324

2425
struct JsonSchemaBuilder {
@@ -38,7 +39,7 @@ impl JsonSchemaBuilder {
3839
&mut self,
3940
schema: &mut SchemaObject,
4041
description: impl ToString,
41-
field_path: RefList<'_, &'_ FieldName>,
42+
field_path: RefList<'_, &'_ spec::FieldName>,
4243
) {
4344
if self.options.extract_descriptions {
4445
let mut fields: Vec<_> = field_path.iter().map(|f| f.as_str()).collect();
@@ -53,7 +54,7 @@ impl JsonSchemaBuilder {
5354
fn for_basic_value_type(
5455
&mut self,
5556
basic_type: &schema::BasicValueType,
56-
field_path: RefList<'_, &'_ FieldName>,
57+
field_path: RefList<'_, &'_ spec::FieldName>,
5758
) -> SchemaObject {
5859
let mut schema = SchemaObject::default();
5960
match basic_type {
@@ -171,7 +172,7 @@ impl JsonSchemaBuilder {
171172
fn for_struct_schema(
172173
&mut self,
173174
struct_schema: &schema::StructSchema,
174-
field_path: RefList<'_, &'_ FieldName>,
175+
field_path: RefList<'_, &'_ spec::FieldName>,
175176
) -> SchemaObject {
176177
let mut schema = SchemaObject::default();
177178
if let Some(description) = &struct_schema.description {
@@ -213,7 +214,7 @@ impl JsonSchemaBuilder {
213214
fn for_value_type(
214215
&mut self,
215216
value_type: &schema::ValueType,
216-
field_path: RefList<'_, &'_ FieldName>,
217+
field_path: RefList<'_, &'_ spec::FieldName>,
217218
) -> SchemaObject {
218219
match value_type {
219220
schema::ValueType::Basic(b) => self.for_basic_value_type(b, field_path),
@@ -234,7 +235,7 @@ impl JsonSchemaBuilder {
234235
fn for_enriched_value_type(
235236
&mut self,
236237
enriched_value_type: &schema::EnrichedValueType,
237-
field_path: RefList<'_, &'_ FieldName>,
238+
field_path: RefList<'_, &'_ spec::FieldName>,
238239
) -> SchemaObject {
239240
self.for_value_type(&enriched_value_type.typ, field_path)
240241
}
@@ -262,11 +263,69 @@ impl JsonSchemaBuilder {
262263
}
263264
}
264265

266+
pub struct ValueExtractor {
267+
value_type: schema::ValueType,
268+
object_wrapper_field_name: Option<String>,
269+
}
270+
271+
impl ValueExtractor {
272+
pub fn extract_value(&self, json_value: serde_json::Value) -> Result<value::Value> {
273+
let unwrapped_json_value =
274+
if let Some(object_wrapper_field_name) = &self.object_wrapper_field_name {
275+
match json_value {
276+
serde_json::Value::Object(mut o) => o
277+
.remove(object_wrapper_field_name)
278+
.unwrap_or(serde_json::Value::Null),
279+
_ => {
280+
bail!("Field `{}` not found", object_wrapper_field_name)
281+
}
282+
}
283+
} else {
284+
json_value
285+
};
286+
let result = value::Value::from_json(unwrapped_json_value, &self.value_type)?;
287+
Ok(result)
288+
}
289+
}
290+
291+
pub struct BuildJsonSchemaOutput {
292+
pub schema: SchemaObject,
293+
pub extra_instructions: Option<String>,
294+
pub value_extractor: ValueExtractor,
295+
}
296+
265297
pub fn build_json_schema(
266-
value_type: &schema::EnrichedValueType,
298+
value_type: schema::EnrichedValueType,
267299
options: ToJsonSchemaOptions,
268-
) -> Result<(SchemaObject, Option<String>)> {
300+
) -> Result<BuildJsonSchemaOutput> {
269301
let mut builder = JsonSchemaBuilder::new(options);
270-
let schema = builder.for_enriched_value_type(value_type, RefList::Nil);
271-
Ok((schema, builder.build_extra_instructions()?))
302+
let (schema, object_wrapper_field_name) = if builder.options.top_level_must_be_object
303+
&& !matches!(value_type.typ, schema::ValueType::Struct(_))
304+
{
305+
let object_wrapper_field_name = "value".to_string();
306+
let wrapper_struct = schema::StructSchema {
307+
fields: Arc::new(vec![schema::FieldSchema {
308+
name: object_wrapper_field_name.clone(),
309+
value_type: value_type.clone(),
310+
}]),
311+
description: None,
312+
};
313+
(
314+
builder.for_struct_schema(&wrapper_struct, RefList::Nil),
315+
Some(object_wrapper_field_name),
316+
)
317+
} else {
318+
(
319+
builder.for_enriched_value_type(&value_type, RefList::Nil),
320+
None,
321+
)
322+
};
323+
Ok(BuildJsonSchemaOutput {
324+
schema,
325+
extra_instructions: builder.build_extra_instructions()?,
326+
value_extractor: ValueExtractor {
327+
value_type: value_type.typ,
328+
object_wrapper_field_name,
329+
},
330+
})
272331
}

src/llm/ollama.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ impl LlmGenerationClient for Client {
8181
fields_always_required: false,
8282
supports_format: true,
8383
extract_descriptions: true,
84+
top_level_must_be_object: false,
8485
}
8586
}
8687
}

src/llm/openai.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ impl LlmGenerationClient for Client {
103103
fields_always_required: true,
104104
supports_format: false,
105105
extract_descriptions: false,
106+
top_level_must_be_object: true,
106107
}
107108
}
108109
}

src/ops/functions/extract_by_llm.rs

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1-
use std::borrow::Cow;
2-
use std::sync::Arc;
3-
4-
use schemars::schema::SchemaObject;
5-
use serde::Serialize;
1+
use crate::prelude::*;
62

7-
use crate::base::json_schema::build_json_schema;
83
use crate::llm::{
94
new_llm_generation_client, LlmGenerateRequest, LlmGenerationClient, LlmSpec, OutputFormat,
105
};
116
use crate::ops::sdk::*;
7+
use base::json_schema::build_json_schema;
8+
use schemars::schema::SchemaObject;
9+
use std::borrow::Cow;
1210

1311
#[derive(Debug, Clone, Serialize, Deserialize)]
1412
pub struct Spec {
@@ -25,8 +23,8 @@ struct Executor {
2523
args: Args,
2624
client: Box<dyn LlmGenerationClient>,
2725
output_json_schema: SchemaObject,
28-
output_type: EnrichedValueType,
2926
system_prompt: String,
27+
value_extractor: base::json_schema::ValueExtractor,
3028
}
3129

3230
fn get_system_prompt(instructions: &Option<String>, extra_instructions: Option<String>) -> String {
@@ -53,14 +51,13 @@ Output only the JSON without any additional messages or explanations."
5351
impl Executor {
5452
async fn new(spec: Spec, args: Args) -> Result<Self> {
5553
let client = new_llm_generation_client(spec.llm_spec).await?;
56-
let (output_json_schema, extra_instructions) =
57-
build_json_schema(&spec.output_type, client.json_schema_options())?;
54+
let schema_output = build_json_schema(spec.output_type, client.json_schema_options())?;
5855
Ok(Self {
5956
args,
6057
client,
61-
output_json_schema,
62-
output_type: spec.output_type,
63-
system_prompt: get_system_prompt(&spec.instruction, extra_instructions),
58+
output_json_schema: schema_output.schema,
59+
system_prompt: get_system_prompt(&spec.instruction, schema_output.extra_instructions),
60+
value_extractor: schema_output.value_extractor,
6461
})
6562
}
6663
}
@@ -87,7 +84,7 @@ impl SimpleFunctionExecutor for Executor {
8784
};
8885
let res = self.client.generate(req).await?;
8986
let json_value: serde_json::Value = serde_json::from_str(res.text.as_str())?;
90-
let value = Value::from_json(json_value, &self.output_type.typ)?;
87+
let value = self.value_extractor.extract_value(json_value)?;
9188
Ok(value)
9289
}
9390
}

src/prelude.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ pub(crate) use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
1717
pub(crate) use std::hash::Hash;
1818
pub(crate) use std::sync::{Arc, LazyLock, Mutex, OnceLock, RwLock, Weak};
1919

20-
pub(crate) use crate::base::{schema, spec, value};
20+
pub(crate) use crate::base::{self, schema, spec, value};
2121
pub(crate) use crate::builder::{self, plan};
2222
pub(crate) use crate::execution;
2323
pub(crate) use crate::lib_context::{get_lib_context, get_runtime, FlowContext, LibContext};

0 commit comments

Comments
 (0)