Skip to content

Commit 07124d0

Browse files
committed
Add udf interface (#3374)
* add udf/udaf interface and take/sqrt function Signed-off-by: xinyual <xinyual@amazon.com> * add UT Signed-off-by: xinyual <xinyual@amazon.com> * add POW, Atan, Atan2 and corresponding UT Signed-off-by: xinyual <xinyual@amazon.com> * apply spotless Signed-off-by: xinyual <xinyual@amazon.com> * fix table for join it Signed-off-by: xinyual <xinyual@amazon.com> * add java doc Signed-off-by: xinyual <xinyual@amazon.com> * apply spotless Signed-off-by: xinyual <xinyual@amazon.com> --------- Signed-off-by: xinyual <xinyual@amazon.com>
1 parent 874d6fb commit 07124d0

File tree

16 files changed

+455
-11
lines changed

16 files changed

+455
-11
lines changed

core/src/main/java/org/opensearch/sql/calcite/CalciteAggCallVisitor.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
package org.opensearch.sql.calcite;
77

8+
import java.util.ArrayList;
9+
import java.util.List;
810
import org.apache.calcite.rex.RexNode;
911
import org.apache.calcite.tools.RelBuilder.AggCall;
1012
import org.opensearch.sql.ast.AbstractNodeVisitor;
@@ -33,6 +35,10 @@ public AggCall visitAlias(Alias node, CalcitePlanContext context) {
3335
@Override
3436
public AggCall visitAggregateFunction(AggregateFunction node, CalcitePlanContext context) {
3537
RexNode field = rexNodeVisitor.analyze(node.getField(), context);
36-
return AggregateUtils.translate(node, field, context);
38+
List<RexNode> argList = new ArrayList<>();
39+
for (UnresolvedExpression arg : node.getArgList()) {
40+
argList.add(rexNodeVisitor.analyze(arg, context));
41+
}
42+
return AggregateUtils.translate(node, field, context, argList);
3743
}
3844
}

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import static org.opensearch.sql.ast.expression.SpanUnit.NONE;
99
import static org.opensearch.sql.ast.expression.SpanUnit.UNKNOWN;
10+
import static org.opensearch.sql.calcite.utils.BuiltinFunctionUtils.translateArgument;
1011

1112
import java.math.BigDecimal;
1213
import java.util.List;
@@ -254,7 +255,8 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
254255
List<RexNode> arguments =
255256
node.getFuncArgs().stream().map(arg -> analyze(arg, context)).collect(Collectors.toList());
256257
return context.rexBuilder.makeCall(
257-
BuiltinFunctionUtils.translate(node.getFuncName()), arguments);
258+
BuiltinFunctionUtils.translate(node.getFuncName()),
259+
translateArgument(node.getFuncName(), arguments, context));
258260
}
259261

