Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 101 additions & 5 deletions src/daft-logical-plan/src/builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ use common_treenode::TreeNode;
use daft_algebra::boolean::combine_conjunction;
use daft_core::join::{JoinStrategy, JoinType};
use daft_dsl::{
Column, Expr, ExprRef, UnresolvedColumn, WindowSpec, left_col, resolved_col, right_col,
unresolved_col,
Column, Expr, ExprRef, UnresolvedColumn, WindowSpec, has_agg, left_col, resolved_col,
right_col, unresolved_col,
};
use daft_schema::schema::{Schema, SchemaRef};
use indexmap::IndexSet;
Expand Down Expand Up @@ -241,10 +241,106 @@ impl LogicalPlanBuilder {
.allow_explode(true)
.build();

let to_select = expr_resolver.resolve(to_select, self.plan.clone())?;
// If the select list contains aggregations, treat it as a global aggregation (no group by).
// Resolve such expressions in an aggregation context (which permits aggregations).
let to_select_has_aggs = to_select.iter().any(has_agg);
let to_select = if to_select_has_aggs {
let empty_groupby: Vec<ExprRef> = vec![];
let agg_resolver = ExprResolver::builder()
.in_agg_context(true)
.groupby(&empty_groupby)
.build();
agg_resolver.resolve(to_select, self.plan.clone())?
} else {
expr_resolver.resolve(to_select, self.plan.clone())?
};

let logical_plan: LogicalPlan = ops::Project::try_new(self.plan.clone(), to_select)?.into();
Ok(self.with_new_plan(logical_plan))
// if a SELECT contains aggregations and there is no GROUP BY, treat it as a global aggregation.
if to_select.iter().any(has_agg) {
let schema = self.plan.schema();

// Split into aggregate expressions and foldable non-aggregate expressions.
// Any non-aggregate expression that references input columns is rejected.
// Tracking (internal_semantic_id, user_visible_name) so we don't need nested aliases.
let mut agg_exprs: Vec<(Arc<str>, String, ExprRef)> = vec![];
let mut foldable_exprs: Vec<ExprRef> = vec![];

for expr in to_select {
let original_name = expr.name().to_string();
// Strip alias, but keep the original name for output.
let (base, out_name) = match expr.as_ref() {
Expr::Alias(child, name) => (child.clone(), Some(name.clone())),
_ => (expr.clone(), None),
};

if has_agg(&base) {
// LiftProjectFromAgg expects Aggregate.aggregations to be top-level Expr::Agg (or Alias(Agg)).
// We also ensure unique names by aliasing each agg to its semantic_id.
let agg = match base.as_ref() {
Expr::Agg(_) => base.clone(),
_ => {
return Err(DaftError::ValueError(format!(
"select() with aggregations can only contain aggregation expressions (and foldable literals).\nReceived a nested aggregation expression: {expr}"
)));
}
};

let id = agg.semantic_id(schema.as_ref()).id;
agg_exprs.push((id, original_name, agg));
} else {
// Non-agg: must be foldable (no computation / no input column refs).
// Also validate it against the aggregation context (no groupby) to provide a clear error
// for mixed agg + non-agg column references.
let empty_groupby: Vec<ExprRef> = vec![];
if ExprResolver::builder()
.groupby(&empty_groupby)
.build()
.resolve_single(base.clone(), self.plan.clone())
.is_err()
{
return Err(DaftError::ValueError(
"Expected aggregation (or a foldable literal) in select() only when aggregation expressions are present without groupby.".to_string(),
));
}

if daft_dsl::optimization::requires_computation(base.as_ref()) {
return Err(DaftError::ValueError(
"Expected aggregation (or a foldable literal) in select() only when aggregation expressions are present without groupby.".to_string(),
));
}

// Keep original expr (with alias) for final output.
if let Some(name) = out_name {
foldable_exprs.push(base.alias(&*name));
} else {
foldable_exprs.push(base);
}
}
}

// Build Aggregate first, using semantic-id aliases for internal column names.
let mut agg_internal: Vec<ExprRef> = Vec::with_capacity(agg_exprs.len());
let mut final_exprs: Vec<ExprRef> =
Vec::with_capacity(agg_exprs.len() + foldable_exprs.len());

for (id, user_name, agg) in agg_exprs {
agg_internal.push(agg.alias(id.clone()));
final_exprs.push(resolved_col(id).alias(user_name));
}

let agg_plan: LogicalPlan =
ops::Aggregate::try_new(self.plan.clone(), agg_internal, vec![])?.into();
let agg_plan = Arc::new(agg_plan);

final_exprs.extend(foldable_exprs);

let logical_plan: LogicalPlan = ops::Project::try_new(agg_plan, final_exprs)?.into();
Ok(self.with_new_plan(logical_plan))
} else {
let logical_plan: LogicalPlan =
ops::Project::try_new(self.plan.clone(), to_select)?.into();
Ok(self.with_new_plan(logical_plan))
}
}

pub fn with_columns(&self, columns: Vec<ExprRef>) -> DaftResult<Self> {
Expand Down
2 changes: 1 addition & 1 deletion src/daft-logical-plan/src/builder/resolve_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ impl ExprResolver<'_> {
fn validate_expr(&self, expr: ExprRef) -> DaftResult<ExprRef> {
if has_agg(&expr) {
return Err(DaftError::ValueError(format!(
"Aggregation expressions are currently only allowed in agg, pivot, and window: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383",
"Aggregation expressions are currently only allowed in agg, pivot, window, or as a global aggregation in select(): {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383",
)));
}

Expand Down
59 changes: 59 additions & 0 deletions tests/dataframe/test_select_global_agg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

import pytest

import daft
from daft import col, lit


def test_select_global_agg_returns_single_row() -> None:
df = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})

res = df.select(col("a").sum().alias("sum_a")).collect().to_pydict()

assert res == {"sum_a": [6]}


def test_select_global_agg_allows_literals() -> None:
df = daft.from_pydict({"a": [1, 2, 3]})

res = (
df.select(
col("a").sum().alias("sum_a"),
lit(1).alias("one"),
)
.collect()
.to_pydict()
)

assert res == {"sum_a": [6], "one": [1]}


def test_select_global_agg_allows_multiple_aggs() -> None:
df = daft.from_pydict({"a": [1, 2, 3]})

res = (
df.select(
col("a").sum().alias("sum_a"),
col("a").count().alias("cnt_a"),
)
.collect()
.to_pydict()
)

assert res == {"sum_a": [6], "cnt_a": [3]}


def test_select_global_agg_without_alias() -> None:
df = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})

res = df.select(col("a").sum()).collect().to_pydict()

assert res == {"a": [6]}


def test_select_global_agg_rejects_non_agg_column_reference() -> None:
df = daft.from_pydict({"a": [1, 2, 3]})

with pytest.raises(ValueError, match="Expressions in aggregations"):
df.select(col("a").sum().alias("sum_a"), col("a")).collect()
Loading