Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.WindowFunction;
import org.opensearch.sql.ast.tree.AD;
import org.opensearch.sql.ast.tree.AddColTotals;
import org.opensearch.sql.ast.tree.AddTotals;
import org.opensearch.sql.ast.tree.Aggregation;
import org.opensearch.sql.ast.tree.Append;
import org.opensearch.sql.ast.tree.AppendCol;
Expand Down Expand Up @@ -521,6 +523,16 @@ public LogicalPlan visitEval(Eval node, AnalysisContext context) {
return new LogicalEval(child, expressionsBuilder.build());
}

@Override
public LogicalPlan visitAddTotals(AddTotals node, AnalysisContext context) {
throw getOnlyForCalciteException("addtotals");
}

@Override
public LogicalPlan visitAddColTotals(AddColTotals node, AnalysisContext context) {
throw getOnlyForCalciteException("addcoltotals");
}

/** Build {@link ParseExpression} to context and skip to child nodes. */
@Override
public LogicalPlan visitParse(Parse node, AnalysisContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
import org.opensearch.sql.ast.statement.Query;
import org.opensearch.sql.ast.statement.Statement;
import org.opensearch.sql.ast.tree.AD;
import org.opensearch.sql.ast.tree.AddColTotals;
import org.opensearch.sql.ast.tree.AddTotals;
import org.opensearch.sql.ast.tree.Aggregation;
import org.opensearch.sql.ast.tree.Append;
import org.opensearch.sql.ast.tree.AppendCol;
Expand Down Expand Up @@ -451,4 +453,12 @@ public T visitAppend(Append node, C context) {
public T visitMultisearch(Multisearch node, C context) {
return visitChildren(node, context);
}

public T visitAddTotals(AddTotals node, C context) {
return visitChildren(node, context);
}

public T visitAddColTotals(AddColTotals node, C context) {
return visitChildren(node, context);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Map;
import lombok.*;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Literal;

@Getter
@Setter
@ToString
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
public class AddColTotals extends UnresolvedPlan {
private final List<Field> fieldList;
private final Map<String, Literal> options;
private UnresolvedPlan child;

@Override
public AddColTotals attach(UnresolvedPlan child) {
this.child = child;
return this;
}

@Override
public List<UnresolvedPlan> getChild() {
return child == null ? ImmutableList.of() : ImmutableList.of(child);
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> visitor, C context) {
return visitor.visitAddColTotals(this, context);
}
}
45 changes: 45 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/AddTotals.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Map;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Literal;

@Getter
@Setter
@ToString
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
public class AddTotals extends UnresolvedPlan {
private final List<Field> fieldList;
private final Map<String, Literal> options;
private UnresolvedPlan child;

@Override
public AddTotals attach(UnresolvedPlan child) {
this.child = child;
return this;
}

@Override
public List<UnresolvedPlan> getChild() {
return child == null ? ImmutableList.of() : ImmutableList.of(child);
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> visitor, C context) {
return visitor.visitAddTotals(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@
import org.opensearch.sql.ast.expression.WindowFunction;
import org.opensearch.sql.ast.expression.subquery.SubqueryExpression;
import org.opensearch.sql.ast.tree.AD;
import org.opensearch.sql.ast.tree.AddColTotals;
import org.opensearch.sql.ast.tree.AddTotals;
import org.opensearch.sql.ast.tree.Aggregation;
import org.opensearch.sql.ast.tree.Append;
import org.opensearch.sql.ast.tree.AppendCol;
Expand Down Expand Up @@ -2411,6 +2413,239 @@ private String getAggFieldAlias(UnresolvedExpression aggregateFunction) {
return sb.toString();
}

/** Transforms visitAddTotals command into SQL-based operations. */
@Override
public RelNode visitAddColTotals(AddColTotals node, CalcitePlanContext context) {
visitChildren(node, context);

// Parse options from the AddTotals node
Map<String, Literal> options = node.getOptions();
String label = getOptionValue(options, "label", "Total");
String labelField = getOptionValue(options, "labelfield", null);
// Determine which fields to aggregate

// Handle row=true option: add a new field that sums all specified fields for each row
List<Field> fieldsToAggregate = node.getFieldList();
return buildAddRowTotalAggregate(
context, fieldsToAggregate, false, true, null, labelField, label);
}

public RelNode buildAddRowTotalAggregate(
CalcitePlanContext context,
List<Field> fieldsToAggregate,
boolean addTotalsForEachRow,
boolean addTotalsForEachColumn,
String newColTotalsFieldName,
String labelField,
String label) {

// Build aggregation calls for totals calculation
boolean extraColTotalField = false;
RexNode sumExpression = null;
List<AggCall> aggCalls = new ArrayList<>();
List<String> fieldNameToSum = new ArrayList<>();
RelNode originalData = context.relBuilder.peek();

boolean foundLableField = false;
int labelLength =
(labelField != null) && (labelField.length() > label.length())
? labelField.length()
: label.length();

RelDataType labelVarcharType =
context.relBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR, labelLength);

// If no specific fields specified, use all numeric fields
if (fieldsToAggregate.isEmpty()) {
fieldsToAggregate = getAllNumericFields(originalData, context);
}

List<RexNode> fieldsToSum = new ArrayList<>();
java.util.List<org.apache.calcite.rel.type.RelDataTypeField> fieldList =
originalData.getRowType().getFieldList();
for (RelDataTypeField fieldDataType : fieldList) {
if (shouldAggregateField(fieldDataType.getName(), fieldsToAggregate)) {
RexNode fieldRef = context.relBuilder.field(fieldDataType.getName());
if (isNumericField(fieldRef, context)) {
fieldsToSum.add(fieldRef);
if (addTotalsForEachColumn) {
AggCall sumCall = context.relBuilder.sum(fieldRef).as(fieldDataType.getName());
aggCalls.add(sumCall);
}
fieldNameToSum.add(fieldDataType.getName());
if (addTotalsForEachRow) {
if (sumExpression == null) {
sumExpression = fieldRef;
} else {
sumExpression =
context.relBuilder.call(
org.apache.calcite.sql.fun.SqlStdOperatorTable.PLUS, sumExpression, fieldRef);
}
}
}
}
if (addTotalsForEachColumn && fieldDataType.getName().equals(labelField)) {
// Use specified label field for the label
foundLableField = true;
}
}
if (addTotalsForEachRow && !fieldsToSum.isEmpty()) {
// Add the new column with the sum
context.relBuilder.projectPlus(
context.relBuilder.alias(sumExpression, newColTotalsFieldName));
if (newColTotalsFieldName.equals(labelField)) {
foundLableField = true;
}
}
if (addTotalsForEachColumn) {
if (!foundLableField && (labelField != null)) {
context.relBuilder.projectPlus(
context.relBuilder.alias(
context.relBuilder.getRexBuilder().makeNullLiteral(labelVarcharType), labelField));
extraColTotalField = true;
}
}
originalData = context.relBuilder.build();
context.relBuilder.push(originalData);
if (addTotalsForEachColumn) {
// Perform aggregation (no group by - single totals row)
context.relBuilder.aggregate(
context.relBuilder.groupKey(), // Empty group key for single totals row
aggCalls);
// 3. Build the totals row with proper field order and labels
List<RexNode> selectList = new ArrayList<>();

fieldList = originalData.getRowType().getFieldList();
for (RelDataTypeField fieldDataType : fieldList) {
if (fieldNameToSum.contains(fieldDataType.getName())) {
selectList.add(
context.relBuilder.alias(
context.relBuilder.field(fieldDataType.getName()), fieldDataType.getName()));

} else if (fieldDataType.getName().equals(labelField)
&& (extraColTotalField
|| fieldDataType.getType().getFamily() == SqlTypeFamily.CHARACTER)) {
// Use specified label field for the label - cast to match original field type
RexNode labelLiteral =
context.relBuilder.getRexBuilder().makeLiteral(label, fieldDataType.getType(), true);
selectList.add(context.relBuilder.alias(labelLiteral, fieldDataType.getName()));

} else {
// Other fields get NULL in totals row - cast to match original field type
selectList.add(
context.relBuilder.alias(
context.relBuilder.getRexBuilder().makeNullLiteral(fieldDataType.getType()),
fieldDataType.getName()));
}
}

// Project the totals row with proper field order and labels
context.relBuilder.project(selectList);
RelNode totalsRow = context.relBuilder.build();
// 4. Union original data with totals row
context.relBuilder.push(originalData);
context.relBuilder.push(totalsRow);
context.relBuilder.union(true); // Use UNION ALL to preserve order
}
return context.relBuilder.peek();
}

/** Transforms visitAddTotals command into SQL-based operations. */
@Override
public RelNode visitAddTotals(AddTotals node, CalcitePlanContext context) {
// 1. Process child plan first
visitChildren(node, context);

// Parse options from the AddTotals node
Map<String, Literal> options = node.getOptions();
String label =
getOptionValue(
options, "label", "Total"); // when col=true , add summary event with this label
String labelField =
getOptionValue(
options,
"labelfield",
null); // when col=true , add summary event with this label field at the end of rows
String newColTotalsFieldName =
getOptionValue(
options, "fieldname", "Total"); // when row=true , add new field as new column
boolean addTotalsForEachRow = getBooleanOptionValue(options, "row", true);
boolean addTotalsForEachColumn =
getBooleanOptionValue(options, "col", false); // when col=true/false check

// Determine which fields to aggregate
List<Field> fieldsToAggregate = node.getFieldList();

// Handle row=true option: add a new field that sums all specified fields for each row
return buildAddRowTotalAggregate(
context,
fieldsToAggregate,
addTotalsForEachRow,
addTotalsForEachColumn,
newColTotalsFieldName,
labelField,
label);
}

/** Helper method to extract option values from the options map */
private String getOptionValue(Map<String, Literal> options, String key, String defaultValue) {
if (options.containsKey(key)) {
return options.get(key).toString().replace("\"", "");
}
return defaultValue;
}

/** Helper method to extract boolean option values */
private boolean getBooleanOptionValue(
Map<String, Literal> options, String key, boolean defaultValue) {
if (options.containsKey(key)) {
Object value = options.get(key).getValue();
if (value instanceof Boolean) {
return (Boolean) value;
}
if (value instanceof String) {
return Boolean.parseBoolean((String) value);
}
}
return defaultValue;
}

/** Get all numeric fields from the RelNode */
private List<Field> getAllNumericFields(RelNode relNode, CalcitePlanContext context) {
List<Field> numericFields = new ArrayList<>();
for (String fieldName : relNode.getRowType().getFieldNames()) {
if (isNumericFieldName(fieldName, relNode)) {
numericFields.add(
new Field(new org.opensearch.sql.ast.expression.QualifiedName(fieldName)));
}
}
return numericFields;
}

/** Check if a field should be aggregated based on the field list */
private boolean shouldAggregateField(String fieldName, List<Field> fieldsToAggregate) {
if (fieldsToAggregate.isEmpty()) {
return true; // Aggregate all fields when none specified
}
return fieldsToAggregate.stream()
.anyMatch(field -> field.getField().toString().equals(fieldName));
}

/** Check if a RexNode represents a numeric field */
private boolean isNumericField(RexNode rexNode, CalcitePlanContext context) {
return rexNode.getType().getSqlTypeName().getFamily() == SqlTypeFamily.NUMERIC;
}

/** Check if a field name represents a numeric field in the RelNode */
private boolean isNumericFieldName(String fieldName, RelNode relNode) {
try {
RelDataTypeField field = relNode.getRowType().getField(fieldName, false, false);
return field != null && field.getType().getSqlTypeName().getFamily() == SqlTypeFamily.NUMERIC;
} catch (Exception e) {
return false;
}
}

@Override
public RelNode visitChart(Chart node, CalcitePlanContext context) {
visitChildren(node, context);
Expand Down
2 changes: 2 additions & 0 deletions docs/category.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
"ppl_cli_calcite": [
"user/ppl/cmd/ad.rst",
"user/ppl/cmd/append.rst",
"user/ppl/cmd/addtotals.rst",
"user/ppl/cmd/addcoltotals.rst",
"user/ppl/cmd/bin.rst",
"user/ppl/cmd/dedup.rst",
"user/ppl/cmd/describe.rst",
Expand Down
Loading
Loading