Skip to content

Commit 40e84b6

Browse files
authored
feat: support union type for basic types (#510)
* Add union type and JSON schema conversion * Add basic postgres conversion * Compact lines * Workaround for JSON conversion in union * Add stub impl for Python union conversion * Add impl for python object union conversion * Update error message * Rename type item * Add str parsing method * Add union conversion for Qdrant * Add basic string parsing for union type * Fix union conversion for Qdrant * Replace if guards with matches * Add extra parsing for string * Add rustdoc for parsing method * Turn string parsing into a util function * Update union parsing for serde value * Add vector union type parsing for Qdrant * Switch to BTreeSet for union types * Remove nested union detection * Remove TODO: Support struct/table * Add union type helper struct * Add comments * Update Python type conversion for union type * Use reversed iteration for union type matching * Add test cases for union fmt * Update comments * Add test cases * Remove "undetected JSON" parsing * Update union analysis in Python API * Add union type encoding for Python API * Add single type checking for union type analysis * Update union type * Add union decoding * Revert "Add union decoding" This reverts commit f8eb3cc. * Update encoded type field * Update union types field in Python * Update type serialization * Revert "Update type serialization" This reverts commit 030281c. * Add `UnionVariant` and conversions in `BasicValue` * Add union value binding for Postgres * Update type guessing for union from python object * Replace direct return with break * Use `Vec` to remove auto-sort * Revert "Use `Vec` to remove auto-sort" This reverts commit d194117. * Use `Vec` for union type * Add union processing for KuzuDB * Update Cypher generation for union type * Use 0-based index for `val{i}` * Update tuple * Take values for JSON conversion for union * Update variable name * Use typed value conversion for union in Postgres * Replace union conversion with error in `from_pg_value()` * Update union conversion for Qdrant * Update `PyErr` message for union * Move `UnionType` to `schema.rs` as `UnionTypeSchema` * Use `to_value()` for union value conversion * Use `bail!()` for early return * Update error message for union tuple conversion * Move union type checking to the loop * Replace `.ok_or_else()` with `.unwrap()` * Update union variant serialization * Match quote styling * Break infinite loops * Added a union test case * Fix union typing * Make `union_variant_types` optional * Update test case * Fix JSON seder and decoding * Add UUID union test cases * Add test cases for union type * Update the offset datetime test case for unions * Remove union implementation for Kuzu * Update `union_variant_types` typing * Add union value serialization for `TypedValue` * Update union tuple check for basic value * Reformat files * Remove explicit type for array, the type is obvious
1 parent cdeb887 commit 40e84b6

File tree

13 files changed

+231
-29
lines changed

13 files changed

+231
-29
lines changed

python/cocoindex/convert.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ def decode_vector(value: Any) -> Any | None:
184184

185185
return decode_vector
186186

187+
if src_type_kind == "Union":
188+
return lambda value: value[1]
189+
187190
return lambda value: value
188191

189192

python/cocoindex/tests/test_convert.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,48 @@ def test_field_position_cases(
549549
assert decoder(engine_val) == PythonOrder(**expected_dict)
550550

551551

552+
def test_roundtrip_union_simple() -> None:
553+
t = int | str | float
554+
value = 10.4
555+
validate_full_roundtrip(value, t)
556+
557+
558+
def test_roundtrip_union_with_active_uuid() -> None:
559+
t = str | uuid.UUID | int
560+
value = uuid.uuid4().bytes
561+
validate_full_roundtrip(value, t)
562+
563+
564+
def test_roundtrip_union_with_inactive_uuid() -> None:
565+
t = str | uuid.UUID | int
566+
value = "5a9f8f6a-318f-4f1f-929d-566d7444a62d" # it's a string
567+
validate_full_roundtrip(value, t)
568+
569+
570+
def test_roundtrip_union_offset_datetime() -> None:
571+
t = str | uuid.UUID | float | int | datetime.datetime
572+
value = datetime.datetime.now(datetime.UTC)
573+
validate_full_roundtrip(value, t)
574+
575+
576+
def test_roundtrip_union_date() -> None:
577+
t = str | uuid.UUID | float | int | datetime.date
578+
value = datetime.date.today()
579+
validate_full_roundtrip(value, t)
580+
581+
582+
def test_roundtrip_union_time() -> None:
583+
t = str | uuid.UUID | float | int | datetime.time
584+
value = datetime.time()
585+
validate_full_roundtrip(value, t)
586+
587+
588+
def test_roundtrip_union_timedelta() -> None:
589+
t = str | uuid.UUID | float | int | datetime.timedelta
590+
value = datetime.timedelta(hours=39, minutes=10, seconds=1)
591+
validate_full_roundtrip(value, t)
592+
593+
552594
def test_roundtrip_ltable() -> None:
553595
t = list[Order]
554596
value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]

python/cocoindex/typing.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ class AnalyzedTypeInfo:
161161

162162
attrs: dict[str, Any] | None
163163
nullable: bool = False
164+
union_variant_types: typing.List[ElementType] | None = None # For Union
164165

165166

166167
def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
@@ -181,18 +182,6 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
181182
if base_type is Annotated:
182183
annotations = t.__metadata__
183184
t = t.__origin__
184-
elif base_type is types.UnionType:
185-
possible_types = typing.get_args(t)
186-
non_none_types = [
187-
arg for arg in possible_types if arg not in (None, types.NoneType)
188-
]
189-
if len(non_none_types) != 1:
190-
raise ValueError(
191-
f"Expect exactly one non-None choice for Union type, but got {len(non_none_types)}: {t}"
192-
)
193-
t = non_none_types[0]
194-
if len(possible_types) > 1:
195-
nullable = True
196185
else:
197186
break
198187

@@ -211,6 +200,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
211200

212201
struct_type: type | None = None
213202
elem_type: ElementType | None = None
203+
union_variant_types: typing.List[ElementType] | None = None
214204
key_type: type | None = None
215205
np_number_type: type | None = None
216206
if _is_struct_type(t):
@@ -251,6 +241,24 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
251241
args = typing.get_args(t)
252242
elem_type = (args[0], args[1])
253243
kind = "KTable"
244+
elif base_type is types.UnionType:
245+
possible_types = typing.get_args(t)
246+
non_none_types = [
247+
arg for arg in possible_types if arg not in (None, types.NoneType)
248+
]
249+
250+
if len(non_none_types) == 0:
251+
return analyze_type_info(None)
252+
253+
nullable = len(non_none_types) < len(possible_types)
254+
255+
if len(non_none_types) == 1:
256+
result = analyze_type_info(non_none_types[0])
257+
result.nullable = nullable
258+
return result
259+
260+
kind = "Union"
261+
union_variant_types = non_none_types
254262
elif kind is None:
255263
if t is bytes:
256264
kind = "Bytes"
@@ -279,6 +287,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
279287
kind=kind,
280288
vector_info=vector_info,
281289
elem_type=elem_type,
290+
union_variant_types=union_variant_types,
282291
key_type=key_type,
283292
struct_type=struct_type,
284293
np_number_type=np_number_type,
@@ -338,6 +347,14 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
338347
encoded_type["element_type"] = _encode_type(elem_type_info)
339348
encoded_type["dimension"] = type_info.vector_info.dim
340349

350+
elif type_info.kind == "Union":
351+
if type_info.union_variant_types is None:
352+
raise ValueError("Union type must have a variant type list")
353+
encoded_type["types"] = [
354+
_encode_type(analyze_type_info(typ))
355+
for typ in type_info.union_variant_types
356+
]
357+
341358
elif type_info.kind in TABLE_TYPES:
342359
if type_info.elem_type is None:
343360
raise ValueError(f"{type_info.kind} type must have an element type")

src/base/json_schema.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::prelude::*;
33
use crate::utils::immutable::RefList;
44
use schemars::schema::{
55
ArrayValidation, InstanceType, ObjectValidation, Schema, SchemaObject, SingleOrVec,
6+
SubschemaValidation,
67
};
78
use std::fmt::Write;
89

@@ -176,6 +177,17 @@ impl JsonSchemaBuilder {
176177
..Default::default()
177178
}));
178179
}
180+
schema::BasicValueType::Union(s) => {
181+
schema.subschemas = Some(Box::new(SubschemaValidation {
182+
one_of: Some(
183+
s.types
184+
.iter()
185+
.map(|t| Schema::Object(self.for_basic_value_type(t, field_path)))
186+
.collect(),
187+
),
188+
..Default::default()
189+
}));
190+
}
179191
}
180192
schema
181193
}

