-
Notifications
You must be signed in to change notification settings - Fork 190
Add udf interface #3374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add udf interface #3374
Changes from 6 commits
4cbf2ca
89207a1
28869cf
d1f8f11
2464c85
df6728a
958cac9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.sql.calcite.udf; | ||
|
|
||
| public interface UserDefinedAggFunction<S extends UserDefinedAggFunction.Accumulator> { | ||
| /** | ||
| * @return {@link Accumulator} | ||
| */ | ||
| S init(); | ||
|
|
||
| /** | ||
| * | ||
| * @param {@link Accumulator} | ||
| * @return final result | ||
| */ | ||
| Object result(S accumulator); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in UDF, the method name is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the function name of agg function, it's defined in calcite here https://github.com/apache/calcite/blob/1793ba79a328c61fb42842f443334cc1353c985f/core/src/main/java/org/apache/calcite/schema/impl/AggregateFunctionImpl.java#L91. We cannot modify them. I will left comment. |
||
|
|
||
| /** | ||
| * Add values to the accumulator. Notice some init argument will also be here like the 50 in Percentile(field, 50). | ||
| * @param acc {@link Accumulator} | ||
| * @param values the value to add to accumulator | ||
| * @return {@link Accumulator} | ||
| */ | ||
| S add(S acc, Object... values); | ||
|
|
||
| /** | ||
| * Maintain the state when {@link UserDefinedAggFunction} add values | ||
| */ | ||
| interface Accumulator { | ||
| /** | ||
| * @return the final aggregation value | ||
|
||
| */ | ||
| Object value(); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.sql.calcite.udf; | ||
|
|
||
| public interface UserDefinedFunction { | ||
| Object eval(Object... args); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add comments. |
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.sql.calcite.udf.mathUDF; | ||
|
|
||
| import static java.lang.Math.sqrt; | ||
|
|
||
| import org.opensearch.sql.calcite.udf.UserDefinedFunction; | ||
|
|
||
| public class SqrtFunction implements UserDefinedFunction { | ||
| @Override | ||
| public Object eval(Object... args) { | ||
| if (args.length < 1) { | ||
| throw new IllegalArgumentException("At least one argument is required"); | ||
| } | ||
|
|
||
| // Get the input value | ||
| Object input = args[0]; | ||
|
|
||
| // Handle numbers dynamically | ||
| if (input instanceof Number) { | ||
| double num = ((Number) input).doubleValue(); | ||
|
|
||
| if (num < 0) { | ||
| throw new ArithmeticException("Cannot compute square root of a negative number"); | ||
| } | ||
|
|
||
| return sqrt(num); // Computes sqrt using Math.sqrt() | ||
| } else { | ||
| throw new IllegalArgumentException("Invalid argument type: Expected a numeric value"); | ||
| } | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.sql.calcite.udf.udaf; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. copyright header missing
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
|
||
| import java.util.ArrayList; | ||
| import java.util.List; | ||
|
|
||
| import org.opensearch.sql.calcite.udf.UserDefinedAggFunction; | ||
|
|
||
| public class TakeAggFunction implements UserDefinedAggFunction<TakeAggFunction.TakeAccumulator> { | ||
|
|
||
| @Override | ||
| public TakeAccumulator init() { | ||
| return new TakeAccumulator(); | ||
| } | ||
|
|
||
| @Override | ||
| public Object result(TakeAccumulator accumulator) { | ||
| return accumulator.value(); | ||
| } | ||
|
|
||
| @Override | ||
| public TakeAccumulator add(TakeAccumulator acc, Object... values) { | ||
| Object candidateValue = values[0]; | ||
| int size = 0; | ||
| if (values.length > 1) { | ||
| size = (int) values[1]; | ||
| } else { | ||
| size = 10; | ||
| } | ||
| if (size > acc.size()) { | ||
| acc.add(candidateValue); | ||
| } | ||
| return acc; | ||
| } | ||
|
|
||
| public static class TakeAccumulator implements Accumulator { | ||
| private List<Object> hits; | ||
|
|
||
| public TakeAccumulator() { | ||
| hits = new ArrayList<>(); | ||
| } | ||
|
|
||
| @Override | ||
| public Object value() { | ||
| return hits; | ||
| } | ||
|
Comment on lines
47
to
49
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Code indentation problem, please run
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
|
||
| public void add(Object value) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just compare the interface |
||
| hits.add(value); | ||
| } | ||
|
|
||
| public int size() { | ||
| return hits.size(); | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,10 +5,19 @@ | |
|
|
||
| package org.opensearch.sql.calcite.utils; | ||
|
|
||
| import static org.opensearch.sql.calcite.utils.UserDefineFunctionUtils.TransferUserDefinedFunction; | ||
|
|
||
| import java.math.BigDecimal; | ||
| import java.util.ArrayList; | ||
| import java.util.List; | ||
| import java.util.Locale; | ||
| import org.apache.calcite.rex.RexNode; | ||
| import org.apache.calcite.sql.SqlOperator; | ||
| import org.apache.calcite.sql.fun.SqlLibraryOperators; | ||
| import org.apache.calcite.sql.fun.SqlStdOperatorTable; | ||
| import org.apache.calcite.sql.type.ReturnTypes; | ||
| import org.opensearch.sql.calcite.CalcitePlanContext; | ||
| import org.opensearch.sql.calcite.udf.mathUDF.SqrtFunction; | ||
|
|
||
| public interface BuiltinFunctionUtils { | ||
|
|
||
|
|
@@ -51,6 +60,12 @@ static SqlOperator translate(String op) { | |
| // Built-in Math Functions | ||
| case "ABS": | ||
| return SqlStdOperatorTable.ABS; | ||
| case "SQRT": | ||
| return TransferUserDefinedFunction(SqrtFunction.class, "SQRT", ReturnTypes.DOUBLE); | ||
| case "ATAN", "ATAN2": | ||
| return SqlStdOperatorTable.ATAN2; | ||
| case "POW", "POWER": | ||
| return SqlStdOperatorTable.POWER; | ||
| // Built-in Date Functions | ||
| case "CURRENT_TIMESTAMP": | ||
| return SqlStdOperatorTable.CURRENT_TIMESTAMP; | ||
|
|
@@ -67,4 +82,29 @@ static SqlOperator translate(String op) { | |
| throw new IllegalArgumentException("Unsupported operator: " + op); | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Translates function arguments to align with Calcite's expectations, ensuring compatibility | ||
| * with PPL (Piped Processing Language). This is necessary because Calcite's input argument | ||
| * order or default values may differ from PPL's function definitions. | ||
| * | ||
| * @param op The function name as a string. | ||
| * @param argList A list of {@link RexNode} representing the parsed arguments from the PPL statement. | ||
| * @param context The {@link CalcitePlanContext} providing necessary utilities such as {@code rexBuilder}. | ||
| * @return A modified list of {@link RexNode} that correctly maps to Calcite’s function expectations. | ||
| */ | ||
| static List<RexNode> translateArgument( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is a framework work, we will change this method frequently, could you add some comments to explain this method for developers
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, already added. |
||
| String op, List<RexNode> argList, CalcitePlanContext context) { | ||
| switch (op.toUpperCase(Locale.ROOT)) { | ||
| case "ATAN": | ||
| List<RexNode> AtanArgs = new ArrayList<>(argList); | ||
| if (AtanArgs.size() == 1) { | ||
| BigDecimal divideNumber = BigDecimal.valueOf(1); | ||
| AtanArgs.add(context.rexBuilder.makeBigintLiteral(divideNumber)); | ||
| } | ||
| return AtanArgs; | ||
| default: | ||
| return argList; | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.sql.calcite.utils; | ||
|
|
||
| import static org.apache.calcite.sql.type.SqlTypeUtil.createArrayType; | ||
|
|
||
| import java.util.ArrayList; | ||
| import java.util.Collections; | ||
| import java.util.List; | ||
| import org.apache.calcite.linq4j.tree.Types; | ||
| import org.apache.calcite.rel.type.RelDataType; | ||
| import org.apache.calcite.rel.type.RelDataTypeFactory; | ||
| import org.apache.calcite.rex.RexNode; | ||
| import org.apache.calcite.schema.ScalarFunction; | ||
| import org.apache.calcite.schema.impl.AggregateFunctionImpl; | ||
| import org.apache.calcite.schema.impl.ScalarFunctionImpl; | ||
| import org.apache.calcite.sql.SqlIdentifier; | ||
| import org.apache.calcite.sql.SqlKind; | ||
| import org.apache.calcite.sql.SqlOperator; | ||
| import org.apache.calcite.sql.parser.SqlParserPos; | ||
| import org.apache.calcite.sql.type.SqlReturnTypeInference; | ||
| import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction; | ||
| import org.apache.calcite.sql.validate.SqlUserDefinedFunction; | ||
| import org.apache.calcite.tools.RelBuilder; | ||
| import org.apache.calcite.util.Optionality; | ||
| import org.opensearch.sql.calcite.udf.UserDefinedAggFunction; | ||
| import org.opensearch.sql.calcite.udf.UserDefinedFunction; | ||
|
|
||
| public class UserDefineFunctionUtils { | ||
| public static RelBuilder.AggCall TransferUserDefinedAggFunction( | ||
| Class<? extends UserDefinedAggFunction> UDAF, | ||
| String functionName, | ||
| SqlReturnTypeInference returnType, | ||
| List<RexNode> fields, | ||
| List<RexNode> argList, | ||
| RelBuilder relBuilder) { | ||
| SqlUserDefinedAggFunction sqlUDAF = | ||
| new SqlUserDefinedAggFunction( | ||
| new SqlIdentifier(functionName, SqlParserPos.ZERO), | ||
| SqlKind.OTHER_FUNCTION, | ||
| returnType, | ||
| null, | ||
| null, | ||
| AggregateFunctionImpl.create(UDAF), | ||
| false, | ||
| false, | ||
| Optionality.FORBIDDEN); | ||
| List<RexNode> addArgList = new ArrayList<>(fields); | ||
| addArgList.addAll(argList); | ||
| return relBuilder.aggregateCall(sqlUDAF, addArgList); | ||
| } | ||
|
|
||
| public static SqlOperator TransferUserDefinedFunction( | ||
| Class<? extends UserDefinedFunction> UDF, | ||
| String functionName, | ||
| SqlReturnTypeInference returnType) { | ||
| final ScalarFunction udfFunction = | ||
| ScalarFunctionImpl.create(Types.lookupMethod(UDF, "eval", Object[].class)); | ||
| SqlIdentifier udfLtrimIdentifier = | ||
| new SqlIdentifier(Collections.singletonList(functionName), null, SqlParserPos.ZERO, null); | ||
| return new SqlUserDefinedFunction( | ||
| udfLtrimIdentifier, SqlKind.OTHER_FUNCTION, returnType, null, null, udfFunction); | ||
| } | ||
|
|
||
| public static SqlReturnTypeInference getReturnTypeInferenceForArray() { | ||
| return opBinding -> { | ||
| RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); | ||
|
|
||
| // Get argument types | ||
| List<RelDataType> argTypes = opBinding.collectOperandTypes(); | ||
|
|
||
| if (argTypes.isEmpty()) { | ||
| throw new IllegalArgumentException("Function requires at least one argument."); | ||
| } | ||
| RelDataType firstArgType = argTypes.getFirst(); | ||
| return createArrayType(typeFactory, firstArgType, true); | ||
| }; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
format issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.