Skip to content

Commit 66a09ec

Browse files
committed
Update union type
1 parent 052815a commit 66a09ec

File tree

3 files changed

+71
-38
lines changed

3 files changed

+71
-38
lines changed

python/cocoindex/typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
233233
encoded_type['dimension'] = type_info.vector_info.dim
234234

235235
elif type_info.kind == 'Union':
236-
if type_info.elem_type is not UnionType:
236+
if type_info.elem_type is not types.UnionType:
237237
raise ValueError("Union type must have a union-typed element type")
238238
encoded_type['element_type'] = [
239239
_encode_type(analyze_type_info(typ)) for typ in typing.get_args(type_info.elem_type)

src/base/schema.rs

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -437,58 +437,57 @@ pub struct OpArgSchema {
437437
#[cfg(test)]
438438
mod tests {
439439
use super::*;
440+
use utils::union::UnionParseResult;
440441

441442
#[test]
442443
fn test_union_fmt_empty() {
443-
let typ = BasicValueType::Union(UnionType::default());
444-
let expected = "Union[]";
445-
446-
assert_eq!(typ.to_string(), expected);
444+
let result = UnionType::parse_from(Vec::new());
445+
assert!(matches!(result, Err(_)));
447446
}
448447

449448
#[test]
450449
fn test_union_fmt_single() {
451-
let typ = BasicValueType::Union(vec![BasicValueType::Uuid].into());
452-
let expected = "Union[Uuid]";
453-
454-
assert_eq!(typ.to_string(), expected);
450+
let result = UnionType::parse_from(vec![BasicValueType::Uuid]).unwrap();
451+
assert!(matches!(result, UnionParseResult::Single(BasicValueType::Uuid)));
455452
}
456453

457454
#[test]
458-
fn test_union_fmt_autosort_basic() {
459-
// Uuid | Date | Str | Bytes = Bytes | Str | Uuid | Date
460-
let typ = BasicValueType::Union(vec![
455+
fn test_union_fmt_flat() {
456+
let union = UnionType::coerce_from(vec![
461457
BasicValueType::Uuid,
462458
BasicValueType::Date,
463459
BasicValueType::Str,
464460
BasicValueType::Bytes,
465-
].into());
461+
]);
462+
463+
// Uuid | Date | Str | Bytes = Bytes | Str | Uuid | Date
464+
let typ = BasicValueType::Union(union);
466465
let expected = "Union[Bytes | Str | Uuid | Date]";
467466

468467
assert_eq!(typ.to_string(), expected);
469468
}
470469

471470
#[test]
472-
fn test_union_fmt_nested_auto_unpack() {
471+
fn test_union_fmt_nested() {
473472
// Uuid | Date | (Date | OffsetDateTime | Time | Bytes | (Bytes | Uuid | Time)) | Str |
474473
// Bytes = Bytes | Str | Uuid | Date | Time | OffsetDateTime
475-
let typ = BasicValueType::Union(vec![
474+
let typ = BasicValueType::Union(UnionType::coerce_from(vec![
476475
BasicValueType::Uuid,
477476
BasicValueType::Date,
478-
BasicValueType::Union(vec![
477+
BasicValueType::Union(UnionType::coerce_from(vec![
479478
BasicValueType::Date,
480479
BasicValueType::OffsetDateTime,
481480
BasicValueType::Time,
482481
BasicValueType::Bytes,
483-
BasicValueType::Union(vec![
482+
BasicValueType::Union(UnionType::coerce_from(vec![
484483
BasicValueType::Bytes,
485484
BasicValueType::Uuid,
486485
BasicValueType::Time,
487-
].into()),
488-
].into()),
486+
])),
487+
])),
489488
BasicValueType::Str,
490489
BasicValueType::Bytes,
491-
].into());
490+
]));
492491
let expected = "Union[Bytes | Str | Uuid | Date | Time | OffsetDateTime]";
493492

494493
assert_eq!(typ.to_string(), expected);

src/utils/union.rs

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,23 @@ use std::{str::FromStr, sync::Arc};
22

33
use crate::{base::{schema::BasicValueType, value::BasicValue}, prelude::*};
44

5+
#[derive(Debug, Clone)]
6+
pub enum UnionParseResult {
7+
Union(UnionType),
8+
Single(BasicValueType),
9+
}
10+
511
/// Union type helper storing an auto-sorted set of types excluding `Union`
6-
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
12+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
713
pub struct UnionType {
814
types: BTreeSet<BasicValueType>,
915
}
1016

1117
impl UnionType {
18+
fn new() -> Self {
19+
Self { types: BTreeSet::new() }
20+
}
21+
1222
pub fn types(&self) -> &BTreeSet<BasicValueType> {
1323
&self.types
1424
}
@@ -31,32 +41,56 @@ impl UnionType {
3141
}
3242
}
3343

34-
pub fn unpack(self) -> Self {
35-
self.types.into()
36-
}
37-
}
44+
fn resolve(self) -> Result<UnionParseResult> {
45+
if self.types().is_empty() {
46+
anyhow::bail!("The union is empty");
47+
}
3848

39-
impl From<Vec<BasicValueType>> for UnionType {
40-
fn from(value: Vec<BasicValueType>) -> Self {
41-
let mut union = Self::default();
49+
if self.types().len() == 1 {
50+
let mut type_tree: BTreeSet<BasicValueType> = self.into();
51+
return Ok(UnionParseResult::Single(type_tree.pop_first().unwrap()));
52+
}
4253

43-
for typ in value {
54+
Ok(UnionParseResult::Union(self))
55+
}
56+
57+
/// Move an iterable and parse it into a union type.
58+
/// If there is only one single unique type, it returns a single `BasicValueType`.
59+
pub fn parse_from<T>(
60+
input: impl IntoIterator<Item = BasicValueType, IntoIter = T>,
61+
) -> Result<UnionParseResult>
62+
where
63+
T: Iterator<Item = BasicValueType>,
64+
{
65+
let mut union = Self::new();
66+
67+
for typ in input {
4468
union.insert(typ);
4569
}
4670

47-
union
71+
union.resolve()
4872
}
49-
}
50-
51-
impl From<BTreeSet<BasicValueType>> for UnionType {
52-
fn from(value: BTreeSet<BasicValueType>) -> Self {
53-
let mut union = Self::default();
5473

55-
for typ in value {
56-
union.insert(typ);
74+
/// Assume the input already contains multiple unique types, panic otherwise.
75+
///
76+
/// This method is meant for streamlining the code for test cases.
77+
/// Use `parse_from()` instead unless you know the input.
78+
pub fn coerce_from<T>(
79+
input: impl IntoIterator<Item = BasicValueType, IntoIter = T>,
80+
) -> Self
81+
where
82+
T: Iterator<Item = BasicValueType>,
83+
{
84+
match Self::parse_from(input) {
85+
Ok(UnionParseResult::Union(union)) => union,
86+
_ => panic!("Do not use `coerce_from()` for basic type lists that can possibly be one type."),
5787
}
88+
}
89+
}
5890

59-
union
91+
impl Into<BTreeSet<BasicValueType>> for UnionType {
92+
fn into(self) -> BTreeSet<BasicValueType> {
93+
self.types
6094
}
6195
}
6296

0 commit comments

Comments
 (0)