Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1361,7 +1361,7 @@ class PyExpr:
def skew(self) -> PyExpr: ...
def agg_list(self) -> PyExpr: ...
def agg_set(self) -> PyExpr: ...
def agg_concat(self) -> PyExpr: ...
def agg_concat(self, delimiter: str | None = None) -> PyExpr: ...
def over(self, window_spec: WindowSpec) -> PyExpr: ...
def offset(self, offset: int, default: PyExpr | None = None) -> PyExpr: ...
def __add__(self, other: PyExpr) -> PyExpr: ...
Expand Down
17 changes: 10 additions & 7 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3674,13 +3674,16 @@ def agg_set(self, *cols: ColumnInputType) -> "DataFrame":
return self._apply_agg_fn(Expression.list_agg_distinct, cols)

@DataframePublicAPI
def agg_concat(self, *cols: ColumnInputType) -> "DataFrame":
"""Performs a global list concatenation agg on the DataFrame.
def agg_concat(self, *cols: ColumnInputType, delimiter: str | None = None) -> "DataFrame":
"""Performs a global concatenation agg on the DataFrame.

Args:
*cols (Union[str, Expression]): columns that are lists to concatenate
*cols (Union[str, Expression]): columns that are lists or strings to concatenate
delimiter: Optional delimiter to insert between concatenated string values. Only supported for string
columns.

Returns:
DataFrame: Globally aggregated list. Should be a single row.
DataFrame: Globally aggregated list or string. Should be a single row.

Examples:
>>> import daft
Expand All @@ -3698,7 +3701,7 @@ def agg_concat(self, *cols: ColumnInputType) -> "DataFrame":
<BLANKLINE>
(Showing first 1 of 1 rows)
"""
return self._apply_agg_fn(Expression.string_agg, cols)
return self._apply_agg_fn(lambda expr: Expression.string_agg(expr, delimiter=delimiter), cols)

@DataframePublicAPI
def agg(self, *to_agg: Expression | Iterable[Expression]) -> "DataFrame":
Expand Down Expand Up @@ -4898,13 +4901,13 @@ def list_agg_distinct(self, *cols: ColumnInputType) -> DataFrame:
"""
return self.df._apply_agg_fn(Expression.list_agg_distinct, cols, self.group_by)

def string_agg(self, *cols: ColumnInputType) -> DataFrame:
def string_agg(self, *cols: ColumnInputType, delimiter: str | None = None) -> DataFrame:
"""Performs grouped string concat on this GroupedDataFrame.

Returns:
DataFrame: DataFrame with grouped string concatenated per column.
"""
return self.df._apply_agg_fn(Expression.string_agg, cols, self.group_by)
return self.df._apply_agg_fn(lambda expr: Expression.string_agg(expr, delimiter=delimiter), cols, self.group_by)

def agg(self, *to_agg: Expression | Iterable[Expression]) -> DataFrame:
"""Perform aggregations on this GroupedDataFrame. Allows for mixed aggregations.
Expand Down
7 changes: 5 additions & 2 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,15 +1067,18 @@ def list_agg_distinct(self) -> Expression:

return list_agg_distinct(self)

def string_agg(self) -> Expression:
def string_agg(self, delimiter: str | None = None) -> Expression:
"""Aggregates the values in the expression into a single string by concatenating them.

Args:
delimiter: Optional delimiter to insert between concatenated values. Only supported for string columns.

