232232import java .util .HashMap ;
233233import java .util .List ;
234234import java .util .Map ;
235- import java .util .Objects ;
236235import java .util .Optional ;
237236import java .util .StringJoiner ;
238237import java .util .concurrent .ConcurrentHashMap ;
262261import org .apache .logging .log4j .LogManager ;
263262import org .apache .logging .log4j .Logger ;
264263import org .opensearch .sql .calcite .CalcitePlanContext ;
265- import org .opensearch .sql .calcite .utils .OpenSearchTypeFactory ;
266264import org .opensearch .sql .calcite .utils .PPLOperandTypes ;
267265import org .opensearch .sql .calcite .utils .PlanUtils ;
268266import org .opensearch .sql .calcite .utils .UserDefinedFunctionUtils ;
@@ -408,25 +406,40 @@ public void registerExternalAggOperator(
408406 aggExternalFunctionRegistry .put (functionName , Pair .of (signature , handler ));
409407 }
410408
409+ public void validateAggFunctionSignature (
410+ BuiltinFunctionName functionName , RexNode field , List <RexNode > argList ) {
411+ var implementation = getImplementation (functionName );
412+ validateFunctionArgs (implementation , functionName , field , argList );
413+ }
414+
411415 public RelBuilder .AggCall resolveAgg (
412416 BuiltinFunctionName functionName ,
413417 boolean distinct ,
414418 RexNode field ,
415419 List <RexNode > argList ,
416420 CalcitePlanContext context ) {
417- var implementation = aggExternalFunctionRegistry .get (functionName );
418- if (implementation == null ) {
419- implementation = aggFunctionRegistry .get (functionName );
420- }
421- if (implementation == null ) {
422- throw new IllegalStateException (String .format ("Cannot resolve function: %s" , functionName ));
423- }
421+ var implementation = getImplementation (functionName );
422+
423+ // Validation is done based on original argument types to generate error from user perspective.
424+ validateFunctionArgs (implementation , functionName , field , argList );
425+
426+ var handler = implementation .getValue ();
427+ return handler .apply (distinct , field , argList , context );
428+ }
429+
430+ static void validateFunctionArgs (
431+ Pair <CalciteFuncSignature , AggHandler > implementation ,
432+ BuiltinFunctionName functionName ,
433+ RexNode field ,
434+ List <RexNode > argList ) {
424435 CalciteFuncSignature signature = implementation .getKey ();
436+
425437 List <RelDataType > argTypes = new ArrayList <>();
426438 if (field != null ) {
427439 argTypes .add (field .getType ());
428440 }
429- // Currently only PERCENTILE_APPROX and TAKE have additional arguments.
441+
442+ // Currently only PERCENTILE_APPROX, TAKE, EARLIEST, and LATEST have additional arguments.
430443 // Their additional arguments will always come as a map of <argName, value>
431444 List <RelDataType > additionalArgTypes =
432445 argList .stream ().map (PlanUtils ::derefMapCall ).map (RexNode ::getType ).toList ();
@@ -442,10 +455,20 @@ public RelBuilder.AggCall resolveAgg(
442455 errorMessagePattern ,
443456 functionName ,
444457 signature .typeChecker ().getAllowedSignatures (),
445- getActualSignature (argTypes )));
458+ PlanUtils . getActualSignature (argTypes )));
446459 }
447- var handler = implementation .getValue ();
448- return handler .apply (distinct , field , argList , context );
460+ }
461+
462+ private Pair <CalciteFuncSignature , AggHandler > getImplementation (
463+ BuiltinFunctionName functionName ) {
464+ var implementation = aggExternalFunctionRegistry .get (functionName );
465+ if (implementation == null ) {
466+ implementation = aggFunctionRegistry .get (functionName );
467+ }
468+ if (implementation == null ) {
469+ throw new IllegalStateException (String .format ("Cannot resolve function: %s" , functionName ));
470+ }
471+ return implementation ;
449472 }
450473
451474 public RexNode resolve (final RexBuilder builder , final String functionName , RexNode ... args ) {
@@ -493,7 +516,7 @@ public RexNode resolve(
493516 throw new ExpressionEvaluationException (
494517 String .format (
495518 "Cannot resolve function: %s, arguments: %s, caused by: %s" ,
496- functionName , getActualSignature (argTypes ), e .getMessage ()),
519+ functionName , PlanUtils . getActualSignature (argTypes ), e .getMessage ()),
497520 e );
498521 }
499522 StringJoiner allowedSignatures = new StringJoiner ("," );
@@ -506,7 +529,7 @@ functionName, getActualSignature(argTypes), e.getMessage()),
506529 throw new ExpressionEvaluationException (
507530 String .format (
508531 "%s function expects {%s}, but got %s" ,
509- functionName , allowedSignatures , getActualSignature (argTypes )));
532+ functionName , allowedSignatures , PlanUtils . getActualSignature (argTypes )));
510533 }
511534
512535 /**
@@ -1074,21 +1097,6 @@ void registerOperator(BuiltinFunctionName functionName, SqlAggFunction aggFuncti
10741097 register (functionName , handler , typeChecker );
10751098 }
10761099
1077- private static RexNode resolveTimeField (List <RexNode > argList , CalcitePlanContext ctx ) {
1078- if (argList .isEmpty ()) {
1079- // Try to find @timestamp field
1080- var timestampField =
1081- ctx .relBuilder .peek ().getRowType ().getField ("@timestamp" , false , false );
1082- if (timestampField == null ) {
1083- throw new IllegalArgumentException (
1084- "Default @timestamp field not found. Please specify a time field explicitly." );
1085- }
1086- return ctx .rexBuilder .makeInputRef (timestampField .getType (), timestampField .getIndex ());
1087- } else {
1088- return PlanUtils .derefMapCall (argList .get (0 ));
1089- }
1090- }
1091-
10921100 void populate () {
10931101 registerOperator (MAX , SqlStdOperatorTable .MAX );
10941102 registerOperator (MIN , SqlStdOperatorTable .MIN );
@@ -1118,8 +1126,7 @@ void populate() {
11181126 return ctx .relBuilder .count (distinct , null , field );
11191127 }
11201128 },
1121- wrapSqlOperandTypeChecker (
1122- SqlStdOperatorTable .COUNT .getOperandTypeChecker (), COUNT .name (), false ));
1129+ wrapSqlOperandTypeChecker (PPLOperandTypes .OPTIONAL_ANY , COUNT .name (), false ));
11231130
11241131 register (
11251132 PERCENTILE_APPROX ,
@@ -1166,20 +1173,22 @@ void populate() {
11661173 register (
11671174 EARLIEST ,
11681175 (distinct , field , argList , ctx ) -> {
1169- RexNode timeField = resolveTimeField (argList , ctx );
1170- return ctx .relBuilder .aggregateCall (SqlStdOperatorTable .ARG_MIN , field , timeField );
1176+ List <RexNode > args = resolveTimeField (argList , ctx );
1177+ return UserDefinedFunctionUtils .makeAggregateCall (
1178+ SqlStdOperatorTable .ARG_MIN , List .of (field ), args , ctx .relBuilder );
11711179 },
11721180 wrapSqlOperandTypeChecker (
1173- SqlStdOperatorTable . ARG_MIN . getOperandTypeChecker () , EARLIEST .name (), false ));
1181+ PPLOperandTypes . ANY_OPTIONAL_TIMESTAMP , EARLIEST .name (), false ));
11741182
11751183 register (
11761184 LATEST ,
11771185 (distinct , field , argList , ctx ) -> {
1178- RexNode timeField = resolveTimeField (argList , ctx );
1179- return ctx .relBuilder .aggregateCall (SqlStdOperatorTable .ARG_MAX , field , timeField );
1186+ List <RexNode > args = resolveTimeField (argList , ctx );
1187+ return UserDefinedFunctionUtils .makeAggregateCall (
1188+ SqlStdOperatorTable .ARG_MAX , List .of (field ), args , ctx .relBuilder );
11801189 },
11811190 wrapSqlOperandTypeChecker (
1182- SqlStdOperatorTable . ARG_MAX . getOperandTypeChecker (), LATEST .name (), false ));
1191+ PPLOperandTypes . ANY_OPTIONAL_TIMESTAMP , EARLIEST .name (), false ));
11831192
11841193 // Register FIRST function - uses document order
11851194 register (
@@ -1203,19 +1212,19 @@ void populate() {
12031212 }
12041213 }
12051214
1206- /**
1207- * Get a string representation of the argument types expressed in ExprType for error messages.
1208- *
1209- * @param argTypes the list of argument types as {@link RelDataType}
1210- * @return a string in the format [type1,type2,...] representing the argument types
1211- */
1212- private static String getActualSignature ( List < RelDataType > argTypes ) {
1213- return "["
1214- + argTypes . stream ()
1215- . map ( OpenSearchTypeFactory :: convertRelDataTypeToExprType )
1216- . map ( Objects :: toString )
1217- . collect (Collectors .joining ( "," ))
1218- + "]" ;
1215+ static List < RexNode > resolveTimeField ( List < RexNode > argList , CalcitePlanContext ctx ) {
1216+ if ( argList . isEmpty ()) {
1217+ // Try to find @timestamp field
1218+ var timestampField = ctx . relBuilder . peek (). getRowType (). getField ( "@timestamp" , false , false );
1219+ if ( timestampField == null ) {
1220+ throw new IllegalArgumentException (
1221+ "Default @timestamp field not found. Please specify a time field explicitly." );
1222+ }
1223+ return List . of (
1224+ ctx . rexBuilder . makeInputRef ( timestampField . getType (), timestampField . getIndex ()));
1225+ } else {
1226+ return argList . stream (). map ( PlanUtils :: derefMapCall ). collect (Collectors .toList ());
1227+ }
12191228 }
12201229
12211230 /**
@@ -1259,6 +1268,8 @@ private static PPLTypeChecker wrapSqlOperandTypeChecker(
12591268 pplTypeChecker = PPLTypeChecker .wrapComparable (comparableTypeChecker );
12601269 } else if (typeChecker instanceof UDFOperandMetadata .UDTOperandMetadata udtOperandMetadata ) {
12611270 pplTypeChecker = PPLTypeChecker .wrapUDT (udtOperandMetadata .allowedParamTypes ());
1271+ } else if (typeChecker != null ) {
1272+ pplTypeChecker = PPLTypeChecker .wrapDefault (typeChecker );
12621273 } else {
12631274 logger .info (
12641275 "Cannot create type checker for function: {}. Will skip its type checking" , functionName );
0 commit comments