Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1362,7 +1362,7 @@ class PyExpr:
def skew(self) -> PyExpr: ...
def agg_list(self) -> PyExpr: ...
def agg_set(self) -> PyExpr: ...
def agg_concat(self) -> PyExpr: ...
def agg_concat(self, delimiter: str | None = None) -> PyExpr: ...
def over(self, window_spec: WindowSpec) -> PyExpr: ...
def offset(self, offset: int, default: PyExpr | None = None) -> PyExpr: ...
def __add__(self, other: PyExpr) -> PyExpr: ...
Expand Down
17 changes: 10 additions & 7 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3690,13 +3690,16 @@ def agg_set(self, *cols: ColumnInputType) -> "DataFrame":
return self._apply_agg_fn(Expression.list_agg_distinct, cols)

@DataframePublicAPI
def agg_concat(self, *cols: ColumnInputType) -> "DataFrame":
"""Performs a global list concatenation agg on the DataFrame.
def agg_concat(self, *cols: ColumnInputType, delimiter: str | None = None) -> "DataFrame":
"""Performs a global concatenation agg on the DataFrame.

Args:
*cols (Union[str, Expression]): columns that are lists to concatenate
*cols (Union[str, Expression]): columns that are lists or strings to concatenate
delimiter: Optional delimiter to insert between concatenated string values. Only supported for string
columns.

Returns:
DataFrame: Globally aggregated list. Should be a single row.
DataFrame: Globally aggregated list or string. Should be a single row.

Examples:
>>> import daft
Expand All @@ -3714,7 +3717,7 @@ def agg_concat(self, *cols: ColumnInputType) -> "DataFrame":
<BLANKLINE>
(Showing first 1 of 1 rows)
"""
return self._apply_agg_fn(Expression.string_agg, cols)
return self._apply_agg_fn(lambda expr: Expression.string_agg(expr, delimiter=delimiter), cols)

@DataframePublicAPI
def agg(self, *to_agg: Expression | Iterable[Expression]) -> "DataFrame":
Expand Down Expand Up @@ -4914,13 +4917,13 @@ def list_agg_distinct(self, *cols: ColumnInputType) -> DataFrame:
"""
return self.df._apply_agg_fn(Expression.list_agg_distinct, cols, self.group_by)

def string_agg(self, *cols: ColumnInputType) -> DataFrame:
def string_agg(self, *cols: ColumnInputType, delimiter: str | None = None) -> DataFrame:
"""Performs grouped string concat on this GroupedDataFrame.

Returns:
DataFrame: DataFrame with grouped string concatenated per column.
"""
return self.df._apply_agg_fn(Expression.string_agg, cols, self.group_by)
return self.df._apply_agg_fn(lambda expr: Expression.string_agg(expr, delimiter=delimiter), cols, self.group_by)

def agg(self, *to_agg: Expression | Iterable[Expression]) -> DataFrame:
"""Perform aggregations on this GroupedDataFrame. Allows for mixed aggregations.
Expand Down
7 changes: 5 additions & 2 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,15 +1082,18 @@ def list_agg_distinct(self) -> Expression:

return list_agg_distinct(self)

def string_agg(self) -> Expression:
def string_agg(self, delimiter: str | None = None) -> Expression:
"""Aggregates the values in the expression into a single string by concatenating them.

Args:
delimiter: Optional delimiter to insert between concatenated values. Only supported for string columns.

