diff --git a/src/base/field_attrs.rs b/src/base/field_attrs.rs index a519672eb..1cc80fee8 100644 --- a/src/base/field_attrs.rs +++ b/src/base/field_attrs.rs @@ -2,10 +2,16 @@ use const_format::concatcp; pub static COCOINDEX_PREFIX: &str = "cocoindex.io/"; -/// Expected mime types for bytes and str. -pub static _MIME_TYPE: &str = concatcp!(COCOINDEX_PREFIX, "mime_type"); +/// Present for bytes and str. It points to fields that represents the original file name for the data. +/// Type: AnalyzedValueMapping +pub static CONTENT_FILENAME: &str = concatcp!(COCOINDEX_PREFIX, "content_filename"); -/// Base text for chunks. +/// Present for bytes and str. It points to fields that represents mime types for the data. +/// Type: AnalyzedValueMapping +pub static CONTENT_MIME_TYPE: &str = concatcp!(COCOINDEX_PREFIX, "content_mime_type"); + +/// Present for chunks. It points to fields that the chunks are for. +/// Type: AnalyzedValueMapping pub static CHUNK_BASE_TEXT: &str = concatcp!(COCOINDEX_PREFIX, "chunk_base_text"); /// Base text for an embedding vector. diff --git a/src/base/schema.rs b/src/base/schema.rs index af35957f3..4c54ca38c 100644 --- a/src/base/schema.rs +++ b/src/base/schema.rs @@ -65,7 +65,7 @@ impl std::fmt::Display for BasicValueType { } } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] pub struct StructSchema { pub fields: Arc>, @@ -172,17 +172,10 @@ impl std::fmt::Display for CollectionSchema { } impl CollectionSchema { - pub fn new( - kind: CollectionKind, - fields: Vec, - description: Option>, - ) -> Self { + pub fn new(kind: CollectionKind, row: StructSchema) -> Self { Self { kind, - row: StructSchema { - fields: Arc::new(fields), - description, - }, + row, collectors: Default::default(), } } diff --git a/src/ops/functions/split_recursively.rs b/src/ops/functions/split_recursively.rs index 69306ead3..2f61dd27e 100644 --- a/src/ops/functions/split_recursively.rs +++ b/src/ops/functions/split_recursively.rs @@ -585,18 +585,23 @@ impl SimpleFunctionFactoryBase for Factory { .next_optional_arg("language")? .expect_type(&ValueType::Basic(BasicValueType::Str))?, }; - let output_schema = make_output_type(CollectionSchema::new( - CollectionKind::Table, - vec![ - FieldSchema::new("location", make_output_type(BasicValueType::Range)), - FieldSchema::new("text", make_output_type(BasicValueType::Str)), - ], - None, - )) - .with_attr( - field_attrs::CHUNK_BASE_TEXT, - serde_json::to_value(args_resolver.get_analyze_value(&args.text))?, - ); + + let mut struct_schema = StructSchema::default(); + let mut schema_builder = StructSchemaBuilder::new(&mut struct_schema); + schema_builder.add_field(FieldSchema::new( + "location", + make_output_type(BasicValueType::Range), + )); + schema_builder.add_field(FieldSchema::new( + "text", + make_output_type(BasicValueType::Str), + )); + let output_schema = + make_output_type(CollectionSchema::new(CollectionKind::Table, struct_schema)) + .with_attr( + field_attrs::CHUNK_BASE_TEXT, + serde_json::to_value(args_resolver.get_analyze_value(&args.text))?, + ); Ok((args, output_schema)) } diff --git a/src/ops/sdk.rs b/src/ops/sdk.rs index 8631ac500..bbe970a69 100644 --- a/src/ops/sdk.rs +++ b/src/ops/sdk.rs @@ -1,6 +1,12 @@ +use crate::builder::plan::AnalyzedFieldReference; +use crate::builder::plan::AnalyzedLocalFieldReference; +use std::collections::BTreeMap; +use std::sync::Arc; + pub use super::factory_bases::*; pub use super::interface::*; pub use crate::base::schema::*; +pub use crate::base::spec::*; pub use crate::base::value::*; pub use anyhow::Result; pub use axum::async_trait; @@ -46,3 +52,75 @@ macro_rules! fields_value { $crate::base::value::FieldValues { fields: std::vec![ $(($field).into()),+ ] } }; } + +pub struct SchemaBuilderFieldRef(AnalyzedLocalFieldReference); + +impl SchemaBuilderFieldRef { + pub fn to_field_ref(&self) -> AnalyzedFieldReference { + AnalyzedFieldReference { + local: self.0.clone(), + scope_up_level: 0, + } + } +} +pub struct StructSchemaBuilder<'a> { + base_fields_idx: Vec, + target: &'a mut StructSchema, +} + +impl<'a> StructSchemaBuilder<'a> { + pub fn new(target: &'a mut StructSchema) -> Self { + Self { + base_fields_idx: Vec::new(), + target, + } + } + + pub fn set_description(&mut self, description: impl Into>) { + self.target.description = Some(description.into()); + } + + pub fn add_field(&mut self, field_schema: FieldSchema) -> SchemaBuilderFieldRef { + let current_idx = self.target.fields.len() as u32; + Arc::make_mut(&mut self.target.fields).push(field_schema); + let mut fields_idx = self.base_fields_idx.clone(); + fields_idx.push(current_idx); + SchemaBuilderFieldRef(AnalyzedLocalFieldReference { fields_idx }) + } + + pub fn add_struct_field<'b>( + &'b mut self, + name: impl Into, + nullable: bool, + attrs: Arc>, + ) -> (StructSchemaBuilder<'b>, SchemaBuilderFieldRef) { + let field_schema = FieldSchema::new( + name.into(), + EnrichedValueType { + typ: ValueType::Struct(StructSchema { + fields: Arc::new(Vec::new()), + description: None, + }), + nullable, + attrs, + }, + ); + let local_ref = self.add_field(field_schema); + let struct_schema = match &mut Arc::make_mut(&mut self.target.fields) + .last_mut() + .unwrap() + .value_type + .typ + { + ValueType::Struct(s) => s, + _ => unreachable!(), + }; + ( + StructSchemaBuilder { + base_fields_idx: local_ref.0.fields_idx.clone(), + target: struct_schema, + }, + local_ref, + ) + } +} diff --git a/src/ops/sources/google_drive.rs b/src/ops/sources/google_drive.rs index 6d6fb8cf1..7d219f1ec 100644 --- a/src/ops/sources/google_drive.rs +++ b/src/ops/sources/google_drive.rs @@ -14,6 +14,7 @@ use hyper_util::client::legacy::connect::HttpConnector; use indexmap::IndexSet; use log::warn; +use crate::base::field_attrs; use crate::ops::sdk::*; struct ExportMimeType { @@ -277,22 +278,39 @@ impl SourceFactoryBase for Factory { spec: &Spec, _context: &FlowInstanceContext, ) -> Result { + let mut struct_schema = StructSchema::default(); + let mut schema_builder = StructSchemaBuilder::new(&mut struct_schema); + schema_builder.add_field(FieldSchema::new( + "file_id", + make_output_type(BasicValueType::Str), + )); + let filename_field = schema_builder.add_field(FieldSchema::new( + "filename", + make_output_type(BasicValueType::Str), + )); + let mime_type_field = schema_builder.add_field(FieldSchema::new( + "mime_type", + make_output_type(BasicValueType::Str), + )); + schema_builder.add_field(FieldSchema::new( + "content", + make_output_type(if spec.binary { + BasicValueType::Bytes + } else { + BasicValueType::Str + }) + .with_attr( + field_attrs::CONTENT_FILENAME, + serde_json::to_value(filename_field.to_field_ref())?, + ) + .with_attr( + field_attrs::CONTENT_MIME_TYPE, + serde_json::to_value(mime_type_field.to_field_ref())?, + ), + )); Ok(make_output_type(CollectionSchema::new( CollectionKind::Table, - vec![ - FieldSchema::new("file_id", make_output_type(BasicValueType::Str)), - FieldSchema::new("filename", make_output_type(BasicValueType::Str)), - FieldSchema::new("mime_type", make_output_type(BasicValueType::Str)), - FieldSchema::new( - "content", - make_output_type(if spec.binary { - BasicValueType::Bytes - } else { - BasicValueType::Str - }), - ), - ], - None, + struct_schema, ))) } diff --git a/src/ops/sources/local_file.rs b/src/ops/sources/local_file.rs index 3e5b5bef4..1fb5aec7b 100644 --- a/src/ops/sources/local_file.rs +++ b/src/ops/sources/local_file.rs @@ -2,6 +2,7 @@ use globset::{Glob, GlobSet, GlobSetBuilder}; use log::warn; use std::{path::PathBuf, sync::Arc}; +use crate::base::field_attrs; use crate::{fields_value, ops::sdk::*}; #[derive(Debug, Deserialize)] @@ -99,20 +100,28 @@ impl SourceFactoryBase for Factory { spec: &Spec, _context: &FlowInstanceContext, ) -> Result { + let mut struct_schema = StructSchema::default(); + let mut schema_builder = StructSchemaBuilder::new(&mut struct_schema); + let filename_field = schema_builder.add_field(FieldSchema::new( + "filename", + make_output_type(BasicValueType::Str), + )); + schema_builder.add_field(FieldSchema::new( + "content", + make_output_type(if spec.binary { + BasicValueType::Bytes + } else { + BasicValueType::Str + }) + .with_attr( + field_attrs::CONTENT_FILENAME, + serde_json::to_value(filename_field.to_field_ref())?, + ), + )); + Ok(make_output_type(CollectionSchema::new( CollectionKind::Table, - vec![ - FieldSchema::new("filename", make_output_type(BasicValueType::Str)), - FieldSchema::new( - "content", - make_output_type(if spec.binary { - BasicValueType::Bytes - } else { - BasicValueType::Str - }), - ), - ], - None, + struct_schema, ))) }