260262
@Override
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.calcite.udf;
7+
8+
public interface UserDefinedAggFunction<S extends UserDefinedAggFunction.Accumulator> {
9+
/**
10+
* @return {@link Accumulator}
11+
*/
12+
S init();
13+
14+
/**
15+
* @param {@link Accumulator}
16+
* @return final result
17+
*/
18+
Object result(S accumulator);
19+
20+
/**
21+
* Add values to the accumulator. Notice some init argument will also be here like the 50 in
22+
* Percentile(field, 50).
23+
*
24+
* @param acc {@link Accumulator}
25+
* @param values the value to add to accumulator
26+
* @return {@link Accumulator}
27+
*/
28+
S add(S acc, Object... values);
29+
30+
/** Maintain the state when {@link UserDefinedAggFunction} add values */
31+
interface Accumulator {
32+
/**
33+
* @return the final aggregation value
34+
*/
35+
Object value();
36+
}
37+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.calcite.udf;
7+
8+
public interface UserDefinedFunction {
9+
Object eval(Object... args);
10+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.calcite.udf.mathUDF;
7+
8+
import static java.lang.Math.sqrt;
9+
10+
import org.opensearch.sql.calcite.udf.UserDefinedFunction;
11+
12+
public class SqrtFunction implements UserDefinedFunction {
13+
@Override
14+
public Object eval(Object... args) {
15+
if (args.length < 1) {
16+
throw new IllegalArgumentException("At least one argument is required");
17+
}
18+
19+
// Get the input value
20+
Object input = args[0];
21+
22+
// Handle numbers dynamically
23+
if (input instanceof Number) {
24+
double num = ((Number) input).doubleValue();
25+
26+
if (num < 0) {
27+
throw new ArithmeticException("Cannot compute square root of a negative number");
28+
}
29+
30+
return sqrt(num); // Computes sqrt using Math.sqrt()
31+
} else {
32+
throw new IllegalArgumentException("Invalid argument type: Expected a numeric value");
33+
}
34+
}
35+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.calcite.udf.udaf;
7+
8+
import java.util.ArrayList;
9+
import java.util.List;
10+
import org.opensearch.sql.calcite.udf.UserDefinedAggFunction;
11+
12+
public class TakeAggFunction implements UserDefinedAggFunction<TakeAggFunction.TakeAccumulator> {
13+
14+
@Override
15+
public TakeAccumulator init() {
16+
return new TakeAccumulator();
17+
}
18+
19+
@Override
20+
public Object result(TakeAccumulator accumulator) {
21+
return accumulator.value();
22+
}
23+
24+
@Override
25+
public TakeAccumulator add(TakeAccumulator acc, Object... values) {
26+
Object candidateValue = values[0];
27+
int size = 0;
28+
if (values.length > 1) {
29+
size = (int) values[1];
30+
} else {
31+
size = 10;
32+
}
33+
if (size > acc.size()) {
34+
acc.add(candidateValue);
35+
}
36+
return acc;
37+
}
38+
39+
public static class TakeAccumulator implements Accumulator {
40+
private List<Object> hits;
41+
42+
public TakeAccumulator() {
43+
hits = new ArrayList<>();
44+
}
45+
46+
@Override
47+
public Object value() {
48+
return hits;
49+
}
50+
51+
public void add(Object value) {
52+
hits.add(value);
53+
}
54+
55+
public int size() {
56+
return hits.size();
57+
}
58+
}
59+
}

core/src/main/java/org/opensearch/sql/calcite/utils/AggregateUtils.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
package org.opensearch.sql.calcite.utils;
77

8+
import static org.opensearch.sql.calcite.utils.UserDefineFunctionUtils.TransferUserDefinedAggFunction;
9+
810
import com.google.common.collect.ImmutableList;
11+
import java.util.List;
912
import org.apache.calcite.rel.RelCollations;
1013
import org.apache.calcite.rel.core.AggregateCall;
1114
import org.apache.calcite.rex.RexInputRef;
@@ -15,12 +18,13 @@
1518
import org.apache.calcite.tools.RelBuilder;
1619
import org.opensearch.sql.ast.expression.AggregateFunction;
1720
import org.opensearch.sql.calcite.CalcitePlanContext;
21+
import org.opensearch.sql.calcite.udf.udaf.TakeAggFunction;
1822
import org.opensearch.sql.expression.function.BuiltinFunctionName;
1923

2024
public interface AggregateUtils {
2125

2226
static RelBuilder.AggCall translate(
23-
AggregateFunction agg, RexNode field, CalcitePlanContext context) {
27+
AggregateFunction agg, RexNode field, CalcitePlanContext context, List<RexNode> argList) {
2428
if (BuiltinFunctionName.ofAggregation(agg.getFuncName()).isEmpty())
2529
throw new IllegalStateException("Unexpected value: " + agg.getFuncName());
2630

@@ -50,6 +54,14 @@ static RelBuilder.AggCall translate(
5054
// case PERCENTILE_APPROX:
5155
// return
5256
// context.relBuilder.aggregateCall(SqlStdOperatorTable.PERCENTILE_CONT, field);
57+
case TAKE:
58+
return TransferUserDefinedAggFunction(
59+
TakeAggFunction.class,
60+
"TAKE",
61+
UserDefineFunctionUtils.getReturnTypeInferenceForArray(),
62+
List.of(field),
63+
argList,
64+
context.relBuilder);
5365
case PERCENTILE_APPROX:
5466
throw new UnsupportedOperationException("PERCENTILE_APPROX is not supported in PPL");
5567
// case APPROX_COUNT_DISTINCT:

core/src/main/java/org/opensearch/sql/calcite/utils/BuiltinFunctionUtils.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,19 @@
55

66
package org.opensearch.sql.calcite.utils;
77

8+
import static org.opensearch.sql.calcite.utils.UserDefineFunctionUtils.TransferUserDefinedFunction;
9+
10+
import java.math.BigDecimal;
11+
import java.util.ArrayList;
12+
import java.util.List;
813
import java.util.Locale;
14+
import org.apache.calcite.rex.RexNode;
915
import org.apache.calcite.sql.SqlOperator;
1016
import org.apache.calcite.sql.fun.SqlLibraryOperators;
1117
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
18+
import org.apache.calcite.sql.type.ReturnTypes;
19+
import org.opensearch.sql.calcite.CalcitePlanContext;
20+
import org.opensearch.sql.calcite.udf.mathUDF.SqrtFunction;
1221

1322
public interface BuiltinFunctionUtils {
1423

@@ -51,6 +60,12 @@ static SqlOperator translate(String op) {
5160
// Built-in Math Functions
5261
case "ABS":
5362
return SqlStdOperatorTable.ABS;
63+
case "SQRT":
64+
return TransferUserDefinedFunction(SqrtFunction.class, "SQRT", ReturnTypes.DOUBLE);
65+
case "ATAN", "ATAN2":
66+
return SqlStdOperatorTable.ATAN2;
67+
case "POW", "POWER":
68+
return SqlStdOperatorTable.POWER;
5469
// Built-in Date Functions
5570
case "CURRENT_TIMESTAMP":
5671
return SqlStdOperatorTable.CURRENT_TIMESTAMP;
@@ -67,4 +82,32 @@ static SqlOperator translate(String op) {
6782
throw new IllegalArgumentException("Unsupported operator: " + op);
6883
}
6984
}
85+
86+
/**
87+
* Translates function arguments to align with Calcite's expectations, ensuring compatibility with
88+
* PPL (Piped Processing Language). This is necessary because Calcite's input argument order or
89+
* default values may differ from PPL's function definitions.
90+
*
91+
* @param op The function name as a string.
92+
* @param argList A list of {@link RexNode} representing the parsed arguments from the PPL
93+
* statement.
94+
* @param context The {@link CalcitePlanContext} providing necessary utilities such as {@code
95+
* rexBuilder}.
96+
* @return A modified list of {@link RexNode} that correctly maps to Calcite’s function
97+
* expectations.
98+
*/
99+
static List<RexNode> translateArgument(
100+
String op, List<RexNode> argList, CalcitePlanContext context) {
101+
switch (op.toUpperCase(Locale.ROOT)) {
102+
case "ATAN":
103+
List<RexNode> AtanArgs = new ArrayList<>(argList);
104+
if (AtanArgs.size() == 1) {
105+
BigDecimal divideNumber = BigDecimal.valueOf(1);
106+
AtanArgs.add(context.rexBuilder.makeBigintLiteral(divideNumber));
107+
}
108+
return AtanArgs;
109+
default:
110+
return argList;
111+
}
112+
}
70113
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.calcite.utils;
7+
8+
import static org.apache.calcite.sql.type.SqlTypeUtil.createArrayType;
9+
10+
import java.util.ArrayList;
11+
import java.util.Collections;
12+
import java.util.List;
13+
import org.apache.calcite.linq4j.tree.Types;
14+
import org.apache.calcite.rel.type.RelDataType;
15+
import org.apache.calcite.rel.type.RelDataTypeFactory;
16+
import org.apache.calcite.rex.RexNode;
17+
import org.apache.calcite.schema.ScalarFunction;
18+
import org.apache.calcite.schema.impl.AggregateFunctionImpl;
19+
import org.apache.calcite.schema.impl.ScalarFunctionImpl;
20+
import org.apache.calcite.sql.SqlIdentifier;
21+
import org.apache.calcite.sql.SqlKind;
22+
import org.apache.calcite.sql.SqlOperator;
23+
import org.apache.calcite.sql.parser.SqlParserPos;
24+
import org.apache.calcite.sql.type.SqlReturnTypeInference;
25+
import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
26+
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
27+
import org.apache.calcite.tools.RelBuilder;
28+
import org.apache.calcite.util.Optionality;
29+
import org.opensearch.sql.calcite.udf.UserDefinedAggFunction;
30+
import org.opensearch.sql.calcite.udf.UserDefinedFunction;
31+
32+
public class UserDefineFunctionUtils {
33+
public static RelBuilder.AggCall TransferUserDefinedAggFunction(
34+
Class<? extends UserDefinedAggFunction> UDAF,
35+
String functionName,
36+
SqlReturnTypeInference returnType,
37+
List<RexNode> fields,
38+
List<RexNode> argList,
39+
RelBuilder relBuilder) {
40+
SqlUserDefinedAggFunction sqlUDAF =
41+
new SqlUserDefinedAggFunction(
42+
new SqlIdentifier(functionName, SqlParserPos.ZERO),
43+
SqlKind.OTHER_FUNCTION,
44+
returnType,
45+
null,
46+
null,
47+
AggregateFunctionImpl.create(UDAF),
48+
false,
49+
false,
50+
Optionality.FORBIDDEN);
51+
List<RexNode> addArgList = new ArrayList<>(fields);
52+
addArgList.addAll(argList);
53+
return relBuilder.aggregateCall(sqlUDAF, addArgList);
54+
}
55+
56+
public static SqlOperator TransferUserDefinedFunction(
57+
Class<? extends UserDefinedFunction> UDF,
58+
String functionName,
59+
SqlReturnTypeInference returnType) {
60+
final ScalarFunction udfFunction =
61+
ScalarFunctionImpl.create(Types.lookupMethod(UDF, "eval", Object[].class));
62+
SqlIdentifier udfLtrimIdentifier =
63+
new SqlIdentifier(Collections.singletonList(functionName), null, SqlParserPos.ZERO, null);
64+
return new SqlUserDefinedFunction(
65+
udfLtrimIdentifier, SqlKind.OTHER_FUNCTION, returnType, null, null, udfFunction);
66+
}
67+
68+
public static SqlReturnTypeInference getReturnTypeInferenceForArray() {
69+
return opBinding -> {
70+
RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
71+
72+
// Get argument types
73+
List<RelDataType> argTypes = opBinding.collectOperandTypes();
74+
75+
if (argTypes.isEmpty()) {
76+
throw new IllegalArgumentException("Function requires at least one argument.");
77+
}
78+
RelDataType firstArgType = argTypes.getFirst();
79+
return createArrayType(typeFactory, firstArgType, true);
80+
};
81+
}
82+
}

integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLAggregationIT.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import static org.opensearch.sql.util.MatcherUtils.verifySchema;
1313

1414
import java.io.IOException;
15+
import java.util.List;
1516
import org.json.JSONObject;
1617
import org.junit.Ignore;
1718
import org.junit.jupiter.api.Test;
@@ -221,4 +222,13 @@ public void testSimpleTwoLevelStats() {
221222
verifySchema(actual, schema("avg_avg", "double"));
222223
verifyDataRows(actual, rows(28432.625));
223224
}
225+
226+
@Test
227+
public void testTake() {
228+
JSONObject actual =
229+
executeQuery(
230+
String.format("source=%s | stats take(firstname, 2) as take", TEST_INDEX_BANK));
231+
verifySchema(actual, schema("take", "array"));
232+
verifyDataRows(actual, rows(List.of("Amber JOHnny", "Hattie")));
233+
}
224234
}

0 commit comments

Comments
 (0)