Tip: See Also
[`daft.functions.string_agg`](https://docs.daft.ai/en/stable/api/functions/string_agg/)
"""
from daft.functions import string_agg

return string_agg(self)
return string_agg(self, delimiter=delimiter)

def apply(self, func: Callable[..., Any], return_dtype: DataTypeLike) -> Expression:
"""Apply a function on each value in a given expression.
Expand Down
10 changes: 7 additions & 3 deletions daft/functions/agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,10 @@ def list_agg_distinct(expr: Expression) -> Expression:
return Expression._from_pyexpr(expr._expr.agg_set())


def string_agg(expr: Expression) -> Expression:
"""Aggregates the values in the expression into a single string by concatenating them."""
return Expression._from_pyexpr(expr._expr.agg_concat())
def string_agg(expr: Expression, delimiter: str | None = None) -> Expression:
"""Aggregates the values in the expression into a single string by concatenating them.

Args:
delimiter: Optional delimiter to insert between concatenated values. Only supported for string columns.
"""
return Expression._from_pyexpr(expr._expr.agg_concat(delimiter))
94 changes: 62 additions & 32 deletions src/daft-core/src/series/ops/agg.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use common_error::{DaftError, DaftResult};
use daft_arrow::{array::PrimitiveArray, offset::OffsetsBuffer};
use daft_arrow::offset::OffsetsBuffer;
use itertools::Itertools;

use crate::{
array::{
ListArray,
growable::make_growable,
ops::{
DaftApproxSketchAggable, DaftCountAggable, DaftHllMergeAggable, DaftMeanAggable,
DaftProductAggable, DaftSetAggable, DaftSkewAggable as _, DaftStddevAggable,
DaftSumAggable, DaftVarianceAggable, GroupIndices,
DaftApproxSketchAggable, DaftBoolAggable, DaftConcatAggable, DaftCountAggable,
DaftHllMergeAggable, DaftMeanAggable, DaftMergeSketchAggable, DaftProductAggable,
DaftSetAggable, DaftSkewAggable, DaftStddevAggable, DaftSumAggable,
DaftVarianceAggable, GroupIndices,
},
},
count_mode::CountMode,
Expand Down Expand Up @@ -96,7 +98,6 @@ impl Series {
}

pub fn product(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
use crate::datatypes::try_product_supertype;
match self.data_type() {
// intX -> int64 (in line with numpy)
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
Expand Down Expand Up @@ -187,10 +188,8 @@ impl Series {
}

pub fn merge_sketch(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
use crate::{array::ops::DaftMergeSketchAggable, datatypes::DataType::*};

match self.data_type() {
Struct(_) => match groups {
DataType::Struct(_) => match groups {
Some(groups) => Ok(DaftMergeSketchAggable::grouped_merge_sketch(
&self.struct_()?,
groups,
Expand Down Expand Up @@ -290,18 +289,17 @@ impl Series {
}

pub fn any_value(&self, groups: Option<&GroupIndices>, ignore_nulls: bool) -> DaftResult<Self> {
let indices = match groups {
let indices: UInt64Array = match groups {
Some(groups) => {
if self.data_type().is_null() {
PrimitiveArray::new_null(daft_arrow::datatypes::DataType::UInt64, groups.len())
std::iter::repeat_n(None, groups.len()).collect()
} else if ignore_nulls && let Some(nulls) = self.nulls() {
PrimitiveArray::from_trusted_len_iter(
groups
.iter()
.map(|g| g.iter().find(|i| nulls.is_valid(**i as usize)).copied()),
)
groups
.iter()
.map(|g| g.iter().find(|i| nulls.is_valid(**i as usize)).copied())
.collect()
} else {
PrimitiveArray::from_trusted_len_iter(groups.iter().map(|g| g.first().copied()))
groups.iter().map(|g| g.first().copied()).collect()
}
}
None => {
Expand All @@ -313,11 +311,11 @@ impl Series {
Some(0)
};

PrimitiveArray::from([idx])
std::iter::once(idx).collect()
}
};

self.take(&UInt64Array::from(("", Box::new(indices))))
self.take(&indices)
}

pub fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
Expand All @@ -328,25 +326,58 @@ impl Series {
self.inner.agg_set(groups)
}

pub fn agg_concat(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
use crate::array::ops::DaftConcatAggable;
pub fn agg_concat(
&self,
groups: Option<&GroupIndices>,
delimiter: Option<&str>,
) -> DaftResult<Self> {
match self.data_type() {
DataType::List(..) => {
let downcasted = self.downcast::<ListArray>()?;
match groups {
Some(groups) => {
Ok(DaftConcatAggable::grouped_concat(downcasted, groups)?.into_series())
}
None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()),
let has_delimiter = delimiter.is_some_and(|d| !d.is_empty());
if has_delimiter {
return Err(DaftError::TypeError(
"concat aggregation delimiter is only supported for Utf8".to_string(),
));
}
let downcasted = self.downcast::<ListArray>()?;
let result = match groups {
Some(groups) => DaftConcatAggable::grouped_concat(downcasted, groups)?,
None => DaftConcatAggable::concat(downcasted)?,
};
Ok(result.into_series())
}
DataType::Utf8 => {
let downcasted = self.downcast::<Utf8Array>()?;
match groups {
Some(groups) => {
Ok(DaftConcatAggable::grouped_concat(downcasted, groups)?.into_series())
}
None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()),
if let Some(delimiter) = delimiter.filter(|d| !d.is_empty()) {
let result: Utf8Array = match groups {
Some(groups) => groups
.iter()
.map(|group| {
let values: Vec<_> = group
.iter()
.filter_map(|&idx| downcasted.get(idx as usize))
.collect();
(!values.is_empty()).then(|| values.join(delimiter))
})
.collect(),
None => {
let output = if downcasted.is_empty() {
Some(String::new())
} else if downcasted.null_count() == downcasted.len() {
None
} else {
Some(downcasted.into_iter().flatten().join(delimiter))
};
std::iter::once(output).collect()
}
};
Ok(result.rename(downcasted.name()).into_series())
} else {
let result = match groups {
Some(groups) => DaftConcatAggable::grouped_concat(downcasted, groups)?,
None => DaftConcatAggable::concat(downcasted)?,
};
Ok(result.into_series())
}
}
_ => Err(DaftError::TypeError(format!(
Expand All @@ -357,7 +388,6 @@ impl Series {
}

pub fn bool_and(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
use crate::array::ops::DaftBoolAggable;
match self.data_type() {
DataType::Boolean => {
let downcasted = self.bool()?;
Expand Down
4 changes: 3 additions & 1 deletion src/daft-dsl/src/expr/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ pub fn extract_agg_expr(expr: &ExprRef) -> DaftResult<AggExpr> {
}
AggExpr::List(e) => AggExpr::List(Expr::Alias(e, name.clone()).into()),
AggExpr::Set(e) => AggExpr::Set(Expr::Alias(e, name.clone()).into()),
AggExpr::Concat(e) => AggExpr::Concat(Expr::Alias(e, name.clone()).into()),
AggExpr::Concat(e, delimiter) => {
AggExpr::Concat(Expr::Alias(e, name.clone()).into(), delimiter)
}
AggExpr::Skew(e) => AggExpr::Skew(Expr::Alias(e, name.clone()).into()),
AggExpr::MapGroups { func, inputs } => AggExpr::MapGroups {
func,
Expand Down
35 changes: 23 additions & 12 deletions src/daft-dsl/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,8 @@ pub enum AggExpr {
#[display("set({_0})")]
Set(ExprRef),

#[display("list({_0})")]
Concat(ExprRef),
#[display("concat({_0}, delimiter={_1:?})")]
Concat(ExprRef, Option<String>),

#[display("skew({_0}")]
Skew(ExprRef),
Expand Down Expand Up @@ -549,7 +549,7 @@ impl AggExpr {
Self::AnyValue(_, _) => "Any Value",
Self::List(_) => "List",
Self::Set(_) => "Set",
Self::Concat(_) => "Concat",
Self::Concat(_, _) => "Concat",
Self::Skew(_) => "Skew",
Self::MapGroups { .. } => "Map Groups",
}
Expand All @@ -575,7 +575,7 @@ impl AggExpr {
| Self::AnyValue(expr, _)
| Self::List(expr)
| Self::Set(expr)
| Self::Concat(expr)
| Self::Concat(expr, _)
| Self::Skew(expr) => expr.name(),
Self::MapGroups { func: _, inputs } => inputs.first().unwrap().name(),
}
Expand Down Expand Up @@ -668,9 +668,9 @@ impl AggExpr {
let child_id = _expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_set()"))
}
Self::Concat(expr) => {
Self::Concat(expr, delimiter) => {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_concat()"))
FieldID::new(format!("{child_id}.local_concat(delimiter={delimiter:?})"))
}
Self::Skew(expr) => {
let child_id = expr.semantic_id(schema);
Expand Down Expand Up @@ -700,7 +700,7 @@ impl AggExpr {
| Self::AnyValue(expr, _)
| Self::List(expr)
| Self::Set(expr)
| Self::Concat(expr)
| Self::Concat(expr, _)
| Self::Skew(expr) => vec![expr.clone()],
Self::MapGroups { func: _, inputs } => inputs.clone(),
}
Expand Down Expand Up @@ -728,7 +728,7 @@ impl AggExpr {
Self::AnyValue(_, ignore_nulls) => Self::AnyValue(first_child(), *ignore_nulls),
Self::List(_) => Self::List(first_child()),
Self::Set(_expr) => Self::Set(first_child()),
Self::Concat(_) => Self::Concat(first_child()),
Self::Concat(_, delimiter) => Self::Concat(first_child(), delimiter.clone()),
Self::Skew(_) => Self::Skew(first_child()),
Self::MapGroups { func, inputs: _ } => Self::MapGroups {
func: func.with_new_children(children.clone()),
Expand Down Expand Up @@ -869,10 +869,20 @@ impl AggExpr {
Ok(Field::new(field.name.as_str(), DataType::Boolean))
}

Self::Concat(expr) => {
Self::Concat(expr, delimiter) => {
let field = expr.to_field(schema)?;
let has_delimiter = delimiter.as_deref().is_some_and(|d| !d.is_empty());
match field.dtype {
DataType::List(..) => Ok(field),
DataType::List(..) => {
if has_delimiter {
Err(DaftError::TypeError(format!(
"Concat Agg delimiter is only supported for Utf8 types, got dtype {} for column \"{}\"",
field.dtype, field.name
)))
} else {
Ok(field)
}
}
DataType::Utf8 => Ok(field),
_ => Err(DaftError::TypeError(format!(
"We can only perform Concat Agg on List or Utf8 types, got dtype {} for column \"{}\"",
Expand Down Expand Up @@ -1183,8 +1193,9 @@ impl Expr {
Self::Agg(AggExpr::Set(self)).into()
}

pub fn agg_concat(self: ExprRef) -> ExprRef {
Self::Agg(AggExpr::Concat(self)).into()
pub fn agg_concat(self: ExprRef, delimiter: Option<String>) -> ExprRef {
let delimiter = delimiter.filter(|d| !d.is_empty());
Self::Agg(AggExpr::Concat(self, delimiter)).into()
}

pub fn row_number() -> ExprRef {
Expand Down
5 changes: 3 additions & 2 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,9 @@ impl PyExpr {
Ok(self.expr.clone().agg_set().into())
}

pub fn agg_concat(&self) -> PyResult<Self> {
Ok(self.expr.clone().agg_concat().into())
#[pyo3(signature = (delimiter=None))]
pub fn agg_concat(&self, delimiter: Option<String>) -> PyResult<Self> {
Ok(self.expr.clone().agg_concat(delimiter).into())
}

pub fn __add__(&self, other: &Self) -> PyResult<Self> {
Expand Down
Loading
Loading