Tip: See Also
[`daft.functions.string_agg`](https://docs.daft.ai/en/stable/api/functions/string_agg/)
"""
from daft.functions import string_agg

return string_agg(self)
return string_agg(self, delimiter=delimiter)

def apply(self, func: Callable[..., Any], return_dtype: DataTypeLike) -> Expression:
"""Apply a function on each value in a given expression.
Expand Down
10 changes: 7 additions & 3 deletions daft/functions/agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ def list_agg_distinct(expr: Expression) -> Expression:
return Expression._from_pyexpr(expr._expr.agg_set())


def string_agg(expr: Expression) -> Expression:
"""Aggregates the values in the expression into a single string by concatenating them."""
return Expression._from_pyexpr(expr._expr.agg_concat())
def string_agg(expr: Expression, delimiter: str | None = None) -> Expression:
"""Aggregates the values in the expression into a single string by concatenating them.

Args:
delimiter: Optional delimiter to insert between concatenated values. Only supported for string columns.
"""
return Expression._from_pyexpr(expr._expr.agg_concat(delimiter))
92 changes: 69 additions & 23 deletions src/daft-core/src/series/ops/agg.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use common_error::{DaftError, DaftResult};
use daft_arrow::{array::PrimitiveArray, offset::OffsetsBuffer};
use daft_arrow::offset::OffsetsBuffer;

use crate::{
array::{
Expand All @@ -23,6 +23,19 @@ fn deduplicate_indices(series: &Series) -> DaftResult<Vec<u64>> {
Ok(unique_indices)
}

fn join_with_delimiter<'a, I>(mut iter: I, delimiter: &str) -> Option<String>
where
I: Iterator<Item = &'a str>,
{
let first = iter.next()?;
let mut output = String::from(first);
for value in iter {
output.push_str(delimiter);
output.push_str(value);
}
Some(output)
}

impl Series {
pub fn count(&self, groups: Option<&GroupIndices>, mode: CountMode) -> DaftResult<Self> {
let s = self.as_physical()?;
Expand Down Expand Up @@ -269,18 +282,17 @@ impl Series {
}

pub fn any_value(&self, groups: Option<&GroupIndices>, ignore_nulls: bool) -> DaftResult<Self> {
let indices = match groups {
let indices: UInt64Array = match groups {
Some(groups) => {
if self.data_type().is_null() {
PrimitiveArray::new_null(daft_arrow::datatypes::DataType::UInt64, groups.len())
std::iter::repeat_n(None, groups.len()).collect()
} else if ignore_nulls && let Some(nulls) = self.nulls() {
PrimitiveArray::from_trusted_len_iter(
groups
.iter()
.map(|g| g.iter().find(|i| nulls.is_valid(**i as usize)).copied()),
)
groups
.iter()
.map(|g| g.iter().find(|i| nulls.is_valid(**i as usize)).copied())
.collect()
} else {
PrimitiveArray::from_trusted_len_iter(groups.iter().map(|g| g.first().copied()))
groups.iter().map(|g| g.first().copied()).collect()
}
}
None => {
Expand All @@ -292,11 +304,11 @@ impl Series {
Some(0)
};

PrimitiveArray::from([idx])
std::iter::once(idx).collect()
}
};

self.take(&UInt64Array::from(("", Box::new(indices))))
self.take(&indices)
}

pub fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
Expand All @@ -307,25 +319,59 @@ impl Series {
self.inner.agg_set(groups)
}

pub fn agg_concat(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
pub fn agg_concat(
&self,
groups: Option<&GroupIndices>,
delimiter: Option<&str>,
) -> DaftResult<Self> {
use crate::array::ops::DaftConcatAggable;

let has_delimiter = delimiter.is_some_and(|d| !d.is_empty());

match self.data_type() {
DataType::List(..) => {
let downcasted = self.downcast::<ListArray>()?;
match groups {
Some(groups) => {
Ok(DaftConcatAggable::grouped_concat(downcasted, groups)?.into_series())
}
None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()),
if has_delimiter {
return Err(DaftError::TypeError(
"concat aggregation delimiter is only supported for Utf8".to_string(),
));
}
let downcasted = self.downcast::<ListArray>()?;
let result = match groups {
Some(groups) => DaftConcatAggable::grouped_concat(downcasted, groups)?,
None => DaftConcatAggable::concat(downcasted)?,
};
Ok(result.into_series())
}
DataType::Utf8 => {
let downcasted = self.downcast::<Utf8Array>()?;
match groups {
Some(groups) => {
Ok(DaftConcatAggable::grouped_concat(downcasted, groups)?.into_series())
}
None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()),
if let Some(delimiter) = delimiter.filter(|d| !d.is_empty()) {
let result: Utf8Array = match groups {
Some(groups) => groups
.iter()
.map(|g| {
let iter = g.iter().filter_map(|&idx| downcasted.get(idx as usize));
join_with_delimiter(iter, delimiter)
})
.collect(),
None => {
let output = if downcasted.is_empty() {
Some(String::new())
} else if downcasted.null_count() == downcasted.len() {
None
} else {
let iter = downcasted.into_iter().flatten();
join_with_delimiter(iter, delimiter)
};
std::iter::once(output).collect()
}
};
Ok(result.rename(downcasted.name()).into_series())
} else {
let result = match groups {
Some(groups) => DaftConcatAggable::grouped_concat(downcasted, groups)?,
None => DaftConcatAggable::concat(downcasted)?,
};
Ok(result.into_series())
}
}
_ => Err(DaftError::TypeError(format!(
Expand Down
4 changes: 3 additions & 1 deletion src/daft-dsl/src/expr/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ pub fn extract_agg_expr(expr: &ExprRef) -> DaftResult<AggExpr> {
}
AggExpr::List(e) => AggExpr::List(Expr::Alias(e, name.clone()).into()),
AggExpr::Set(e) => AggExpr::Set(Expr::Alias(e, name.clone()).into()),
AggExpr::Concat(e) => AggExpr::Concat(Expr::Alias(e, name.clone()).into()),
AggExpr::Concat(e, delimiter) => {
AggExpr::Concat(Expr::Alias(e, name.clone()).into(), delimiter)
}
AggExpr::Skew(e) => AggExpr::Skew(Expr::Alias(e, name.clone()).into()),
AggExpr::MapGroups { func, inputs } => AggExpr::MapGroups {
func,
Expand Down
35 changes: 23 additions & 12 deletions src/daft-dsl/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,8 @@ pub enum AggExpr {
#[display("set({_0})")]
Set(ExprRef),

#[display("list({_0})")]
Concat(ExprRef),
#[display("concat({_0}, delimiter={_1:?})")]
Concat(ExprRef, Option<String>),

#[display("skew({_0}")]
Skew(ExprRef),
Expand Down Expand Up @@ -545,7 +545,7 @@ impl AggExpr {
Self::AnyValue(_, _) => "Any Value",
Self::List(_) => "List",
Self::Set(_) => "Set",
Self::Concat(_) => "Concat",
Self::Concat(_, _) => "Concat",
Self::Skew(_) => "Skew",
Self::MapGroups { .. } => "Map Groups",
}
Expand All @@ -570,7 +570,7 @@ impl AggExpr {
| Self::AnyValue(expr, _)
| Self::List(expr)
| Self::Set(expr)
| Self::Concat(expr)
| Self::Concat(expr, _)
| Self::Skew(expr) => expr.name(),
Self::MapGroups { func: _, inputs } => inputs.first().unwrap().name(),
}
Expand Down Expand Up @@ -659,9 +659,9 @@ impl AggExpr {
let child_id = _expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_set()"))
}
Self::Concat(expr) => {
Self::Concat(expr, delimiter) => {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_concat()"))
FieldID::new(format!("{child_id}.local_concat(delimiter={delimiter:?})"))
}
Self::Skew(expr) => {
let child_id = expr.semantic_id(schema);
Expand Down Expand Up @@ -690,7 +690,7 @@ impl AggExpr {
| Self::AnyValue(expr, _)
| Self::List(expr)
| Self::Set(expr)
| Self::Concat(expr)
| Self::Concat(expr, _)
| Self::Skew(expr) => vec![expr.clone()],
Self::MapGroups { func: _, inputs } => inputs.clone(),
}
Expand All @@ -717,7 +717,7 @@ impl AggExpr {
Self::AnyValue(_, ignore_nulls) => Self::AnyValue(first_child(), *ignore_nulls),
Self::List(_) => Self::List(first_child()),
Self::Set(_expr) => Self::Set(first_child()),
Self::Concat(_) => Self::Concat(first_child()),
Self::Concat(_, delimiter) => Self::Concat(first_child(), delimiter.clone()),
Self::Skew(_) => Self::Skew(first_child()),
Self::MapGroups { func, inputs: _ } => Self::MapGroups {
func: func.with_new_children(children.clone()),
Expand Down Expand Up @@ -851,10 +851,20 @@ impl AggExpr {
Ok(Field::new(field.name.as_str(), DataType::Boolean))
}

Self::Concat(expr) => {
Self::Concat(expr, delimiter) => {
let field = expr.to_field(schema)?;
let has_delimiter = delimiter.as_deref().is_some_and(|d| !d.is_empty());
match field.dtype {
DataType::List(..) => Ok(field),
DataType::List(..) => {
if has_delimiter {
Err(DaftError::TypeError(format!(
"Concat Agg delimiter is only supported for Utf8 types, got dtype {} for column \"{}\"",
field.dtype, field.name
)))
} else {
Ok(field)
}
}
DataType::Utf8 => Ok(field),
_ => Err(DaftError::TypeError(format!(
"We can only perform Concat Agg on List or Utf8 types, got dtype {} for column \"{}\"",
Expand Down Expand Up @@ -1161,8 +1171,9 @@ impl Expr {
Self::Agg(AggExpr::Set(self)).into()
}

pub fn agg_concat(self: ExprRef) -> ExprRef {
Self::Agg(AggExpr::Concat(self)).into()
pub fn agg_concat(self: ExprRef, delimiter: Option<String>) -> ExprRef {
let delimiter = delimiter.filter(|d| !d.is_empty());
Self::Agg(AggExpr::Concat(self, delimiter)).into()
}

pub fn row_number() -> ExprRef {
Expand Down
5 changes: 3 additions & 2 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,9 @@ impl PyExpr {
Ok(self.expr.clone().agg_set().into())
}

pub fn agg_concat(&self) -> PyResult<Self> {
Ok(self.expr.clone().agg_concat().into())
#[pyo3(signature = (delimiter=None))]
pub fn agg_concat(&self, delimiter: Option<String>) -> PyResult<Self> {
Ok(self.expr.clone().agg_concat(delimiter).into())
}

pub fn __add__(&self, other: &Self) -> PyResult<Self> {
Expand Down
13 changes: 7 additions & 6 deletions src/daft-local-plan/src/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ pub fn populate_aggregation_stages_bound_with_schema(
}
AggExpr::CountDistinct(expr) => {
let set_agg_col = first_stage!(AggExpr::Set(expr.clone()));
let concat_col = second_stage!(AggExpr::Concat(set_agg_col));
let concat_col = second_stage!(AggExpr::Concat(set_agg_col, None));
final_stage(count_distinct(concat_col));
}
AggExpr::Sum(expr) => {
Expand Down Expand Up @@ -217,17 +217,18 @@ pub fn populate_aggregation_stages_bound_with_schema(
}
AggExpr::List(expr) => {
let list_col = first_stage!(AggExpr::List(expr.clone()));
let concat_col = second_stage!(AggExpr::Concat(list_col));
let concat_col = second_stage!(AggExpr::Concat(list_col, None));
final_stage(concat_col);
}
AggExpr::Set(expr) => {
let set_col = first_stage!(AggExpr::Set(expr.clone()));
let concat_col = second_stage!(AggExpr::Concat(set_col));
let concat_col = second_stage!(AggExpr::Concat(set_col, None));
final_stage(distinct(concat_col));
}
AggExpr::Concat(expr) => {
let concat_col = first_stage!(AggExpr::Concat(expr.clone()));
let global_concat_col = second_stage!(AggExpr::Concat(concat_col));
AggExpr::Concat(expr, delimiter) => {
let concat_col = first_stage!(AggExpr::Concat(expr.clone(), delimiter.clone()));
let global_concat_col =
second_stage!(AggExpr::Concat(concat_col, delimiter.clone()));
final_stage(global_concat_col);
}
AggExpr::Skew(expr) => {
Expand Down
8 changes: 5 additions & 3 deletions src/daft-logical-plan/src/ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -619,9 +619,11 @@ fn replace_column_with_semantic_id_aggexpr(
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::Set, |_| e)
}
AggExpr::Concat(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::Concat, |_| e)
AggExpr::Concat(ref child, ref delimiter) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no(
|transformed_child| AggExpr::Concat(transformed_child, delimiter.clone()),
|_| AggExpr::Concat(child.clone(), delimiter.clone()),
)
}
AggExpr::Skew(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
Expand Down
Loading
Loading