Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -1065,13 +1065,17 @@ protected static <T extends Function> FunctionDefinition def(Class<T> function,
if (children.size() > 4 || children.size() < 2) {
throw new QlIllegalArgumentException("expects minimum two, maximum four arguments");
}
} else if (ThreeOptionalArguments.class.isAssignableFrom(function)) {
if (children.size() > 4 || children.isEmpty()) {
throw new QlIllegalArgumentException("expects minimum one, maximum four arguments");
}
} else if (children.size() != 4) {
throw new QlIllegalArgumentException("expects exactly four arguments");
}
return ctorRef.build(
source,
children.get(0),
children.get(1),
children.size() > 1 ? children.get(1) : null,
children.size() > 2 ? children.get(2) : null,
children.size() > 3 ? children.get(3) : null
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function;

/**
* Marker interface indicating that a function accepts three optional arguments (the last three).
* This is used by the {@link EsqlFunctionRegistry} to perform validation of function declaration.
*/
public interface ThreeOptionalArguments {

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.ThreeOptionalArguments;
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.TimeZone;

import static org.elasticsearch.common.time.DateFormatter.forPattern;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
Expand All @@ -38,7 +41,7 @@
import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.DEFAULT_DATE_TIME_FORMATTER;
import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.dateTimeToLong;

public class DateParse extends EsqlScalarFunction implements OptionalArgument {
public class DateParse extends EsqlScalarFunction implements ThreeOptionalArguments {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
Expression.class,
"DateParse",
Expand All @@ -47,6 +50,8 @@ public class DateParse extends EsqlScalarFunction implements OptionalArgument {

private final Expression field;
private final Expression format;
private final Expression locale;
private final Expression timezone;

@FunctionInfo(
returnType = "date",
Expand All @@ -63,17 +68,38 @@ public DateParse(
name = "dateString",
type = { "keyword", "text" },
description = "Date expression as a string. If `null` or an empty string, the function returns `null`."
) Expression second
) Expression second,
@Param(name = "dateLocale", type = { "keyword", "text" }, description = "The locale to parse with") Expression third,
@Param(name = "dateTimezone", type = { "keyword", "text" }, description = "The timezone to parse with") Expression forth
) {
super(source, second != null ? List.of(first, second) : List.of(first));
super(source, fields(first, second, third, forth));
this.field = second != null ? second : first;
this.format = second != null ? first : null;
this.locale = third;
this.timezone = forth;
}

private static List<Expression> fields(Expression field, Expression format, Expression locale, Expression timezone) {
List<Expression> list = new ArrayList<>(3);
list.add(field);
if (format != null) {
list.add(format);
}
if (locale != null) {
list.add(locale);
}
if (timezone != null) {
list.add(timezone);
}
return list;
}

private DateParse(StreamInput in) throws IOException {
this(
Source.readFrom((PlanStreamInput) in),
in.readNamedWriteable(Expression.class),
in.readOptionalNamedWriteable(Expression.class),
in.readOptionalNamedWriteable(Expression.class),
in.readOptionalNamedWriteable(Expression.class)
);
}
Expand All @@ -82,7 +108,9 @@ private DateParse(StreamInput in) throws IOException {
public void writeTo(StreamOutput out) throws IOException {
source().writeTo(out);
out.writeNamedWriteable(children().get(0));
out.writeOptionalNamedWriteable(children().size() == 2 ? children().get(1) : null);
out.writeOptionalNamedWriteable(children().size() > 1 ? children().get(1) : null);
out.writeOptionalNamedWriteable(children().size() > 2 ? children().get(2) : null);
out.writeOptionalNamedWriteable(children().size() > 3 ? children().get(3) : null);
}

@Override
Expand Down Expand Up @@ -141,9 +169,23 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
if (DataType.isString(format.dataType()) == false) {
throw new IllegalArgumentException("unsupported data type for date_parse [" + format.dataType() + "]");
}
String localeAsString = locale == null ? null : ((BytesRef) locale.fold(toEvaluator.foldCtx())).utf8ToString();
Locale locale = localeAsString == null ? null : Locale.forLanguageTag(localeAsString);
if (localeAsString != null && locale == null) {
throw new IllegalArgumentException("unsupported locale [" + localeAsString + "]");
}

String timezoneAsString = timezone == null ? null : ((BytesRef) timezone.fold(toEvaluator.foldCtx())).utf8ToString();
TimeZone timezone = timezoneAsString == null ? null : TimeZone.getTimeZone(timezoneAsString);
if (format.foldable()) {
try {
DateFormatter formatter = toFormatter(format.fold(toEvaluator.foldCtx()));
if (locale != null) {
formatter = formatter.withLocale(locale);
}
if (timezone != null) {
formatter = formatter.withZone(timezone.toZoneId());
}
return new DateParseConstantEvaluator.Factory(source(), fieldEvaluator, formatter);
} catch (IllegalArgumentException e) {
throw new InvalidArgumentException(e, "invalid date pattern for [{}]: {}", sourceText(), e.getMessage());
Expand All @@ -159,13 +201,19 @@ private static DateFormatter toFormatter(Object format) {

@Override
public Expression replaceChildren(List<Expression> newChildren) {
return new DateParse(source(), newChildren.get(0), newChildren.size() > 1 ? newChildren.get(1) : null);
return new DateParse(
source(),
newChildren.get(0),
newChildren.size() > 1 ? newChildren.get(1) : null,
newChildren.size() > 2 ? newChildren.get(2) : null,
newChildren.size() > 3 ? newChildren.get(3) : null
);
}

@Override
protected NodeInfo<? extends Expression> info() {
Expression first = format != null ? format : field;
Expression second = format != null ? field : null;
return NodeInfo.create(this, DateParse::new, first, second);
return NodeInfo.create(this, DateParse::new, first, second, locale, timezone);
}
}