Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 36 additions & 45 deletions src/daft-functions-list/src/kernels.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#![allow(deprecated, reason = "arrow2 migration")]

use std::{iter::repeat_n, sync::Arc};

use arrow::array::{BooleanBufferBuilder, make_comparator};
Expand Down Expand Up @@ -140,14 +138,14 @@ impl ListArrayExtension for ListArray {

let keys = self.flat_child.filter(&include_mask)?;

let keys = Series::try_from_field_and_arrow_array(
Field::new("key", key_type.clone()),
keys.to_arrow2(),
let keys = Series::from_arrow(
Arc::new(Field::new("key", key_type.clone())),
keys.to_arrow()?,
)?;

let values = Series::try_from_field_and_arrow_array(
Field::new("value", count_type.clone()),
values.to_arrow2(),
let values = Series::from_arrow(
Arc::new(Field::new("value", count_type.clone())),
values.to_arrow()?,
)?;

let struct_type = DataType::Struct(vec![
Expand Down Expand Up @@ -257,23 +255,21 @@ impl ListArrayExtension for ListArray {
} else {
assert_eq!(delimiter.len(), self.len());

Box::new(delimiter.as_arrow2().iter())
Box::new(delimiter.into_iter())
};
let self_iter = (0..self.len()).map(|i| self.get(i));

let result = self_iter
let result: Utf8Array = self_iter
.zip(delimiter_iter)
.map(|(list_element, delimiter)| {
join_arrow_list_of_utf8s(
list_element.as_ref().map(|l| l.utf8().unwrap().data()),
join_list_of_utf8s(
list_element.as_ref().map(|l| l.utf8().unwrap()),
delimiter.unwrap_or(""),
)
});
})
.collect();

Ok(Utf8Array::from((
self.name(),
Box::new(daft_arrow::array::Utf8Array::from_iter(result)),
)))
Ok(result.rename(self.name()))
}

fn get_children(&self, idx: &Int64Array, default: &Series) -> DaftResult<Series> {
Expand Down Expand Up @@ -333,8 +329,10 @@ impl ListArrayExtension for ListArray {
)?
}
} else {
let desc_iter = desc.as_arrow2().values_iter();
let nulls_first_iter = nulls_first.as_arrow2().values_iter();
let desc_arrow = desc.as_arrow()?;
let nulls_first_arrow = nulls_first.as_arrow()?;
let desc_iter = desc_arrow.values().iter();
let nulls_first_iter = nulls_first_arrow.values().iter();
if let Some(nulls) = self.nulls() {
list_sort_helper(
&self.flat_child,
Expand Down Expand Up @@ -404,9 +402,9 @@ impl ListArrayExtension for ListArray {
let mut all_true = true;
let bool_slice = slice.bool()?;
let bool_nulls = bool_slice.nulls();
let bool_data = bool_slice.as_arrow2().values();
let bool_arrow = bool_slice.as_arrow()?;
for j in 0..bool_slice.len() {
if bool_nulls.is_none_or(|v| v.is_valid(j)) && !bool_data.get_bit(j) {
if bool_nulls.is_none_or(|v| v.is_valid(j)) && !bool_arrow.value(j) {
all_true = false;
break;
}
Expand Down Expand Up @@ -455,9 +453,9 @@ impl ListArrayExtension for ListArray {
let mut any_true = false;
let bool_slice = slice.bool()?;
let bool_nulls = bool_slice.nulls();
let bool_data = bool_slice.as_arrow2().values();
let bool_arrow = bool_slice.as_arrow()?;
for j in 0..bool_slice.len() {
if bool_nulls.is_none_or(|v| v.is_valid(j)) && bool_data.get_bit(j) {
if bool_nulls.is_none_or(|v| v.is_valid(j)) && bool_arrow.value(j) {
any_true = true;
break;
}
Expand All @@ -467,8 +465,7 @@ impl ListArrayExtension for ListArray {
}

let null_buffer = daft_arrow::buffer::NullBuffer::from_iter(result_nulls.iter().copied());
let values = daft_arrow::bitmap::Bitmap::from_iter(result.iter().copied());
BooleanArray::from_iter_values(values)
BooleanArray::from_iter_values(result)
.rename(self.name())
.with_nulls(Some(null_buffer))
}
Expand Down Expand Up @@ -543,23 +540,21 @@ impl ListArrayExtension for FixedSizeListArray {
Box::new(repeat_n(delimiter.get(0), self.len()))
} else {
assert_eq!(delimiter.len(), self.len());
Box::new(delimiter.as_arrow2().iter())
Box::new(delimiter.into_iter())
};
let self_iter = (0..self.len()).map(|i| self.get(i));

let result = self_iter
let result: Utf8Array = self_iter
.zip(delimiter_iter)
.map(|(list_element, delimiter)| {
join_arrow_list_of_utf8s(
list_element.as_ref().map(|l| l.utf8().unwrap().data()),
join_list_of_utf8s(
list_element.as_ref().map(|l| l.utf8().unwrap()),
delimiter.unwrap_or(""),
)
});
})
.collect();

Ok(Utf8Array::from((
self.name(),
Box::new(daft_arrow::array::Utf8Array::from_iter(result)),
)))
Ok(result.rename(self.name()))
}

fn get_children(&self, idx: &Int64Array, default: &Series) -> DaftResult<Series> {
Expand Down Expand Up @@ -654,8 +649,10 @@ impl ListArrayExtension for FixedSizeListArray {
)?
}
} else {
let desc_iter = desc.as_arrow2().values_iter();
let nulls_first_iter = nulls_first.as_arrow2().values_iter();
let desc_arrow = desc.as_arrow()?;
let nulls_first_arrow = nulls_first.as_arrow()?;
let desc_iter = desc_arrow.values().iter();
let nulls_first_iter = nulls_first_arrow.values().iter();
if let Some(nulls) = self.nulls() {
list_sort_helper_fixed_size(
&self.flat_child,
Expand Down Expand Up @@ -685,17 +682,11 @@ impl ListArrayExtension for FixedSizeListArray {
}
}

fn join_arrow_list_of_utf8s(
list_element: Option<&dyn daft_arrow::array::Array>,
delimiter_str: &str,
) -> Option<String> {
fn join_list_of_utf8s(list_element: Option<&Utf8Array>, delimiter_str: &str) -> Option<String> {
list_element
.map(|list_element| {
list_element
.as_any()
.downcast_ref::<daft_arrow::array::Utf8Array<i64>>()
.unwrap()
.iter()
.into_iter()
.fold(String::new(), |acc, str_item| {
acc + str_item.unwrap_or("") + delimiter_str
})
Expand All @@ -719,7 +710,7 @@ fn create_iter<'a>(arr: &'a Int64Array, len: usize) -> Box<dyn Iterator<Item = i
1 => Box::new(repeat_n(arr.get(0).unwrap(), len)),
arr_len => {
assert_eq!(arr_len, len);
Box::new(arr.as_arrow2().iter().map(|x| *x.unwrap()))
Box::new(arr.into_iter().map(|x| *x.unwrap()))
}
}
}
Expand Down
Loading