src/base/schema.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ pub struct VectorTypeSchema {
99
pub dimension: Option<usize>,
1010
}
1111

12+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
13+
pub struct UnionTypeSchema {
14+
pub types: Vec<BasicValueType>,
15+
}
16+
1217
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
1318
#[serde(tag = "kind")]
1419
pub enum BasicValueType {
@@ -56,6 +61,9 @@ pub enum BasicValueType {
5661

5762
/// A vector of values (usually numbers, for embeddings).
5863
Vector(VectorTypeSchema),
64+
65+
/// A union
66+
Union(UnionTypeSchema),
5967
}
6068

6169
impl std::fmt::Display for BasicValueType {
@@ -82,6 +90,17 @@ impl std::fmt::Display for BasicValueType {
8290
}
8391
write!(f, "]")
8492
}
93+
BasicValueType::Union(s) => {
94+
write!(f, "Union[")?;
95+
for (i, typ) in s.types.iter().enumerate() {
96+
if i > 0 {
97+
// Add type delimiter
98+
write!(f, " | ")?;
99+
}
100+
write!(f, "{}", typ)?;
101+
}
102+
write!(f, "]")
103+
}
85104
}
86105
}
87106
}

src/base/value.rs

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,10 @@ pub enum BasicValue {
379379
TimeDelta(chrono::Duration),
380380
Json(Arc<serde_json::Value>),
381381
Vector(Arc<[BasicValue]>),
382+
UnionVariant {
383+
tag_id: usize,
384+
value: Box<BasicValue>,
385+
},
382386
}
383387

