diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index 8fe86dad7..c6171148c 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -184,6 +184,9 @@ def decode_vector(value: Any) -> Any | None: return decode_vector + if src_type_kind == "Union": + return lambda value: value[1] + return lambda value: value diff --git a/python/cocoindex/tests/test_convert.py b/python/cocoindex/tests/test_convert.py index a3ff67d02..0e44c24bf 100644 --- a/python/cocoindex/tests/test_convert.py +++ b/python/cocoindex/tests/test_convert.py @@ -549,6 +549,48 @@ def test_field_position_cases( assert decoder(engine_val) == PythonOrder(**expected_dict) +def test_roundtrip_union_simple() -> None: + t = int | str | float + value = 10.4 + validate_full_roundtrip(value, t) + + +def test_roundtrip_union_with_active_uuid() -> None: + t = str | uuid.UUID | int + value = uuid.uuid4().bytes + validate_full_roundtrip(value, t) + + +def test_roundtrip_union_with_inactive_uuid() -> None: + t = str | uuid.UUID | int + value = "5a9f8f6a-318f-4f1f-929d-566d7444a62d" # it's a string + validate_full_roundtrip(value, t) + + +def test_roundtrip_union_offset_datetime() -> None: + t = str | uuid.UUID | float | int | datetime.datetime + value = datetime.datetime.now(datetime.UTC) + validate_full_roundtrip(value, t) + + +def test_roundtrip_union_date() -> None: + t = str | uuid.UUID | float | int | datetime.date + value = datetime.date.today() + validate_full_roundtrip(value, t) + + +def test_roundtrip_union_time() -> None: + t = str | uuid.UUID | float | int | datetime.time + value = datetime.time() + validate_full_roundtrip(value, t) + + +def test_roundtrip_union_timedelta() -> None: + t = str | uuid.UUID | float | int | datetime.timedelta + value = datetime.timedelta(hours=39, minutes=10, seconds=1) + validate_full_roundtrip(value, t) + + def test_roundtrip_ltable() -> None: t = list[Order] value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)] diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index 863494538..c89b55670 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -161,6 +161,7 @@ class AnalyzedTypeInfo: attrs: dict[str, Any] | None nullable: bool = False + union_variant_types: typing.List[ElementType] | None = None # For Union def analyze_type_info(t: Any) -> AnalyzedTypeInfo: @@ -181,18 +182,6 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: if base_type is Annotated: annotations = t.__metadata__ t = t.__origin__ - elif base_type is types.UnionType: - possible_types = typing.get_args(t) - non_none_types = [ - arg for arg in possible_types if arg not in (None, types.NoneType) - ] - if len(non_none_types) != 1: - raise ValueError( - f"Expect exactly one non-None choice for Union type, but got {len(non_none_types)}: {t}" - ) - t = non_none_types[0] - if len(possible_types) > 1: - nullable = True else: break @@ -211,6 +200,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: struct_type: type | None = None elem_type: ElementType | None = None + union_variant_types: typing.List[ElementType] | None = None key_type: type | None = None np_number_type: type | None = None if _is_struct_type(t): @@ -251,6 +241,24 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: args = typing.get_args(t) elem_type = (args[0], args[1]) kind = "KTable" + elif base_type is types.UnionType: + possible_types = typing.get_args(t) + non_none_types = [ + arg for arg in possible_types if arg not in (None, types.NoneType) + ] + + if len(non_none_types) == 0: + return analyze_type_info(None) + + nullable = len(non_none_types) < len(possible_types) + + if len(non_none_types) == 1: + result = analyze_type_info(non_none_types[0]) + result.nullable = nullable + return result + + kind = "Union" + union_variant_types = non_none_types elif kind is None: if t is bytes: kind = "Bytes" @@ -279,6 +287,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: kind=kind, vector_info=vector_info, elem_type=elem_type, + union_variant_types=union_variant_types, key_type=key_type, struct_type=struct_type, np_number_type=np_number_type, @@ -338,6 +347,14 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: encoded_type["element_type"] = _encode_type(elem_type_info) encoded_type["dimension"] = type_info.vector_info.dim + elif type_info.kind == "Union": + if type_info.union_variant_types is None: + raise ValueError("Union type must have a variant type list") + encoded_type["types"] = [ + _encode_type(analyze_type_info(typ)) + for typ in type_info.union_variant_types + ] + elif type_info.kind in TABLE_TYPES: if type_info.elem_type is None: raise ValueError(f"{type_info.kind} type must have an element type") diff --git a/src/base/json_schema.rs b/src/base/json_schema.rs index 9ae9b2b47..ae5a3995e 100644 --- a/src/base/json_schema.rs +++ b/src/base/json_schema.rs @@ -3,6 +3,7 @@ use crate::prelude::*; use crate::utils::immutable::RefList; use schemars::schema::{ ArrayValidation, InstanceType, ObjectValidation, Schema, SchemaObject, SingleOrVec, + SubschemaValidation, }; use std::fmt::Write; @@ -176,6 +177,17 @@ impl JsonSchemaBuilder { ..Default::default() })); } + schema::BasicValueType::Union(s) => { + schema.subschemas = Some(Box::new(SubschemaValidation { + one_of: Some( + s.types + .iter() + .map(|t| Schema::Object(self.for_basic_value_type(t, field_path))) + .collect(), + ), + ..Default::default() + })); + } } schema } diff --git a/src/base/schema.rs b/src/base/schema.rs index b13ba5cae..0c8ae38c0 100644 --- a/src/base/schema.rs +++ b/src/base/schema.rs @@ -9,6 +9,11 @@ pub struct VectorTypeSchema { pub dimension: Option, } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct UnionTypeSchema { + pub types: Vec, +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(tag = "kind")] pub enum BasicValueType { @@ -56,6 +61,9 @@ pub enum BasicValueType { /// A vector of values (usually numbers, for embeddings). Vector(VectorTypeSchema), + + /// A union + Union(UnionTypeSchema), } impl std::fmt::Display for BasicValueType { @@ -82,6 +90,17 @@ impl std::fmt::Display for BasicValueType { } write!(f, "]") } + BasicValueType::Union(s) => { + write!(f, "Union[")?; + for (i, typ) in s.types.iter().enumerate() { + if i > 0 { + // Add type delimiter + write!(f, " | ")?; + } + write!(f, "{}", typ)?; + } + write!(f, "]") + } } } } diff --git a/src/base/value.rs b/src/base/value.rs index 8235e5bce..e1c8713f6 100644 --- a/src/base/value.rs +++ b/src/base/value.rs @@ -379,6 +379,10 @@ pub enum BasicValue { TimeDelta(chrono::Duration), Json(Arc), Vector(Arc<[BasicValue]>), + UnionVariant { + tag_id: usize, + value: Box, + }, } impl From for BasicValue { @@ -496,7 +500,8 @@ impl BasicValue { | BasicValue::OffsetDateTime(_) | BasicValue::TimeDelta(_) | BasicValue::Json(_) - | BasicValue::Vector(_) => api_bail!("invalid key value type"), + | BasicValue::Vector(_) + | BasicValue::UnionVariant { .. } => api_bail!("invalid key value type"), }; Ok(result) } @@ -517,7 +522,8 @@ impl BasicValue { | BasicValue::OffsetDateTime(_) | BasicValue::TimeDelta(_) | BasicValue::Json(_) - | BasicValue::Vector(_) => api_bail!("invalid key value type"), + | BasicValue::Vector(_) + | BasicValue::UnionVariant { .. } => api_bail!("invalid key value type"), }; Ok(result) } @@ -539,6 +545,7 @@ impl BasicValue { BasicValue::TimeDelta(_) => "timedelta", BasicValue::Json(_) => "json", BasicValue::Vector(_) => "vector", + BasicValue::UnionVariant { .. } => "union", } } } @@ -892,6 +899,12 @@ impl serde::Serialize for BasicValue { BasicValue::TimeDelta(v) => serializer.serialize_str(&v.to_string()), BasicValue::Json(v) => v.serialize(serializer), BasicValue::Vector(v) => v.serialize(serializer), + BasicValue::UnionVariant { tag_id, value } => { + let mut s = serializer.serialize_tuple(2)?; + s.serialize_element(tag_id)?; + s.serialize_element(value)?; + s.end() + } } } } @@ -956,6 +969,40 @@ impl BasicValue { .collect::>>()?; BasicValue::Vector(Arc::from(vec)) } + (v, BasicValueType::Union(typ)) => { + let arr = match v { + serde_json::Value::Array(arr) => arr, + _ => anyhow::bail!("Invalid JSON value for union, expect array"), + }; + + if arr.len() != 2 { + anyhow::bail!( + "Invalid union tuple: expect 2 values, received {}", + arr.len() + ); + } + + let mut obj_iter = arr.into_iter(); + + // Take first element + let tag_id = obj_iter + .next() + .and_then(|value| value.as_u64().map(|num_u64| num_u64 as usize)) + .unwrap(); + + // Take second element + let value = obj_iter.next().unwrap(); + + let cur_type = typ + .types + .get(tag_id) + .ok_or_else(|| anyhow::anyhow!("No type in `tag_id` \"{tag_id}\" found"))?; + + BasicValue::UnionVariant { + tag_id, + value: Box::new(BasicValue::from_json(value, cur_type)?), + } + } (v, t) => { anyhow::bail!("Value and type not matched.\nTarget type {t:?}\nJSON value: {v}\n") } @@ -1088,7 +1135,17 @@ impl Serialize for TypedValue<'_> { fn serialize(&self, serializer: S) -> Result { match (self.t, self.v) { (_, Value::Null) => serializer.serialize_none(), - (ValueType::Basic(_), v) => v.serialize(serializer), + (ValueType::Basic(t), v) => match t { + BasicValueType::Union(_) => match v { + Value::Basic(BasicValue::UnionVariant { value, .. }) => { + value.serialize(serializer) + } + _ => Err(serde::ser::Error::custom( + "Unmatched union type and value for `TypedValue`", + )), + }, + _ => v.serialize(serializer), + }, (ValueType::Struct(s), Value::Struct(field_values)) => TypedFieldsValue { schema: &s.fields, values_iter: field_values.fields.iter(), diff --git a/src/llm/litellm.rs b/src/llm/litellm.rs index 1dc628d9e..27648747a 100644 --- a/src/llm/litellm.rs +++ b/src/llm/litellm.rs @@ -1,16 +1,22 @@ -use async_openai::config::OpenAIConfig; use async_openai::Client as OpenAIClient; +use async_openai::config::OpenAIConfig; pub use super::openai::Client; impl Client { pub async fn new_litellm(spec: super::LlmSpec) -> anyhow::Result { - let address = spec.address.clone().unwrap_or_else(|| "http://127.0.0.1:4000".to_string()); + let address = spec + .address + .clone() + .unwrap_or_else(|| "http://127.0.0.1:4000".to_string()); let api_key = std::env::var("LITELLM_API_KEY").ok(); let mut config = OpenAIConfig::new().with_api_base(address); if let Some(api_key) = api_key { config = config.with_api_key(api_key); } - Ok(Client::from_parts(OpenAIClient::with_config(config), spec.model)) + Ok(Client::from_parts( + OpenAIClient::with_config(config), + spec.model, + )) } } diff --git a/src/llm/mod.rs b/src/llm/mod.rs index a3652955f..ea4aa58ee 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -56,9 +56,9 @@ pub trait LlmGenerationClient: Send + Sync { mod anthropic; mod gemini; +mod litellm; mod ollama; mod openai; -mod litellm; mod openrouter; pub async fn new_llm_generation_client(spec: LlmSpec) -> Result> { @@ -78,11 +78,8 @@ pub async fn new_llm_generation_client(spec: LlmSpec) -> Result { Box::new(litellm::Client::new_litellm(spec).await?) as Box } - LlmApiType::OpenRouter => { - Box::new(openrouter::Client::new_openrouter(spec).await?) as Box - } - - + LlmApiType::OpenRouter => Box::new(openrouter::Client::new_openrouter(spec).await?) + as Box, }; Ok(client) } diff --git a/src/llm/openrouter.rs b/src/llm/openrouter.rs index 5dde06b91..cb7757889 100644 --- a/src/llm/openrouter.rs +++ b/src/llm/openrouter.rs @@ -1,16 +1,22 @@ -use async_openai::config::OpenAIConfig; use async_openai::Client as OpenAIClient; +use async_openai::config::OpenAIConfig; pub use super::openai::Client; impl Client { pub async fn new_openrouter(spec: super::LlmSpec) -> anyhow::Result { - let address = spec.address.clone().unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string()); + let address = spec + .address + .clone() + .unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string()); let api_key = std::env::var("OPENROUTER_API_KEY").ok(); let mut config = OpenAIConfig::new().with_api_base(address); if let Some(api_key) = api_key { config = config.with_api_key(api_key); } - Ok(Client::from_parts(OpenAIClient::with_config(config), spec.model)) + Ok(Client::from_parts( + OpenAIClient::with_config(config), + spec.model, + )) } } diff --git a/src/ops/targets/kuzu.rs b/src/ops/targets/kuzu.rs index 069f980eb..0a4ceaf9e 100644 --- a/src/ops/targets/kuzu.rs +++ b/src/ops/targets/kuzu.rs @@ -123,7 +123,7 @@ fn basic_type_to_kuzu(basic_type: &BasicValueType) -> Result { t.dimension .map_or_else(|| "".to_string(), |d| d.to_string()) ), - t @ (BasicValueType::Time | BasicValueType::Json) => { + t @ (BasicValueType::Union(_) | BasicValueType::Time | BasicValueType::Json) => { api_bail!("{t} is not supported in Kuzu") } }) @@ -377,7 +377,7 @@ fn append_basic_value(cypher: &mut CypherBuilder, basic_value: &BasicValue) -> R } write!(cypher.query_mut(), "]")?; } - v @ (BasicValue::Time(_) | BasicValue::Json(_)) => { + v @ (BasicValue::UnionVariant { .. } | BasicValue::Time(_) | BasicValue::Json(_)) => { bail!("value types are not supported in Kuzu: {}", v.kind()); } } diff --git a/src/ops/targets/neo4j.rs b/src/ops/targets/neo4j.rs index 90295ee26..bfc4c4156 100644 --- a/src/ops/targets/neo4j.rs +++ b/src/ops/targets/neo4j.rs @@ -226,6 +226,17 @@ fn basic_value_to_bolt(value: &BasicValue, schema: &BasicValueType) -> Result anyhow::bail!("Non-vector type got vector value: {}", schema), }, BasicValue::Json(v) => json_value_to_bolt_value(v)?, + BasicValue::UnionVariant { tag_id, value } => match schema { + BasicValueType::Union(s) => { + let typ = s + .types + .get(*tag_id) + .ok_or_else(|| anyhow::anyhow!("Invalid `tag_id`: {}", tag_id))?; + + basic_value_to_bolt(value, typ)? + } + _ => anyhow::bail!("Non-union type got union value: {}", schema), + }, }; Ok(bolt_value) } diff --git a/src/ops/targets/postgres.rs b/src/ops/targets/postgres.rs index 747b60c03..a9b8507f4 100644 --- a/src/ops/targets/postgres.rs +++ b/src/ops/targets/postgres.rs @@ -154,6 +154,12 @@ fn bind_value_field<'arg>( builder.push_bind(sqlx::types::Json(v)); } }, + BasicValue::UnionVariant { .. } => { + builder.push_bind(sqlx::types::Json(TypedValue { + t: &field_schema.value_type.typ, + v: value, + })); + } }, Value::Null => { builder.push("NULL"); @@ -383,6 +389,7 @@ fn to_column_type_sql(column_type: &ValueType) -> String { "jsonb".into() } } + BasicValueType::Union(_) => "jsonb".into(), }, _ => "jsonb".into(), } diff --git a/src/py/convert.rs b/src/py/convert.rs index 43afae95e..2ccc19055 100644 --- a/src/py/convert.rs +++ b/src/py/convert.rs @@ -1,6 +1,7 @@ use bytes::Bytes; use numpy::{PyArray1, PyArrayDyn, PyArrayMethods}; use pyo3::IntoPyObjectExt; +use pyo3::exceptions::PyTypeError; use pyo3::types::PyAny; use pyo3::types::{PyList, PyTuple}; use pyo3::{exceptions::PyException, prelude::*}; @@ -76,6 +77,9 @@ fn basic_value_to_py_object<'py>( value::BasicValue::TimeDelta(v) => v.into_bound_py_any(py)?, value::BasicValue::Json(v) => pythonize(py, v).into_py_result()?, value::BasicValue::Vector(v) => handle_vector_to_py(py, v)?, + value::BasicValue::UnionVariant { tag_id, value } => { + (*tag_id, basic_value_to_py_object(py, &value)?).into_bound_py_any(py)? + } }; Ok(result) } @@ -162,6 +166,27 @@ fn basic_value_from_py_object<'py>( )) } } + schema::BasicValueType::Union(s) => { + let mut valid_value = None; + + // Try parsing the value + for (i, typ) in s.types.iter().enumerate() { + if let Ok(value) = basic_value_from_py_object(typ, v) { + valid_value = Some(value::BasicValue::UnionVariant { + tag_id: i, + value: Box::new(value), + }); + break; + } + } + + valid_value.ok_or_else(|| { + PyErr::new::(format!( + "invalid union value: {}, available types: {:?}", + v, s.types + )) + })? + } }; Ok(result) }