diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index f701756559..6770df5c90 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -57,6 +57,8 @@ import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.WindowFunction; import org.opensearch.sql.ast.tree.AD; +import org.opensearch.sql.ast.tree.AddColTotals; +import org.opensearch.sql.ast.tree.AddTotals; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Append; import org.opensearch.sql.ast.tree.AppendCol; @@ -521,6 +523,16 @@ public LogicalPlan visitEval(Eval node, AnalysisContext context) { return new LogicalEval(child, expressionsBuilder.build()); } + @Override + public LogicalPlan visitAddTotals(AddTotals node, AnalysisContext context) { + throw getOnlyForCalciteException("addtotals"); + } + + @Override + public LogicalPlan visitAddColTotals(AddColTotals node, AnalysisContext context) { + throw getOnlyForCalciteException("addcoltotals"); + } + /** Build {@link ParseExpression} to context and skip to child nodes. */ @Override public LogicalPlan visitParse(Parse node, AnalysisContext context) { diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 320723fd57..460322a3c4 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -45,6 +45,8 @@ import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; import org.opensearch.sql.ast.tree.AD; +import org.opensearch.sql.ast.tree.AddColTotals; +import org.opensearch.sql.ast.tree.AddTotals; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Append; import org.opensearch.sql.ast.tree.AppendCol; @@ -451,4 +453,12 @@ public T visitAppend(Append node, C context) { public T visitMultisearch(Multisearch node, C context) { return visitChildren(node, context); } + + public T visitAddTotals(AddTotals node, C context) { + return visitChildren(node, context); + } + + public T visitAddColTotals(AddColTotals node, C context) { + return visitChildren(node, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/AddColTotals.java b/core/src/main/java/org/opensearch/sql/ast/tree/AddColTotals.java new file mode 100644 index 0000000000..91aa4084f1 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/AddColTotals.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import java.util.Map; +import lombok.*; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.Literal; + +@Getter +@Setter +@ToString +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class AddColTotals extends UnresolvedPlan { + private final List fieldList; + private final Map options; + private UnresolvedPlan child; + + @Override + public AddColTotals attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return child == null ? ImmutableList.of() : ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitAddColTotals(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/AddTotals.java b/core/src/main/java/org/opensearch/sql/ast/tree/AddTotals.java new file mode 100644 index 0000000000..93ff5ccfbe --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/AddTotals.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import java.util.Map; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.Literal; + +@Getter +@Setter +@ToString +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class AddTotals extends UnresolvedPlan { + private final List fieldList; + private final Map options; + private UnresolvedPlan child; + + @Override + public AddTotals attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return child == null ? ImmutableList.of() : ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitAddTotals(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index 9408695261..431bfadfbe 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -102,6 +102,8 @@ import org.opensearch.sql.ast.expression.WindowFunction; import org.opensearch.sql.ast.expression.subquery.SubqueryExpression; import org.opensearch.sql.ast.tree.AD; +import org.opensearch.sql.ast.tree.AddColTotals; +import org.opensearch.sql.ast.tree.AddTotals; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Append; import org.opensearch.sql.ast.tree.AppendCol; @@ -2411,6 +2413,239 @@ private String getAggFieldAlias(UnresolvedExpression aggregateFunction) { return sb.toString(); } + /** Transforms visitAddTotals command into SQL-based operations. */ + @Override + public RelNode visitAddColTotals(AddColTotals node, CalcitePlanContext context) { + visitChildren(node, context); + + // Parse options from the AddTotals node + Map options = node.getOptions(); + String label = getOptionValue(options, "label", "Total"); + String labelField = getOptionValue(options, "labelfield", null); + // Determine which fields to aggregate + + // Handle row=true option: add a new field that sums all specified fields for each row + List fieldsToAggregate = node.getFieldList(); + return buildAddRowTotalAggregate( + context, fieldsToAggregate, false, true, null, labelField, label); + } + + public RelNode buildAddRowTotalAggregate( + CalcitePlanContext context, + List fieldsToAggregate, + boolean addTotalsForEachRow, + boolean addTotalsForEachColumn, + String newColTotalsFieldName, + String labelField, + String label) { + + // Build aggregation calls for totals calculation + boolean extraColTotalField = false; + RexNode sumExpression = null; + List aggCalls = new ArrayList<>(); + List fieldNameToSum = new ArrayList<>(); + RelNode originalData = context.relBuilder.peek(); + + boolean foundLableField = false; + int labelLength = + (labelField != null) && (labelField.length() > label.length()) + ? labelField.length() + : label.length(); + + RelDataType labelVarcharType = + context.relBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR, labelLength); + + // If no specific fields specified, use all numeric fields + if (fieldsToAggregate.isEmpty()) { + fieldsToAggregate = getAllNumericFields(originalData, context); + } + + List fieldsToSum = new ArrayList<>(); + java.util.List fieldList = + originalData.getRowType().getFieldList(); + for (RelDataTypeField fieldDataType : fieldList) { + if (shouldAggregateField(fieldDataType.getName(), fieldsToAggregate)) { + RexNode fieldRef = context.relBuilder.field(fieldDataType.getName()); + if (isNumericField(fieldRef, context)) { + fieldsToSum.add(fieldRef); + if (addTotalsForEachColumn) { + AggCall sumCall = context.relBuilder.sum(fieldRef).as(fieldDataType.getName()); + aggCalls.add(sumCall); + } + fieldNameToSum.add(fieldDataType.getName()); + if (addTotalsForEachRow) { + if (sumExpression == null) { + sumExpression = fieldRef; + } else { + sumExpression = + context.relBuilder.call( + org.apache.calcite.sql.fun.SqlStdOperatorTable.PLUS, sumExpression, fieldRef); + } + } + } + } + if (addTotalsForEachColumn && fieldDataType.getName().equals(labelField)) { + // Use specified label field for the label + foundLableField = true; + } + } + if (addTotalsForEachRow && !fieldsToSum.isEmpty()) { + // Add the new column with the sum + context.relBuilder.projectPlus( + context.relBuilder.alias(sumExpression, newColTotalsFieldName)); + if (newColTotalsFieldName.equals(labelField)) { + foundLableField = true; + } + } + if (addTotalsForEachColumn) { + if (!foundLableField && (labelField != null)) { + context.relBuilder.projectPlus( + context.relBuilder.alias( + context.relBuilder.getRexBuilder().makeNullLiteral(labelVarcharType), labelField)); + extraColTotalField = true; + } + } + originalData = context.relBuilder.build(); + context.relBuilder.push(originalData); + if (addTotalsForEachColumn) { + // Perform aggregation (no group by - single totals row) + context.relBuilder.aggregate( + context.relBuilder.groupKey(), // Empty group key for single totals row + aggCalls); + // 3. Build the totals row with proper field order and labels + List selectList = new ArrayList<>(); + + fieldList = originalData.getRowType().getFieldList(); + for (RelDataTypeField fieldDataType : fieldList) { + if (fieldNameToSum.contains(fieldDataType.getName())) { + selectList.add( + context.relBuilder.alias( + context.relBuilder.field(fieldDataType.getName()), fieldDataType.getName())); + + } else if (fieldDataType.getName().equals(labelField) + && (extraColTotalField + || fieldDataType.getType().getFamily() == SqlTypeFamily.CHARACTER)) { + // Use specified label field for the label - cast to match original field type + RexNode labelLiteral = + context.relBuilder.getRexBuilder().makeLiteral(label, fieldDataType.getType(), true); + selectList.add(context.relBuilder.alias(labelLiteral, fieldDataType.getName())); + + } else { + // Other fields get NULL in totals row - cast to match original field type + selectList.add( + context.relBuilder.alias( + context.relBuilder.getRexBuilder().makeNullLiteral(fieldDataType.getType()), + fieldDataType.getName())); + } + } + + // Project the totals row with proper field order and labels + context.relBuilder.project(selectList); + RelNode totalsRow = context.relBuilder.build(); + // 4. Union original data with totals row + context.relBuilder.push(originalData); + context.relBuilder.push(totalsRow); + context.relBuilder.union(true); // Use UNION ALL to preserve order + } + return context.relBuilder.peek(); + } + + /** Transforms visitAddTotals command into SQL-based operations. */ + @Override + public RelNode visitAddTotals(AddTotals node, CalcitePlanContext context) { + // 1. Process child plan first + visitChildren(node, context); + + // Parse options from the AddTotals node + Map options = node.getOptions(); + String label = + getOptionValue( + options, "label", "Total"); // when col=true , add summary event with this label + String labelField = + getOptionValue( + options, + "labelfield", + null); // when col=true , add summary event with this label field at the end of rows + String newColTotalsFieldName = + getOptionValue( + options, "fieldname", "Total"); // when row=true , add new field as new column + boolean addTotalsForEachRow = getBooleanOptionValue(options, "row", true); + boolean addTotalsForEachColumn = + getBooleanOptionValue(options, "col", false); // when col=true/false check + + // Determine which fields to aggregate + List fieldsToAggregate = node.getFieldList(); + + // Handle row=true option: add a new field that sums all specified fields for each row + return buildAddRowTotalAggregate( + context, + fieldsToAggregate, + addTotalsForEachRow, + addTotalsForEachColumn, + newColTotalsFieldName, + labelField, + label); + } + + /** Helper method to extract option values from the options map */ + private String getOptionValue(Map options, String key, String defaultValue) { + if (options.containsKey(key)) { + return options.get(key).toString().replace("\"", ""); + } + return defaultValue; + } + + /** Helper method to extract boolean option values */ + private boolean getBooleanOptionValue( + Map options, String key, boolean defaultValue) { + if (options.containsKey(key)) { + Object value = options.get(key).getValue(); + if (value instanceof Boolean) { + return (Boolean) value; + } + if (value instanceof String) { + return Boolean.parseBoolean((String) value); + } + } + return defaultValue; + } + + /** Get all numeric fields from the RelNode */ + private List getAllNumericFields(RelNode relNode, CalcitePlanContext context) { + List numericFields = new ArrayList<>(); + for (String fieldName : relNode.getRowType().getFieldNames()) { + if (isNumericFieldName(fieldName, relNode)) { + numericFields.add( + new Field(new org.opensearch.sql.ast.expression.QualifiedName(fieldName))); + } + } + return numericFields; + } + + /** Check if a field should be aggregated based on the field list */ + private boolean shouldAggregateField(String fieldName, List fieldsToAggregate) { + if (fieldsToAggregate.isEmpty()) { + return true; // Aggregate all fields when none specified + } + return fieldsToAggregate.stream() + .anyMatch(field -> field.getField().toString().equals(fieldName)); + } + + /** Check if a RexNode represents a numeric field */ + private boolean isNumericField(RexNode rexNode, CalcitePlanContext context) { + return rexNode.getType().getSqlTypeName().getFamily() == SqlTypeFamily.NUMERIC; + } + + /** Check if a field name represents a numeric field in the RelNode */ + private boolean isNumericFieldName(String fieldName, RelNode relNode) { + try { + RelDataTypeField field = relNode.getRowType().getField(fieldName, false, false); + return field != null && field.getType().getSqlTypeName().getFamily() == SqlTypeFamily.NUMERIC; + } catch (Exception e) { + return false; + } + } + @Override public RelNode visitChart(Chart node, CalcitePlanContext context) { visitChildren(node, context); diff --git a/docs/category.json b/docs/category.json index f126904da6..87d0e9e0b4 100644 --- a/docs/category.json +++ b/docs/category.json @@ -24,6 +24,8 @@ "ppl_cli_calcite": [ "user/ppl/cmd/ad.rst", "user/ppl/cmd/append.rst", + "user/ppl/cmd/addtotals.rst", + "user/ppl/cmd/addcoltotals.rst", "user/ppl/cmd/bin.rst", "user/ppl/cmd/dedup.rst", "user/ppl/cmd/describe.rst", diff --git a/docs/user/ppl/cmd/addcoltotals.rst b/docs/user/ppl/cmd/addcoltotals.rst new file mode 100644 index 0000000000..9bef4e02e8 --- /dev/null +++ b/docs/user/ppl/cmd/addcoltotals.rst @@ -0,0 +1,86 @@ +========== +AddColTotals +========== + +.. rubric:: Table of contents + +.. contents:: + :local: + :depth: 2 + +Description +=========== + +The ``addcoltotals`` command computes the sum of each column and add a summary event at the end to show the total of each column. This command works the same way ``addtotals`` command works with row=false and col=true option. This is useful for creating summary reports with subtotals or grand totals. +The ``addcoltotals`` command only sums numeric fields (integers, floats, doubles). Non-numeric fields in the field list are ignored even if its specified in field-list or in the case of no field-list specified. + +Syntax +====== + +``addcoltotals [field-list] [label=] [labelfield=] `` + +* ``field-list``: Optional. Comma-separated list of numeric fields to sum. If not specified, all numeric fields are summed. +* ``labelfield=``: Optional. Field name to place the label. If it specifies a non-existing field, adds the field and shows label at the summary event row at this field. This is applicable when col=true. +* ``label=``: Optional. Custom text for the totals row labelfield's label. Default is "Total". This is applicable when col=true. + + +Example 1: Basic Example +========================= + +The example shows placing the label in an existing field. + +PPL query:: + + os> source=accounts | fields firstname, balance | head 3 | addcoltotals labelfield='firstname'; + fetched rows / total rows = 4/4 + +-----------+---------+ + | firstname | balance | + |-----------+---------| + | Amber | 39225 | + | Hattie | 5686 | + | Nanette | 32838 | + | Total | 77749 | + +-----------+---------+ + + +Example 2: Adding column totals and adding a summary event with label specified. +================================================================================= + +The example shows adding totals after a stats command where final summary event label is 'Sum' and row=true value was used by default when not specified. It also added new field +specified by labelfield as it did not match existing field. + + +PPL query:: + + os> source=accounts | stats count() by gender | addcoltotals `count()` label='Sum' labelfield='Total'; + fetched rows / total rows = 3/3 + +---------+--------+-------+ + | count() | gender | Total | + |---------+--------+-------| + | 1 | F | null | + | 3 | M | null | + | 4 | null | Sum | + +---------+--------+-------+ + +Example 3: With all options +============================ + +The example shows using addcoltotals with all options set. + +PPL query:: + + os> source=accounts | where age > 30 | stats avg(balance) as avg_balance, count() as count by state | head 3 | addcoltotals avg_balance, count label='Sum' labelfield='Column Total'; + fetched rows / total rows = 4/4 + +-------------+-------+-------+--------------+ + | avg_balance | count | state | Column Total | + |-------------+-------+-------+--------------| + | 39225.0 | 1 | IL | null | + | 4180.0 | 1 | MD | null | + | 5686.0 | 1 | TN | null | + | 49091.0 | 3 | null | Sum | + +-------------+-------+-------+--------------+ + + + + + diff --git a/docs/user/ppl/cmd/addtotals.rst b/docs/user/ppl/cmd/addtotals.rst new file mode 100644 index 0000000000..58bc7f0a4e --- /dev/null +++ b/docs/user/ppl/cmd/addtotals.rst @@ -0,0 +1,110 @@ +========== +AddTotals +========== + +.. rubric:: Table of contents + +.. contents:: + :local: + :depth: 2 + +Description +=========== + +The ``addtotals`` command computes the sum of numeric fields and appends a row with the totals to the result. The command can also add row totals and add a field to store row totals. This is useful for creating summary reports with subtotals or grand totals. +The ``addtotals`` command only sums numeric fields (integers, floats, doubles). Non-numeric fields in the field list are ignored even if its specified in field-list or in the case of no field-list specified. + +Syntax +====== + +``addtotals [field-list] [label=] [labelfield=] [row=] [col=] [fieldname=]`` + +* ``field-list``: Optional. Comma-separated list of numeric fields to sum. If not specified, all numeric fields are summed. +* ``row=``: Optional. Calculates total of each row and add a new field with the total. Default is true. +* ``col=``: Optional. Calculates total of each column and add a new event at the end of all events with the total. Default is false. +* ``labelfield=``: Optional. Field name to place the label. If it specifies a non-existing field, adds the field and shows label at the summary event row at this field. This is applicable when col=true. +* ``label=``: Optional. Custom text for the totals row labelfield's label. Default is "Total". This is applicable when col=true. This does not have any effect when labelfield and fieldname parameter both have same value. +* ``fieldname=``: Optional. Calculates total of each row and add a new field to store this total. This is applicable when row=true. + + +Example 1: Basic Example +========================= + +The example shows placing the label in an existing field. + +PPL query:: + + os> source=accounts | head 3|fields firstname, balance | addtotals col=true labelfield='firstname' label='Total'; + fetched rows / total rows = 4/4 + +-----------+---------+-------+ + | firstname | balance | Total | + |-----------+---------+-------| + | Amber | 39225 | 39225 | + | Hattie | 5686 | 5686 | + | Nanette | 32838 | 32838 | + | Total | 77749 | null | + +-----------+---------+-------+ + +Example 2: Adding column totals and adding a summary event with label specified. +================================================================================= + +The example shows adding totals after a stats command where final summary event label is 'Sum' and row=true value was used by default when not specified. It also added new field +specified by labelfield as it did not match existing field. + + +PPL query:: + + os> source=accounts | addtotals col=true row=false label='Sum' labelfield='Total'; + fetched rows / total rows = 5/5 + +----------------+-----------+----------------------+---------+--------+--------+----------+-------+-----+-----------------------+----------+-------+ + | account_number | firstname | address | balance | gender | city | employer | state | age | email | lastname | Total | + |----------------+-----------+----------------------+---------+--------+--------+----------+-------+-----+-----------------------+----------+-------| + | 1 | Amber | 880 Holmes Lane | 39225 | M | Brogan | Pyrami | IL | 32 | amberduke@pyrami.com | Duke | null | + | 6 | Hattie | 671 Bristol Street | 5686 | M | Dante | Netagy | TN | 36 | hattiebond@netagy.com | Bond | null | + | 13 | Nanette | 789 Madison Street | 32838 | F | Nogal | Quility | VA | 28 | null | Bates | null | + | 18 | Dale | 467 Hutchinson Court | 4180 | M | Orick | null | MD | 33 | daleadams@boink.com | Adams | null | + | 38 | null | null | 81929 | null | null | null | null | 129 | null | null | Sum | + +----------------+-----------+----------------------+---------+--------+--------+----------+-------+-----+-----------------------+----------+-------+ + + + +if row=true, there will be conflict between column added for column totals and column added for row totals being same field 'Total', in that case the output will have final event row label null instead of 'Sum' because the column is number type and it cannot output String in number type column. +PPL query:: + + os> source=accounts | addtotals col=true row=true label='Sum' labelfield='Total'; + fetched rows / total rows = 5/5 + +----------------+-----------+----------------------+---------+--------+--------+----------+-------+-----+-----------------------+----------+---------+ + | account_number | firstname | address | balance | gender | city | employer | state | age | email | lastname | Total | + |----------------+-----------+----------------------+---------+--------+--------+----------+-------+-----+-----------------------+----------+---------| + | 1 | Amber | 880 Holmes Lane | 39225 | M | Brogan | Pyrami | IL | 32 | amberduke@pyrami.com | Duke | 39258.0 | + | 6 | Hattie | 671 Bristol Street | 5686 | M | Dante | Netagy | TN | 36 | hattiebond@netagy.com | Bond | 5728.0 | + | 13 | Nanette | 789 Madison Street | 32838 | F | Nogal | Quility | VA | 28 | null | Bates | 32879.0 | + | 18 | Dale | 467 Hutchinson Court | 4180 | M | Orick | null | MD | 33 | daleadams@boink.com | Adams | 4231.0 | + | 38 | null | null | 81929 | null | null | null | null | 129 | null | null | null | + +----------------+-----------+----------------------+---------+--------+--------+----------+-------+-----+-----------------------+----------+---------+ + + + + +Example 3: With all options +============================ + +The example shows using addtotals with all options set. + +PPL query:: + + os> source=accounts | where age > 30 | stats avg(balance) as avg_balance, count() as count by state | head 3 | addtotals avg_balance, count row=true col=true fieldname='Row Total' label='Sum' labelfield='Column Total'; + fetched rows / total rows = 4/4 + +-------------+-------+-------+-----------+--------------+ + | avg_balance | count | state | Row Total | Column Total | + |-------------+-------+-------+-----------+--------------| + | 39225.0 | 1 | IL | 39226.0 | null | + | 4180.0 | 1 | MD | 4181.0 | null | + | 5686.0 | 1 | TN | 5687.0 | null | + | 49091.0 | 3 | null | null | Sum | + +-------------+-------+-------+-----------+--------------+ + + + + + diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/AddColTotalsCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/AddColTotalsCommandIT.java new file mode 100644 index 0000000000..139e5eec30 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/AddColTotalsCommandIT.java @@ -0,0 +1,190 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ACCOUNT; +import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; + +import java.io.IOException; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; + +public class AddColTotalsCommandIT extends PPLIntegTestCase { + + @Override + public void init() throws IOException { + loadIndex(Index.ACCOUNT); + loadIndex(Index.BANK); + } + + @Test + public void testAddColTotalsTotalWithTotalField() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | fields age, balance | addcoltotals", + TEST_INDEX_ACCOUNT)); + + // Verify that we get original rows plus totals row + verifySchema(result, schema("age", "bigint"), schema("balance", "bigint")); + + // Should have original data plus one totals row + var dataRows = result.getJSONArray("datarows"); + // Iterate through all data rows + ArrayList field_indexes = new ArrayList<>(); + field_indexes.add(0); + field_indexes.add(1); + + verifyColTotals(dataRows, field_indexes, null); + } + + @Test + public void testAddColTotalsRowWithSpecificFields() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | fields age, balance | addcoltotals balance", + TEST_INDEX_ACCOUNT)); + + // Verify that we get original rows plus totals row + verifySchema(result, schema("age", "bigint"), schema("balance", "bigint")); + + var dataRows = result.getJSONArray("datarows"); + ArrayList field_indexes = new ArrayList<>(); + field_indexes.add(1); + + verifyColTotals(dataRows, field_indexes, null); + } + + public static boolean isNumeric(String str) { + return str != null && str.matches("-?\\d+(\\.\\d+)?"); + } + + public void verifyColTotals( + org.json.JSONArray dataRows, List field_indexes, String finalSummaryEventLevel) { + + BigDecimal[] cColTotals = new BigDecimal[field_indexes.size()]; + for (int i = 0; i < dataRows.length() - 1; i++) { + var row = dataRows.getJSONArray(i); + + // Iterate through each field in the row + for (int j = 0; j < field_indexes.size(); j++) { + + int colIndex = field_indexes.get(j); + if (cColTotals[j] == null) { + cColTotals[j] = new BigDecimal(0); + } + Object value = row.isNull(colIndex) ? 0 : row.get(colIndex); + if (value instanceof Integer) { + cColTotals[j] = cColTotals[j].add(new BigDecimal((Integer) (value))); + } else if (value instanceof Double) { + cColTotals[j] = cColTotals[j].add(new BigDecimal((Double) (value))); + } else if (value instanceof BigDecimal) { + cColTotals[j] = cColTotals[j].add((BigDecimal) value); + + } else if (value instanceof String) { + if (AddColTotalsCommandIT.isNumeric((String) value)) { + cColTotals[j] = cColTotals[j].add(new BigDecimal((String) (value))); + } + } + } + } + var total_row = dataRows.getJSONArray((dataRows.length() - 1)); + for (int j = 0; j < field_indexes.size(); j++) { + int colIndex = field_indexes.get(j); + BigDecimal foundTotal = total_row.getBigDecimal(colIndex); + assertEquals(foundTotal.doubleValue(), cColTotals[j].doubleValue(), 0.000001); + } + if (finalSummaryEventLevel != null) { + String foundSummaryEventLabel = total_row.getString(total_row.length() - 1); + + assertEquals(foundSummaryEventLabel, finalSummaryEventLevel); + } + } + + @Test + public void testAddColTotalsRowFieldsNonNumeric() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 |fields age address balance | addcoltotals ", + TEST_INDEX_ACCOUNT)); + + // Verify that we get original rows plus totals row + verifySchema( + result, schema("age", "bigint"), schema("address", "string"), schema("balance", "bigint")); + + var dataRows = result.getJSONArray("datarows"); + ArrayList field_indexes = new ArrayList<>(); + field_indexes.add(0); + field_indexes.add(2); + + verifyColTotals(dataRows, field_indexes, null); + } + + @Test + public void testAddColTotalsWithCustomLabel() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | head 2|fields age, balance | addcoltotals label='Sum'" + + " labelfield='Grand Total'", + TEST_INDEX_ACCOUNT)); + + verifySchema( + result, + schema("age", "bigint"), + schema("balance", "bigint"), + schema("Grand Total", "string")); + + var dataRows = result.getJSONArray("datarows"); + ArrayList field_indexes = new ArrayList<>(); + field_indexes.add(0); + field_indexes.add(1); + + verifyColTotals(dataRows, field_indexes, "Sum"); + } + + @Test + public void testAddColTotalsWithNoData() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 1000 | fields age, balance | addcoltotals", + TEST_INDEX_ACCOUNT)); + + // Should still have totals row even with no input data + var dataRows = result.getJSONArray("datarows"); + assertEquals(1, dataRows.length()); // Only totals row + } + + @Test + public void testAddColTotalsWithLabelAndLabelField() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 |head 3| fields age, balance,firstname | addcoltotals " + + " age balance label='Sum' labelfield='firstname'", + TEST_INDEX_ACCOUNT)); + + // Verify schema includes custom fieldname + verifySchema( + result, + schema("age", "bigint"), + schema("balance", "bigint"), + schema("firstname", "string")); + + var dataRows = result.getJSONArray("datarows"); + ArrayList field_indexes = new ArrayList<>(); + field_indexes.add(0); + field_indexes.add(1); + + verifyColTotals(dataRows, field_indexes, "Sum"); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/AddTotalsCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/AddTotalsCommandIT.java new file mode 100644 index 0000000000..bc2c1f5cc8 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/AddTotalsCommandIT.java @@ -0,0 +1,366 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ACCOUNT; +import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; + +import java.io.IOException; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; + +public class AddTotalsCommandIT extends PPLIntegTestCase { + + @Override + public void init() throws IOException { + loadIndex(Index.ACCOUNT); + loadIndex(Index.BANK); + } + + /** + * default test without parameters + * + * @throws IOException + */ + @Test + public void testAddTotalsTotalWithTotalField() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | fields age, balance | addtotals", + TEST_INDEX_ACCOUNT)); + + // Verify that we get original rows plus totals row + verifySchema( + result, schema("age", "bigint"), schema("balance", "bigint"), schema("Total", "bigint")); + + // Should have original data plus one totals row + var dataRows = result.getJSONArray("datarows"); + // Iterate through all data rows + for (int i = 0; i < dataRows.length(); i++) { + var row = dataRows.getJSONArray(i); + + BigDecimal cRowTotal = new BigDecimal(0); + // Iterate through each field in the row + for (int j = 0; j < row.length() - 1; j++) { + Object value = row.isNull(j) ? 0 : row.get(j); + if (value instanceof Integer) { + cRowTotal = cRowTotal.add(new BigDecimal((Integer) (value))); + } else if (value instanceof Double) { + cRowTotal = cRowTotal.add(new BigDecimal((Double) (value))); + } else if (value instanceof String) { + cRowTotal = cRowTotal.add(new BigDecimal((String) (value))); + } + } + BigDecimal foundTotal = row.getBigDecimal(row.length() - 1); + assertEquals(foundTotal.doubleValue(), cRowTotal.doubleValue(), 0.000001); + } + } + + @Test + public void testAddTotalsRowWithSpecificFields() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | fields age, balance | addtotals balance", + TEST_INDEX_ACCOUNT)); + + // Verify that we get original rows plus totals row + verifySchema( + result, schema("age", "bigint"), schema("balance", "bigint"), schema("Total", "bigint")); + + // sum for balance, "Total" for label + var dataRows = result.getJSONArray("datarows"); + // Iterate through all data rows + for (int i = 0; i < dataRows.length(); i++) { + var row = dataRows.getJSONArray(i); + + BigDecimal cRowTotal = new BigDecimal(0); + // Iterate through each field in the row + + Object value = row.isNull(1) ? 0 : row.get(1); + if (value instanceof Integer) { + cRowTotal = cRowTotal.add(new BigDecimal((Integer) (value))); + } else if (value instanceof Double) { + cRowTotal = cRowTotal.add(new BigDecimal((Double) (value))); + } else if (value instanceof String) { + cRowTotal = cRowTotal.add(new BigDecimal((String) (value))); + } + + BigDecimal foundTotal = row.getBigDecimal(row.length() - 1); + assertEquals(foundTotal.doubleValue(), cRowTotal.doubleValue(), 0.000001); + } + } + + public static boolean isNumeric(String str) { + return str != null && str.matches("-?\\d+(\\.\\d+)?"); + } + + public void compareDataRowTotals( + org.json.JSONArray dataRows, List field_indexes, int totalColIndex) { + for (int i = 0; i < dataRows.length(); i++) { + var row = dataRows.getJSONArray(i); + + BigDecimal cRowTotal = new BigDecimal(0); + // Iterate through each field in the row + for (int j = 0; j < field_indexes.size(); j++) { + int colIndex = field_indexes.get(j); + Object value = row.isNull(colIndex) ? 0 : row.get(colIndex); + if (value instanceof Integer) { + cRowTotal = cRowTotal.add(new BigDecimal((Integer) (value))); + } else if (value instanceof Double) { + cRowTotal = cRowTotal.add(new BigDecimal((Double) (value))); + } else if (value instanceof BigDecimal) { + cRowTotal = cRowTotal.add((BigDecimal) value); + + } else if (value instanceof String) { + if (AddTotalsCommandIT.isNumeric((String) value)) { + cRowTotal = cRowTotal.add(new BigDecimal((String) (value))); + } + } + } + BigDecimal foundTotal = row.getBigDecimal(totalColIndex); + assertEquals(foundTotal.doubleValue(), cRowTotal.doubleValue(), 0.000001); + } + } + + public void verifyColTotals( + org.json.JSONArray dataRows, List field_indexes, String finalSummaryEventLevel) { + + BigDecimal[] cColTotals = new BigDecimal[field_indexes.size()]; + for (int i = 0; i < dataRows.length() - 1; i++) { + var row = dataRows.getJSONArray(i); + + // Iterate through each field in the row + for (int j = 0; j < field_indexes.size(); j++) { + + int colIndex = field_indexes.get(j); + if (cColTotals[j] == null) { + cColTotals[j] = new BigDecimal(0); + } + Object value = row.isNull(colIndex) ? 0 : row.get(colIndex); + if (value instanceof Integer) { + cColTotals[j] = cColTotals[j].add(new BigDecimal((Integer) (value))); + } else if (value instanceof Double) { + cColTotals[j] = cColTotals[j].add(new BigDecimal((Double) (value))); + } else if (value instanceof BigDecimal) { + cColTotals[j] = cColTotals[j].add((BigDecimal) value); + + } else if (value instanceof String) { + if (AddTotalsCommandIT.isNumeric((String) value)) { + cColTotals[j] = cColTotals[j].add(new BigDecimal((String) (value))); + } + } + } + } + var total_row = dataRows.getJSONArray((dataRows.length() - 1)); + for (int j = 0; j < field_indexes.size(); j++) { + int colIndex = field_indexes.get(j); + BigDecimal foundTotal = total_row.getBigDecimal(colIndex); + assertEquals(foundTotal.doubleValue(), cColTotals[j].doubleValue(), 0.000001); + } + String foundSummaryEventLabel = total_row.getString(total_row.length() - 1); + assertEquals(foundSummaryEventLabel, finalSummaryEventLevel); + } + + @Test + public void testAddTotalsRowFieldsNonNumeric() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 |fields age address balance | addtotals ", + TEST_INDEX_ACCOUNT)); + + // Verify that we get original rows plus totals row + verifySchema( + result, + schema("age", "bigint"), + schema("address", "string"), + schema("balance", "bigint"), + schema("Total", "bigint")); + + // sum for balance, "Total" for label + // Should have original data plus one totals row + var dataRows = result.getJSONArray("datarows"); + // Iterate through all data rows + for (int i = 0; i < dataRows.length(); i++) { + var row = dataRows.getJSONArray(i); + + BigDecimal cRowTotal = new BigDecimal(0); + // Iterate through each field in the row + for (int j = 0; j < row.length() - 1; j++) { + Object value = row.isNull(j) ? 0 : row.get(j); + if (value instanceof Integer) { + cRowTotal = cRowTotal.add(new BigDecimal((Integer) (value))); + } else if (value instanceof Double) { + cRowTotal = cRowTotal.add(new BigDecimal((Double) (value))); + } else if (value instanceof String) { + if (AddTotalsCommandIT.isNumeric((String) value)) { + cRowTotal = cRowTotal.add(new BigDecimal((String) (value))); + } + } + } + BigDecimal foundTotal = row.getBigDecimal(row.length() - 1); + assertEquals(foundTotal.doubleValue(), cRowTotal.doubleValue(), 0.000001); + } + } + + @Test + public void testAddTotalsWithCustomLabel() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | head 2|fields age, balance | addtotals" + + " fieldname='Grand Total'", + TEST_INDEX_ACCOUNT)); + + verifySchema( + result, + schema("age", "bigint"), + schema("balance", "bigint"), + schema("Grand Total", "bigint")); + } + + @Test + public void testAddTotalsAfterStats() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | stats count() by gender | addtotals `count()`", TEST_INDEX_ACCOUNT)); + + var dataRows = result.getJSONArray("datarows"); + for (int i = 0; i < dataRows.length(); i++) { + var row = dataRows.getJSONArray(i); + assertEquals(row.get(0), row.get(2)); + } + } + + @Test + public void testAddTotalsWithNoData() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 1000 | fields age, balance | addtotals", + TEST_INDEX_ACCOUNT)); + + // Should still have totals row even with no input data + var dataRows = result.getJSONArray("datarows"); + assertEquals(0, dataRows.length()); // Only totals row + } + + @Test + public void testAddTotalsInComplexPipeline() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | stats avg(balance) as avg_balance, count() as" + + " total_count by gender | addtotals avg_balance, total_count", + TEST_INDEX_ACCOUNT)); + + var dataRows = result.getJSONArray("datarows"); + ArrayList field_indexes = new ArrayList<>(); + field_indexes.add(0); + field_indexes.add(1); + + compareDataRowTotals(dataRows, field_indexes, 3); + } + + @Test + public void testAddTotalsWithRowFalse() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | fields age, balance | addtotals row=false", + TEST_INDEX_ACCOUNT)); + + // With row=false, should not append totals row + var dataRows = result.getJSONArray("datarows"); + + // Verify that no totals row was added - all rows should have actual data + for (int i = 0; i < dataRows.length(); i++) { + var row = dataRows.getJSONArray(i); + // None of these rows should have "Total" label + for (int j = 0; j < row.length(); j++) { + if (!row.isNull(j) && row.get(j).equals("Total")) { + fail("Found totals row when row=false was specified"); + } + } + } + } + + @Test + public void testAddTotalsWithLabelAndLabelField() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 |head 3| fields age, balance | addtotals row=false" + + " col=true label='Sum' labelfield='Total Summary'", + TEST_INDEX_ACCOUNT)); + + // Verify schema includes custom fieldname + verifySchema( + result, + schema("age", "bigint"), + schema("balance", "bigint"), + schema("Total Summary", "string")); + + var dataRows = result.getJSONArray("datarows"); + ArrayList field_indexes = new ArrayList<>(); + field_indexes.add(0); + field_indexes.add(1); + + verifyColTotals(dataRows, field_indexes, "Sum"); + } + + @Test + public void testAddTotalsWithFieldnameAndSpecificFields() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 |head 2| fields age, balance | addtotals balance" + + " fieldname='BalanceSum'", + TEST_INDEX_ACCOUNT)); + + verifySchema( + result, + schema("age", "bigint"), + schema("balance", "bigint"), + schema("BalanceSum", "bigint")); + + var dataRows = result.getJSONArray("datarows"); + ArrayList field_indexes = new ArrayList<>(); + field_indexes.add(1); + + compareDataRowTotals(dataRows, field_indexes, 2); + } + + @Test + public void testAddTotalsWithFieldnameNoRow() throws IOException { + var result = + executeQuery( + String.format( + "source=%s | where age > 25 | fields age, balance | " + + "addtotals balance fieldname='CustomSum' row=false", + TEST_INDEX_ACCOUNT)); + + // With row=false, should not append totals row regardless of fieldname + var dataRows = result.getJSONArray("datarows"); + + // Verify that no totals row was added + for (int i = 0; i < dataRows.length(); i++) { + var row = dataRows.getJSONArray(i); + // None of these rows should have "CustomSum" label + for (int j = 0; j < row.length(); j++) { + if (!row.isNull(j) && row.get(j).equals("CustomSum")) { + fail("Found totals row when row=false was specified"); + } + } + } + } +} diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index b0650c2442..e940dc26b9 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -49,6 +49,10 @@ TRENDLINE: 'TRENDLINE'; CHART: 'CHART'; TIMECHART: 'TIMECHART'; APPENDCOL: 'APPENDCOL'; +ADDTOTALS: 'ADDTOTALS'; +ADDCOLTOTALS: 'ADDCOLTOTALS'; +ROW: 'ROW'; +COL: 'COL'; EXPAND: 'EXPAND'; SIMPLE_PATTERN: 'SIMPLE_PATTERN'; BRAIN: 'BRAIN'; @@ -59,6 +63,9 @@ MAX_SAMPLE_COUNT: 'MAX_SAMPLE_COUNT'; MAX_MATCH: 'MAX_MATCH'; OFFSET_FIELD: 'OFFSET_FIELD'; BUFFER_LIMIT: 'BUFFER_LIMIT'; +FIELDLIST: 'FIELDLIST'; +LABELFIELD: 'LABELFIELD'; +FIELDNAME: 'FIELDNAME'; LABEL: 'LABEL'; SHOW_NUMBERED_TOKEN: 'SHOW_NUMBERED_TOKEN'; AGGREGATION: 'AGGREGATION'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 3857c8557b..09819cb379 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -72,6 +72,8 @@ commands | fillnullCommand | trendlineCommand | appendcolCommand + | addtotalsCommand + | addcoltotalsCommand | appendCommand | expandCommand | flattenCommand @@ -117,6 +119,8 @@ commandName | EXPLAIN | REVERSE | REGEX + | ADDTOTALS + | ADDCOLTOTALS | APPEND | MULTISEARCH | REX @@ -565,6 +569,27 @@ mlArg : (argName = ident EQUAL argValue = literalValue) ; +addtotalsCommand + : ADDTOTALS (fieldList)? addtotalsOption* + ; + +addtotalsOption + : (LABEL EQUAL stringLiteral) + | (LABELFIELD EQUAL stringLiteral) + | (FIELDNAME EQUAL stringLiteral) + | (ROW EQUAL booleanLiteral) + | (COL EQUAL booleanLiteral) + ; + +addcoltotalsCommand + : ADDCOLTOTALS (fieldList)? addcoltotalsOption* + ; + +addcoltotalsOption + : (LABEL EQUAL stringLiteral) + | (LABELFIELD EQUAL stringLiteral) + ; + // clauses fromClause : SOURCE EQUAL tableOrSubqueryClause diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 09e9b4c77e..9620c47d83 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -70,6 +70,8 @@ import org.opensearch.sql.ast.expression.WindowFrame; import org.opensearch.sql.ast.expression.WindowFunction; import org.opensearch.sql.ast.tree.AD; +import org.opensearch.sql.ast.tree.AddColTotals; +import org.opensearch.sql.ast.tree.AddTotals; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Append; import org.opensearch.sql.ast.tree.AppendCol; @@ -1388,4 +1390,43 @@ private boolean hasActualWildcards(OpenSearchPPLParser.FieldsCommandBodyContext } return false; } + + @Override + public UnresolvedPlan visitAddtotalsCommand(OpenSearchPPLParser.AddtotalsCommandContext ctx) { + + List fieldList = new ArrayList<>(); + if (ctx.fieldList() != null) { + fieldList = getFieldList(ctx.fieldList()); + } + ImmutableMap.Builder cmdOptionsBuilder = ImmutableMap.builder(); + ctx.addtotalsOption() + .forEach( + option -> { + String argName = option.children.get(0).toString(); + Literal value = (Literal) internalVisitExpression(option.children.get(2)); + cmdOptionsBuilder.put(argName, value); + }); + java.util.Map options = cmdOptionsBuilder.build(); + return new AddTotals(fieldList, options); + } + + @Override + public UnresolvedPlan visitAddcoltotalsCommand( + OpenSearchPPLParser.AddcoltotalsCommandContext ctx) { + + List fieldList = new ArrayList<>(); + if (ctx.fieldList() != null) { + fieldList = getFieldList(ctx.fieldList()); + } + ImmutableMap.Builder cmdOptionsBuilder = ImmutableMap.builder(); + ctx.addcoltotalsOption() + .forEach( + option -> { + String argName = option.children.get(0).toString(); + Literal value = (Literal) internalVisitExpression(option.children.get(2)); + cmdOptionsBuilder.put(argName, value); + }); + java.util.Map options = cmdOptionsBuilder.build(); + return new AddColTotals(fieldList, options); + } } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java index 0971924295..9db1905783 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java @@ -52,6 +52,8 @@ import org.opensearch.sql.ast.statement.Explain; import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; +import org.opensearch.sql.ast.tree.AddColTotals; +import org.opensearch.sql.ast.tree.AddTotals; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Append; import org.opensearch.sql.ast.tree.AppendCol; @@ -772,6 +774,41 @@ public String visitSpath(SPath node, String context) { return builder.toString(); } + public void appendAddTotalsOptionParameters( + List fieldList, java.util.Map options, StringBuilder builder) { + + if (!fieldList.isEmpty()) { + builder.append(visitExpressionList(fieldList, " ")); + } + if (!options.isEmpty()) { + for (String key : options.keySet()) { + String value = options.get(key).toString(); + if (value.matches(".*\\s.*")) { + value = StringUtils.format("'%s'", value); + } + builder.append(" ").append(key).append("=").append(value); + } + } + } + + @Override + public String visitAddTotals(AddTotals node, String context) { + String child = node.getChild().get(0).accept(this, context); + StringBuilder builder = new StringBuilder(); + builder.append(child).append(" | addtotals"); + appendAddTotalsOptionParameters(node.getFieldList(), node.getOptions(), builder); + return builder.toString(); + } + + @Override + public String visitAddColTotals(AddColTotals node, String context) { + String child = node.getChild().get(0).accept(this, context); + StringBuilder builder = new StringBuilder(); + builder.append(child).append(" | addcoltotals"); + appendAddTotalsOptionParameters(node.getFieldList(), node.getOptions(), builder); + return builder.toString(); + } + @Override public String visitPatterns(Patterns node, String context) { String child = node.getChild().get(0).accept(this, context); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java index 67403741a8..3459cf1487 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java @@ -914,4 +914,36 @@ public void testWhereCommandWithDoubleEqual() { "SOURCE=test | WHERE query_string(['field1', 'field2' ^ 3.2], 'test query'," + " analyzer='keyword')")); } + + @Test + public void testAddTotalsCommandShouldPass() { + ParseTree tree = new PPLSyntaxParser().parse("source=t | addtotals"); + assertNotEquals(null, tree); + } + + @Test + public void testAddTotalsCommandWithFieldsShouldPass() { + ParseTree tree = new PPLSyntaxParser().parse("source=t | addtotals price, quantity"); + assertNotEquals(null, tree); + } + + @Test + public void testAddTotalsCommandWithLabelShouldPass() { + ParseTree tree = new PPLSyntaxParser().parse("source=t | addtotals label='Grand Total'"); + assertNotEquals(null, tree); + } + + @Test + public void testAddTotalsCommandWithLabelFieldShouldPass() { + ParseTree tree = new PPLSyntaxParser().parse("source=t | addtotals labelfield='category'"); + assertNotEquals(null, tree); + } + + @Test + public void testAddTotalsCommandWithAllOptionsShouldPass() { + ParseTree tree = + new PPLSyntaxParser() + .parse("source=t | addtotals price, quantity label='Total' labelfield='type'"); + assertNotEquals(null, tree); + } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAddColTotalsTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAddColTotalsTest.java new file mode 100644 index 0000000000..d14dd00c20 --- /dev/null +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAddColTotalsTest.java @@ -0,0 +1,274 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.calcite; + +import java.io.IOException; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.test.CalciteAssert; +import org.junit.Test; + +public class CalcitePPLAddColTotalsTest extends CalcitePPLAbstractTest { + + public CalcitePPLAddColTotalsTest() { + super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL); + } + + @Test + public void testAddColTotals() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addcoltotals "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[$0], SAL=[$1], JOB=[null:VARCHAR(9)])\n" + + " LogicalAggregate(group=[{}], DEPTNO=[SUM($0)], SAL=[SUM($1)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK\n" + + "DEPTNO=310; SAL=29025.00; JOB=null\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT SUM(`DEPTNO`) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING) `JOB`\n" + + "FROM `scott`.`EMP`"; + + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddColTotalsFieldSpecified() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addcoltotals SAL "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[null:TINYINT], SAL=[$0], JOB=[null:VARCHAR(9)])\n" + + " LogicalAggregate(group=[{}], SAL=[SUM($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK\n" + + "DEPTNO=null; SAL=29025.00; JOB=null\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT CAST(NULL AS TINYINT) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING)" + + " `JOB`\n" + + "FROM `scott`.`EMP`"; + + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddColTotalsAllFields() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addcoltotals "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[$0], SAL=[$1], JOB=[null:VARCHAR(9)])\n" + + " LogicalAggregate(group=[{}], DEPTNO=[SUM($0)], SAL=[SUM($1)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK\n" + + "DEPTNO=310; SAL=29025.00; JOB=null\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT SUM(`DEPTNO`) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING) `JOB`\n" + + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddColTotalsMultiFields() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addcoltotals DEPTNO SAL "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[$0], SAL=[$1], JOB=[null:VARCHAR(9)])\n" + + " LogicalAggregate(group=[{}], DEPTNO=[SUM($0)], SAL=[SUM($1)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK\n" + + "DEPTNO=310; SAL=29025.00; JOB=null\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT SUM(`DEPTNO`) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING) `JOB`\n" + + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddColTotalsWithAllOptions() throws IOException { + String ppl = + "source=EMP | fields DEPTNO, SAL, JOB | addcoltotals SAL label='GrandTotal'" + + " labelfield='all_emp_total' "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2]," + + " all_emp_total=[null:VARCHAR(13)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[null:TINYINT], SAL=[$0], JOB=[null:VARCHAR(9)]," + + " all_emp_total=['GrandTotal':VARCHAR(13)])\n" + + " LogicalAggregate(group=[{}], SAL=[SUM($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; all_emp_total=null\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; all_emp_total=null\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; all_emp_total=null\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; all_emp_total=null\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; all_emp_total=null\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; all_emp_total=null\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; all_emp_total=null\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; all_emp_total=null\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; all_emp_total=null\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; all_emp_total=null\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; all_emp_total=null\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; all_emp_total=null\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; all_emp_total=null\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; all_emp_total=null\n" + + "DEPTNO=null; SAL=29025.00; JOB=null; all_emp_total=GrandTotal\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, CAST(NULL AS STRING) `all_emp_total`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT CAST(NULL AS TINYINT) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING) `JOB`," + + " 'GrandTotal' `all_emp_total`\n" + + "FROM `scott`.`EMP`"; + + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddColTotalsMatchingLabelFieldWithExisting() throws IOException { + String ppl = + "source=EMP | fields DEPTNO, SAL, JOB | addcoltotals SAL label='GrandTotal'" + + " labelfield='JOB' "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[null:TINYINT], SAL=[$0], JOB=['GrandTota':VARCHAR(9)])\n" + + " LogicalAggregate(group=[{}], SAL=[SUM($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK\n" + + "DEPTNO=null; SAL=29025.00; JOB=GrandTota\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT CAST(NULL AS TINYINT) `DEPTNO`, SUM(`SAL`) `SAL`, 'GrandTota' `JOB`\n" + + "FROM `scott`.`EMP`"; + + verifyPPLToSparkSQL(root, expectedSparkSql); + } +} diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAddTotalsTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAddTotalsTest.java new file mode 100644 index 0000000000..386baa5ff9 --- /dev/null +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAddTotalsTest.java @@ -0,0 +1,426 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.calcite; + +import java.io.IOException; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.test.CalciteAssert; +import org.junit.Test; + +public class CalcitePPLAddTotalsTest extends CalcitePPLAbstractTest { + + public CalcitePPLAddTotalsTest() { + super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL); + } + + @Test + public void testAddTotals() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], Total=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; Total=800.00\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; Total=1600.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1250.00\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; Total=2975.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1250.00\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; Total=2850.00\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; Total=2450.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3000.00\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; Total=5000.00\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; Total=1500.00\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; Total=1100.00\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; Total=950.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3000.00\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; Total=1300.00\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = "SELECT `DEPTNO`, `SAL`, `JOB`, `SAL` `Total`\nFROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsAllFields() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addtotals "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], Total=[+($7, $5)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; Total=820.00\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; Total=1630.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1280.00\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; Total=2995.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1280.00\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; Total=2880.00\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; Total=2460.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3020.00\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; Total=5010.00\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; Total=1530.00\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; Total=1120.00\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; Total=980.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3020.00\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; Total=1310.00\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, `DEPTNO` + `SAL` `Total`\nFROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsMultiFields() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addtotals DEPTNO SAL "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], Total=[+($7, $5)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; Total=820.00\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; Total=1630.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1280.00\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; Total=2995.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1280.00\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; Total=2880.00\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; Total=2460.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3020.00\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; Total=5010.00\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; Total=1530.00\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; Total=1120.00\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; Total=980.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3020.00\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; Total=1310.00\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, `DEPTNO` + `SAL` `Total`\nFROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + public void testAddTotalsWithFieldname() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL fieldname='CustomSum' "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], CustomSum=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; CustomSum=800.00\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; CustomSum=1600.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; CustomSum=1250.00\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; CustomSum=2975.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; CustomSum=1250.00\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; CustomSum=2850.00\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; CustomSum=2450.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; CustomSum=3000.00\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; CustomSum=5000.00\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; CustomSum=1500.00\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; CustomSum=1100.00\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; CustomSum=950.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; CustomSum=3000.00\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; CustomSum=1300.00\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, `SAL` `CustomSum`\nFROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsWithFieldnameRowOptionTrue() throws IOException { + String ppl = + "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL fieldname='CustomSum' row=true "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], CustomSum=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; CustomSum=800.00\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; CustomSum=1600.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; CustomSum=1250.00\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; CustomSum=2975.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; CustomSum=1250.00\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; CustomSum=2850.00\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; CustomSum=2450.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; CustomSum=3000.00\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; CustomSum=5000.00\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; CustomSum=1500.00\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; CustomSum=1100.00\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; CustomSum=950.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; CustomSum=3000.00\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; CustomSum=1300.00\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, `SAL` `CustomSum`\nFROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsWithFieldnameRowOptionFalse() throws IOException { + String ppl = + "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL fieldname='CustomSum' row=false "; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = "SELECT `DEPTNO`, `SAL`, `JOB`\nFROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsWithColTrueNoSummaryLabel() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL col=true"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], Total=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[null:TINYINT], SAL=[$0], JOB=[null:VARCHAR(9)]," + + " Total=[null:DECIMAL(7, 2)])\n" + + " LogicalAggregate(group=[{}], SAL=[SUM($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; Total=800.00\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; Total=1600.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1250.00\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; Total=2975.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1250.00\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; Total=2850.00\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; Total=2450.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3000.00\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; Total=5000.00\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; Total=1500.00\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; Total=1100.00\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; Total=950.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3000.00\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; Total=1300.00\n" + + "DEPTNO=null; SAL=29025.00; JOB=null; Total=null\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, `SAL` `Total`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT CAST(NULL AS TINYINT) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING) `JOB`," + + " CAST(NULL AS DECIMAL(7, 2)) `Total`\n" + + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsWithColTrueRowFalseNoSummaryLabel() throws IOException { + String ppl = "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL col=true row=false"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[null:TINYINT], SAL=[$0], JOB=[null:VARCHAR(9)])\n" + + " LogicalAggregate(group=[{}], SAL=[SUM($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK\n" + + "DEPTNO=null; SAL=29025.00; JOB=null\n"; + + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT CAST(NULL AS TINYINT) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING)" + + " `JOB`\n" + + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsWithAllOptionsIncludingDefaultFieldname() throws IOException { + String ppl = + "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL label='ColTotal'" + + " labelfield='Total' col=true"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], Total=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[null:TINYINT], SAL=[$0], JOB=[null:VARCHAR(9)]," + + " Total=[null:DECIMAL(7, 2)])\n" + + " LogicalAggregate(group=[{}], SAL=[SUM($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; Total=800.00\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; Total=1600.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1250.00\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; Total=2975.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1250.00\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; Total=2850.00\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; Total=2450.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3000.00\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; Total=5000.00\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; Total=1500.00\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; Total=1100.00\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; Total=950.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3000.00\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; Total=1300.00\n" + + "DEPTNO=null; SAL=29025.00; JOB=null; Total=null\n"; + // by default row=true , new field added as 'Total' and labelfield='Total' will have conflict + // and 'ColTotal' will not be set in Total column as it will be number type being row=true + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, `SAL` `Total`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT CAST(NULL AS TINYINT) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING) `JOB`," + + " CAST(NULL AS DECIMAL(7, 2)) `Total`\n" + + "FROM `scott`.`EMP`"; + + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsWithAllOptionsIncludingFieldname() throws IOException { + String ppl = + "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL label='ColTotal'" + + " fieldname='CustomSum' labelfield='all_emp_total' row=true col=true"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], CustomSum=[$5]," + + " all_emp_total=[null:NULL])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[null:NULL], SAL=[$0], JOB=[null:NULL]," + + " CustomSum=[null:NULL], all_emp_total=['ColTotal'])\n" + + " LogicalAggregate(group=[{}], SAL=[SUM($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + // verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; CustomSum=800.00; all_emp_total=null\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; CustomSum=1600.00; all_emp_total=null\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; CustomSum=1250.00; all_emp_total=null\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; CustomSum=2975.00; all_emp_total=null\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; CustomSum=1250.00; all_emp_total=null\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; CustomSum=2850.00; all_emp_total=null\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; CustomSum=2450.00; all_emp_total=null\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; CustomSum=3000.00; all_emp_total=null\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; CustomSum=5000.00; all_emp_total=null\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; CustomSum=1500.00; all_emp_total=null\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; CustomSum=1100.00; all_emp_total=null\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; CustomSum=950.00; all_emp_total=null\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; CustomSum=3000.00; all_emp_total=null\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; CustomSum=1300.00; all_emp_total=null\n" + + "DEPTNO=null; SAL=29025.00; JOB=null; CustomSum=null; all_emp_total=ColTotal\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, `SAL` `CustomSum`, CAST(NULL AS STRING) `all_emp_total`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT CAST(NULL AS TINYINT) `DEPTNO`, SUM(`SAL`) `SAL`, CAST(NULL AS STRING) `JOB`," + + " CAST(NULL AS DECIMAL(7, 2)) `CustomSum`, 'ColTotal' `all_emp_total`\n" + + "FROM `scott`.`EMP`"; + + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAddTotalsMatchingLabelFieldWithExisting() throws IOException { + String ppl = + "source=EMP | fields DEPTNO, SAL, JOB | addtotals SAL DEPTNO col=true label='ColTotal'" + + " labelfield='JOB' "; + // default is row=true for addtotals + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], Total=[+($7, $5)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(DEPTNO=[$0], SAL=[$1], JOB=['ColTotal':VARCHAR(9)]," + + " Total=[null:DECIMAL(8, 2)])\n" + + " LogicalAggregate(group=[{}], DEPTNO=[SUM($0)], SAL=[SUM($1)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "DEPTNO=20; SAL=800.00; JOB=CLERK; Total=820.00\n" + + "DEPTNO=30; SAL=1600.00; JOB=SALESMAN; Total=1630.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1280.00\n" + + "DEPTNO=20; SAL=2975.00; JOB=MANAGER; Total=2995.00\n" + + "DEPTNO=30; SAL=1250.00; JOB=SALESMAN; Total=1280.00\n" + + "DEPTNO=30; SAL=2850.00; JOB=MANAGER; Total=2880.00\n" + + "DEPTNO=10; SAL=2450.00; JOB=MANAGER; Total=2460.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3020.00\n" + + "DEPTNO=10; SAL=5000.00; JOB=PRESIDENT; Total=5010.00\n" + + "DEPTNO=30; SAL=1500.00; JOB=SALESMAN; Total=1530.00\n" + + "DEPTNO=20; SAL=1100.00; JOB=CLERK; Total=1120.00\n" + + "DEPTNO=30; SAL=950.00; JOB=CLERK; Total=980.00\n" + + "DEPTNO=20; SAL=3000.00; JOB=ANALYST; Total=3020.00\n" + + "DEPTNO=10; SAL=1300.00; JOB=CLERK; Total=1310.00\n" + + "DEPTNO=310; SAL=29025.00; JOB=ColTotal; Total=null\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, `SAL`, `JOB`, `DEPTNO` + `SAL` `Total`\n" + + "FROM `scott`.`EMP`\n" + + "UNION ALL\n" + + "SELECT SUM(`DEPTNO`) `DEPTNO`, SUM(`SAL`) `SAL`, 'ColTotal' `JOB`, CAST(NULL AS" + + " DECIMAL(8, 2)) `Total`\n" + + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } +} diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java index ec87000b5b..167574a828 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java @@ -479,6 +479,23 @@ public void testAppendcol() { anonymize("source=t | appendcol override=false [ where a = 1 ]")); } + @Test + public void testAddTotals() { + assertEquals( + "source=table | addtotals row=true col=true label=identifier labelfield=identifier" + + " fieldname=identifier", + anonymize( + "source=table | addtotals row=true col=true label='identifier' labelfield='identifier'" + + " fieldname='identifier'")); + } + + @Test + public void testAddColTotals() { + assertEquals( + "source=table | addcoltotals label=identifier labelfield=identifier", + anonymize("source=table | addcoltotals label='identifier' labelfield='identifier'")); + } + @Test public void testAppend() { assertEquals( @@ -678,8 +695,8 @@ public void testGrok() { @Test public void testReplaceCommandSingleField() { assertEquals( - "source=table | replace *** WITH *** IN Field(field=fieldname, fieldArgs=[])", - anonymize("source=EMP | replace \"value\" WITH \"newvalue\" IN fieldname")); + "source=table | replace *** WITH *** IN Field(field=fieldname1, fieldArgs=[])", + anonymize("source=EMP | replace \"value\" WITH \"newvalue\" IN fieldname1")); } @Test @@ -706,8 +723,8 @@ public void testReplaceCommandSpecialCharactersInFields() { @Test public void testReplaceCommandWithWildcards() { assertEquals( - "source=table | replace *** WITH *** IN Field(field=fieldname, fieldArgs=[])", - anonymize("source=EMP | replace \"CLERK*\" WITH \"EMPLOYEE*\" IN fieldname")); + "source=table | replace *** WITH *** IN Field(field=fieldname1, fieldArgs=[])", + anonymize("source=EMP | replace \"CLERK*\" WITH \"EMPLOYEE*\" IN fieldname1")); } @Test @@ -745,9 +762,9 @@ public void testPatterns() { @Test public void testRegex() { assertEquals( - "source=table | regex identifier=***", anonymize("source=t | regex fieldname='pattern'")); + "source=table | regex identifier=***", anonymize("source=t | regex field='pattern'")); assertEquals( - "source=table | regex identifier!=***", anonymize("source=t | regex fieldname!='pattern'")); + "source=table | regex identifier!=***", anonymize("source=t | regex field!='pattern'")); assertEquals( "source=table | regex identifier=*** | fields + identifier", anonymize("source=t | regex email='.*@domain.com' | fields email"));