diff --git a/src/daft-core/src/array/extension_array.rs b/src/daft-core/src/array/extension_array.rs new file mode 100644 index 0000000000..b9a5d76d26 --- /dev/null +++ b/src/daft-core/src/array/extension_array.rs @@ -0,0 +1,154 @@ +use std::sync::Arc; + +use arrow::{array::ArrayRef, buffer::NullBuffer}; +use common_error::DaftResult; +use daft_schema::{dtype::DataType, field::Field}; + +use crate::{datatypes::DaftArrayType, series::Series}; + +#[derive(Clone, Debug)] +pub struct ExtensionArray { + field: Arc, + /// Extension type name (e.g. "geoarrow.point") + extension_name: Arc, + /// Extension metadata (e.g. '{"crs": "WGS84"}') + metadata: Option>, + /// The underlying storage data + pub physical: Series, +} + +impl ExtensionArray { + pub fn new(field: Arc, physical: Series) -> Self { + let DataType::Extension(ext_name, _, ext_metadata) = &field.dtype else { + panic!( + "ExtensionArray field must have Extension dtype, got {}", + field.dtype + ); + }; + Self { + extension_name: Arc::from(ext_name.as_str()), + metadata: ext_metadata.as_deref().map(Arc::from), + field, + physical, + } + } + + pub fn name(&self) -> &str { + self.field.name.as_ref() + } + + pub fn data_type(&self) -> &DataType { + &self.field.dtype + } + + pub fn extension_name(&self) -> &str { + &self.extension_name + } + + pub fn extension_metadata(&self) -> Option<&str> { + self.metadata.as_deref() + } + + pub fn field(&self) -> &Field { + &self.field + } + + pub fn len(&self) -> usize { + self.physical.len() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn rename(&self, name: &str) -> Self { + Self { + field: Arc::new(Field::new(name, self.field.dtype.clone())), + extension_name: self.extension_name.clone(), + metadata: self.metadata.clone(), + physical: self.physical.rename(name), + } + } + + /// Replace the underlying physical `Series` of this `ExtensionArray`. + pub fn with_physical(&self, physical: Series) -> Self { + Self { + field: self.field.clone(), + extension_name: self.extension_name.clone(), + metadata: self.metadata.clone(), + physical, + } + } + + pub fn nulls(&self) -> Option<&NullBuffer> { + self.physical.inner.nulls() + } + + pub fn to_arrow(&self) -> DaftResult { + let arr = self.physical.to_arrow()?; + let target_field = self.field.to_arrow()?; + if arr.data_type() != target_field.data_type() { + Ok(arrow::compute::cast(&arr, target_field.data_type())?) + } else { + Ok(arr) + } + } + + pub fn slice(&self, start: usize, end: usize) -> DaftResult { + Ok(self.with_physical(self.physical.slice(start, end)?)) + } + + pub fn concat(arrays: &[&Self]) -> DaftResult { + if arrays.is_empty() { + return Err(common_error::DaftError::ValueError( + "Cannot concat empty list of ExtensionArrays".to_string(), + )); + } + let first = arrays[0]; + let physical_arrays: Vec<&Series> = arrays.iter().map(|a| &a.physical).collect(); + let physical = Series::concat(&physical_arrays)?; + Ok(first.with_physical(physical)) + } +} + +impl DaftArrayType for ExtensionArray { + fn data_type(&self) -> &DataType { + &self.field.dtype + } +} + +impl crate::array::ops::from_arrow::FromArrow for ExtensionArray { + fn from_arrow>( + field: F, + arrow_arr: ArrayRef, + ) -> DaftResult { + let field: daft_schema::field::FieldRef = field.into(); + let DataType::Extension(_, storage_type, _) = &field.dtype else { + return Err(common_error::DaftError::TypeError(format!( + "Expected Extension dtype for ExtensionArray, got {}", + field.dtype + ))); + }; + let storage_field = Arc::new(Field::new(field.name.as_ref(), *storage_type.clone())); + let physical = Series::from_arrow(storage_field, arrow_arr)?; + Ok(Self::new(field, physical)) + } +} + +impl crate::array::ops::full::FullNull for ExtensionArray { + fn full_null(name: &str, dtype: &DataType, length: usize) -> Self { + let DataType::Extension(_, storage_type, _) = dtype else { + panic!("Expected Extension dtype for ExtensionArray::full_null, got {dtype}"); + }; + let physical = Series::full_null(name, storage_type, length); + Self::new(Arc::new(Field::new(name, dtype.clone())), physical) + } + + fn empty(name: &str, dtype: &DataType) -> Self { + let DataType::Extension(_, storage_type, _) = dtype else { + panic!("Expected Extension dtype for ExtensionArray::empty, got {dtype}"); + }; + let physical = Series::empty(name, storage_type); + Self::new(Arc::new(Field::new(name, dtype.clone())), physical) + } +} diff --git a/src/daft-core/src/array/growable/extension_growable.rs b/src/daft-core/src/array/growable/extension_growable.rs new file mode 100644 index 0000000000..77d083153d --- /dev/null +++ b/src/daft-core/src/array/growable/extension_growable.rs @@ -0,0 +1,54 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_schema::{dtype::DataType, field::Field}; + +use super::Growable; +use crate::{ + array::extension_array::ExtensionArray, + series::{IntoSeries, Series}, +}; + +pub struct ExtensionGrowable<'a> { + name: String, + dtype: DataType, + physical_growable: Box, +} + +impl<'a> ExtensionGrowable<'a> { + pub fn new( + name: &str, + dtype: &DataType, + arrays: Vec<&'a ExtensionArray>, + use_validity: bool, + capacity: usize, + ) -> Self { + let DataType::Extension(_, storage_type, _) = dtype else { + panic!("Expected Extension dtype for ExtensionGrowable, got {dtype}"); + }; + let physical_series: Vec<&Series> = arrays.iter().map(|a| &a.physical).collect(); + let physical_growable = + super::make_growable(name, storage_type, physical_series, use_validity, capacity); + Self { + name: name.to_string(), + dtype: dtype.clone(), + physical_growable, + } + } +} + +impl Growable for ExtensionGrowable<'_> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + self.physical_growable.extend(index, start, len); + } + + fn add_nulls(&mut self, additional: usize) { + self.physical_growable.add_nulls(additional); + } + + fn build(&mut self) -> DaftResult { + let physical = self.physical_growable.build()?; + let field = Arc::new(Field::new(self.name.as_str(), self.dtype.clone())); + Ok(ExtensionArray::new(field, physical).into_series()) + } +} diff --git a/src/daft-core/src/array/growable/mod.rs b/src/daft-core/src/array/growable/mod.rs index b2712c7dfe..c0ffa14908 100644 --- a/src/daft-core/src/array/growable/mod.rs +++ b/src/daft-core/src/array/growable/mod.rs @@ -1,7 +1,10 @@ use common_error::DaftResult; +use extension_growable::ExtensionGrowable; use crate::{ - array::{FixedSizeListArray, ListArray, StructArray, prelude::*}, + array::{ + FixedSizeListArray, ListArray, StructArray, extension_array::ExtensionArray, prelude::*, + }, datatypes::{FileArray, prelude::*}, file::DaftMediaType, series::Series, @@ -10,6 +13,7 @@ use crate::{ mod arrow_growable; mod bitmap_growable; +mod extension_growable; mod fixed_size_list_growable; mod list_growable; mod logical_growable; @@ -170,10 +174,19 @@ impl_growable_array!( arrow_growable::ArrowGrowable<'a, FixedSizeBinaryType> ); impl_growable_array!(Utf8Array, arrow_growable::ArrowGrowable<'a, Utf8Type>); -impl_growable_array!( - ExtensionArray, - arrow_growable::ArrowGrowable<'a, ExtensionType> -); +impl GrowableArray for ExtensionArray { + type GrowableType<'a> = ExtensionGrowable<'a>; + + fn make_growable<'a>( + name: &str, + dtype: &DataType, + arrays: Vec<&'a Self>, + use_validity: bool, + capacity: usize, + ) -> Self::GrowableType<'a> { + ExtensionGrowable::new(name, dtype, arrays, use_validity, capacity) + } +} impl_growable_array!( FixedSizeListArray, fixed_size_list_growable::FixedSizeListGrowable<'a> diff --git a/src/daft-core/src/array/mod.rs b/src/daft-core/src/array/mod.rs index a8b2f55825..fabd7d09fc 100644 --- a/src/daft-core/src/array/mod.rs +++ b/src/daft-core/src/array/mod.rs @@ -1,3 +1,4 @@ +pub mod extension_array; pub mod file_array; mod fixed_size_list_array; pub mod from; diff --git a/src/daft-core/src/array/ops/broadcast.rs b/src/daft-core/src/array/ops/broadcast.rs index 8cf8488a3c..1fa1bdb380 100644 --- a/src/daft-core/src/array/ops/broadcast.rs +++ b/src/daft-core/src/array/ops/broadcast.rs @@ -200,9 +200,9 @@ macro_rules! impl_broadcast_via_concat { impl_broadcast_via_concat!(FixedSizeListArray); impl_broadcast_via_concat!(ListArray); -impl_broadcast_via_concat!(ExtensionArray); #[cfg(feature = "python")] impl_broadcast_via_concat!(PythonArray); +impl_broadcast_via_concat!(ExtensionArray); impl Broadcastable for StructArray { fn broadcast(&self, num: usize) -> DaftResult { diff --git a/src/daft-core/src/array/ops/get.rs b/src/daft-core/src/array/ops/get.rs index 4d5530fe71..45f5c39c3d 100644 --- a/src/daft-core/src/array/ops/get.rs +++ b/src/daft-core/src/array/ops/get.rs @@ -194,10 +194,9 @@ impl ExtensionArray { idx, self.len() ); - let is_valid = self.is_valid(idx); - if is_valid { + if self.physical.is_valid(idx) { let scalar = self.slice(idx, idx + 1).unwrap(); - let scalar = Scalar::new(scalar.to_arrow()); + let scalar = Scalar::new(scalar.to_arrow().unwrap()); Some(scalar) } else { None diff --git a/src/daft-core/src/array/ops/get_lit.rs b/src/daft-core/src/array/ops/get_lit.rs index 5fd7fc4f5b..631a4f572b 100644 --- a/src/daft-core/src/array/ops/get_lit.rs +++ b/src/daft-core/src/array/ops/get_lit.rs @@ -180,7 +180,7 @@ impl ExtensionArray { self.len() ); - if self.is_valid(idx) { + if self.physical.is_valid(idx) { Literal::Extension(self.slice(idx, idx + 1).unwrap().into_series()) } else { Literal::Null diff --git a/src/daft-core/src/array/ops/null.rs b/src/daft-core/src/array/ops/null.rs index f8841e2fba..0106373601 100644 --- a/src/daft-core/src/array/ops/null.rs +++ b/src/daft-core/src/array/ops/null.rs @@ -200,3 +200,12 @@ impl StructArray { } } } +impl ExtensionArray { + #[inline] + pub fn is_valid(&self, idx: usize) -> bool { + match self.nulls() { + None => true, + Some(nulls) => nulls.is_valid(idx), + } + } +} diff --git a/src/daft-core/src/array/serdes.rs b/src/daft-core/src/array/serdes.rs index 3e5e12b1c7..c0c994420f 100644 --- a/src/daft-core/src/array/serdes.rs +++ b/src/daft-core/src/array/serdes.rs @@ -7,11 +7,11 @@ use super::{DataArray, FixedSizeListArray, ListArray, StructArray}; use crate::prelude::PythonArray; use crate::{ datatypes::{ - BinaryArray, BooleanArray, DaftLogicalType, DaftPrimitiveType, DataType, ExtensionArray, - Field, FixedSizeBinaryArray, Int64Array, IntervalArray, NullArray, Utf8Array, + BinaryArray, BooleanArray, DaftLogicalType, DaftPrimitiveType, ExtensionArray, + FixedSizeBinaryArray, Int64Array, IntervalArray, NullArray, Utf8Array, logical::LogicalArray, }, - series::{IntoSeries, Series}, + series::IntoSeries, }; pub struct IterSer @@ -102,15 +102,7 @@ impl serde::Serialize for ExtensionArray { { let mut s = serializer.serialize_map(Some(2))?; s.serialize_entry("field", self.field())?; - let DataType::Extension(_, inner, _) = self.data_type() else { - panic!("Expected Extension Type!") - }; - let values = Series::from_arrow( - Field::new("physical", inner.as_ref().clone()), - self.to_arrow(), - ) - .unwrap(); - s.serialize_entry("values", &values)?; + s.serialize_entry("values", &self.physical)?; s.end() } } diff --git a/src/daft-core/src/datatypes/matching.rs b/src/daft-core/src/datatypes/matching.rs index a4a452da24..4d2b342d01 100644 --- a/src/daft-core/src/datatypes/matching.rs +++ b/src/daft-core/src/datatypes/matching.rs @@ -91,7 +91,6 @@ macro_rules! with_match_physical_daft_types { DataType::FixedSizeList(_, _) => __with_ty__! { FixedSizeListType }, DataType::List(_) => __with_ty__! { ListType }, DataType::Struct(_) => __with_ty__! { StructType }, - DataType::Extension(_, _, _) => __with_ty__! { ExtensionType }, DataType::Interval => __with_ty__! { IntervalType }, #[cfg(feature = "python")] DataType::Python => __with_ty__! { PythonType }, @@ -129,7 +128,6 @@ macro_rules! with_match_arrow_daft_types { DataType::Float64 => __with_ty__! { Float64Type }, DataType::Decimal128(..) => __with_ty__! { Decimal128Type }, DataType::List(_) => __with_ty__! { ListType }, - DataType::Extension(_, _, _) => __with_ty__! { ExtensionType }, DataType::Utf8 => __with_ty__! { Utf8Type }, _ => panic!("{:?} not implemented", $key_type) diff --git a/src/daft-core/src/datatypes/mod.rs b/src/daft-core/src/datatypes/mod.rs index 1833268c50..542ce3bc56 100644 --- a/src/daft-core/src/datatypes/mod.rs +++ b/src/daft-core/src/datatypes/mod.rs @@ -235,7 +235,18 @@ impl_daft_arrow_datatype!(Float64Type, Float64); impl_daft_arrow_datatype!(BinaryType, Binary); impl_daft_arrow_datatype!(FixedSizeBinaryType, Unknown); impl_daft_arrow_datatype!(Utf8Type, Utf8); -impl_daft_arrow_datatype!(ExtensionType, Unknown); +// ExtensionType is a logical type backed by a variable physical type (stored as Series). +// It is neither DaftPhysicalType nor DaftArrowBackedType. +#[derive(Clone, Debug)] +pub struct ExtensionType {} + +impl DaftDataType for ExtensionType { + #[inline] + fn get_dtype() -> DataType { + DataType::Unknown + } + type ArrayType = crate::array::extension_array::ExtensionArray; +} impl_daft_arrow_datatype!(Decimal128Type, Unknown); impl_nested_datatype!(FixedSizeListType, FixedSizeListArray); @@ -461,6 +472,6 @@ pub type Float64Array = DataArray; pub type BinaryArray = DataArray; pub type FixedSizeBinaryArray = DataArray; pub type Utf8Array = DataArray; -pub type ExtensionArray = DataArray; +pub use crate::array::extension_array::ExtensionArray; pub type IntervalArray = DataArray; pub type Decimal128Array = DataArray; diff --git a/src/daft-core/src/series/array_impl/data_array.rs b/src/daft-core/src/series/array_impl/data_array.rs index 1d53ea4231..f25bb93203 100644 --- a/src/daft-core/src/series/array_impl/data_array.rs +++ b/src/daft-core/src/series/array_impl/data_array.rs @@ -184,16 +184,3 @@ impl_series_like_for_data_array!(Float64Array); impl_series_like_for_data_array!(Utf8Array); impl_series_like_for_data_array!(IntervalArray); impl_series_like_for_data_array!(Decimal128Array); -impl_series_like_for_data_array!(ExtensionArray, { - fn to_arrow(&self) -> DaftResult { - let arr: ArrayRef = self.0.to_arrow(); - // Reverse the coercion applied during from_arrow (e.g. LargeBinary → Binary) - // so callers see the original storage type. - let target_field = self.0.field.to_arrow()?; - if arr.data_type() != target_field.data_type() { - Ok(arrow::compute::cast(&arr, target_field.data_type())?) - } else { - Ok(arr) - } - } -}); diff --git a/src/daft-core/src/series/array_impl/extension_array.rs b/src/daft-core/src/series/array_impl/extension_array.rs new file mode 100644 index 0000000000..668792c119 --- /dev/null +++ b/src/daft-core/src/series/array_impl/extension_array.rs @@ -0,0 +1,163 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_schema::field::Field; + +use super::{ArrayWrapper, IntoSeries}; +use crate::{ + array::{extension_array::ExtensionArray, ops::GroupIndices}, + datatypes::{BooleanArray, DataType}, + lit::Literal, + prelude::UInt64Array, + series::{Series, series_like::SeriesLike}, +}; + +impl IntoSeries for ExtensionArray { + fn into_series(self) -> Series { + Series { + inner: Arc::new(ArrayWrapper(self)), + } + } +} + +impl SeriesLike for ArrayWrapper { + fn into_series(&self) -> Series { + self.0.clone().into_series() + } + + fn to_arrow(&self) -> DaftResult { + self.0.to_arrow() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn with_nulls(&self, nulls: Option) -> DaftResult { + Ok(self + .0 + .with_physical(self.0.physical.with_nulls(nulls)?) + .into_series()) + } + + fn nulls(&self) -> Option<&arrow::buffer::NullBuffer> { + self.0.nulls() + } + + fn min(&self, groups: Option<&GroupIndices>) -> DaftResult { + self.0.physical.min(groups) + } + + fn max(&self, groups: Option<&GroupIndices>) -> DaftResult { + self.0.physical.max(groups) + } + + fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult { + self.0.physical.agg_list(groups) + } + + fn agg_set(&self, groups: Option<&GroupIndices>) -> DaftResult { + self.0.physical.agg_set(groups) + } + + fn broadcast(&self, num: usize) -> DaftResult { + Ok(self + .0 + .with_physical(self.0.physical.broadcast(num)?) + .into_series()) + } + + fn cast(&self, datatype: &DataType) -> DaftResult { + match datatype { + DataType::Extension(_, storage_type, _) => { + let casted_storage = self.0.physical.cast(storage_type)?; + let field = Arc::new(Field::new(self.0.name(), datatype.clone())); + Ok(ExtensionArray::new(field, casted_storage).into_series()) + } + _ => self.0.physical.cast(datatype), + } + } + + fn filter(&self, mask: &BooleanArray) -> DaftResult { + Ok(self + .0 + .with_physical(self.0.physical.filter(mask)?) + .into_series()) + } + + fn if_else(&self, other: &Series, predicate: &Series) -> DaftResult { + let other_physical = match other.downcast::() { + Ok(other_ext) => &other_ext.physical, + Err(_) => other, + }; + Ok(self + .0 + .with_physical(self.0.physical.if_else(other_physical, predicate)?) + .into_series()) + } + + fn data_type(&self) -> &DataType { + self.0.data_type() + } + + fn field(&self) -> &Field { + self.0.field() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn name(&self) -> &str { + self.0.name() + } + + fn rename(&self, name: &str) -> Series { + self.0.rename(name).into_series() + } + + fn size_bytes(&self) -> usize { + self.0.physical.size_bytes() + } + + fn is_null(&self) -> DaftResult { + self.0.physical.is_null() + } + + fn not_null(&self) -> DaftResult { + self.0.physical.not_null() + } + + fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { + Ok(self + .0 + .with_physical(self.0.physical.sort(descending, nulls_first)?) + .into_series()) + } + + fn head(&self, num: usize) -> DaftResult { + Ok(self + .0 + .with_physical(self.0.physical.head(num)?) + .into_series()) + } + + fn slice(&self, start: usize, end: usize) -> DaftResult { + Ok(self.0.slice(start, end)?.into_series()) + } + + fn take(&self, idx: &UInt64Array) -> DaftResult { + Ok(self + .0 + .with_physical(self.0.physical.take(idx)?) + .into_series()) + } + + fn str_value(&self, idx: usize) -> DaftResult { + self.0.physical.inner.str_value(idx) + } + + fn get_lit(&self, idx: usize) -> Literal { + self.0.get_lit(idx) + } +} diff --git a/src/daft-core/src/series/array_impl/mod.rs b/src/daft-core/src/series/array_impl/mod.rs index 1308d146bf..de3ac7ed62 100644 --- a/src/daft-core/src/series/array_impl/mod.rs +++ b/src/daft-core/src/series/array_impl/mod.rs @@ -1,4 +1,5 @@ pub mod data_array; +pub mod extension_array; pub mod logical_array; pub mod nested_array; #[cfg(feature = "python")] diff --git a/src/daft-core/src/series/serdes.rs b/src/daft-core/src/series/serdes.rs index be4c0e5bd6..f089df2760 100644 --- a/src/daft-core/src/series/serdes.rs +++ b/src/daft-core/src/series/serdes.rs @@ -153,11 +153,7 @@ impl<'d> serde::Deserialize<'d> for Series { .into_series()), DataType::Extension(..) => { let physical = map.next_value::()?; - let physical = physical.to_arrow().unwrap(); - - Ok(ExtensionArray::from_arrow(Arc::new(field), physical) - .unwrap() - .into_series()) + Ok(ExtensionArray::new(Arc::new(field), physical).into_series()) } DataType::Map { .. } => { let physical = map.next_value::()?; diff --git a/src/daft-ext-core/src/function.rs b/src/daft-ext-core/src/function.rs index b6e1a9ce7c..3a89f5684f 100644 --- a/src/daft-ext-core/src/function.rs +++ b/src/daft-ext-core/src/function.rs @@ -140,7 +140,7 @@ mod tests { Ok(export_schema(&Schema::new(vec![field]))) } - fn call(&self, mut args: Vec) -> DaftResult { + fn call(&self, args: Vec) -> DaftResult { let input_array = import_array(args.into_iter().next().unwrap()); let input = input_array .as_any() diff --git a/src/daft-schema/src/dtype.rs b/src/daft-schema/src/dtype.rs index 83e54331ca..396012b6a8 100644 --- a/src/daft-schema/src/dtype.rs +++ b/src/daft-schema/src/dtype.rs @@ -379,6 +379,7 @@ impl DataType { Field::new("indices", List(Box::new(minimal_indices_dtype))) }, ]), + Extension(_, storage, _) => storage.to_physical(), File(..) => Struct(vec![ Field::new("url", Utf8), Field::new("io_config", Binary), @@ -798,6 +799,7 @@ impl DataType { | Self::SparseTensor(..) | Self::FixedShapeSparseTensor(..) | Self::Map { .. } + | Self::Extension(..) | Self::File(..) ) }