384388
impl From<Bytes> for BasicValue {
@@ -496,7 +500,8 @@ impl BasicValue {
496500
| BasicValue::OffsetDateTime(_)
497501
| BasicValue::TimeDelta(_)
498502
| BasicValue::Json(_)
499-
| BasicValue::Vector(_) => api_bail!("invalid key value type"),
503+
| BasicValue::Vector(_)
504+
| BasicValue::UnionVariant { .. } => api_bail!("invalid key value type"),
500505
};
501506
Ok(result)
502507
}
@@ -517,7 +522,8 @@ impl BasicValue {
517522
| BasicValue::OffsetDateTime(_)
518523
| BasicValue::TimeDelta(_)
519524
| BasicValue::Json(_)
520-
| BasicValue::Vector(_) => api_bail!("invalid key value type"),
525+
| BasicValue::Vector(_)
526+
| BasicValue::UnionVariant { .. } => api_bail!("invalid key value type"),
521527
};
522528
Ok(result)
523529
}
@@ -539,6 +545,7 @@ impl BasicValue {
539545
BasicValue::TimeDelta(_) => "timedelta",
540546
BasicValue::Json(_) => "json",
541547
BasicValue::Vector(_) => "vector",
548+
BasicValue::UnionVariant { .. } => "union",
542549
}
543550
}
544551
}
@@ -892,6 +899,12 @@ impl serde::Serialize for BasicValue {
892899
BasicValue::TimeDelta(v) => serializer.serialize_str(&v.to_string()),
893900
BasicValue::Json(v) => v.serialize(serializer),
894901
BasicValue::Vector(v) => v.serialize(serializer),
902+
BasicValue::UnionVariant { tag_id, value } => {
903+
let mut s = serializer.serialize_tuple(2)?;
904+
s.serialize_element(tag_id)?;
905+
s.serialize_element(value)?;
906+
s.end()
907+
}
895908
}
896909
}
897910
}
@@ -956,6 +969,40 @@ impl BasicValue {
956969
.collect::<Result<Vec<_>>>()?;
957970
BasicValue::Vector(Arc::from(vec))
958971
}
972+
(v, BasicValueType::Union(typ)) => {
973+
let arr = match v {
974+
serde_json::Value::Array(arr) => arr,
975+
_ => anyhow::bail!("Invalid JSON value for union, expect array"),
976+
};
977+
978+
if arr.len() != 2 {
979+
anyhow::bail!(
980+
"Invalid union tuple: expect 2 values, received {}",
981+
arr.len()
982+
);
983+
}
984+
985+
let mut obj_iter = arr.into_iter();
986+
987+
// Take first element
988+
let tag_id = obj_iter
989+
.next()
990+
.and_then(|value| value.as_u64().map(|num_u64| num_u64 as usize))
991+
.unwrap();
992+
993+
// Take second element
994+
let value = obj_iter.next().unwrap();
995+
996+
let cur_type = typ
997+
.types
998+
.get(tag_id)
999+
.ok_or_else(|| anyhow::anyhow!("No type in `tag_id` \"{tag_id}\" found"))?;
1000+
1001+
BasicValue::UnionVariant {
1002+
tag_id,
1003+
value: Box::new(BasicValue::from_json(value, cur_type)?),
1004+
}
1005+
}
9591006
(v, t) => {
9601007
anyhow::bail!("Value and type not matched.\nTarget type {t:?}\nJSON value: {v}\n")
9611008
}
@@ -1088,7 +1135,17 @@ impl Serialize for TypedValue<'_> {
10881135
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
10891136
match (self.t, self.v) {
10901137
(_, Value::Null) => serializer.serialize_none(),
1091-
(ValueType::Basic(_), v) => v.serialize(serializer),
1138+
(ValueType::Basic(t), v) => match t {
1139+
BasicValueType::Union(_) => match v {
1140+
Value::Basic(BasicValue::UnionVariant { value, .. }) => {
1141+
value.serialize(serializer)
1142+
}
1143+
_ => Err(serde::ser::Error::custom(
1144+
"Unmatched union type and value for `TypedValue`",
1145+
)),
1146+
},
1147+
_ => v.serialize(serializer),
1148+
},
10921149
(ValueType::Struct(s), Value::Struct(field_values)) => TypedFieldsValue {
10931150
schema: &s.fields,
10941151
values_iter: field_values.fields.iter(),

src/llm/litellm.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
1-
use async_openai::config::OpenAIConfig;
21
use async_openai::Client as OpenAIClient;
2+
use async_openai::config::OpenAIConfig;
33

44
pub use super::openai::Client;
55

66
impl Client {
77
pub async fn new_litellm(spec: super::LlmSpec) -> anyhow::Result<Self> {
8-
let address = spec.address.clone().unwrap_or_else(|| "http://127.0.0.1:4000".to_string());
8+
let address = spec
9+
.address
10+
.clone()
11+
.unwrap_or_else(|| "http://127.0.0.1:4000".to_string());
912
let api_key = std::env::var("LITELLM_API_KEY").ok();
1013
let mut config = OpenAIConfig::new().with_api_base(address);
1114
if let Some(api_key) = api_key {
1215
config = config.with_api_key(api_key);
1316
}
14-
Ok(Client::from_parts(OpenAIClient::with_config(config), spec.model))
17+
Ok(Client::from_parts(
18+
OpenAIClient::with_config(config),
19+
spec.model,
20+
))
1521
}
1622
}

src/llm/mod.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ pub trait LlmGenerationClient: Send + Sync {
5656

5757
mod anthropic;
5858
mod gemini;
59+
mod litellm;
5960
mod ollama;
6061
mod openai;
61-
mod litellm;
6262
mod openrouter;
6363

6464
pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGenerationClient>> {
@@ -78,11 +78,8 @@ pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGener
7878
LlmApiType::LiteLlm => {
7979
Box::new(litellm::Client::new_litellm(spec).await?) as Box<dyn LlmGenerationClient>
8080
}
81-
LlmApiType::OpenRouter => {
82-
Box::new(openrouter::Client::new_openrouter(spec).await?) as Box<dyn LlmGenerationClient>
83-
}
84-
85-
81+
LlmApiType::OpenRouter => Box::new(openrouter::Client::new_openrouter(spec).await?)
82+
as Box<dyn LlmGenerationClient>,
8683
};
8784
Ok(client)
8885
}

src/llm/openrouter.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
1-
use async_openai::config::OpenAIConfig;
21
use async_openai::Client as OpenAIClient;
2+
use async_openai::config::OpenAIConfig;
33

44
pub use super::openai::Client;
55

66
impl Client {
77
pub async fn new_openrouter(spec: super::LlmSpec) -> anyhow::Result<Self> {
8-
let address = spec.address.clone().unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string());
8+
let address = spec
9+
.address
10+
.clone()
11+
.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string());
912
let api_key = std::env::var("OPENROUTER_API_KEY").ok();
1013
let mut config = OpenAIConfig::new().with_api_base(address);
1114
if let Some(api_key) = api_key {
1215
config = config.with_api_key(api_key);
1316
}
14-
Ok(Client::from_parts(OpenAIClient::with_config(config), spec.model))
17+
Ok(Client::from_parts(
18+
OpenAIClient::with_config(config),
19+
spec.model,
20+
))
1521
}
1622
}

0 commit comments

Comments
 (0)