diff --git a/src/daft-logical-plan/src/builder/mod.rs b/src/daft-logical-plan/src/builder/mod.rs index 2e923377d7..8cd36d7713 100644 --- a/src/daft-logical-plan/src/builder/mod.rs +++ b/src/daft-logical-plan/src/builder/mod.rs @@ -18,8 +18,8 @@ use common_treenode::TreeNode; use daft_algebra::boolean::combine_conjunction; use daft_core::join::{JoinStrategy, JoinType}; use daft_dsl::{ - Column, Expr, ExprRef, UnresolvedColumn, WindowSpec, left_col, resolved_col, right_col, - unresolved_col, + Column, Expr, ExprRef, UnresolvedColumn, WindowSpec, has_agg, left_col, resolved_col, + right_col, unresolved_col, }; use daft_schema::schema::{Schema, SchemaRef}; use indexmap::IndexSet; @@ -241,10 +241,106 @@ impl LogicalPlanBuilder { .allow_explode(true) .build(); - let to_select = expr_resolver.resolve(to_select, self.plan.clone())?; + // If the select list contains aggregations, treat it as a global aggregation (no group by). + // Resolve such expressions in an aggregation context (which permits aggregations). + let to_select_has_aggs = to_select.iter().any(has_agg); + let to_select = if to_select_has_aggs { + let empty_groupby: Vec = vec![]; + let agg_resolver = ExprResolver::builder() + .in_agg_context(true) + .groupby(&empty_groupby) + .build(); + agg_resolver.resolve(to_select, self.plan.clone())? + } else { + expr_resolver.resolve(to_select, self.plan.clone())? + }; - let logical_plan: LogicalPlan = ops::Project::try_new(self.plan.clone(), to_select)?.into(); - Ok(self.with_new_plan(logical_plan)) + // if a SELECT contains aggregations and there is no GROUP BY, treat it as a global aggregation. + if to_select.iter().any(has_agg) { + let schema = self.plan.schema(); + + // Split into aggregate expressions and foldable non-aggregate expressions. + // Any non-aggregate expression that references input columns is rejected. + // Tracking (internal_semantic_id, user_visible_name) so we don't need nested aliases. + let mut agg_exprs: Vec<(Arc, String, ExprRef)> = vec![]; + let mut foldable_exprs: Vec = vec![]; + + for expr in to_select { + let original_name = expr.name().to_string(); + // Strip alias, but keep the original name for output. + let (base, out_name) = match expr.as_ref() { + Expr::Alias(child, name) => (child.clone(), Some(name.clone())), + _ => (expr.clone(), None), + }; + + if has_agg(&base) { + // LiftProjectFromAgg expects Aggregate.aggregations to be top-level Expr::Agg (or Alias(Agg)). + // We also ensure unique names by aliasing each agg to its semantic_id. + let agg = match base.as_ref() { + Expr::Agg(_) => base.clone(), + _ => { + return Err(DaftError::ValueError(format!( + "select() with aggregations can only contain aggregation expressions (and foldable literals).\nReceived a nested aggregation expression: {expr}" + ))); + } + }; + + let id = agg.semantic_id(schema.as_ref()).id; + agg_exprs.push((id, original_name, agg)); + } else { + // Non-agg: must be foldable (no computation / no input column refs). + // Also validate it against the aggregation context (no groupby) to provide a clear error + // for mixed agg + non-agg column references. + let empty_groupby: Vec = vec![]; + if ExprResolver::builder() + .groupby(&empty_groupby) + .build() + .resolve_single(base.clone(), self.plan.clone()) + .is_err() + { + return Err(DaftError::ValueError( + "Expected aggregation (or a foldable literal) in select() only when aggregation expressions are present without groupby.".to_string(), + )); + } + + if daft_dsl::optimization::requires_computation(base.as_ref()) { + return Err(DaftError::ValueError( + "Expected aggregation (or a foldable literal) in select() only when aggregation expressions are present without groupby.".to_string(), + )); + } + + // Keep original expr (with alias) for final output. + if let Some(name) = out_name { + foldable_exprs.push(base.alias(&*name)); + } else { + foldable_exprs.push(base); + } + } + } + + // Build Aggregate first, using semantic-id aliases for internal column names. + let mut agg_internal: Vec = Vec::with_capacity(agg_exprs.len()); + let mut final_exprs: Vec = + Vec::with_capacity(agg_exprs.len() + foldable_exprs.len()); + + for (id, user_name, agg) in agg_exprs { + agg_internal.push(agg.alias(id.clone())); + final_exprs.push(resolved_col(id).alias(user_name)); + } + + let agg_plan: LogicalPlan = + ops::Aggregate::try_new(self.plan.clone(), agg_internal, vec![])?.into(); + let agg_plan = Arc::new(agg_plan); + + final_exprs.extend(foldable_exprs); + + let logical_plan: LogicalPlan = ops::Project::try_new(agg_plan, final_exprs)?.into(); + Ok(self.with_new_plan(logical_plan)) + } else { + let logical_plan: LogicalPlan = + ops::Project::try_new(self.plan.clone(), to_select)?.into(); + Ok(self.with_new_plan(logical_plan)) + } } pub fn with_columns(&self, columns: Vec) -> DaftResult { diff --git a/src/daft-logical-plan/src/builder/resolve_expr.rs b/src/daft-logical-plan/src/builder/resolve_expr.rs index 72e34dfaa4..38e501d663 100644 --- a/src/daft-logical-plan/src/builder/resolve_expr.rs +++ b/src/daft-logical-plan/src/builder/resolve_expr.rs @@ -445,7 +445,7 @@ impl ExprResolver<'_> { fn validate_expr(&self, expr: ExprRef) -> DaftResult { if has_agg(&expr) { return Err(DaftError::ValueError(format!( - "Aggregation expressions are currently only allowed in agg, pivot, and window: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383", + "Aggregation expressions are currently only allowed in agg, pivot, window, or as a global aggregation in select(): {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383", ))); } diff --git a/tests/dataframe/test_select_global_agg.py b/tests/dataframe/test_select_global_agg.py new file mode 100644 index 0000000000..cac9a7e2a4 --- /dev/null +++ b/tests/dataframe/test_select_global_agg.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import pytest + +import daft +from daft import col, lit + + +def test_select_global_agg_returns_single_row() -> None: + df = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) + + res = df.select(col("a").sum().alias("sum_a")).collect().to_pydict() + + assert res == {"sum_a": [6]} + + +def test_select_global_agg_allows_literals() -> None: + df = daft.from_pydict({"a": [1, 2, 3]}) + + res = ( + df.select( + col("a").sum().alias("sum_a"), + lit(1).alias("one"), + ) + .collect() + .to_pydict() + ) + + assert res == {"sum_a": [6], "one": [1]} + + +def test_select_global_agg_allows_multiple_aggs() -> None: + df = daft.from_pydict({"a": [1, 2, 3]}) + + res = ( + df.select( + col("a").sum().alias("sum_a"), + col("a").count().alias("cnt_a"), + ) + .collect() + .to_pydict() + ) + + assert res == {"sum_a": [6], "cnt_a": [3]} + + +def test_select_global_agg_without_alias() -> None: + df = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) + + res = df.select(col("a").sum()).collect().to_pydict() + + assert res == {"a": [6]} + + +def test_select_global_agg_rejects_non_agg_column_reference() -> None: + df = daft.from_pydict({"a": [1, 2, 3]}) + + with pytest.raises(ValueError, match="Expressions in aggregations"): + df.select(col("a").sum().alias("sum_a"), col("a")).collect()